In [None]:
import MinkowskiEngine as ME
import matplotlib.pyplot as plt
import matplotlib as mpl
import torchvision.transforms.v2 as transforms
import torchvision.transforms.v2.functional as F

## Jupyter magic
%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)
import numpy as np
SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)

In [None]:
## Includes from my libraries for this project                                                                                                                                           
from ME_NN_libs import ContrastiveEncoderME
from ME_dataset_libs import CenterCrop, MaxRegionCrop, RandomCrop, RandomHorizontalFlip, RandomRotation2D, RandomShear2D, RandomBlockZero, ConstantCharge
from ME_dataset_libs import SingleModuleImage2D_solo_ME, solo_ME_collate_fn, solo_ME_collate_fn_with_meta
from ME_dataset_libs import make_dense, make_dense_from_tensor

In [None]:
## Load the pretrained model, set a few other parameters
nchan=32
nlatent=64
temp=0.5
hidden_act_fn=ME.MinkowskiSiLU
latent_act_fn=ME.MinkowskiTanh
dropout=0
lr="5e-6"
batch_size=1536
aug_type="bigmodblock10x10"
aug_type="block10x10"

## Define the model
encoder=ContrastiveEncoderME(nchan, nlatent, hidden_act_fn, latent_act_fn, dropout)

## Load in the pre-calculated model weights
file_dir = "/global/cfs/cdirs/dune/users/cwilk/single_module_unsupervised"
chk_file = file_dir+"/state_CONTONLY_lat"+str(nlatent)+"_nchan"+str(nchan)+"_"+lr+"_"+str(batch_size)+"_NTXentMerged"+str(temp)+"_onecycle_"+aug_type+"_5M_ME_v9.pth"
inDir = file_dir+"/cwilk/h5_inputs_v9/"
nevents = 500000

checkpoint = torch.load(chk_file, map_location=device)
encoder.load_state_dict(checkpoint['encoder_state_dict'])
encoder.eval()
encoder.to(device)

In [None]:
import time
start = time.process_time() 
nom_transform=transforms.Compose([
    MaxRegionCrop(),
])
train_dataset = SingleModuleImage2D_solo_ME(inDir, transform=nom_transform, max_events=nevents, return_metadata=True)
print("Time taken to load", train_dataset.__len__(),"images:", time.process_time() - start)

## Randomly chosen batching
single_loader = torch.utils.data.DataLoader(train_dataset,
                                            collate_fn=solo_ME_collate_fn_with_meta,
                                            batch_size=1024,
                                            shuffle=False,
                                            num_workers=4)

In [None]:
import numpy as np

## Encode the images we'll work with here (can take a while)
latent = []
nhits  = []
filenames = []
event_ids = []

encoder.eval()

## Note that this uses the loader including metadata so it's possible to trace back to the input files
for orig_bcoords, orig_bfeats, batch_filenames, batch_eventids in single_loader:

    orig_bcoords = orig_bcoords.to(device)
    orig_bfeats = orig_bfeats.to(device)
    orig_batch = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)            
                                            
    ## Now do the forward passes            
    with torch.no_grad(): encoded_batch = encoder(orig_batch)
    
    nhits += [i.shape[0] for i in orig_batch.decomposed_features] 
    latent += [x.cpu().numpy() for x in encoded_batch.decomposed_features]
    filenames += [i for i in batch_filenames]
    event_ids += [i for i in batch_eventids]
    
lat_nonorm = np.vstack(latent)
hit_vect = np.array(nhits)

lat_vect = lat_nonorm / np.linalg.norm(lat_nonorm, axis=1, keepdims=True)

In [None]:
# Make a plot of what it looks like (not particularly useful)
x_coord = 0
y_coord = 1
plt.scatter(lat_vect[:,x_coord], lat_vect[:,y_coord], s=1, vmin=100, vmax=500, c=hit_vect)
plt.xlabel('Latent #'+str(x_coord))
plt.ylabel('Latent #'+str(y_coord))
plt.colorbar(label='N. hits')
plt.show()

In [None]:
from cuml.manifold import TSNE as cuML_TSNE
import cupy as cp

