In [15]:

import os, math, numpy as np, torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from load_data import create_dataloaders, CFG
from timesformer_min import TimeSformerEncoder

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Same loaders; here we *use* labels
dl_train, dl_val, dl_test = create_dataloaders(CFG, batch_size=128, num_workers=4)
for loader in (dl_train, dl_val, dl_test):
    if hasattr(loader, "pin_memory"):
        loader.pin_memory = True

In [18]:
class ThresholdClassifier(nn.Module):
    def __init__(self, encoder: nn.Module, num_classes=8):
        super().__init__()
        self.encoder = encoder
        # infer feature dim
        with torch.no_grad():
            dummy = torch.zeros(2, 7, 1, 32, 64, device=next(encoder.parameters()).device)
            feat_dim = self.encoder(dummy).shape[-1]
        self.classifier = nn.Sequential(
            nn.LayerNorm(feat_dim),
            nn.Linear(feat_dim, num_classes)
        )

    def forward(self, x):            # x: [B,T,1,32,64]
        h = self.encoder(x)          # [B,D]
        logits = self.classifier(h)  # [B,8]
        return logits

# Build encoder with same sizes you used in Step 2
encoder = TimeSformerEncoder(
    in_ch=1, embed_dim=384, depth=6, num_heads=6,
    mlp_ratio=4.0, drop=0.1, attn_drop=0.0,
    patch=(8,8), T=7, H=32, W=64
).to(device)

# Load pretrained SimCLR weights
ckpt = r"C:\Users\shrua\OneDrive\Desktop\threshold project\threshold\models\timesformer_ssl_50.pth"
state = torch.load(ckpt, map_location=device)
missing, unexpected = encoder.load_state_dict(state, strict=False), None
print("Loaded pretrain from:", ckpt, "\n(strict=False)")

model = ThresholdClassifier(encoder).to(device)


Loaded pretrain from: C:\Users\shrua\OneDrive\Desktop\threshold project\threshold\models\timesformer_ssl_50.pth 
(strict=False)


## metrics

In [23]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total, correct = 0, 0
    # confusion counts for macro-F1
    num_classes = 8
    TP = torch.zeros(num_classes, dtype=torch.long, device=device)
    FP = torch.zeros(num_classes, dtype=torch.long, device=device)
    FN = torch.zeros(num_classes, dtype=torch.long, device=device)

    for batch in loader:
        x = batch["video"].to(device)           # [B,T,1,32,64]
        y = batch["label"].to(device)           # [B]
        logits = model(x)                       # [B,8]
        pred = logits.argmax(dim=1)             # [B]
        correct += (pred == y).sum().item()
        total += y.numel()

        # per-class stats
        for c in range(num_classes):
            TP[c] += ((pred == c) & (y == c)).sum()
            FP[c] += ((pred == c) & (y != c)).sum()
            FN[c] += ((pred != c) & (y == c)).sum()

    acc = correct / max(1, total)
    # macro-F1
    precision = TP.float() / torch.clamp(TP + FP, min=1)
    recall    = TP.float() / torch.clamp(TP + FN, min=1)
    f1_per_c  = 2 * precision * recall / torch.clamp(precision + recall, min=1e-8)
    macro_f1  = f1_per_c.mean().item()

    return acc, macro_f1, f1_per_c.detach()


## training

In [26]:
use_amp = (device == "cuda")

def train_step(model, batch, opt, scaler):
    model.train()
    x = batch["video"].to(device)
    y = batch["label"].to(device)

    opt.zero_grad(set_to_none=True)
    with torch.amp.autocast(device_type='cuda', enabled=use_amp):
        logits = model(x)
        loss = F.cross_entropy(logits, y)

    scaler.scale(loss).backward()
    scaler.step(opt)
    scaler.update()
    return loss.item()


# --- optimizers & schedulers ---
# Start with linear probe (freeze encoder), then unfreeze
for p in model.encoder.parameters():
    p.requires_grad = False

opt = torch.optim.AdamW(model.classifier.parameters(), lr=5e-4, weight_decay=1e-4)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=5)
scaler = torch.amp.GradScaler(enabled=use_amp)

# --- linear probe epochs ---
lin_epochs = 30
hist = {"train_loss": [], "val_acc": [], "val_f1": []}

for ep in range(1, lin_epochs+1):
    running = 0.0
    for batch in dl_train:
        running += train_step(model, batch, opt, scaler)
    sched.step()

    acc, mf1, _ = evaluate(model, dl_val)
    hist["train_loss"].append(running / max(1, len(dl_train)))
    hist["val_acc"].append(acc)
    hist["val_f1"].append(mf1)
    print(f"[Linear] epoch {ep:02d}  loss {hist['train_loss'][-1]:.4f}  val_acc {acc:.3f}  val_F1 {mf1:.3f}")

# --- unfreeze encoder: full fine‑tune ---
for p in model.encoder.parameters():
    p.requires_grad = True

opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=5e-4)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=20)
scaler = torch.amp.GradScaler('cuda', enabled=use_amp)

best_val = -1.0
os.makedirs("checkpoints", exist_ok=True)

ft_epochs = 30
for ep in range(1, ft_epochs+1):
    running = 0.0
    for batch in dl_train:
        running += train_step(model, batch, opt, scaler)
    sched.step()

    acc, mf1, _ = evaluate(model, dl_val)
    tr_loss = running / max(1, len(dl_train))
    print(f"[Finetune] epoch {ep:02d}  loss {tr_loss:.4f}  val_acc {acc:.3f}  val_F1 {mf1:.3f}")

    if mf1 > best_val:
        best_val = mf1
        torch.save(model.state_dict(), r"C:\Users\shrua\OneDrive\Desktop\threshold project\threshold\models\timesformer_cls.pth")
        print("saved")


[Linear] epoch 01  loss 2.0824  val_acc 0.150  val_F1 0.055
[Linear] epoch 02  loss 2.0662  val_acc 0.237  val_F1 0.096
[Linear] epoch 03  loss 2.0631  val_acc 0.263  val_F1 0.146
[Linear] epoch 04  loss 2.0636  val_acc 0.237  val_F1 0.128
[Linear] epoch 05  loss 2.0612  val_acc 0.237  val_F1 0.112
[Linear] epoch 06  loss 2.0590  val_acc 0.237  val_F1 0.112
[Linear] epoch 07  loss 2.0607  val_acc 0.225  val_F1 0.090
[Linear] epoch 08  loss 2.0610  val_acc 0.225  val_F1 0.090
[Linear] epoch 09  loss 2.0610  val_acc 0.200  val_F1 0.095
[Linear] epoch 10  loss 2.0620  val_acc 0.125  val_F1 0.028
[Linear] epoch 11  loss 2.0647  val_acc 0.250  val_F1 0.100
[Linear] epoch 12  loss 2.0634  val_acc 0.200  val_F1 0.079
[Linear] epoch 13  loss 2.0628  val_acc 0.225  val_F1 0.093
[Linear] epoch 14  loss 2.0591  val_acc 0.100  val_F1 0.033
[Linear] epoch 15  loss 2.0573  val_acc 0.125  val_F1 0.062
[Linear] epoch 16  loss 2.0585  val_acc 0.125  val_F1 0.062
[Linear] epoch 17  loss 2.0585  val_acc 