In [None]:
import bootstrap
from torch.utils.tensorboard import SummaryWriter
import MinkowskiEngine as ME
from torch import nn
import torch
import numpy as np

import matplotlib.pyplot as plt
import matplotlib as mpl

%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

## Tell pytorch we have a GPU if we do
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
torch.device(device)

SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)
writer = SummaryWriter("log")

In [None]:
## Includes from my libraries for this project                                                                                                                                           
from core.losses.ntxent import NTXentMerged
from core.losses.clustering import ClusteringLossMerged
from core.models.projection_head import ProjectionHeadLogits
from core.models.clustering_head import ClusteringHeadOneLayer, ClusteringHeadTwoLayer
from core.models.encoder import CCEncoderFSD12x4Opt, CCEncoderFSD24x8Opt       

## Import transformations                                                                                                                                                                
from core.data.augmentations_2d import CenterCrop, DoNothing
from datasets.fsd.augmentations_2d import get_transform

## Import dataset                                                                                                                                                                        
from core.data.datasets import paired_2d_dataset_ME, cat_ME_collate_fn
from core.data.datasets import single_2d_dataset_ME, solo_ME_collate_fn

## For later visualization
from core.analysis.image_utils import make_dense_from_tensor

from core.analysis.metrics import argmax_consistency

In [None]:
## Training function
def sharpen(p, temperature=0.5):
    p_power = p ** (1.0 / temperature)
    return p_power / p_power.sum(dim=1, keepdim=True)

def run_training(num_iterations, log_dir, encoder, proj_head, clust_head, temperature, dataloader, optimizer, batch_size, scheduler=None, entropy_weight=1.0):

    print("Training with", num_iterations, "iterations")
    tstart = time.process_time()

    if log_dir: writer = SummaryWriter(log_dir=log_dir)

    instance_loss_fn = NTXentMerged(temperature)
    cluster_loss_fn  = ClusteringLossMerged(temperature, entropy_weight=entropy_weight)
    
    encoder.to(device)
    proj_head.to(device)
    clust_head.to(device)
    
    ## Loop over the desired iterations
    for iteration in range(num_iterations):
        
        total_loss = 0
        total_instance_loss = 0
        total_cluster_loss = 0

        total_entropy_loss = 0
        total_acc = 0
        nbatches   = 0
        
        # Set train mode for both the encoder and the decoder
        encoder.train()
        proj_head.train()
        clust_head.train()
        
        # Iterate over batches of images with the dataloader
        for cat_bcoords, cat_bfeats, this_batch_size in train_loader:
            
            ## Send to the device, then make the sparse tensors                                                                                                                          
            cat_bcoords = cat_bcoords.to(device, non_blocking=True)
            cat_bfeats  = cat_bfeats .to(device)
            cat_batch   = ME.SparseTensor(cat_bfeats, cat_bcoords, device=device)

            ## Now do the forward pass     
            encoded_instance_batch, encoded_cluster_batch = encoder(cat_batch, this_batch_size)
            proj_batch = proj_head(encoded_instance_batch)
            clust_batch = clust_head(encoded_cluster_batch)
            
            total_acc += argmax_consistency(clust_batch, device).item()
            # sharpened_batch = sharpen(clust_batch, temperature=0.5)
            
            # Evaluate loss, but sw
            instance_loss = instance_loss_fn(proj_batch)
            clust_loss, clust_entropy = cluster_loss_fn(clust_batch)
            loss = instance_loss + clust_loss + clust_entropy
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_instance_loss += instance_loss.item()
            total_cluster_loss += clust_loss.item()
            total_entropy_loss += clust_entropy.item()
            nbatches += 1
            
            torch.cuda.empty_cache()
        
        ## See if we have an LR scheduler...
        if scheduler: scheduler.step(total_loss)
        
        av_loss = total_loss/nbatches

        if log_dir: writer.add_scalar('loss/train', av_loss, iteration)
        print("Processed", iteration, "/", num_iterations, "; loss =", av_loss, "(", total_instance_loss/nbatches, \
              total_cluster_loss/nbatches, total_entropy_loss/nbatches,");", "acc =", total_acc/nbatches)
        print("Time taken:", time.process_time() - tstart)
        
        ## End so empty cache because MinkowskiEngine can't be trusted
        torch.cuda.empty_cache()

In [None]:
import time
from torch.utils.data import ConcatDataset
start = time.process_time() 

## Get a set of augmentations to use in training
aug_transform = get_transform('fsd', 'baseaug')

