In [None]:
import os
from dataclasses import dataclass
from typing import Tuple

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import timm


# -------------------------
# Config
# -------------------------
@dataclass
class CFG:
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    img_size: int = 224
    batch_size: int = 64
    num_workers: int = 4
    lr: float = 1e-3
    epochs: int = 10
    num_classes: int = 2  # change
    amp: bool = True


# -------------------------
# Dataset (replace with your loader)
# -------------------------
class PatchDataset(Dataset):
    """
    Expects: a list of (img_path, label)
    """
    def __init__(self, samples, transform=None):
        self.samples = samples
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, y = self.samples[idx]
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(y, dtype=torch.long)


def build_transforms(img_size: int):
    # UNI generally uses ImageNet normalization for patch input
    return transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
    ])


# -------------------------
# Model
# -------------------------
class UNIClassifier(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()

        # UNI backbone from Hugging Face via timm
        # NOTE: requires HF login + accepting model terms
        self.backbone = timm.create_model(
            "hf_hub:MahmoodLab/uni",
            pretrained=True,
            num_classes=0,   # feature extractor (no classifier head)
        )

        # figure out embedding dim robustly
        feat_dim = getattr(self.backbone, "num_features", None)
        if feat_dim is None:
            # timm backbones usually expose num_features; fallback:
            with torch.no_grad():
                dummy = torch.zeros(1, 3, 224, 224)
                feat_dim = self.backbone(dummy).shape[-1]

        # small head (edit as needed)
        self.head = nn.Sequential(
            nn.LayerNorm(feat_dim),
            nn.Linear(feat_dim, num_classes),
        )

        # Freeze backbone by default
        for p in self.backbone.parameters():
            p.requires_grad = False

    def forward(self, x):
        feats = self.backbone(x)       # [B, D]
        logits = self.head(feats)      # [B, C]
        return logits


# -------------------------
# Train / Eval loops
# -------------------------
def train_one_epoch(model, loader, optimizer, scaler, device):
    model.train()
    ce = nn.CrossEntropyLoss()

    total_loss, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad(set_to_none=True)
        with torch.autocast(device_type=device.split(":")[0], enabled=(scaler is not None)):
            logits = model(x)
            loss = ce(logits, y)

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        total_loss += loss.item() * x.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        n += x.size(0)

    return total_loss / n, correct / n


@torch.no_grad()
def eval_one_epoch(model, loader, device):
    model.eval()
    ce = nn.CrossEntropyLoss()

    total_loss, correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = ce(logits, y)

        total_loss += loss.item() * x.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        n += x.size(0)

    return total_loss / n, correct / n


# -------------------------
# Main (skeleton)
# -------------------------
def main(train_samples, val_samples):
    cfg = CFG()

    tfm = build_transforms(cfg.img_size)
    train_ds = PatchDataset(train_samples, transform=tfm)
    val_ds   = PatchDataset(val_samples, transform=tfm)

    train_loader = DataLoader(
        train_ds, batch_size=cfg.batch_size, shuffle=True,
        num_workers=cfg.num_workers, pin_memory=True
    )
    val_loader = DataLoader(
        val_ds, batch_size=cfg.batch_size, shuffle=False,
        num_workers=cfg.num_workers, pin_memory=True
    )

    model = UNIClassifier(num_classes=cfg.num_classes).to(cfg.device)

    # Only train head params (backbone frozen)
    optimizer = torch.optim.AdamW(
        (p for p in model.parameters() if p.requires_grad),
        lr=cfg.lr, weight_decay=1e-2
    )

    scaler = torch.cuda.amp.GradScaler() if (cfg.amp and cfg.device.startswith("cuda")) else None

    for epoch in range(cfg.epochs):
        tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, scaler, cfg.device)
        va_loss, va_acc = eval_one_epoch(model, val_loader, cfg.device)
        print(f"epoch {epoch:02d} | train loss {tr_loss:.4f} acc {tr_acc:.4f} | val loss {va_loss:.4f} acc {va_acc:.4f}")

    # Save only the head if you want (typical for FL)
    torch.save(model.head.state_dict(), "uni_head.pt")
    # Or save whole model state
    torch.save(model.state_dict(), "uni_full.pt")


if __name__ == "__main__":
    # Replace with your actual split
    # Format: [(path, label), ...]
    train_samples = [("path/to/img1.png", 0), ("path/to/img2.png", 1)]
    val_samples   = [("path/to/img3.png", 0)]
    main(train_samples, val_samples)
