In [1]:
# Torch
import torch
import torch.optim as optim
from torcheval.metrics import *

import pickle
from captum.attr import *
import random
import numpy as np

# Custom modules
from preprocessing_post_fastsurfer.subject import *
from preprocessing_post_fastsurfer.vis import *
from ozzy_torch_utils.split_dataset import *
from ozzy_torch_utils.subject_dataset import *
from ozzy_torch_utils.plot import *
from ozzy_torch_utils.train_nn import *
from ozzy_torch_utils.model_parameters import *
from ozzy_torch_utils.init_dataloaders import *
from explain_pointnet import *

In [2]:
# Load model
pickle_pathname = "/uolstore/home/student_lnxhome01/sc22olj/Compsci/year3/individual-project-COMP3931/individual-project-sc22olj/runs/run_18-03-2025_15-35-05/run_18-03-2025_15-35-05_params.pkl"

with open(pickle_pathname, 'rb') as file:
    
    model_parameters = pickle.load(file)
    
model = model_parameters.model

In [3]:
# Load dataset
data_path = "/uolstore/home/users/sc22olj/Compsci/year3/individual-project-COMP3931/individual-project-sc22olj/scratch-disk/full-datasets/hcampus-1.5T-cohort"

subject_list = find_subjects_parallel(data_path)

Interesting experiment comparing attributions from two permutations of the same cloud

In [4]:
subject = sample(subject_list, 1)[0]

cloud = np.load(os.path.join(subject.path, "Left-Hippocampus_aligned_cropped_mesh_downsampledcloud.npy"))
    

In [None]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"

attributions_orig, pred_research_group_orig = pointnet_ig(model, cloud, device)

shuffler = np.random.permutation(cloud.shape[0])

unshuffler = np.argsort(shuffler)

cloud_shuffled = np.array(cloud[i] for i in shuffler)

attributions_shuffled, pred_research_group_shuffle = pointnet_ig(model, cloud_shuffled, device)

cloud_unshuffled = [cloud_shuffled[i] for i in unshuffler]

attributions_unshuffled = np.array([attributions_shuffled[i] for i in unshuffler])

attributions_diff = attributions_orig - attributions_unshuffled

print(attributions_diff)

if pred_research_group_orig != pred_research_group_shuffle:
    
    print("Research groups are different after shuffle")
    

[[6.65154829e-08 1.01864301e-07 2.99499747e-07]
 [1.47840784e-02 4.18064712e-02 2.85079568e-02]
 [7.34231075e-13 3.12145172e-13 6.60628809e-13]
 ...
 [1.33077205e-07 6.12107915e-09 2.18990218e-07]
 [1.14355741e-10 7.65692545e-12 4.75774655e-12]
 [1.41251399e-06 2.98809966e-06 2.19691583e-06]]


In [None]:
def vis_attributions(attributions, subject, cloud, pred_research_group):
    
    # Sum x, y and z values for an overall attribution for that point
    xyz_sum = np.sum(attributions, axis=1)

    xyz_sum = np.sign(xyz_sum) * np.power(np.abs(xyz_sum), 0.1)

    # Normalise into range -1, 1 such that positive and negative attributions are preserved
    norm = Normalize(vmin = -np.max(np.abs(xyz_sum)), vmax = np.max(np.abs(xyz_sum)))

    norm_attributions = norm(xyz_sum)

    # Cmap for pyvista
    cmap = plt.get_cmap('seismic')
    colours = cmap(norm_attributions)

    pv_cloud = pv.PolyData(cloud)

    plotter = pv.Plotter()

    plotter.add_points(pv_cloud, scalars=colours, rgb=True)

    plotter.set_background("black")

    # THIS IS NOT FOR USE, IT IS RUNNING PREDICTIONS ON TRAINING DATA!!
    plotter.add_text("This is just a test running on training data!", color='white')
    plotter.add_text(f"True class: {str(subject.subject_metadata['Group'].iloc[0])} \n Predicted class: {pred_research_group} ", color='white', position='upper_right')

    plotter.show()
    
    return

In [None]:
vis_attributions(attributions_orig, subject, cloud, pred_research_group_orig)

#vis_attributions(attributions_unshuffle, subject, cloud, pred_research_group_shuffle)

#vis_attributions(attributions_diff, subject, cloud, pred_research_group_orig)

Widget(value='<iframe src="http://localhost:34717/index.html?ui=P_0x7f27091ba3f0_0&reconnect=auto" class="pyvi…

Widget(value='<iframe src="http://localhost:34717/index.html?ui=P_0x7f27081d9ac0_1&reconnect=auto" class="pyvi…

Widget(value='<iframe src="http://localhost:34717/index.html?ui=P_0x7f2701d98f80_2&reconnect=auto" class="pyvi…