In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from pathlib import Path

import numpy as np
import imageio.v2 as imageio
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import cv2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

BASE_DIR = Path("/content/drive/Shareddrives/TissueMotionForecasting")
TRAIN_ROOT = BASE_DIR / "scared_data" / "train"
print("TRAIN_ROOT:", TRAIN_ROOT)

DISP_SCALE = 256.0


Mounted at /content/drive
Using device: cuda
TRAIN_ROOT: /content/drive/Shareddrives/TissueMotionForecasting/scared_data/train


In [None]:
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super().__init__()

        self.enc1 = DoubleConv(in_channels, 32)
        self.enc2 = DoubleConv(32, 64)
        self.enc3 = DoubleConv(64, 128)
        self.enc4 = DoubleConv(128, 256)

        self.pool = nn.MaxPool2d(2)

        self.bottleneck = DoubleConv(256, 512)

        self.up4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(256 + 256, 256)

        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(128 + 128, 128)

        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(64 + 64, 64)

        self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(32 + 32, 32)

        self.out_conv = nn.Conv2d(32, out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        b = self.bottleneck(self.pool(e4))

        d4 = self.up4(b)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)

        d3 = self.up3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)

        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        out = self.out_conv(d1)
        return out


In [None]:
class HorizonHead(nn.Module):
    """
    Takes:
      - last disparity frame D_t          [B,1,H,W]
      - UNet's t+5 forecast mu_5         [B,1,H,W]
      - horizon channel h_norm           [B,1,H,W]
    Concatenated: [B,3,H,W] -> predicts D_{t+h}.
    """
    def __init__(self, in_ch=3, hidden=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, hidden, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden, hidden, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden, 1, 1),
        )

    def forward(self, x):
        return self.net(x)


In [None]:
# load UNet backbone (frozen)
backbone = UNet(in_channels=3, out_channels=1).to(device)

unet_ckpt_path = BASE_DIR / "models" / "unet_forecast_kf1to3_7_epochs.pth"
print("Loading UNet from:", unet_ckpt_path)
unet_state = torch.load(str(unet_ckpt_path), map_location=device)
backbone.load_state_dict(unet_state)
backbone.eval()
for p in backbone.parameters():
    p.requires_grad = False

print("UNet loaded & frozen.")

# load HorizonHead
horizon_head = HorizonHead(in_ch=3, hidden=32).to(device)
h_ckpt_path = BASE_DIR / "models" / "horizon_head_after_epoch3.pth"
print("Loading HorizonHead from:", h_ckpt_path)
h_state = torch.load(str(h_ckpt_path), map_location=device)
horizon_head.load_state_dict(h_state["state_dict"])
H_MAX = h_state["H_MAX"]
horizon_head.eval()

print("HorizonHead loaded for inference. H_MAX =", H_MAX)


Loading UNet from: /content/drive/Shareddrives/TissueMotionForecasting/models/unet_forecast_kf1to3_7_epochs.pth
UNet loaded & frozen.
Loading HorizonHead from: /content/drive/Shareddrives/TissueMotionForecasting/models/horizon_head_after_epoch3.pth
HorizonHead loaded for inference. H_MAX = 9


In [None]:
def load_disp_float(path: Path):
    """uint16 disparity PNG -> float disp (divided by DISP_SCALE)."""
    arr = imageio.imread(str(path)).astype(np.float32)
    return arr / DISP_SCALE

def disp_to_raft_color_raft_style(disp_float):
    """Match the RAFT-style coloring from step1 notebook."""
    disp = np.asarray(disp_float, dtype=np.float32).copy()

    invalid = (disp <= 0)
    disp[invalid] = np.nan

    valid = np.isfinite(disp)
    if valid.any():
        vmin, vmax = np.nanpercentile(disp[valid], [5, 95])
    else:
        vmin, vmax = 0.0, 1.0

    norm = np.clip((disp - vmin) / (vmax - vmin + 1e-6), 0.0, 1.0)
    norm[~valid] = 0.0

    turbo = cm.get_cmap("turbo")
    color = (turbo(norm)[..., :3] * 255.0).astype(np.uint8)
    return color


