In [None]:
import MinkowskiEngine as ME
import matplotlib.pyplot as plt
import matplotlib as mpl
import importlib
import torchvision.transforms.v2 as transforms
import torchvision.transforms.v2.functional as F
from torch import nn

## Jupyter magic
%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)
import numpy as np
SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)

In [None]:
## Includes from my libraries for this project                                                                                                                                           
from ME_dataset_libs import make_dense, make_dense_from_tensor, Label

In [None]:
from FSD_training_analysis import get_models_from_checkpoint

## Load in the pre-calculated model weights
file_dir = "/pscratch/sd/c/cwilk"

## This is interesting, but limited so the best performance really is for ~N=20-30. The best silhouette is ~0.25
# chk_file = "state_lat64_hid128_clust25_nchan64_5E-6_1024_PROJ0.5one_CLUST0.5one_ent1E-1_soft1.0_arch24x8silu_poolmax_flat1_grow1_kern7_sep1_onecycle50_bigaugbilinfix0.5_DROP0_WEIGHT_DECAY0.05_10M_DATA1_FSDCCFIX.pth"
# chk_file="state_lat32_hid256_clust25_nchan48_1E-5_1024_PROJ0.5_CLUST0.5two_ent1E-1_soft1.0_arch12x4_poolmax_flat1_grow1_kern7_sep1_onecycle50_bigaugbilinfix_1M_DATA1_FSDCCFIX.pth"

## Try one with 24x8
# chk_file="state_lat64_hid256_clust50_nchan64_5E-6_1024_PROJ0.5_CLUST0.5two_ent1E-1_soft1.0_arch24x8_poolmax_flat1_grow1_kern7_sep1_onecycle50_bigaugbilinfix0.5_5M_DATA1_FSDCCFIX.pth"
chk_file="state_lat128_hid256_clust30_nchan64_5E-6_1024_PROJ0.5two_CLUST0.5two_ent1E-1_soft1.0_archd4silu_poolmax_flat0_grow1_kern7_sep0_onecycle50_newbaseaug0.5_DROP0_WEIGHT_DECAY0_5M_DATA1_FSDCCFIX.pth"
encoder, heads, args = get_models_from_checkpoint(file_dir+"/"+chk_file)
encoder.eval()
for h in heads.values(): h.eval()

encoder.to(device)
for h in heads.values(): h.to(device)

print("Loaded:", chk_file)

In [None]:
## Setup the dataloader
from FSD_training_analysis import get_dataset
import time

data_dir = "/pscratch/sd/c/cwilk/FSD/DATA"
sim_dir = "/pscratch/sd/c/cwilk/FSD/SIMULATIONv2"
max_data_events=100000
max_sim_events=100000

start = time.time() 
sim_dataset, sim_loader = get_dataset(sim_dir, max_sim_events, return_metadata=True)
data_dataset, data_loader = get_dataset(data_dir, max_data_events, return_metadata=True)
print("Time taken to load", data_dataset.__len__(),"data and", sim_dataset.__len__(), "images:", time.time() - start)

In [None]:
import numpy as np
import FSD_training_analysis
importlib.reload(FSD_training_analysis)
from FSD_training_analysis import image_loop, reorder_clusters
import time
start = time.time()
## Get the processed vectors of interest from the datasets                                                                                                                                                     
data_processed = image_loop(encoder, heads, data_loader, False)
sim_processed = image_loop(encoder, heads, sim_loader, False)
print("Time to process events:", time.time() - start)

start = time.time()
## Do some magic to re-order the clusters for presentation purposes                                                                                                                                            
reorder_clusters(data_processed, sim_processed)
print("Time to reorder events:", time.time() - start)


In [None]:
import importlib
import ME_analysis_libs
importlib.reload(ME_analysis_libs)
from ME_analysis_libs import plot_metric_data_vs_sim

plot_metric_data_vs_sim(data_processed['clust_index'],
                        sim_processed['clust_index'], 
                        sim_processed['labels'],
                        xtitle="Max. cluster index")

