Original code from:
https://colab.research.google.com/github/pytorch/pytorch.github.io/blob/master/assets/hub/pytorch_vision_googlenet.ipynb

# Get output probabilities

In [None]:
%%capture
import torch
import torchvision.models as models
from torchvision import transforms
import copy
from PIL import Image

# model = torch.hub.load('pytorch/vision:v0.9.0', 'googlenet', pretrained=True)
model = models.googlenet(pretrained=True)  #w/o arg, this will not pretrain it
# model = models.vgg16(pretrained=True)

model.eval() #set model in eval mode: https://stackoverflow.com/questions/60018578/what-does-model-eval-do-in-pytorch

# Download ImageNet labels
!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

In [None]:
# Download an example image from the pytorch website
import urllib
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)

In [None]:
from google.colab import files
files.upload()

Saving greatwhiteshark.jpg to greatwhiteshark.jpg
Saving tigershark.jpg to tigershark.jpg


{'greatwhiteshark.jpg': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00\x84\x00\t\x06\x07\x0f\x12\x12\x15\x0f\x0f\x12\x0f\x15\x15\x15\x15\x0f\x0f\x15\x15\x15\x0f\x0f\x10\x10\x0f\x10\x15\x15\x16\x16\x15\x15\x15\x15\x18\x1d( \x18\x1a%\x1d\x15\x15!1!%)+...\x17\x1f383-7(-.+\x01\n\n\n\x0e\r\x0e\x17\x10\x10\x1a-% %--++++---/-+-----+------------------------+-------\xff\xc0\x00\x11\x08\x00\xb7\x01\x13\x03\x01"\x00\x02\x11\x01\x03\x11\x01\xff\xc4\x00\x1b\x00\x00\x03\x00\x03\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\xff\xc4\x00;\x10\x00\x02\x02\x01\x02\x03\x05\x05\x06\x04\x06\x03\x01\x00\x00\x00\x00\x01\x02\x11\x03\x04!\x121A\x05Qaq\x81"\x91\xa1\xb1\xc1\x06\x132b\xd1\xf0\x14R\x92\xe1#3r\x82\xa2\xf1\x16BC\x15\xff\xc4\x00\x1a\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x03\x02\x04\x05\x06\xff\xc4\x00$\x11\x01\x00\x02\x02\x02\x01\x05\x00\x03\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x11\x041!\

In [None]:
def get_output(model, input_image):
    # input_image = Image.open(input)
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

    # move the input and model to GPU for speed if available
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    with torch.no_grad():
        output = model(input_batch)
    return output
    # Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
    # print(output[0])  # The output has unnormalized scores. To get probabilities, you can run a softmax on it.

def output_probs(model, input_image):
    output = get_output(model, input_image)
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    return probabilities

def get_top_categories(probabilities, categories):
  top5_prob, top5_catid = torch.topk(probabilities, 5)
  for i in range(top5_prob.size(0)):
      print(categories[top5_catid[i]], top5_prob[i].item())


In [None]:
input_image_1 = Image.open('tigershark.jpg')
input_image_2 = Image.open('greatwhiteshark.jpg')

probs_tiger_shark = output_probs(model, input_image_1)
probs_great_white_shark = output_probs(model, input_image_2)

In [None]:
get_top_categories(probs_tiger_shark, categories)

tiger shark 0.9285188317298889
gar 0.008483723737299442
great white shark 0.006604191847145557
sturgeon 0.005556200165301561
axolotl 0.0019637865480035543


In [None]:
get_top_categories(probs_great_white_shark, categories)

great white shark 0.6037811636924744
tiger shark 0.35362929105758667
hammerhead 0.02359415590763092
sturgeon 0.001169938943348825
killer whale 0.0006463401368819177


---

# Compare probabilities after 1 neuron knockout

In [None]:
#knockout conv 3
new_model_1 = copy.deepcopy(model)
new_model_1.conv3.conv.weight.data = new_model_1.conv3.conv.weight.data * 0

new_probs_tiger_shark = output_probs(new_model_1, input_image_1)

def compare_probs(probabilities, new_probabilities, categories):
    top5_prob, top5_catid = torch.topk(probabilities, 5)
    new_top5_prob, new_top5_catid = torch.topk(new_probabilities, 5)
    for i in range(new_top5_prob.size(0)):
        row = ['before:', categories[top5_catid[i]], round(top5_prob[i].item(), 5), 'after:', categories[new_top5_catid[i]], round(new_top5_prob[i].item(), 5)]
        print("{: >10} {: >15} {: >10} {: >10} {: >15} {: >10}".format(*row))

compare_probs(probs_tiger_shark, new_probs_tiger_shark, categories)

   before:     tiger shark    0.92852     after:    space heater     0.4274
   before:             gar    0.00848     after:     loudspeaker    0.40686
   before: great white shark     0.0066     after:   window screen    0.08127
   before:        sturgeon    0.00556     after:        strainer    0.04828
   before:         axolotl    0.00196     after:           radio    0.00625


In [None]:
# Find avg change in probability for all classes

def Average(lst):
    return sum(lst) / len(lst)

prob_change = [probs_tiger_shark[i] - new_probs_tiger_shark[i] for i in range(len(probs_tiger_shark))]

print(Average(prob_change[5:]))

tensor(-0.0009, device='cuda:0')


In [None]:
# get top 3 most frequent changes in prob
from collections import Counter
c = Counter(prob_change)
c.most_common(3)

[(tensor(0.0003, device='cuda:0'), 1),
 (tensor(0.0002, device='cuda:0'), 1),
 (tensor(0.0066, device='cuda:0'), 1)]

In [None]:
# remove outliers then get mean

---

# Knockout neuron communities

**KNOCKOUT LAYERS**

The following shows knocking out which layers cause the greatest change in the predicted probability for Tiger Shark. The output of this cell shows the new predicted probability for Tiger Shark, so the layer which has the lowest probability suggests that layer is very important for predicting Tiger Shark.

In [None]:
#loop through every layer and find top 10 layers in which knocking out that layer causes the biggest change
torch.cuda.empty_cache()

layer_scores = {}
new_model_2 = copy.deepcopy(model)
for layer_name, layer in new_model_2.named_modules():
    if type(layer) == torch.nn.modules.conv.Conv2d:
        temp_weightMatrix = layer.weight.data.detach().clone() #old values
        layer.weight.data = layer.weight.data * 0  #modify using call by reference 
        new_probabilities = output_probs(new_model_2, input_image_1)
        layer_scores[layer_name] = new_probabilities[3].item()
        layer.weight.data = temp_weightMatrix # restore old values

sorted(layer_scores.items(), key=lambda item: item[1])


**KNOCKOUT FILTERS**

Next, try knocking out individual filters within an important layer to further pinpoint which filters in it are important.

In [None]:
def knockout_layer(layer_name, model, input_image):
    new_model = copy.deepcopy(model)
    for name_to_check, layer in new_model.named_modules():
        if name_to_check == layer_name:
            break
    filter_scores = {}
    for filter_ind in range(layer.weight.data.shape[0]):  #layer.weight.data.shape  #[# filters, # channels, filter_height, filter_width]
        temp_weightMatrix = layer.weight.data[filter_ind].detach().clone() #old values
        layer.weight.data[filter_ind] = layer.weight.data[filter_ind] * 0  #modify using call by reference 
        new_probabilities = output_probs(new_model, input_image_1)
        filter_scores[filter_ind] = new_probabilities[3].item()
        layer.weight.data[filter_ind] = temp_weightMatrix # restore old values

    # Show bottom 10 and top 10 weights
    return sorted(filter_scores.items(), key=lambda item: item[1])

knockout_layer('conv3.conv', model, input_image_1)
# knockout_layer('inception3a.branch2.1.conv', model, input_image_1)
# knockout_layer('features.26')  #VGG

**KNOCKOUT RANDOM COMBOS OF FILTERS**

Instead of knocking out one filter, try combos of filters. Try random percentages of a layer, then avg different combos of the same % 



In [None]:
import random

def knockout_layer_perc(layer_name):
    new_model = copy.deepcopy(model)
    for name_to_check, layer in new_model.named_modules():
        if name_to_check == layer_name:
            break
    filter_scores = {}
    for perc in range(1,10):
        perc /= 10
        temp_weightMatrix = layer.weight.data.detach().clone() #old values
        total_num_filters = layer.weight.data.shape[0]
        knockedout_filts = random.sample(range(0,total_num_filters), round(perc*total_num_filters))  #no repeats
        for filter_ind in knockedout_filts:
            layer.weight.data[filter_ind] = layer.weight.data[filter_ind] * 0  #modify using call by reference 
        new_output = new_model(input_batch)  # run modified model
        new_probabilities = torch.nn.functional.softmax(new_output[0], dim=0)
        filter_scores[perc] = new_probabilities[3].item()
        layer.weight.data = temp_weightMatrix # restore old values

    # Show bottom 10 and top 10 weights
    return sorted(filter_scores.items(), key=lambda item: item[1])

# knockout_layer('conv3.conv')
knockout_layer_perc('inception3a.branch2.1.conv')

[(0.9, 0.0006269084988161922),
 (0.7, 0.038260217756032944),
 (0.8, 0.04518599808216095),
 (0.6, 0.13788191974163055),
 (0.5, 0.5507573485374451),
 (0.4, 0.6387174725532532),
 (0.2, 0.761429488658905),
 (0.3, 0.8238069415092468),
 (0.1, 0.8852053880691528)]

---

Get probabilities for classes in Abstract Class vs Others after knockout

In [None]:
# Abstract class of sharks. 3: tiger shark
abstract_class = [2, 3, 4]

In [None]:
for i in abstract_class:
    row = [categories[i], 'before:', probs_tiger_shark[i], 'after:', new_probs_tiger_shark[i]]
    print("{: >10} {: >10} {: >10} {: >10}".format(*row))

   before: 1.5476764758659556e-07     after: 1.5476764758659556e-07
   before: 9.497205155639676e-08     after: 9.497205155639676e-08
   before: 4.6568541733904567e-07     after: 4.6568541733904567e-07
