In [None]:
!pip install -q "monai[all]" nibabel einops

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.6/52.6 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.0/40.0 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.2/47.2 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.5/266.5 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m53.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m80.9/80.9 MB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.8/67.8 MB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from pathlib import Path

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from monai.networks.nets import SwinUNETR

device = torch.device("cpu")
print("Using device:", device)

BASE_DIR = Path("/content/drive/Shareddrives/TissueMotionForecasting")
TRAIN_ROOT = BASE_DIR / "scared_data" / "train"
print("TRAIN_ROOT:", TRAIN_ROOT)


Mounted at /content/drive




Using device: cpu
TRAIN_ROOT: /content/drive/Shareddrives/TissueMotionForecasting/scared_data/train


In [None]:
CROP_H, CROP_W = 256, 320

def random_crop_torch(stack, ch, cw):
    """
    stack: [C, H, W] or [C+1, H, W]
    returns: cropped stack with same channel dim, spatial ch x cw
    """
    _, H, W = stack.shape
    if H <= ch or W <= cw:
        return stack
    top = np.random.randint(0, H - ch)
    left = np.random.randint(0, W - cw)
    return stack[:, top:top+ch, left:left+cw]


In [None]:
CONTEXT_LEN = 3
FORECAST_HORIZON = 10
DISP_SCALE = 256.0

DATASETS_TO_USE = ["dataset_1", "dataset_2", "dataset_3"]
KEYFRAME_NAMES  = [f"keyframe_{i}" for i in range(1, 6)]  # 1..5


class DisparityForecastDataset(Dataset):
    """
    Build (context, target) forecast pairs from RAW disparity PNGs.

    - Uses disparity PNGs from:
        TRAIN_ROOT/dataset_X/keyframe_Y/data/disparity/*.png
    - Ignores colored_*.png
    - context: [C,H,W], C = context_len
    - target : [1,H,W] at t + forecast_horizon
    """
    def __init__(self, train_root, dataset_names, keyframe_names,
                 context_len=3, forecast_horizon=10, scale=256.0):
        self.samples = []
        self.context_len = context_len
        self.forecast_horizon = forecast_horizon
        self.scale = scale

        for ds_name in dataset_names:
            for kf_name in keyframe_names:
                disp_dir = train_root / ds_name / kf_name / "data" / "disparity"
                if not disp_dir.exists():
                    continue

                frame_paths = sorted([
                    p for p in disp_dir.glob("*.png")
                    if not p.name.startswith("colored")
                ])

                if len(frame_paths) < context_len + forecast_horizon:
                    continue


                for i in range(context_len - 1,
                               len(frame_paths) - forecast_horizon):
                    ctx_paths = frame_paths[i - (context_len - 1): i + 1]
                    tgt_path  = frame_paths[i + forecast_horizon]
                    self.samples.append((ctx_paths, tgt_path))

        print(f"DisparityForecastDataset: {len(self.samples)} samples")

    def _load_disp(self, path: Path):
        img = Image.open(path).convert("I")
        arr = np.array(img, dtype=np.float32)
        arr = arr / self.scale
        return arr

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        ctx_paths, tgt_path = self.samples[idx]
        ctx_frames = [self._load_disp(p) for p in ctx_paths]  # list of [H,W]
        ctx = np.stack(ctx_frames, axis=0)                    # [C,H,W]
        tgt = self._load_disp(tgt_path)                       # [H,W]

        ctx = torch.from_numpy(ctx.astype("float32"))         # [C,H,W]
        tgt = torch.from_numpy(tgt.astype("float32")).unsqueeze(0)  # [1,H,W]

        stack = torch.cat([ctx, tgt], dim=0)                  # [C+1,H,W]
        stack = random_crop_torch(stack, CROP_H, CROP_W)      # [C+1,ch,cw]
        ctx = stack[:-1]                                      # [C,ch,cw]
        tgt = stack[-1:].contiguous()                         # [1,ch,cw]

        return ctx, tgt


dataset = DisparityForecastDataset(
    train_root=TRAIN_ROOT,
    dataset_names=DATASETS_TO_USE,
    keyframe_names=KEYFRAME_NAMES,
    context_len=CONTEXT_LEN,
    forecast_horizon=FORECAST_HORIZON,
    scale=DISP_SCALE,
)

print("Total samples:", len(dataset))
ctx_sample, tgt_sample = dataset[0]
print("Context shape:", ctx_sample.shape)
print("Target  shape:", tgt_sample.shape)


DisparityForecastDataset: 7902 samples
Total samples: 7902
Context shape: torch.Size([3, 1024, 1280])
Target  shape: torch.Size([1, 1024, 1280])


In [None]:
from torch.utils.data import random_split, DataLoader

VAL_FRACTION = 0.2
val_len = int(len(dataset) * VAL_FRACTION)
train_len = len(dataset) - val_len

train_set, val_set = random_split(dataset, [train_len, val_len])
print(f"Train samples: {train_len}, Val samples: {val_len}")

BATCH_SIZE = 1

train_loader = DataLoader(
    train_set, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=2, pin_memory=True
)
val_loader = DataLoader(
    val_set, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=2, pin_memory=True
)

ctx_batch, tgt_batch = next(iter(train_loader))
print("ctx_batch:", ctx_batch.shape, "tgt_batch:", tgt_batch.shape)


Train samples: 6322, Val samples: 1580


'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.


ctx_batch: torch.Size([1, 3, 1024, 1280]) tgt_batch: torch.Size([1, 1, 1024, 1280])


In [None]:

_, C, Hc, Wc = ctx_batch.shape
print(f"Inferred img_size from loader: H={Hc}, W={Wc}")

