# Contrastive Learning (SimCLR/InfoNCE)

Two-crop pipeline: two stochastic views of the same sequence (positives).

Projection head (MLP): improves contrastive training stability/quality (SimCLR finding).

NT-Xent / InfoNCE loss with temperature τ and cosine similarities.

Tiny eval: in-batch k-NN / retrieval sanity check.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import math, torch, torch.nn as nn, torch.nn.functional as F

In [None]:
def nt_xent(z1: torch.Tensor, z2: torch.Tensor, tau: float = 0.1) -> torch.Tensor:
    """
    z1, z2: [B, D], L2-normalized.
    Returns scalar loss (average over 2B positives).
    """
    B = z1.size(0)
    z = torch.cat([z1, z2], dim=0)              # [2B, D]
    # Similarity matrix (cosine) scaled by temperature
    sim = z @ z.t()                              # cosine since z's are normalized -> [2B, 2B]
    # mask self-similarity
    diag = torch.eye(2*B, device=z.device, dtype=torch.bool)
    sim = sim / tau
    sim = sim - 1e9 * diag                       # remove self-pairs

    # positives: (i <-> i+B) and (i+B <-> i)
    # each anchor sees 2B−2 negatives
    targets = torch.cat([torch.arange(B, 2*B), torch.arange(0, B)], dim=0).to(z.device)  # [2B]
    # Cross-entropy over rows (softmax over all 2B-1 others)
    loss = F.cross_entropy(sim, targets)
    return loss


In [39]:
class TSView:
    def __init__(self, jitter_std=0.02, scale_min=0.9, scale_max=1.1, cutout_p=0.3, cutout_len=12):
        self.jitter_std = jitter_std
        self.scale_min = scale_min
        self.scale_max = scale_max
        self.cutout_p = cutout_p
        self.cutout_len = cutout_len

    def __call__(self, x: torch.Tensor):   # x: [T, C]
        # jitter
        x = x + torch.randn_like(x) * self.jitter_std
        # per-channel scaling
        scales = torch.empty(x.size(1), device=x.device).uniform_(self.scale_min, self.scale_max)
        x = x * scales
        # cutout (time masking) — zeros a contiguous block
        if torch.rand(()) < self.cutout_p and x.size(0) > self.cutout_len:
            start = int(torch.randint(0, x.size(0) - self.cutout_len + 1, (1,)))
            x[start:start+self.cutout_len] = 0.0
        return x


In [41]:
from torch.utils.data import Dataset

class TwoCropTSDataset(Dataset):
    """
    Time-series two-crop dataset.
    X: [N, T, C] (float)
    M: [N, T] bool pad mask (True=PAD). Pass None if no padding.
    view: callable that applies augmentations to a [T, C] tensor and returns [T, C]
    crop_len: fixed crop window length in timesteps
    pad_to_crop_len: if valid length < crop_len, right-pad with zeros and set mask=True on padded tail
    """
    def __init__(self, X: torch.Tensor, M: torch.Tensor | None, view, *,
                 crop_len: int = 128, pad_to_crop_len: bool = True):
        super().__init__()
        assert X.ndim == 3, "X must be [N,T,C]"
        self.X = X.float()
        self.M = M if M is None else M.bool()
        self.view = view
        self.crop_len = crop_len
        self.pad_to_crop_len = pad_to_crop_len

    def __len__(self): return self.X.size(0)

    def __getitem__(self, i):
        x = self.X[i]                         # [T, C]
        m = None if self.M is None else self.M[i]  # [T] bool (True=PAD)

        x1, m1 = self._crop_once(x, m)
        x2, m2 = self._crop_once(x, m)
        # return one sample [T, C]
        return (x1, m1), (x2, m2)

    def _crop_once(self, x: torch.Tensor, m: torch.Tensor | None):
        T = x.size(0)
        if m is None:
            # whole sequence is valid
            valid_T = T
            start_max = max(0, valid_T - self.crop_len)
            start = int(torch.randint(0, start_max + 1, (1,)))
            end = start + self.crop_len
            sub = x[start:end] if self.crop_len <= T else x
            if self.crop_len > T and self.pad_to_crop_len:
                sub, pad_mask = self._right_pad(sub, self.crop_len)
            else:
                pad_mask = torch.zeros(sub.size(0), dtype=torch.bool, device=x.device)
        else:
            # m: True=PAD, so valid = ~m
            valid_len = int((~m).sum().item())
            if valid_len == 0:
                # degenerate: make a single-timestep zero crop and pad
                sub = torch.zeros((0, x.size(1)), dtype=x.dtype, device=x.device)
                sub, pad_mask = self._right_pad(sub, self.crop_len)
                return self.view(sub), pad_mask

            w = min(valid_len, self.crop_len)
            start_max = max(0, valid_len - w)
            start = int(torch.randint(0, start_max + 1, (1,)))
            end = start + w
            # take from the *valid* prefix of length valid_len
            sub = x[:valid_len][start:end]     # [w, C]

            if w < self.crop_len and self.pad_to_crop_len:
                sub, pad_mask = self._right_pad(sub, self.crop_len)
            else:
                pad_mask = torch.zeros(sub.size(0), dtype=torch.bool, device=x.device)

        # apply view (augmentations) in [T, C] then return sub + mask
        sub = self.view(sub) if self.view is not None else sub
        return sub, pad_mask

    @staticmethod
    def _right_pad(sub: torch.Tensor, target_len: int):
        """Right-pad sub [t,C] to target_len with zeros; return padded sub and pad mask [target_len]."""
        t, C = sub.size(0), sub.size(1) if sub.ndim == 2 else (sub.size(0), 1)
        if t == target_len:
            pad_mask = torch.zeros(t, dtype=torch.bool, device=sub.device)
            return sub, pad_mask
        pad = torch.zeros((target_len - t, sub.size(1)), dtype=sub.dtype, device=sub.device)
        out = torch.cat([sub, pad], dim=0)
        mask = torch.zeros(target_len, dtype=torch.bool, device=sub.device)
        mask[t:] = True  # True=PAD
        return out, mask

