In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from cifar_model import resnet50
from tqdm import tqdm
import copy
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import classification_report

In [3]:
torch.__version__, torchvision.__version__

('1.4.0', '0.5.0')

In [4]:
device = 'cuda:0'

In [5]:
def train_val(model, optimizer, train_loader, val_loader, epochs=10):
    model.train()
    val_accs = []
    best_val_loss = np.inf
    patience_counter = 0
    for epoch in range(epochs):
        # Train
        train_loss, train_preds, train_labels = [], [], []
        model.train()
        pbar = tqdm(train_loader, position=0, leave=True)
        for i, (data, label) in enumerate(pbar):                                            
            data, label = data.to(device), label.to(device)
            optimizer.zero_grad()
            logits = model(data)
            loss = F.cross_entropy(logits, label)
            loss.backward()
            optimizer.step()
            preds = torch.argmax(logits, dim=1)
            acc = (preds == label).type(torch.FloatTensor).mean()
            train_loss.append(loss.item()), train_preds.append(preds), train_labels.append(label)
            train_loss_mean = torch.Tensor(train_loss).mean()
            train_acc_mean = (torch.cat(train_preds) == 
                              torch.cat(train_labels)).type(torch.FloatTensor).mean().item()
            pbar.set_postfix({'epoch': epoch, 
                              'loss': f'{train_loss_mean:.2f}', 
                              'acc': f'{train_acc_mean:.2f}'})
        # Validation
        model.eval()
        val_loss, val_preds, val_labels = [], [], []
        pbar = tqdm(val_loader, position=0, leave=True)
        for i, (data, label) in enumerate(pbar):
            data, label = data.to(device), label.to(device)
            with torch.no_grad():
                logits = model(data)
                loss = F.cross_entropy(logits, label).item()
                preds = torch.argmax(logits, dim=1)
                acc = (preds == label).type(torch.FloatTensor).mean().item()
                val_loss.append(loss), val_preds.append(preds), val_labels.append(label)
            val_loss_mean = torch.Tensor(val_loss).mean()
            val_acc_mean = (torch.cat(val_preds) == 
                            torch.cat(val_labels)).type(torch.FloatTensor).mean().item()
            pbar.set_postfix({'epoch': epoch, 
                              'val loss': f'{val_loss_mean:.2f}', 
                              'val acc': f'{val_acc_mean:.2f}'})

In [11]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
val_loader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
model = resnet50(num_classes=len(trainset.classes))
model = torch.nn.DataParallel(model)
model.to(device);

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=5e-4)
train_val(model, optimizer, train_loader, val_loader, epochs=20)

100%|██████████| 391/391 [01:09<00:00,  5.60it/s, epoch=0, loss=1.53, acc=0.44]
100%|██████████| 100/100 [00:07<00:00, 13.40it/s, epoch=0, val loss=1.16, val acc=0.60]
100%|██████████| 391/391 [01:00<00:00,  6.43it/s, epoch=1, loss=1.03, acc=0.64]
100%|██████████| 100/100 [00:07<00:00, 13.16it/s, epoch=1, val loss=0.91, val acc=0.68]
100%|██████████| 391/391 [01:01<00:00,  6.33it/s, epoch=2, loss=0.81, acc=0.72]
100%|██████████| 100/100 [00:07<00:00, 13.12it/s, epoch=2, val loss=0.81, val acc=0.72]
100%|██████████| 391/391 [01:00<00:00,  6.42it/s, epoch=3, loss=0.70, acc=0.76]
100%|██████████| 100/100 [00:07<00:00, 13.21it/s, epoch=3, val loss=0.80, val acc=0.74]
100%|██████████| 391/391 [01:00<00:00,  6.43it/s, epoch=4, loss=0.62, acc=0.79]
100%|██████████| 100/100 [00:07<00:00, 13.09it/s, epoch=4, val loss=0.90, val acc=0.70]
100%|██████████| 391/391 [01:00<00:00,  6.44it/s, epoch=5, loss=0.57, acc=0.80]
100%|██████████| 100/100 [00:07<00:00, 12.90it/s, epoch=5, val loss=0.64, val ac

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)
train_val(model, optimizer, train_loader, val_loader, epochs=2)

100%|██████████| 391/391 [01:00<00:00,  6.42it/s, epoch=0, loss=0.21, acc=0.93]
100%|██████████| 100/100 [00:07<00:00, 12.76it/s, epoch=0, val loss=0.25, val acc=0.92]
100%|██████████| 391/391 [01:00<00:00,  6.44it/s, epoch=1, loss=0.17, acc=0.94]
100%|██████████| 100/100 [00:07<00:00, 13.12it/s, epoch=1, val loss=0.25, val acc=0.92]


In [14]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)
train_val(model, optimizer, train_loader, val_loader, epochs=10)

100%|██████████| 391/391 [01:00<00:00,  6.42it/s, epoch=0, loss=0.15, acc=0.95]
100%|██████████| 100/100 [00:07<00:00, 13.56it/s, epoch=0, val loss=0.24, val acc=0.92]
100%|██████████| 391/391 [01:01<00:00,  6.41it/s, epoch=1, loss=0.14, acc=0.95]
100%|██████████| 100/100 [00:07<00:00, 13.62it/s, epoch=1, val loss=0.24, val acc=0.92]
100%|██████████| 391/391 [01:00<00:00,  6.44it/s, epoch=2, loss=0.13, acc=0.96]
100%|██████████| 100/100 [00:07<00:00, 13.27it/s, epoch=2, val loss=0.24, val acc=0.93]
100%|██████████| 391/391 [01:00<00:00,  6.46it/s, epoch=3, loss=0.12, acc=0.96]
100%|██████████| 100/100 [00:07<00:00, 13.43it/s, epoch=3, val loss=0.25, val acc=0.92]
100%|██████████| 391/391 [01:01<00:00,  6.40it/s, epoch=4, loss=0.12, acc=0.96]
100%|██████████| 100/100 [00:07<00:00, 13.44it/s, epoch=4, val loss=0.25, val acc=0.92]
100%|██████████| 391/391 [01:00<00:00,  6.45it/s, epoch=5, loss=0.11, acc=0.96]
100%|██████████| 100/100 [00:07<00:00, 12.52it/s, epoch=5, val loss=0.24, val ac