In [None]:
!pip install omegaconf


Collecting omegaconf
  Downloading omegaconf-2.3.0-py3-none-any.whl.metadata (3.9 kB)
Collecting antlr4-python3-runtime==4.9.* (from omegaconf)
  Downloading antlr4-python3-runtime-4.9.3.tar.gz (117 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.0/117.0 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading omegaconf-2.3.0-py3-none-any.whl (79 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.5/79.5 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: antlr4-python3-runtime
  Building wheel for antlr4-python3-runtime (setup.py) ... [?25l[?25hdone
  Created wheel for antlr4-python3-runtime: filename=antlr4_python3_runtime-4.9.3-py3-none-any.whl size=144554 sha256=74c336c80062f3841396b460841068f5f769c1e0eca5e98ff9c3736e81583ebe
  Stored in directory: /root/.cache/pip/wheels/1a/97/32/461f837398029ad76911109f07047fde1d7b661a147c7c56d1
Successfull

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Schedulers
from enum import auto, Enum
import math


class LRSchedule(Enum):
    Constant = auto()
    Cosine = auto()


class Scheduler:
    def __init__(
        self,
        schedule: str,
        base_lr: float,
        data_loader,
        epochs: int,
        optimizer,
        batch_steps=None,
        batch_size=None,
    ):
        self.schedule = schedule
        self.base_lr = base_lr
        self.data_loader = data_loader
        self.epochs = epochs
        self.optimizer = optimizer

        if batch_size is None:
            self.batch_size = data_loader.config.batch_size
        else:
            self.batch_size = batch_size

        if batch_steps is None:
            self.batch_steps = len(data_loader)
        else:
            self.batch_steps = batch_steps

    def adjust_learning_rate(self, step: int):
        if self.schedule == LRSchedule.Constant:
            return self.base_lr
        else:
            max_steps = self.epochs * self.batch_steps
            warmup_steps = int(0.10 * max_steps)
            for param_group in self.optimizer.param_groups:
                base_lr = (
                    param_group["base_lr"] if "base_lr" in param_group else self.base_lr
                )
                base_lr = base_lr * self.batch_size / 256
                if step < warmup_steps:
                    lr = base_lr * step / warmup_steps
                else:
                    step -= warmup_steps
                    max_steps -= warmup_steps
                    q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
                    end_lr = base_lr * 0.001
                    lr = base_lr * q + end_lr * (1 - q)
                param_group["lr"] = lr
            return lr
## Models
from typing import List
import numpy as np
from torch import nn
from torch.nn import functional as F
import torch


def build_mlp(layers_dims: List[int]):
    layers = []
    for i in range(len(layers_dims) - 2):
        layers.append(nn.Linear(layers_dims[i], layers_dims[i + 1]))
        layers.append(nn.BatchNorm1d(layers_dims[i + 1]))
        layers.append(nn.ReLU(True))
    layers.append(nn.Linear(layers_dims[-2], layers_dims[-1]))
    return nn.Sequential(*layers)


class MockModel(torch.nn.Module):
    """
    Does nothing. Just for testing.
    """

    def __init__(self, device="cuda", output_dim=256):
        super().__init__()
        self.device = device
        self.repr_dim = output_dim

    def forward(self, states, actions):
        """
        Args:
            During training:
                states: [B, T, Ch, H, W]
            During inference:
                states: [B, 1, Ch, H, W]
            actions: [B, T-1, 2]

        Output:
            predictions: [B, T, D]
        """
        B, T, _ = actions.shape

        return torch.randn((B, T + 1, self.repr_dim)).to(self.device)


class Prober(torch.nn.Module):
    def __init__(
        self,
        embedding: int,
        arch: str,
        output_shape: List[int],
    ):
        super().__init__()
        self.output_dim = np.prod(output_shape)
        self.output_shape = output_shape
        self.arch = arch

        arch_list = list(map(int, arch.split("-"))) if arch != "" else []
        f = [embedding] + arch_list + [self.output_dim]
        layers = []
        for i in range(len(f) - 2):
            layers.append(torch.nn.Linear(f[i], f[i + 1]))
            layers.append(torch.nn.ReLU(True))
        layers.append(torch.nn.Linear(f[-2], f[-1]))
        self.prober = torch.nn.Sequential(*layers)

    def forward(self, e):
        output = self.prober(e)
        return output

###Configs
import argparse
import dataclasses
from dataclasses import dataclass
from enum import Enum
from typing import Any, Iterable, Tuple, Union, cast, List

from omegaconf import OmegaConf

DataClass = Any
DataClassType = Any


@dataclass
class ConfigBase:
    """Base class that should handle parsing from command line,
    json, dicts.
    """

    @classmethod
    def parse_from_command_line(cls):
        return omegaconf_parse(cls)

    @classmethod
    def parse_from_file(cls, path: str):
        oc = OmegaConf.load(path)
        return cls.parse_from_dict(OmegaConf.to_container(oc))

    @classmethod
    def parse_from_command_line_deprecated(cls):
        result = DataclassArgParser(
            cls, fromfile_prefix_chars="@"
        ).parse_args_into_dataclasses()
        if len(result) > 1:
            raise RuntimeError(
                f"The following arguments were not recognized: {result[1:]}"
            )
        return result[0]

    @classmethod
    def parse_from_dict(cls, inputs):
        return DataclassArgParser._populate_dataclass_from_dict(cls, inputs.copy())

    @classmethod
    def parse_from_flat_dict(cls, inputs):
        return DataclassArgParser._populate_dataclass_from_flat_dict(cls, inputs.copy())

    def save(self, path: str):
        with open(path, "w") as f:
            OmegaConf.save(config=self, f=f)

### dataset
from typing import NamedTuple, Optional
import torch
import numpy as np


class WallSample(NamedTuple):
    states: torch.Tensor
    locations: torch.Tensor
    actions: torch.Tensor


class WallDataset:
    def __init__(
        self,
        data_path,
        probing=False,
        device="cuda",
    ):
        self.device = device
        self.states = np.load(f"{data_path}/states.npy", mmap_mode="r")
        self.actions = np.load(f"{data_path}/actions.npy")

        if probing:
            self.locations = np.load(f"{data_path}/locations.npy")
        else:
            self.locations = None

    def __len__(self):
        return len(self.states)

    def __getitem__(self, i):
        states = torch.from_numpy(self.states[i]).float().to(self.device)
        actions = torch.from_numpy(self.actions[i]).float().to(self.device)

        if self.locations is not None:
            locations = torch.from_numpy(self.locations[i]).float().to(self.device)
        else:
            locations = torch.empty(0).to(self.device)

        return WallSample(states=states, locations=locations, actions=actions)


def create_wall_dataloader(
    data_path,
    probing=False,
    device="cuda",
    batch_size=64,
    train=True,
):
    ds = WallDataset(
        data_path=data_path,
        probing=probing,
        device=device,
    )

    loader = torch.utils.data.DataLoader(
        ds,
        batch_size,
        shuffle=train,
        drop_last=True,
        pin_memory=False,
    )

    return loader


### normalizer
import torch


class Normalizer:
    def __init__(self):
        self.location_mean = torch.tensor([31.5863, 32.0618])
        self.location_std = torch.tensor([16.1025, 16.1353])

    def normalize_location(self, location: torch.Tensor) -> torch.Tensor:
        return (location - self.location_mean.to(location.device)) / (
            self.location_std.to(location.device) + 1e-6
        )

    def unnormalize_location(self, location: torch.Tensor) -> torch.Tensor:
        return location * self.location_std.to(location.device) + self.location_mean.to(
            location.device
        )

    def unnormalize_mse(self, mse):
        return mse * (self.location_std.to(mse.device) ** 2)


In [None]:
# ---------------------- Encoder ----------------------
class SimpleChannelEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.path1 = nn.Sequential(
            nn.Conv2d(1,8,3,1,1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(8,16,3,1,1),nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16,32,3,1,1),nn.ReLU(), nn.AdaptiveAvgPool2d((4,4)))
        self.path2 = nn.Sequential(
            nn.Conv2d(1,8,3,1,1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(8,16,3,1,1),nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16,32,3,1,1),nn.ReLU(), nn.AdaptiveAvgPool2d((4,4)))
        self.out_channels = 64
    def forward(self,x):
        return torch.cat([self.path1(x[:,0:1]), self.path2(x[:,1:2])],1)



# ---------------------- Predictor ----------------------
class Predictor(nn.Module):
    def __init__(self, map_channels=64, action_dim=2, ball_dim=32):
        super().__init__()
        self.ball_ch = ball_dim  # 32

        self.fc1 = nn.Linear(action_dim, 64)
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, 128)
        self.fc_out = nn.Linear(128, self.ball_ch * 4 * 4)

        self.act = nn.LeakyReLU()

    def forward(self, z_map, action):
        ball, wall = torch.split(z_map, self.ball_ch, dim=1)  # ball: (B,32,4,4)

        x = self.act(self.fc1(action))   # (B,64)
        x = self.act(self.fc2(x))         # (B,128)

        # --- Residual connection here ---
        residual = x
        x = self.act(self.fc3(x))         # (B,128)
        x = x + residual                  # Residual addition
        # ---------------------------------

        bias = self.fc_out(x)             # (B,32*4*4)
        bias = bias.view(-1, self.ball_ch, 4, 4)  # (B,32,4,4)

        ball_pred = ball + bias  # Inject action influence into ball
        return torch.cat([ball_pred, wall.detach()], dim=1)  # (B,64,4,4)

# ---------------------- VICReg Predictive ----------------------
class VICRegPredictive(nn.Module):
    def __init__(self, embed_dim=64, action_dim=2, sim=25., var=10., cov=0.5):
        super().__init__()
        self.encoder   = SimpleChannelEncoder()
        self.predictor = Predictor(self.encoder.out_channels, action_dim)
        self.proj      = nn.Linear(self.encoder.out_channels*4*4, embed_dim, bias=False)
        self.sim_c, self.var_c, self.cov_c = sim, var, cov

        self.repr_dim = self.encoder.out_channels * 4 * 4   # 64 x 4 x 4 = 1024
    def _proj(self, z_map):
        return self.proj(z_map.flatten(1))
class VICRegPredictive(nn.Module):
    def __init__(self, embed_dim=64, action_dim=2, sim=25., var=10., cov=0.5):
        super().__init__()
        self.encoder   = SimpleChannelEncoder()
        self.predictor = Predictor(self.encoder.out_channels, action_dim)
        self.proj      = nn.Linear(self.encoder.out_channels * 4 * 4, embed_dim, bias=False)

        # Set constants for VICReg loss
        self.sim_c, self.var_c, self.cov_c = sim, var, cov

        # Add repr_dim attribute
        self.repr_dim = self.encoder.out_channels * 4 * 4   # 64 x 4 x 4 = 1024

    def _proj(self, z_map):
        return self.proj(z_map.flatten(1))

    def forward(self, s0, s1=None, a0=None, evaluation=False):
        if evaluation:
            # s0 is (B, 1, C, H, W)  -> initial state
            # a0 is (B, 1, action_dim) -> action sequence
            init_states = s0
            actions = a0

            B, _, C, H, W = init_states.shape
            T = actions.shape[1]

            preds = []
            z = self.encoder(init_states.squeeze(1))  # (B, C, H, W)

            for t in range(T):
                a_t = actions[:, t]  # (B, action_dim)
                z = self.predictor(z, a_t)  # Predict next z
                preds.append(z.unsqueeze(1))  # Store with time dimension

            pred_seq = torch.cat(preds, dim=1)  # (B, T, C, H, W)

            # Flatten spatial dimensions
            pred_seq = pred_seq.reshape(B, T, -1)  # (B, T, D)
            return pred_seq  # you will transpose outside if needed

        else:
            # Normal training mode
            # s0: (B, C, H, W), s1: (B, C, H, W), a0: (B, action_dim)
            z0 = self.encoder(s0)
            z_pred = self.predictor(z0, a0)
            z_next = self.encoder(s1)
            return z_pred, z_next

    def compute_loss(self, states, actions):
        z_pred, z_next = self(states[:,0], states[:,1], actions[:,0])
        zp, zn = self._proj(z_pred), self._proj(z_next)
        inv = F.mse_loss(zp, zn)
        var = (F.relu(1-torch.sqrt(zp.var(0)+1e-4)).mean() +
               F.relu(1-torch.sqrt(zn.var(0)+1e-4)).mean())
        def off(z):
            n,d=z.shape; zc=z-z.mean(0); c=(zc.T@zc)/(n-1)
            return (c-torch.eye(d,device=z.device)).pow(2).sum()/d
        cov = off(zp)+off(zn)
        total = self.sim_c*inv + self.var_c*var + self.cov_c*cov
        return total, dict(inv_loss=inv.item(), var_loss=var.item(), cov_loss=cov.item(), total_loss=total.item())

In [None]:
from typing import NamedTuple, List, Any, Optional, Dict
from itertools import chain
from dataclasses import dataclass
import itertools
import os
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
import numpy as np
from matplotlib import pyplot as plt
from typing import Tuple, Dict, NamedTuple, Optional
from typing import Optional


@dataclass
class ProbingConfig(ConfigBase):
    probe_targets: str = "locations"
    lr: float = 0.0002
    epochs: int = 20
    schedule: LRSchedule = LRSchedule.Cosine
    sample_timesteps: int = 30
    prober_arch: str = "256"


class ProbeResult(NamedTuple):
    model: torch.nn.Module
    average_eval_loss: float
    eval_losses_per_step: List[float]
    plots: List[Any]


default_config = ProbingConfig()


def location_losses(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    #print("pred.shape:", pred.shape)
    #print("target.shape:", target.shape)
    assert pred.shape == target.shape
    mse = (pred - target).pow(2).mean(dim=0)
    return mse


class ProbingEvaluator:
    def __init__(
        self,
        device: "cuda",
        model: torch.nn.Module,
        probe_train_ds,
        probe_val_ds: dict,
        config: ProbingConfig = default_config,
        quick_debug: bool = False,
    ):
        self.device = device
        self.config = config

        self.model = model
        self.model.eval()

        self.quick_debug = quick_debug

        self.ds = probe_train_ds
        self.val_ds = probe_val_ds

        self.normalizer = Normalizer()

    def train_pred_prober(self):
        """
        Probes whether the predicted embeddings capture the future locations
        """
        repr_dim = self.model.repr_dim
        dataset = self.ds
        model = self.model

        config = self.config
        epochs = config.epochs

        if self.quick_debug:
            epochs = 1
        test_batch = next(iter(dataset))

        prober_output_shape = getattr(test_batch, "locations")[0, 0].shape
        prober = Prober(
            repr_dim,
            config.prober_arch,
            output_shape=prober_output_shape,
        ).to(self.device)

        all_parameters = []
        all_parameters += list(prober.parameters())

        optimizer_pred_prober = torch.optim.Adam(all_parameters, config.lr)

        step = 0

        batch_size = dataset.batch_size
        batch_steps = None

        scheduler = Scheduler(
            schedule=self.config.schedule,
            base_lr=config.lr,
            data_loader=dataset,
            epochs=epochs,
            optimizer=optimizer_pred_prober,
            batch_steps=batch_steps,
            batch_size=batch_size,
        )

        for epoch in tqdm(range(epochs), desc=f"Probe prediction epochs"):
            for batch in tqdm(dataset, desc="Probe prediction step"):
                ################################################################################
                # ── Forward pass through VICRegPredictive ──────────────────────────

                #print("batch.states.shape:", batch.states.shape)
                #print("batch.actions.shape:", batch.actions.shape)

                init_states = batch.states[:, 0:1]  # (64, 1, 2, 65, 65)
                actions = batch.actions             # (64, 16, 2)

                #print("init_states.shape:", init_states.shape)
                #print("actions.shape:", actions.shape)

                # Forward through model autoregressively
                pred_encs = model(s0=init_states, a0=actions, evaluation=True)  # (64, 16, 1024)
                #print("pred_encs.shape (before transpose):", pred_encs.shape)

                # Compute initial encoding manually
                with torch.no_grad():
                    init_enc = model.encoder(init_states.squeeze(1))  # (64, 64, 4, 4)
                    init_enc = init_enc.flatten(start_dim=1)           # (64, 1024)

                #print("init_enc.shape:", init_enc.shape)  # (64, 1024)

                # Transpose pred_encs to (T, B, D)
                pred_encs = pred_encs.transpose(0, 1)  # (16, 64, 1024)
                #print("pred_encs.shape (after transpose):", pred_encs.shape)

                # Reshape init_enc to (1, B, D)
                init_enc = init_enc.unsqueeze(0)  # (1, 64, 1024)

                # Concatenate init_enc with pred_encs
                pred_encs = torch.cat([init_enc, pred_encs], dim=0)  # (17, 64, 1024)
                #print("pred_encs.shape (after adding init frame):", pred_encs.shape)
                ################################################################################
                pred_encs = pred_encs.detach()

                n_steps = pred_encs.shape[0]
                bs = pred_encs.shape[1]

                losses_list = []

                target = getattr(batch, "locations").cuda()
                target = self.normalizer.normalize_location(target)


                if (
                    config.sample_timesteps is not None
                    and config.sample_timesteps < n_steps
                ):
                    sample_shape = (config.sample_timesteps,) + pred_encs.shape[1:]
                    # we only randomly sample n timesteps to train prober.
                    # we most likely do this to avoid OOM
                    sampled_pred_encs = torch.empty(
                        sample_shape,
                        dtype=pred_encs.dtype,
                        device=pred_encs.device,
                    )

                    sampled_target_locs = torch.empty(bs, config.sample_timesteps, 2)


                    for i in range(bs):
                        indices = torch.randperm(n_steps)[: config.sample_timesteps]
                        sampled_pred_encs[:, i, :] = pred_encs[indices, i, :]
                        sampled_target_locs[i, :] = target[i, indices]

                    pred_encs = sampled_pred_encs
                    target = sampled_target_locs.cuda()

                pred_locs = torch.stack([prober(x) for x in pred_encs], dim=1)
                losses = location_losses(pred_locs, target)
                per_probe_loss = losses.mean()

                if step % 100 == 0:
                    print(f"normalized pred locations loss {per_probe_loss.item()}")

                losses_list.append(per_probe_loss)
                optimizer_pred_prober.zero_grad()
                loss = sum(losses_list)
                loss.backward()
                optimizer_pred_prober.step()

                lr = scheduler.adjust_learning_rate(step)

                step += 1

                if self.quick_debug and step > 2:
                    break

        return prober

    @torch.no_grad()
    def evaluate_all(
        self,
        prober,
    ):
        """
        Evaluates on all the different validation datasets
        """
        avg_losses = {}

        for prefix, val_ds in self.val_ds.items():
            avg_losses[prefix] = self.evaluate_pred_prober(
                prober=prober,
                val_ds=val_ds,
                prefix=prefix,
            )

        return avg_losses

    @torch.no_grad()
    def evaluate_pred_prober(
        self,
        prober,
        val_ds,
        prefix="",
    ):
        quick_debug = self.quick_debug
        config = self.config

        model = self.model
        probing_losses = []
        prober.eval()

        for idx, batch in enumerate(tqdm(val_ds, desc="Eval probe pred")):
            ################################################################################
            # ── Forward pass through VICRegPredictive ──────────────────────────

            #print("batch.states.shape:", batch.states.shape)
            #print("batch.actions.shape:", batch.actions.shape)

            init_states = batch.states[:, 0:1]  # (64, 1, 2, 65, 65)
            actions = batch.actions             # (64, 16, 2)

            #print("init_states.shape:", init_states.shape)
            #print("actions.shape:", actions.shape)

            # Forward through model autoregressively
            pred_encs = model(s0=init_states, a0=actions, evaluation=True)  # (64, 16, 1024)
            #print("pred_encs.shape (before transpose):", pred_encs.shape)

            # Compute initial encoding manually
            with torch.no_grad():
                init_enc = model.encoder(init_states.squeeze(1))  # (64, 64, 4, 4)
                init_enc = init_enc.flatten(start_dim=1)           # (64, 1024)

            #print("init_enc.shape:", init_enc.shape)  # (64, 1024)

            # Transpose pred_encs to (T, B, D)
            pred_encs = pred_encs.transpose(0, 1)  # (16, 64, 1024)
            #print("pred_encs.shape (after transpose):", pred_encs.shape)

            # Reshape init_enc to (1, B, D)
            init_enc = init_enc.unsqueeze(0)  # (1, 64, 1024)

            # Concatenate init_enc with pred_encs
            pred_encs = torch.cat([init_enc, pred_encs], dim=0)  # (17, 64, 1024)
            #print("pred_encs.shape (after adding init frame):", pred_encs.shape)
            ################################################################################

            target = getattr(batch, "locations").cuda()
            target = self.normalizer.normalize_location(target)

            pred_locs = torch.stack([prober(x) for x in pred_encs], dim=1)
            losses = location_losses(pred_locs, target)
            probing_losses.append(losses.cpu())

        losses_t = torch.stack(probing_losses, dim=0).mean(dim=0)
        losses_t = self.normalizer.unnormalize_mse(losses_t)

        losses_t = losses_t.mean(dim=-1)
        average_eval_loss = losses_t.mean().item()

        return average_eval_loss

In [None]:
# -----------------------------------------------------------------------------
# ░░ BaselineM1-twostage Model code  ░░
# -----------------------------------------------------------------------------
class SimpleChannelEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.path1 = nn.Sequential(
            nn.Conv2d(1,8,3,1,1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(8,16,3,1,1),nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16,32,3,1,1),nn.ReLU(), nn.AdaptiveAvgPool2d((4,4)))
        self.path2 = nn.Sequential(
            nn.Conv2d(1,8,3,1,1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(8,16,3,1,1),nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16,32,3,1,1),nn.ReLU(), nn.AdaptiveAvgPool2d((4,4)))
        self.out_channels = 64
    def forward(self,x):
        return torch.cat([self.path1(x[:,0:1]), self.path2(x[:,1:2])],1)

class Predictor(nn.Module):
    def __init__(self, map_channels=64, action_dim=2, ball_dim=32):
        super().__init__()
        self.ball_ch = ball_dim
        self.fc1 = nn.Linear(action_dim, 64)
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, 128)
        self.fc_out = nn.Linear(128, self.ball_ch*4*4)
        self.act = nn.LeakyReLU()
    def forward(self, z_map, action):
        ball, wall = torch.split(z_map, self.ball_ch, dim=1)
        x = self.act(self.fc1(action))
        x = self.act(self.fc2(x))
        residual = x
        x = self.act(self.fc3(x))
        x = x + residual
        bias = self.fc_out(x).view(-1, self.ball_ch, 4, 4)
        ball_pred = ball + bias
        return torch.cat([ball_pred, wall.detach()], dim=1)

# -----------------------------------------------------------------------------
# ░░ Stage‑2 wrapper with multi‑step MSE ░░
# -----------------------------------------------------------------------------
class VICRegPredictiveStage2(nn.Module):
    """Encoder frozen; predictor learns multi‑step dynamics with pure MSE."""
    def __init__(self, encoder_ckpt: Optional[str] = None, action_dim=2, device="cuda"):
        super().__init__()
        self.encoder = SimpleChannelEncoder().to(device)
        if encoder_ckpt is not None:
          self.encoder.load_state_dict(torch.load(encoder_ckpt, map_location=device), strict=False)
          self.encoder.requires_grad_(False)
          self.encoder.eval()
        self.predictor = Predictor(self.encoder.out_channels, action_dim).to(device)
        self.repr_dim = self.encoder.out_channels * 4 * 4
    def compute_loss(self, states, actions):
        """states: (B,17,C,H,W), actions: (B,16,2)"""
        B,Tp1 = states.shape[:2]    # 17
        T = Tp1-1                  # 16
        loss = 0.
        with torch.no_grad():
            z_t = self.encoder(states[:,0])
        for t in range(T):
            with torch.no_grad():
                z_target = self.encoder(states[:,t+1])
            z_pred = self.predictor(z_t, actions[:,t])
            loss += F.mse_loss(z_pred, z_target)
            z_t = z_pred.detach()
        loss /= T
        return loss

    def forward(self, s0, s1=None, a0=None, evaluation=False):
        if evaluation:
            # s0: (B, 1, C, H, W), a0: (B, T, action_dim)
            init_states = s0
            actions = a0

            B, _, C, H, W = init_states.shape
            T = actions.shape[1]

            preds = []
            z = self.encoder(init_states.squeeze(1))  # (B, C, H, W)

            for t in range(T):
                a_t = actions[:, t]  # (B, action_dim)
                z = self.predictor(z, a_t)  # Predict next embedding
                preds.append(z.unsqueeze(1))  # Add time dimension

            pred_seq = torch.cat(preds, dim=1)  # (B, T, C, H, W)
            pred_seq = pred_seq.flatten(2)     # (B, T, D), D = C × H × W

            return pred_seq

        else:
            # s0: (B, C, H, W), s1: (B, C, H, W), a0: (B, action_dim)
            z0 = self.encoder(s0)               # (B, C, H, W)
            z_pred = self.predictor(z0, a0)     # predicted embedding
            z_next = self.encoder(s1)           # target embedding
            return z_pred, z_next


In [None]:
# -----------------------------------------------------------------------------
# ░░ Baseline M1-twostage Model Probing Code  ░░
# -----------------------------------------------------------------------------

import glob
import os
from typing import Tuple, Dict, NamedTuple, Optional

def get_device(device = None):
    """Check for GPU availability."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    return device


def load_data(device):
    #data_path = "/scratch/DL24FA"
    data_path = "/content/drive/MyDrive/jepa"

    probe_train_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/train",
        probing=True,
        device=device,
        train=True,
    )

    probe_val_normal_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_wall_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_wall/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_ds = {"normal": probe_val_normal_ds, "wall": probe_val_wall_ds}

    return probe_train_ds, probe_val_ds




def load_model(checkpoint_path: str = "/content/drive/MyDrive/jepa/Baseline_M1_RE.pth",
               device: str | None = None) -> VICRegPredictiveStage2:
    """
    Load a VICRegPredictive model from checkpoint for evaluation.

    Args:
        checkpoint_path (str): path to .pth file saved in training loop.
        device (str | None): 'cuda' or 'cpu'; auto-detect if None.

    Returns:
        model (VICRegPredictive) in .eval() mode.
    """
    if not os.path.isfile(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")

    device = get_device(device)

    # initialise the new baseline model and push to device
    model = VICRegPredictiveStage2().to(device)

    # load checkpoint
    ckpt = torch.load(checkpoint_path, map_location=device)
    try:
        model.load_state_dict(ckpt["model_state_dict"])
    except KeyError as e:
        raise KeyError(f"Checkpoint missing expected key: {e}")

    model.eval()
    return model


def evaluate_model(device, model, probe_train_ds, probe_val_ds):
    evaluator = ProbingEvaluator(
        device=device,
        model=model,
        probe_train_ds=probe_train_ds,
        probe_val_ds=probe_val_ds,
        quick_debug=False,
    )

    prober = evaluator.train_pred_prober()

    avg_losses = evaluator.evaluate_all(prober=prober)

    for probe_attr, loss in avg_losses.items():
        print(f"{probe_attr} loss: {loss}")


if __name__ == "__main__":
    device = get_device()
    probe_train_ds, probe_val_ds = load_data(device)
    model = load_model()
    evaluate_model(device, model, probe_train_ds, probe_val_ds)


Using device: cuda
Using device: cuda


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 1.0113166570663452
normalized pred locations loss 1.0257712602615356


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.9988100528717041
normalized pred locations loss 0.9609401822090149


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.9914831519126892


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 1.0618687868118286
normalized pred locations loss 0.9227306246757507


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 1.0380175113677979


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 1.0702626705169678
normalized pred locations loss 1.0133963823318481


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 1.0297143459320068


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.9406649470329285
normalized pred locations loss 0.9891737699508667


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 1.048850178718567
normalized pred locations loss 0.9970170259475708


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 1.012986183166504


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 1.0114541053771973
normalized pred locations loss 1.0870314836502075


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.9726529121398926


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.9871707558631897
normalized pred locations loss 1.0190587043762207


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.9804393649101257


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.9498564600944519
normalized pred locations loss 1.0187326669692993


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.8810274600982666


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 1.1005717515945435
normalized pred locations loss 1.1101164817810059


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 1.070177674293518
normalized pred locations loss 0.9885451197624207


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.9448111057281494


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 1.2055623531341553
normalized pred locations loss 1.09724760055542


Eval probe pred:   0%|          | 0/62 [00:00<?, ?it/s]

Eval probe pred:   0%|          | 0/62 [00:00<?, ?it/s]

normal loss: 263.5338134765625
wall loss: 189.05747985839844


In [None]:
# -----------------------------------------------------------------------------
# ░░ baselineM4_Collider_ContinueTrainingcode  ░░
# -----------------------------------------------------------------------------

# model_stage2.py
# -------------------------------------------------------------
# Baseline-M4 (Collider) : encoder + gated predictor for stage-2
# -------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple

# ---------------- Encoder (frozen in stage-2) -----------------
class SimpleChannelEncoder(nn.Module):
    """
    Two-path CNN that encodes ball & wall channels separately
    and concatenates them.  Output: (B, 64, 4, 4)
    """
    def __init__(self):
        super().__init__()
        def _path():
            return nn.Sequential(
                nn.Conv2d(1,  8, 3, 1, 1), nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(8, 16, 3, 1, 1), nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(16, 32, 3, 1, 1), nn.ReLU(),
                nn.AdaptiveAvgPool2d((4, 4)),
            )
        self.path1 = _path()   # ball channel
        self.path2 = _path()   # wall channel
        self.out_channels = 64  # 32 + 32

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, 2, H, W)  – channel-0 = ball, channel-1 = wall
        return torch.cat([self.path1(x[:, 0:1]), self.path2(x[:, 1:2])], dim=1)


# ---------------- Gated Stage-2 Predictor ---------------------
class Stage2Predictor(nn.Module):
    """
    Learns residual ∆ for *ball* channels and a gating scalar
    that can “freeze” the ball when a collision is detected.
    Wall channels are passed through detached.
    """
    def __init__(
        self,
        z_channels: int = 64,
        action_dim:  int = 2,
        ball_ch:     int = 32,
        hidden:      int = 256,
    ):
        super().__init__()
        self.ball_ch = ball_ch
        flat = z_channels * 4 * 4          # flattened latent size

        # ---- map action → spatial bias (ball_ch×4×4) --------------
        self.action_mapper = nn.Sequential(
            nn.Linear(action_dim, 64), nn.LeakyReLU(0.1),
            nn.Linear(64, ball_ch * 4 * 4), nn.LeakyReLU(0.1),
        )

        # ---- fusion MLP -------------------------------------------
        self.fc_in = nn.Linear(flat + ball_ch * 4 * 4, hidden)
        self.fc_h1 = nn.Linear(hidden, hidden)

        # ---- gate head (collision scalar) -------------------------
        self.fc_gate = nn.Linear(hidden, 1)

        # ---- residual bias head -----------------------------------
        self.fc_b1   = nn.Linear(hidden, hidden)
        self.fc_b2   = nn.Linear(hidden, hidden)
        self.fc_bias_out  = nn.Linear(hidden, ball_ch * 4 * 4)

        self.act = nn.LeakyReLU(0.1)

    # -------------------------------------------------------------
    def forward(
        self, z_map: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        z_map : (B, 64, 4, 4)
        action: (B, action_dim)
        returns
            z_pred : (B, 64, 4, 4)
            g_logits: (B, 1)  (before sigmoid)
        """
        B = z_map.size(0)
        ball, wall = torch.split(z_map, self.ball_ch, dim=1)

        a_hat = self.action_mapper(action)              # (B, ball_ch*4*4)

        fusion = torch.cat([z_map.flatten(1), a_hat], dim=1)
        h = self.act(self.fc_in(fusion))
        h = self.act(self.fc_h1(h))

        g_logits = self.fc_gate(h)                # (B,1)

        # residual bias
        b = self.act(self.fc_b1(h))
        b = self.act(self.fc_b2(b) + b)
        bias = self.fc_bias_out(b).view(B, self.ball_ch, 4, 4)

        gate = torch.sigmoid(g_logits).view(B, 1, 1, 1)
        ball_pred = ball + gate * bias

        z_pred = torch.cat([ball_pred, wall.detach()], dim=1)
        return z_pred, g_logits


# ---------------- Stage-2 wrapper ------------------------------
class VICRegPredictiveStage2(nn.Module):
    """
    * encoder is frozen (loaded from stage-1 ckpt)
    * predictor learns multi-step dynamics
    """
    def __init__(
        self,
        encoder_ckpt: Optional[str] = None,   # path to stage-1 ckpt; None = random init
        action_dim: int = 2,
        device: str = "cuda",
    ):
        super().__init__()
        self.encoder = SimpleChannelEncoder().to(device)
        self.predictor = Stage2Predictor(
            z_channels=self.encoder.out_channels,
            action_dim=action_dim,
        ).to(device)

        # load & freeze encoder if ckpt supplied
        if encoder_ckpt is not None:
            ckpt = torch.load(encoder_ckpt, map_location=device)
            enc_state = {
                k.replace("encoder.", ""): v
                for k, v in ckpt["model_state_dict"].items()
                if k.startswith("encoder.")
            }
            self.encoder.load_state_dict(enc_state, strict=True)
            self.encoder.eval()
            for p in self.encoder.parameters():
                p.requires_grad = False

        # expose flattened dimension for probing
        self.repr_dim = self.encoder.out_channels * 4 * 4

    # -------- one-step loss used during stage-2 training ---------
    def compute_loss(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        """
        states : (B, 17, C, H, W)   [0..16]
        actions: (B, 16, action_dim)
        -> average MSE over 16 steps
        """
        B, Tp1 = states.shape[:2]   # Tp1 = 17
        T = Tp1 - 1
        loss = 0.0

        with torch.no_grad():
            z_t = self.encoder(states[:, 0])      # t = 0

        for t in range(T):
            with torch.no_grad():
                z_target = self.encoder(states[:, t + 1])

            z_pred, _ = self.predictor(z_t, actions[:, t])
            loss += F.mse_loss(z_pred, z_target)
            z_t = z_pred.detach()

        return loss / T

    # -------- autoregressive rollout for evaluation --------------
    def forward(self, s0: torch.Tensor, a0: torch.Tensor, evaluation: bool = True) -> torch.Tensor:
        """
        s0 : (B, 1, C, H, W)   initial frame
        a0 : (B, T, action_dim)
        returns (B, T, repr_dim)  flattened latents
        """
        assert evaluation, "Only evaluation=True supported for rollout"
        B, _, C, H, W = s0.shape
        T = a0.shape[1]

        with torch.no_grad():
            z = self.encoder(s0.squeeze(1))      # (B,64,4,4)

        preds = []
        for t in range(T):
            z, _ = self.predictor(z, a0[:, t])
            preds.append(z.flatten(1))           # (B,1024)

        return torch.stack(preds, dim=1)         # (B,T,1024)


In [None]:
# -----------------------------------------------------------------------------
# ░░ baseline_M4_Collider_ContinueTraining Probing Code  ░░
# -----------------------------------------------------------------------------

import glob
import os
from typing import Tuple, Dict, NamedTuple, Optional

def get_device(device = None):
    """Check for GPU availability."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    return device


def load_data(device):
    #data_path = "/scratch/DL24FA"
    data_path = "/content/drive/MyDrive/jepa"

    probe_train_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/train",
        probing=True,
        device=device,
        train=True,
    )

    probe_val_normal_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_wall_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_wall/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_ds = {"normal": probe_val_normal_ds, "wall": probe_val_wall_ds}

    return probe_train_ds, probe_val_ds




def load_model(checkpoint_path: str = "/content/drive/MyDrive/jepa/Baseline_M4_Collider_ContinueTraining.pth",
               device: str | None = None) -> VICRegPredictiveStage2:
    """
    Load a VICRegPredictive model from checkpoint for evaluation.

    Args:
        checkpoint_path (str): path to .pth file saved in training loop.
        device (str | None): 'cuda' or 'cpu'; auto-detect if None.

    Returns:
        model (VICRegPredictive) in .eval() mode.
    """
    if not os.path.isfile(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")

    device = get_device(device)

    # initialise the new baseline model and push to device
    model = VICRegPredictiveStage2().to(device)

    # load checkpoint
    ckpt = torch.load(checkpoint_path, map_location=device)
    try:
        model.load_state_dict(ckpt["model_state_dict"])
    except KeyError as e:
        raise KeyError(f"Checkpoint missing expected key: {e}")

    model.eval()
    return model


def evaluate_model(device, model, probe_train_ds, probe_val_ds):
    evaluator = ProbingEvaluator(
        device=device,
        model=model,
        probe_train_ds=probe_train_ds,
        probe_val_ds=probe_val_ds,
        quick_debug=False,
    )

    prober = evaluator.train_pred_prober()

    avg_losses = evaluator.evaluate_all(prober=prober)

    for probe_attr, loss in avg_losses.items():
        print(f"{probe_attr} loss: {loss}")


if __name__ == "__main__":
    device = get_device()
    probe_train_ds, probe_val_ds = load_data(device)
    model = load_model()
    evaluate_model(device, model, probe_train_ds, probe_val_ds)


Using device: cuda
Using device: cuda


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 1.0937353372573853
normalized pred locations loss 0.7831757664680481


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.7597733736038208
normalized pred locations loss 0.829097330570221


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.6390058994293213


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.3911798298358917
normalized pred locations loss 0.25936761498451233


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.20930488407611847


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.11758735775947571
normalized pred locations loss 0.07800734788179398


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.056923214346170425


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.0415324829518795
normalized pred locations loss 0.04273030906915665


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.021051829680800438
normalized pred locations loss 0.028812984004616737


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.03637881204485893


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.01965867541730404
normalized pred locations loss 0.026207003742456436


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.019235646352171898


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.022876447066664696
normalized pred locations loss 0.024984944611787796


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.017992762848734856


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.022979890927672386
normalized pred locations loss 0.02436106838285923


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.0238339863717556


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.0196219552308321
normalized pred locations loss 0.03060464933514595


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.022278355434536934
normalized pred locations loss 0.02077462710440159


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.020541058853268623


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.02406962588429451
normalized pred locations loss 0.01858516037464142


Eval probe pred:   0%|          | 0/62 [00:00<?, ?it/s]

Eval probe pred:   0%|          | 0/62 [00:00<?, ?it/s]

normal loss: 5.319822788238525
wall loss: 9.957267761230469


In [None]:
import torch, pprint, itertools

ckpt = torch.load("/content/drive/MyDrive/jepa/Baseline_M4_Collider_ContinueTraining.pth",
                  map_location="cpu")["model_state_dict"]

# show all predictor keys
pred_keys = [k for k in ckpt if k.startswith("predictor.")]
print("First few predictor keys in checkpoint:")
pprint.pprint(list(itertools.islice(pred_keys, 10)))

First few predictor keys in checkpoint:
['predictor.action_mapper.0.weight',
 'predictor.action_mapper.0.bias',
 'predictor.action_mapper.2.weight',
 'predictor.action_mapper.2.bias',
 'predictor.fc_in.weight',
 'predictor.fc_in.bias',
 'predictor.fc_h1.weight',
 'predictor.fc_h1.bias',
 'predictor.fc_gate.weight',
 'predictor.fc_gate.bias']


In [None]:
model = Stage2Predictor()
print([n for n,_ in model.named_children()])


['action_mapper', 'fc_in', 'fc_h1', 'fc_gate', 'fc_b1', 'fc_b2', 'fc_bias_out', 'act']


In [None]:
# -------------------------------------------------------------
# Baseline-M4 (Fusion) – encoder + residual predictor
# -------------------------------------------------------------



# ------------------------------------------------------------------
#  Encoder (frozen)
# ------------------------------------------------------------------
class SimpleChannelEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.path1 = nn.Sequential(
            nn.Conv2d(1,8,3,1,1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(8,16,3,1,1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16,32,3,1,1), nn.ReLU(), nn.AdaptiveAvgPool2d((4,4)))
        self.path2 = nn.Sequential(
            nn.Conv2d(1,8,3,1,1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(8,16,3,1,1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16,32,3,1,1), nn.ReLU(), nn.AdaptiveAvgPool2d((4,4)))
        self.out_channels = 64

    def forward(self, x):
        return torch.cat([self.path1(x[:,0:1]), self.path2(x[:,1:2])], dim=1)

# ------------------------------------------------------------------
#  New Stage-2 Predictor
# ------------------------------------------------------------------

class Predictor(nn.Module):
    def __init__(
        self,
        map_channels: int = 64,
        action_dim:   int = 2,
        ball_dim:     int = 32,
        hidden_dim:   int = 128
    ):
        super().__init__()
        self.ball_ch = ball_dim

        # 1) Action expander: 2 → 64 → (32*4*4=512)
        self.action_expander = nn.Sequential(
            nn.Linear(action_dim, 64),
            nn.LeakyReLU(0.1),
            nn.Linear(64, ball_dim * 4 * 4),
            nn.LeakyReLU(0.1),
        )

        # 2) Fusion MLP in_size = 1024 (z_map) + 512 (action)
        in_feats = map_channels * 4 * 4 + ball_dim * 4 * 4  # 1024 + 512 = 1536
        self.fc1    = nn.Linear(in_feats,    hidden_dim)
        self.fc2    = nn.Linear(hidden_dim,  hidden_dim)
        self.fc3    = nn.Linear(hidden_dim,  hidden_dim)
        self.fc4    = nn.Linear(hidden_dim,  hidden_dim)
        # project to bias map
        self.fc_out = nn.Linear(hidden_dim,  ball_dim * 4 * 4)

        self.act = nn.LeakyReLU(0.1)

    def forward(self, z_map: torch.Tensor, action: torch.Tensor):
        B = z_map.size(0)

        # split ball vs. wall
        ball, wall = torch.split(z_map, self.ball_ch, dim=1)

        # expand action to size 32*4*4
        a_flat = self.action_expander(action)     # (B,512)
        # flatten full z_map
        z_flat = z_map.view(B, -1)                # (B,1024)

        # fuse and run through 4-layer MLP + one residual
        fusion   = torch.cat([z_flat, a_flat], dim=1)  # (B,1536)
        x        = self.act(self.fc1(fusion))          # (B,128)
        x        = self.act(self.fc2(x))               # (B,128)
        residual = x                                   # save residual
        x        = self.act(self.fc3(x))               # (B,128)
        x        = self.act(self.fc4(x) + residual)    # (B,128)

        # build and inject bias
        bias_flat = self.fc_out(x)                     # (B,512)
        bias      = bias_flat.view(B, self.ball_ch, 4, 4)
        ball_pred = ball + bias

        # freeze wall & concat back
        return torch.cat([ball_pred, wall.detach()], dim=1)


# ------------------------------------------------------------------
#  Stage-2 wrapper & loss
# ------------------------------------------------------------------
class VICRegPredictiveStage2(nn.Module):
    def __init__(self, encoder_ckpt: Optional[str] = None, action_dim=2, device="cuda"):
        super().__init__()

        # 1) Create encoder
        self.encoder = SimpleChannelEncoder().to(device)

        # 2) Create predictor using encoder.out_channels
        self.predictor = Predictor(
            map_channels=self.encoder.out_channels,
            action_dim=action_dim,
            ball_dim=self.encoder.out_channels // 2,
            hidden_dim=128
        ).to(device)

        # 3) Load and freeze encoder (optional)
        if encoder_ckpt is not None:
            ckpt = torch.load(encoder_ckpt, map_location=device)
            state = ckpt["model_state_dict"]
            enc_state = {
                k.replace("encoder.", ""): v
                for k, v in state.items()
                if k.startswith("encoder.")
            }
            self.encoder.load_state_dict(enc_state, strict=True)
            self.encoder.eval()
            for p in self.encoder.parameters():
                p.requires_grad = False

        # 4) Store repr_dim for probing
        self.repr_dim = self.encoder.out_channels * 4 * 4

    def compute_loss(self, states, actions, timestep=None):
        """Compute average 1-step MSE loss."""
        if timestep is None:
            state_t   = states[:, 0]
            state_tp1 = states[:, 1]
            action_t  = actions[:, 0]
        else:
            state_t   = states[:, timestep]
            state_tp1 = states[:, timestep + 1]
            action_t  = actions[:, timestep]

        with torch.no_grad():
            z_t      = self.encoder(state_t)
            z_target = self.encoder(state_tp1)

        z_pred = self.predictor(z_t, action_t)
        return F.mse_loss(z_pred, z_target)

    def forward(self, s0, a0, evaluation=True):
        """
        Autoregressive rollout of future embeddings.
        s0: (B, 1, C, H, W), a0: (B, T, action_dim)
        returns: (B, T, repr_dim)
        """
        assert evaluation, "Only evaluation=True supported."
        B, _, C, H, W = s0.shape
        T = a0.shape[1]

        with torch.no_grad():
            z = self.encoder(s0.squeeze(1))  # (B,64,4,4)

        preds = []
        for t in range(T):
            z = self.predictor(z, a0[:, t])
            preds.append(z.flatten(1))  # (B,1024)

        return torch.stack(preds, dim=1)  # (B,T,1024)


In [None]:
# -----------------------------------------------------------------------------
# ░░ Baseline M4-Fusion Model Probing Code  ░░
# -----------------------------------------------------------------------------

import glob
import os
from typing import Tuple, Dict, NamedTuple, Optional

def get_device(device = None):
    """Check for GPU availability."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    return device


def load_data(device):
    #data_path = "/scratch/DL24FA"
    data_path = "/content/drive/MyDrive/jepa"

    probe_train_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/train",
        probing=True,
        device=device,
        train=True,
    )

    probe_val_normal_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_wall_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_wall/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_ds = {"normal": probe_val_normal_ds, "wall": probe_val_wall_ds}

    return probe_train_ds, probe_val_ds




def load_model(checkpoint_path: str = "/content/drive/MyDrive/jepa/Baseline_M4_Fusion_ContinueTraining.pth",
               device: Optional[str] = None) -> VICRegPredictiveStage2:
    """
    Load a VICRegPredictive model from checkpoint for evaluation.
    """
    if not os.path.isfile(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")

    device = get_device(device) if device is None else torch.device(device)

    model = VICRegPredictiveStage2(encoder_ckpt=None, device=device)
    state = torch.load(checkpoint_path, map_location=device)["model_state_dict"]
    model.load_state_dict(state, strict=True)
    model.eval()
    return model


def evaluate_model(device, model, probe_train_ds, probe_val_ds):
    evaluator = ProbingEvaluator(
        device=device,
        model=model,
        probe_train_ds=probe_train_ds,
        probe_val_ds=probe_val_ds,
        quick_debug=False,
    )

    prober = evaluator.train_pred_prober()

    avg_losses = evaluator.evaluate_all(prober=prober)

    for probe_attr, loss in avg_losses.items():
        print(f"{probe_attr} loss: {loss}")


if __name__ == "__main__":
    device = get_device()
    probe_train_ds, probe_val_ds = load_data(device)
    model = load_model()
    evaluate_model(device, model, probe_train_ds, probe_val_ds)


Using device: cuda
Using device: cuda


Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]

Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 2.057464599609375
normalized pred locations loss 0.8895341157913208


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.7888920903205872
normalized pred locations loss 0.5611903667449951


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.4355781376361847


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.304225355386734
normalized pred locations loss 0.210633784532547


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.12017997354269028


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.08954346179962158
normalized pred locations loss 0.07998412102460861


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.0504586435854435


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.035301946103572845
normalized pred locations loss 0.05094463750720024


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.03964448347687721
normalized pred locations loss 0.03513237461447716


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.03675047680735588


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.026029052212834358
normalized pred locations loss 0.028370250016450882


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.02306382544338703


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.03146805241703987
normalized pred locations loss 0.02740958146750927


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.03321618214249611


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.03412136808037758
normalized pred locations loss 0.030009929090738297


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.03108244016766548


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.025740155950188637
normalized pred locations loss 0.021949732676148415


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.030199198052287102
normalized pred locations loss 0.02093612402677536


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.02774716354906559


Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s]

normalized pred locations loss 0.029295938089489937
normalized pred locations loss 0.021241208538413048


Eval probe pred:   0%|          | 0/62 [00:00<?, ?it/s]

Eval probe pred:   0%|          | 0/62 [00:00<?, ?it/s]

normal loss: 6.432509899139404
wall loss: 9.19483757019043
