In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from datetime import datetime
import numpy as np
import torch.optim as optim
from torch.utils.data import random_split, DataLoader
from collections import Counter
from classes.MyMLP import MyMLP


In [2]:
SEED = 808
torch.manual_seed(SEED)

DEVICE = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print(f"Using device: {DEVICE}")

torch.set_default_dtype(torch.double)

Using device: cuda


In [3]:
def load_CIFAR2(train_val_split=.9, data_path='data', preprocessor=None):
    if preprocessor is None:
        preprocessor = transforms.Compose([
            transforms.ToTensor(),
        ])

    data_train_val = datasets.CIFAR10(
        data_path,
        train=True,
        download=True,
        transform=preprocessor)
    
    data_test = datasets.CIFAR10(
        data_path,
        train=False,
        download=True,
        transform=preprocessor)

    n_train = int(len(data_train_val)*train_val_split)
    n_val = len(data_train_val) - n_train

    data_train, data_val = random_split(
        data_train_val,
        [n_train, n_val],
        generator=torch.Generator()
    )
    
    label_map = {0: 0, 2: 1}
    class_names = ['airplane', 'bird']
    
    data_train = [(img, label_map[label]) for img, label in data_train if label in [0, 2]]
    data_val = [(img, label_map[label]) for img, label in data_val if label in [0, 2]]
    data_test = [(img, label_map[label]) for img, label in data_test if label in [0, 2]]

    print("Size of training set: ", len(data_train))
    print("Size of validation set: ", len(data_val))
    print("Size of test set: ", len(data_test))

    return (data_train, data_val, data_test)

cifar_train, cifar_val, cifar_test = load_CIFAR2()

Files already downloaded and verified
Files already downloaded and verified
Size of training set:  8956
Size of validation set:  1044
Size of test set:  2000


In [4]:
def train(n_epochs, optimizer, model, loss_fn, train_loader):
    print(f"Training {model} with optimizer")
    n_batch = len(train_loader)
    losses_train = []
    losses_val = []

    for epoch in range(1, n_epochs+1):
        model.train()
        loss_train = 0.0
        for imgs, labels in train_loader:
            imgs = imgs.to(device=DEVICE, dtype=torch.double)
            labels = labels.to(device=DEVICE)

            outputs = model(imgs)

            loss = loss_fn(outputs,labels)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            loss_train += loss.item()

        losses_train.append(loss_train / n_batch)

        if epoch == 1 or epoch % 5 == 0:
            print(f"{datetime.now().time()}, {epoch}, train_loss: {loss/n_batch}")

In [5]:
def train_manual_update(n_epochs, lr, model, loss_fn, train_loader, weight_decay=0, momentum=0):
    print(f"Training {model} with manual update")
    n_batch = len(train_loader)
    losses_train = []
    losses_val = []

    for epoch in range(1, n_epochs+1):
        model.train()
        loss_train = 0.0
        for imgs, labels in train_loader:
            imgs = imgs.to(device=DEVICE, dtype=torch.double)
            labels = labels.to(device=DEVICE)

            outputs = model(imgs)
            loss = loss_fn(outputs,labels)
            loss.backward()

            # Implement gradient descent here:
            with torch.no_grad():
                for p in model.parameters():
                    g_t = p.grad
                    if weight_decay != 0:
                        g_t = g_t + weight_decay*p.data

                    # if momentum != 0:
                    #     b_t_prev = None
                    #     if b_t_prev is not None:
                    #         b_t = momentum*b_t_prev + g_t
                    #     else:
                    #         b_t = g_t
                    #         b_t_prev = b_t

                    #     g_t = b_t
                    new_params = p.data-lr*g_t
                    p.copy_(new_params)
                
                        
                model.zero_grad()

            loss_train += loss.item()

        losses_train.append(loss_train / n_batch)

        if epoch == 1 or epoch % 5 == 0:
            print(f"{datetime.now().time()}, {epoch}, train_loss: {loss/n_batch}")
    pass

In [6]:
def compute_accuracy(model, loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device=DEVICE)
            labels = labels.to(device=DEVICE)

            outputs = model(imgs)
            _, predicted = torch.max(outputs, dim=1)
            total += labels.shape[0]
            correct += int((predicted == labels).sum())
    acc = correct / total
    return acc

