# Model and Data Download

In [None]:
# Change the paths as you need
!mkdir /content/models
!mkdir /content/models/vgg16
!mkdir /content/models/alexnet
!wget https://download.pytorch.org/models/vgg16-397923af.pth -O /content/models/vgg16/model.pth
!wget https://download.pytorch.org/models/alexnet-owt-7be5be79.pth -O /content/models/alexnet/model.pth

In [None]:
!mkdir /content/models/resnet
!wget https://download.pytorch.org/models/resnet50-0676ba61.pth -O /content/models/resnet/model.pth

In [None]:
#Optional you can get images from anywhere you want
!wget http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar
!tar -xf /content/images.tar

# Load Input Image

In [None]:
import numpy as np
import matplotlib.pyplot as plt


In [None]:
from PIL import Image,ImageEnhance
from torchvision import transforms
input_image = Image.open('test_image.jpg')#('/content/Images/n02089867-Walker_hound/n02089867_1105.jpg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Try with and without normalizing
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
transforms.ToPILImage()(input_batch[0]).convert("RGB")

# Deconv Wrapper Class

In [None]:
class Deconv(nn.Module):
  def __init__(self,model):
    super().__init__()
    self.model = model
    self.layer_counter = 0
    self.pass_forward = 0

  def forward_one_layer(self,x):
    # print('At layer ',self.layer_counter)
    layer, rev_layer = self.model.features[self.layer_counter],self.model.rev_features[self.layer_counter]
    if layer._get_name()=='MaxPool2d':
      y,switches = layer(x)
      switches = switches[0].repeat(18,1,1,1)
    else:
      y,switches = layer(x),None
    
    if self.pass_forward==0:
      with torch.no_grad():
        norms = [i.norm() for i in y[0]]
        ordered_indexes = np.argsort(norms)
      plt.figure(figsize=(20,10))############
      count=0###############################
      back_input = []
      for ind in ordered_indexes[-18:][::-1]:
        new_y = y[0].detach().clone()
        plt.subplot(3,6,count+1)#####################################
        vis_mean = transforms.ToPILImage()(new_y[ind].clamp(min=0)).convert("RGB")
        enhancer = ImageEnhance.Contrast(vis_mean)
        factor = 2
        enh_vis_mean = enhancer.enhance(factor)
        plt.imshow(enh_vis_mean)
        count+=1
        plt.axis('off')#####################################
        new_y = torch.stack([i if index==ind else torch.zeros_like(i) for index,i in enumerate(new_y)],axis=0)

        back_input.append(new_y)
      back_input = torch.stack(back_input,axis=0)
      plt.subplots_adjust(wspace=0, hspace=0)
      plt.show()
      

      if switches is not None:
        return rev_layer(back_input,switches)
      return rev_layer(back_input)
    self.layer_counter+=1
    self.pass_forward-=1
    back_input = self.forward_one_layer(y)
    if switches is not None:
      return rev_layer(back_input,switches)
    return rev_layer(back_input)

      
  def forward(self,x,j = 1):
    self.pass_forward = j -1
    self.layer_counter = 0
    viss = self.forward_one_layer(x)
    plt.figure(figsize=(20,10))
    for i,vis in enumerate(viss):
      plt.subplot(3,6,i+1)
      # Try out various clipping/normalization methods
      vis_mean = transforms.ToPILImage()(vis.clamp(min=0)).convert("RGB")
      # ToDo try out various enhancing methods
      enhancer = ImageEnhance.Contrast(vis_mean)
      factor = 2
      enh_vis_mean = enhancer.enhance(factor)
      plt.imshow(enh_vis_mean)
      plt.axis('off')
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.show()
dvgg = Deconv(vgg16)
dvgg(input_batch,1)

# VGG 16 (without batch normalization)

In [None]:
%run vgg.py

In [None]:
for i in range(1,11):
  print('Activations and Features at layer ',i)
  dvgg = Deconv(vgg16)
  dvgg(input_batch,i)

# alexnet

In [None]:
%run alexnet.py

In [None]:
for i in range(1,len(an.features)+1):
  print('Features at layer ',i)
  dvgg = Deconv(vgg16)
  dvgg(input_batch,i)

# ResNet 50
Deconvolution method is not described for skip connections and batch normalization hence using Vanilla Backprop instead

In [None]:
%run resnet.py

## Backprop Wrapper Class

In [None]:
# Hooks are a bit weird. Every time you make a change restart the notebook you don't need to download all the files everytime
class VanillaBackprop(nn.Module):
  def __init__(self,model,threshold_for_vis=0):
    super().__init__()
    self.model = model
    self.layer_counter =0 
    try:
      count=0
      for layer in self.model.features:
        if count>=15:
            break
        if layer._get_name()=='Conv2d':
          count+=1
          print('Registering hook on ',layer)
          layer.register_forward_hook(self.hook)
    except:
      l = [module for module in self.model.modules() if type(module) != nn.Sequential]
      count=0
      for layer in l:
        if count>=15:
          break
        if layer._get_name()=='Conv2d':
          count+=1
          print('Registering hook on ',layer)
          layer.register_forward_hook(self.hook)
        

  def hook( self,module, input, output):
    print('Activations and Visualizations layer ',self.layer_counter+1,' module is ',module)
    self.layer_counter+=1
    with torch.no_grad():
        norms = [i.norm() for i in output[0]]
        ordered_indexes = np.argsort(norms)

    plt.figure(figsize=(20,10))
    for index,i in enumerate(ordered_indexes[-18:][::-1]): 
      plt.subplot(3,6,index+1)
      vis= transforms.ToPILImage()(output[0][i]).convert("RGB")
      enhancer = ImageEnhance.Contrast(vis)
      enh_vis = enhancer.enhance(2)
      plt.imshow(enh_vis,cmap='magma')
      plt.axis('off')
    plt.show()
    plt.figure(figsize=(20,10))
    for index,i in enumerate(ordered_indexes[-18:][::-1]):
      self.x.grad=None
      self.zero_grad()
      output[0][i].sum().backward(retain_graph=True)
      plt.subplot(3,6,index+1)
      # Try out various clipping/normalization methods
      vis= transforms.ToPILImage()(self.x.grad[0]).convert("RGB")
      enhancer = ImageEnhance.Contrast(vis)
      enh_vis = enhancer.enhance(2)
      plt.imshow(enh_vis)
      plt.axis('off')
    plt.show()
    



  def forward(self,x):
    self.layer_counter=0
    x.requires_grad=True
    self.x=x
    
    return self.model(x)

In [None]:
rn_backprop = VanillaBackprop(rn)
logits= rn_backprop(input_batch)