In [2]:
!pip install transformers datasets evaluate timm --quiet

In [None]:
import os, random
import torch
import numpy as np
from datasets import load_dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
from io import BytesIO

### 1 - Fine-tuning

In [4]:
# 1) Fix seeds
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# 2) Config
config = {
    "dataset_name":     "cifar100",
    "num_labels":       100,
    "train_frac":       0.80,
    "val_frac":         0.10,
    "holdout_frac":     0.10,
    "batch_size_train": 128,
    "batch_size_val":   256,
    "max_epochs":       25,
    "patience":         5,
    "warmup_epochs":    3,
    "lr":               5e-5,
    "weight_decay":     0.05,
    "output_dir":       "./vit-cifar100-checkpoints",
}
os.makedirs(config["output_dir"], exist_ok=True)

# 3) Load & split
ds = load_dataset(config["dataset_name"])["train"]
# carve out hold‑out
split1 = ds.train_test_split(test_size=config["holdout_frac"], seed=seed)
tmp, holdout_ds = split1["train"], split1["test"]
# then split tmp -> train/val
val_prop = config["val_frac"]/(config["train_frac"] + config["val_frac"])
split2   = tmp.train_test_split(test_size=val_prop, seed=seed)
train_ds, val_ds = split2["train"], split2["test"]

# 4) Transforms
train_tf = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
])
eval_tf = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
])

# 5) Robust preprocess helper
def preprocess(example, train=True):
    img = example["img"]
    # handle raw bytes
    if isinstance(img, (bytes, bytearray)):
        img = Image.open(BytesIO(img))
    # handle list of images
    if isinstance(img, list):
        img = img[0]
        if isinstance(img, (bytes, bytearray)):
            img = Image.open(BytesIO(img))
    # handle numpy arrays
    if not isinstance(img, Image.Image):
        img = Image.fromarray(img)
    # ensure RGB
    img = img.convert("RGB")
    # apply the correct transform
    example["pixel_values"] = (train and train_tf or eval_tf)(img)
    example["labels"] = example["fine_label"]
    return example

# 6) Map & format
train_ds   = train_ds.map(lambda x: preprocess(x, True),    remove_columns=["img"])
val_ds     = val_ds.map(lambda x:   preprocess(x, False),   remove_columns=["img"])
holdout_ds = holdout_ds.map(lambda x: preprocess(x, False), remove_columns=["img"])
for d in (train_ds, val_ds, holdout_ds):
    d.set_format("torch", columns=["pixel_values","labels"])

# 7) Dataloaders
train_loader   = DataLoader(train_ds,    batch_size=config["batch_size_train"],
                            shuffle=True,  num_workers=4)
val_loader     = DataLoader(val_ds,      batch_size=config["batch_size_val"],
                            shuffle=False, num_workers=4)
holdout_loader = DataLoader(holdout_ds,  batch_size=config["batch_size_val"],
                            shuffle=False, num_workers=4)

print(f"Train/Val/Holdout sizes: {len(train_ds)}/{len(val_ds)}/{len(holdout_ds)}")

Map:   0%|          | 0/40000 [00:00<?, ? examples/s]

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Train/Val/Holdout sizes: 40000/5000/5000


In [5]:
import torch.nn as nn
from transformers import get_cosine_schedule_with_warmup
import timm
from torch.optim import AdamW

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1) Model init (swap in your DyT version here when ready)
model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=config["num_labels"])
model.to(device)

# 2) Optimizer, loss, mixed‑precision scaler
optimizer = AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
criterion = nn.CrossEntropyLoss()
scaler    = torch.cuda.amp.GradScaler()

# 3) LR schedule: warmup → cosine decay
total_steps   = len(train_loader) * config["max_epochs"]
warmup_steps  = len(train_loader) * config["warmup_epochs"]
scheduler     = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

# 4) Training + early stopping
best_val_acc, wait = 0.0, 0
for epoch in range(1, config["max_epochs"]+1):
    # — train
    model.train()
    running_loss = 0.0
    for batch in train_loader:
        inputs, labels = batch["pixel_values"].to(device), batch["labels"].to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            logits = model(inputs)
            loss   = criterion(logits, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        running_loss += loss.item()
    avg_train_loss = running_loss / len(train_loader)

    # — validate
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for batch in val_loader:
            inputs, labels = batch["pixel_values"].to(device), batch["labels"].to(device)
            preds = model(inputs).argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total   += labels.size(0)
    val_acc = correct / total

    print(f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Acc: {val_acc:.4f}")

    # — early stop & checkpoint
    if val_acc > best_val_acc:
        best_val_acc, wait = val_acc, 0
        ckpt_path = os.path.join(config["output_dir"], "best_val_acc.pt")
        torch.save(model.state_dict(), ckpt_path)
        print(f" → new best, saved to {ckpt_path}")
    else:
        wait += 1
        if wait >= config["patience"]:
            print("Early stopping triggered.")
            break

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

  scaler    = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():


Epoch 01 | Train Loss: 3.2967 | Val Acc: 0.8372
 → new best, saved to ./vit-cifar100-checkpoints/best_val_acc.pt
Epoch 02 | Train Loss: 0.9911 | Val Acc: 0.8640
 → new best, saved to ./vit-cifar100-checkpoints/best_val_acc.pt
Epoch 03 | Train Loss: 0.6414 | Val Acc: 0.8716
 → new best, saved to ./vit-cifar100-checkpoints/best_val_acc.pt
Epoch 04 | Train Loss: 0.4235 | Val Acc: 0.8756
 → new best, saved to ./vit-cifar100-checkpoints/best_val_acc.pt
Epoch 05 | Train Loss: 0.2312 | Val Acc: 0.8788
 → new best, saved to ./vit-cifar100-checkpoints/best_val_acc.pt
Epoch 06 | Train Loss: 0.1091 | Val Acc: 0.8860
 → new best, saved to ./vit-cifar100-checkpoints/best_val_acc.pt
Epoch 07 | Train Loss: 0.0605 | Val Acc: 0.8862
 → new best, saved to ./vit-cifar100-checkpoints/best_val_acc.pt
Epoch 08 | Train Loss: 0.0465 | Val Acc: 0.8822
Epoch 09 | Train Loss: 0.0537 | Val Acc: 0.8794
Epoch 10 | Train Loss: 0.0636 | Val Acc: 0.8768
Epoch 11 | Train Loss: 0.0396 | Val Acc: 0.8780
Epoch 12 | Train 

# 2 - Interpretability