In [11]:
# Torch
import torch
import torch.optim as optim
from torcheval.metrics import *

# Benny pointnet
from pointnet2_benny import pointnet2_cls_msg

import dill as pickle
import shap
from captum.attr import *
import pyvista as pv
from matplotlib.colors import Normalize
from random import sample
from matplotlib.colors import LinearSegmentedColormap

# Custom modules
from final_models_explainability.explain_pointnet import *

In [12]:
def pointnet_saliency(model, cloud, device):

    model.to(device)
    
    model.eval()

    # Wrap model as pointnet_cls outputs a tuple for some reason
    wrapped_model = lambda x: model(x)[0]

    saliency = Saliency(wrapped_model)

    input = torch.from_numpy(cloud)

    # NN expects float32 on cuda
    input = input.type(torch.float32).to(device)

    # Unsqueeze to add empty batch dimension then transpose  to 3 x n as expected by pointnet
    input = input.unsqueeze(0).transpose(2, 1)

    attributions = saliency.attribute(input, target=1, abs=False)
    
    # Transpose back to n x 3 and remove batch dim
    attributions = attributions.transpose(1, 2).squeeze(0)
    
    # Move to CPU for processing
    attributions = attributions.cpu().numpy()
    
    return attributions

In [13]:
results = np.load("pointnet_eval.npz")

correct_indices = []

for i in range(len(results['true'])):

    if results['true'][i] == 1 and results['pred_classes'][i] == 1:

        correct_indices.append(i)

filtered_results = {
                    'true': [results['true'][i] for i in correct_indices],
                    'pred_probs': [results['pred_probs'][i] for i in correct_indices],
                    'data': [results['data'][i] for i in correct_indices], 
                    'attributions_zero_list': [results['attributions_zero_list'][i] for i in correct_indices],
                    'attributions_mean_list': [results['attributions_mean_list'][i] for i in correct_indices]
}

print(len(filtered_results['data']))

70


In [14]:
plotter = pv.Plotter()

colours = [(0, 'white'), (1, 'red')]
    
custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', colours)

norm_xyz_sum = normalise_attributions(filtered_results['attributions_zero_list'][0], power = 0.3)

pv_cloud = pv.PolyData(filtered_results['data'][0])

plotter.add_points(pv_cloud, scalars=norm_xyz_sum, cmap=custom_cmap, clim=[0,1])

plotter.set_background("black")

plotter.show()

Widget(value='<iframe src="http://localhost:53245/index.html?ui=P_0x31a946360_6&reconnect=auto" class="pyvista…

In [15]:
plotter = pv.Plotter()

colours = [(0, 'white'), (1, 'red')]
    
custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', colours)

model = pointnet2_cls_msg.get_model(2, normal_channel=False)

model.load_state_dict(torch.load("final_models_explainability/pointnet.pth", weights_only=True))

norm_xyz_sum = normalise_attributions(pointnet_saliency(model, filtered_results['data'][0], 'cpu'), power = 0.3)

pv_cloud = pv.PolyData(filtered_results['data'][0])

plotter.add_points(pv_cloud, scalars=norm_xyz_sum, cmap=custom_cmap, clim=[0,1])

plotter.set_background("black")

plotter.show()



Widget(value='<iframe src="http://localhost:53245/index.html?ui=P_0x16ddb45c0_7&reconnect=auto" class="pyvista…