In [7]:
def compare_models(n_epochs, batch_size):
    loss_fn = nn.CrossEntropyLoss()
    seed = SEED

    print("\tGlobal parameters:")
    print(f"Batch size: {batch_size}")
    print(f"Epochs: {n_epochs}")
    print(f"Loss function: {loss_fn}")
    print(f"Seed: {seed}")
    
    train_loader = DataLoader(cifar_train, shuffle=False, batch_size=batch_size)
    val_loader = DataLoader(cifar_val, shuffle=False, batch_size=batch_size)
    
    hyper_params = [
        {"lr": 0.01, "weight_decay": 0, "momentum": 0},
        {"lr": 0.01, "weight_decay": 0.01, "momentum": 0},
    ]

    models = []
    accuracies = []

    for hparam in hyper_params:
        print("\n", "="*50)
        print("\tCurrent parameters: ")
        [print(f"{key}: {value}") for key, value in hparam.items()]

        print("\n", "-"*6, "Using pytorch SGD", "-"*6)
        torch.manual_seed(SEED)
        model_auto = MyMLP()
        model_auto.to(device=DEVICE)
        optimizer = optim.SGD(model_auto.parameters(), **hparam)
        train(n_epochs, optimizer, model_auto, loss_fn, train_loader)
        train_acc_auto = compute_accuracy(model_auto, train_loader)
        val_acc_auto = compute_accuracy(model_auto, val_loader)

        models.append(model_auto)
        accuracies.append(val_acc_auto)
        
        print("\n", "-"*3, "Accuracies", "-"*3)
        print(f"Training accuracy: {train_acc_auto:.2f}")
        print(f"Validation accuracy: {val_acc_auto:.2f}")
        
        print("\n", "-"*6, "Using manual update", "-"*6)
        torch.manual_seed(SEED)
        model_manual = MyMLP()
        model_manual.to(device=DEVICE)
        train_manual_update(n_epochs, model=model_manual, train_loader=train_loader, loss_fn=loss_fn, **hparam)
        train_acc_manual = compute_accuracy(model_manual, train_loader)
        val_acc_manual = compute_accuracy(model_manual, val_loader)
        
        models.append(model_manual)
        accuracies.append(val_acc_manual)
        
        print("\n", "-"*3, "Accuracies", "-"*3)
        print(f"Training accuracy: {train_acc_manual:.2f}")
        print(f"Validation accuracy: {val_acc_manual:.2f}")


    return models, accuracies

In [8]:
n_epochs = 1
batch_size = 1
models, accuracies = compare_models(n_epochs, batch_size)

	Global parameters:
Batch size: 1
Epochs: 1
Loss function: CrossEntropyLoss()
Seed: 808

	Current parameters: 
lr: 0.01
weight_decay: 0
momentum: 0

 ------ Using pytorch SGD ------
Training MyMLP with optimizer
23:00:45.980795, 1, train_loss: 0.00010952912793141069

 --- Accuracies ---
Training accuracy: 0.79
Validation accuracy: 0.77

 ------ Using manual update ------
Training MyMLP with manual update
23:00:54.610955, 1, train_loss: 0.00010952912793141069

 --- Accuracies ---
Training accuracy: 0.79
Validation accuracy: 0.77

	Current parameters: 
lr: 0.01
weight_decay: 0.01
momentum: 0

 ------ Using pytorch SGD ------
Training MyMLP with optimizer
23:01:01.886148, 1, train_loss: 0.000131260954048647

 --- Accuracies ---
Training accuracy: 0.78
Validation accuracy: 0.77

 ------ Using manual update ------
Training MyMLP with manual update
23:01:12.019116, 1, train_loss: 0.000131260954048647

 --- Accuracies ---
Training accuracy: 0.78
Validation accuracy: 0.77


In [9]:
highest_acc = max(accuracies)
selected_model = models[accuracies.index(highest_acc)]

In [10]:
test_loader = DataLoader(cifar_test, shuffle=False, batch_size=batch_size)
acc = compute_accuracy(selected_model, test_loader)
print(f"Accuracy of selected model: {acc:.2f}")

Accuracy of selected model: 0.79
