<a href="https://colab.research.google.com/github/ssundar6087/Deep-Learning-Mini-Course/blob/main/Pytorch/DL_Minicourse_Pytorch_Day_7.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab"/></a>

# It's the final countdown!
Today, we'll wrap up the course by implementing checkpointing, early stopping and trying out some transfer learning. Congratulations on making it so far! 🙌

# Image Classification Pytorch

In [None]:
# Imports
import torch 
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np

## Model Definition
Replace the definition below with your best conv net from the previous exercise

In [None]:
class BabyThanos(nn.Module):
  def __init__(self, in_dims, in_channels, n_classes=10):
    super().__init__()
    
    self.in_dims = in_dims 
    self.in_channels = in_channels
    self.n_classes = n_classes
    self.k = 5
    self.n_filters1 = 6
    self.n_filters2 = 16
    self.fc1_dim = 100
    self.fc2_dim = 40
    self.pool_size = 2
    self.pool_stride = 2
    self.final_dim = self.__compute_flattened_dim__(num_conv_pools=2)

    # define the layers here
    self.conv1 = nn.Conv2d(self.in_channels, self.n_filters1, self.k)
    self.pool = nn.MaxPool2d(self.pool_size, self.pool_stride)
    self.conv2 = nn.Conv2d(self.n_filters1, self.n_filters2, self.k)
    self.fc1 = nn.Linear(self.n_filters2 * self.final_dim * self.final_dim, self.fc1_dim)
    self.fc2 = nn.Linear(self.fc1_dim, self.fc2_dim)
    self.fc3 = nn.Linear(self.fc2_dim, self.n_classes)


  def __compute_flattened_dim__(self, num_conv_pools):
    final_dim = self.in_dims
    for i in range(num_conv_pools):
      final_dim = (final_dim + 0 - self.k) // 1 + 1
      final_dim = final_dim // self.pool_size
    return final_dim

  def forward(self, x):
      x = self.conv1(x)
      x = F.relu(x) 
      x = self.pool(x)
      x = self.conv2(x)
      x = F.relu(x) 
      x = self.pool(x)
      x = torch.flatten(x, 1) 
      x = F.relu(self.fc1(x)) 
      x = F.relu(self.fc2(x)) 
      x = self.fc3(x)
      return x
    

In [None]:
IN_DIMS = 32 
IN_CHANNELS = 3
N_CLASSES = 10
net = BabyThanos(in_dims=IN_DIMS, in_channels=IN_CHANNELS, n_classes=N_CLASSES)

In [None]:
print(net)

## Hyperparameters 
**Note:** Use the best values from the previous exercise

In [None]:
BATCH_SIZE = 64
EPOCHS = 20
LEARNING_RATE = 1e-3

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

## Define Optimizer & Loss Function 
These two functions allow us to help baby thanos learn from his mistakes.

Use the best choices from the previous exercise

In [None]:
criterion = nn.CrossEntropyLoss() # Loss Function
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE) # Optimizer 

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

##Train and Evaluate the Network 
The training and validation loops call the same set of functions over and over, so we'll package them into separate functions. Note that the validation loop does not have any optimizer calls. 

In [None]:
from tqdm.notebook import tqdm
def train_step(model, train_loader, optimizer, criterion):
  model.train()
  epoch_loss = []
  total, correct = 0, 0

  for i, batch in tqdm(enumerate(train_loader), 
                       total=len(train_loader),
                       leave=False,
                       ):
    images, labels = batch
    images = images.to(device)
    labels = labels.to(device)

    optimizer.zero_grad() # Erase history - clean slate

    predictions = model(images) # forward -> model (images) -> make guesses on labels
    loss = criterion(predictions, labels) # how did I do?
    epoch_loss.append(loss.item())
    _, predicted = torch.max(predictions.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item() # Accuracy score
    loss.backward() # backward pass
    optimizer.step() # Update the weights using gradients

  
  return np.mean(epoch_loss), correct / total


In [None]:
def valid_step(model, val_loader, criterion):
  model.eval()
  epoch_loss = []
  total, correct = 0, 0

  with torch.no_grad():
    for i, batch in tqdm(enumerate(val_loader), 
                        total=len(val_loader),
                        leave=False,
                        ):
      images, labels = batch
      images = images.to(device)
      labels = labels.to(device)
      # Note that there's no optimizer here
      predictions = model(images)
      loss = criterion(predictions, labels)
      epoch_loss.append(loss.item())
      _, predicted = torch.max(predictions.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()
  
  return np.mean(epoch_loss), correct / total

 ## **YOUR EXERCISE HERE: Checkpointing and Early Stopping 👇** 

 Implement early stopping and checkpointing below. If you are stuck, check out these hints

 **HINT:** https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py and https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html

In [None]:
net = net.to(device)
losses = {"train_loss": [], "val_loss": []}
accuracies = {"train_acc": [], "val_acc": []}
epochs = []
for epoch in tqdm(range(EPOCHS), total=EPOCHS):
  train_loss, train_acc = train_step(net, 
                                     trainloader, 
                                     optimizer, 
                                     criterion,)
  
  val_loss, val_acc = valid_step(net, 
                                 testloader, 
                                 criterion,
                                 )
  
  #TODO: YOUR CODE HERE -
  # 1. Checkpoint the model 

  # 2. Check for Early Stopping
  
  # END OF YOUR CODE
  
  losses["train_loss"].append(train_loss)
  losses["val_loss"].append(val_loss)
  accuracies["train_acc"].append(train_acc)
  accuracies["val_acc"].append(val_acc)
  epochs.append(epoch)

  print(f'[{epoch + 1}] train loss: {train_loss}  train accuracy: {train_acc}  val loss: {val_loss}  val accuracy: {val_acc}')

## Plot the Loss and Accuracy of our Model

In [None]:
plt.figure(figsize=(8,8))
plt.plot(epochs, losses["train_loss"], label="train")
plt.plot(epochs, losses["val_loss"], label="val")
plt.xlabel("epochs")
plt.ylabel("loss")
plt.grid("on")
plt.legend()
plt.title("Loss vs Epochs")

In [None]:
plt.figure(figsize=(8,8))
plt.plot(epochs, accuracies["train_acc"], label="train")
plt.plot(epochs, accuracies["val_acc"], label="val")
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.grid("on")
plt.legend()
plt.title("Accuracy vs Epochs")

## **YOUR EXERCISE HERE: Transfer Learning 👇** 

Repeat the same exercise but swap out the model definition with this one first

Try out other models from here https://pytorch.org/vision/stable/models.html

In [None]:
class ThanosWithInfinityStones(nn.Module):
  def __init__(self):
    super().__init__()
    self.backbone = torchvision.models.resnet18(pretrained=True)
    num_ftrs = self.backbone.fc.in_features # model = models.resnet18() print(model)
    self.fc1_dim = 128
    self.backbone.fc = nn.Linear(num_ftrs, self.fc1_dim)
    self.clf_head = nn.Linear(self.fc1_dim, len(classes))
  
  def forward(self, x):
    x = self.backbone(x)
    x = self.clf_head(F.relu(x)) # (batch_size, num_classes)
    return x

In [None]:
net = ThanosWithInfinityStones()