In [None]:
from torch.utils.tensorboard import SummaryWriter
import MinkowskiEngine as ME
from torch import nn
import torch
import numpy as np

import matplotlib.pyplot as plt
import matplotlib as mpl

%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

## Tell pytorch we have a GPU if we do
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)

SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)
writer = SummaryWriter("log")

In [None]:
## Includes from my libraries for this project                                                                                                                                           
from ME_NN_libs import AsymmetricL2LossME, EuclideanDistLoss
from ME_NN_libs import EncoderME, DecoderME, DeepEncoderME, DeepDecoderME, DeeperEncoderME, DeeperDecoderME
from ME_NN_libs import ProjectionHead
from ME_dataset_libs import CenterCrop, RandomCrop, RandomHorizontalFlip, RandomRotation2D, RandomShear2D, \
    RandomBlockZero, RandomJitterCharge, RandomScaleCharge, RandomElasticDistortion2D, RandomGridDistortion2D, BilinearInterpolation
from ME_dataset_libs import SingleModuleImage2D_MultiHDF5_ME, triple_ME_collate_fn
from ME_dataset_libs import SingleModuleImage2D_solo_ME, solo_ME_collate_fn
from ME_dataset_libs import make_dense, make_dense_from_tensor

In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, dim, act_fn=ME.MinkowskiReLU):
        super(ProjectionHead, self).__init__()

        self.linear_proj = nn.Sequential(
            ME.MinkowskiLinear(dim[0], dim[1], bias=False),
            ME.MinkowskiBatchNorm(dim[1]),
            act_fn(), 
            ME.MinkowskiLinear(dim[1], dim[2], bias=False),
            ME.MinkowskiBatchNorm(dim[2]),
            act_fn(), 
            ME.MinkowskiLinear(dim[2], dim[3], bias=False),
            act_fn(), 
            # ME.MinkowskiBatchNorm(dim[3])
        )
        
        self.initialize_weights()
        
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, ME.MinkowskiLinear):
                ME.utils.kaiming_normal_(m.linear.weight, mode='fan_out', nonlinearity='relu')   
            if isinstance(m, ME.MinkowskiBatchNorm):
                nn.init.constant_(m.bn.weight, 1)
                nn.init.constant_(m.bn.bias, 0)

        
    def forward(self, x):
        x = self.linear_proj(x)
        return x

In [None]:
## Dump out some of the input and reconstructed images to see how the autoencoder is getting on
def plot_ae_outputs(encoder,decoder,n=10):  
    
    plt.figure(figsize=(12,5))
    
    encoder.eval()
    decoder.eval()
    ## Loop over figures
    for i in range(n):
        ax = plt.subplot(3,n,i+1)

        aug1_bcoords, aug1_bfeats, aug2_bcoords, aug2_bfeats, orig_bcoords, orig_bfeats = next(iter(train_loader))
        
        with torch.no_grad():
            
            orig_bcoords = orig_bcoords.to(device)
            orig_bfeats = orig_bfeats.to(device)
            orig = ME.SparseTensor(orig_bfeats.float(), orig_bcoords.int(), device=device)

            enc_orig  = encoder(orig)
            rec_orig  = decoder(enc_orig)
            
        inputs  = make_dense_from_tensor(orig)
        outputs = make_dense_from_tensor(rec_orig)
        
        this_input = inputs[0].cpu().squeeze().numpy()
        this_output = outputs[0].cpu().squeeze().numpy()
                
        ## Input row
        plt.imshow(this_input, cmap='viridis', origin='lower')            
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2: ax.set_title('Original images')
        
        ## Reconstructed row
        ax = plt.subplot(3, n, i + 1 + n)
        plt.imshow(this_output, cmap='viridis', origin='lower')  
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2: ax.set_title('Reconstructed images')
    
    plt.show()   