## Make a mixed dataset of data and simulation
data_dir = "/pscratch/sd/c/cwilk/FSD/DATA"
sim_dir = "/pscratch/sd/c/cwilk/FSD/SIMULATION"
data_dataset = paired_2d_dataset_ME(data_dir, nom_transform=DoNothing(), aug_transform=aug_transform, max_events=100000)
sim_dataset  = paired_2d_dataset_ME(sim_dir, nom_transform=DoNothing(), aug_transform=aug_transform, max_events=50000)
mixed_dataset = data_dataset #ConcatDataset([data_dataset, sim_dataset])
print("Loaded:", data_dataset.__len__(), "data and", sim_dataset.__len__(), "simulated events")

batch_size=1024
train_loader = torch.utils.data.DataLoader(mixed_dataset,
                                           collate_fn=cat_ME_collate_fn,
                                           batch_size=batch_size,
                                           shuffle=True, 
                                           num_workers=8,
                                           drop_last=True,
                                           pin_memory=False,
                                           prefetch_factor=1)

In [None]:
## This is a useful but experimental pytorch function which flags where synchronization calls are made
## (useful for debugging only)
from torch import optim

## Varius config parameters
nchan=24
nhidden=256
nlatent=64
enc_act_fn=ME.MinkowskiSiLU
hidden_act_fn=nn.SiLU
latent_act_fn=nn.Tanh
dropout = 0
temperature = 0.5
num_iterations=50
log_dir="log_jupyter"
nclusters = 50
first_kernel = 7
flatten = 1
pool = "max"
slow_growth = 1
sep_heads = 0
softmax_temp=1.0
entropy_weight=0.1

## Define the models
encoder=CCEncoderFSD24x8Opt(nchan, \
                  act_fn=enc_act_fn, \
                  first_kernel=first_kernel, \
                  flatten=True, \
                  pool=pool, \
                  slow_growth=bool(slow_growth),
                  sep_heads=True)
proj_head = ProjectionHeadLogits(encoder.get_nchan_instance(), nlatent, nhidden, hidden_act_fn)
clust_head = ClusteringHeadTwoLayer(encoder.get_nchan_cluster(), nclusters, softmax_temp)

## Load in the pre-calculated model weights if they exist
chk_file=None 
if chk_file:
    checkpoint = torch.load(chk_file, map_location='cpu')
    encoder.load_state_dict(checkpoint['encoder_state_dict'])

encoder.to(device)
proj_head.to(device)
clust_head.to(device)

params_to_optimize = [
        {'params': encoder.parameters()},
        {'params': proj_head.parameters()},
        {'params': clust_head.parameters()},
    ]

lr=5e-4
weight_decay=0 #1e-5
optimizer = torch.optim.AdamW(params_to_optimize, lr=lr, weight_decay=weight_decay)

## Scheduler options
scheduler = None 
#scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
#                                                 mode='min',
#                                                 factor=0.2,
#                                                 patience=1,
#                                                 cooldown=2,
#                                                 threshold=5e-3,
#                                                 threshold_mode='rel')
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=num_iterations, cycle_momentum=False)

run_training(num_iterations, log_dir, encoder, proj_head, clust_head, temperature, train_loader, optimizer, batch_size, scheduler, entropy_weight=entropy_weight)

In [None]:
import numpy as np
from core.data.datasets import single_2d_dataset_ME, solo_ME_collate_fn
from core.data.augmentations_2d import MaxRegionCrop, FirstRegionCrop

## Make a single loader to loop over for ease
max_transform = FirstRegionCrop((800, 256), (768, 256))
single_sim_dataset = single_2d_dataset_ME(sim_dir, transform=max_transform, max_events=100000) #nevents)
single_data_dataset = single_2d_dataset_ME(data_dir, transform=max_transform, max_events=100000) #nevents)

single_mixed_dataset = ConcatDataset([single_data_dataset, single_sim_dataset])
#single_mixed_dataset = single_sim_dataset

single_loader = torch.utils.data.DataLoader(single_mixed_dataset,
                                            collate_fn=solo_ME_collate_fn,
                                            batch_size=512,
                                            shuffle=False,
                                            num_workers=4)

cluster = []
latent = []
nhits  = []
labels = []

