In [17]:
import torch
from torchvision import models, transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import warnings
warnings.filterwarnings("ignore")

In [18]:
# Load a pre-trained deep learning model
model = models.vgg16(pretrained=True)
model.eval()

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [19]:
# Define the input image preprocessing transformation
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [20]:
# Load and preprocess the input image
image_path = '/Users/xxxyy/PycharmProjects/UoB/visulisation/video_sampled_frame/7.png'
image = Image.open(image_path)
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)
# input_batch

In [21]:
# Dictionary to save the features of each layer
features = {}

In [22]:
# Define the model hook function to save the features of each layer
def save_features(name):
    def hook(module, input, output):
        features[name] = output.detach()
    return hook

In [23]:
# Register the model hook function
hooks = []
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        hook = module.register_forward_hook(save_features(name))
        hooks.append(hook)

In [24]:
# Use the pre-trained model for forward propagation
output = model(input_batch)
output

tensor([[ 9.7795e-01, -4.3383e+00, -2.5600e-01, -1.9550e+00, -5.3776e-01,
         -2.3525e+00, -2.1363e+00, -1.6020e+00, -4.9646e-01, -1.4015e+00,
         -2.6557e+00, -3.6273e+00, -2.9918e+00, -1.6687e+00, -3.3448e+00,
         -4.3427e+00, -3.3815e+00,  2.9765e-01, -1.2855e-01, -1.0438e+00,
         -2.8982e+00, -1.7948e+00,  1.0888e+00, -5.6991e-01,  3.5009e+00,
         -3.0179e+00, -5.6532e+00, -3.0325e+00, -3.0167e+00, -7.2720e+00,
         -4.4232e+00, -3.9834e+00, -4.6253e+00, -1.0330e+00,  2.1813e+00,
         -3.3619e+00, -1.9814e+00, -4.0117e+00, -2.9534e+00, -8.7990e-01,
         -3.4365e+00, -3.3295e+00, -1.9446e+00, -2.2538e+00, -3.9583e+00,
         -1.0898e+00, -4.3982e+00, -4.9961e+00, -3.0081e+00, -1.7991e+00,
          6.6251e-01, -1.5115e+00, -3.8117e+00, -3.8402e+00, -3.2338e+00,
         -5.5393e+00, -1.9041e+00, -5.4470e+00, -2.7533e+00, -3.0394e+00,
         -4.3945e-02, -2.2833e+00, -2.2868e+00, -1.5722e+00, -3.9776e+00,
         -3.1006e+00, -4.3494e+00,  1.

In [25]:
# Create a folder to save the feature maps
save_folder = '/Users/xxxyy/PycharmProjects/UoB/visulisation/vgg/deep_feautures_visulisation_2D_max/'
os.makedirs(save_folder, exist_ok=True)

In [26]:
features.keys()

dict_keys(['features.0', 'features.2', 'features.5', 'features.7', 'features.10', 'features.12', 'features.14', 'features.17', 'features.19', 'features.21', 'features.24', 'features.26', 'features.28'])

In [27]:
# Iterate over each layer and save the slice images
for name, output in features.items():
    # Get the size of the feature map
    num_channels, height, width = output.shape[1:]

    # Find the channel with the maximum activation
    max_channel = torch.argmax(output.view(num_channels, -1).mean(dim=1))

    # Get the slice image of the maximum activation channel
    slice_image = output[0, max_channel, :, :].cpu().numpy()

    # Generate the save path
    save_path = os.path.join(save_folder, f'{name}_max_activation.png')

    # Add title to the plot
    plt.imshow(slice_image, cmap='viridis')
    plt.axis('off')
    plt.title(f'{name}__max activation', fontsize=26)  # Add the title here

    # Save the slice image
    plt.savefig(save_path)
    plt.close()

In [28]:
# plot all the max activation feature maps
import math

# Calculate the number of rows and columns in the final grid
num_images = len(features)
num_columns = 7
num_rows = math.ceil(num_images / num_columns)

# Create a new figure with the appropriate size
fig, axs = plt.subplots(num_rows, num_columns, figsize=(16, 16))

# Iterate over each layer and save the slice images
for i, (name, output) in enumerate(features.items()):
    # Get the size of the feature map
    num_channels, height, width = output.shape[1:]

    # Find the channel with the maximum activation
    max_channel = torch.argmax(output.view(num_channels, -1).mean(dim=1))

    # Get the slice image of the maximum activation channel
    slice_image = output[0, max_channel, :, :].cpu().numpy()

    # Calculate the row and column index for the current image
    row_index = i // num_columns
    col_index = i % num_columns

    # Plot the slice image on the corresponding subplot
    axs[row_index, col_index].imshow(slice_image, cmap='viridis')
    axs[row_index, col_index].axis('off')

    # Add text label with the coordinates
    axs[row_index, col_index].text(0.5, -0.1, f'({name})',
                                   transform=axs[row_index, col_index].transAxes,
                                   fontsize=15, ha='center')

# Remove any empty subplots
if num_images < num_rows * num_columns:
    for i in range(num_images, num_rows * num_columns):
        row_index = i // num_columns
        col_index = i % num_columns
        axs[row_index, col_index].axis('off')

# Add title to the plot
fig.suptitle('Max Activation Feature Maps for different layers of VGG16', fontsize=26)

# Adjust the spacing between subplots
fig.tight_layout()

# Save the figure
save_path = os.path.join(save_folder, 'vgg16_max_activation_all_images.eps')
plt.savefig(save_path)
plt.close()

In [29]:
# Remove the model hooks
for hook in hooks:
    hook.remove()