In [None]:
import MinkowskiEngine as ME
import matplotlib.pyplot as plt
import matplotlib as mpl
import torchvision.transforms.v2 as transforms
import torchvision.transforms.v2.functional as F
from torch import nn

## Jupyter magic
%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)
import numpy as np
SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)

In [None]:
## Includes from my libraries for this project                                                                                                                                           
from ME_NN_libs import CCEncoderFSD, ProjectionHead, ClusteringHead
from ME_dataset_libs import CenterCrop, MaxRegionCrop, ConstantCharge, RandomCrop, RandomPixelNoise2D
from ME_dataset_libs import SingleModuleImage2D_solo_ME, solo_ME_collate_fn, solo_ME_collate_fn_with_meta
from ME_dataset_libs import make_dense, make_dense_from_tensor, Label

In [None]:
## Load the pretrained model, set a few other parameters
nchan=64
nlatent=128
nclusters=20
temp=0.5
enc_act_fn=ME.MinkowskiSiLU
hidden_act_fn=nn.SiLU
latent_act_fn=nn.Tanh
dropout=0
lr="5E-6"
batch_size=1024
aug_type="unitcharge"

## Define the model    
encoder=CCEncoderFSD(nchan, enc_act_fn, dropout)
proj_head = ProjectionHead(nchan, nlatent, hidden_act_fn, latent_act_fn)
clust_head = ClusteringHead(nchan, nclusters, hidden_act_fn)

## Load in the pre-calculated model weights
file_dir = "/pscratch/sd/c/cwilk"
chk_file = file_dir+"/state_lat"+str(nlatent)+"_clust"+str(nclusters)+"_nchan"+str(nchan)+"_"+lr+"_"+str(batch_size)+"_PROJ0.5CLUST0.5_onecycle50_unitcharge_2M_FSDCC.pth"

print(chk_file)

checkpoint = torch.load(chk_file, map_location=device)
encoder.load_state_dict(checkpoint['encoder_state_dict'])
proj_head.load_state_dict(checkpoint['proj_head_state_dict'])
clust_head.load_state_dict(checkpoint['clust_head_state_dict'])

encoder.eval()
proj_head.eval()
clust_head.eval()

encoder.to(device)
proj_head.to(device)
clust_head.to(device)


In [None]:
## Setup the dataloader
from torch.utils.data import ConcatDataset
import time
start = time.process_time() 

## Modify the nominal transform
nom_transform = transforms.Compose([
            MaxRegionCrop((256, 800), (256,512)),
            ConstantCharge(),
            ])

data_dir = "/pscratch/sd/c/cwilk/FSD/DATA"
sim_dir = "/pscratch/sd/c/cwilk/FSD/SIMULATION"
single_sim_dataset = SingleModuleImage2D_solo_ME(sim_dir, transform=nom_transform, max_events=250000, return_metadata=True)
single_data_dataset = SingleModuleImage2D_solo_ME(data_dir, transform=nom_transform, max_events=250000, return_metadata=True)
single_mixed_dataset = ConcatDataset([single_data_dataset, single_sim_dataset])

print("Time taken to load", single_data_dataset.__len__(),"data and", single_sim_dataset.__len__(), "images:", time.process_time() - start)

## Randomly chosen batching
single_loader = torch.utils.data.DataLoader(single_mixed_dataset,
                                            collate_fn=solo_ME_collate_fn_with_meta,
                                            batch_size=1024,
                                            shuffle=False,
                                            num_workers=4)

In [None]:
import numpy as np

## Encode the images we'll work with here (can take a while)
latent = []
cluster = []
nhits  = []
filenames = []
event_ids = []
labels = []

encoder.eval()
proj_head.eval()
clust_head.eval()

## Note that this uses the loader including metadata so it's possible to trace back to the input files
for orig_bcoords, orig_bfeats, batch_labels, batch_filenames, batch_eventids in single_loader:

    batch_size = len(batch_filenames)
    orig_bcoords = orig_bcoords.to(device)
    orig_bfeats = orig_bfeats.to(device)
    orig_batch = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)            
                                            
    ## Now do the forward passes            
    with torch.no_grad(): 
        encoded_batch = encoder(orig_batch, batch_size)
        clust_batch = clust_head(encoded_batch.F)
        proj_batch = proj_head(encoded_batch.F)

    nhits += [i.shape[0] for i in orig_batch.decomposed_features] 
    cluster += [x[np.newaxis, :] for x in clust_batch.detach().cpu().numpy()]
    latent += [x[np.newaxis, :] for x in proj_batch.detach().cpu().numpy()]
    filenames += [i for i in batch_filenames]
    event_ids += [i for i in batch_eventids]
    labels += [i for i in batch_labels]