## Make this work with batches larger than 1...
for orig_bcoords, orig_bfeats, blabels in single_loader:

    batch_size = len(blabels)
    # print(type(blabels), len(blabels))

    orig_bcoords = orig_bcoords.to(device)
    orig_bfeats = orig_bfeats.to(device)
    orig_batch = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)            
                                            
    ## Now do the forward passes            
    encoder.eval()
    clust_head.eval()
    proj_head.eval()
    with torch.no_grad(): 
        encoded_instance_batch, encoded_cluster_batch = encoder(orig_batch, batch_size)
        proj_batch = proj_head(encoded_instance_batch)
        clust_batch = clust_head(encoded_cluster_batch)
    
    nhits += [i.shape[0] for i in orig_batch.decomposed_features] # if i.shape[0] != 0] 
    cluster += [x[np.newaxis, :] for x in clust_batch.detach().cpu().numpy()]
    latent += [x[np.newaxis, :] for x in proj_batch.detach().cpu().numpy()]
    labels += [x for x in blabels]
    
latent_vect = np.vstack(latent)
cluster_vect = np.vstack(cluster)
hit_vect = np.array(nhits)
label_vect = np.array(labels)

In [None]:
clust_index = np.argmax(cluster_vect, axis=1)

In [None]:
max_values = np.max(cluster_vect, axis=1)

# Plot histogram
plt.hist(max_values, bins=50, edgecolor='black')
plt.xlabel('Maximum value (per row)')
plt.ylabel('Count')
plt.title('Distribution of maximum values across 20 features')
plt.grid(True)
plt.show()

In [None]:
max_indices = np.argmax(cluster_vect, axis=1)

# Plot histogram
plt.hist(max_indices, bins=np.arange(nclusters+1)-0.5, edgecolor='black')
plt.xlabel('Index of max value')
plt.ylabel('Count')
plt.title('Distribution of max indices')
plt.grid(True)
plt.show()

In [None]:
# Force reload so I can play with changes outside jupyter...
import importlib
import analysis.plotting_utils
importlib.reload(analysis.plotting_utils)
from analysis.plotting_utils import compute_cluster_overlap, plot_overlap_matrix

overlap_matrix = compute_cluster_overlap(cluster_vect, 4)
plot_overlap_matrix(overlap_matrix, max_val=0.2)

In [None]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

def __compute_cluster_centroids(embeddings, assignments, n_clusters):
    D = embeddings.shape[1]
    centroids = np.zeros((n_clusters, D), dtype=np.float32)

    for k in range(n_clusters):
        mask = (assignments == k)
        if mask.any():
            centroids[k] = embeddings[mask].mean(axis=0)
        else:
            centroids[k] = np.nan  # empty cluster
    
    return centroids

def compute_cluster_centroids(X, labels):
    centroids = []
    for k in np.unique(labels):
        members = X[labels == k]
        centroids.append(members.mean(axis=0))
    return np.vstack(centroids)
    
def compute_centroid_similarity(centroids, metric="cosine"):
    """
    Compute pairwise similarity/distance between centroids.

    Args:
        centroids: (K, D) array
        metric: "cosine" (similarity) or "euclidean" (distance)

    Returns:
        sim_matrix: (K, K) array of similarities/distances
    """
    mask = ~np.isnan(centroids).any(axis=1)
    centroids_valid = centroids[mask]

    print(centroids_valid.shape)
    if metric == "cosine":
        sim_matrix = cosine_similarity(centroids_valid)
    elif metric == "euclidean":
        from sklearn.metrics import pairwise_distances
        sim_matrix = -pairwise_distances(centroids_valid, metric="euclidean")  # negative so higher=closer
    else:
        raise ValueError("metric must be 'cosine' or 'euclidean'")
    
    return sim_matrix

In [None]:
centroids = compute_cluster_centroids(cluster_vect, clust_index) #, 30)
sim_matrix = compute_centroid_similarity(centroids)
plt.imshow(centroids, cmap="viridis", vmin=-0.2, vmax=0.2) 
#plot_overlap_matrix(centroids) #sim_matrix)

In [None]:
## Make a plot of what it looks like
plt.scatter(latent_vect[:,0], latent_vect[:,1], s=1, c=clust_index)

In [None]:
H, xedges, yedges = np.histogram2d(clust_index, label_vect, bins=[10, 20], range=[[0, 10], [0, 20]])

# Plot it
plt.imshow(H.T, origin='lower', aspect='auto', cmap='viridis')  # Note the .T to match axes
plt.colorbar(label='Count')
plt.xlabel('label')
plt.ylabel('clust_index')
#plt.xticks(np.arange(20))
#plt.yticks(np.arange(10))
plt.show()