In [None]:
## Play with some GMM options

def calc_saliency(orig, encoder, clust_head, cluster_index, device=device):

    # Prepare features with gradient enabled
    feats  = orig.F.clone().detach().to(device)
    feats.requires_grad_(True)

    coords = orig.C.to(device)

    st = ME.SparseTensor(feats, coords, device=device)

    # Forward pass through encoder + head
    _, cluster_batch = encoder(st, 1)
    clust_probs = clust_head(cluster_batch)

    # Select score
    score = clust_probs[:, cluster_index].sum()

    # Backprop
    score.backward()

    # Extract saliency
    saliency = (feats * feats.grad).detach().cpu()
    # saliency = feats.grad.detach().cpu()  # shape (N,1)
    sal_coords = coords.cpu()
    
    return saliency, coords


In [None]:
def smooth_grad(orig, encoder, clust_head, cluster_index, n_samples=50, sigma=0.2, device=device):
    feats = orig.F.clone().detach().to(device)
    coords = orig.C.to(device)
    st_base = ME.SparseTensor(feats, coords, device=device)

    total_grad = torch.zeros_like(feats)
    for _ in range(n_samples):
        noise = torch.randn_like(feats) * sigma
        feats_noisy = (feats + noise).clone().detach().requires_grad_(True)
        st_noisy = ME.SparseTensor(feats_noisy, coords, device=device)
        _, cluster_batch = encoder(st_noisy, 1)
        clust_probs = clust_head(cluster_batch)
        score = clust_probs[:, cluster_index].sum()
        score.backward()
        total_grad += feats_noisy.grad

    saliency = (total_grad / n_samples).detach().cpu()
    return saliency, coords

In [None]:

from MinkowskiEngine import MinkowskiConvolution

def find_last_conv(module):
    for m in reversed(list(module.modules())):
        if isinstance(m, MinkowskiConvolution):
            return m
    raise RuntimeError("No MinkowskiConvolution found")
    
def grad_cam_sparse_hook(orig, encoder, clust_head, cluster_index, device='cpu'):
    """
    Grad-CAM for MinkowskiEngine sparse tensors using forward hook.
    
    Args:
        orig: ME.SparseTensor input
        encoder: your encoder module
        clust_head: classification head
        cluster_index: class to compute CAM for
        device: torch device
    Returns:
        cam_big: upsampled Grad-CAM numpy array
    """
    # --- 1. Prepare input ---
    feats = orig.F.clone().detach().to(device)
    feats.requires_grad_(True)
    coords = orig.C.clone().to(device)
    orig_st = ME.SparseTensor(feats, coords, device=device)

    # --- 2. Hook to capture features ---
    saved = {}
    def forward_hook(module, input, output):
        # output is SparseTensor
        saved['features'] = output.F
        saved['coords'] = output.C
        saved['output_tensor'] = output
        saved['features'].retain_grad()

    # Register hook on last conv layer of encoder
    target_layer = find_last_conv(encoder.encoders)
    handle = target_layer.register_forward_hook(forward_hook)
    # handle = encoder.encoder_cnn[-1].register_forward_hook(forward_hook)

    # --- 3. Forward pass ---
    head_feats, cluster_batch = encoder(orig_st, batch_size=1, return_maps=False)
    scores = clust_head(cluster_batch)
    score = scores[:, cluster_index].sum()

    # --- 4. Backward pass ---
    encoder.zero_grad()
    clust_head.zero_grad()
    score.backward()

    # --- 5. Grab gradients and compute weights ---
    grads = saved['features'].grad  # (N, C)
    if grads is None:
        raise RuntimeError("Gradients not found! Make sure forward hook and retain_grad are working.")

    H=48
    W=16
    # Convert to dense for CAM computation
    dense_maps, _, _ = saved['output_tensor'].dense(
        shape=torch.Size([1, saved['features'].shape[1], H, W])
    )
    dense_grads, _, _ = saved['output_tensor'].dense(
        shape=torch.Size([1, saved['features'].shape[1], H, W])
    )
    dense_grads = dense_grads.to(device)

    # Compute channel weights (global average pooling over H,W)
    weights = dense_grads.mean(dim=(2,3), keepdim=True)

    # Compute Grad-CAM
    cam = (weights * dense_maps).sum(dim=1, keepdim=True)
    cam = torch.relu(cam)
    cam = cam - cam.min()
    if cam.max() > 0:
        cam = cam / cam.max()

    # Upsample to match original input resolution
    cam_big = F.interpolate(cam, size=(768, 256), mode='nearest')
    cam_big = cam_big[0,0].detach().cpu().numpy()

    # --- 6. Cleanup ---
    handle.remove()

    return cam_big #[0,0].detach().cpu().numpy() #_big

