In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os, glob
from pathlib import Path

BASE_DIR = Path("/content/drive/Shareddrives/TissueMotionForecasting")
TRAIN_ROOT = BASE_DIR / "scared_data" / "train"

print("BASE_DIR  :", BASE_DIR)
print("TRAIN_ROOT:", TRAIN_ROOT)

DATASETS_TO_USE = ["dataset_1", "dataset_2", "dataset_3"]
KEYFRAME_NAMES = [f"keyframe_{i}" for i in range(1, 6)]

KEYFRAMES_TO_USE = []

print("\nScanning datasets and keyframes:")
for ds_name in DATASETS_TO_USE:
    ds_root = TRAIN_ROOT / ds_name
    kfs = sorted(ds_root.glob("keyframe_*"))
    print(f"\n[{ds_name}] all keyframes:")
    for kf in kfs:
        print(" ", kf)

    for kf in kfs:
        if kf.name in KEYFRAME_NAMES:
            KEYFRAMES_TO_USE.append(kf)

print("\nKeyframes to use (all datasets, kf1–5):")
for kf in KEYFRAMES_TO_USE:
    print(" ", kf)

print("Total keyframes used:", len(KEYFRAMES_TO_USE))


for kf in KEYFRAMES_TO_USE:
    disp_dir = kf / "data" / "disparity"
    raw_files = sorted([f for f in disp_dir.glob("*.png")
                        if not f.name.startswith("colored")])
    colored_files = sorted(disp_dir.glob("colored_*.png"))

    print(f"\n[{kf.name}]")
    print("  disparity dir :", disp_dir)
    print("  RAW disparity :", len(raw_files))
    print("  COLORED       :", len(colored_files))

    if raw_files:
        print("   example RAW     :", raw_files[0].name)
    if colored_files:
        print("   example COLORED :", colored_files[0].name)



Mounted at /content/drive
BASE_DIR  : /content/drive/Shareddrives/TissueMotionForecasting
TRAIN_ROOT: /content/drive/Shareddrives/TissueMotionForecasting/scared_data/train

Scanning datasets and keyframes:

[dataset_1] all keyframes:
  /content/drive/Shareddrives/TissueMotionForecasting/scared_data/train/dataset_1/keyframe_1
  /content/drive/Shareddrives/TissueMotionForecasting/scared_data/train/dataset_1/keyframe_2
  /content/drive/Shareddrives/TissueMotionForecasting/scared_data/train/dataset_1/keyframe_3
  /content/drive/Shareddrives/TissueMotionForecasting/scared_data/train/dataset_1/keyframe_4
  /content/drive/Shareddrives/TissueMotionForecasting/scared_data/train/dataset_1/keyframe_5

[dataset_2] all keyframes:
  /content/drive/Shareddrives/TissueMotionForecasting/scared_data/train/dataset_2/keyframe_1
  /content/drive/Shareddrives/TissueMotionForecasting/scared_data/train/dataset_2/keyframe_2
  /content/drive/Shareddrives/TissueMotionForecasting/scared_data/train/dataset_2/keyfr

In [None]:
import torch
from torch.utils.data import Dataset

CONTEXT_LEN = 3
FORECAST_HORIZON = 5
DISP_SCALE = 256.0


class DisparityForecastDataset(Dataset):
    """
    Uses RAW disparity from:
      DATA_ROOT/keyframe_*/data/disparity/*.png

    - Ignores colored_*.png
    - Input : CONTEXT_LEN past frames  -> tensor [C, H, W]
    - Target: frame at t + FORECAST_HORIZON -> tensor [1, H, W]
    """
    def __init__(self, data_root, keyframes_paths,
                 context_len=3, forecast_horizon=1, scale=256.0):
        self.samples = []
        self.context_len = context_len
        self.forecast_horizon = forecast_horizon
        self.scale = scale

        for kf in keyframes_paths:
            disp_dir = kf / "data" / "disparity"


            frame_paths = sorted([
                p for p in disp_dir.glob("*.png")
                if not p.name.startswith("colored")
            ])

            if len(frame_paths) < context_len + forecast_horizon:
                continue


            for i in range(context_len - 1,
                           len(frame_paths) - forecast_horizon):
                ctx_paths = frame_paths[i - (context_len - 1): i + 1]
                tgt_path = frame_paths[i + forecast_horizon]
                self.samples.append((ctx_paths, tgt_path))

        print(f"DisparityForecastDataset: {len(self.samples)} samples")

    def _load_disp(self, path):
        from PIL import Image
        import numpy as np

        img = Image.open(path).convert("I")
        arr = np.array(img, dtype=np.float32)


        arr = arr / self.scale
        return arr

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        ctx_paths, tgt_path = self.samples[idx]


        ctx_frames = [self._load_disp(p) for p in ctx_paths]  # list of [H,W]
        import numpy as np
        ctx = np.stack(ctx_frames, axis=0)   # [C,H,W]

        tgt = self._load_disp(tgt_path)     # [H,W]

        ctx = torch.from_numpy(ctx.astype("float32"))         # [C,H,W]
        tgt = torch.from_numpy(tgt.astype("float32")).unsqueeze(0)  # [1,H,W]

        return ctx, tgt


