In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import numpy as np
from fpn_resnet import build_model  # Import the model builder from your existing file

In [2]:
BATCH_SIZE = 128            # Batch size for training.
EPOCHS = 100                # Number of epochs to train for.
SEED = 42                   # Random seed for reproducibility.
LR = 0.0001                 # Learning rate.
WEIGHT_DECAY = 5e-5         # Weight decay (L2 penalty).
VALIDATE_FREQ = 5           # Frequency of validation.
SAVE_DIR = "checkpoints"    # Directory to save checkpoints.

In [3]:
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [5]:
# Load the full training dataset
full_train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=train_transform
)

# Split the training dataset: 80% training, 20% validation
train_size = int(0.8 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size

train_dataset, val_dataset = random_split(
    full_train_dataset, [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

# Create a validation dataset with test transforms
val_dataset.dataset.transform = test_transform

Files already downloaded and verified


In [6]:
# Create data loaders
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
    num_workers=4, pin_memory=True
)

val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
    num_workers=4, pin_memory=True
)

In [7]:
def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc='Training')
    for inputs, targets in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': running_loss / (pbar.n + 1),
            'acc': 100. * correct / total
        })

    return running_loss / len(train_loader), 100. * correct / total


In [8]:
def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation')
        for inputs, targets in pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Update progress bar
            pbar.set_postfix({
                'loss': running_loss / (pbar.n + 1),
                'acc': 100. * correct / total
            })

    return running_loss / len(val_loader), 100. * correct / total


In [9]:
# Create save directory if it doesn't exist
os.makedirs(SAVE_DIR, exist_ok=True)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create model configuration for CIFAR-10
config = {
    'backbone_depth': 18,  # ResNet-16
    'num_classes': 10,  # CIFAR-10 has 10 classes
    'fpn_out_channels': 256,
    'hidden_dim': 256,
    'dropout_prob': 0.25,
    'use_dropout': True,
    'pnp_class': None,  # No plug-and-play module for this example
    'neck_final_only': True
}


Using device: cuda


In [10]:
# Build the model
model = build_model(config)
model = model.to(device)

In [11]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=LR,
    weight_decay=WEIGHT_DECAY,
)

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [None]:
# Training loop
best_val_acc = 0.0

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    
    # Train for one epoch
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")

    if epoch % VALIDATE_FREQ == 0:
        # Validate
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        # Update learning rate
        scheduler.step()
        
        # Print statistics
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
        
        # Save checkpoint if validation accuracy improved
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            checkpoint_path = os.path.join(SAVE_DIR, 'best_model.pth')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'config': config,
            }, checkpoint_path)
            print(f"Checkpoint saved to {checkpoint_path}")


Epoch 1/100


Training: 100%|██████████| 313/313 [00:14<00:00, 20.97it/s, loss=1.67, acc=38.7]


Train Loss: 1.6650 | Train Acc: 38.69%


Validation: 100%|██████████| 79/79 [00:00<00:00, 88.46it/s, loss=1.55, acc=49.1] 


Val Loss: 1.4090 | Val Acc: 49.14%
Checkpoint saved to checkpoints/best_model.pth

Epoch 2/100


Training: 100%|██████████| 313/313 [00:14<00:00, 22.16it/s, loss=1.26, acc=55.5]


Train Loss: 1.2503 | Train Acc: 55.45%

Epoch 3/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.46it/s, loss=1.03, acc=64.1]


Train Loss: 1.0258 | Train Acc: 64.09%

Epoch 4/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.02it/s, loss=0.819, acc=71.7]


Train Loss: 0.8142 | Train Acc: 71.71%

Epoch 5/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.35it/s, loss=0.621, acc=78.8]


Train Loss: 0.6171 | Train Acc: 78.76%

Epoch 6/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.24it/s, loss=0.45, acc=84.6] 


Train Loss: 0.4476 | Train Acc: 84.60%


Validation: 100%|██████████| 79/79 [00:00<00:00, 86.91it/s, loss=1.57, acc=59.5]


Val Loss: 1.3691 | Val Acc: 59.45%
Checkpoint saved to checkpoints/best_model.pth

Epoch 7/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.45it/s, loss=0.325, acc=88.9]


Train Loss: 0.3229 | Train Acc: 88.93%

Epoch 8/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.26it/s, loss=0.238, acc=91.9]


Train Loss: 0.2368 | Train Acc: 91.91%

Epoch 9/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.34it/s, loss=0.19, acc=93.5] 


Train Loss: 0.1885 | Train Acc: 93.45%

Epoch 10/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.33it/s, loss=0.153, acc=94.9]


Train Loss: 0.1521 | Train Acc: 94.87%

Epoch 11/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.33it/s, loss=0.129, acc=95.7]


Train Loss: 0.1285 | Train Acc: 95.66%


Validation: 100%|██████████| 79/79 [00:00<00:00, 86.81it/s, loss=2.11, acc=59.9] 


Val Loss: 1.8973 | Val Acc: 59.94%
Checkpoint saved to checkpoints/best_model.pth

Epoch 12/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.56it/s, loss=0.116, acc=96]   


Train Loss: 0.1155 | Train Acc: 95.99%

Epoch 13/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.36it/s, loss=0.108, acc=96.4]


Train Loss: 0.1071 | Train Acc: 96.36%

Epoch 14/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.39it/s, loss=0.107, acc=96.3] 


Train Loss: 0.1065 | Train Acc: 96.28%

Epoch 15/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.42it/s, loss=0.0932, acc=96.9]


Train Loss: 0.0926 | Train Acc: 96.86%

Epoch 16/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.34it/s, loss=0.0847, acc=97.1]


Train Loss: 0.0841 | Train Acc: 97.11%


Validation: 100%|██████████| 79/79 [00:00<00:00, 85.95it/s, loss=2.42, acc=60.2] 


Val Loss: 2.1478 | Val Acc: 60.24%
Checkpoint saved to checkpoints/best_model.pth

Epoch 17/100


Training: 100%|██████████| 313/313 [00:14<00:00, 22.13it/s, loss=0.0793, acc=97.4]


Train Loss: 0.0788 | Train Acc: 97.38%

Epoch 18/100


Training: 100%|██████████| 313/313 [00:14<00:00, 22.10it/s, loss=0.0861, acc=97.1]


Train Loss: 0.0855 | Train Acc: 97.11%

Epoch 19/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.60it/s, loss=0.0767, acc=97.4]


Train Loss: 0.0763 | Train Acc: 97.42%

Epoch 20/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.41it/s, loss=0.067, acc=97.8] 


Train Loss: 0.0666 | Train Acc: 97.77%

Epoch 21/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.41it/s, loss=0.0763, acc=97.4]


Train Loss: 0.0759 | Train Acc: 97.38%


Validation: 100%|██████████| 79/79 [00:00<00:00, 83.96it/s, loss=2.35, acc=60.1] 


Val Loss: 2.1097 | Val Acc: 60.08%

Epoch 22/100


Training: 100%|██████████| 313/313 [00:14<00:00, 22.12it/s, loss=0.0738, acc=97.5]


Train Loss: 0.0733 | Train Acc: 97.47%

Epoch 23/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.42it/s, loss=0.0734, acc=97.5]


Train Loss: 0.0729 | Train Acc: 97.49%

Epoch 24/100


Training: 100%|██████████| 313/313 [00:14<00:00, 21.38it/s, loss=0.0652, acc=97.8]


Train Loss: 0.0648 | Train Acc: 97.83%

Epoch 25/100


Training:  26%|██▌       | 82/313 [00:03<00:10, 21.73it/s, loss=0.0535, acc=98.2]