In [None]:
from cuml.manifold import TSNE as cuML_TSNE
import cupy as cp

## Define a function for running t-SNE using the cuml version
def run_tsne_cuml(perp=300, exag=100, input_vect=latent_vect, nhits=clust_index):
    print("Running cuML t-SNE with: perplexity =", perp, "early exaggeration =", exag)
    
    input_vect = cp.asarray(input_vect, dtype=cp.float32)

    ## I haven't played with most of cuml's t-SNE parameters
    tsne = cuML_TSNE(n_components=2, perplexity=perp, n_neighbors=4*perp, n_iter=1000, early_exaggeration=exag, late_exaggeration=1, metric='cosine') #, learning_rate=100, n_neighbors=1000)
    tsne_results = tsne.fit_transform(input_vect)

    tsne_results = cp.asnumpy(tsne_results)  # Convert to NumPy for matplotlib

    # gr = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], s=0.2, alpha=0.8, vmin=100, vmax=500, c=nhits)
    gr = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], s=0.2, alpha=0.8, c=nhits)
    plt.colorbar(gr, label='N.hits')
    plt.xlabel('t-SNE #0')
    plt.ylabel('t-SNE #1')
    plt.show()

    return tsne_results

In [None]:
## Actually run t-SNE
perp=30
exag=10
tsne_results = run_tsne_cuml(perp, exag, latent_vect, clust_index)

In [None]:
## Function to show examples for each cluster
def plot_cluster_examples(dataset, labels, index, max_images=10): 
    
    plt.figure(figsize=(15,9))

    ## Get a mask of labels
    indices = np.where(np.array(labels) == index)[0]
    
    ## Grab the first 10 images (if there are 10)
    if len(indices) < max_images:
        max_images = len(indices)
    
    ## Plot
    for i in range(max_images):
        ax = plt.subplot(2,max_images,i+1)
        
        numpy_coords, numpy_feats, label = dataset[indices[i]]
    
        # Create batched coordinates for the SparseTensor input
        orig_bcoords  = ME.utils.batched_coordinates([numpy_coords])
        orig_bfeats  = torch.from_numpy(np.concatenate([numpy_feats], 0)).float()

        orig_bcoords = orig_bcoords.to(device)
        orig_bfeats = orig_bfeats.to(device)
        orig = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)
            
        inputs  = make_dense_from_tensor(orig, 0, 768, 256)
        inputs  = inputs .cpu().squeeze().numpy()
        
        plt.imshow(inputs, origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)            
    plt.show()   

In [None]:
## Now pull out a bank of example images for each cluster

for index in range(nclusters):
    print("Showing examples for cluster:", index, "which has", np.count_nonzero(clust_index==index), "values")
    plot_cluster_examples(single_mixed_dataset, clust_index, index)

In [None]:
## Function to show a big block of examples for each cluster
## index == None will just give an unclustered set
def plot_cluster_bigblock(dataset, clust_index, index, max_x=10, max_y=10, save_name=None): 
    
    plt.figure(figsize=(max_y*2, max_x*1.8*2))
    ## Get a mask of cluster_ids
    indices = np.arange(max_x*max_y) 
    if index != None: indices = np.where(np.array(clust_index) == index)[0]
    max_images = min(len(indices), max_x*max_y)
    print(len(indices))
    
    ## Plot
    for i in range(max_images):
        ax = plt.subplot(max_x,max_y,i+1)
        
        numpy_coords, numpy_feats, _ = dataset[indices[i]]
    
        # Create batched coordinates for the SparseTensor input
        orig_bcoords  = ME.utils.batched_coordinates([numpy_coords])
        orig_bfeats  = torch.from_numpy(np.concatenate([numpy_feats], 0)).float()

        orig_bcoords = orig_bcoords.to(device)
        orig_bfeats = orig_bfeats.to(device)
        orig = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)
            
        inputs  = make_dense_from_tensor(orig, 0, 768, 256)
        inputs  = inputs .cpu().squeeze().numpy()
        
        plt.imshow(inputs, origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)    
    plt.tight_layout()
    if save_name: plt.savefig(save_name, dpi=150, bbox_inches='tight')
    plt.show()  

In [None]:
## Dump out a large block of images for one cluster
plot_cluster_bigblock(single_mixed_dataset, clust_index, 8, 10, 10) #, 'cluster_plots/v9_michel_like.png')