# Your first deep neural network

# imports

In [None]:
import numpy as np
from random import randint
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

from torchsummary import summary

## Data

In [None]:
NUM_CLASSES = 10

In [None]:
batch_size = 32

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

# architecture

In [None]:
class FirstCNN(nn.Module):
    def __init__(self): 
        super(FirstCNN, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(3072, 200),
            nn.ReLU(inplace = True),
            nn.Linear(200, 150),
            nn.ReLU(inplace = True),
            nn.Linear(150, 10),
            nn.Softmax()
        )
        
    def forward(self, x):
        x = self.model(x)
        return x

In [None]:
net = FirstCNN().to("cuda")
print(net)

In [None]:
summary(net, (3, 32, 32))

# train

In [None]:
loss_fn = F.cross_entropy
opt = torch.optim.Adam(net.parameters(), lr = 0.0005)

In [None]:
def evaluation(dataloader):
    total, correct = 0, 0
    for data in dataloader:
        inputs, labels = data
        inputs, labels = inputs.to("cuda"), labels.to("cuda")
        outputs = net(inputs)
        _, pred = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (pred == labels).sum().item()
    return 100 * correct / total

In [None]:
%%time
loss_arr = []
loss_epoch_arr = []
max_epochs = 20

for epoch in range(max_epochs):

    for i, data in enumerate(trainloader, 0):

        inputs, labels = data
        inputs, labels = inputs.to("cuda"), labels.to("cuda")

        opt.zero_grad()

        outputs = net(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        opt.step()
        
        loss_arr.append(loss.item())
        
    loss_epoch_arr.append(loss.item())
        
    print('Epoch: %d/%d, Test acc: %0.2f, Train acc: %0.2f' % (epoch, max_epochs, evaluation(testloader), evaluation(trainloader)))
    
    
plt.plot(loss_epoch_arr)
plt.show()

# analysis

In [None]:
print('Test acc: %0.2f, Train acc: %0.2f' % (evaluation(testloader), evaluation(trainloader)))

In [None]:
fig = plt.figure(figsize=(15, 3))
fig.subplots_adjust(hspace=0.4, wspace=0.4)
n_to_show = 10

for i, data in enumerate(testloader, 0):
  
  if i > 9:
    break
  
  inputs, labels = data
  inputs, labels = inputs.to("cuda"), labels.to("cuda")

  outputs = net(inputs)
  _, pred = torch.max(outputs.data, 1)

  ax = fig.add_subplot(1, n_to_show, i+1)
  ax.axis('off')
  idx = randint(0, 31)
  ax.text(0.5, -0.35, 'pred = ' + str(classes[pred[idx]]), fontsize=10, ha='center', transform=ax.transAxes) 
  ax.text(0.5, -0.7, 'act = ' + str(classes[labels[idx]]), fontsize=10, ha='center', transform=ax.transAxes)
  ax.imshow(inputs[idx].permute(1, 2, 0).cpu())