dataset = DisparityForecastDataset(
    data_root=TRAIN_ROOT,
    keyframes_paths=KEYFRAMES_TO_USE,
    context_len=CONTEXT_LEN,
    forecast_horizon=FORECAST_HORIZON,
    scale=DISP_SCALE,
)

print("Total samples in dataset:", len(dataset))

ctx_sample, tgt_sample = dataset[0]
print("Context shape:", ctx_sample.shape)
print("Target shape :", tgt_sample.shape)
print("Context min/max:", ctx_sample.min().item(), ctx_sample.max().item())
print("Target  min/max:", tgt_sample.min().item(), tgt_sample.max().item())


DisparityForecastDataset: 7952 samples
Total samples in dataset: 7952
Context shape: torch.Size([3, 1024, 1280])
Target shape : torch.Size([1, 1024, 1280])
Context min/max: 0.0 60.48828125
Target  min/max: 0.0 57.32421875


In [None]:
from torch.utils.data import random_split, DataLoader

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_ds, val_ds = random_split(dataset, [train_size, val_size])

print("Train samples:", len(train_ds))
print("Val samples  :", len(val_ds))

BATCH_SIZE = 2
NUM_WORKERS = 2

train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True
)

val_loader = DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True
)


ctx, tgt = next(iter(train_loader))
print("Batch ctx shape:", ctx.shape)   # [B, 3, H, W]
print("Batch tgt shape:", tgt.shape)   # [B, 1, H, W]


Train samples: 6361
Val samples  : 1591




Batch ctx shape: torch.Size([2, 3, 1024, 1280])
Batch tgt shape: torch.Size([2, 1, 1024, 1280])


In [None]:

import torch.nn as nn
import torch

class DoubleConv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super().__init__()

        self.enc1 = DoubleConv(in_channels, 32)
        self.enc2 = DoubleConv(32, 64)
        self.enc3 = DoubleConv(64, 128)
        self.enc4 = DoubleConv(128, 256)

        self.pool = nn.MaxPool2d(2)

        self.bottleneck = DoubleConv(256, 512)

        self.up4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(256 + 256, 256)

        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(128 + 128, 128)

        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(64 + 64, 64)

        self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(32 + 32, 32)

        self.out_conv = nn.Conv2d(32, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)           # [B,32,H,W]
        e2 = self.enc2(self.pool(e1))   # [B,64,H/2,W/2]
        e3 = self.enc3(self.pool(e2))   # [B,128,H/4,W/4]
        e4 = self.enc4(self.pool(e3))   # [B,256,H/8,W/8]

        # Bottleneck
        b = self.bottleneck(self.pool(e4))  # [B,512,H/16,W/16]

        # Decoder
        d4 = self.up4(b)                # [B,256,H/8,W/8]
        d4 = torch.cat([d4, e4], dim=1) # [B,512,H/8,W/8]
        d4 = self.dec4(d4)

        d3 = self.up3(d4)               # [B,128,H/4,W/4]
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)

        d2 = self.up2(d3)               # [B,64,H/2,W/2]
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.up1(d2)               # [B,32,H,W]
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        out = self.out_conv(d1)         # [B,1,H,W]
        return out


# instantiate once to check shapes
device = "cuda" if torch.cuda.is_available() else "cpu"
model = UNet(in_channels=3, out_channels=1).to(device)

dummy = torch.randn(1, 3, 1024, 1280).to(device)
with torch.no_grad():
    out = model(dummy)
