In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import numpy as np
from fpn_resnet import build_model  # Import the model builder from your existing file

In [None]:
BATCH_SIZE = 128            # Batch size for training.
EPOCHS = 100                # Number of epochs to train for.
SEED = 42                   # Random seed for reproducibility.
LR = 0.0001                 # Learning rate.
WEIGHT_DECAY = 5e-5         # Weight decay (L2 penalty).
VALIDATE_FREQ = 5           # Frequency of validation.
SAVE_DIR = "checkpoints"    # Directory to save checkpoints.

In [3]:
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [5]:
# Load the full training dataset
full_train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=train_transform
)

# Split the training dataset: 80% training, 20% validation
train_size = int(0.8 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size

train_dataset, val_dataset = random_split(
    full_train_dataset, [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

# Create a validation dataset with test transforms
val_dataset.dataset.transform = test_transform

Files already downloaded and verified


In [6]:
# Create data loaders
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
    num_workers=4, pin_memory=True
)

val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
    num_workers=4, pin_memory=True
)

In [7]:
def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc='Training')
    for inputs, targets in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': running_loss / (pbar.n + 1),
            'acc': 100. * correct / total
        })

    return running_loss / len(train_loader), 100. * correct / total


In [8]:
def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation')
        for inputs, targets in pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Update progress bar
            pbar.set_postfix({
                'loss': running_loss / (pbar.n + 1),
                'acc': 100. * correct / total
            })

    return running_loss / len(val_loader), 100. * correct / total


In [9]:
# Create save directory if it doesn't exist
os.makedirs(SAVE_DIR, exist_ok=True)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create model configuration for CIFAR-10
config = {
    'backbone_depth': 50,  # ResNet-50
    'num_classes': 10,  # CIFAR-10 has 10 classes
    'fpn_out_channels': 256,
    'hidden_dim': 256,
    'dropout_prob': 0.25,
    'use_dropout': True,
    'pnp_class': None,  # No plug-and-play module for this example
    'neck_final_only': True
}


Using device: cuda


In [10]:
# Build the model
model = build_model(config)
model = model.to(device)

In [11]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=LR,
    weight_decay=WEIGHT_DECAY,
)

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [12]:
# Training loop
best_val_acc = 0.0

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    
    # Train for one epoch
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")

    if epoch % VALIDATE_FREQ == 0:
        # Validate
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        # Update learning rate
        scheduler.step()
        
        # Print statistics
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
        
        # Save checkpoint if validation accuracy improved
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            checkpoint_path = os.path.join(SAVE_DIR, 'best_model.pth')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'config': config,
            }, checkpoint_path)
            print(f"Checkpoint saved to {checkpoint_path}")


Epoch 1/100


Training: 100%|██████████| 313/313 [00:35<00:00,  8.78it/s, loss=2.05, acc=21.7]


Train Loss: 2.0545 | Train Acc: 21.68%


Validation: 100%|██████████| 79/79 [00:01<00:00, 45.81it/s, loss=1.81, acc=31.4]


Val Loss: 1.8119 | Val Acc: 31.45%
Checkpoint saved to checkpoints/best_model.pth

Epoch 2/100


Training: 100%|██████████| 313/313 [00:35<00:00,  8.75it/s, loss=1.7, acc=36.3] 


Train Loss: 1.6966 | Train Acc: 36.26%

Epoch 3/100


Training: 100%|██████████| 313/313 [00:36<00:00,  8.64it/s, loss=1.52, acc=44.1]


Train Loss: 1.5191 | Train Acc: 44.14%

Epoch 4/100


Training: 100%|██████████| 313/313 [00:36<00:00,  8.63it/s, loss=1.38, acc=49.8]


Train Loss: 1.3827 | Train Acc: 49.83%

Epoch 5/100


Training: 100%|██████████| 313/313 [00:36<00:00,  8.57it/s, loss=1.26, acc=55.1]


Train Loss: 1.2572 | Train Acc: 55.08%

Epoch 6/100


Training: 100%|██████████| 313/313 [00:36<00:00,  8.64it/s, loss=1.15, acc=59.1]


Train Loss: 1.1455 | Train Acc: 59.13%


Validation: 100%|██████████| 79/79 [00:01<00:00, 45.17it/s, loss=1.39, acc=53.5]


Val Loss: 1.3002 | Val Acc: 53.55%
Checkpoint saved to checkpoints/best_model.pth

Epoch 7/100


Training: 100%|██████████| 313/313 [00:36<00:00,  8.64it/s, loss=1.04, acc=63.3] 


Train Loss: 1.0371 | Train Acc: 63.28%

Epoch 8/100


Training: 100%|██████████| 313/313 [00:36<00:00,  8.64it/s, loss=0.912, acc=67.8]


Train Loss: 0.9120 | Train Acc: 67.80%

Epoch 9/100


Training: 100%|██████████| 313/313 [00:36<00:00,  8.62it/s, loss=0.796, acc=72.2]


Train Loss: 0.7960 | Train Acc: 72.23%

Epoch 10/100


Training: 100%|██████████| 313/313 [00:36<00:00,  8.65it/s, loss=0.687, acc=76.1]


Train Loss: 0.6870 | Train Acc: 76.14%

Epoch 11/100


Training: 100%|██████████| 313/313 [00:36<00:00,  8.66it/s, loss=0.579, acc=80]  


Train Loss: 0.5788 | Train Acc: 80.00%


Validation: 100%|██████████| 79/79 [00:01<00:00, 45.16it/s, loss=1.43, acc=57.3]


Val Loss: 1.3767 | Val Acc: 57.31%
Checkpoint saved to checkpoints/best_model.pth

Epoch 12/100


Training: 100%|██████████| 313/313 [00:36<00:00,  8.67it/s, loss=0.484, acc=83.3]


