In [None]:
import MinkowskiEngine as ME
import matplotlib.pyplot as plt
import matplotlib as mpl
import torchvision.transforms.v2 as transforms
import torchvision.transforms.v2.functional as F

%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)
import numpy as np
SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)

In [None]:
## Includes from my libraries for this project                                                                                                                                           
from ME_NN_libs import AsymmetricL2LossME, EuclideanDistLoss
from ME_NN_libs import EncoderME, DecoderME, DeepEncoderME, DeepDecoderME, DeeperEncoderME, DeeperDecoderME
from ME_NN_libs import ProjectionHead
from ME_dataset_libs import CenterCrop, MaxRegionCrop, RandomCrop, RandomHorizontalFlip, RandomRotation2D, RandomShear2D, RandomBlockZero, ConstantCharge
from ME_dataset_libs import SingleModuleImage2D_MultiHDF5_ME, triple_ME_collate_fn
from ME_dataset_libs import SingleModuleImage2D_solo_ME, solo_ME_collate_fn
from ME_dataset_libs import make_dense, make_dense_from_tensor

In [None]:
nchan=32
nlatent=128
temp=0.5
hidden_act_fn=ME.MinkowskiSiLU
latent_act_fn=ME.MinkowskiTanh
dropout=0

final_layer=128

## Define the models
encoder=EncoderME(nchan, nlatent, hidden_act_fn, latent_act_fn, dropout)
decoder=DecoderME(nchan, nlatent, hidden_act_fn)
project=ProjectionHead([nlatent, nlatent, nlatent, final_layer], latent_act_fn)

## Load in the pre-calculated model weights
# chk_file="/pscratch/sd/c/cwilk/state_lat"+str(nlatent)+"_nchan"+str(nchan)+"_5e-6_PROJECT_TEMP"+str(temp)+"_onecycle_smallmodFIXCROP_2M_ME.pth"
chk_file="/pscratch/sd/c/cwilk/state_lat"+str(nlatent)+"_nchan"+str(nchan)+"_5e-6_PROJECT_NTXentMerged"+str(temp)+"_onecycle_unitcharge_projsize"+str(final_layer)+"_2M_ME.pth"

inDir = "/pscratch/sd/c/cwilk/h5_inputs/"
nevents = 50000

checkpoint = torch.load(chk_file, map_location=device)
encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])
project.load_state_dict(checkpoint['project_state_dict'])

encoder.eval()
decoder.eval()
project.eval()

encoder.to(device)
decoder.to(device)
project.to(device)

In [None]:
import time
start = time.process_time() 
nom_transform=transforms.Compose([
    MaxRegionCrop(),
    ConstantCharge()
])
train_dataset = SingleModuleImage2D_solo_ME(inDir, transform=nom_transform, max_events=nevents)
print("Time taken to load", train_dataset.__len__(),"images:", time.process_time() - start)

## Randomly chosen batching
single_loader = torch.utils.data.DataLoader(train_dataset,
                                            collate_fn=solo_ME_collate_fn,
                                            batch_size=512,
                                            shuffle=False,
                                            num_workers=4)

