# Wandb Lab

**Notebook created by [JuanJo Nieto](https://www.linkedin.com/in/juan-jose-nieto-salas/) for Artificial Intelligence with Deep Learning Postgraduate course.**

**Last update: 07/03/2021**



This lab follows previous labs [1](https://colab.research.google.com/drive/182VXgrR08KIAWP8h-xKY6w8NX8oKV9bE?usp=sharing) and [2](https://colab.research.google.com/drive/1Riz1h9-gk01Jl80R1scvutYCfC85bnJy?usp=sharing). 

*  In the former we trained an autoencoder and a classifier with plain Pytorch and used Tensorboard for logging all the different variables, images and graphs.
*  In the later, we organized the code for making it more readable using Pytorch Lightning, but we still used Tensorboard for logging.

In this lab we are going to substitute Tensorboard for a cloud-based alternative called Wandb.


In [None]:
!pip install wandb pytorch-lightning==1.0.8 --quiet

In [None]:
import copy
import os
import numpy as np

import itertools
import tensorflow
import torch
import tensorboard
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

from torchvision.utils import make_grid

import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

import pytorch_lightning as pl


When you initialize wandb you’ll be asked for a token. You can get it [here](https://wandb.ai/authorize).

In [None]:
import wandb
wandb.login()

With pytorch lightning we can use custom loggers very easily:

In [None]:
from pytorch_lightning.loggers import WandbLogger

In [None]:
# Avoid MNIST download crashing
from six.moves import urllib
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

## Hyperparameters

In [None]:
num_run = 0
conf = {
    "latent_dims": 64,
    "num_epochs": 10,
    "batch_size": 128,
    "capacity": 64,
    "learning_rate": 1e-3
}

In [None]:
# TODO: Initialize WandbLogger
wandb_logger = ...

# TODO: Log your configuration for quick comparison with other runs.
wandb_logger...

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        c = conf['capacity']
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=c, kernel_size=4, stride=2, padding=1) # out: c x 14 x 14
        self.conv2 = nn.Conv2d(in_channels=c, out_channels=c*2, kernel_size=4, stride=2, padding=1) # out: c x 7 x 7
        self.linear = nn.Linear(in_features=c*2*7*7, out_features=conf['latent_dims'])
            
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1) # flatten batch of multi-channel feature maps to a batch of feature vectors
        x = self.linear(x)
        return x

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        c = conf['capacity']
        self.fc = nn.Linear(in_features=conf['latent_dims'], out_features=c*2*7*7)
        self.conv2 = nn.ConvTranspose2d(in_channels=c*2, out_channels=c, kernel_size=4, stride=2, padding=1)
        self.conv1 = nn.ConvTranspose2d(in_channels=c, out_channels=1, kernel_size=4, stride=2, padding=1)
            
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), conf['capacity']*2, 7, 7) # unflatten batch of feature vectors to a batch of multi-channel feature maps
        x = F.relu(self.conv2(x))
        x = torch.tanh(self.conv1(x)) # last layer before output is sigmoid, since we are using BCE as reconstruction loss
        return x

In [None]:
class AutoEncoder(pl.LightningModule):
    def __init__(self):

        super().__init__()


        self.encoder = Encoder()
        self.decoder = Decoder()

        self.image_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

        self.criterion = F.mse_loss

        self.example_input_array = torch.rand(128, 1, 28, 28)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def training_step(self, batch, batch_idx):
        img, _ = batch
        reconstruction = self(img)
        loss = self.criterion(reconstruction, img)
        self.log('Reconstruction/train_loss', loss, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        img, _ = batch
        with torch.no_grad():
            valid_reconstruction = self(img)
            val_loss = self.criterion(valid_reconstruction, img)
            self.log('Reconstruction/val_loss', val_loss, on_step=False, on_epoch=True)

            if batch_idx == 0:
                grid = make_grid(valid_reconstruction)
                grid = grid.permute(1,2,0)
                self.logger.experiment.log({"Images": [wandb.Image(grid.detach().cpu().numpy())]})
        return val_loss

    def configure_optimizers(self):
        return torch.optim.Adam(params=self.parameters(), lr=conf['learning_rate'], weight_decay=1e-5)

    def train_dataloader(self):
        train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=self.image_transform)
        train_dataloader = DataLoader(train_dataset, batch_size=conf['batch_size'], shuffle=True)
        return train_dataloader

    def val_dataloader(self):
        val_dataset = MNIST(root='./data/MNIST', download=True, train=False, transform=self.image_transform)
        val_dataloader = DataLoader(val_dataset, batch_size=conf['batch_size'], shuffle=False)
        return val_dataloader

    def log_model(self, model, file_path):
        # It won't show the model graph
        # But it will store your latest checkpoint to the cloud
        model = wandb.Artifact(model, type='model')
        model.add_file(file_path)
        wandb.run.log_artifact(model)

In [None]:
ae = AutoEncoder()

In [None]:
# TODO: Watch model's updates every 100 steps. (Will log weights and gradients distributions)
wandb_logger...