print("Dummy out shape:", out.shape)


Dummy out shape: torch.Size([1, 1, 1024, 1280])


In [None]:
def masked_l1_loss(pred, target):
    """
    pred, target: [B,1,H,W]
    Only compute L1 on valid disparity (target > 0)
    """
    mask = target > 0

    if mask.sum() == 0:
        # no valid pixels → return zero but keep gradient path
        return (pred - target).mean() * 0.0

    return torch.abs(pred[mask] - target[mask]).mean()


LR = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

print("Loss + optimizer initialized, LR =", LR)


Loss + optimizer initialized, LR = 0.0001


In [None]:
import time

EPOCHS = 7
best_val = float("inf")
save_path = "/content/unet_forecast_kf1to3.pth"

for epoch in range(1, EPOCHS + 1):
    model.train()
    train_loss = 0.0

    start_time = time.time()

    for ctx, tgt in train_loader:
        ctx = ctx.to(device)
        tgt = tgt.to(device)

        pred = model(ctx)
        loss = masked_l1_loss(pred, tgt)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)


    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for ctx, tgt in val_loader:
            ctx = ctx.to(device)
            tgt = tgt.to(device)
            pred = model(ctx)
            val_loss += masked_l1_loss(pred, tgt).item()

    val_loss /= len(val_loader)

    epoch_time = time.time() - start_time

    print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} | Time: {epoch_time:.1f}s")

    # save best model
    if val_loss < best_val:
        best_val = val_loss
        torch.save(model.state_dict(), save_path)
        print(f"  -> Saved best model to {save_path}")


Epoch 01 | Train Loss: 2.7200 | Val Loss: 2.5598 | Time: 3481.8s
  -> Saved best model to /content/unet_forecast_kf1to3.pth
Epoch 02 | Train Loss: 1.4476 | Val Loss: 1.6169 | Time: 2555.0s
  -> Saved best model to /content/unet_forecast_kf1to3.pth
Epoch 03 | Train Loss: 1.3997 | Val Loss: 1.2574 | Time: 2548.2s
  -> Saved best model to /content/unet_forecast_kf1to3.pth
Epoch 04 | Train Loss: 1.3399 | Val Loss: 1.5327 | Time: 2540.4s
Epoch 05 | Train Loss: 1.4732 | Val Loss: 1.6378 | Time: 2527.1s
Epoch 06 | Train Loss: 1.4247 | Val Loss: 1.3590 | Time: 2522.3s
Epoch 07 | Train Loss: 1.3888 | Val Loss: 1.2169 | Time: 2528.5s
  -> Saved best model to /content/unet_forecast_kf1to3.pth


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = UNet(in_channels=3, out_channels=1).to(device)

