# Evaluating Epoch Models

In [9]:
import os
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision.models import resnet50
import torchvision
import random
import pickle


# set root dirs
grey_dir = "/home/local/data/sophie/imagenet/output/grey"
base_dir = "/home/local/data/sophie/imagenet/output/base"
# configure GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# fetch imagenet classes 
with open('/home/local/data/sophie/imagenet/imagenet_classes.pkl', 'rb') as f:
    clses = pickle.load(f)

# define data augmentations
img_size = 224
transforms = torchvision.transforms.Compose([
      torchvision.transforms.RandomCrop(img_size, padding=random.randint(0, 8)), # jitter
      torchvision.transforms.RandomRotation((-45, 45)), # rotate
      torchvision.transforms.RandomResizedCrop(img_size, scale=(0.9, 1.2), ratio=(1.0, 1.,0)) # scale
    ])


def load_models(base_model_path,grey_model_path):
    # initalize models
    base_model = resnet50(pretrained=True)
    grey_model = resnet50(pretrained=True)

    # load weights
    base_weights = torch.load(base_model_path, map_location='cpu')
    # configure state dict
    new_state_dict = {}
    for k, v in base_weights['model'].items():
        k = k.replace("module.", "")
        new_state_dict[k] = v
    # load model with state dict
    base_model.load_state_dict(new_state_dict)
    # disable grad
    for param in base_model.parameters():
      param.requires_grad_(False)
    
    # load weights
    grey_weights = torch.load(grey_model_path, map_location='cpu')
    # configure state dict
    new_state_dict = {}
    for k, v in grey_weights['model'].items():
        k = k.replace("module.", "")
        new_state_dict[k] = v
    # load model with state dict
    grey_model.load_state_dict(new_state_dict)
    # disable grad
    for param in grey_model.parameters():
      param.requires_grad_(False)

    # send to GPU
    base_model.to(device)
    grey_model.to(device)
    # set as eval 
    base_model.eval()
    grey_model.eval()
    return base_model, grey_model
    
def visualize_neuron(model, layer, neuron_indices, img_shape=(3, 224, 224), iterations=30, lr=0.1, device='cuda', transforms=None):
    # Initialize the input image with requires_grad=True
    input_img = torch.randn(1, *img_shape, requires_grad=True, device=device)
    # input_img = torch.randn(1, *img_shape, requires_grad=True, device='cpu')
    optimizer = torch.optim.Adam([input_img], lr=lr)

    activations = None

    # Non-intrusive hook function that captures activations
    def non_intrusive_hook_fn(module, input, output):
        nonlocal activations
        activations = output  # Do not detach to keep the graph intact

    # Register the hook to capture the layer's output
    hook = layer.register_forward_hook(non_intrusive_hook_fn)

    for i in range(iterations):
        optimizer.zero_grad()

        # Apply transformations if provided
        if transforms is not None:
            trs_img = transforms(input_img)
        else:
            trs_img = input_img

        # Forward pass
        model(trs_img)

        # Ensure that neuron_indices is within bounds
        if neuron_indices >= activations.shape[1]:
            raise ValueError(f"neuron_indices {neuron_indices} is out of bounds for activations with shape {activations.shape}")

        # Loss calculation as in Method 2
        if activations.dim() == 2:
            channel_activations = activations[:, neuron_indices]
        else:
            channel_activations = activations[:, neuron_indices, :, :]

        loss = -channel_activations.mean()

        # Backward pass to calculate gradients
        loss.backward()
        optimizer.step()

    # Remove the hook
    hook.remove()

    return input_img.detach()

def generateClassImages(base_model, grey_model, neuron, its=3000, lr=0.05, layer_name='fc', seed=0,trs=transforms):
    # get the target layers
    colour_layer1 = dict([*base_model.named_modules()])[layer_name]
    grey_layer1 = dict([*grey_model.named_modules()])[layer_name]
    # generate images
    colour_mean_img1 = visualize_neuron(base_model, colour_layer1, neuron, iterations=its, 
                                        lr=lr, transforms=trs, device=device)
    grey_mean_img1 = visualize_neuron(grey_model, grey_layer1, neuron, iterations=its, 
                                      lr=lr, transforms=trs, device=device)
    # normalise images
    colour_mean_img1_norm = (colour_mean_img1.cpu().numpy()-colour_mean_img1.cpu().numpy().min())/(colour_mean_img1.cpu().numpy().max()-colour_mean_img1.cpu().numpy().min())
    grey_mean_img1_norm = (grey_mean_img1.cpu().numpy()-grey_mean_img1.cpu().numpy().min())/(grey_mean_img1.cpu().numpy().max()-grey_mean_img1.cpu().numpy().min())
    
    return colour_mean_img1, colour_mean_img1_norm, grey_mean_img1, grey_mean_img1_norm
    

