In [None]:
from itertools import product
import sys
from PIL import Image

import numpy as np
import torch
import torchvision

from lucent.modelzoo import *
from lucent.misc.io import show
import lucent.optvis.objectives as objectives
import lucent.optvis.param as param
import lucent.optvis.render as render
import lucent.optvis.transform as transform
from lucent.misc.channel_reducer import ChannelReducer
from lucent.misc.io import show

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class KittyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn0 = nn.BatchNorm2d(3)
        self.conv1 = nn.Conv2d(3, 9, 3)
        self.pool1 = nn.AvgPool2d(4, 4)
        
        self.conv1_bn = nn.BatchNorm2d(9)
        self.conv2 = nn.Conv2d(9, 16, 3)
        self.pool2 = nn.AvgPool2d(4, 4)
        
        self.conv2_bn = nn.BatchNorm2d(16)
        self.conv3 = nn.Conv2d(16, 25, 3)
        self.pool3 = nn.AvgPool2d(4, 4)
        
        self.conv3_bn = nn.BatchNorm2d(25)
        self.conv4 = nn.Conv2d(25, 36, 3)
        self.pool4 = nn.AvgPool2d(2 , 2)
        
        self.fc = nn.Linear(324, 4)

    def forward(self, x):
        x = self.bn0(x)
        x = self.conv1_bn(self.pool1(F.relu(self.conv1(x))))
        x = self.conv2_bn(self.pool2(F.relu(self.conv2(x))))
        x = self.conv3_bn(self.pool3(F.relu(self.conv3(x))))
        x = self.pool4(F.relu(self.conv4(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = self.fc(x)
        return x
    
class LongcatNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(3)
        self.conv1 = nn.Conv2d(3, 9, 3)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.conv2_bn = nn.BatchNorm2d(9)
        self.conv2 = nn.Conv2d(9, 16, 3)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        self.conv3_bn = nn.BatchNorm2d(16)
        self.conv3 = nn.Conv2d(16, 25, 3)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        self.conv4_bn = nn.BatchNorm2d(25)
        self.conv4 = nn.Conv2d(25, 36, 3)
        self.pool4 = nn.MaxPool2d(2, 2)
        
        self.conv5_bn = nn.BatchNorm2d(36)
        self.conv5 = nn.Conv2d(36, 36, 3)
  
        self.conv6_bn = nn.BatchNorm2d(36)
        self.conv6 = nn.Conv2d(36, 49, 3)

        self.conv7_bn = nn.BatchNorm2d(49)
        self.conv7 = nn.Conv2d(49, 49, 3)
        
        self.conv8_bn = nn.BatchNorm2d(49)
        self.conv8 = nn.Conv2d(49, 49, 3)
        
        self.conv9_bn = nn.BatchNorm2d(49)
        self.conv9 = nn.Conv2d(49, 49, 3)
        self.pool9 = nn.MaxPool2d(2, 2)

        self.conv10_bn = nn.BatchNorm2d(49)
        self.conv10 = nn.Conv2d(49, 49, 3)
        self.pool10 = nn.MaxPool2d(2, 2)

        self.fc = nn.Linear(1764, 4)

    def forward(self, x):
        x = self.bn1(x)
        x = self.conv2_bn(self.pool1(F.relu(self.conv1(x))))
        x = self.conv3_bn(self.pool2(F.relu(self.conv2(x))))
        x = self.conv4_bn(self.pool3(F.relu(self.conv3(x))))
        x = self.conv5_bn(self.pool4(F.relu(self.conv4(x))))

        x = self.conv6_bn(F.relu(self.conv5(x)))  
        x = self.conv7_bn(F.relu(self.conv6(x)))
        x = self.conv8_bn(F.relu(self.conv7(x)))
        x = self.conv9_bn(F.relu(self.conv8(x)))
        
        
        x = self.conv10_bn(self.pool9(F.relu(self.conv9(x))))
        x = self.pool10(F.relu(self.conv10(x)))
        
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        
        x = self.fc(x)
        return x

In [None]:
# from https://colab.research.google.com/github/greentfrapp/lucent-notebooks/blob/master/notebooks/activation_grids.ipynb

@torch.no_grad()
def get_layer(model, layer, X):
    hook = render.ModuleHook(getattr(model, layer))
    model(X)
    hook.close()
    return hook.features


@objectives.wrap_objective()
def dot_compare(layer, acts, batch=1):
    def inner(T):
        pred = T(layer)[batch]
        return -(pred * acts).sum(dim=0, keepdims=True).mean()

    return inner


def render_activation_grid_less_naive(
    img,
    model,
    layer="inception5b",
    cell_image_size=60,
    n_groups=6,
    n_steps=1024,
    batch_size=64,
):
    # First wee need, to normalize and resize the image
    img = torch.tensor(np.transpose(img, [2, 0, 1])).to(device)
    transforms = transform.standard_transforms.copy() + [
        torchvision.transforms.Resize(256),
        torchvision.transforms.CenterCrop(224),
        lambda x: x/255,
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
    transforms_f = transform.compose(transforms)
    # shape: (1, 3, original height of img, original width of img)
    img = img.unsqueeze(0)
    # shape: (1, 3, 224, 224)
    img = transforms_f(img)

    # Here we compute the activations of the layer `layer` using `img` as input
    # shape: (layer_channels, layer_height, layer_width), the shape depends on the layer
    acts = get_layer(model, layer, img)[0]
    # shape: (layer_height, layer_width, layer_channels)
    acts = acts.permute(1, 2, 0)
    # shape: (layer_height*layer_width, layer_channels)
    acts = acts.view(-1, acts.shape[-1])
    acts_np = acts.cpu().numpy()
    nb_cells = acts.shape[0]

    # negative matrix factorization `NMF` is used to reduce the number
    # of channels to n_groups. This will be used as the following.
    # Each cell image in the grid is decomposed into a sum of
    # (n_groups+1) images. First, each cell has its own set of parameters
    #  this is what is called `cells_params` (see below). At the same time, we have
    # a of group of images of size 'n_groups', which also have their own image parametrized
    # by `groups_params`. The resulting image for a given cell in the grid
    # is the sum of its own image (parametrized by `cells_params`)
    # plus a weighted sum of the images of the group. Each each image from the group
    # is weighted by `groups[cell_index, group_idx]`. Basically, this is a way of having
    # the possibility to make cells with similar activations have a similar image, because
    # cells with similar activations will have a similar weighting for the elements
    # of the group.
    if n_groups > 0:
        reducer = ChannelReducer(n_groups, "NMF")
        groups = reducer.fit_transform(acts_np)
        groups /= groups.max(0)
    else:
        groups = np.zeros([])
    # shape: (layer_height*layer_width, n_groups)
    groups = torch.from_numpy(groups)

    # Parametrization of the images of the groups (we have 'n_groups' groups)
    groups_params, groups_image_f = param.fft_image(
        [n_groups, 3, cell_image_size, cell_image_size]
    )
    # Parametrization of the images of each cell in the grid (we have 'layer_height*layer_width' cells)
    cells_params, cells_image_f = param.fft_image(
        [nb_cells, 3, cell_image_size, cell_image_size]
    )

    # First, we need to construct the images of the grid
    # from the parameterizations

    def image_f():
        groups_images = groups_image_f()
        cells_images = cells_image_f()
        X = []
        for i in range(nb_cells):
            x = 0.7 * cells_images[i] + 0.5 * sum(
                groups[i, j] * groups_images[j] for j in range(n_groups)
            )
            X.append(x)
        X = torch.stack(X)
        return X

    # make sure the images are between 0 and 1
    image_f = param.to_valid_rgb(image_f, decorrelate=True)

    # After constructing the cells images, we sample randomly a mini-batch of cells
    # from the grid. This is to prevent memory overflow, especially if the grid
    # is large.
    def sample(image_f, batch_size):
        def f():
            X = image_f()
            inds = torch.randint(0, len(X), size=(batch_size,))
            inputs = X[inds]
            # HACK to store indices of the mini-batch, because we need them
            # in objective func. Might be better ways to do that
            sample.inds = inds
            return inputs

        return f

    image_f_sampled = sample(image_f, batch_size=batch_size)

    # Now, we define the objective function

    def objective_func(model):
        # shape: (batch_size, layer_channels, cell_layer_height, cell_layer_width)
        pred = model(layer)
        # use the sampled indices from `sample` to get the corresponding targets
        target = acts[sample.inds].to(pred.device)
        # shape: (batch_size, layer_channels, 1, 1)
        target = target.view(target.shape[0], target.shape[1], 1, 1)
        dot = (pred * target).sum(dim=1).mean()
        return -dot

    obj = objectives.Objective(objective_func)

    def param_f():
        # We optimize the parametrizations of both the groups and the cells
        params = list(groups_params) + list(cells_params)
        return params, image_f_sampled

    results = render.render_vis(
        model,
        obj,
        param_f,
        thresholds=(n_steps,),
        show_image=False,
        progress=True,
        fixed_image_size=cell_image_size,
    )
    # shape: (layer_height*layer_width, 3, grid_image_size, grid_image_size)
    imgs = image_f()
    imgs = imgs.cpu().data
    imgs = imgs[:, :, 2:-2, 2:-2]
    # turn imgs into a a grid
    grid = torchvision.utils.make_grid(imgs, nrow=int(np.sqrt(nb_cells)), padding=0)
    grid = grid.permute(1, 2, 0)
    grid = grid.numpy()
    render.show(grid)
    return imgs

## Kitty 

In [None]:
img = np.array(Image.open("dataset/test/LasVegas/0000075_0018745_0000004_0002564.jpg"), np.float32)

In [None]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'googlenet', pretrained=True)
#model.load_state_dict(torch.load('saved_models/inception/epoch_7_batch_5000.pth', map_location=device))
model.to(device).eval()

In [None]:
render_activation_grid_less_naive(
            img, model, 'inception4b', cell_image_size=120, n_steps=1024, batch_size=64, n_groups=8
            )

In [None]:
model.inception5b

In [None]:
kitty = KittyNet()
kitty.load_state_dict(torch.load('saved_models/kitty/epoch_7_batch_5000.pth', map_location=device))
kitty.to(device).eval()

In [None]:
for x in range(1,50):
    try:
        render_activation_grid_less_naive(
            img, kitty, 'conv3', cell_image_size=x, n_steps=1024, batch_size=64, n_groups=2
            )
        print("aloha")
    except RuntimeError:
        print(x)

In [None]:
## Longcat

In [None]:
## Inception