In [None]:
class EuclideanDistLoss(torch.nn.Module):
    def __init__(self, scale, cutoff=0.1, pressure=100):
        super(EuclideanDistLoss, self).__init__()
        self.cutoff = cutoff
        self.pressure = pressure
        self.scale = scale
        
    def forward(self, latent1, latent2):
        # Compute the Euclidean distance between each pair of corresponding tensors in the batch
        norm_lat1 = nn.functional.normalize(latent1, p=2, dim=1)
        norm_lat2 = nn.functional.normalize(latent2, p=2, dim=1)
        #print(batch_size)
        # print(latent1)
        # print(latent2)
        distances = torch.norm(norm_lat1 - norm_lat2, p=2, dim=1)
        #print(distances)
        mod_penalty = torch.stack([self.calc_penalty(item) for item in distances])
        #print(mod_penalty)
        loss = mod_penalty.mean()*self.scale
        return loss
        
    def calc_penalty(self, value):
        return value**2
        ## Apply a penalty that is the value-cutoff above the cutoff, and is penalty*(cutoff - value)**2 for values below it
        #if value > self.cutoff:
        #    return (value - self.cutoff)**2
        #else: 
        #return self.pressure*(self.cutoff - value)**2


In [None]:
class NTXentOld(torch.nn.Module):
    def __init__(self, temperature=0.5):
        super(NTXentOld, self).__init__()
        self.temperature = temperature
        
    def forward(self, latent1, latent2):
        batch_size = latent1.shape[0]
        z_i = nn.functional.normalize(latent1, p=2, dim=1)
        z_j = nn.functional.normalize(latent2, p=2, dim=1)
        
        xcs = torch.matmul(z_i, z_j.T)        
        xcs[torch.eye(batch_size).bool()] = float("-inf")
        
        target = torch.arange(batch_size, device=xcs.device)
        target[0::2] += 1
        target[1::2] -= 1
        
        loss = nn.functional.cross_entropy(xcs / self.temperature, target, reduction="mean")
        return loss

In [None]:
class NTXent(nn.Module):
    def __init__(self, temperature=0.5):
        super().__init__()
        # self.batch_size = batch_size
        self.register_buffer("temperature", torch.tensor(temperature))
        #self.register_buffer("negatives_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float())
            
    def forward(self, emb_i, emb_j):
        """
        emb_i and emb_j are batches of embeddings, where corresponding indices are pairs
        z_i, z_j as per SimCLR paper
        """
        batch_size = emb_i.shape[0]
        z_i = nn.functional.normalize(emb_i, dim=1)
        z_j = nn.functional.normalize(emb_j, dim=1)
        
        negatives_mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool, device=device)).float()

        representations = torch.cat([z_i, z_j], dim=0)
        xcs = torch.matmul(z_i, z_j.T)        

        similarity_matrix = nn.functional.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
        
        #print(xcs.shape)C
        
        #print(similarity_matrix.shape)
        
        sim_ij = torch.diag(similarity_matrix, batch_size)
        sim_ji = torch.diag(similarity_matrix, -batch_size)
        positives = torch.cat([sim_ij, sim_ji], dim=0)
        
        nominator = torch.exp(positives / self.temperature)
        denominator = negatives_mask * torch.exp(similarity_matrix / self.temperature)
    
        loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))
        loss = torch.sum(loss_partial) / (2 * batch_size)
        
        # print(nominator/torch.sum(denominator, dim=1))
        return loss

In [None]:
class NTXentDCL(nn.Module):
    def __init__(self, temperature=0.5, tau_plus=0.1):
        super().__init__()
        self.temperature = temperature
        self.tau_plus = tau_plus
        
    def get_negative_mask(self, batch_size):
        negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)
        for i in range(batch_size):
            negative_mask[i, i] = 0
            negative_mask[i, i + batch_size] = 0

        negative_mask = torch.cat((negative_mask, negative_mask), 0)
        return negative_mask    
            
    def forward(self, emb_i, emb_j):
        """
        emb_i and emb_j are batches of embeddings, where corresponding indices are pairs
        z_i, z_j as per SimCLR paper
        """
        batch_size = emb_i.shape[0]
        z_i = nn.functional.normalize(emb_i, dim=1)
        z_j = nn.functional.normalize(emb_j, dim=1)
        
        # neg score
        out = torch.cat([z_i, z_j], dim=0)
        neg = torch.exp(torch.mm(out, out.t().contiguous()) / self.temperature)
        mask = self.get_negative_mask(batch_size).to(out.device)
        neg = neg.masked_select(mask).view(2 * batch_size, -1)

        # pos score
        pos = torch.exp(torch.sum(z_i * z_j, dim=-1) / self.temperature)
        pos = torch.cat([pos, pos], dim=0)

        # estimator g()
        # if debiased:
        N = batch_size * 2 - 2
        Ng = (-self.tau_plus * N * pos + neg.sum(dim = -1)) / (1 - self.tau_plus)
        # constrain (optional)
        Ng = torch.clamp(Ng, min = N * np.e**(-1 / self.temperature))
        #else:
        #    Ng = neg.sum(dim=-1)

        # contrastive loss
        loss = (- torch.log(pos / (pos + Ng) )).mean()
        return loss