In [None]:
neuron = 963
img_class = clses[neuron].split(",")[0]
plt_dir = "/home/jovyan/pytrain_imagenet/viz_plots/{}".format(img_class)
if not os.path.exists(plt_dir):
    os.makedirs(plt_dir)

# iterate over models
for model in os.listdir(grey_dir):
    if "model" in model:
        # set current model paths
        grey_model_path = os.path.join(grey_dir, model)
        base_model_path = os.path.join(base_dir, model.replace("grey", "base"))
        # load models
        base_model, grey_model = load_models(grey_model_path, base_model_path)
        
        colour_image, norm_colour_image, grey_image, norm_grey_image, = generateClassImages(base_model, grey_model, neuron)
        
        # generate plot
        plt.figure(figsize=(10,5))
        plt.subplot(1,2,1)
        plt.imshow(norm_colour_image[0].transpose(1,2,0))
        plt.axis('off')
        plt.title("Colour FC: {}".format(img_class))
        plt.subplot(1,2,2)
        plt.imshow(norm_grey_image[0].transpose(1,2,0))
        plt.axis('off')
        plt.title("Grey FC: {}".format(img_class))
        plt.savefig(os.path.join(plt_dir, "{}.png".format(model.split("_")[-1].split(".")[0])), bbox_inches='tight')

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/jovyan/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 87.8MB/s]


In [None]:
neuron = 388
img_class = clses[neuron].split(",")[0]
plt_dir = "/home/jovyan/pytrain_imagenet/viz_plots/{}".format(img_class)
if not os.path.exists(plt_dir):
    os.makedirs(plt_dir)

# iterate over models
for model in os.listdir(grey_dir):
    if "model" in model:
        # set current model paths
        grey_model_path = os.path.join(grey_dir, model)
        base_model_path = os.path.join(base_dir, model.replace("grey", "base"))
        # load models
        base_model, grey_model = load_models(grey_model_path, base_model_path)
        
        colour_image, norm_colour_image, grey_image, norm_grey_image, = generateClassImages(base_model, grey_model, neuron)
        
        # generate plot
        plt.figure(figsize=(10,5))
        plt.subplot(1,2,1)
        plt.imshow(norm_colour_image[0].transpose(1,2,0))
        plt.axis('off')
        plt.title("Colour FC: {}".format(img_class))
        plt.subplot(1,2,2)
        plt.imshow(norm_grey_image[0].transpose(1,2,0))
        plt.axis('off')
        plt.title("Grey FC: {}".format(img_class))
        plt.savefig(os.path.join(plt_dir, "{}.png".format(model.split("_")[-1].split(".")[0])), bbox_inches='tight')

  plt.figure(figsize=(10,5))


In [None]:
neuron = 470
img_class = clses[neuron].split(",")[0]
plt_dir = "/home/jovyan/pytrain_imagenet/viz_plots/{}".format(img_class)
if not os.path.exists(plt_dir):
    os.makedirs(plt_dir)

# iterate over models
for model in os.listdir(grey_dir):
    if "model" in model:
        # set current model paths
        grey_model_path = os.path.join(grey_dir, model)
        base_model_path = os.path.join(base_dir, model.replace("grey", "base"))
        # load models
        base_model, grey_model = load_models(base_model_path,grey_model_path)
        
        colour_image, norm_colour_image, grey_image, norm_grey_image, = generateClassImages(base_model, grey_model, neuron)
        
        # generate plot
        plt.figure(figsize=(10,5))
        plt.subplot(1,2,1)
        plt.imshow(norm_colour_image[0].transpose(1,2,0))
        plt.axis('off')
        plt.title("Colour FC: {}".format(img_class))
        plt.subplot(1,2,2)
        plt.imshow(norm_grey_image[0].transpose(1,2,0))
        plt.axis('off')
        plt.title("Grey FC: {}".format(img_class))
        plt.savefig(os.path.join(plt_dir, "{}.png".format(model.split("_")[-1].split(".")[0])), bbox_inches='tight')