In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader

import numpy as np
import copy
import time
import os
from tqdm.notebook import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

Using device: cuda


In [None]:
mnist_mean = 0.5
mnist_std = 0.5
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mnist_mean, mnist_std),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mnist_mean, mnist_std),
])

# Load datasets
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)

# Create DataLoaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
def get_resnet18_for_mnist(pretrained=True):
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
    model.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    # Modify final layer
    model.fc = nn.Linear(512, 10)
    return model

In [None]:
# Training parameters
NUM_EPOCHS = 15
CRITERION = nn.CrossEntropyLoss()

# ResNet-specific parameters
RESNET_LR = 1.15e-3
RESNET_MOMENTUM = 0.95
RESNET_WD = 0.025

In [None]:
def train_model(model, optimizer, scheduler, train_loader, test_loader, device, model_name):
    history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}
    best_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(NUM_EPOCHS):
        print(f'Epoch {epoch+1}/{NUM_EPOCHS}')
        print('-' * 10)

        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in tqdm(train_loader, desc=f'Training {model_name}', leave=False):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = CRITERION(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        if scheduler:
            scheduler.step()

        epoch_loss = running_loss / total
        epoch_acc = correct / total
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc)

        # Evaluation phase
        model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0

        with torch.no_grad():
            for inputs, labels in tqdm(test_loader, desc=f'Testing {model_name}', leave=False):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = CRITERION(outputs, labels)

                test_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs.data, 1)
                test_total += labels.size(0)
                test_correct += (predicted == labels).sum().item()

        test_epoch_loss = test_loss / test_total
        test_epoch_acc = test_correct / test_total
        history['test_loss'].append(test_epoch_loss)
        history['test_acc'].append(test_epoch_acc)

        print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        print(f'Test Loss: {test_epoch_loss:.4f} Acc: {test_epoch_acc:.4f}')

        # Save best model
        if test_epoch_acc > best_acc:
            best_acc = test_epoch_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            print(f"New best accuracy: {best_acc:.4f}")

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model, history

In [None]:
# Initialize models
resnet18 = get_resnet18_for_mnist(pretrained=True).to(device)

# ResNet optimizer/scheduler
optimizer_resnet = optim.SGD(resnet18.parameters(), lr=RESNET_LR,
                            momentum=RESNET_MOMENTUM, weight_decay=RESNET_WD)
scheduler_resnet = lr_scheduler.CosineAnnealingLR(optimizer_resnet, T_max=NUM_EPOCHS)

print("Training ResNet-18...")
resnet18, resnet_history = train_model(resnet18, optimizer_resnet, scheduler_resnet,
                                      train_loader, test_loader, device, "ResNet-18")

Training ResNet-18...
Epoch 1/15
----------


Training ResNet-18:   0%|          | 0/938 [00:00<?, ?it/s]

Testing ResNet-18:   0%|          | 0/157 [00:00<?, ?it/s]

Train Loss: 0.1443 Acc: 0.9567
Test Loss: 0.1192 Acc: 0.9785
New best accuracy: 0.9785
Epoch 2/15
----------


Training ResNet-18:   0%|          | 0/938 [00:00<?, ?it/s]

Testing ResNet-18:   0%|          | 0/157 [00:00<?, ?it/s]

Train Loss: 0.0706 Acc: 0.9828
Test Loss: 0.1292 Acc: 0.9786
New best accuracy: 0.9786
Epoch 3/15
----------


Training ResNet-18:   0%|          | 0/938 [00:00<?, ?it/s]

Testing ResNet-18:   0%|          | 0/157 [00:00<?, ?it/s]

Train Loss: 0.0889 Acc: 0.9822
Test Loss: 0.1135 Acc: 0.9832
New best accuracy: 0.9832
Epoch 4/15
----------


Training ResNet-18:   0%|          | 0/938 [00:00<?, ?it/s]

Testing ResNet-18:   0%|          | 0/157 [00:00<?, ?it/s]

Train Loss: 0.1142 Acc: 0.9817
Test Loss: 0.1780 Acc: 0.9665
Epoch 5/15
----------


Training ResNet-18:   0%|          | 0/938 [00:00<?, ?it/s]

Testing ResNet-18:   0%|          | 0/157 [00:00<?, ?it/s]

Train Loss: 0.1406 Acc: 0.9784
Test Loss: 0.1853 Acc: 0.9762
Epoch 6/15
----------


Training ResNet-18:   0%|          | 0/938 [00:00<?, ?it/s]

Testing ResNet-18:   0%|          | 0/157 [00:00<?, ?it/s]

Train Loss: 0.1429 Acc: 0.9795
Test Loss: 0.1515 Acc: 0.9787
Epoch 7/15
----------


Training ResNet-18:   0%|          | 0/938 [00:00<?, ?it/s]

Testing ResNet-18:   0%|          | 0/157 [00:00<?, ?it/s]

Train Loss: 0.1379 Acc: 0.9800
Test Loss: 0.1264 Acc: 0.9821
Epoch 8/15
----------


Training ResNet-18:   0%|          | 0/938 [00:00<?, ?it/s]

Testing ResNet-18:   0%|          | 0/157 [00:00<?, ?it/s]

Train Loss: 0.1270 Acc: 0.9829
Test Loss: 0.1403 Acc: 0.9811
Epoch 9/15
----------


Training ResNet-18:   0%|          | 0/938 [00:00<?, ?it/s]

Testing ResNet-18:   0%|          | 0/157 [00:00<?, ?it/s]

Train Loss: 0.1199 Acc: 0.9848
Test Loss: 0.1140 Acc: 0.9849
New best accuracy: 0.9849
Epoch 10/15
----------


Training ResNet-18:   0%|          | 0/938 [00:00<?, ?it/s]

Testing ResNet-18:   0%|          | 0/157 [00:00<?, ?it/s]

Train Loss: 0.1113 Acc: 0.9866
Test Loss: 0.1024 Acc: 0.9875
New best accuracy: 0.9875
Epoch 11/15
----------


Training ResNet-18:   0%|          | 0/938 [00:00<?, ?it/s]

Testing ResNet-18:   0%|          | 0/157 [00:00<?, ?it/s]

Train Loss: 0.1061 Acc: 0.9882
Test Loss: 0.1113 Acc: 0.9867
Epoch 12/15
----------


Training ResNet-18:   0%|          | 0/938 [00:00<?, ?it/s]

Testing ResNet-18:   0%|          | 0/157 [00:00<?, ?it/s]

Train Loss: 0.0992 Acc: 0.9900
Test Loss: 0.1040 Acc: 0.9906
New best accuracy: 0.9906
Epoch 13/15
----------


Training ResNet-18:   0%|          | 0/938 [00:00<?, ?it/s]

Testing ResNet-18:   0%|          | 0/157 [00:00<?, ?it/s]

Train Loss: 0.0928 Acc: 0.9918
Test Loss: 0.1027 Acc: 0.9910
New best accuracy: 0.9910
Epoch 14/15
----------


Training ResNet-18:   0%|          | 0/938 [00:00<?, ?it/s]

Testing ResNet-18:   0%|          | 0/157 [00:00<?, ?it/s]

Train Loss: 0.0876 Acc: 0.9930
Test Loss: 0.0887 Acc: 0.9924
New best accuracy: 0.9924
Epoch 15/15
----------


Training ResNet-18:   0%|          | 0/938 [00:00<?, ?it/s]

Testing ResNet-18:   0%|          | 0/157 [00:00<?, ?it/s]

Train Loss: 0.0844 Acc: 0.9940
Test Loss: 0.0940 Acc: 0.9928
New best accuracy: 0.9928
