In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import numpy as np
from fpn_resnet import build_model  # Import the model builder from your existing file

In [2]:
BATCH_SIZE = 128            # Batch size for training.
EPOCHS = 100                # Number of epochs to train for.
SEED = 42                   # Random seed for reproducibility.
LR = 0.0001                 # Learning rate.
WEIGHT_DECAY = 5e-5         # Weight decay (L2 penalty).
VALIDATE_FREQ = 5           # Frequency of validation.
SAVE_DIR = "checkpoints"    # Directory to save checkpoints.

In [3]:
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [None]:
# Training dataset with training transforms
train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=train_transform
)

# Validation dataset with test transforms
val_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=False, transform=test_transform
)

# Optionally, split indices if needed (or use Subset)
train_indices = list(range(0, int(0.8 * len(train_dataset))))
val_indices = list(range(int(0.8 * len(train_dataset)), len(train_dataset)))

train_dataset = torch.utils.data.Subset(train_dataset, train_indices)
val_dataset = torch.utils.data.Subset(val_dataset, val_indices)


Files already downloaded and verified


In [6]:
# Create data loaders
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
    num_workers=4, pin_memory=True
)

val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
    num_workers=4, pin_memory=True
)

In [7]:
def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc='Training')
    for inputs, targets in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': running_loss / (pbar.n + 1),
            'acc': 100. * correct / total
        })

    return running_loss / len(train_loader), 100. * correct / total


In [8]:
def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation')
        for inputs, targets in pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Update progress bar
            pbar.set_postfix({
                'loss': running_loss / (pbar.n + 1),
                'acc': 100. * correct / total
            })

    return running_loss / len(val_loader), 100. * correct / total


In [9]:
# Create save directory if it doesn't exist
os.makedirs(SAVE_DIR, exist_ok=True)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create model configuration for CIFAR-10
config = {
    'backbone_depth': 18,  # ResNet-16
    'num_classes': 10,  # CIFAR-10 has 10 classes
    'fpn_out_channels': 256,
    'hidden_dim': 256,
    'dropout_prob': 0.25,
    'use_dropout': True,
    'pnp_class': None,  # No plug-and-play module for this example
    'neck_final_only': True
}


Using device: cuda


In [10]:
# Build the model
model = build_model(config)
model = model.to(device)

In [None]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
    model.parameters(),
)


In [None]:
# Training loop
best_val_acc = 0.0

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    
    # Train for one epoch
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")

    if epoch % VALIDATE_FREQ == 0:
        # Validate
        val_loss, val_acc = validate(model, val_loader, criterion, device)
                
        # Print statistics
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
        
        # Save checkpoint if validation accuracy improved
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            checkpoint_path = os.path.join(SAVE_DIR, 'best_model.pth')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'config': config,
            }, checkpoint_path)
            print(f"Checkpoint saved to {checkpoint_path}")