In [None]:
class NTXentExtended(nn.Module):
    def __init__(self, temperature=0.5, epsilon=0.05):
        super().__init__()
        self.temperature = temperature
        self.epsilon = epsilon
            
    def forward(self, emb_i, emb_j):
        """
        emb_i and emb_j are batches of embeddings, where corresponding indices are pairs
        z_i, z_j as per SimCLR paper
        """
        batch_size = emb_i.shape[0]
        z_i = nn.functional.normalize(emb_i, dim=1)
        z_j = nn.functional.normalize(emb_j, dim=1)
        
        embeddings = torch.cat([z_i, z_j], dim=0)
        
        mask = torch.eye(2 * batch_size, dtype=torch.bool).to(embeddings.device)
        
        # Compute cosine similarity matrix (2N x 2N)
        similarity_matrix = torch.matmul(embeddings, embeddings.T)

        # Extract positive pair similarities (sim(z_i, z_j))
        positive_similarities = torch.cat([similarity_matrix[i, i + batch_size].unsqueeze(0) for i in range(batch_size)], dim=0)
    
        # Compute dynamic thresholds for negative pairs based on positive pair similarities
        dynamic_thresholds = positive_similarities - self.epsilon
    
        # Compute the numerator (exp(sim(z_i, z_j) / temperature)) for positive pairs
        numerator = torch.exp(positive_similarities / self.temperature)
        
        # Initialize total loss
        loss = 0
    
        for i in range(batch_size):
            # Get the negative pair similarities for the current anchor z_i
            neg_similarities = similarity_matrix[i][~mask[i]]
            
            print("neg_similarities", neg_similarities.shape)
            
            # Apply the dynamic threshold to exclude negatives that are too similar
            valid_negatives = neg_similarities[neg_similarities < dynamic_thresholds[i]]
            
            print("valid_negatives", valid_negatives.shape)
            
            # Apply debiasing: weight negative pairs based on their similarity
            # Weights are smaller for highly similar negative pairs (close to the anchor)
            weights = 1 - torch.exp(-valid_negatives / self.temperature)
        
            print("weights", weights.shape)
        
            # Compute the denominator (sum of weighted exp(sim(z_i, z_k) / temperature) for all valid negatives)
            denominator = torch.sum(weights * torch.exp(valid_negatives / self.temperature)) + numerator[i]
        
            print("numerator[i]", numerator[i].shape)
            print("denominator", denominator.shape)
            print(denominator)
        
            # Compute the NT-Xent loss for the positive pair
            loss += -torch.log(numerator[i] / denominator)
    
        # Return the average loss
        return loss / batch_size

