In [1]:
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from datetime import datetime
import numpy as np
from torch.utils.data import random_split
from collections import Counter
from classes.MyMLP import MyMLP


In [2]:
SEED = 808
torch.manual_seed(SEED)

DEVICE = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print(f"Using device: {DEVICE}")

torch.set_default_dtype(torch.double)

Using device: cuda


In [3]:
def load_CIFAR2(train_val_split=.9, data_path='data', preprocessor=None):
    data_train_val = datasets.CIFAR10(
        data_path,
        train=True,
        download=True,
        transform=preprocessor)
    
    data_test = datasets.CIFAR10(
        data_path,
        train=False,
        download=True,
        transform=preprocessor)

    n_train = int(len(data_train_val)*train_val_split)
    n_val = len(data_train_val) - n_train

    data_train, data_val = random_split(
        data_train_val,
        [n_train, n_val],
        generator=torch.Generator()
    )
    
    label_map = {0: 0, 2: 1}
    class_names = ['airplane', 'bird']
    
    data_train = [(img, label_map[label]) for img, label in data_train if label in [0, 2]]
    data_val = [(img, label_map[label]) for img, label in data_val if label in [0, 2]]
    data_test = [(img, label_map[label]) for img, label in data_test if label in [0, 2]]

    print("Size of training set: ", len(data_train))
    print("Size of validation set: ", len(data_val))
    print("Size of test set: ", len(data_test))

    return (data_train, data_val, data_test)

cifar_train, cifar_val, cifar_test = load_CIFAR2()

Files already downloaded and verified
Files already downloaded and verified
Size of training set:  8956
Size of validation set:  1044
Size of test set:  2000


In [4]:
def train(n_epochs, optimizer, model, loss_fn, train_loader):
    n_batch = len(train_loader)
    losses_train = []
    losses_val = []

    for epoch in range(1, n_epochs+1):
        model.train()
        loss_train = 0.0
        for imgs, labels in train_loader:
            imgs = imgs.to(device=DEVICE, dtype=torch.double)
            labels = labels.to(device=DEVICE)

            outputs = model(imgs)

            loss = loss_fn(outputs,labels)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            loss_train += loss.item()

        losses_train.append(loss_train / n_batch)

        if epoch == 1 or epoch % 10 == 0:
            print(f"{datetime.now().time()}, {epoch}, train_loss: {loss/n_batch}")

In [None]:
def train_manual_update(n_epochs, lr, model, loss_fn, train_loader):
    pass