# Visualisation

In this notebook, we will further explore the different things that can be done with `interpret`. If you'd like a basic intro, see the [Interpret Intro Notebook](https://github.com/ttumiel/interpret/blob/master/nbs/Interpret-Intro.ipynb)

## Install

In [None]:
# install from PyPI
!pip install interpret-pytorch

# Install from github
# !pip install git+https://github.com/ttumiel/interpret

## Channel Visualisations

In [None]:
from interpret import OptVis, ImageParam, denorm, get_layer_names
import torchvision, torch

In [None]:
# Create a network and select the particular objective that you want to optimise for.
network = torchvision.models.googlenet(pretrained=True)

# Print the layer names so that we can choose one to optimise for
get_layer_names(network)

Perhaps we want to optimise for the layer 'inception4c/branch1/conv', we can select this layer by passing it into the class method `OptVis.from_layer`. This will create an `OptVis` object with that layer as the objective. We can also choose which channel we would like to optimise for in that layer.

In [None]:
layer = 'inception4c/branch1/conv' # choose layer
channel = 32 # choose channel in layer

# Create an OptVis object that will create a layer objective to optimise
optvis = OptVis.from_layer(network, layer=layer, channel=channel)

# Parameterise input noise in colour decorrelated Fourier domain
img_param = ImageParam(128, fft=True, decorrelate=True)

# Create visualisation
# thresh is a tuple containing the iterations at which to display the image
optvis.vis(img_param, thresh=(250,500))

In [None]:
channel = 14 # choose channel in layer
optvis = OptVis.from_layer(network, layer=layer, channel=channel)
optvis.vis() # you can leave out the image parameterisation to use the default

## Manually setting objectives

We can also manually create an objective and pass it to the constructor of the `OptVis` class. By creating our own objective, we can do interesting things, like combine 2 different objectives:

In [None]:
from interpret.vis import LayerObjective

In [None]:
objective32 = LayerObjective(network, layer, channel=32)
objective14 = LayerObjective(network, layer, channel=14)
objective = objective32 + objective14

optvis = OptVis(network, objective)
optvis.vis()

In [None]:
# And you can interpolate between them:
objective = 0.75*objective32 + 0.25*objective14

optvis = OptVis(network, objective)
optvis.vis()

## Other Objectives

Additionally, you can optimise based on other objectives. An objective is just a class that saves a `self.loss` value to optimise. For example the `LayerObjective` hooks into the pytorch model and grabs the output of the particular layer. It then sets the negative mean of this value as the loss (i.e. we want to maximise the activation of that particular layer.)

Another objective to optimise for is `DeepDream`. This creates a dream-like effect on an input image. If you'd like to see other objectives, please make a PR!

In [None]:
from interpret.vis import ImageFile

In [None]:
# Download an image to apply attribution to
!curl https://www.yourpurebredpuppy.com/dogbreeds/photos2-G/german-shepherd-05.jpg -o dog.jpg

In [None]:
# Parameterise the image
img_param = ImageFile("dog.jpg", size=256)
img_param.cuda()

In [None]:
# Deep Dream
optvis = OptVis.from_dream(network, layer=layer)
optvis.vis(img_param, thresh=(30,));

## Creating Objectives

To create an object you can either subclass from `Objective`. This is particularly useful if you want to save some state. If you do not have any state, you can create a function that takes the network input `x` and returns some loss to minimise, and decorate it with `@Objective` so that it has all the `Objective` properties.

## Improving Visualisations

Visualisations don't always play nice and sometimes you may have to change a few things around. You might try some of the following:
- Add a bit of weight decay, using the `wd` parameter in `optvis.vis`.
- Changing the transformations. You will have to make sure the transformations operate on tensors so that the gradient can be propagated through. See `transforms.py`. This seems particularly useful for layers that are deep in the network, like the final output.
- Add other regularisation terms like the L1 or L2 norm, or total variation to help reduce noise.