In [None]:
from matplotlib import cm

def plot_saliency_block(dataset, encoder, clust_head, cluster_ids, cluster_index, max_x=10, cluster_probs=None, save_name=None): 

    ## Sort colours
    cmap = cm.turbo.copy()
    cmap.set_under("#F0F0F0")

    #cmap_sal = cm.PRGn.copy()
    cmap_sal = cm.YlGn.copy()
    # cmap_sal.set_under("#F0F0F0")
    
    plt.figure(figsize=(max_x*2.1, 2*6))
    ## Get a mask of cluster_ids
    indices = np.arange(max_x*2) 
    if cluster_index != None: 
        indices = np.where(np.array(cluster_ids) == cluster_index)[0]
        ## If the probabilities are given, show the top N probabilities
        if cluster_probs is not None:
            indices = indices[np.argsort(np.array(cluster_probs)[indices])][::-1]
    max_images = min(len(indices), max_x)
    
    ## Plot
    for i in range(max_images):
        ax = plt.subplot(2,max_x,i+1)
        
        numpy_coords, numpy_feats, *_ = dataset[indices[i]]
    
        # Create batched coordinates for the SparseTensor input
        orig_bcoords  = ME.utils.batched_coordinates([numpy_coords])
        orig_bfeats  = torch.from_numpy(np.concatenate([numpy_feats], 0)).float()
        orig = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)
            
        inputs  = make_dense_from_tensor(orig, 0, 768, 256)
        inputs  = inputs .cpu().squeeze().numpy()

        nonzero_vals = inputs[inputs > 0]
        vmax = np.percentile(nonzero_vals, 80)
        plt.imshow(inputs, origin='lower', cmap=cmap, vmin=1e-6, vmax=vmax)
        ax.axis('off')
        plt.tight_layout()
        ax = plt.subplot(2,max_x,i+1+max_x)

        # sal_feats, sal_coords = calc_saliency(orig, encoder, clust_head, cluster_index, device=device)
        # sal_feats, sal_coords = smooth_grad(orig, encoder, clust_head, cluster_index, device=device)
        outputs = grad_cam_sparse_hook(orig, encoder, clust_head, cluster_index, device=device)

        #sal = ME.SparseTensor(sal_feats, sal_coords, device='cpu')

        #outputs  = make_dense_from_tensor(sal, 0, 768, 256)
        #outputs  = outputs .cpu().squeeze().numpy()    

        # print(np.isnan(outputs).any(), np.isinf(outputs).any(), outputs.min(), outputs.max())

        vmax = outputs.max()
        if abs(outputs.min()) > vmax: vmax = abs(outputs.min())
        
        plt.imshow(outputs, origin='lower', cmap=cmap_sal, vmin=1E-8, vmax=vmax)
        ax.axis('off')
        plt.tight_layout()

        #sal_feats, sal_coords = smooth_grad(orig, encoder, clust_head, cluster_index, device=device)
        #sal = ME.SparseTensor(sal_feats, sal_coords, device='cpu')

        #sal_outputs  = make_dense_from_tensor(sal, 0, 768, 256)
        # sal_outputs  = outputs .cpu().squeeze().numpy()    

    
    plt.tight_layout()
    if save_name: plt.savefig(save_name, dpi=300, bbox_inches='tight')
    plt.show()  
    plt.close()

