In [None]:
!pip install --quiet torch-lucent

In [None]:
from PIL import Image
import numpy as np
import scipy.ndimage as nd
import torch

from lucent.optvis import render, param, transform, objectives
from lucent.misc.io import show

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class KittyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn0 = nn.BatchNorm2d(3)
        self.conv1 = nn.Conv2d(3, 9, 3)
        self.pool1 = nn.AvgPool2d(4, 4)
        
        self.conv1_bn = nn.BatchNorm2d(9)
        self.conv2 = nn.Conv2d(9, 16, 3)
        self.pool2 = nn.AvgPool2d(4, 4)
        
        self.conv2_bn = nn.BatchNorm2d(16)
        self.conv3 = nn.Conv2d(16, 25, 3)
        self.pool3 = nn.AvgPool2d(4, 4)
        
        self.conv3_bn = nn.BatchNorm2d(25)
        self.conv4 = nn.Conv2d(25, 36, 3)
        self.pool4 = nn.AvgPool2d(2 , 2)
        
        self.fc = nn.Linear(324, 4)

    def forward(self, x):
        x = self.bn0(x)
        x = self.conv1_bn(self.pool1(F.relu(self.conv1(x))))
        x = self.conv2_bn(self.pool2(F.relu(self.conv2(x))))
        x = self.conv3_bn(self.pool3(F.relu(self.conv3(x))))
        x = self.pool4(F.relu(self.conv4(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = self.fc(x)
        return x
    
class LongcatNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(3)
        self.conv1 = nn.Conv2d(3, 9, 3)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.conv2_bn = nn.BatchNorm2d(9)
        self.conv2 = nn.Conv2d(9, 16, 3)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        self.conv3_bn = nn.BatchNorm2d(16)
        self.conv3 = nn.Conv2d(16, 25, 3)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        self.conv4_bn = nn.BatchNorm2d(25)
        self.conv4 = nn.Conv2d(25, 36, 3)
        self.pool4 = nn.MaxPool2d(2, 2)
        
        self.conv5_bn = nn.BatchNorm2d(36)
        self.conv5 = nn.Conv2d(36, 36, 3)
  
        self.conv6_bn = nn.BatchNorm2d(36)
        self.conv6 = nn.Conv2d(36, 49, 3)

        self.conv7_bn = nn.BatchNorm2d(49)
        self.conv7 = nn.Conv2d(49, 49, 3)
        
        self.conv8_bn = nn.BatchNorm2d(49)
        self.conv8 = nn.Conv2d(49, 49, 3)
        
        self.conv9_bn = nn.BatchNorm2d(49)
        self.conv9 = nn.Conv2d(49, 49, 3)
        self.pool9 = nn.MaxPool2d(2, 2)

        self.conv10_bn = nn.BatchNorm2d(49)
        self.conv10 = nn.Conv2d(49, 49, 3)
        self.pool10 = nn.MaxPool2d(2, 2)

        self.fc = nn.Linear(1764, 4)

    def forward(self, x):
        x = self.bn1(x)
        x = self.conv2_bn(self.pool1(F.relu(self.conv1(x))))
        x = self.conv3_bn(self.pool2(F.relu(self.conv2(x))))
        x = self.conv4_bn(self.pool3(F.relu(self.conv3(x))))
        x = self.conv5_bn(self.pool4(F.relu(self.conv4(x))))

        x = self.conv6_bn(F.relu(self.conv5(x)))  
        x = self.conv7_bn(F.relu(self.conv6(x)))
        x = self.conv8_bn(F.relu(self.conv7(x)))
        x = self.conv9_bn(F.relu(self.conv8(x)))
        
        
        x = self.conv10_bn(self.pool9(F.relu(self.conv9(x))))
        x = self.pool10(F.relu(self.conv10(x)))
        
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        
        x = self.fc(x)
        return x

In [None]:
@objectives.wrap_objective()
def dot_compare(layer, batch=1, cossim_pow=0):
    def inner(T):
        dot = (T(layer)[batch] * T(layer)[0]).sum()
        mag = torch.sqrt(torch.sum(T(layer)[0]**2))
        cossim = dot/(1e-6 + mag)
        return -dot * cossim ** cossim_pow
    return inner

In [None]:
def feature_inversion(img, model, layer=None, n_steps=512, cossim_pow=0.0):
    # Convert image to torch.tensor and scale image
    img = torch.tensor(np.transpose(img, [2, 0, 1])).to(device)

    obj = objectives.Objective.sum([
    1.0 * dot_compare(layer, cossim_pow=cossim_pow),
    objectives.blur_input_each_step(),
    ])

    # Initialize parameterized input and stack with target image
    # to be accessed in the objective function
    params, image_f = param.image(640)
    def stacked_param_f():
        return params, lambda: torch.stack([image_f()[0], img])

    transforms = [
    transform.pad(8, mode='constant', constant_value=.5),
    transform.jitter(8),
    transform.random_scale([0.9, 0.95, 1.05, 1.1] + [1]*4),
    transform.random_rotate(list(range(-5, 5)) + [0]*5),
    transform.jitter(2),
    ]

    _ = render.render_vis(model, obj, stacked_param_f, transforms=transforms, thresholds=(n_steps,), show_image=False, progress=False)

    show(_[0][0])

## Kitty

In [None]:
kitty = KittyNet()
kitty.load_state_dict(torch.load('saved_models/kitty/epoch_7_batch_5000.pth', map_location=device))
kitty.to(device).eval()

In [None]:
kitty_layers = [f'conv{i}' for i in range(1, 5)]

In [None]:
img = np.array(Image.open("dataset/test/Amsterdam/0000093_0002505_0000002_0000780.jpg"), np.float32)
img = img/255

layers = [f'conv{i}' for i in range(1, 5)]

In [None]:
for layer in kitty_layers:
    print(layer)
    feature_inversion(img, kitty, layer=layer)
    print()

In [None]:
img = np.array(Image.open("dataset/test/LasVegas/0000075_0018745_0000004_0002564.jpg"), np.float32)

layers = [f'conv{i}' for i in range(1, 5)]
for layer in layers:
    print(layer)
    feature_inversion(img, layer=layer)
    print()

## Longcat

In [None]:
longcat = LongcatNet()
longcat.load_state_dict(torch.load('saved_models/longcat/epoch_7_batch_5000.pth', map_location=device))
longcat.to(device).eval()

In [None]:
longcat_layer = [f'conv{i}' for i in range(1,11)]


In [None]:
img = np.array(Image.open("dataset/test/LasVegas/0000075_0018745_0000004_0002564.jpg"), np.float32)
img = img/255

for layer in longcat_layer:
    print(layer)
    feature_inversion(img, longcat, layer=layer)
    print()

## Inception, pretrained—

In [None]:
from lucent.modelzoo import inceptionv1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = inceptionv1(pretrained=True)
_ = model.to(device).eval()

In [None]:
layers = ['conv2d%d' % i for i in range(0, 3)] + \
         ['mixed3a', 'mixed3b', 'mixed4a',
          'mixed4b', 'mixed4c', 'mixed4d',
          'mixed4e', 'mixed5a', 'mixed5b']

In [None]:
for layer in layers:
    print(layer)
    feature_inversion(img, model, layer=layer)
    print()