In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from pathlib import Path

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

BASE_DIR = Path("/content/drive/Shareddrives/TissueMotionForecasting")
TRAIN_ROOT = BASE_DIR / "scared_data" / "train"
print("TRAIN_ROOT:", TRAIN_ROOT)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda
TRAIN_ROOT: /content/drive/Shareddrives/TissueMotionForecasting/scared_data/train


In [None]:
CROP_H, CROP_W = 256, 320

def random_crop_torch(stack, ch, cw):
    """
    stack: [C, H, W] or [C+1, H, W]
    returns: cropped stack with same channel dim, spatial ch x cw
    """
    _, H, W = stack.shape
    if H <= ch or W <= cw:
        return stack
    top = np.random.randint(0, H - ch)
    left = np.random.randint(0, W - cw)
    return stack[:, top:top+ch, left:left+cw]


In [None]:
CONTEXT_LEN = 3
FORECAST_HORIZON = 5
DISP_SCALE = 256.0

DATASETS_TO_USE = ["dataset_1", "dataset_2", "dataset_3"]
KEYFRAME_NAMES  = [f"keyframe_{i}" for i in range(1, 6)]


class DisparityForecastDataset(Dataset):
    """
    Build (context, target) forecast pairs from RAW disparity PNGs.

    - Uses disparity PNGs from:
        TRAIN_ROOT/dataset_X/keyframe_Y/data/disparity/*.png
    - Ignores colored_*.png
    - context: [C,H,W], C = context_len
    - target : [1,H,W] at t + forecast_horizon
    """
    def __init__(self, train_root, dataset_names, keyframe_names,
                 context_len=3, forecast_horizon=10, scale=256.0):
        self.samples = []
        self.context_len = context_len
        self.forecast_horizon = forecast_horizon
        self.scale = scale

        for ds_name in dataset_names:
            for kf_name in keyframe_names:
                disp_dir = train_root / ds_name / kf_name / "data" / "disparity"
                if not disp_dir.exists():
                    continue

                frame_paths = sorted([
                    p for p in disp_dir.glob("*.png")
                    if not p.name.startswith("colored")
                ])

                if len(frame_paths) < context_len + forecast_horizon:
                    continue

                # slide window: [i-(C-1) ... i] -> target at i+H
                for i in range(context_len - 1,
                               len(frame_paths) - forecast_horizon):
                    ctx_paths = frame_paths[i - (context_len - 1): i + 1]
                    tgt_path  = frame_paths[i + forecast_horizon]
                    self.samples.append((ctx_paths, tgt_path))

        print(f"DisparityForecastDataset: {len(self.samples)} samples")

    def _load_disp(self, path: Path):
        img = Image.open(path).convert("I")
        arr = np.array(img, dtype=np.float32)
        arr = arr / self.scale
        return arr

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        ctx_paths, tgt_path = self.samples[idx]
        ctx_frames = [self._load_disp(p) for p in ctx_paths]  # list of [H,W]
        ctx = np.stack(ctx_frames, axis=0)                    # [C,H,W]
        tgt = self._load_disp(tgt_path)                       # [H,W]

        ctx = torch.from_numpy(ctx.astype("float32"))         # [C,H,W]
        tgt = torch.from_numpy(tgt.astype("float32")).unsqueeze(0)  # [1,H,W]

        stack = torch.cat([ctx, tgt], dim=0)                  # [C+1,H,W]
        stack = random_crop_torch(stack, CROP_H, CROP_W)      # [C+1,ch,cw]
        ctx = stack[:-1]                                      # [C,ch,cw]
        tgt = stack[-1:].contiguous()                         # [1,ch,cw]

        return ctx, tgt


dataset = DisparityForecastDataset(
    train_root=TRAIN_ROOT,
    dataset_names=DATASETS_TO_USE,
    keyframe_names=KEYFRAME_NAMES,
    context_len=CONTEXT_LEN,
    forecast_horizon=FORECAST_HORIZON,
    scale=DISP_SCALE,
)

print("Total samples:", len(dataset))
ctx_sample, tgt_sample = dataset[0]
print("Context shape:", ctx_sample.shape)
print("Target  shape:", tgt_sample.shape)


DisparityForecastDataset: 7952 samples
Total samples: 7952
Context shape: torch.Size([3, 256, 320])
Target  shape: torch.Size([1, 256, 320])