In [None]:
trainer = pl.Trainer(
    gpus=1, 
    max_epochs=conf['num_epochs'], 
    progress_bar_refresh_rate=20, 
    limit_train_batches=0.1, 
    limit_val_batches=0.1, 
    weights_summary='full',
    logger=wandb_logger,
    default_root_dir=f'/content/{wandb.run.name}/'
    )

#trainer = pl.Trainer(fast_dev_run=True)
trainer.fit(ae)


In [None]:
path = f'/content/{wandb.run.name}/mnist_colab/'
version = [x for x in os.listdir(path) if not '.' in x]
file = os.listdir(os.path.join(path,version[0], 'checkpoints'))
model_path = os.path.join(path, version[0], 'checkpoints', file[0])

In [None]:
# TODO: Store model checkpoint making use of log_model function defined in the class above


In [None]:
wandb_logger.finalize(status='success')

# Classifier

In [None]:
# TODO: Initialize WandbLogger
wandb_logger = ...

# TODO: Log your configuration for quick comparison with other runs.
wandb_logger...

In [None]:
class Classifier(pl.LightningModule):
    def __init__(self, encoder):

        super().__init__()


        self.encoder = encoder
        self.linear = nn.Sequential(
              nn.Linear(conf['latent_dims'], 10),
              nn.ReLU())

        self.image_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        self.criterion = nn.CrossEntropyLoss()

        self.example_input_array = torch.rand(128, 1, 28, 28)


    def forward(self, x):
        x = self.encoder(x)
        return self.linear(x)

    def training_step(self, batch, batch_idx):
        img, label = batch
        prediction = self(img)
        loss = self.criterion(prediction, label.long())
        self.log('Classification/train_loss', loss, on_step=False, on_epoch=True)
        _ = self.log_accuracy(prediction, label, 'train')
        return loss


    def validation_step(self, batch, batch_idx):
        img, label = batch
        prediction = self(img)
        
        loss = self.criterion(prediction, label.long())
        self.log('Classification/val_loss', loss, on_step=False, on_epoch=True)
        pred_labels = self.log_accuracy(prediction, label, 'val')

        return {'predictions': pred_labels, 'labels': label}



    def validation_epoch_end(self, outputs):
        predictions = outputs[0]['predictions']
        labels = outputs[0]['labels']
        fig = self.log_confusion_matrix(predictions, labels)

        # TODO: Log figure to Wandb
        self.logger.experiment...

        plt.close()

    def log_accuracy(self, preds, labls, type):
        pred_labels = preds.argmax(dim=1, keepdim=True)
        acc = pred_labels.eq(labls.view_as(pred_labels)).sum().item()/len(pred_labels)
        self.log(f"Classification/{type}_acc", acc, on_step=False, on_epoch=True)
        return pred_labels

    
    def log_confusion_matrix(self, predictions, labels):
        
        cm = confusion_matrix(labels.cpu(), predictions.cpu())
        cm = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)

        fig = plt.figure(figsize=(8,8))
        plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)

        plt.colorbar()
        tick_marks = np.arange(10)

        plt.xticks(tick_marks, np.arange(0,10))
        plt.yticks(tick_marks, np.arange(0,10))

        plt.tight_layout()
        threshold = cm.max() / 2.

        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            color = "white" if cm[i, j] > threshold else "black"
            plt.text(j, i, cm[i, j], horizontalalignment="center", color=color)

        plt.ylabel('True label')
        plt.xlabel('Predicted label')
        plt.title("Confusion matrix")
        return fig



    def configure_optimizers(self):
        return torch.optim.Adam(params=self.parameters(), lr=conf['learning_rate'], weight_decay=1e-5)

    def train_dataloader(self):
        train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=self.image_transform)
        train_dataloader = DataLoader(train_dataset, batch_size=conf['batch_size'], shuffle=True)
        return train_dataloader

    def val_dataloader(self):
        test_dataset = MNIST(root='./data/MNIST', download=True, train=False, transform=self.image_transform)
        test_dataloader = DataLoader(test_dataset, batch_size=conf['batch_size'], shuffle=False)
        return test_dataloader


In [None]:
classifier = Classifier(ae.encoder)

In [None]:
trainer = pl.Trainer(
    gpus=1, 
    max_epochs=conf['num_epochs'], 
    progress_bar_refresh_rate=20, 
    limit_train_batches=0.3, 
    limit_val_batches=0.2, 
    weights_summary='full',
    logger=wandb_logger,
    default_root_dir=f'/content/{wandb.run.name}/'
    )

#trainer = pl.Trainer(fast_dev_run=True)
trainer.fit(classifier)

In [None]:
path = f'/content/{wandb.run.name}/mnist_colab/'
version = [x for x in os.listdir(path) if not '.' in x]
file = os.listdir(os.path.join(path,version[0], 'checkpoints'))
model_path = os.path.join(path, version[0], 'checkpoints', file[0])

In [None]:
# TODO: Store model checkpoint making use of log_model function


In [None]:
wandb_logger.finalize(status='success')