In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import time
import torch.nn as nn
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.models.resnet import resnet18

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# Create datasets - Use data augmentation 
mean = (0.5,0.5,0.5)
std = (0.5,0.5,0.5)

classes = ["airplane",'automobile','bird','cat','deer',
          'dog','frog','horse','ship','truck']

train_transforms = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.RandomAffine(0,shear=7,scale=(0.9,1.1)),
    transforms.RandomRotation(15),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.1,0.1,0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean,std)
])
val_transforms = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize(mean,std)
])

train_ds = CIFAR10(root='cifar10_data',train=True,download=True,transform=train_transforms)
val_ds = CIFAR10(root='cifar10_data',train=False,download=True,transform=val_transforms)

In [None]:
net = resnet18()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(),lr=0.001)

In [None]:
epochs = 30
train_accuracies=[]
val_accuracies=[]
train_losses = []
val_losses = []
for epoch in range(epochs):
    t1 = time.time() 
    total_loss=0
    total_correct=0
    for index,(samples,labels) in enumerate(train_loader):
        samples = samples.to(device)
        labels = labels.to(device)
        outputs = net.forward(samples)
        loss = criterion(outputs,labels)
        preds = torch.argmax(outputs,1)
        total_loss += loss
        total_correct += torch.sum(preds == labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if index % 50==49:
            print('.',end='')
    with torch.no_grad():
        val_loss=0
        val_correct=0
        for samples,labels in val_loader:
            samples = samples.to(device)
            labels = labels.to(device)
            outputs = net.forward(samples)
            loss = criterion(outputs,labels)
            preds = torch.argmax(outputs,1)
            val_loss += loss
            val_correct += torch.sum(preds==labels)
        print('\nEpoch:',epoch+1,
             'Training loss:',round(total_loss.item()*batch_size/len(train_ds),3),
             'Training accuracy:',round(total_correct.item()/len(train_ds),3),
             'Val loss:',round(val_loss.item()*batch_size/len(val_ds),3),
             'Val accuracy:',round(val_correct.item()/len(val_ds),3),
             'Time taken:',round(time.time()-t1))
        train_accuracies.append(round(total_correct.item()/len(train_ds),3))
        val_accuracies.append(round(val_correct.item()/len(val_ds),3))
        train_losses.append(round(total_loss.item()*batch_size/len(train_ds),3))
        val_losses.append(round(val_loss.item()*batch_size/len(val_ds),3))
print("Training losses:",train_losses)
print("Training accuracies:",train_accuracies)
print("Validation losses:",val_losses)
print("Validation accuracies:",val_accuracies)