# Training Pipeline Experiment

This notebook runs training experiments: baseline ViT, hyperparameter sweeps, logging performance.

In [None]:
# Imports and setup
import os
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
from src.models.vit_model import ViT

In [None]:
# Data transforms and loaders
train_dir = os.path.join('..', 'data', 'train')
test_dir = os.path.join('..', 'data', 'test')
transform = transforms.Compose([
    transforms.Resize((460, 460)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
train_ds = datasets.ImageFolder(train_dir, transform=transform)
test_ds = datasets.ImageFolder(test_dir, transform=transform)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=32)

In [None]:
# Model, loss, optimizer, device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ViT(img_size=460, patch_size=8).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
writer = SummaryWriter(log_dir=os.path.join('..','logs'))

In [None]:
# Training loop
best_acc = 0.0
for epoch in range(5):
    model.train()
    running_loss = 0.0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    avg_loss = running_loss / len(train_loader)
    # Evaluate
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model(imgs).argmax(dim=1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
    acc = correct / total
    writer.add_scalar('Loss/train', avg_loss, epoch)
    writer.add_scalar('Accuracy/test', acc, epoch)
    print(f'Epoch {epoch+1}: Loss={avg_loss:.4f}, Acc={acc:.4f}')
    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), os.path.join('..','models','best_model.pth'))
writer.close()