In [None]:
from torch.utils.data import random_split, DataLoader

VAL_FRACTION = 0.2
val_len = int(len(dataset) * VAL_FRACTION)
train_len = len(dataset) - val_len

train_set, val_set = random_split(dataset, [train_len, val_len])
print(f"Train samples: {train_len}, Val samples: {val_len}")

BATCH_SIZE = 2

train_loader = DataLoader(
    train_set, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=2, pin_memory= True
)
val_loader = DataLoader(
    val_set, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=2, pin_memory=True
)

ctx_batch, tgt_batch = next(iter(train_loader))
print("ctx_batch:", ctx_batch.shape, "tgt_batch:", tgt_batch.shape)


Train samples: 6362, Val samples: 1590
ctx_batch: torch.Size([2, 3, 256, 320]) tgt_batch: torch.Size([2, 1, 256, 320])


In [None]:
import torch.nn as nn
import torch

class DoubleConv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super().__init__()

        self.enc1 = DoubleConv(in_channels, 32)
        self.enc2 = DoubleConv(32, 64)
        self.enc3 = DoubleConv(64, 128)
        self.enc4 = DoubleConv(128, 256)

        self.pool = nn.MaxPool2d(2)

        self.bottleneck = DoubleConv(256, 512)

        self.up4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(256 + 256, 256)

        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(128 + 128, 128)

        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(64 + 64, 64)

        self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(32 + 32, 32)

        self.out_conv = nn.Conv2d(32, out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        b = self.bottleneck(self.pool(e4))

        d4 = self.up4(b)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)

        d3 = self.up3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)

        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        out = self.out_conv(d1)
        return out


In [None]:
_, C, Hc, Wc = ctx_batch.shape   # ctx_batch: [B,3,H,W]
print(f"Inferred crop size from loader: H={Hc}, W={Wc}")

backbone = UNet(in_channels=3, out_channels=1).to(device)

CKPT_PATH = "/content/drive/Shareddrives/TissueMotionForecasting/models/unet_forecast_kf1to3_7_epochs.pth"
state_dict = torch.load(CKPT_PATH, map_location=device)
backbone.load_state_dict(state_dict)
backbone.to(device)
backbone.eval()

for p in backbone.parameters():
    p.requires_grad = False

print("Loaded UNet checkpoint and froze backbone.")

with torch.no_grad():
    out = backbone(ctx_batch.to(device))
print("UNet output shape:", out.shape)


Inferred crop size from loader: H=256, W=320
Loaded UNet checkpoint and froze backbone.
UNet output shape: torch.Size([2, 1, 256, 320])


In [None]:
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, num_groups=4):
        super().__init__()
        g = min(num_groups, out_ch)
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(g, out_ch),
            nn.SiLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(g, out_ch),
            nn.SiLU(),
        )

    def forward(self, x):
        return self.conv(x)

class DiffusionHead(nn.Module):
    """
    Much smaller U-Net:
      base_channels = 8
      input : [B,4,H,W] (x_t, ctx_last, mu, t_channel)
      output: [B,1,H,W] (eps_hat)
    """
    def __init__(self, in_channels=4, base_channels=8):
        super().__init__()
        c = base_channels

        # encoder
        self.enc1 = ConvBlock(in_channels, c)
        self.down1 = nn.Conv2d(c, c, 4, stride=2, padding=1)  # H/2

        self.enc2 = ConvBlock(c, c * 2)
        self.down2 = nn.Conv2d(c * 2, c * 2, 4, stride=2, padding=1)  # H/4

        # bottleneck
        self.bottleneck = ConvBlock(c * 2, c * 4)

        # decoder
        self.up2 = nn.ConvTranspose2d(c * 4, c * 2, 4, stride=2, padding=1)
        self.dec2 = ConvBlock(c * 4, c * 2)

        self.up1 = nn.ConvTranspose2d(c * 2, c, 4, stride=2, padding=1)
        self.dec1 = ConvBlock(c * 2, c)

        self.out_conv = nn.Conv2d(c, 1, 1)

    def forward(self, x_in):
        # x_in: [B,4,H,W]
        e1 = self.enc1(x_in)     # [B,c,H,W]
        d1 = self.down1(e1)      # [B,c,H/2,W/2]

        e2 = self.enc2(d1)       # [B,2c,H/2,W/2]
        d2 = self.down2(e2)      # [B,2c,H/4,W/4]

        b = self.bottleneck(d2)  # [B,4c,H/4,W/4]

        u2 = self.up2(b)         # [B,2c,H/2,W/2]
        u2 = torch.cat([u2, e2], dim=1)
        u2 = self.dec2(u2)       # [B,2c,H/2,W/2]

        u1 = self.up1(u2)        # [B,c,H,W]
        u1 = torch.cat([u1, e1], dim=1)
        u1 = self.dec1(u1)       # [B,c,H,W]

        out = self.out_conv(u1)  # [B,1,H,W]
        return out

