In [None]:
import MinkowskiEngine as ME
import matplotlib.pyplot as plt
import matplotlib as mpl
import torchvision.transforms.v2 as transforms
import torchvision.transforms.v2.functional as F

%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

import torch
device = 'cpu' #torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)
import numpy as np
SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)

In [None]:
## Includes from my libraries for this project                                                                                                                                           
from ME_NN_libs import AsymmetricL2LossME, EuclideanDistLoss
from ME_NN_libs import EncoderME, DecoderME, DeepEncoderME, DeepDecoderME, DeeperEncoderME, DeeperDecoderME
from ME_dataset_libs import CenterCrop, RandomCrop, RandomHorizontalFlip, RandomRotation2D, RandomShear2D, RandomBlockZero
from ME_dataset_libs import SingleModuleImage2D_MultiHDF5_ME, triple_ME_collate_fn
from ME_dataset_libs import SingleModuleImage2D_solo_ME, solo_ME_collate_fn
from ME_dataset_libs import make_dense, make_dense_from_tensor

In [None]:
nchan=32
nlatent=32
act_fn=ME.MinkowskiReLU
dropout=0

## Define the models
#encoder=EncoderME(nchan, nlatent, act_fn, dropout)
#decoder=DecoderME(nchan, nlatent, act_fn)

nchan=16
nlatent=128
hidden_act_fn=ME.MinkowskiSiLU
latent_act_fn=ME.MinkowskiTanh
dropout=0

## Define the models
encoder=EncoderME(nchan, nlatent, hidden_act_fn, latent_act_fn, dropout)
decoder=DecoderME(nchan, nlatent, hidden_act_fn)

## Load in the pre-calculated model weights
chk_file="/pscratch/sd/c/cwilk/state_lat"+str(nlatent)+"_nchan"+str(nchan)+"_5e-6_archsimple_SOLO_actfns_silu_tanh_2M_onecycle_ME.pth"
inDir = "/pscratch/sd/c/cwilk/h5_inputs/"

## Load in the pre-calculated model weights
nevents = 100000
#chk_file="/pscratch/sd/c/cwilk/state_lat"+str(nlatent)+"_nchan"+str(nchan)+"_5e-6_archsimple_SOLO_NEWAUGNOBI_1M_onecycle_ME.pth"
#inDir = "/pscratch/sd/c/cwilk/h5_inputs/"

checkpoint = torch.load(chk_file, map_location=device)
encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])
encoder.eval()
decoder.eval()

encoder.to(device)
decoder.to(device)

In [None]:
import time
start = time.process_time() 
train_dataset = SingleModuleImage2D_solo_ME(inDir, transform=CenterCrop(), max_events=100000) #nevents)
print("Time taken to load", train_dataset.__len__(),"images:", time.process_time() - start)

## Randomly chosen batching
single_loader = torch.utils.data.DataLoader(train_dataset,
                                            collate_fn=solo_ME_collate_fn,
                                            batch_size=1,
                                            shuffle=False,
                                            num_workers=1)

In [None]:
## Make a few example plots for comparison
def plot_ae_outputs(encoder,decoder,loader,n=10, start=0):  
    loader_iter = iter(loader)
    for x in range(start): next(loader_iter)
    plt.figure(figsize=(12,5))
    
    encoder.eval()
    decoder.eval()
    ## Loop over figures
    for i in range(n):
        ax = plt.subplot(3,n,i+1)

        orig_bcoords, orig_bfeats = next(loader_iter)
        
        with torch.no_grad():
            
            orig_bcoords = orig_bcoords.to(device)
            orig_bfeats = orig_bfeats.to(device)
            orig = ME.SparseTensor(orig_bfeats.float(), orig_bcoords.int(), device=device)

            enc_orig  = encoder(orig)
            rec_orig  = decoder(enc_orig)
            
        inputs  = make_dense_from_tensor(orig)
        outputs = make_dense_from_tensor(rec_orig)
        
        this_input = inputs[0].cpu().squeeze().numpy()
        this_output = outputs[0].cpu().squeeze().numpy()
        
        # print(this_output)
        
        ## Input row
        plt.imshow(this_input, cmap='viridis', origin='lower')            
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2: ax.set_title('Original images')
        
        ## Reconstructed row
        ax = plt.subplot(3, n, i + 1 + n)
        plt.imshow(this_output, cmap='viridis', origin='lower')  
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2: ax.set_title('Reconstructed images')
    
    plt.show()

In [None]:
plot_ae_outputs(encoder,decoder,single_loader,n=12, start=1000)

In [None]:
import pandas as pd 
import numpy as np

latent_vectors = []
nhits = []

nevt = 0
for orig_bcoords, orig_bfeats in single_loader:
    
    ## If the center crop is blank, everything barfs
    if orig_bfeats.size()[0] == 0: continue

    if nevt%10000==0: print("Processed evt:", nevt)
    orig_bcoords = orig_bcoords.to(device)
    orig_bfeats = orig_bfeats.to(device)
    orig_batch = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)            
                                            
    ## Now do the forward passes            
    encoder.eval()
    with torch.no_grad(): encoded_batch = encoder(orig_batch)
        
    lat_vect = encoded_batch.F.flatten().cpu().numpy()

    ## Remove fringe cases
    if lat_vect.size == 0: 
        print("Broken image:", orig_feats.size())
        continue
    
    nhits.append(orig_bfeats.size()[0])
    latent_vectors.append(lat_vect)
    nevt+=1