Train Loss: 0.4844 | Train Acc: 83.28%

Epoch 13/100


Training: 100%|██████████| 313/313 [00:36<00:00,  8.68it/s, loss=0.404, acc=86.2]


Train Loss: 0.4041 | Train Acc: 86.17%

Epoch 14/100


Training: 100%|██████████| 313/313 [00:36<00:00,  8.68it/s, loss=0.349, acc=88]  


Train Loss: 0.3487 | Train Acc: 88.00%

Epoch 15/100


Training: 100%|██████████| 313/313 [00:36<00:00,  8.68it/s, loss=0.287, acc=90.4]


Train Loss: 0.2869 | Train Acc: 90.38%

Epoch 16/100


Training: 100%|██████████| 313/313 [00:35<00:00,  8.70it/s, loss=0.242, acc=91.7]


Train Loss: 0.2418 | Train Acc: 91.75%


Validation: 100%|██████████| 79/79 [00:01<00:00, 44.93it/s, loss=1.75, acc=57.2]


Val Loss: 1.7091 | Val Acc: 57.24%

Epoch 17/100


Training: 100%|██████████| 313/313 [00:35<00:00,  8.70it/s, loss=0.226, acc=92.3]


Train Loss: 0.2256 | Train Acc: 92.28%

Epoch 18/100


Training: 100%|██████████| 313/313 [00:36<00:00,  8.68it/s, loss=0.188, acc=93.7]


Train Loss: 0.1875 | Train Acc: 93.72%

Epoch 19/100


Training: 100%|██████████| 313/313 [00:36<00:00,  8.69it/s, loss=0.168, acc=94.4]


Train Loss: 0.1675 | Train Acc: 94.41%

Epoch 20/100


Training: 100%|██████████| 313/313 [00:35<00:00,  8.70it/s, loss=0.159, acc=94.7]


Train Loss: 0.1594 | Train Acc: 94.67%

Epoch 21/100


Training: 100%|██████████| 313/313 [00:35<00:00,  8.71it/s, loss=0.155, acc=94.8]


Train Loss: 0.1546 | Train Acc: 94.76%


Validation: 100%|██████████| 79/79 [00:01<00:00, 45.70it/s, loss=1.85, acc=58.2]


Val Loss: 1.8486 | Val Acc: 58.25%
Checkpoint saved to checkpoints/best_model.pth

Epoch 22/100


Training: 100%|██████████| 313/313 [00:36<00:00,  8.68it/s, loss=0.132, acc=95.6]


Train Loss: 0.1318 | Train Acc: 95.58%

Epoch 23/100


Training: 100%|██████████| 313/313 [00:35<00:00,  8.70it/s, loss=0.129, acc=95.6]


Train Loss: 0.1289 | Train Acc: 95.59%

Epoch 24/100


Training: 100%|██████████| 313/313 [00:35<00:00,  8.71it/s, loss=0.121, acc=95.9]


Train Loss: 0.1206 | Train Acc: 95.92%

Epoch 25/100


Training: 100%|██████████| 313/313 [00:35<00:00,  8.71it/s, loss=0.112, acc=96.2] 


Train Loss: 0.1123 | Train Acc: 96.20%

Epoch 26/100


Training: 100%|██████████| 313/313 [00:35<00:00,  8.70it/s, loss=0.113, acc=96.2] 


Train Loss: 0.1125 | Train Acc: 96.19%


Validation: 100%|██████████| 79/79 [00:01<00:00, 45.91it/s, loss=1.98, acc=59.7]


Val Loss: 1.9284 | Val Acc: 59.68%
Checkpoint saved to checkpoints/best_model.pth

Epoch 27/100


Training: 100%|██████████| 313/313 [00:36<00:00,  8.69it/s, loss=0.11, acc=96.3]  


Train Loss: 0.1104 | Train Acc: 96.28%

Epoch 28/100


Training: 100%|██████████| 313/313 [00:35<00:00,  8.72it/s, loss=0.099, acc=96.7] 


Train Loss: 0.0990 | Train Acc: 96.66%

Epoch 29/100


Training: 100%|██████████| 313/313 [00:35<00:00,  8.73it/s, loss=0.0952, acc=96.8]


Train Loss: 0.0952 | Train Acc: 96.77%

Epoch 30/100


Training: 100%|██████████| 313/313 [00:35<00:00,  8.71it/s, loss=0.0979, acc=96.8]


Train Loss: 0.0979 | Train Acc: 96.81%

Epoch 31/100


Training: 100%|██████████| 313/313 [00:35<00:00,  8.70it/s, loss=0.09, acc=97]    


Train Loss: 0.0900 | Train Acc: 96.97%


Validation: 100%|██████████| 79/79 [00:01<00:00, 45.15it/s, loss=1.97, acc=59.9]


Val Loss: 1.9216 | Val Acc: 59.93%
Checkpoint saved to checkpoints/best_model.pth

Epoch 32/100


Training: 100%|██████████| 313/313 [00:35<00:00,  8.76it/s, loss=0.0808, acc=97.3]


Train Loss: 0.0808 | Train Acc: 97.26%

Epoch 33/100


Training: 100%|██████████| 313/313 [00:35<00:00,  8.72it/s, loss=0.0879, acc=97.1]


Train Loss: 0.0879 | Train Acc: 97.07%

Epoch 34/100


Training: 100%|██████████| 313/313 [00:35<00:00,  8.71it/s, loss=0.0895, acc=96.9]


Train Loss: 0.0895 | Train Acc: 96.88%

Epoch 35/100


Training:  18%|█▊        | 55/313 [00:06<00:30,  8.38it/s, loss=0.0675, acc=97.6]


KeyboardInterrupt: 