FEATURE_SIZE = 48

swin_unetr = SwinUNETR(
    in_channels=3,
    out_channels=1,
    feature_size=FEATURE_SIZE,
    spatial_dims=2,
    use_checkpoint=True,
).to(device)


CKPT_PATH = "/content/drive/Shareddrives/TissueMotionForecasting/models/swin_unetr_forecast.pth"
state_dict = torch.load(CKPT_PATH, map_location=device)
swin_unetr.load_state_dict(state_dict)
swin_unetr.to(device)
swin_unetr.eval()

for p in swin_unetr.parameters():
    p.requires_grad = False

print("Loaded Swin-UNETR checkpoint and froze backbone.")

with torch.no_grad():
    out = swin_unetr(ctx_batch.to(device))
print("Swin-UNETR output shape:", out.shape)


Inferred img_size from loader: H=1024, W=1280
Loaded Swin-UNETR checkpoint and froze backbone.
Swin-UNETR output shape: torch.Size([1, 1, 1024, 1280])


In [None]:
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, num_groups=4):
        super().__init__()
        g = min(num_groups, out_ch)
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(g, out_ch),
            nn.SiLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(g, out_ch),
            nn.SiLU(),
        )

    def forward(self, x):
        return self.conv(x)

class DiffusionHead(nn.Module):
    """
    Much smaller U-Net:
      base_channels = 8
      input : [B,4,H,W] (x_t, ctx_last, mu, t_channel)
      output: [B,1,H,W] (eps_hat)
    """
    def __init__(self, in_channels=4, base_channels=8):
        super().__init__()
        c = base_channels

        # encoder
        self.enc1 = ConvBlock(in_channels, c)
        self.down1 = nn.Conv2d(c, c, 4, stride=2, padding=1)  # H/2

        self.enc2 = ConvBlock(c, c * 2)
        self.down2 = nn.Conv2d(c * 2, c * 2, 4, stride=2, padding=1)  # H/4

        # bottleneck
        self.bottleneck = ConvBlock(c * 2, c * 4)

        # decoder
        self.up2 = nn.ConvTranspose2d(c * 4, c * 2, 4, stride=2, padding=1)
        self.dec2 = ConvBlock(c * 4, c * 2)

        self.up1 = nn.ConvTranspose2d(c * 2, c, 4, stride=2, padding=1)
        self.dec1 = ConvBlock(c * 2, c)

        self.out_conv = nn.Conv2d(c, 1, 1)

    def forward(self, x_in):
        # x_in: [B,4,H,W]
        e1 = self.enc1(x_in)     # [B,c,H,W]
        d1 = self.down1(e1)      # [B,c,H/2,W/2]

        e2 = self.enc2(d1)       # [B,2c,H/2,W/2]
        d2 = self.down2(e2)      # [B,2c,H/4,W/4]

        b = self.bottleneck(d2)  # [B,4c,H/4,W/4]

        u2 = self.up2(b)         # [B,2c,H/2,W/2]
        u2 = torch.cat([u2, e2], dim=1)
        u2 = self.dec2(u2)       # [B,2c,H/2,W/2]

        u1 = self.up1(u2)        # [B,c,H,W]
        u1 = torch.cat([u1, e1], dim=1)
        u1 = self.dec1(u1)       # [B,c,H,W]

        out = self.out_conv(u1)  # [B,1,H,W]
        return out

diff_head = DiffusionHead(in_channels=4, base_channels=8).to(device)
print("Diffusion head params:",
      sum(p.numel() for p in diff_head.parameters()) / 1e6, "M")



Diffusion head params: 0.042681 M


In [None]:
T = 10
beta_start, beta_end = 1e-4, 0.02
betas = torch.linspace(beta_start, beta_end, T, device=device)
alphas = 1.0 - betas
alpha_bar = torch.cumprod(alphas, dim=0)

print("Defined diffusion schedule with T =", T)

optimizer = torch.optim.Adam(diff_head.parameters(), lr=1e-4)
NUM_DIFF_EPOCHS = 1

for epoch in range(1, NUM_DIFF_EPOCHS + 1):
    diff_head.train()
    running_loss = 0.0
    n_batches = 0

    for ctx, tgt in train_loader:
        ctx = ctx.to(device)   # [1,3,H,W]
        tgt = tgt.to(device)   # [1,1,H,W]

        with torch.no_grad():
            mu = swin_unetr(ctx)          # [1,1,H,W]
            ctx_last = ctx[:, -1:, :, :]  # [1,1,H,W]

        cond = torch.cat([ctx_last, mu], dim=1)  # [1,2,H,W]

        B = tgt.shape[0]
        t_idx = torch.randint(0, T, (B,), device=device)
        eps = torch.randn_like(tgt)

        a_bar = alpha_bar[t_idx].view(B, 1, 1, 1)
        x_t = torch.sqrt(a_bar) * tgt + torch.sqrt(1.0 - a_bar) * eps

        t_channel = (t_idx.float() / (T - 1)).view(B, 1, 1, 1)
        t_channel = t_channel.expand_as(tgt)

        x_in = torch.cat([x_t, cond, t_channel], dim=1)  # [1,4,H,W]
        eps_hat = diff_head(x_in)

        mask = (tgt > 0).float()
        sq_err = (eps_hat - eps) ** 2 * mask
        loss = sq_err.sum() / (mask.sum() + 1e-8)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        n_batches += 1

    avg_loss = running_loss / max(1, n_batches)
    print(f"[Diffusion] Epoch {epoch:02d} | train noise-MSE(masked): {avg_loss:.4f}")


Defined diffusion schedule with T = 20