In [None]:
## Wrap the training in a nicer function...
def run_training(num_iterations, log_dir, encoder, decoder, projection, dataloader, optimizer, batch_size, scheduler=None, prof=None):

    print("Training with", num_iterations, "iterations")
    tstart = time.process_time()

    if log_dir: writer = SummaryWriter(log_dir=log_dir)

    reco_loss_fn = AsymmetricL2LossME(10, 1, batch_size)
    # latent_loss_fn = EuclideanDistLoss(latent_scale)
    # latent_loss_fn = CosDistLoss()
    # latent_loss_fn = NTXentDCL(0.5, 0.8)
    latent_loss_fn = NTXent(1)

    encoder.to(device, non_blocking=True)
    decoder.to(device)
    projection.to(device)
    
    #encoder.eval()
    #decoder.eval()
    
    ## Freeze the encoder and decoder
    # for param in encoder.parameters(): param.requires_grad = False
    # for param in decoder.parameters(): param.requires_grad = False    
    
    ## Loop over the desired iterations
    for iteration in range(num_iterations):
        
        total_loss = 0
        total_aug1_loss = 0
        total_aug2_loss = 0
        total_orig_loss = 0
        total_latent_loss = 0
        nbatches   = 0
        
        # Set train mode for both the encoder and the decoder
        encoder.train()
        decoder.train()
        projection.train()
        
        # Iterate over batches of images with the dataloader
        for aug1_bcoords, aug1_bfeats, aug2_bcoords, aug2_bfeats, orig_bcoords, orig_bfeats in dataloader:
            
            ## Send to the device, then make the sparse tensors
            aug1_bcoords = aug1_bcoords.to(device, non_blocking=True)
            aug1_bfeats = aug1_bfeats.to(device, non_blocking=True)
            aug2_bcoords = aug2_bcoords.to(device, non_blocking=True)
            aug2_bfeats = aug2_bfeats.to(device) ## This one has to block or it can try to run the forward function without the data on the GPU...
            
            aug1_batch = ME.SparseTensor(aug1_bfeats, aug1_bcoords, device=device)
            aug2_batch = ME.SparseTensor(aug2_bfeats, aug2_bcoords, device=device)
                                    
            ## Now do the forward passes
            encoded_batch1 = encoder(aug1_batch)
            decoded_batch1 = decoder(encoded_batch1)
            encoded_batch2 = encoder(aug2_batch)
            decoded_batch2 = decoder(encoded_batch2)
     
            # Evaluate loss
            aug1_loss = reco_loss_fn(decoded_batch1, aug1_batch)
            aug2_loss = reco_loss_fn(decoded_batch2, aug2_batch)
            
            proj1 = projection(encoded_batch1)
            proj2 = projection(encoded_batch2)

            latent_loss = latent_loss_fn(torch.cat(proj1.decomposed_features), torch.cat(proj2.decomposed_features))
            
            print(latent_loss_fn(torch.cat(proj1.decomposed_features), torch.cat(proj2.decomposed_features)).item())
            print(latent_loss_fn(proj1.F, proj2.F).item())
            
            loss = aug1_loss + aug2_loss + latent_loss
            # loss = latent_loss
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()            
            
            total_loss += loss.item()
            total_aug1_loss += aug1_loss.item()
            total_aug2_loss += aug2_loss.item()
            total_latent_loss += latent_loss.item()
            nbatches += 1
            
            torch.cuda.empty_cache()
        
        ## See if we have an LR scheduler...
        if scheduler: scheduler.step() #total_loss)
        
        av_loss = total_loss/nbatches
        av_aug1_loss = total_aug1_loss/nbatches
        av_aug2_loss = total_aug2_loss/nbatches
        # av_orig_loss = total_orig_loss/nbatches
        av_latent_loss = total_latent_loss/nbatches

        if log_dir: writer.add_scalar('loss/train', av_loss, iteration)
        print("Processed", iteration, "/", num_iterations, "; loss =", av_loss, "(", av_aug1_loss, "+", av_aug2_loss, "+", av_latent_loss, ")")
        print("Time taken:", time.process_time() - tstart)

        ## For profiling, it can be helpful to add a break here
        #break
        
        if prof: prof.step()
        
        ## dump some images so we can see how the training is going
        if iteration%10 == 0: plot_ae_outputs(encoder,decoder,10)
        
        ## End so empty cache because MinkowskiEngine can't be trusted
        torch.cuda.empty_cache()

In [None]:
import time
import torchvision.transforms.v2 as transforms
import torchvision.transforms.v2.functional as F

aug_transform = transforms.Compose([
    RandomGridDistortion2D(5,5),
    RandomShear2D(0.1, 0.1),
    RandomHorizontalFlip(),
    RandomRotation2D(-10,10),
    RandomBlockZero(5, 6),
    RandomScaleCharge(0.02),
    RandomJitterCharge(0.02),
    RandomCrop()
])

inDir = "/pscratch/sd/c/cwilk/h5_inputs/"
start = time.process_time() 
train_dataset = SingleModuleImage2D_MultiHDF5_ME(inDir, nom_transform=CenterCrop(), aug_transform=aug_transform, max_events=100000)
print("Time taken to load", train_dataset.__len__(),"images:", time.process_time() - start)

batch_size=512
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           collate_fn=triple_ME_collate_fn,
                                           batch_size=batch_size,
                                           shuffle=True, 
                                           num_workers=8,
                                           drop_last=True,
                                           pin_memory=False,
                                           prefetch_factor=1)

In [None]:
## This is a useful but experimental pytorch function which flags where synchronization calls are made
torch.cuda.set_sync_debug_mode(0)