## Define a function for running t-SNE using the cuml version
def run_tsne_cuml(perp=300, exag=100, input_vect=lat_vect, nhits=hit_vect):
    print("Running cuML t-SNE with: perplexity =", perp, "early exaggeration =", exag)
    
    input_vect = cp.asarray(input_vect, dtype=cp.float32)

    ## I haven't played with most of cuml's t-SNE parameters
    tsne = cuML_TSNE(n_components=2, perplexity=perp, n_iter=1000, early_exaggeration=exag, late_exaggeration=1, metric='cosine', learning_rate=100, n_neighbors=1000)
    tsne_results = tsne.fit_transform(input_vect)

    tsne_results = cp.asnumpy(tsne_results)  # Convert to NumPy for matplotlib

    gr = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], s=0.2, alpha=0.8, vmin=100, vmax=500, c=nhits)
    plt.colorbar(gr, label='N.hits')
    plt.xlabel('t-SNE #0')
    plt.ylabel('t-SNE #1')
    plt.show()

    return tsne_results


In [None]:
## Actually run tsne (not always that useful)
perp=100
exag=20
tsne_results = run_tsne_cuml(perp, exag, lat_vect, hit_vect)

In [None]:
from cuml.neighbors import NearestNeighbors as cuML_NearestNeighbors

## Make a function to show nearest neighbours (not all that useful)
def run_knn_cuml(lat_vect, k=5):
    # Convert to CuPy array if not already
    lat_vect_gpu = cp.asarray(lat_vect, dtype=cp.float32)

    # Fit cuML k-NN
    neighbors = cuML_NearestNeighbors(n_neighbors=k, metric='cosine')
    neighbors.fit(lat_vect_gpu)

    distances, indices = neighbors.kneighbors(lat_vect_gpu)

    # Convert distances to NumPy for plotting
    distances_cpu = cp.asnumpy(distances)

    # Sort distances to the k-th nearest neighbor
    kth_distances = np.sort(distances_cpu[:, k-1])

    # Plot
    plt.figure(figsize=(10, 6))
    plt.plot(kth_distances)
    plt.title(f'k-NN Distance Plot (k={k})')
    plt.xlabel(f'Points sorted by distance to {k}-th nearest neighbor')
    plt.ylabel('Distance')
    plt.show()

    return kth_distances, cp.asnumpy(indices)

In [None]:
## Actually run knn
run_knn_cuml(lat_vect, 20)

In [None]:
from cuml.cluster import DBSCAN
from sklearn.preprocessing import normalize

## Run DBSCAN using the cuml implementation
def run_dbscan_gpu(eps=0.1, min_samples=20, input_vect=None):
    if input_vect is None:
        raise ValueError("input_vect must be provided.")

    print(f"Running GPU-accelerated DBSCAN with eps={eps}, min_samples={min_samples}")

    # Normalize vectors for cosine similarity (same as CPU version)
    input_vect = normalize(input_vect, norm='l2', axis=1)

    # Move data to GPU using CuPy
    input_vect_gpu = cp.asarray(input_vect)

    # Run DBSCAN on GPU
    dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric='cosine', index_type='int64')
    labels = dbscan.fit_predict(input_vect_gpu).get()  # Move result back to CPU

    # Compute cluster statistics
    n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
    n_noise_ = np.sum(labels == -1)
    n_points = np.bincount(labels[labels >= 0]) if n_clusters_ > 0 else []

    print(f"Estimated number of clusters: {n_clusters_}")
    print(f"N. points in clusters: {n_points.tolist()}")
    print(f"Estimated number of noise points: {n_noise_} (out of {len(input_vect)})")

    return labels, n_clusters_, n_noise_, n_points, dbscan

In [None]:
## Actually run an example of dbscan
eps=0.06
min_samples=100
labels, n_clusters_, n_noise_, n_points, dbscan = run_dbscan_gpu(eps, min_samples, input_vect=lat_vect)

In [None]:
## Assive colours to each labels
unique_labels = set(labels)
core_samples_mask = np.zeros_like(labels, dtype=bool)
core_samples_mask[dbscan.core_sample_indices_.get()] = True

colors = [plt.cm.Spectral(each) for each in np.linspace(0, 1, len(unique_labels))]
for k, col in zip(unique_labels, colors):
    if k == -1:
        # Black used for noise.
        col = [0, 0, 0, 1]

    class_member_mask = labels == k

    xy = lat_vect[class_member_mask & core_samples_mask]
    plt.plot(
        xy[:, 0],
        xy[:, 1],
        "o",
        markerfacecolor=tuple(col),
        markeredgecolor="k",
        markersize=5,
    )

    xy = lat_vect[class_member_mask & ~core_samples_mask]
    plt.plot(
        xy[:, 0],
        xy[:, 1],
        "o",
        markerfacecolor=tuple(col),
        markeredgecolor="k",
        markersize=0.1,
    )

plt.title(f"Estimated number of clusters: {n_clusters_}")
plt.show()