def color_gt_and_pred(gt_disp, pred_disp):
    """
    Make GT and Pred share:
      - the same valid mask (gt>0)
      - the same vmin/vmax (from GT percentiles)

    Returns:
      gt_color, pred_color: uint8 [H,W,3] images in RAFT-style 'turbo' colors.
    """
    gt = np.asarray(gt_disp, np.float32)
    pred = np.asarray(pred_disp, np.float32)

    mask = gt > 0

    if np.any(mask):
        vals = gt[mask]
        vmin, vmax = np.percentile(vals, (5, 95))
        if vmax <= vmin:
            vmax = vmin + 1e-6
    else:
        vmin, vmax = 0.0, 1.0

    def _colorize(d):
        d = np.asarray(d, np.float32)
        d = np.clip(d, vmin, vmax)
        norm = (d - vmin) / (vmax - vmin + 1e-8)

        norm[~mask] = 0.0

        turbo = cm.get_cmap("turbo")
        col = turbo(norm)[..., :3]
        col[~mask] = 0.0
        return (col * 255.0).astype(np.uint8)

    return _colorize(gt), _colorize(pred)


def latency_to_horizon(eta_ms: float, dt_ms: float, allowed_horizons=(3,5,7,9)):
    """
    eta_ms: end-to-end display delay in ms
    dt_ms : frame interval in ms (e.g. 40ms for 25 fps)
    returns: chosen horizon h in allowed_horizons
    """
    n0 = eta_ms / dt_ms
    allowed = sorted(allowed_horizons)
    best_h = min(allowed, key=lambda h: (abs(h - n0), -h))
    return best_h

chosen_h = latency_to_horizon(eta_ms=160.0, dt_ms=20.0)
print("Example latency mapping → horizon:", chosen_h)


Example latency mapping → horizon: 9


In [None]:
@torch.no_grad()
def sample_future_with_uncertainty(ctx_tensor, h, K=5, sigma=0.02):
    """
    ctx_tensor: [1,3,H,W] disparity context (t-2,t-1,t)
    h: int horizon (e.g. 3,5,7,9)
    K: number of MC samples
    sigma: noise scale in disparity units

    Returns:
      mean_pred: [1,1,H,W]
      var_pred : [1,1,H,W]
    """
    backbone.eval()
    horizon_head.eval()

    ctx_tensor = ctx_tensor.to(device)

    mu_5 = backbone(ctx_tensor)              # [1,1,H,W]
    ctx_last = ctx_tensor[:, -1:, :, :]      # [1,1,H,W]

    h_norm = (torch.tensor(float(h), device=device) / H_MAX)
    h_norm = h_norm.view(1,1,1,1).expand_as(mu_5)  # [1,1,H,W]

    preds = []
    for k in range(K):
        eps = torch.randn_like(mu_5) * sigma
        mu_5_perturbed = mu_5 + eps
        x_in = torch.cat([ctx_last, mu_5_perturbed, h_norm], dim=1)  # [1,3,H,W]
        pred_k = horizon_head(x_in)  # [1,1,H,W]
        preds.append(pred_k)

    preds = torch.stack(preds, dim=0)   # [K,1,1,H,W]
    mean_pred = preds.mean(dim=0)       # [1,1,H,W]
    var_pred  = preds.var(dim=0, unbiased=False)  # [1,1,H,W]

    return mean_pred, var_pred


In [None]:
import imageio.v2 as imageio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from IPython.display import display, clear_output
import time


DS_NAME   = "dataset_4"
KF_ID     = 1
eta_ms    = 160.0
dt_ms     = 20
H_STAR = latency_to_horizon(eta_ms, dt_ms, allowed_horizons=(3,5,7,9))
n0 = eta_ms / dt_ms
print(f"η={eta_ms} ms, Δt={dt_ms} ms → n0={n0:.2f}, using horizon H_STAR={H_STAR}")


K_SAMPLES = 5
NOISE_SIGMA = 0.1
STEP      = 1
SLEEP_SEC = 0
SHOW_LIVE = False

MAX_FRAMES = 100

disp_dir = TRAIN_ROOT / DS_NAME / f"keyframe_{KF_ID}" / "data" / "disparity"
print("Using disparity dir:", disp_dir)

disp_paths = sorted(
    p for p in disp_dir.glob("*.png")
    if not p.name.startswith("colored")
)

num_frames_total = len(disp_paths)
print("Total disparity frames on disk:", num_frames_total)

num_frames = min(num_frames_total, MAX_FRAMES)
disp_paths = disp_paths[:num_frames]
print("Using disparity frames:", num_frames)

if num_frames == 0:
    raise RuntimeError("No disparity PNGs found.")


all_disp = [load_disp_float(p) for p in disp_paths]


