# Weight and Biases Lab

**Notebook created by [JuanJo Nieto](https://www.linkedin.com/in/juan-jose-nieto-salas/) for Artificial Intelligence with Deep Learning Postgraduate course.**

**Update 1: 07/03/2021**

**Update 2: 25/10/2021 Laia Tarrés and JuanJo Nieto**


This lab follows previous labs [1](https://colab.research.google.com/drive/182VXgrR08KIAWP8h-xKY6w8NX8oKV9bE?usp=sharing). 

*  In the former we trained an autoencoder and a classifier with plain Pytorch and used Tensorboard for logging all the different variables, images and graphs.

In this lab we are going to substitute Tensorboard for a cloud-based alternative called Wandb. Just like Tensorboard, wandb has many different logging options. In this lab we will explore some of the most useful features. 

In [None]:
!pip install wandb --quiet

In [None]:
import copy
import os
import numpy as np

import itertools
import tensorflow
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

import datetime
from time import time

import wandb

import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import MNIST

For this next cell to login, you will need to have your own wandb account, by following the [link](https://app.wandb.ai/login?signup=true).

In [None]:
wandb.login()

## Make sure your runtime has a GPU


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
assert not device.type == 'cpu', "Change Runtime Type -> GPU"

# Set a fixed seed

In [None]:
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

## Hyperparameters

In [None]:
num_run = 0
conf = {
    "latent_dims": 64,
    "num_epochs": 10,
    "batch_size": 128,
    "capacity": 64,
    "learning_rate": 1e-3,
    "subset_len": 5120
}

## Create train and validation dataloaders

In [None]:
# Define image transformations
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Define training and validation sets
train_dataset = MNIST(root='./data/MNIST', download=True, transform=img_transform, train=True)
val_dataset = MNIST(root='./data/MNIST', download=True, transform=img_transform, train=False)

# We don't need the whole dataset, we'll pick a subset
train_dataset = Subset(train_dataset, list(range(conf["subset_len"])))
val_dataset = Subset(train_dataset, list(range(conf["subset_len"])))

# To iterate over batches of data we will use these dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=conf["batch_size"], shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=conf["batch_size"], shuffle=False)

## Define Encoder and Decoder networks

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        c = conf["capacity"]
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=c, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=c, out_channels=c*2, kernel_size=4, stride=2, padding=1)
        self.linear = nn.Linear(in_features=c*2*7*7, out_features=conf["latent_dims"])
            
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.c = conf["capacity"]
        self.fc = nn.Linear(in_features=conf["latent_dims"], out_features=self.c*2*7*7)
        self.conv2 = nn.ConvTranspose2d(in_channels=self.c*2, out_channels=self.c, kernel_size=4, stride=2, padding=1)
        self.conv1 = nn.ConvTranspose2d(in_channels=self.c, out_channels=1, kernel_size=4, stride=2, padding=1)
            
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), self.c*2, 7, 7)
        x = F.relu(self.conv2(x))
        x = torch.tanh(self.conv1(x))
        return x

## Join them within the AutoEncoder class

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
model = AutoEncoder()

In [None]:
model

In [None]:
# Put the model in the GPU
model = model.to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Number of parameters: %d' % num_params)

# **Exercise 1.  Log weights and gradients histograms**
# **Exercise 2.  Log train and validation reconstruction loss**
# **Exercise 3.  Log reconstructed images from validation set**

In [None]:
def forward_image(image_batch):
    image_batch = image_batch.to(device)
    image_batch_recon = model(image_batch)
    return F.mse_loss(image_batch_recon, image_batch), image_batch_recon

In [None]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=conf['learning_rate'], weight_decay=1e-5)

# Initialize wandb run
wandb.finish() # This is needed just in case there was a wandb run from a previous execution
wandb.init(project="mnist_colab")
wandb.run.name = f'run_{num_run}_ae'


# TODO: Log weights and gradients to wandb
wandb... # use wandb.watch with a log_freq=100.


ini = time()
for epoch in range(conf['num_epochs']):
    
    train_loss = []
    val_loss = []    

    model.eval()
    for i,(image_batch, _) in enumerate(val_dataloader):
        with torch.no_grad():
            loss, recon = forward_image(image_batch)
            if i == 0:

                # make_grid returns an image tensor from a batch of data (https://pytorch.org/vision/stable/utils.html#torchvision.utils.make_grid)
                grid = make_grid(recon)

                # TODO: Log a batch of reconstructed images from the validation set. Use the grid variable defined above.
                wandb... #use wandb.log and wand.Image(), to print the reconstructed images once for every epoch

            val_loss.append(loss.item())
        
    val_loss_avg = np.mean(val_loss)

    # TODO: Log validation reconstruction loss to wandb
    wandb... #use label 'Reconstruction/val_loss', and print the loss once for every epoch


    model.train()
    for image_batch, _ in train_dataloader:
        
        loss, _ = forward_image(image_batch)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss.append(loss.item())
    
    train_loss_avg = np.mean(train_loss)

    # TODO: Log train reconstruction loss to wandb
    wandb... #use label 'Reconstruction/train_loss', and print the loss once for every epoch

    
    print(f"Epoch [{epoch} / {conf['num_epochs']}] average reconstruction error: {train_loss_avg}")