## The image-wise latent space
latent_vect = np.vstack(latent)

## The pre-clustered space
cluster_vect = np.vstack(cluster)

## The index of the maximum value in cluster space
clust_index = np.argmax(cluster_vect, axis=1)

## The maximum value in cluster space
clust_max = np.max(cluster_vect, axis=1)

## The number of hits in the input image
hit_vect = np.array(nhits)

## The label of the input image (-1 for data)
label_vect = np.array(labels)

In [None]:
# Plot histogram
plt.hist(clust_max, bins=50, edgecolor='black')
plt.xlabel('Maximum value (per row)')
plt.ylabel('Count')
plt.title('Distribution of maximum values across 20 features')
plt.grid(True)
plt.show()

In [None]:
# Plot histogram
plt.hist(clust_index, bins=np.arange(nclusters+1)-0.5, edgecolor='black')
plt.xlabel('Index of max value')
plt.ylabel('Count')
plt.title('Distribution of max indices')
plt.grid(True)
plt.show()

In [None]:
from cuml.manifold import TSNE as cuML_TSNE
import cupy as cp
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import normalize
import matplotlib.colors as mcolors
import matplotlib

## Define a function for running t-SNE using the cuml version
def run_tsne_cuml(perp=30, exag=100, lr=lr, input_vect=latent_vect, zvect=clust_index):
    print("Running cuML t-SNE with: perplexity =", perp, "early exaggeration =", exag)
    
    input_vect = normalize(input_vect, norm='l2')
    input_vect = cp.asarray(input_vect, dtype=cp.float32)

    print("Input shape:", input_vect.shape)
    print("Input range:", input_vect.min(), input_vect.max())
    
    ## I haven't played with most of cuml's t-SNE parameters
    tsne = cuML_TSNE(n_components=2, perplexity=perp, n_iter=5000, \
                     early_exaggeration=exag, learning_rate=lr, \
                     learning_rate_method=None, \
                     metric='cosine', method='barnes_hut', verbose=True)
    tsne_results = tsne.fit_transform(input_vect)

    tsne_results = cp.asnumpy(tsne_results)  # Convert to NumPy for matplotlib
    tsne_results = StandardScaler().fit_transform(tsne_results)

    unique_labels = np.unique(zvect)
    n_clusters = len(unique_labels)

    # Use a qualitative colormap with enough colors
    cmap = matplotlib.colormaps['tab20']
    norm = mcolors.BoundaryNorm(boundaries=np.arange(n_clusters + 1), ncolors=n_clusters)
    
    gr = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], s=0.005, alpha=0.9, c=zvect, cmap=cmap, norm=norm)
    plt.colorbar(gr, label='Cluster ID')
    plt.xlabel('t-SNE #0')
    plt.ylabel('t-SNE #1')
    plt.show()

    print("t-SNE output min/max:", tsne_results.min(), tsne_results.max())
    print("t-SNE output std per dim:", tsne_results.std(axis=0))
    return tsne_results


In [None]:
## Actually run tsne (not always that useful)
perp=30
exag=6
lr=2000.0
tsne_results = run_tsne_cuml(perp, exag, lr, latent_vect, clust_index)