latent_vectors = np.vstack(latent_vectors)



In [None]:
# Make a plot of what it looks like
plt.scatter(latent_vectors[:,0], latent_vectors[:,1], s=1, vmin=100, vmax=500, c=nhits)

In [None]:
## Now TSNE it up
from sklearn.manifold import TSNE

perp=50
exag=20
print("Perplexity =", perp, "early exaggeration =", exag)
tsne = TSNE(n_components=2, perplexity=perp, max_iter=1000, early_exaggeration=exag, verbose=1) #, perplexity=60, n_iter=1000, early_exaggeration=20)
tsne_results = tsne.fit_transform(latent_vectors)

In [None]:
gr = plt.scatter(list(zip(*tsne_results))[0], list(zip(*tsne_results))[1], s=1, alpha=0.8, vmin=100, vmax=500, c=nhits)
plt.colorbar(gr)
plt.show()

In [None]:
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler

scaled_encoded_images = latent_vectors #encoded_images
# scaled_encoded_images = StandardScaler().fit_transform(encoded_images)

plt.scatter(scaled_encoded_images[:, 0], scaled_encoded_images[:, 1], s=1)
plt.show()

In [None]:
## Try k-NN algorithm
from sklearn.neighbors import NearestNeighbors

# Assuming `latent_space` is your latent space representation
latent_space = latent_vectors #np.array(latent_space)  # Ensure latent_space is a NumPy array

# Find the distances to the k-nearest neighbors
k = 20  # You can set k equal to min_samples
neighbors = NearestNeighbors(n_neighbors=k)
neighbors_fit = neighbors.fit(latent_space)
distances, indices = neighbors_fit.kneighbors(latent_space)

# Sort distances to the k-th nearest neighbor (ascending order)
distances = np.sort(distances, axis=0)
distances = distances[:, 1]

# Plot the distances
plt.figure(figsize=(10, 6))
plt.plot(distances)
ax = plt.gca()
#ax.set_ylim([0,0.2])
plt.title('k-NN Distance Plot')
plt.xlabel('Points sorted by distance to {}-th nearest neighbor'.format(k))
plt.ylabel('Distance')
plt.show()

In [None]:
from sklearn.cluster import DBSCAN
dbscan = DBSCAN(eps=200, min_samples=20)

clusters = dbscan.fit(scaled_encoded_images)

labels = clusters.labels_

# Number of clusters in labels, ignoring noise if present.
n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
n_noise_ = list(labels).count(-1)

n_points = [list(labels).count(x) for x in range(n_clusters_)]

print("Estimated number of clusters: %d" % n_clusters_)
print("N. points in clusters:", n_points)
print("Estimated number of noise points: %d" % n_noise_)
print("(Out of a total of %d images)" % len(scaled_encoded_images))

In [None]:
unique_labels = set(labels)
core_samples_mask = np.zeros_like(labels, dtype=bool)
core_samples_mask[dbscan.core_sample_indices_] = True

colors = [plt.cm.Spectral(each) for each in np.linspace(0, 1, len(unique_labels))]
for k, col in zip(unique_labels, colors):
    if k == -1:
        # Black used for noise.
        col = [0, 0, 0, 1]

    class_member_mask = labels == k

    xy = scaled_encoded_images[class_member_mask & core_samples_mask]
    plt.plot(
        xy[:, 0],
        xy[:, 1],
        "o",
        markerfacecolor=tuple(col),
        markeredgecolor="k",
        markersize=14,
    )

    xy = scaled_encoded_images[class_member_mask & ~core_samples_mask]
    plt.plot(
        xy[:, 0],
        xy[:, 1],
        "o",
        markerfacecolor=tuple(col),
        markeredgecolor="k",
        markersize=0.1,
    )

plt.title(f"Estimated number of clusters: {n_clusters_}")
plt.show()


In [None]:
## Now TSNE it up
from sklearn.manifold import TSNE

perp=50
exag=50
print("Perplexity =", perp, "early exaggeration =", exag)
tsne = TSNE(n_components=2, perplexity=perp, n_iter=1000, early_exaggeration=exag)#, verbose=1, perplexity=60, n_iter=1000, early_exaggeration=20)
tsne_results = tsne.fit_transform(encoded_samples)

In [None]:
## Visualise the results including the DB cluster info
plt.scatter(list(zip(*tsne_results))[0], list(zip(*tsne_results))[1], s=1, c=labels)

In [None]:
## Function to show examples for each cluster
def plot_cluster_examples(raw_images, labels, index, max_images=10): 
    
    plt.figure(figsize=(12,4.5))

    ## Get a mask of labels
    indices = np.where(np.array(labels) == index)[0]
    
    ## Grab the first 10 images (if there are 10)
    if len(indices) < max_images:
        max_images = len(indices)
    
    ## Plot
    for i in range(max_images):
        ax = plt.subplot(2,max_images,i+1)
        plt.imshow(raw_images[indices[i]], origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)            
    plt.show()   

In [None]:
## Now pull out a bank of example images for each cluster

for index in range(n_clusters_):
    print("Showing examples for cluster:", index, "which has", n_points[index], "values")
    plot_cluster_examples(train_dataset, labels, index)

print("Showing examples for the noise, which has", n_noise_, "values")
plot_cluster_examples(train_dataset, labels, -1)