In [None]:
from matplotlib import cm
import torch.nn.functional as F

def process_max_val(max_val):
    if isinstance(max_val, (list, tuple)):
        if len(max_val) == 2:
            x, y = max_val
        elif len(max_val) == 1:
            x, y = max_val[0], 1
        else:
            raise ValueError("max_val list/tuple must have 1 or 2 elements")
    else:
        # single scalar value
        x, y = max_val, 1
    return x, y

def plot_saliency_overlay_block(dataset, encoder, clust_head, cluster_ids, cluster_index, max_val=10, cluster_probs=None, save_name=None): 

    ## Sort colours
    cmap = cm.turbo.copy()
    #cmap.set_under("#F0F0F0")

    #cmap_sal = cm.PRGn.copy()
    cmap_sal = cm.spring.copy()
    cmap_sal.set_under("#F0F0F0")
    max_x, max_y = process_max_val(max_val)
    
    plt.figure(figsize=(max_x*2.1, max_y*6))
    ## Get a mask of cluster_ids
    indices = np.arange(max_x*max_y) 
    if cluster_index != None: 
        indices = np.where(np.array(cluster_ids) == cluster_index)[0]
        ## If the probabilities are given, show the top N probabilities
        if cluster_probs is not None:
            indices = indices[np.argsort(np.array(cluster_probs)[indices])][::-1]
    max_images = min(len(indices), max_x*max_y)
    
    ## Plot
    for i in range(max_images):
        ax = plt.subplot(max_y,max_x,i+1)
        
        numpy_coords, numpy_feats, *_ = dataset[indices[i]]
    
        # Create batched coordinates for the SparseTensor input
        orig_bcoords  = ME.utils.batched_coordinates([numpy_coords])
        orig_bfeats  = torch.from_numpy(np.concatenate([numpy_feats], 0)).float()
        orig = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)
            
        inputs  = make_dense_from_tensor(orig, 0, 768, 256)
        inputs  = inputs .cpu().squeeze().numpy()
        #img = inputs
        inputs = (inputs - inputs.min()) / (inputs.max() - inputs.min() + 1e-8)

        rgba_inputs = np.zeros((*inputs.shape, 4), dtype=np.float32)
        rgba_inputs[..., 0] = inputs  # R
        rgba_inputs[..., 1] = inputs  # G
        rgba_inputs[..., 2] = inputs  # B
        rgba_inputs[..., 3] = (inputs > 0).astype(np.float32)  # alpha: 0 if zero, 1 if nonzero
        
        outputs = grad_cam_sparse_hook(orig, encoder, clust_head, cluster_index, device=device)

        cam = outputs
        # Normalize CAM
        cam_norm = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        heatmap = plt.cm.YlGn(cam_norm)[:, :, :3]

        nonzero_vals = inputs[inputs > 0]
        vmax = np.percentile(nonzero_vals, 80)
        # plt.imshow(inputs, origin='lower', cmap=cmap, vmin=1e-6, vmax=vmax)
        plt.imshow(outputs, origin='lower', cmap=cmap_sal, alpha=0.8, vmin=1E-8, vmax=vmax)
        plt.imshow(rgba_inputs, origin='lower') #, cmap=cmap, vmin=0, vmax=vmax)

        ax.axis("off")
        plt.tight_layout()
    
    plt.tight_layout()
    if save_name: plt.savefig(save_name, dpi=300, bbox_inches='tight')
    plt.show()  
    plt.close()

In [None]:
#print(encoder)

In [None]:
for clust in range(30):
    print("Saliency for cluster", clust)
    plot_saliency_overlay_block(data_dataset, encoder, heads["clust"], data_processed['clust_index'], clust, [8,2], cluster_probs=data_processed['clust_max'])