In [None]:
def info_nce_two_way(z1, z2, tau=0.2):
    # z1,z2: [B,D], MUST be L2-normalized along dim=-1
    sim = (z1 @ z2.t()) / tau           # [B,B]
    y = torch.arange(z1.size(0), device=z1.device)
    # Each anchor sees only cross-view negatives (B−1 per anchor)
    return 0.5 * (F.cross_entropy(sim, y) + F.cross_entropy(sim.t(), y))

In [168]:
def train_contrastive(model: nn.Module,loader, *,
                      epochs=20, lr=3e-4, tau=0.1, device=None):
    device = device or (torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu"))
    model.to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    for ep in range(1, epochs+1):
        model.train()
        total = 0.0
        for (ids1, m1), (ids2, m2) in loader:
            optim.zero_grad(set_to_none=True)
            ids1, m1 = ids1.to(device), m1.to(device)
            ids2, m2 = ids2.to(device), m2.to(device)

            z1 = model.encode(ids1)              # [B,d]
            z2 = model.encode(ids2)              # [B,d]
            loss = nt_xent(z1, z2, tau=tau)
            # loss = info_nce_two_way(z1, z2, tau=tau)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            g = sum((p.grad is not None and p.grad.abs().sum().item()) for p in model.parameters())
            assert g > 0, "no gradients flowed"
            optim.step()
            total += loss.item() * ids1.size(0)

        avg = total / len(loader.dataset)
        with torch.no_grad():
            # simple health checks
            z = torch.cat([z1, z2], 0)
            std = z.float().std(dim=0).mean().item()

            # intra-branch cosine structure (should be diag >> offdiag)
            sim_intra = (z1 @ z1.t()).float()
            B = sim_intra.size(0)
            diag = sim_intra.diag().mean().item()
            offd = (sim_intra.sum() - sim_intra.diag().sum()) / max(1, (B*B - B))
            # cross-branch diag/offdiag for the actual loss logits
            sim_cross = (z1 @ z2.t()).float()
            diag_x = sim_cross.diag().mean().item()
            offd_x = (sim_cross.sum() - sim_cross.diag().sum()) / max(1, (B*B - B))
        print(f"[epoch {ep:03d}] loss={avg:.4f} | z-std={std:.3f} |"
              f"intra diag/off={diag:.3f}/{offd:.3f} | cross diag/off={diag_x:.3f}/{offd_x:.3f}")



In [None]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
import numpy as np
from src.encoder_classifier_wrapper import EncoderClassifier, LinearFrontend

In [15]:
from src.synthetic_data import synth_trips

In [160]:
X, road, wthr, t = synth_trips(N=2000, T=128, use_accel=False, seed=7)

In [161]:
X.shape, road.shape, wthr.shape, t.shape

(torch.Size([2000, 128, 1]),
 torch.Size([2000]),
 torch.Size([2000]),
 torch.Size([128]))

In [85]:
from src.mha_block import SDPAMHA, PreLNEncoderBlockSDPA, ClippedRelPosBias, SinusoidalPositionalEncoding

In [86]:
pos = SinusoidalPositionalEncoding(d_model=64)

In [99]:
# Build attention block
def make_block(d_model):
    attn = SDPAMHA(d_model=d_model, num_heads=4, p_drop=0.1, use_rope=False, rel_bias=None)
    return PreLNEncoderBlockSDPA(d_model, attn=attn, ff_mult=4, p_drop=0.1, norm="ln", resid_mode="plain")

In [162]:
frontend = LinearFrontend(in_channels=1, d_model=64)  # C=1 (speed) or 2 (speed+accel)
model = EncoderClassifier(
    d_model=64, num_layers=2, block_ctor=make_block,
    pool="mean", posenc = pos,
    final_norm="ln", final_norm_pos="post_pool", proj_dim=64,
    frontend=frontend
)

In [163]:
view = TSView()
ds = TwoCropTSDataset(X, M=None, view=view, crop_len=64)
loader = torch.utils.data.DataLoader(ds, batch_size=256, shuffle=True, drop_last=True)
ds.__getitem__(0)[0][0].shape

torch.Size([64, 1])

In [169]:
train_contrastive(model=model, loader=loader, tau=0.2)

[epoch 001] loss=4.1874 | z-std=0.119 |intra diag/off=1.000/0.023 | cross diag/off=0.790/0.024
[epoch 002] loss=4.1895 | z-std=0.118 |intra diag/off=1.000/0.058 | cross diag/off=0.789/0.054
[epoch 003] loss=4.1477 | z-std=0.119 |intra diag/off=1.000/0.034 | cross diag/off=0.798/0.027
[epoch 004] loss=4.1278 | z-std=0.119 |intra diag/off=1.000/0.025 | cross diag/off=0.795/0.024
[epoch 005] loss=4.1785 | z-std=0.119 |intra diag/off=1.000/0.017 | cross diag/off=0.790/0.020
[epoch 006] loss=4.1449 | z-std=0.120 |intra diag/off=1.000/0.008 | cross diag/off=0.793/0.007
[epoch 007] loss=4.2068 | z-std=0.120 |intra diag/off=1.000/0.017 | cross diag/off=0.779/0.017
[epoch 008] loss=4.1541 | z-std=0.118 |intra diag/off=1.000/0.052 | cross diag/off=0.767/0.044
[epoch 009] loss=4.1274 | z-std=0.120 |intra diag/off=1.000/0.017 | cross diag/off=0.774/0.016
[epoch 010] loss=4.1487 | z-std=0.118 |intra diag/off=1.000/0.052 | cross diag/off=0.772/0.054
[epoch 011] loss=4.1197 | z-std=0.119 |intra diag/

In [170]:
model.train()
dev = torch.device("mps")
(x1, m1), (x2, m2), *_ = next(iter(loader))
x1, x2 = x1.to(dev), x2.to(dev)
z1 = model.encode(x1, mask=m1.to(dev))
z2 = model.encode(x2, mask=m2.to(dev))
assert z1.requires_grad and z2.requires_grad, "encode() is producing no-grad tensors"

In [171]:
print("Δ(x1,x2):", (x1 - x2).abs().mean().item())

Δ(x1,x2): 0.6974043846130371


In [172]:
# 1) frontend signal
h1f, _ = model.frontend(x1, m1)  # [B,T,D]
h2f, _ = model.frontend(x2, m2)
print("frontend std:", h1f.std().item())

frontend std: 0.7535977363586426


In [173]:
# 2) after first block (no masks for fixed crops)
h1 = model.layers[0](h1f)
h2 = model.layers[0](h2f)
print("after block std:", h1.std().item())

after block std: 0.9703962802886963


In [174]:
# 3) pooled before proj
def masked_mean(x, pad_mask):
    return x.mean(1) if pad_mask is None else (x * (~pad_mask).float().unsqueeze(-1)).sum(1) / ((~pad_mask).float().sum(1, keepdim=True).clamp_min(1.0))

h1p = masked_mean(h1, None)
h2p = masked_mean(h2, None)
print("pooled std:", h1p.std().item())

pooled std: 0.7788299322128296


In [175]:
# 4) projection (pre-norm)
z1p = model.proj(h1p)
z2p = model.proj(h2p)
print("proj pre-norm std:", z1p.std().item())

proj pre-norm std: 0.14979314804077148


In [176]:
# 5) final norm + cosine stats
z1 = F.normalize(z1p, dim=-1)
z2 = F.normalize(z2p, dim=-1)
sim = z1 @ z2.T
B = z1.size(0)
diag = sim.diag().mean().item()
offd = (sim.sum() - sim.diag().sum()) / (B*B - B)
print("cos diag mean:", diag, "offdiag mean:", offd.item())

cos diag mean: 0.9398491978645325 offdiag mean: 0.5329511165618896


In [177]:
# 1) Cosines BEFORE normalization
sim_raw = torch.nn.functional.cosine_similarity(
    z1p.unsqueeze(1), z1p.unsqueeze(0), dim=-1
)  # [B,B]
B = sim_raw.size(0)
print("raw cos diag:", sim_raw.diag().mean().item(),
      "raw cos offdiag:", ((sim_raw.sum()-sim_raw.diag().sum())/(B*B-B)).item())

raw cos diag: 1.0 raw cos offdiag: 0.5339733362197876


In [178]:
# 2) Cosines AFTER normalization you use in loss
sim_norm = (z1 @ z1.t())  # if you normalize correctly, offdiag should drop vs diag
print("norm cos diag:", sim_norm.diag().mean().item(),
      "norm cos offdiag:", ((sim_norm.sum()-sim_norm.diag().sum())/(B*B-B)).item())

norm cos diag: 1.0 norm cos offdiag: 0.5339733362197876


In [179]:
# grads flow?
loss = nt_xent(z1, z2, tau=0.2)
loss.backward()
print(sum(int(p.grad is not None) for p in model.parameters()), "params have grads")

32 params have grads


In [180]:
# batch similarity structure inside each branch (should NOT be ~all ones)
sim_intra = (z1 @ z1.t()).detach()
print("intra z1 diag:", sim_intra.diag().mean().item(),
      "offdiag:", ((sim_intra.sum()-sim_intra.diag().sum())/(B*B-B)).item())

intra z1 diag: 1.0 offdiag: 0.5339733362197876


In [181]:
with torch.no_grad():
    B = z1.size(0)
    s = (z1 @ z2.t())
    print("cos diag mean:", s.diag().mean().item(), "offdiag mean:", (s.sum() - s.diag().sum())/ (B*B-B))

cos diag mean: 0.9398491978645325 offdiag mean: tensor(0.5330, device='mps:0')


In [189]:
class ProjectionHead(nn.Module):
    """
    z = head(h) for contrastive training.
    SimCLR-style: MLP -> (norm) -> nonlinearity -> Linear
    """
    def __init__(self, in_dim: int, hid: int = 256, out_dim: int = 128, use_ln: bool = True):
        super().__init__()
        layers = [nn.Linear(in_dim, hid)]
        if use_ln: layers += [nn.LayerNorm(hid)]
        layers += [nn.ReLU(), nn.Linear(hid, out_dim)]
        self.net = nn.Sequential(*layers)

    def forward(self, h):           # h: [B, D]
        z = self.net(h)             # [B, out_dim]
        z = F.normalize(z, dim=-1)  # L2-normalize for cosine/InfoNCE
        return z

In [183]:
# Projector with BatchNorm
class SimCLRProjector(nn.Module):
    def __init__(self, d, p):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, d, bias=False), nn.BatchNorm1d(d), nn.ReLU(inplace=True),
            nn.Linear(d, p, bias=False), nn.BatchNorm1d(p, affine=False)
        )
    def forward(self, x): return self.net(x)