CKPT_PATH = "/content/drive/Shareddrives/TissueMotionForecasting/models/unet_forecast_kf1to3_7_epochs.pth"  # adjust if saved elsewhere
state_dict = torch.load(CKPT_PATH, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()

print("Loaded UNet checkpoint and set to eval().")

Loaded UNet checkpoint and set to eval().


In [None]:
import imageio.v2 as imageio
import numpy as np
import matplotlib.cm as cm
import cv2
from pathlib import Path

SCALE_FACTOR = DISP_SCALE

def load_disp_float(path):
    """Load uint16 disparity PNG -> float (already /256)."""
    disp_u16 = imageio.imread(path)
    disp = disp_u16.astype(np.float32) / SCALE_FACTOR
    return disp

def disp_to_raft_color_fixed_range(disp_float, vmin, vmax, valid_mask):
    """
    Apply RAFT-style turbo colormap with fixed vmin/vmax and GT-based mask.
    disp_float: [H,W] float
    valid_mask: [H,W] bool
    """
    d = np.asarray(disp_float, dtype=np.float32)
    d_clipped = np.clip(d, vmin, vmax)
    d_norm = (d_clipped - vmin) / (vmax - vmin + 1e-8)
    d_norm[~valid_mask] = np.nan

    turbo = cm.get_cmap("turbo")
    colored = turbo(d_norm)[:, :, :3]
    colored = np.nan_to_num(colored) * 255.0
    return colored.astype(np.uint8)


In [None]:
CONTEXT_LEN = 3
HORIZON = 5

DATA_ROOT = BASE_DIR / "scared_data" / "train" / "dataset_4"
kf_id = 1
kf_dir = DATA_ROOT / f"keyframe_{kf_id}"
rgb_path = kf_dir / "data" / "rgb.mp4"
disp_dir = kf_dir / "data" / "disparity"

print("Keyframe dir:", kf_dir)
print("RGB:", rgb_path)
print("Disp dir:", disp_dir)

# disparity PNGs (raw only)
disp_paths = sorted(
    p for p in disp_dir.glob("*.png")
    if not p.name.startswith("colored")
)
print("Num disparity frames:", len(disp_paths))

# RGB video reader (stacked vertically: top=left, bottom=right)
rgb_reader = imageio.get_reader(str(rgb_path), "ffmpeg")
num_rgb_frames = rgb_reader.count_frames()
print("Num RGB frames:", num_rgb_frames)

# Output video writer
out_dir = BASE_DIR / "videos"
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / "dt4_kf1_left-right_gt-pred_disp_unet_7epochs.mp4"
writer = imageio.get_writer(str(out_path), fps=10)
print("Writing to:", out_path)

model.eval()

# t must allow t-2, t-1, t and t+hto exist in both disparity and rgb
max_t = min(len(disp_paths), num_rgb_frames) - HORIZON
start_t = CONTEXT_LEN - 1

for t in range(start_t, max_t):
    rgb_frame = rgb_reader.get_data(t)
    H_rgb, W_rgb, _ = rgb_frame.shape
    mid = H_rgb // 2
    left_t  = rgb_frame[:mid, :, :]
    right_t = rgb_frame[mid:, :, :]


    ctx_indices = [t - 2, t - 1, t]
    ctx_disps = [load_disp_float(disp_paths[i]) for i in ctx_indices]  # list of [Hc,Wc]
    ctx_arr = np.stack(ctx_disps, axis=0)                               # [3,Hc,Wc]
    ctx_tensor = torch.from_numpy(ctx_arr).unsqueeze(0).float().to(device)

    tgt_idx = t + HORIZON
    gt_disp = load_disp_float(disp_paths[tgt_idx])   # [Hc,Wc]
    valid_mask = gt_disp > 0

    with torch.no_grad():
        pred_tensor = model(ctx_tensor)         # [1,1,Hc,Wc]
    pred_disp = pred_tensor[0, 0].cpu().numpy()


    gt_vals = gt_disp[valid_mask]
    vmin, vmax = np.percentile(gt_vals, (5, 95))
    if vmax <= vmin:
        vmax = vmin + 1e-6

    gt_color   = disp_to_raft_color_fixed_range(gt_disp,   vmin, vmax, valid_mask)
    pred_color = disp_to_raft_color_fixed_range(pred_disp, vmin, vmax, valid_mask)

    Hc, Wc, _ = gt_color.shape
    target_h = left_t.shape[0]
    scale = target_h / float(Hc)
    new_w = int(Wc * scale)

    gt_color_res   = cv2.resize(gt_color,   (new_w, target_h), interpolation=cv2.INTER_LINEAR)
    pred_color_res = cv2.resize(pred_color, (new_w, target_h), interpolation=cv2.INTER_LINEAR)

    left_panel  = left_t.astype(np.uint8)
    right_panel = right_t.astype(np.uint8)

    frame_out = np.concatenate(
        [left_panel, right_panel, gt_color_res, pred_color_res],
        axis=1
    )

    writer.append_data(frame_out)

writer.close()
rgb_reader.close()
print("Done, video saved at:", out_path)


Keyframe dir: /content/drive/Shareddrives/TissueMotionForecasting/scared_data/train/dataset_4/keyframe_1
RGB: /content/drive/Shareddrives/TissueMotionForecasting/scared_data/train/dataset_4/keyframe_1/data/rgb.mp4
Disp dir: /content/drive/Shareddrives/TissueMotionForecasting/scared_data/train/dataset_4/keyframe_1/data/disparity
Num disparity frames: 728
Num RGB frames: 728
Writing to: /content/drive/Shareddrives/TissueMotionForecasting/videos/dt4_kf1_left-right_gt-pred_disp_unet_7epochs.mp4


  turbo = cm.get_cmap("turbo")


Done, video saved at: /content/drive/Shareddrives/TissueMotionForecasting/videos/dt4_kf1_left-right_gt-pred_disp_unet_7epochs.mp4