rgb_path = TRAIN_ROOT / DS_NAME / f"keyframe_{KF_ID}" / "data" / "rgb.mp4"
print("RGB video:", rgb_path)
rgb_reader = imageio.get_reader(str(rgb_path), "ffmpeg")

left_rgb = []
for i in range(num_frames):
    frame = rgb_reader.get_data(i)
    H_rgb, W_rgb, _ = frame.shape
    mid = H_rgb // 2
    left = frame[:mid, :, :]
    left_rgb.append(left.astype(np.uint8))

rgb_reader.close()

min_t = 2
max_t = num_frames - H_STAR - 1
if max_t <= min_t:
    raise RuntimeError("Sequence too short for this horizon.")
print(f"Streaming from t={min_t} to t={max_t} (step={STEP})")

t_center = (min_t + max_t) // 2
ctx_sample = [all_disp[i] for i in [t_center-2, t_center-1, t_center]]
gt_sample  = all_disp[t_center + H_STAR]

frames_for_range = ctx_sample + [gt_sample]
all_vals = []
valid_union = np.zeros_like(gt_sample, dtype=bool)

for f in frames_for_range:
    f = np.asarray(f, dtype=np.float32)
    m = f > 0
    if np.any(m):
        valid_union |= m
        all_vals.append(f[m])

all_vals = np.concatenate(all_vals)
vmin, vmax = np.percentile(all_vals, (5, 95))
if vmax <= vmin:
    vmax = vmin + 1e-6

print("Global color range vmin, vmax:", vmin, vmax)

η=160.0 ms, Δt=20 ms → n0=8.00, using horizon H_STAR=9
Using disparity dir: /content/drive/Shareddrives/TissueMotionForecasting/scared_data/train/dataset_4/keyframe_1/data/disparity
Total disparity frames on disk: 728
Using disparity frames: 100
RGB video: /content/drive/Shareddrives/TissueMotionForecasting/scared_data/train/dataset_4/keyframe_1/data/rgb.mp4
Streaming from t=2 to t=90 (step=1)
Global color range vmin, vmax: 31.16796875 51.3671875


In [None]:
video_dir = BASE_DIR / "videos"
video_dir.mkdir(parents=True, exist_ok=True)

FPS = 10.0
REPEATS = 1

video_path = video_dir / f"stream_ds{DS_NAME}_kf{KF_ID}_h{H_STAR}.mp4"
writer = imageio.get_writer(str(video_path), fps=FPS)
print("Saving streaming video to:", video_path)

Saving streaming video to: /content/drive/Shareddrives/TissueMotionForecasting/videos/stream_dsdataset_4_kf1_h9.mp4


In [None]:
backbone.eval()
horizon_head.eval()