In [None]:
## Make a few example plots for comparison
def plot_ae_outputs(encoder,decoder,loader,n=10, start=0, save_name=None):  
    loader_iter = iter(loader)
    for x in range(start): next(loader_iter)
    plt.figure(figsize=(12,6))
    
    encoder.eval()
    decoder.eval()

    ## Loop over figures
    for i in range(n):
        ax = plt.subplot(3,n,i+1)

        orig_bcoords, orig_bfeats = next(loader_iter)
        
        with torch.no_grad():
            
            orig_bcoords = orig_bcoords.to(device)
            orig_bfeats  = orig_bfeats.to(device)
            orig = ME.SparseTensor(orig_bfeats.float(), orig_bcoords.int(), device=device)

            enc_orig  = encoder(orig)
            rec_orig  = decoder(enc_orig)
            
        inputs  = make_dense_from_tensor(orig)
        outputs = make_dense_from_tensor(rec_orig)
        
        this_input = inputs[0].cpu().squeeze().numpy()
        this_output = outputs[0].cpu().squeeze().numpy()
        
        # print(this_output)
        
        ## Input row
        plt.imshow(this_input, cmap='viridis', origin='lower')            
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2: ax.set_title('Original images')
        
        ## Reconstructed row
        ax = plt.subplot(3, n, i + 1 + n)
        plt.imshow(this_output, cmap='viridis', origin='lower')  
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2: ax.set_title('Reconstructed images')
    
    # plt.tight_layout()
    if save_name: plt.savefig(save_name, dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
plot_ae_outputs(encoder,decoder,single_loader,n=15, start=30, save_name='cluster_plots/AE_reco_example.png')

In [None]:
import numpy as np

latent = []
proj   = []
nhits  = []

## Make this work with batches larger than 1...
for orig_bcoords, orig_bfeats in single_loader:

    orig_bcoords = orig_bcoords.to(device)
    orig_bfeats = orig_bfeats.to(device)
    orig_batch = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)            
                                            
    ## Now do the forward passes            
    encoder.eval()
    project.eval()
    with torch.no_grad(): 
        encoded_batch = encoder(orig_batch)
        project_batch = project(encoded_batch)
    
    nhits += [i.shape[0] for i in orig_batch.decomposed_features if i.shape[0] != 0] 
    latent += [x.cpu().numpy() for x in encoded_batch.decomposed_features]
    proj += [x.cpu().numpy() for x in project_batch.decomposed_features]
    
lat_vect = np.vstack(latent)
hit_vect = np.array(nhits)
proj_nonorm = np.vstack(proj)

proj_vect = proj_nonorm / np.linalg.norm(proj_nonorm, axis=1, keepdims=True)

In [None]:
# print(proj_vect)
#proj_vect_norm = proj_vect / np.linalg.norm(proj_vect, axis=1, keepdims=True)
norms = np.linalg.norm(proj_vect, axis=1)
print(norms)

# Plot the norms
plt.figure(figsize=(10, 6))
plt.plot(norms, marker='o', linestyle='', markersize=2, alpha=0.7)
plt.xlabel("Index", fontsize=12)
plt.ylabel("Norm", fontsize=12)
plt.grid(alpha=0.4)
plt.show()

In [None]:
# Make a plot of what it looks like
x_coord = 0
y_coord = 1
plt.scatter(lat_vect[:,x_coord], lat_vect[:,y_coord], s=1, vmin=100, vmax=500, c=hit_vect)
plt.xlabel('Latent #'+str(x_coord))
plt.ylabel('Latent #'+str(y_coord))
plt.colorbar(label='N. hits')
plt.show()

In [None]:
# Make a plot of what it looks like
x_coord = 0
y_coord = 1
plt.scatter(proj_vect[:,x_coord], proj_vect[:,y_coord], s=1, vmin=100, vmax=500, c=hit_vect)
plt.xlabel('Proj. #'+str(x_coord))
plt.ylabel('Proj. #'+str(y_coord))
plt.colorbar(label='N. hits')
plt.show()

In [None]:
which_vect=proj_vect

In [None]:
## Now TSNE it up
from sklearn.manifold import TSNE

perp=200
exag=50
print("Perplexity =", perp, "early exaggeration =", exag)
tsne = TSNE(n_components=2, perplexity=perp, max_iter=1000, early_exaggeration=exag, metric='cosine')
tsne_results = tsne.fit_transform(which_vect)



In [None]:
## Now TSNE it up
from sklearn.manifold import TSNE

def run_tsne(perp=300, exag=100):
    print("Running t-SNE with: perplexity =", perp, "early exaggeration =", exag)
    tsne = TSNE(n_components=2, perplexity=perp, max_iter=1000, early_exaggeration=exag, metric='cosine')
    tsne_results = tsne.fit_transform(which_vect)

    gr = plt.scatter(list(zip(*tsne_results))[0], list(zip(*tsne_results))[1], s=1, alpha=0.8, vmin=100, vmax=500, c=nhits)
    plt.colorbar(gr, label='N.hits')
    plt.xlabel('t-SNE #0')
    plt.ylabel('t-SNE #1')
    plt.show()

    return tsne_results

