In [None]:
import MinkowskiEngine as ME
import matplotlib.pyplot as plt
import matplotlib as mpl
import torchvision.transforms.v2 as transforms
import torchvision.transforms.v2.functional as F
from torch import nn

## Jupyter magic
%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)
import numpy as np
SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)

In [None]:
## Includes from my libraries for this project                                                                                                                                           
from ME_dataset_libs import CenterCrop, MaxRegionCrop, ConstantCharge, RandomCrop, RandomPixelNoise2D, FirstRegionCrop
from ME_dataset_libs import SingleModuleImage2D_solo_ME, solo_ME_collate_fn, solo_ME_collate_fn_with_meta
from ME_dataset_libs import make_dense, make_dense_from_tensor, Label

In [None]:
from FSD_training_analysis import get_models_from_checkpoint

## Load the pretrained model, set a few other parameters
nlatent=256
nclusters=20
lr="1E-5"
batch_size=1024
nevts="2M"
nsteps=50
data_frac=1

aug_type="bigaugbilin"
nchan=64
clust_arch="two"
proj_arch="logits_"
enc_arch="12x4"
enc_arch_pool="max"
enc_arch_flatten=1
enc_arch_slow_growth=1
enc_arch_first_kernel=7
enc_arch_sep_heads=1
softmax_temp=1.0
clust_temp=0.5
proj_temp=0.5
ent="_ent1E-3"
match=""

file_dir = "/pscratch/sd/c/cwilk"
chk_file = "state_lat"+str(nlatent)+"_clust"+str(nclusters)+match+"_nchan"+str(nchan)+"_"+lr+"_"+str(batch_size)+\
    "_PROJ"+str(proj_temp)+proj_arch+"CLUST"+str(clust_temp)+clust_arch+ent+"_soft"+str(softmax_temp)+"_arch"+enc_arch+"_pool"+enc_arch_pool+"_flat"+str(enc_arch_flatten)+"_grow"+str(enc_arch_slow_growth)+"_kern"+str(enc_arch_first_kernel)+"_sep"+str(enc_arch_sep_heads)+\
    "_onecycle"+str(nsteps)+"_"+aug_type+"_"+nevts+"_DATA"+str(data_frac)+"_FSDCCFIX.pth"

encoder, proj_head, clust_head, args = get_models_from_checkpoint(file_dir+"/"+chk_file)
encoder.eval()
proj_head.eval()
clust_head.eval()

encoder.to(device)
proj_head.to(device)
clust_head.to(device)

In [None]:
## Setup the dataloader
from torch.utils.data import ConcatDataset
import time
start = time.process_time() 

## Modify the nominal transform
nom_transform = transforms.Compose([
            FirstRegionCrop((800, 256), (768, 256)),
            ])

data_dir = "/pscratch/sd/c/cwilk/FSD/DATA"
sim_dir = "/pscratch/sd/c/cwilk/FSD/SIMULATION"
single_sim_dataset = SingleModuleImage2D_solo_ME(sim_dir, transform=nom_transform, max_events=250000, return_metadata=True)
single_data_dataset = SingleModuleImage2D_solo_ME(data_dir, transform=nom_transform, max_events=250000, return_metadata=True)
single_mixed_dataset = ConcatDataset([single_data_dataset, single_sim_dataset])

print("Time taken to load", single_data_dataset.__len__(),"data and", single_sim_dataset.__len__(), "images:", time.process_time() - start)

## Randomly chosen batching
single_loader = torch.utils.data.DataLoader(single_mixed_dataset,
                                            collate_fn=solo_ME_collate_fn_with_meta,
                                            batch_size=1024,
                                            shuffle=False,
                                            num_workers=4)

In [None]:
import numpy as np

## Encode the images we'll work with here (can take a while)
nhits  = []
filenames = []
event_ids = []
labels = []

encoder.eval()

## Note that this uses the loader including metadata so it's possible to trace back to the input files
for orig_bcoords, orig_bfeats, batch_labels, batch_filenames, batch_eventids in single_loader:

    batch_size = len(batch_filenames)
    orig_bcoords = orig_bcoords.to(device)
    orig_bfeats = orig_bfeats.to(device)
    orig_batch = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)            
                                            
    ## Now do the forward passes            
    with torch.no_grad(): 
        feature_maps = encoder(orig_batch, batch_size, return_maps=True)

        maps = feature_maps[2].cpu()

        C = maps.shape[0]  # number of channels
        num_to_show = min(C, 16)  # show at most 16 channels

        fig, axes = plt.subplots(4, 4, figsize=(8,8))
        for i in range(num_to_show):
            ax = axes[i//4, i%4]
            ax.imshow(maps[i], cmap="viridis", aspect="auto")
            ax.set_title(f"Ch {i}")
            ax.axis("off")
        plt.tight_layout()
        plt.show()
        break

    # nhits += [i.shape[0] for i in orig_batch.decomposed_features] 


In [None]:
## Function to show examples for each cluster
def plot_cluster_examples(dataset, cluster_ids, index, max_images=8): 
    
    plt.figure(figsize=(12,10))

    ## Get a mask of cluster_ids
    indices = np.where(np.array(cluster_ids) == index)[0]
    
    ## Grab the first 10 images (if there are 10)
    if len(indices) < max_images:
        max_images = len(indices)
    
    ## Plot
    for i in range(max_images):
        ax = plt.subplot(2,max_images,i+1)
        
        numpy_coords, numpy_feats, _, _, _ = dataset[indices[i]]

        # Create batched coordinates for the SparseTensor input
        orig_bcoords  = ME.utils.batched_coordinates([numpy_coords])
        orig_bfeats  = torch.from_numpy(np.concatenate([numpy_feats], 0)).float()

        orig_bcoords = orig_bcoords.to(device)
        orig_bfeats = orig_bfeats.to(device)
        orig = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)
            
        inputs  = make_dense_from_tensor(orig, 0, 768, 256)
        inputs  = inputs .cpu().squeeze().numpy()
        
        plt.imshow(inputs, origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)            
    plt.show()   