diff_head = DiffusionHead(in_channels=4, base_channels=8).to(device)
print("Diffusion head params:",
      sum(p.numel() for p in diff_head.parameters()) / 1e6, "M")



Diffusion head params: 0.042681 M


In [None]:
T = 10
beta_start, beta_end = 1e-4, 0.02
betas = torch.linspace(beta_start, beta_end, T, device=device)
alphas = 1.0 - betas
alpha_bar = torch.cumprod(alphas, dim=0)

print("Defined diffusion schedule with T =", T)


optimizer = torch.optim.Adam(diff_head.parameters(), lr=3e-4)
NUM_DIFF_EPOCHS = 5
for epoch in range(1, NUM_DIFF_EPOCHS + 1):
    diff_head.train()
    running_loss = 0.0
    n_batches = 0

    for ctx, tgt in train_loader:
        ctx = ctx.to(device)   # [1,3,H,W]
        tgt = tgt.to(device)   # [1,1,H,W]

        with torch.no_grad():
            mu = backbone(ctx)           # [B,1,H,W]
            ctx_last = ctx[:, -1:, :, :]  # # [B,1,H,W]

        cond = torch.cat([ctx_last, mu], dim=1)  # [1,2,H,W]

        B = tgt.shape[0]
        t_idx = torch.randint(0, T, (B,), device=device)
        eps = torch.randn_like(tgt)

        a_bar = alpha_bar[t_idx].view(B, 1, 1, 1)
        x_t = torch.sqrt(a_bar) * tgt + torch.sqrt(1.0 - a_bar) * eps

        t_channel = (t_idx.float() / (T - 1)).view(B, 1, 1, 1)
        t_channel = t_channel.expand_as(tgt)

        x_in = torch.cat([x_t, cond, t_channel], dim=1)  # [1,4,H,W]
        eps_hat = diff_head(x_in)

        mask = (tgt > 0).float()
        sq_err = (eps_hat - eps) ** 2 * mask
        loss = sq_err.sum() / (mask.sum() + 1e-8)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        n_batches += 1

    avg_loss = running_loss / max(1, n_batches)
    print(f"[Diffusion] Epoch {epoch:02d} | train noise-MSE(masked): {avg_loss:.4f}")


In [None]:
import torch

RAW_SAVE_PATH = "/content/drive/Shareddrives/TissueMotionForecasting/models/diffusion_head_with_5_epochs.pth"

torch.save({
    "state_dict": diff_head.state_dict(),   # raw weights
    "betas": betas,                         # diffusion schedule
    "alpha_bar": alpha_bar,
}, RAW_SAVE_PATH)

print(f"Saved RAW diffusion head → {RAW_SAVE_PATH}")

[✓] Saved RAW diffusion head → /content/drive/Shareddrives/TissueMotionForecasting/models/diffusion_head_with_5_epochs.pth


In [None]:
import torch

CKPT_PATH = "/content/drive/Shareddrives/TissueMotionForecasting/models/diffusion_head_with_15_epochs.pth"

ckpt = torch.load(CKPT_PATH, map_location=device)

# restore weights
diff_head.load_state_dict(ckpt["state_dict"])

# restore diffusion schedule
betas = ckpt["betas"].to(device)
alpha_bar = ckpt["alpha_bar"].to(device)
T = betas.shape[0]

print(f"Reloaded diffusion head from {CKPT_PATH} with T = {T}")


Reloaded diffusion head from /content/drive/Shareddrives/TissueMotionForecasting/models/diffusion_head_with_15_epochs.pth with T = 10


In [None]:
LR = 3e-4
optimizer = torch.optim.Adam(diff_head.parameters(), lr=LR)
print("Continuing training diffusion head with LR =", LR)

NUM_EXTRA_EPOCHS = 20