In [None]:
## Show the tSNE output (assuming it's been made), with the colours from the clustering
from matplotlib.colors import ListedColormap, BoundaryNorm
cmap = plt.get_cmap('gist_ncar', len(unique_labels))
norm = BoundaryNorm(range(len(unique_labels) + 1), cmap.N)
plt.scatter(list(zip(*tsne_results))[0], list(zip(*tsne_results))[1], s=1, cmap=cmap, norm=norm, alpha=0.8, c=labels)
plt.colorbar(label='Cluster')
plt.xlabel('t-SNE #0')
plt.ylabel('t-SNE #1')
plt.show()

In [None]:
## Function to show examples for each cluster
def plot_cluster_examples(dataset, labels, index, max_images=10): 
    
    plt.figure(figsize=(12,4.5))

    ## Get a mask of labels
    indices = np.where(np.array(labels) == index)[0]
    
    ## Grab the first 10 images (if there are 10)
    if len(indices) < max_images:
        max_images = len(indices)
    
    ## Plot
    for i in range(max_images):
        ax = plt.subplot(2,max_images,i+1)
        
        numpy_coords, numpy_feats, _, _ = dataset[indices[i]]
    
        # Create batched coordinates for the SparseTensor input
        orig_bcoords  = ME.utils.batched_coordinates([numpy_coords])
        orig_bfeats  = torch.from_numpy(np.concatenate([numpy_feats], 0)).float()

        orig_bcoords = orig_bcoords.to(device)
        orig_bfeats = orig_bfeats.to(device)
        orig = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)
            
        inputs  = make_dense_from_tensor(orig)
        inputs  = inputs .cpu().squeeze().numpy()
        
        plt.imshow(inputs, origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)            
    plt.show()   

In [None]:
## Now pull out a bank of example images for each cluster
for index in range(n_clusters_):
    print("Showing examples for cluster:", index, "which has", n_points[index], "values")
    plot_cluster_examples(train_dataset, labels, index, 15)

print("Showing examples for the noise, which has", n_noise_, "values")
plot_cluster_examples(train_dataset, labels, -1, 12)

In [None]:
## Function to show a big block of examples for each cluster
## index == None will just give an unclustered set
def plot_cluster_bigblock(dataset, labels, index, max_x=10, max_y=10, save_name=None): 
    
    plt.figure(figsize=(max_y*2, max_x*1.8*2))
    ## Get a mask of labels
    indices = np.arange(max_x*max_y) 
    if index != None: indices = np.where(np.array(labels) == index)[0]
    max_images = min(len(indices), max_x*max_y)
    print(len(indices))
    
    ## Plot
    for i in range(max_images):
        ax = plt.subplot(max_x,max_y,i+1)
        
        numpy_coords, numpy_feats, _, _ = dataset[indices[i]]
    
        # Create batched coordinates for the SparseTensor input
        orig_bcoords  = ME.utils.batched_coordinates([numpy_coords])
        orig_bfeats  = torch.from_numpy(np.concatenate([numpy_feats], 0)).float()

        orig_bcoords = orig_bcoords.to(device)
        orig_bfeats = orig_bfeats.to(device)
        orig = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)
            
        inputs  = make_dense_from_tensor(orig)
        inputs  = inputs .cpu().squeeze().numpy()
        
        plt.imshow(inputs, origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)    
    plt.tight_layout()
    if save_name: plt.savefig(save_name, dpi=150, bbox_inches='tight')
    plt.show()  

In [None]:
## Dump out a large block of images for one cluster
plot_cluster_bigblock(train_dataset, labels, 1, 10, 10) #, 'cluster_plots/v9_michel_like.png')

In [None]:
import json
from collections import defaultdict

## Dump out a file including the filenames and indices for the clustered images (for going back to the original files)
def dump_cluster_indices(index_label, cluster_labels, filenames, event_ids):

    # Inputs
    indices = np.where(cluster_labels == index_label)[0]

    selected_filenames = np.array(filenames)[indices]
    selected_event_ids = np.array(event_ids)[indices]

    # Group by filename
    grouped = defaultdict(list)
    for fname, eid in zip(selected_filenames, selected_event_ids):
        grouped[fname].append(int(eid))  # ensure JSON serializability

    # Save to JSON
    output_file = f'cluster_{index_label}_events.json'
    with open(output_file, 'w') as f:
        json.dump(grouped, f, indent=2)

    print(f"Saved grouped event list for cluster {index_label} to {output_file}")

In [None]:
## dump_cluster_indices(1, labels, filenames, event_ids)