In [26]:
import torch
from torch import nn
from lucent.optvis import render
from lucent.modelzoo import inceptionv1
import lucent
import torchvision.models as models
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class RedirectedReLU(torch.autograd.Function):
    """
    A workaround when there is no gradient flow from an initial random input
    See https://github.com/tensorflow/lucid/blob/master/lucid/misc/redirected_relu_grad.py
    Note: this means that the gradient is technically "wrong"
    TODO: the original Lucid library has a more sophisticated way of doing this
    """
    @staticmethod
    def forward(ctx, input_tensor):
        ctx.save_for_backward(input_tensor)
        return input_tensor.clamp(min=0)
    @staticmethod
    def backward(ctx, grad_output):
        input_tensor, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input_tensor < 0] = grad_input[input_tensor < 0] * 1e-1
        return grad_input


class RedirectedReluLayer(nn.Module):
    def forward(self, tensor):
        return RedirectedReLU.apply(tensor)

def redirect_relu(model):
    for child_name, child in model.named_children():
        if isinstance(child, nn.ReLU):
            print("replacing relu", child_name)
            setattr(model, child_name, RedirectedReluLayer())
        else:
            redirect_relu(child)

In [50]:
def load_model(name):
    m=getattr(models, name)(pretrained=True)
    redirect_relu(m)
    m.to(device).eval()
    return m

resnet50 = load_model("resnet50")
# inceptionv2 = load_model("inception_v3")
# alexnet = load_model('alexnet')

replacing relu relu
replacing relu relu
replacing relu relu
replacing relu relu
replacing relu relu
replacing relu relu
replacing relu relu
replacing relu relu
replacing relu relu
replacing relu relu
replacing relu relu
replacing relu relu
replacing relu relu
replacing relu relu
replacing relu relu
replacing relu relu
replacing relu relu


In [51]:
print([n for n, x in resnet50.named_modules()])
print([n for n, x in alexnet.named_modules()])
# print(lucent.modelzoo.util.get_model_layers(resnet50))
# render.render_vis(resnet50, "avgpool:7")
render.render_vis(resnet50, "fc:10")
render.render_vis(resnet50, "layer3:11")

['', 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer1.0', 'layer1.0.conv1', 'layer1.0.bn1', 'layer1.0.conv2', 'layer1.0.bn2', 'layer1.0.conv3', 'layer1.0.bn3', 'layer1.0.relu', 'layer1.0.downsample', 'layer1.0.downsample.0', 'layer1.0.downsample.1', 'layer1.1', 'layer1.1.conv1', 'layer1.1.bn1', 'layer1.1.conv2', 'layer1.1.bn2', 'layer1.1.conv3', 'layer1.1.bn3', 'layer1.1.relu', 'layer1.2', 'layer1.2.conv1', 'layer1.2.bn1', 'layer1.2.conv2', 'layer1.2.bn2', 'layer1.2.conv3', 'layer1.2.bn3', 'layer1.2.relu', 'layer2', 'layer2.0', 'layer2.0.conv1', 'layer2.0.bn1', 'layer2.0.conv2', 'layer2.0.bn2', 'layer2.0.conv3', 'layer2.0.bn3', 'layer2.0.relu', 'layer2.0.downsample', 'layer2.0.downsample.0', 'layer2.0.downsample.1', 'layer2.1', 'layer2.1.conv1', 'layer2.1.bn1', 'layer2.1.conv2', 'layer2.1.bn2', 'layer2.1.conv3', 'layer2.1.bn3', 'layer2.1.relu', 'layer2.2', 'layer2.2.conv1', 'layer2.2.bn1', 'layer2.2.conv2', 'layer2.2.bn2', 'layer2.2.conv3', 'layer2.2.bn3', 'layer2.2.relu', 'layer2.

100%|█████████████████████████████| 512/512 [00:23<00:00, 21.80it/s]
100%|█████████████████████████████| 512/512 [00:20<00:00, 25.44it/s]


[array([[[[0.2367459 , 0.27923757, 0.06287524],
          [0.8577448 , 0.89851636, 0.82486796],
          [0.24325831, 0.29001126, 0.28956234],
          ...,
          [0.24543612, 0.08558498, 0.23577204],
          [0.30836508, 0.11068077, 0.37057713],
          [0.7890136 , 0.5117073 , 0.72716236]],
 
         [[0.24291992, 0.33009845, 0.4123458 ],
          [0.5348774 , 0.4893551 , 0.47835368],
          [0.74113566, 0.73734206, 0.6040661 ],
          ...,
          [0.25935367, 0.09364453, 0.25645468],
          [0.6756637 , 0.6204074 , 0.56250894],
          [0.55517566, 0.37849906, 0.50153595]],
 
         [[0.33293498, 0.32150447, 0.11175639],
          [0.42920393, 0.48505506, 0.6034013 ],
          [0.7156328 , 0.7514193 , 0.62032527],
          ...,
          [0.17285085, 0.05920639, 0.10277893],
          [0.74947137, 0.6235415 , 0.63907343],
          [0.42556205, 0.12283821, 0.31467655]],
 
         ...,
 
         [[0.50020397, 0.56715006, 0.4256577 ],
          [0.45848