In [190]:
projector = ProjectionHead(in_dim=64, out_dim=64)
model2 = EncoderClassifier(
    d_model=64, num_layers=2, block_ctor=make_block,
    pool="mean", posenc = pos,
    final_norm="ln", final_norm_pos="post_pool", proj_dim=64,
    frontend=frontend,
    projector=projector
)

In [191]:
train_contrastive(model=model2, loader=loader, tau=0.2)

[epoch 001] loss=4.8491 | z-std=0.094 |intra diag/off=1.000/0.322 | cross diag/off=0.958/0.321
[epoch 002] loss=4.6115 | z-std=0.109 |intra diag/off=1.000/0.097 | cross diag/off=0.910/0.095
[epoch 003] loss=4.5561 | z-std=0.106 |intra diag/off=1.000/0.123 | cross diag/off=0.917/0.131
[epoch 004] loss=4.4737 | z-std=0.111 |intra diag/off=1.000/0.058 | cross diag/off=0.910/0.058
[epoch 005] loss=4.4256 | z-std=0.114 |intra diag/off=1.000/0.054 | cross diag/off=0.857/0.056
[epoch 006] loss=4.3676 | z-std=0.114 |intra diag/off=1.000/0.059 | cross diag/off=0.875/0.057
[epoch 007] loss=4.3644 | z-std=0.115 |intra diag/off=1.000/0.042 | cross diag/off=0.844/0.041
[epoch 008] loss=4.3317 | z-std=0.116 |intra diag/off=1.000/0.028 | cross diag/off=0.865/0.029
[epoch 009] loss=4.3087 | z-std=0.114 |intra diag/off=1.000/0.046 | cross diag/off=0.874/0.051
[epoch 010] loss=4.3310 | z-std=0.117 |intra diag/off=1.000/0.019 | cross diag/off=0.848/0.021
[epoch 011] loss=4.3042 | z-std=0.116 |intra diag/