In [None]:
# A helper function to make column normalized histograms
from matplotlib.colors import LogNorm
def make_2D_histogram(x_vect, y_vect, norm='column', label_enum=Label):
    # Determine range of unique integer values
    x_min, x_max = x_vect.min(), x_vect.max()
    y_min, y_max = y_vect.min(), y_vect.max()

    # Define bin edges so each integer gets its own bin
    x_bins = np.arange(x_min, x_max + 2)  # +2 to include the last integer
    y_bins = np.arange(y_min, y_max + 2)

    # Compute the 2D histogram
    H, xedges, yedges = np.histogram2d(x_vect, y_vect, bins=[x_bins, y_bins])
    H = H.T
    
    # Column normalization: divide each column by its sum
    # Note: H shape is (len(x_bins)-1, len(y_bins)-1)
    column_sums = H.sum(axis=0, keepdims=True)
    row_sums = H.sum(axis=1, keepdims=True)

    
    if norm=='column': 
        H_normalized = np.divide(H, column_sums, where=column_sums != 0)
    elif norm=='row':
        H_normalized = np.divide(H, row_sums, where=row_sums != 0)
    else:
        print("Unknown norm option:, norm")
        return
        
    plt.figure(figsize=(8, 6))
    mesh = plt.pcolormesh(x_bins, y_bins, H_normalized, cmap='viridis', shading='auto')
    plt.colorbar(mesh, label='Normalized Frequency (per '+norm+')')
    plt.ylabel("Cluster ID")

    if label_enum is not None:
        x_ticks = np.arange(x_min, x_max + 1)
        x_labels = [label_enum.name_from_index(i) for i in x_ticks]
        plt.xticks(ticks=x_ticks + 1, labels=x_labels, rotation=45, ha='right')
    plt.yticks(ticks=y_bins[:-1])
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.show()

In [None]:
make_2D_histogram(label_vect, clust_index)

In [None]:
make_2D_histogram(label_vect, clust_index, 'row')

In [None]:
## Function to show examples for each cluster
def plot_cluster_examples(dataset, cluster_ids, index, max_images=8): 
    
    plt.figure(figsize=(12,6))

    ## Get a mask of cluster_ids
    indices = np.where(np.array(cluster_ids) == index)[0]
    
    ## Grab the first 10 images (if there are 10)
    if len(indices) < max_images:
        max_images = len(indices)
    
    ## Plot
    for i in range(max_images):
        ax = plt.subplot(2,max_images,i+1)
        
        numpy_coords, numpy_feats, _, _, _ = dataset[indices[i]]

        # Create batched coordinates for the SparseTensor input
        orig_bcoords  = ME.utils.batched_coordinates([numpy_coords])
        orig_bfeats  = torch.from_numpy(np.concatenate([numpy_feats], 0)).float()

        orig_bcoords = orig_bcoords.to(device)
        orig_bfeats = orig_bfeats.to(device)
        orig = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)
            
        inputs  = make_dense_from_tensor(orig, 0, 512, 256)
        inputs  = inputs .cpu().squeeze().numpy()
        
        plt.imshow(inputs, origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)            
    plt.show()   

In [None]:
## Now pull out a bank of example images for each cluster
for index in range(nclusters):
    print("Showing examples for cluster:", index, "which has", np.count_nonzero(clust_index==index), "values")
    plot_cluster_examples(single_mixed_dataset, clust_index, index, 8)

In [None]:
## Function to show a big block of examples for each cluster
## index == None will just give an unclustered set
def plot_cluster_bigblock(dataset, cluster_ids, index, max_x=10, max_y=10, save_name=None): 
    
    plt.figure(figsize=(max_y*2, max_x*1.8*2))
    ## Get a mask of cluster_ids
    indices = np.arange(max_x*max_y) 
    if index != None: indices = np.where(np.array(cluster_ids) == index)[0]
    max_images = min(len(indices), max_x*max_y)
    print(len(indices))
    
    ## Plot
    for i in range(max_images):
        ax = plt.subplot(max_x,max_y,i+1)
        
        numpy_coords, numpy_feats, _, _, _ = dataset[indices[i]]
    
        # Create batched coordinates for the SparseTensor input
        orig_bcoords  = ME.utils.batched_coordinates([numpy_coords])
        orig_bfeats  = torch.from_numpy(np.concatenate([numpy_feats], 0)).float()

        orig_bcoords = orig_bcoords.to(device)
        orig_bfeats = orig_bfeats.to(device)
        orig = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)
            
        inputs  = make_dense_from_tensor(orig, 0, 512, 256)
        inputs  = inputs .cpu().squeeze().numpy()
        
        plt.imshow(inputs, origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)    
    plt.tight_layout()
    if save_name: plt.savefig(save_name, dpi=150, bbox_inches='tight')
    plt.show()  

In [None]:
## Dump out a large block of images for one cluster
plot_cluster_bigblock(single_mixed_dataset, clust_index, 17, 10, 10) #, 'cluster_plots/v9_michel_like.png')