In [None]:
perp=200
exag=50
run_tsne(perp, exag)

In [None]:
gr = plt.scatter(list(zip(*tsne_results))[0], list(zip(*tsne_results))[1], s=1, alpha=0.8, vmin=100, vmax=500, c=nhits)
plt.colorbar(gr, label='N.hits')
plt.xlabel('t-SNE #0')
plt.ylabel('t-SNE #1')
plt.show()

In [None]:
## Try k-NN algorithm
from sklearn.neighbors import NearestNeighbors

# Find the distances to the k-nearest neighbors
k = 5 # You can set k equal to min_samples
neighbors = NearestNeighbors(n_neighbors=k, metric='cosine')
neighbors_fit = neighbors.fit(which_vect)
distances, indices = neighbors_fit.kneighbors(which_vect)

# Sort distances to the k-th nearest neighbor (ascending order)
distances = np.sort(distances, axis=0)
distances = distances[:, k-1]

# Plot the distances
plt.figure(figsize=(10, 6))
plt.plot(distances)
ax = plt.gca()
#ax.set_ylim([0,0.2])
plt.title('k-NN Distance Plot')
plt.xlabel('Points sorted by distance to {}-th nearest neighbor'.format(k))
plt.ylabel('Distance')
plt.show()

In [None]:
from sklearn.cluster import DBSCAN
## 0.1, 25 kind of works
## 0.1, 20 works with the 0.5 unitcharge one

def run_dbscan(eps=0.1, min_samples=20):
    print("Running DBSCAN with eps =", eps, "; min_samples =", min_samples)
    dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric='cosine')
    clusters = dbscan.fit(which_vect)
    labels = clusters.labels_

    # Number of clusters in labels, ignoring noise if present.
    n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
    n_noise_ = list(labels).count(-1)
    n_points = [list(labels).count(x) for x in range(n_clusters_)]

    print("Estimated number of clusters: %d" % n_clusters_)
    print("N. points in clusters:", n_points)
    print("Estimated number of noise points: %d" % n_noise_)
    print("(Out of a total of %d images)" % len(which_vect))
    return clusters, labels, n_clusters_, n_noise_, n_points, dbscan
clusters, labels, n_clusters_, n_noise_, n_points, dbscan = run_dbscan(0.1)

In [None]:
for eps in [0.07, 0.08, 0.09, 0.1, 0.11, 0.12]: clusters, labels, n_clusters_, n_noise_, n_points = run_dbscan(eps)

In [None]:
from sklearn.cluster import HDBSCAN

hdbscan = HDBSCAN(min_cluster_size=25, min_samples=5, cluster_selection_epsilon=0.05, metric='cosine')
clusters = hdbscan.fit(which_vect)

labels = clusters.labels_

# Number of clusters in labels, ignoring noise if present.
n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
n_noise_ = list(labels).count(-1)

n_points = [list(labels).count(x) for x in range(n_clusters_)]

print("Estimated number of clusters: %d" % n_clusters_)
print("N. points in clusters:", n_points)
print("Estimated number of noise points: %d" % n_noise_)
print("(Out of a total of %d images)" % len(which_vect))


In [None]:
unique_labels = set(labels)
core_samples_mask = np.zeros_like(labels, dtype=bool)
core_samples_mask[dbscan.core_sample_indices_] = True

colors = [plt.cm.Spectral(each) for each in np.linspace(0, 1, len(unique_labels))]
for k, col in zip(unique_labels, colors):
    if k == -1:
        # Black used for noise.
        col = [0, 0, 0, 1]

    class_member_mask = labels == k

    xy = which_vect[class_member_mask & core_samples_mask]
    plt.plot(
        xy[:, 0],
        xy[:, 1],
        "o",
        markerfacecolor=tuple(col),
        markeredgecolor="k",
        markersize=5,
    )

    xy = which_vect[class_member_mask & ~core_samples_mask]
    plt.plot(
        xy[:, 0],
        xy[:, 1],
        "o",
        markerfacecolor=tuple(col),
        markeredgecolor="k",
        markersize=0.1,
    )

