In [31]:
import torch
from torchvision import models, transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import warnings
warnings.filterwarnings("ignore")

In [32]:
# Load a pre-trained deep learning model
model = models.vgg16(pretrained=True)
model.eval()

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [33]:
# Define the pre-processing transformation of the input image
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [34]:
# Load and pre-process the input image
image_path = '/Users/xxxyy/PycharmProjects/UoB/visulisation/video_sampled_frame/7.png'
image = Image.open(image_path)
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)

In [35]:
# Dictionaries to hold each layer of features
features = {}

In [36]:
# Define model hook functions to get the features of each layer
def save_features(name):
    def hook(module, input, output):
        features[name] = output.detach()
    return hook

In [37]:
# Register model hook functions
hooks = []
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        hook = module.register_forward_hook(save_features(name))
        hooks.append(hook)

In [38]:
# Forward propagation using pre-trained models
output = model(input_batch)

In [39]:
# Create a folder to save feature maps
save_folder = '/Users/xxxyy/PycharmProjects/UoB/visulisation/vgg/deep_features_visulisation_3D/'
os.makedirs(save_folder, exist_ok=True)

In [40]:
# Set the font size for the entire figure
plt.rcParams['font.size'] = 26
# Visualize and save feature maps for each layer
for name, output in features.items():
    # Extract the number of channels and spatial dimensions from the feature map
    num_channels = output.shape[1]
    spatial_size = output.shape[2:]

    # Limit the number of channels to 64
    num_channels = min(num_channels, 64)

    # Create an image grid for displaying feature maps
    grid_cols = 8  # of feature maps displayed per row
    grid_rows = (num_channels - 1) // grid_cols + 1
    fig, axes = plt.subplots(grid_rows, grid_cols, figsize=(30, 30))

    # Display feature maps for each channel
    for i, ax in enumerate(axes.flat):
        if i < num_channels:
            feature_map = output[0, i, :, :].cpu().numpy()
            ax.imshow(feature_map, cmap='viridis')
            ax.axis('off')

            # Add channel name to x-axis and y-axis
            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            ax.set_title(f'Channel {i+1}', fontsize=26)

        else:
            ax.axis('off')

    # Adjust the spacing between subplots
    fig.tight_layout()

    # Add the figure title above the subplots
    fig.subplots_adjust(top=0.95)  # Adjust top parameter to leave space for the figure title
    fig.suptitle(name, fontsize=26)

    # Save the figure
    save_path = os.path.join(save_folder, f'{name}.png')
    plt.savefig(save_path)
    plt.close()

In [41]:
# Remove model hooks
for hook in hooks:
    hook.remove()