# Imports

In [None]:
import pickle
import torch

from utils.graph_helpers import plot_activation_heatmap_and_density
from utils.graph_model import GNNSAGERecommenderwithSkipConnections

In [None]:
# Set up model features
num_users = 474892
num_products = 89060
user_feature_dim = 776
product_feature_dim = 770
embedding_dim = 256
dropout_prob = 0.2

In [None]:
# Instantiate the model
model = GNNSAGERecommenderwithSkipConnections(num_users, num_products, user_feature_dim, product_feature_dim, embedding_dim, dropout_prob)
model.load_state_dict(torch.load("results/final_model/model.pt", weights_only=True))

In [None]:
# Record activations for visualization
activations = {}  # Dictionary to store activations
def hook_fn(module, input, output):
    """Hook function to capture activations."""
    activations[module] = output
model.conv1.register_forward_hook(hook_fn)
model.conv2.register_forward_hook(hook_fn)
model.user_feature_transform.register_forward_hook(hook_fn)
model.product_feature_transform.register_forward_hook(hook_fn)

In [None]:
activations = pickle.load(open("results/final_model/activations.pkl", "rb"))

# Analysing predictions

In [None]:
keys = list(activations.keys())
conv2 = keys[3]
conv1 = keys[2]
product_feature_transform = keys[1]
user_feature_transform = keys[0]

In [None]:
# Visualize activations of conv2 activation output
activations_of_interest = activations[conv2].detach().cpu().numpy()
plot_activation_heatmap_and_density(activations_of_interest, "results/final_model/conv_2")

In [None]:
# Visualize activations of conv1 activation output
activations_of_interest = activations[conv1].detach().cpu().numpy()
plot_activation_heatmap_and_density(activations_of_interest, "results/final_model/conv_1")

In [None]:
# Visualize activations of product_feature_transform activation output
activations_of_interest = activations[product_feature_transform].detach().cpu().numpy()
plot_activation_heatmap_and_density(activations_of_interest, "results/final_model/product_feature_transform")

In [None]:
# Visualize activations of user_feature_transform activation output
activations_of_interest = activations[user_feature_transform].detach().cpu().numpy()
plot_activation_heatmap_and_density(activations_of_interest, "results/final_model/user_feature_transform")