plt.title(f"Estimated number of clusters: {n_clusters_}")
plt.show()


In [None]:
## Visualise the results including the DB cluster info
from matplotlib.colors import ListedColormap, BoundaryNorm
cmap = ListedColormap(plt.cm.tab20.colors[:len(unique_labels)])
norm = BoundaryNorm(range(len(unique_labels) + 1), cmap.N)
plt.scatter(list(zip(*tsne_results))[0], list(zip(*tsne_results))[1], s=1, cmap=cmap, norm=norm, alpha=0.8, c=labels)
plt.colorbar(label='Cluster')
plt.xlabel('t-SNE #0')
plt.ylabel('t-SNE #1')
plt.show()

In [None]:
## Function to show examples for each cluster
def plot_cluster_examples(dataset, labels, index, max_images=10): 
    
    plt.figure(figsize=(12,4.5))

    ## Get a mask of labels
    indices = np.where(np.array(labels) == index)[0]
    
    ## Grab the first 10 images (if there are 10)
    if len(indices) < max_images:
        max_images = len(indices)
    
    ## Plot
    for i in range(max_images):
        ax = plt.subplot(2,max_images,i+1)
        
        numpy_coords, numpy_feats = dataset[indices[i]]
    
        # Create batched coordinates for the SparseTensor input
        orig_bcoords  = ME.utils.batched_coordinates([numpy_coords])
        orig_bfeats  = torch.from_numpy(np.concatenate([numpy_feats], 0)).float()

        orig_bcoords = orig_bcoords.to(device)
        orig_bfeats = orig_bfeats.to(device)
        orig = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)
            
        inputs  = make_dense_from_tensor(orig)
        inputs  = inputs .cpu().squeeze().numpy()
        
        plt.imshow(inputs, origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)            
    plt.show()   

In [None]:
## Function to show a big block of examples for each cluster
## index == None will just give an unclustered set
def plot_cluster_bigblock(dataset, labels, index, max_x=10, max_y=10, save_name=None): 
    
    plt.figure(figsize=(max_y*2, max_x*1.8*2))
    ## Get a mask of labels
    indices = np.arange(max_x*max_y) 
    if index != None: indices = np.where(np.array(labels) == index)[0]
    max_images = min(len(indices), max_x*max_y)
    print(len(indices))
    
    ## Plot
    for i in range(max_images):
        ax = plt.subplot(max_x,max_y,i+1)
        
        numpy_coords, numpy_feats = dataset[indices[i]]
    
        # Create batched coordinates for the SparseTensor input
        orig_bcoords  = ME.utils.batched_coordinates([numpy_coords])
        orig_bfeats  = torch.from_numpy(np.concatenate([numpy_feats], 0)).float()

        orig_bcoords = orig_bcoords.to(device)
        orig_bfeats = orig_bfeats.to(device)
        orig = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)
            
        inputs  = make_dense_from_tensor(orig)
        inputs  = inputs .cpu().squeeze().numpy()
        
        plt.imshow(inputs, origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)    
    plt.tight_layout()
    if save_name: plt.savefig(save_name, dpi=150, bbox_inches='tight')
    plt.show()  


In [None]:
## Now pull out a bank of example images for each cluster

for index in range(n_clusters_):
    print("Showing examples for cluster:", index, "which has", n_points[index], "values")
    plot_cluster_examples(train_dataset, labels, index, 15)

print("Showing examples for the noise, which has", n_noise_, "values")
plot_cluster_examples(train_dataset, labels, -1, 12)

In [None]:
# for i in range(14): plot_cluster_bigblock(train_dataset, labels, i, 5, 12, 'cluster_plots/cluster_'+str(i)+'.png')
plot_cluster_bigblock(train_dataset, labels, None, 5, 12, 'cluster_plots/unclustered.png')