# Extraction of representations in hidden layers

Now we start looking at representations in hidden layers of
a trained network

# Guided exercise

First of all we have to open 

    mnist_cnn.py 

and

- familiarize with the code
- launch the script to train it and save the model in 
      
        mnist_cnn.pt

Then we will go back to this notebook and

- load the model (see also
https://pytorch.org/tutorials/beginner/saving_loading_models.html)

- extract and visualize representations with T-SNE

We will not explain T-SNE but you will find the following resources

- the documentation on scikit-learn
  https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html

- the original paper        http://www.jmlr.org/papers/volume9/vandermaaten08a/vandermaaten08a.pdf 
  (also in the bibliography)

- the *distill* article (which is a wonderful source of information on networks)
  https://distill.pub/2016/misread-tsne/

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchsummary import summary
from matplotlib import pyplot as plt

This is the same network as in the script

    mnist_cnn.py
    
Usually you may want to put this in a module, anyway I reproduce it
here

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [None]:
# Training settings
input_size=(1,28,28,) # Notice that now the input is not unrolled like in the MLP
batch_size=64
test_batch_size=1000
epochs=1
lr=0.01
momentum=0.0   
seed=1
log_interval=100

In [None]:
use_cuda = torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [None]:
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=test_batch_size, shuffle=True, **kwargs)

In [None]:
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

In [None]:
summary(model,input_size)

In [None]:
print("model's state_dict:")
for p in model.state_dict():
    print(p, "\t", model.state_dict()[p].size())


print("\noptimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

In [None]:
#optimizer.state_dict??

In [None]:
for i, p in enumerate(model.parameters()):
    print(i, p.requires_grad)

In [None]:
print(list(model.parameters())[4][0,0:10])

In [None]:
model.load_state_dict(torch.load('mnist_cnn.pt'))
model.eval()

In [None]:
print(list(model.parameters())[4][0,0:10])

Now that we have the trained model back we extract its representations, but how can we do that?
Think about that for a minute before going on...

In [None]:
inputs,labels = next(iter(test_loader))
print(inputs.shape)

In [None]:
output = model(inputs).detach().numpy()

In [None]:
plt.plot(output[0,:],'-o')
print(labels[0])

# Representations extraction

Define a new class identical to Net but with a method to extract activation in
the following places:

- after the first ReLU        **(h1)**
- after the first pooling     **(h2)**
- after the second ReLU       **(h3)**
- after the second pooling    **(h4)**
- after the third ReLU        **(h5)**
- at output                   **(h6)**

Then extract all these activations in correspondance with the input
forward pass and analyze them with T-SNE 
(including the input in the analysis for comparison)

In [None]:
class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
    
    def extract(self,x):
        ...