In [187]:
projector_bn = SimCLRProjector(d=64, p=64)
model3 = EncoderClassifier(
    d_model=64, num_layers=2, block_ctor=make_block,
    pool="mean", posenc = pos,
    final_norm="ln", final_norm_pos="post_pool", proj_dim=64,
    frontend=frontend,
    projector=projector_bn
)

In [188]:
train_contrastive(model=model3, loader=loader, tau=0.2)

[epoch 001] loss=4.4894 | z-std=0.124 |intra diag/off=1.000/0.010 | cross diag/off=0.822/0.011
[epoch 002] loss=4.3682 | z-std=0.124 |intra diag/off=1.000/0.008 | cross diag/off=0.778/0.010
[epoch 003] loss=4.3462 | z-std=0.124 |intra diag/off=1.000/0.011 | cross diag/off=0.802/0.011
[epoch 004] loss=4.3184 | z-std=0.124 |intra diag/off=1.000/0.007 | cross diag/off=0.789/0.008
[epoch 005] loss=4.2859 | z-std=0.124 |intra diag/off=1.000/0.007 | cross diag/off=0.789/0.008
[epoch 006] loss=4.2615 | z-std=0.124 |intra diag/off=1.000/0.008 | cross diag/off=0.799/0.009
[epoch 007] loss=4.2587 | z-std=0.124 |intra diag/off=1.000/0.005 | cross diag/off=0.793/0.006
[epoch 008] loss=4.2579 | z-std=0.124 |intra diag/off=1.000/0.005 | cross diag/off=0.764/0.005
[epoch 009] loss=4.2714 | z-std=0.124 |intra diag/off=1.000/0.001 | cross diag/off=0.767/0.003
[epoch 010] loss=4.2191 | z-std=0.124 |intra diag/off=1.000/0.002 | cross diag/off=0.756/0.003
[epoch 011] loss=4.2061 | z-std=0.124 |intra diag/

