<img src="https://i.imgur.com/gb6B4ig.png" width="400" alt="Weights & Biases" />

# View & analyze model predictions during training

This notebook shows how to track, visualize, and compare model predictions over the course of training, using Pytorch on MNIST data. 

With W&B's newest feature for [Dataset and Prediction Visualization (in Beta)](https://docs.wandb.com/datasets-and-predictions), you can
* log metrics, images, text, etc. to a wandb.Table() during model training or evaluation
* view, sort, filter, group, join, interactively query, and otherwise explore these tables
* compare model predictions or results over time: dynamically across different epochs or validation steps

## Compare predicted scores for specific images
[Live example: compare predictions after 1 vs 5 epochs of training →](https://wandb.ai/stacey/mnist-viz/artifacts/predictions/test_samples_222yogf6/620ddb9966343696912d/files/predictions.table.json#e6232e8a664e34740d7f)
<img src="https://i.imgur.com/EwXO0HY.png" alt="MNIST test predictions"/>
The histograms compare per-class scores between the two models. The top bar in each histogram represents model v0, which only trained for 1 epoch. The bottom bar represents model v4, which trained for 5 epochs. For example, the 1 in the middle row has much higher confidence scores for classes 0, 2, 3, and 4 (all incorrect labels) after 1 epoch than after 5 epochs of training.
<img src="https://i.imgur.com/VoGXdsj.png" alt="middle row"/>

## Focus on top errors over time
See incorrect predictions (filter to rows where "guess" != "truth") on the full test data
<img src="https://i.imgur.com/8KulMp0.png" alt="MNIST errors compare"/>
Note that there are 213 wrong guesses after 1 training epoch, but only 84 after 5 epochs.

## Compare model performance and find patterns
Filter out correct answers, then group by the guess to see examples of misclassified images and the underlying distribution of true labels—for two models side-by-side. A baseline model is on the left, and a variant with double the layer sizes is on the right.
<img src="https://i.imgur.com/yOSAiGh.png" alt="MNIST grouped"/>

## Sign up or login

[Sign up or login](https://wandb.ai/login) to W&B to see and interact with your experiments in the browser.

In this example we're using Google Colab as a convenient hosted environment, but you can run your own training scripts from anywhere and visualize metrics with W&B's experiment tracking tool.

<a href="https://colab.research.google.com/github/wandb/examples/blob/master/colabs/datasets-predictions/Visualize_Model_Predictions_over_Time.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install wandb -qqq
WANDB_PROJECT = "mnist-viz"
import wandb
wandb.login()

# 0. Setup

Install dependencies, download MNIST, and create train and test datasets using Pytorch.

In [None]:
%%capture
import torch
import torch.nn as nn
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms
import torch.nn.functional as F

# workaround to fetch MNIST data
!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz

# create train and test sets
train_set = MNIST('./', download=True,
transform=transforms.Compose([
transforms.ToTensor(),
]), train=True)

test_set = MNIST('./', download=True,
transform=transforms.Compose([
transforms.ToTensor(),
]), train=False)

# 1. Define the model and training schedule

* Set the number of epochs to run, where each epoch consists of a training step and a validation step. Optionally configure the amount of data to log per validation step. Here the number of batches and number of images per batch to visualize are set low to simplify the demo. 
* Define a simple convolutional neural net (following [pytorch-tutorial](https://github.com/yunjey/pytorch-tutorial) code).
* Load in train and test sets using Pytorch


In [None]:
# Number of epochs to run
# Each epoch includes a training step and a test step, so this sets
# the number of tables of test predictions to log
EPOCHS = 5

# Number of batches to log from the test data for each test step
# (default set low to simplify demo)
NUM_BATCHES_TO_LOG = 78

# Number of images to log per test batch
# (default set low to simplify demo)
LOG_IMAGES_PER_BATCH = 128

# training configuration and hyperparameters
NUM_CLASSES = 10
BATCH_SIZE = 128
LEARNING_RATE = 0.001
L1_SIZE = 32
L2_SIZE = 64
# changing this may require changing the shape of adjacent layers
CONV_KERNEL_SIZE = 5

# define a two-layer convolutional neural network
class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, L1_SIZE, CONV_KERNEL_SIZE, stride=1, padding=2),
            nn.BatchNorm2d(L1_SIZE),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(L1_SIZE, L2_SIZE, CONV_KERNEL_SIZE, stride=1, padding=2),
            nn.BatchNorm2d(L2_SIZE),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*L2_SIZE, NUM_CLASSES)
        self.softmax = nn.Softmax(NUM_CLASSES)

    def forward(self, x):
        # uncomment to see the shape of a given layer:
        #print("x: ", x.size())
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                           batch_size=BATCH_SIZE,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                          batch_size=BATCH_SIZE,
                                          shuffle=False)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 2. Run training and log test predictions

For every epoch, run a training step and a test step. For each test step, create a wandb.Table() in which to store test predictions. These can be visualized, dynamically queried, adn compared side by side in your browser.

In [None]:
# log to a specific project which will store all your experiments
wandb.init(project=WANDB_PROJECT)
# log model comnfiguration to wandb
cfg = wandb.config
cfg.update({"epochs" : EPOCHS, "batch_size": BATCH_SIZE, "lr" : LEARNING_RATE,
            "l1_size" : L1_SIZE, "l2_size": L2_SIZE,
            "conv_kernel" : CONV_KERNEL_SIZE})

model = ConvNet(NUM_CLASSES).to(device)
# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# optionally log gradients and parameters to wandb
#wandb.watch(model, log="all")

# convenience funtion to log predictions for a batch of test images
def log_test_predictions(images, labels, outputs, predicted, test_table, log_counter):
  # obtain confidence scores for ALL classes
  scores = F.softmax(outputs.data, dim=1)
  log_scores = scores.cpu().numpy()
  log_images = images.cpu().numpy()
  log_labels = labels.cpu().numpy()
  log_preds = predicted.cpu().numpy()
  # assing ids based on the order of the images
  _id = 0
  for i, l, p, s in zip(log_images, log_labels, log_preds, log_scores):
    # add required info to data table:
    # id, image pixels, model's guess, true label, scores for all classes
    img_id = str(_id) + "_" + str(log_counter)
    test_table.add_data(img_id, wandb.Image(i), p, l, *s)
    _id += 1
    if _id == LOG_IMAGES_PER_BATCH:
      break

# train the model
total_step = len(train_loader)
for epoch in range(EPOCHS):
    # training step
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        # forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        # backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
  
        # log loss to wandb
        wandb.log({"loss" : loss})
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                .format(epoch+1, EPOCHS, i+1, total_step, loss.item()))
            
    
    
    # create a wandb Artifact to version each test step separately
    test_data_at = wandb.Artifact("test_samples_" + str(wandb.run.id), type="predictions")
    # create a wandb.Table() in which to store predictions for each test step
    columns=["id", "image", "guess", "truth"]
    for digit in range(10):
      columns.append("score_" + str(digit))
    test_table = wandb.Table(columns=columns)

    # test the model
    model.eval()
    log_counter = 0
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            if log_counter < NUM_BATCHES_TO_LOG:
              log_test_predictions(images, labels, outputs, predicted, test_table, log_counter)
              log_counter += 1
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        acc = 100 * correct / total
        # log acc to wandb (and epoch so you can sync up model performance across batch sizes)
        wandb.log({"epoch" : epoch, "acc" : acc})
        print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

    # log predictions table to wandb
    test_data_at.add(test_table, "predictions")
    wandb.run.log_artifact(test_data_at)

wandb.run.finish()