## Varius config parameters
nchan=32
nlatent=128
hidden_act_fn=ME.MinkowskiSiLU
latent_act_fn=ME.MinkowskiTanh
dropout = 0
num_iterations=50
log_dir="log"

## Define the models
encoder=EncoderME(nchan, nlatent, hidden_act_fn, latent_act_fn, dropout)
decoder=DecoderME(nchan, nlatent, hidden_act_fn)

## Load in the pre-calculated model weights
chk_file="/pscratch/sd/c/cwilk/state_lat"+str(nlatent)+"_nchan"+str(nchan)+"_5e-6_archsimple_SOLO_actfns_silu_tanh_2M_onecycle_ME.pth"

checkpoint = torch.load(chk_file, map_location='cpu')
encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])

projection = ProjectionHead([nlatent, nlatent, nlatent, nlatent], latent_act_fn)


encoder.to(device, non_blocking=True)
decoder.to(device)
projection.to(device)

params_to_optimize = [
        {'params': encoder.parameters()},
        {'params': decoder.parameters()},
        {'params': projection.parameters()}
    ]

lr=1e-5
weight_decay=0
optimizer = torch.optim.AdamW(params_to_optimize, lr=lr, weight_decay=weight_decay)

## Scheduler options
scheduler = None 
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.5)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=3)
#scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50,100,150], gamma=0.2, last_epoch=-1)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=num_iterations, cycle_momentum=False)

#from torch.profiler import profile, record_function, ProfilerActivity
#with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], use_cuda=True, 
#             with_stack=False, record_shapes=False) as prof:
#    with record_function("model_inference"):
#        run_training(num_iterations, log_dir, encoder, decoder, train_loader, optimizer, scale scheduler, prof)

run_training(num_iterations, log_dir, encoder, decoder, projection, train_loader, optimizer, batch_size, scheduler)

In [None]:
## Now take the trained model and try to run some unsupervised learning on it...
import numpy as np

## Make a single loader to loop over for ease
single_dataset = SingleModuleImage2D_solo_ME(inDir, transform=CenterCrop(), max_events=50000) #nevents)
single_loader = torch.utils.data.DataLoader(single_dataset,
                                            collate_fn=solo_ME_collate_fn,
                                            batch_size=512,
                                            shuffle=False,
                                            num_workers=4)

latent = []
proj   = []
nhits  = []

## Make this work with batches larger than 1...
for orig_bcoords, orig_bfeats in single_loader:

    orig_bcoords = orig_bcoords.to(device)
    orig_bfeats = orig_bfeats.to(device)
    orig_batch = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)            
                                            
    ## Now do the forward passes            
    encoder.eval()
    projection.eval()
    with torch.no_grad(): 
        encoded_batch = encoder(orig_batch)
        proj_batch = projection(encoded_batch)
    
    nhits += [i.shape[0] for i in orig_batch.decomposed_features if i.shape[0] != 0] 
    latent += [x.cpu().numpy() for x in encoded_batch.decomposed_features]
    proj += [x.cpu().numpy() for x in proj_batch.decomposed_features]

    print(encoded_batch.decomposed_features[2].cpu().numpy())
    print(encoded_batch.F.cpu().numpy()[2])
    
lat_vect = np.vstack(latent)
proj_vect = np.vstack(proj)
hit_vect = np.array(nhits)

In [None]:
## Make a plot of what it looks like
plt.scatter(proj_vect[:,0], proj_vect[:,1], s=1, vmin=100, vmax=500, c=hit_vect)

In [None]:
## Now TSNE it up
from sklearn.manifold import TSNE

perp=50
exag=50
print("Perplexity =", perp, "early exaggeration =", exag)
tsne = TSNE(n_components=2, perplexity=perp, max_iter=1000, early_exaggeration=exag, verbose=1, metric='cosine')
tsne_results = tsne.fit_transform(proj_vect)

In [None]:
gr = plt.scatter(list(zip(*tsne_results))[0], list(zip(*tsne_results))[1], s=1, alpha=0.8, vmin=100, vmax=500, c=hit_vect)
plt.colorbar(gr)
plt.show()

In [None]:
## Try k-NN algorithm
from sklearn.neighbors import NearestNeighbors