In [192]:
train_contrastive(model=model3, loader=loader, tau=0.1)

[epoch 001] loss=4.1835 | z-std=0.125 |intra diag/off=1.000/0.001 | cross diag/off=0.828/0.001
[epoch 002] loss=4.2480 | z-std=0.125 |intra diag/off=1.000/0.001 | cross diag/off=0.822/0.001
[epoch 003] loss=4.2483 | z-std=0.124 |intra diag/off=1.000/0.005 | cross diag/off=0.819/0.007
[epoch 004] loss=4.1947 | z-std=0.124 |intra diag/off=1.000/0.008 | cross diag/off=0.845/0.007
[epoch 005] loss=4.2182 | z-std=0.124 |intra diag/off=1.000/0.008 | cross diag/off=0.821/0.009
[epoch 006] loss=4.1470 | z-std=0.124 |intra diag/off=1.000/0.011 | cross diag/off=0.837/0.011
[epoch 007] loss=4.1523 | z-std=0.124 |intra diag/off=1.000/0.011 | cross diag/off=0.829/0.012
[epoch 008] loss=4.1179 | z-std=0.124 |intra diag/off=1.000/0.007 | cross diag/off=0.838/0.009
[epoch 009] loss=4.1939 | z-std=0.124 |intra diag/off=1.000/0.012 | cross diag/off=0.835/0.014
[epoch 010] loss=4.1107 | z-std=0.124 |intra diag/off=1.000/0.012 | cross diag/off=0.863/0.014
[epoch 011] loss=4.1881 | z-std=0.123 |intra diag/