for t in range(min_t, max_t, STEP):

    ctx_indices = [t-2, t-1, t]
    ctx_np = np.stack([all_disp[i] for i in ctx_indices], axis=0)   # [3,H,W]
    rgb_ctx = [left_rgb[i] for i in ctx_indices]

    gt_t = all_disp[t + H_STAR]                                    # [H,W]
    rgb_future = left_rgb[t + H_STAR]                              # RGB at t+H*

    ctx_tensor = torch.from_numpy(ctx_np).unsqueeze(0).float().to(device)
    mean_pred, var_pred = sample_future_with_uncertainty(
        ctx_tensor, h=H_STAR, K=K_SAMPLES, sigma=NOISE_SIGMA
    )
    pred_np = mean_pred[0,0].cpu().numpy()
    var_np  = var_pred[0,0].cpu().numpy()

    valid_mask = gt_t > 0

    if np.any(valid_mask):
        abs_err = np.abs(pred_np - gt_t)
        abs_err_valid = abs_err[valid_mask]

        mae = abs_err_valid.mean()

        rmse = np.sqrt((abs_err_valid ** 2).mean())
        absrel_disp = (abs_err_valid / (np.abs(gt_t[valid_mask]) + 1e-8)).mean()

        unc_mean = np.nanmean(var_np[valid_mask])
        unc_max  = np.nanmax(var_np[valid_mask])

        err_map = abs_err
        err_map[~valid_mask] = np.nan
        vmax_err = np.nanpercentile(err_map, 95)
    else:
        mae = rmse = absrel_disp = np.nan
        unc_mean = unc_max = np.nan
        err_map = np.full_like(pred_np, np.nan, dtype=np.float32)
        vmax_err = 1.0

    ctx_disp_colors = [disp_to_raft_color_raft_style(f) for f in ctx_np]
    gt_color, pred_color = color_gt_and_pred(gt_t, pred_np)

    #plotting
    n_rows, n_cols = 5, 3
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(11, 14))

    # Row 1: RGB t-2, t-1, t
    for i, ax in enumerate(axes[0]):
        ax.imshow(rgb_ctx[i])
        ax.set_title(["RGB t-2", "RGB t-1", "RGB t"][i])
        ax.axis("off")

    # Row 2: Disp t-2, t-1, t
    for i, ax in enumerate(axes[1]):
        ax.imshow(ctx_disp_colors[i])
        ax.set_title(["Disp t-2", "Disp t-1", "Disp t"][i])
        ax.axis("off")

    # Row 3: RGB t+H*, GT t+H*, Pred t+H*
    ax_rf, ax_gt, ax_pred = axes[2]
    ax_rf.imshow(rgb_future)
    ax_rf.set_title(f"RGB t+{H_STAR}")
    ax_rf.axis("off")

    ax_gt.imshow(gt_color)
    ax_gt.set_title(f"GT t+{H_STAR}")
    ax_gt.axis("off")

    ax_pred.imshow(pred_color)
    ax_pred.set_title(f"Pred t+{H_STAR}")
    ax_pred.axis("off")

    # Row 4: |Pred-GT|, variance map, blank
    ax_err, ax_varmap, ax_dispdiff = axes[3]

    # |Disp(t) - Disp(t+H*)|
    disp_diff = np.abs(all_disp[t] - gt_t)     # H,W
    disp_diff[~valid_mask] = np.nan
    vmax_dd = np.nanpercentile(disp_diff, 95) if np.any(valid_mask) else 1.0

    im_dd = ax_dispdiff.imshow(disp_diff, cmap="coolwarm", vmin=0, vmax=vmax_dd)
    ax_dispdiff.set_title(f"|Disp(t) - Disp(t+{H_STAR})|")
    ax_dispdiff.axis("off")
    fig.colorbar(im_dd, ax=ax_dispdiff, fraction=0.046, pad=0.04)

    # Error |Pred - GT|
    im_err = ax_err.imshow(err_map, cmap="magma", vmin=0, vmax=vmax_err)
    ax_err.set_title("|Pred - GT|")
    ax_err.axis("off")
    fig.colorbar(im_err, ax=ax_err, fraction=0.046, pad=0.04)

    # Variance
    vmax_u = np.nanpercentile(var_np[valid_mask], 95) if np.any(valid_mask) else 1.0
    im_u = ax_varmap.imshow(var_np, cmap="viridis", vmin=0, vmax=vmax_u)
    ax_varmap.set_title("Variance σ²")
    ax_varmap.axis("off")
    fig.colorbar(im_u, ax=ax_varmap, fraction=0.046, pad=0.04)

    # Row 5: Metrics text

    for ax in axes[4]:
      ax.axis("off")

    metrics_text = (
        f"MAE (EPE): {mae:.3f} RMSE: {rmse:.3f} AbsRel-d: {absrel_disp:.3f}  \n"
        f"mean σ²: {unc_mean:.2e}  max σ²: {unc_max:.2e}"
    )

    axes[4][0].text(
        0.01, 0.5, metrics_text,
        fontsize=12,
        ha="left", va="center",
        transform=axes[4][0].transAxes
    )

    fig.suptitle(
    f"Streaming @ t={t}, horizon={H_STAR} "
    f"(η={eta_ms:.1f}ms, Δt={dt_ms:.1f}ms, n*={H_STAR}), K={K_SAMPLES}",
    fontsize=14
    )

    # showing in notebook
    # clear_output(wait=True)
    # display(fig)

    if SHOW_LIVE:
        clear_output(wait=True)
        display(fig)
        plt.pause(0.001)

    fig.canvas.draw()
    w, h = fig.canvas.get_width_height()
    buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
    buf = buf.reshape((h, w, 4))          # [H, W, 4]
    frame_img = buf[..., :3]              # drop alpha → [H, W, 3]

    # for _ in range(REPEATS):
    writer.append_data(frame_img)

    plt.close(fig)
    time.sleep(SLEEP_SEC)

writer.close()
print("Streaming finished. Video saved to:", video_path)

  turbo = cm.get_cmap("turbo")
  turbo = cm.get_cmap("turbo")


Streaming finished. Video saved to: /content/drive/Shareddrives/TissueMotionForecasting/videos/stream_dsdataset_4_kf1_h9.mp4
