# Captum tests

Needed to modify the model code slightly, so the rest of the repo won't work

Don't forget to download the models from the README

In [None]:
import argparse
import multiprocessing as mp
import os
import os.path as osp

import numpy as np
import torch
import yaml
from munch import Munch
from softgroup.data import build_dataloader, build_dataset
from softgroup.model import SoftGroup
from softgroup.util import (collect_results_cpu, get_dist_info, get_root_logger, init_dist,
                            is_main_process, load_checkpoint, rle_decode)
from torch.nn.parallel import DistributedDataParallel
from tqdm import tqdm

args = argparse.Namespace(
    config="./configs/softgroup/softgroup_s3dis_fold5.yaml",
    # Download and extract this
    checkpoint="./models/softgroup_s3dis_spconv2.pth",
    dist=False,
)

cfg_txt = open(args.config, 'r').read()
cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
logger = get_root_logger()

model = SoftGroup(**cfg.model).cuda()
logger.info(f'Load state dict from {args.checkpoint}')
load_checkpoint(args.checkpoint, logger, model)

dataset = build_dataset(cfg.data.test, logger)
dataloader = build_dataloader(dataset, training=False, dist=args.dist, **cfg.dataloader.test)

results = []

with torch.no_grad():
    model.eval()
    for i, batch in tqdm(enumerate(dataloader), total=len(dataset)):
        # Set params here so the model only takes coords and feats as input
        model.set_params(**batch)
        result = model(batch['coords_float'],batch['feats'])
        results.append(result)
        print(batch[ "scan_ids" ])
        if i > 0:
            break

Input tensors

In [None]:
batch.keys()

Show point cloud

In [None]:
import k3d
def show_color_cloud(batch):
    colors = (batch['feats'] + 1)/2*255
    rgb = colors.cpu().numpy().astype(np.uint32)
    colors_hex = (rgb[:,0]<<16) + (rgb[:,1]<<8) + (rgb[:,2])
    coords = batch["coords_float"].cpu().numpy()

    plot = k3d.plot(grid_visible=False)
    plot += k3d.points(coords, colors_hex, point_size=0.1, shader="simple")
    plot.display()

In [None]:
show_color_cloud(batch)

In [None]:
import matplotlib.pyplot as plt

model.train()
model.set_params(**batch)
out = model(batch['coords_float'],batch['feats'])
#print(out[:10])
# N = number of points
# K = number of classes
print(out.shape) # N x K

# Chose a random point of interest
poi = 145 # [index, class]
confidences = out[poi].cpu().numpy()
plt.imshow(confidences.reshape(-1,1))
plt.title(f"Confidences for point of interest {poi}")
confidence, classification = out[poi].max(0)

In [None]:
from captum.attr import Saliency

        
model.train()
model.set_params(**batch)

def model_wrapper(coords_float, feats):
    out = model(coords_float, feats)
    return out[poi].reshape(1,-1)

att = Saliency(model_wrapper)

In [None]:
coord_attr, feat_attr = att.attribute((batch['coords_float'],batch['feats']), target = classification)

In [None]:
print(batch['coords_float'].shape, batch['feats'].shape)
print(coord_attr.shape, feat_attr.shape)

In [None]:
point_attrs = (coord_attr + feat_attr).sum(axis=1)
point_attrs.shape

coords = batch["coords_float"].detach().cpu().numpy()
attrs = point_attrs.detach().cpu().numpy()

Show attribution

In [None]:
import plotly.express as px
import plotly.graph_objects as go


coords_plot = coords[attrs > 0]
attrs_plot = attrs[attrs > 0]
p1 = px.scatter_3d(x=coords_plot[:,0], y=coords_plot[:,1], z=coords_plot[:,2],
                   size=attrs_plot,
                   color=attrs_plot,
                   opacity=0.8,
                   color_continuous_scale='viridis')
# p2 = go.Scatter3d(x=coords[:,0], y=coords[:,1], z=coords[:,2], mode="markers",marker=dict(color=out.argmax(axis=1).cpu().numpy(), size=1, opacity=0.1))


# p1.add_trace(p2)
p1

Show classifications

In [None]:
import plotly.express as px
import plotly.graph_objects as go

model.train()
model.set_params(**batch)
out = model(batch['coords_float'],batch['feats'])
px.scatter_3d(x=coords[:,0], y=coords[:,1], z=coords[:,2],color=out.argmax(axis=-1).cpu().numpy().astype(np.int), color_discrete_map="category20")