In [None]:
pip install torch-lucent

In [None]:
from lucent.optvis.transform import pad, jitter, random_rotate, random_scale
from lucent.optvis import render, param, transform, objectives

In [None]:
import torch

from lucent.optvis import render, param, transform, objectives

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn0 = nn.BatchNorm2d(3)
        self.conv1 = nn.Conv2d(3, 32, 3)
        self.pool = nn.AvgPool1d(4, 4)
        self.conv1_bn = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.pool = nn.AvgPool2d(4, 4)
        self.conv2_bn = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, 3)
        self.pool = nn.AvgPool2d(4, 4)
        self.conv3_bn = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 64, 3)
        self.pool = nn.AvgPool2d(4, 4)
        self.fc = nn.Linear(64, 4)

    def forward(self, x):
        x = self.bn0(x)
        x = self.conv1_bn(self.pool(F.relu(self.conv1(x))))
        x = self.conv2_bn(self.pool(F.relu(self.conv2(x))))
        x = self.conv3_bn(self.pool(F.relu(self.conv3(x))))
        x = self.pool(F.relu(self.conv4(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = self.fc(x)
        return x


model = Net()
model = model.to(device)

In [None]:
 !ls saved_models/smallboy_v2

In [None]:

model_start = Net()
model_start = model_start.to(device)

model_early = Net()
model_early = model_early.to(device)

model_late = Net()
model_late = model_late.to(device)

model_start.load_state_dict(torch.load('saved_models/smallboy_v2/epoch_0_batch_0.pth', map_location=device))
model_early.load_state_dict(torch.load('saved_models/smallboy_v2/epoch_0_batch_1800.pth', map_location=device))
model_late.load_state_dict(torch.load('saved_models/smallboy_v2/epoch_10_batch_0.pth', map_location=device))


In [None]:
model_early.to(device).eval()
model_start.to(device).eval()
model_late.to(device).eval()

In [None]:
%pylab inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg


In [None]:

obj = objectives.channel('conv3', n_channel=1, batch=None)

to_add = range(5, 25)
for x in to_add:
    obj += objectives.channel('conv3', n_channel=x, batch=None)
tfms = [pad(12, mode="constant", constant_value=.5),
        jitter(8),
        random_scale([1 + (i - 5) / 50. for i in range(11)]),
        random_rotate(list(range(-10, 11)) + 5 * [0]),
        jitter(4),]

param_f = lambda: param.image(128, fft=False, decorrelate=False, batch=1)


render.render_vis(model_start, obj, param_f, transforms=[], progress=False)
render.render_vis(model_early, obj, param_f, transforms=[], progress=False)
render.render_vis(model_late, obj, param_f, transforms=[], progress=False)

In [None]:
tfms = [pad(12, mode="constant", constant_value=.5),
        jitter(8),
        random_scale([1 + (i - 5) / 50. for i in range(11)]),
        random_rotate(list(range(-10, 11)) + 5 * [0]),
        jitter(4),]

param_f = lambda: param.image(320, fft=True, decorrelate=False, batch=1)


render.render_vis(model_start, obj, param_f, transforms=tfms, progress=False)
render.render_vis(model_early, obj, param_f, transforms=tfms, progress=False)
render.render_vis(model_late, obj, param_f, transforms=tfms, progress=False)

In [None]:
a = render.render_vis(model, "conv1:3", transforms=[], progress=False)[0]

In [None]:
a = render.render_vis(model, "conv1:3", transforms=[], progress=False)[0]

In [None]:
a[0].shape

In [None]:
n_row = 2
n_col = 3
_, axs = plt.subplots(n_row, n_col, figsize=(18, 14))
axs = axs.flatten()
for ix, ax in zip(range(n_row*n_col), axs):
  img = render.render_vis(net, f"conv1:{ix}", progress=False, show_image=False)[0]
  #img = np.transpose(img, [0, 3, 2, 1])
  img = np.reshape(img, (128, 128, 3))
  ax.imshow(img)
  ax.set_xticklabels([])
  ax.set_yticklabels([])
    
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()

In [None]:
n_row = 4
n_col = 4
_, axs = plt.subplots(n_row, n_col, figsize=(25, 25))
axs = axs.flatten()
for ix, ax in zip(range(16), axs):
    param_f = lambda: param.image(128, fft=True, decorrelate=False)
    img = render.render_vis(net, f"conv2:{ix}", param_f, transforms=[], progress=False, show_image=False)[0]
    #img = np.transpose(img, [0, 3, 2, 1])
    img = np.reshape(img, (128, 128, 3))
    ax.imshow(img)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()

In [None]:
channel = lambda n: objectives.channel("conv2", n)
obj = channel(0) + channel(1)
_ = render.render_vis(net, obj, show_inline=True)

# Captum (to the rescue?)

In [None]:
!git clone https://github.com/pytorch/captum
%cd captum
!git checkout "optim-wip"
!pip3 install -e .
import sys
sys.path.append('/content/captum')
%cd ..


In [None]:
import captum.optim as optimviz
import torchvision

In [None]:
from typing import Callable, Iterable, Optional

In [None]:
def vis_neuron_large(
    model: torch.nn.Module, target: torch.nn.Module, channel: int
) -> None:
    image = optimviz.images.NaturalImage((640, 640)).to(device)
    transforms = torch.nn.Sequential(
        torch.nn.ReflectionPad2d(2),
        optimviz.transforms.RandomSpatialJitter(8),
        optimviz.transforms.RandomScale(scale=(2, 2, 2, 0.95, 1.05)),
        torchvision.transforms.RandomRotation(degrees=(-5, 5)),
        optimviz.transforms.RandomSpatialJitter(2),
        optimviz.transforms.CenterCrop((640, 640)),
    )
    loss_fn = optimviz.loss.NeuronActivation(target, channel)
    obj = optimviz.InputOptimization(model, loss_fn, image, transforms)
    history = obj.optimize(optimviz.optimization.n_steps(128, False))
    image().show()

In [None]:
vis_neuron_large(model_early, model_early.conv4, 56)

In [None]:
vis_neuron_large(model_late, model_late.conv4, 4)

In [None]:
!nvidia-smi