# Miscellaneous Interpretations

This notebook covers the various interpretations that do not neatly fit into the attribution or visualisation sections. The first section trains a neural network on Imagenette which we use for the interpretations.

Most methods in the misc section are structured into generating the data and plotting the data. You can thus also do your own plotting if you'd like. For example, get the top losses using `top_losses` and then plot the inputs that result in these losses using `plot_top_losses` and the output from before. This gives you more control over the data that you want to plot and prevents having to rerun the computation because the plotting isn't how you'd like it.

- [Top Losses](#Top-Losses)
- [Confusion Matrix](#Confusion-Matrix)
- [Dataset Examples](#Dataset-Examples)
- [Loss Landscape](#Loss-Landscape)

In [None]:
# Install interpret
!pip install git+https://github.com/ttumiel/interpret

In [None]:
import torch
import torchvision
import numpy as np
from torch import nn
from torchvision import transforms, datasets
from pathlib import Path
from functools import partial

from interpret.misc import *

In [None]:
# download the imagenette dataset
!wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz
!tar xf imagenette2-160.tgz

imagenette_mean = [0.4616, 0.4538, 0.4254]
imagenette_std = [0.2681, 0.2643, 0.2865]

def get_transforms(size, mean, std, rotate=10, flip_lr=True, flip_ud=False):
    "Get some basic transforms for the dataset"
    val_tfms = [
        transforms.Resize(size),
        transforms.CenterCrop((size, size)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]

    tfms = [transforms.RandomRotation(rotate)] if rotate != 0 else []
    if flip_lr: tfms += [transforms.RandomHorizontalFlip()]
    if flip_ud: tfms += [transforms.RandomVerticalFlip()]

    train_tfms = transforms.Compose(tfms+val_tfms)
    valid_tfms = transforms.Compose(val_tfms)
    return train_tfms, valid_tfms

def imagenette(path, imsize):
    "Load the imagenette datasets"
    path = Path(path)
    train_tfms, val_tfms = get_transforms(imsize, imagenette_mean, imagenette_std)
    train_ds = datasets.ImageFolder(path/'train', transform=train_tfms)
    valid_ds = datasets.ImageFolder(path/'val', transform=val_tfms)
    return train_ds, valid_ds

tds, ds = imagenette("imagenette2-160/", 128)
dl = torch.utils.data.DataLoader(tds, batch_size=128, shuffle=True)
val_dl = torch.utils.data.DataLoader(ds, batch_size=128)

## Train a ResNet18 on the Imagenette Dataset

Here we train a pretrained resnet18 network on the [Imagenette](https://github.com/fastai/imagenette) dataset. This is a 10 class subset of imagenet so it's very easy, so we just retrain the head of the network to output the required 10 classes. The next code block simply creates this training method.

The classes are:

0. Tench
1. English springer
1. cassette player
1. chain saw
1. church
1. French horn
1. garbage truck 
1. gas pump 
1. golf ball 
1. parachute

In [None]:
network = torchvision.models.resnet18(pretrained=True)
for m in network.modules():
    if not isinstance(m, nn.BatchNorm2d):
        m.requires_grad_(False)

network.fc = nn.Linear(512, 10)

def accuracy(preds, tgt): return torch.mean((preds==tgt).float()).cpu().item()

def train(network, dataloader, loss_fn, epochs=3):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    network.train().to(device)
    losses, accs = [], []
    optim = torch.optim.Adam(network.parameters(), lr=1e-3)
    for e in range(epochs):
        network.train()
        for x,y in dataloader:
            x,y = x.to(device), y.to(device)
            y_hat = network(x)
            loss = loss_fn(y_hat, y)
            loss.backward()
            optim.step()
            optim.zero_grad()
            losses.append(loss.cpu().item())
            accs.append(accuracy(y_hat.argmax(1), y))
        
        preds, tgts = validate(network, val_dl)
        print(e, "-  Loss:", np.mean(losses), "Train Acc:", np.mean(accs), 
              "Val Acc:", np.mean(accuracy(preds.argmax(1), tgts)))
        losses, accs = [], []

In [None]:
train(network, dl, nn.CrossEntropyLoss())

## Top Losses

Plot the inputs from a particular dataset that result in the largest loss. Useful for identifying where your network is most unsure or where the inputs actually don't fit the label given (a mislabelled image).

`top_losses` returns the top sorted predictions, targets, losses and all ranked indexes in a tuple. The returned values can be passed into the `plot_top_losses` function for plotting.

In [None]:
losses = top_losses(network, val_dl, nn.CrossEntropyLoss())

In [None]:
plot_top_losses(losses, val_dl, network=network, gradcam=True, layer='layer3');

## Confusion Matrix

Plot a confusion matrix for a multi-class classification or binned (rounded) regression objective. The true labels are plotted on the y-axis, with the predictions on the x-axis. This helps you find out which classes your network is favouring and where its making its mistakes across the entire dataset.

Pass the output of `confusion_matrix` to `plot_confusion_matrix` to plot the matrix into an image. Use a dict for the `decode_label` parameter to translate the target labels into readable names.

In [None]:
cm = confusion_matrix(network, val_dl, num_classes=10)

In [None]:
plot_confusion_matrix(cm);

## Dataset Examples

Plot some dataset examples that maximise a particular `LayerObjective` from the visualisation objectives described in the visualisation tutorial. Useful for identifying clear examples of what the network is looking for in a particular visualisation using real examples.

In [None]:
# First let's generate a visualisation of what the network is looking for 
# in a random layer. We can then compare this to some dataset examples.
from interpret import OptVis
channel = 176           # Choose a channel. Try: np.random.randint(200)
layer = 'layer3'        # Choose a layer
OptVis.from_layer(network, layer=layer, channel=channel).vis()

In [None]:
ds_examples = dataset_examples(network, val_dl, layer, channel=channel)

In [None]:
# Plot some examples that activate the same objective
plot_dataset_examples(ds_examples, val_dl);

## Loss Landscape

Plot the loss landscape in 2 random directions around a trained network. This allows you to see how smooth the landscape around the current optimum of the network is. See https://arxiv.org/abs/1712.09913 for more details.

Loss landscapes calculate the loss across a grid of points. Because of the large amount of compute, this can take long. However, we can sub-sample the dataset to around 5% and still generate the same landscape.

In [None]:
# Subsample the validation set to reduce computation
indices = np.random.choice(np.arange(len(ds)), int(0.1*len(ds)), replace=False)
sampler = torch.utils.data.sampler.SubsetRandomSampler(indices)
subset_val_dl = torch.utils.data.DataLoader(ds, batch_size=256, sampler=sampler)

In [None]:
ll = loss_landscape(network, subset_val_dl, nn.CrossEntropyLoss())

In [None]:
plot_loss_landscape(ll, angle=30);

In [None]:
plot_loss_landscape(ll, mode='contour');