In [25]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import numpy as np
from resnet import build_model  # Import the model builder from your existing file

In [26]:
BATCH_SIZE = 128            # Batch size for training.
EPOCHS = 100                # Number of epochs to train for.
SEED = 42                   # Random seed for reproducibility.
LR = 0.0001                 # Learning rate.
WEIGHT_DECAY = 5e-5         # Weight decay (L2 penalty).
VALIDATE_FREQ = 2           # Frequency of validation.
SAVE_DIR = "checkpoints"    # Directory to save checkpoints.
device = "cuda:2"
print(f"Using device: {device}")

Using device: cuda:2


In [27]:
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [28]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [29]:
# Training dataset with training transforms
train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=train_transform
)

# Validation dataset with test transforms
val_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=False, transform=test_transform
)

# Optionally, split indices if needed (or use Subset)
train_indices = list(range(0, int(0.8 * len(train_dataset))))
val_indices = list(range(int(0.8 * len(train_dataset)), len(train_dataset)))

train_dataset = torch.utils.data.Subset(train_dataset, train_indices)
val_dataset = torch.utils.data.Subset(val_dataset, val_indices)


Files already downloaded and verified


In [30]:
# Create data loaders
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
    num_workers=4, pin_memory=True
)

val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
    num_workers=4, pin_memory=True
)

In [31]:
def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc='Training')
    for inputs, targets in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': running_loss / (pbar.n + 1),
            'acc': 100. * correct / total
        })

    return running_loss / len(train_loader), 100. * correct / total


In [32]:
def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation')
        for inputs, targets in pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Update progress bar
            pbar.set_postfix({
                'loss': running_loss / (pbar.n + 1),
                'acc': 100. * correct / total
            })

    return running_loss / len(val_loader), 100. * correct / total


In [33]:
# Create save directory if it doesn't exist
os.makedirs(SAVE_DIR, exist_ok=True)



In [None]:
# Create model configuration for CIFAR-10
config = {
    'backbone_depth': 18,  # ResNet-18
    'num_classes': 10,  # CIFAR-10 has 10 classes
    'hidden_dim': 256,
    'dropout_prob': 0.5,
    'use_dropout': False,
    'pnp_class': None,  # No plug-and-play module for this example
}

# Build the model
model = build_model(config)
model = model.to(device)

In [35]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Training loop
best_val_acc = 0.0

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    
    # Train for one epoch
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")

    if epoch % VALIDATE_FREQ == 0:
        # Validate
        val_loss, val_acc = validate(model, val_loader, criterion, device)
                
        # Print statistics
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
        
        # Save checkpoint if validation accuracy improved
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            checkpoint_path = os.path.join(SAVE_DIR, 'best_model.pth')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'config': config,
            }, checkpoint_path)
            print(f"Checkpoint saved to {checkpoint_path}")


Epoch 1/100


Training: 100%|██████████| 313/313 [00:12<00:00, 25.76it/s, loss=1.58, acc=41.8]


Train Loss: 1.5845 | Train Acc: 41.78%


Validation: 100%|██████████| 79/79 [00:00<00:00, 110.35it/s, loss=1.54, acc=49.1]


Val Loss: 1.4645 | Val Acc: 49.11%
Checkpoint saved to checkpoints/best_model.pth

Epoch 2/100


Training: 100%|██████████| 313/313 [00:11<00:00, 26.19it/s, loss=1.27, acc=54.9]


Train Loss: 1.2578 | Train Acc: 54.88%

Epoch 3/100


Training: 100%|██████████| 313/313 [00:12<00:00, 25.71it/s, loss=1.09, acc=61.7]


Train Loss: 1.0839 | Train Acc: 61.69%


Validation: 100%|██████████| 79/79 [00:00<00:00, 97.80it/s, loss=1.12, acc=65.5]  


Val Loss: 0.9908 | Val Acc: 65.51%
Checkpoint saved to checkpoints/best_model.pth

Epoch 4/100


Training: 100%|██████████| 313/313 [00:12<00:00, 25.99it/s, loss=0.979, acc=65.7]


Train Loss: 0.9732 | Train Acc: 65.74%

Epoch 5/100


Training: 100%|██████████| 313/313 [00:11<00:00, 26.17it/s, loss=0.903, acc=68.4]


Train Loss: 0.8971 | Train Acc: 68.44%


Validation: 100%|██████████| 79/79 [00:00<00:00, 109.09it/s, loss=0.869, acc=70.4]


Val Loss: 0.8472 | Val Acc: 70.42%
Checkpoint saved to checkpoints/best_model.pth

Epoch 6/100


Training: 100%|██████████| 313/313 [00:12<00:00, 25.19it/s, loss=0.841, acc=70.6]


Train Loss: 0.8359 | Train Acc: 70.63%

Epoch 7/100


Training: 100%|██████████| 313/313 [00:12<00:00, 25.01it/s, loss=0.79, acc=72.6] 


Train Loss: 0.7853 | Train Acc: 72.59%


Validation: 100%|██████████| 79/79 [00:00<00:00, 107.69it/s, loss=0.845, acc=72.5]


Val Loss: 0.8026 | Val Acc: 72.49%
Checkpoint saved to checkpoints/best_model.pth

Epoch 8/100


Training: 100%|██████████| 313/313 [00:12<00:00, 25.90it/s, loss=0.752, acc=74]  


Train Loss: 0.7471 | Train Acc: 73.99%

Epoch 9/100


Training: 100%|██████████| 313/313 [00:12<00:00, 26.08it/s, loss=0.717, acc=75]  


Train Loss: 0.7123 | Train Acc: 74.99%


Validation: 100%|██████████| 79/79 [00:00<00:00, 117.59it/s, loss=0.963, acc=73.7]


Val Loss: 0.7926 | Val Acc: 73.66%
Checkpoint saved to checkpoints/best_model.pth

Epoch 10/100


Training: 100%|██████████| 313/313 [00:12<00:00, 25.48it/s, loss=0.689, acc=76.1]


Train Loss: 0.6844 | Train Acc: 76.11%

Epoch 11/100


Training: 100%|██████████| 313/313 [00:12<00:00, 25.92it/s, loss=0.655, acc=77.4]


Train Loss: 0.6509 | Train Acc: 77.44%


Validation: 100%|██████████| 79/79 [00:00<00:00, 97.43it/s, loss=0.807, acc=75.9] 


Val Loss: 0.6943 | Val Acc: 75.92%
Checkpoint saved to checkpoints/best_model.pth

Epoch 12/100


Training: 100%|██████████| 313/313 [00:12<00:00, 25.70it/s, loss=0.629, acc=78.3]


Train Loss: 0.6253 | Train Acc: 78.33%

Epoch 13/100


Training: 100%|██████████| 313/313 [00:12<00:00, 26.06it/s, loss=0.618, acc=78.8]


Train Loss: 0.6137 | Train Acc: 78.78%


Validation: 100%|██████████| 79/79 [00:00<00:00, 115.03it/s, loss=0.639, acc=78.3]


Val Loss: 0.6386 | Val Acc: 78.32%
Checkpoint saved to checkpoints/best_model.pth

Epoch 14/100


Training: 100%|██████████| 313/313 [00:12<00:00, 25.21it/s, loss=0.583, acc=80.1]


Train Loss: 0.5789 | Train Acc: 80.09%

Epoch 15/100


Training:  14%|█▎        | 43/313 [00:01<00:10, 25.70it/s, loss=0.554, acc=80.7]