# Find the distances to the k-nearest neighbors
k = 100
neighbors = NearestNeighbors(n_neighbors=k, metric='cosine')
neighbors_fit = neighbors.fit(proj_vect)
distances, indices = neighbors_fit.kneighbors(proj_vect)

# Sort distances to the k-th nearest neighbor (ascending order)
distances = np.sort(distances, axis=0)
distances = distances[:, 1]

# Plot the distances
plt.figure(figsize=(10, 6))
plt.plot(distances)
plt.title('k-NN Distance Plot')
plt.xlabel('Points sorted by distance to {}-th nearest neighbor'.format(k))
plt.ylabel('Distance')
#plt.ylim(0, 5)
plt.show()

In [None]:
from sklearn.cluster import DBSCAN
k=100
dbscan = DBSCAN(eps=0.01, min_samples=k, metric='cosine')

clusters = dbscan.fit(proj_vect)

labels = clusters.labels_

# Number of clusters in labels, ignoring noise if present.
n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
n_noise_ = list(labels).count(-1)

n_points = [list(labels).count(x) for x in range(n_clusters_)]

print("Estimated number of clusters: %d" % n_clusters_)
print("N. points in clusters:", n_points)
print("Estimated number of noise points: %d" % n_noise_)
print("(Out of a total of %d images)" % len(proj_vect))

In [None]:
unique_labels = set(labels)
core_samples_mask = np.zeros_like(labels, dtype=bool)
core_samples_mask[dbscan.core_sample_indices_] = True

colors = [plt.cm.Spectral(each) for each in np.linspace(0, 1, len(unique_labels))]
for k, col in zip(unique_labels, colors):
    if k == -1:
        # Black used for noise.
        col = [0, 0, 0, 1]

    class_member_mask = labels == k

    xy = proj_vect[class_member_mask & core_samples_mask]
    plt.plot(
        xy[:, 0],
        xy[:, 1],
        "o",
        markerfacecolor=tuple(col),
        markeredgecolor="k",
        markersize=14,
    )

    xy = proj_vect[class_member_mask & ~core_samples_mask]
    plt.plot(
        xy[:, 0],
        xy[:, 1],
        "o",
        markerfacecolor=tuple(col),
        markeredgecolor="k",
        markersize=0.1,
    )

plt.title(f"Estimated number of clusters: {n_clusters_}")
plt.show()


In [None]:
## Now TSNE it up
from sklearn.manifold import TSNE

perp=50
exag=50
print("Perplexity =", perp, "early exaggeration =", exag)
tsne = TSNE(n_components=2, perplexity=perp, n_iter=1000, early_exaggeration=exag)#, verbose=1, perplexity=60, n_iter=1000, early_exaggeration=20)
tsne_results = tsne.fit_transform(encoded_samples)

In [None]:
## Visualise the results
plt.scatter(list(zip(*tsne_results))[0], list(zip(*tsne_results))[1], s=4, c=labels)

In [None]:
## Function to show examples for each cluster
def plot_cluster_examples(dataset, labels, index, max_images=10): 
    
    plt.figure(figsize=(12,4.5))

    ## Get a mask of labels
    indices = np.where(np.array(labels) == index)[0]
    
    ## Grab the first 10 images (if there are 10)
    if len(indices) < max_images:
        max_images = len(indices)
    
    ## Plot
    for i in range(max_images):
        ax = plt.subplot(2,max_images,i+1)
        
        numpy_coords, numpy_feats = dataset[indices[i]]
    
        # Create batched coordinates for the SparseTensor input
        orig_bcoords  = ME.utils.batched_coordinates([numpy_coords])
        orig_bfeats  = torch.from_numpy(np.concatenate([numpy_feats], 0)).float()

        orig_bcoords = orig_bcoords.to(device)
        orig_bfeats = orig_bfeats.to(device)
        orig = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)
            
        inputs  = make_dense_from_tensor(orig)
        inputs  = inputs .cpu().squeeze().numpy()
        
        plt.imshow(inputs, origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)            
    plt.show()   

In [None]:
## Now pull out a bank of example images for each cluster

for index in range(n_clusters_):
    print("Showing examples for cluster:", index, "which has", n_points[index], "values")
    plot_cluster_examples(single_dataset, labels, index)

print("Showing examples for the noise, which has", n_noise_, "values")
plot_cluster_examples(single_dataset, labels, -1)