In [1]:
import torch
import torchvision
from torch.utils.data import random_split
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader

from resnet import resnet50

In [2]:
# Define data transforms
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load the CIFAR-10 dataset
trainval_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Split trainval_dataset into train and validation sets
train_size = int(0.8 * len(trainval_dataset))
val_size = len(trainval_dataset) - train_size
train_dataset, val_dataset = random_split(trainval_dataset, [train_size, val_size])

# Create data loaders for the train, validation, and test sets
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

# Load the pre-trained ResNet50 model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = resnet50().to(device)

In [4]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005)

In [5]:
# Define the validation function
def validate(model, testloader, criterion, device):
    correct = 0
    total = 0
    running_loss = 0.0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loss = criterion(outputs, labels)
            running_loss += loss.item()
    
    val_accuracy = 100 * correct / total
    val_loss = running_loss / len(testloader)
    return val_accuracy, val_loss

In [6]:
best_loss = float('inf')
patience = 15
counter = 0

# Train the model
for epoch in range(200):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # Validate the model every 1 epoch
    if (epoch + 1) % 1 == 0:
        val_accuracy, val_loss = validate(model, val_loader, criterion, device)
        print(f'Epoch {epoch + 1}, Loss {running_loss / len(train_loader):.4f}, Validation Loss {val_loss:.4f}, Validation Accuracy {val_accuracy:.2f}%')
        running_loss = 0.0

        if val_loss < best_loss:
            best_loss = val_loss
            model_name = f'models/resnet50_best_epoch.pth'
            torch.save(model.state_dict(), model_name)
            counter = 0

        else:
            counter += 1
            if counter >= patience:
                print(f'Validation loss did not improve for {patience} epochs, stopping training')
                break

Epoch 1, Loss 1.9802, Validation Loss 1.7223, Validation Accuracy 38.32%
Epoch 2, Loss 1.6507, Validation Loss 1.5618, Validation Accuracy 43.43%
Epoch 3, Loss 1.5087, Validation Loss 1.5345, Validation Accuracy 45.53%
Epoch 4, Loss 1.5304, Validation Loss 1.4134, Validation Accuracy 48.67%
Epoch 5, Loss 1.4357, Validation Loss 1.4067, Validation Accuracy 50.70%
Epoch 6, Loss 1.2985, Validation Loss 1.3060, Validation Accuracy 54.35%
Epoch 7, Loss 1.2170, Validation Loss 1.1512, Validation Accuracy 58.74%
Epoch 8, Loss 1.1680, Validation Loss 1.1657, Validation Accuracy 58.23%
Epoch 9, Loss 1.0706, Validation Loss 1.0541, Validation Accuracy 62.81%
Epoch 10, Loss 1.0041, Validation Loss 1.0256, Validation Accuracy 63.85%
Epoch 11, Loss 0.9567, Validation Loss 0.9981, Validation Accuracy 65.48%
Epoch 12, Loss 0.9054, Validation Loss 0.9648, Validation Accuracy 66.73%
Epoch 13, Loss 0.8741, Validation Loss 0.9539, Validation Accuracy 67.44%
Epoch 14, Loss 0.8485, Validation Loss 0.8750, 

In [8]:
model.load_state_dict(torch.load(f'models/resnet50_best_epoch.pth'))

# Evaluate the final model on the test set
test_accuracy, _ = validate(model, test_loader, criterion, device)
print(f'Test Accuracy: {test_accuracy}')

Test Accuracy: 81.48