z-std ≈ 0.124 → no collapse.

intra off ≈ 0.015 and cross off ≈ 0.015 → negatives are well spread (uniformity 👍).

cross diag ≈ 0.849 → positives are meaningfully closer than negatives (alignment 👍).

## Freeze & probe (road / weather)

In [None]:
model3.features

<bound method EncoderClassifier.features of EncoderClassifier(
  (frontend): LinearFrontend(
    (proj): Linear(in_features=1, out_features=64, bias=True)
  )
  (posenc): SinusoidalPositionalEncoding()
  (layers): ModuleList(
    (0-1): 2 x PreLNEncoderBlockSDPA(
      (attn): SDPAMHA(
        (qkv): Linear(in_features=64, out_features=192, bias=True)
        (o): Linear(in_features=64, out_features=64, bias=True)
      )
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (drop1): Dropout(p=0.1, inplace=False)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ff): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
      )
      (drop2): Dropout(p=0.1, inplace=False)
    )
  )
  (final_ln): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (head): Linear(in_features=64, out_features=1, bias=True)
  (proj): SimCL

In [196]:
model3.eval()
enc = model3.features  # or a wrapper that returns pooled pre-projection h
with torch.no_grad():
    Z = enc(X.to(dev))   # [N,D]
# simple linear probes (logreg) for road (3) & weather (2)
clf_road = torch.nn.Linear(Z.size(1), 3).to(dev)
clf_wthr = torch.nn.Linear(Z.size(1), 2).to(dev)


In [197]:
Z.shape

torch.Size([2000, 64])

In [198]:
def split_idx(N, val=0.2, seed=0): # train test split
    g = torch.Generator().manual_seed(seed)
    perm = torch.randperm(N, generator=g)
    n_val = int(N * val)
    return perm[n_val:], perm[:n_val]   # train_idx, val_idx

def train_linear_probe(Z, y, clf, *, epochs=200, lr=1e-2, wd=0.0, device=None):
    device = device or torch.device("cpu")
    Z = Z.to(device)
    y = y.to(device).long()
    W = clf.to(device)
    opt = torch.optim.AdamW(W.parameters(), lr=lr, weight_decay=wd)

    for _ in range(epochs):
        opt.zero_grad(set_to_none=True)
        loss = F.cross_entropy(W(Z), y)
        loss.backward()
        opt.step()
    return W

@torch.no_grad()
def accuracy(W, Z, y, device=None):
    device = device or torch.device("cpu")
    logits = W(Z.to(device))
    pred = logits.argmax(dim=-1).cpu()
    return (pred == y.cpu().long()).float().mean().item()

In [199]:
device = next(model3.parameters()).device
device

device(type='mps', index=0)

In [200]:
# Train/val split
train_idx, val_idx = split_idx(Z.size(0), val=0.2, seed=42)
Ztr, Zva = Z[train_idx], Z[val_idx]
road_tr, road_va = road[train_idx], road[val_idx]
wthr_tr, wthr_va = wthr[train_idx], wthr[val_idx]

In [201]:
# Train probes
W_road = train_linear_probe(Ztr, road_tr, clf_road, epochs=200, lr=1e-2, wd=1e-4, device=device)

In [202]:
W_wthr = train_linear_probe(Ztr, wthr_tr, clf_wthr, epochs=200, lr=1e-2, wd=1e-4, device=device)


In [203]:
# Evaluate
acc_road = accuracy(W_road, Zva, road_va, device=device)
acc_wthr = accuracy(W_wthr, Zva, wthr_va, device=device)
print(f"Linear probe – road: {acc_road:.3f}, weather: {acc_wthr:.3f}")

Linear probe – road: 0.890, weather: 0.600