for epoch in range(1, NUM_EXTRA_EPOCHS + 1):
    diff_head.train()
    running_loss = 0.0
    n_batches = 0

    for ctx, tgt in train_loader:
        ctx = ctx.to(device)   # [B,3,H,W]
        tgt = tgt.to(device)   # [B,1,H,W]

        with torch.no_grad():
            mu = backbone(ctx)           # [B,1,H,W]
            ctx_last = ctx[:, -1:, :, :] # [B,1,H,W]

        cond = torch.cat([ctx_last, mu], dim=1)  # [B,2,H,W]

        B = tgt.shape[0]
        t_idx = torch.randint(0, T, (B,), device=device)
        eps = torch.randn_like(tgt)

        a_bar = alpha_bar[t_idx].view(B, 1, 1, 1)
        x_t = torch.sqrt(a_bar) * tgt + torch.sqrt(1.0 - a_bar) * eps

        t_channel = (t_idx.float() / (T - 1)).view(B, 1, 1, 1)
        t_channel = t_channel.expand_as(tgt)

        x_in = torch.cat([x_t, cond, t_channel], dim=1)
        eps_hat = diff_head(x_in)

        mask = (tgt > 0).float()
        sq_err = (eps_hat - eps) ** 2 * mask
        loss = sq_err.sum() / (mask.sum() + 1e-8)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(diff_head.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item()
        n_batches += 1

    avg_loss = running_loss / max(1, n_batches)
    print(f"[Diffusion] Extra Epoch {epoch:02d} | train noise-MSE(masked): {avg_loss:.4f}")


Continuing training diffusion head with LR = 0.0003
[Diffusion] Extra Epoch 01 | train noise-MSE(masked): 0.7411
[Diffusion] Extra Epoch 02 | train noise-MSE(masked): 0.7386
[Diffusion] Extra Epoch 03 | train noise-MSE(masked): 0.7247
[Diffusion] Extra Epoch 04 | train noise-MSE(masked): 0.7250
[Diffusion] Extra Epoch 05 | train noise-MSE(masked): 0.7242
[Diffusion] Extra Epoch 06 | train noise-MSE(masked): 0.7221
[Diffusion] Extra Epoch 07 | train noise-MSE(masked): 0.7218
[Diffusion] Extra Epoch 08 | train noise-MSE(masked): 0.7371
[Diffusion] Extra Epoch 09 | train noise-MSE(masked): 0.7146
[Diffusion] Extra Epoch 10 | train noise-MSE(masked): 0.7066
[Diffusion] Extra Epoch 11 | train noise-MSE(masked): 0.7026
[Diffusion] Extra Epoch 12 | train noise-MSE(masked): 0.7082
[Diffusion] Extra Epoch 13 | train noise-MSE(masked): 0.7099
[Diffusion] Extra Epoch 14 | train noise-MSE(masked): 0.7080
[Diffusion] Extra Epoch 15 | train noise-MSE(masked): 0.7030
[Diffusion] Extra Epoch 16 | trai

In [None]:
NEW_SAVE_PATH = "/content/drive/Shareddrives/TissueMotionForecasting/models/diffusion_head_with_35_epochs.pth"

torch.save({
    "state_dict": diff_head.state_dict(),
    "betas": betas,
    "alpha_bar": alpha_bar,
}, NEW_SAVE_PATH)

print(f"Saved updated diffusion head → {NEW_SAVE_PATH}")


[✓] Saved updated diffusion head → /content/drive/Shareddrives/TissueMotionForecasting/models/diffusion_head_with_35_epochs.pth


In [None]:
DIFF_CKPT_PATH = "/content/drive/Shareddrives/TissueMotionForecasting/models/diffusion_head_with_35_epochs.pth"

ckpt = torch.load(DIFF_CKPT_PATH, map_location=device)

# recreate the diffusion head with the same architecture
diff_head = DiffusionHead(in_channels=4, base_channels=8).to(device)
diff_head.load_state_dict(ckpt["state_dict"])
diff_head.eval()

# restore schedule from checkpoint
betas = ckpt["betas"].to(device)
alpha_bar = ckpt["alpha_bar"].to(device)
T = betas.shape[0]

print(f"Loaded diffusion head from {DIFF_CKPT_PATH}")
print(f"T (num timesteps) = {T}")

Loaded diffusion head from /content/drive/Shareddrives/TissueMotionForecasting/models/diffusion_head_with_15_epochs.pth
T (num timesteps) = 10