print(f"Training took {time()-ini} seconds")
wandb.finish()

## Define Classifier module

In [None]:
class Classifier(nn.Module):
  def __init__(self):
    super(Classifier, self).__init__()
    self.encoder = Encoder()
    self.linear = nn.Linear(conf['latent_dims'], 10)

  def forward(self, x):
    x = self.encoder(x)
    return self.linear(x)

## Copy the encoder weights from AutoEncoder to our Classifier

In [None]:
classifier = Classifier().to(device)
classifier.encoder = copy.deepcopy(model.encoder)

In [None]:
optimizer = torch.optim.Adam(classifier.parameters(), lr=conf['learning_rate'], weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()

In [None]:
def compute_accuracy(preds, labels):
    pred_labels = preds.argmax(dim=1, keepdim=True)
    acc = pred_labels.eq(labels.view_as(pred_labels)).sum().item()/len(pred_labels)
    return acc

def log_confusion_matrix(preds, labels):
    predictions = preds.argmax(dim=1, keepdim=True)
    
    cm = confusion_matrix(labels.cpu(), predictions.cpu())
    cm = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)

    fig = plt.figure(figsize=(8,8))
    plt.imshow(cm, interpolation='none', cmap=plt.cm.Blues)

    plt.colorbar()
    tick_marks = np.arange(10)

    plt.xticks(tick_marks, np.arange(0,10))
    plt.yticks(tick_marks, np.arange(0,10))

    plt.tight_layout()
    threshold = cm.max() / 2.

    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        color = "white" if cm[i, j] > threshold else "black"
        plt.text(j, i, cm[i, j], horizontalalignment="center", color=color)

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.title("Confusion matrix")
    return fig

# **Exercise 4. Log confusion matrix figures**
# **Exercise 5. Log train and validation loss and accuracy**

In [None]:
# Initialize wandb run
wandb.finish() # This is needed just in case there was a wandb run from a previous execution
wandb.init(project="mnist_colab")
wandb.run.name = f'run_{num_run}_cl'

for epoch in range(conf['num_epochs']):

    classifier.eval()
    val_loss_epoch = []
    val_acc_epoch = []
    with torch.no_grad():
        for image_batch, label_batch in val_dataloader:

            img, label = image_batch.to(device), label_batch.to(device)
            predictions = classifier(img)
            loss = criterion(predictions, label.long())
            
            val_loss_epoch.append(loss.item())
            val_acc_epoch.append(compute_accuracy(predictions, label))

        fig = log_confusion_matrix(predictions, label)
        mean_loss = np.mean(val_loss_epoch)
        mean_acc = np.mean(val_acc_epoch)

        # TODO: Log confusion matrix figure to wandb
        wandb... #use wandb-log and fig directly, and log the confusion matrix once for every epoch.

        # TODO: Log validation loss to wandb
        wandb... #use the tag 'Classification/val_loss'

        # TODO: Log validation accuracy to wandb
        wandb... #use the tag 'Classification/val_acc'
        

    classifier.train()
    loss_epoch = []
    acc_epoch = []
    for image_batch, label_batch in train_dataloader:
        
        img, label = image_batch.to(device), label_batch.to(device)
        predictions = classifier(img)
        loss = criterion(predictions, label.long())

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loss_epoch.append(loss.item())     
        acc_epoch.append(compute_accuracy(predictions, label))

        
    mean_loss = np.mean(loss_epoch)
    mean_acc = np.mean(acc_epoch)


    # TODO: Log training loss to wandb
        wandb... #use the tag 'Classification/train_loss'

    # TODO: Log train accuracy to wandb
    wandb... #use the tag 'Classification/train_acc'
   
    
    print(f"Train Epoch: {epoch} Loss: {mean_loss} Acc: {mean_acc}")

wandb.finish()



#**EXTRA**

If there is still some time remaining, you can try experimenting with different values for the embedding dimension, and generate a report that compares the performance of the models. 

*Suggestion*: try replicating [this](https://wandb.ai/juanjo3ns/mnist_colab/reports/MNIST_COLAB--Vmlldzo1MDIxOTE?accessToken=tyl6j6yot3es3s1iisam3gzjhaoyoxyml2ini4blvbmyzcny2lb0v32t9xs8rfu6) report