# Vizualizacija neurona u kitty mreži



### Definicije i importovi

In [None]:
pip install --quiet torch-lucent

In [None]:
from lucent.optvis.transform import pad, jitter, random_rotate, random_scale
from lucent.optvis import render, param, transform, objectives

In [None]:
import torch

from lucent.optvis import render, param, transform, objectives

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn0 = nn.BatchNorm2d(3)
        self.conv1 = nn.Conv2d(3, 9, 3)
        self.pool1 = nn.AvgPool2d(4, 4)
        
        self.conv1_bn = nn.BatchNorm2d(9)
        self.conv2 = nn.Conv2d(9, 16, 3)
        self.pool2 = nn.AvgPool2d(4, 4)
        
        self.conv2_bn = nn.BatchNorm2d(16)
        self.conv3 = nn.Conv2d(16, 25, 3)
        self.pool3 = nn.AvgPool2d(4, 4)
        
        self.conv3_bn = nn.BatchNorm2d(25)
        self.conv4 = nn.Conv2d(25, 36, 3)
        self.pool4 = nn.AvgPool2d(2 , 2)
        
        self.fc = nn.Linear(324, 4)

    def forward(self, x):
        x = self.bn0(x)
        x = self.conv1_bn(self.pool1(F.relu(self.conv1(x))))
        x = self.conv2_bn(self.pool2(F.relu(self.conv2(x))))
        x = self.conv3_bn(self.pool3(F.relu(self.conv3(x))))
        x = self.pool4(F.relu(self.conv4(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = self.fc(x)
        return x

In [None]:
model_start = Net()
model_start = model_start.to(device)

model_early = Net()
model_early = model_early.to(device)

model_late = Net()
model_late = model_late.to(device)

model_start.load_state_dict(torch.load('saved_models/kitty/epoch_0_batch_0.pth', map_location=device))
model_early.load_state_dict(torch.load('saved_models/kitty/epoch_0_batch_4001.pth', map_location=device))
model_late.load_state_dict(torch.load('saved_models/kitty/epoch_7_batch_0.pth', map_location=device))


In [None]:
model_early.to(device).eval()
model_start.to(device).eval()
model_late.to(device).eval()

In [None]:
%pylab inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg


In [None]:
def lucent_show_layer(model, layer, grid_dim,
                      param_f=None, transforms=None,
                      optimizer=None, image_size=128):
    n_row = grid_dim
    n_col = grid_dim
    _, axs = plt.subplots(n_row, n_col, figsize=(19.55, 20))
    axs = axs.flatten()
    for ix, ax in zip(range(n_row*n_col), axs):
        img = render.render_vis(model, f"{layer}:{ix}", param_f=param_f,
                                transforms=transforms, progress=False, show_image=False)[0]
        img = np.reshape(img, (image_size, image_size, 3))
        ax.imshow(img)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.margins(x=0, y=0, tight=True)

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.show()

### Lucent vizualizacija

Lucent je PyTorch library nastao na Tensorflow library Lucid, kojeg su razvili ljudi iz Google Braina za _circuits research_.

### Sloj po sloj—

#### 1.

##### Start

In [None]:
lucent_show_layer(model_start, 'conv1', 3,
                  param_f=lambda: param.image(128),
                  image_size=128)

##### Early

In [None]:
lucent_show_layer(model_early, 'conv1', 3,
                  param_f=lambda: param.image(128),
                  image_size=128)

###### Late

In [None]:
lucent_show_layer(model_late, 'conv1', 3,
                  param_f=lambda: param.image(128),
                  image_size=128)

#### 2.

##### Start

In [None]:
lucent_show_layer(model_start, 'conv2', 4,
                  param_f=lambda: param.image(128),
                  image_size=128)

##### Early

In [None]:
lucent_show_layer(model_early, 'conv2', 4,
                  param_f=lambda: param.image(128),
                  image_size=128)

##### Late

In [None]:
lucent_show_layer(model_late, 'conv2', 4,
                  param_f=lambda: param.image(128),
                  image_size=128)

#### 3.

##### Start

In [None]:
lucent_show_layer(model_start, 'conv3', 5,
                  param_f=lambda: param.image(128),
                  image_size=128)

##### Early

In [None]:
lucent_show_layer(model_early, 'conv3', 5,
                  param_f=lambda: param.image(128),
                  image_size=128)

##### Late

In [None]:
lucent_show_layer(model_late, 'conv3', 5,
                  param_f=lambda: param.image(128),
                  image_size=128)

#### 4.

##### Start

In [None]:
lucent_show_layer(model_start, 'pool4', 6,
                  param_f=lambda: param.image(32),
                  image_size=32)

In [None]:
model_early.to(device).eval()

##### Early

In [None]:
lucent_show_layer(model_early, 'conv4', 6,
                  param_f=lambda: param.image(128),
                  image_size=128)

##### Late

In [None]:
lucent_show_layer(model_late, 'conv4', 6,
                  param_f=lambda: param.image(128),
                  image_size=128)

In [None]:
# možda ovdje, možda negdje drugdje staviti nešto o kombinaciji:

# u drugom layeru, probati sumu neurona 1 i 4

In [None]:
channel = lambda n: objectives.channel("conv2", n)
obj = channel(1) + channel(4)
_ = render.render_vis(model_late, obj, show_inline=True)

In [None]:
channel = lambda n: objectives.channel("conv3", n)
obj = sum([channel(n) for n in range(25) if n % 4 == 0])
_ = render.render_vis(model_late, obj, show_inline=True)

# Captum vizualizacija

In [None]:
!pip3 uninstall --quiet captum --y
!git clone https://github.com/pytorch/captum
%cd captum
!git checkout "optim-wip"
!pip3 install -e .
import sys
sys.path.append('/content/captum')
%cd ..

In [None]:
import captum.optim as optimviz
import torchvision

In [None]:
from typing import Callable, Iterable, Optional

In [None]:
def vis_neuron_large(
    model: torch.nn.Module, target: torch.nn.Module, channel: int
) -> None:
    image = optimviz.images.NaturalImage((640, 640)).to(device)
    transforms = torch.nn.Sequential(
        torch.nn.ReflectionPad2d(2),
        optimviz.transforms.RandomSpatialJitter(8),
        optimviz.transforms.RandomScale(scale=(2.15, 1.85, 2, 1.95, 2.05)),
        torchvision.transforms.RandomRotation(degrees=(-15, 15)),
        optimviz.transforms.RandomSpatialJitter(64),
        optimviz.transforms.CenterCrop((640, 640)),
    )
    loss_fn = optimviz.loss.NeuronActivation(target, channel)
    obj = optimviz.InputOptimization(model, loss_fn, image, transforms)
    history = obj.optimize(optimviz.optimization.n_steps(512, False))
    return image()

In [None]:
def visualize_layer_captum(model, layer, grid_dim):
    n_row = grid_dim
    n_col = grid_dim
    _, axs = plt.subplots(n_row, n_col, figsize=(19.55, 20))
    axs = axs.flatten()
    for ix, ax in zip(range(n_row*n_col), axs):
        img = vis_neuron_large(model, layer, ix)
        img = img.permute(0, 2, 3, 1)
        with torch.no_grad():
            img = img.cpu().numpy()
        img = img.reshape((640,640,3))
        ax.imshow(img)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.margins(x=0, y=0, tight=True)


    plt.subplots_adjust(wspace=0, hspace=0)
    plt.show()

In [None]:
visualize_layer_captum(model_start, model_start.conv1, 3)

In [None]:
visualize_layer_captum(model_early, model_early.conv1, 3)

In [None]:
visualize_layer_captum(model_late, model_late.conv1, 3)

In [None]:
visualize_layer_captum(model_start, model_start.conv4, 6)

In [None]:
visualize_layer_captum(model_early, model_early.conv4, 6)

In [None]:
visualize_layer_captum(model_late, model_late.conv4, 6)