In [None]:
from torch.utils.tensorboard import SummaryWriter
import MinkowskiEngine as ME
import matplotlib.pyplot as plt
import matplotlib as mpl
import torchvision.transforms.v2 as transforms
import torchvision.transforms.v2.functional as F

%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)
import numpy as np
SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)
writer = SummaryWriter("log")

In [None]:
## Includes from my libraries for this project                                                                                                                                           
from ME_NN_libs import AsymmetricL2LossME, NTXentMerged, NTXentMergedTopTenNeg
from ME_NN_libs import EncoderME, DecoderME, DeepEncoderME, DeepDecoderME, DeeperEncoderME, DeeperDecoderME
from ME_NN_libs import ProjectionHead
from ME_dataset_libs import CenterCrop, MaxRegionCrop, RandomCrop, RandomHorizontalFlip, RandomRotation2D, RandomShear2D, RandomBlockZero, RandomJitterCharge, RandomScaleCharge, RandomElasticDistortion2D, RandomGridDistortion2D
from ME_dataset_libs import SingleModuleImage2D_MultiHDF5_ME, triple_ME_collate_fn, cat_ME_collate_fn
from ME_dataset_libs import make_dense, make_dense_from_tensor

In [None]:
import time

aug_transform = transforms.Compose([
    RandomGridDistortion2D(5,5),
    RandomShear2D(0.1, 0.1),
	RandomHorizontalFlip(),
    RandomRotation2D(-10,10),
    RandomBlockZero(5, 6),
    RandomScaleCharge(0.02),
    RandomJitterCharge(0.02),
    RandomCrop()
])

## How many events shall we use here?
nevents = 200000
inDir = "/pscratch/sd/c/cwilk/h5_inputs/"

start = time.process_time() 
train_dataset = SingleModuleImage2D_MultiHDF5_ME(inDir, \
                                                 nom_transform=MaxRegionCrop(), \
                                                 aug_transform=aug_transform, \
                                                 max_events=nevents)

print("Time taken to load", train_dataset.__len__(),"images:", time.process_time() - start)

## Randomly chosen batching
batch_size=512
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           collate_fn=cat_ME_collate_fn,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           drop_last=True,
                                           num_workers=16)

In [None]:
## Dump out some of the input and reconstructed images to see how the autoencoder is getting on
def plot_ae_outputs(encoder,decoder,n=10):  
    
    plt.figure(figsize=(12,5))
    
    encoder.eval()
    decoder.eval()
    ## Loop over figures
    for i in range(n):
        ax = plt.subplot(3,n,i+1)

        bcoords, bfeats = next(iter(train_loader))
        
        with torch.no_grad():
            
            bcoords = bcoords.to(device)
            bfeats = bfeats.to(device)
            orig = ME.SparseTensor(bfeats.float(), bcoords.int(), device=device)

            enc_orig  = encoder(orig)
            rec_orig  = decoder(enc_orig)
            
        inputs  = make_dense_from_tensor(orig)
        outputs = make_dense_from_tensor(rec_orig)
        
        this_input = inputs[0].cpu().squeeze().numpy()
        this_output = outputs[0].cpu().squeeze().numpy()
                
        ## Input row
        plt.imshow(this_input, cmap='viridis', origin='lower')            
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2: ax.set_title('Original images')
        
        ## Reconstructed row
        ax = plt.subplot(3, n, i + 1 + n)
        plt.imshow(this_output, cmap='viridis', origin='lower')  
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2: ax.set_title('Reconstructed images')
    
    plt.show()   

In [None]:
from torch import nn
class NTXentMergedTopTenNeg(nn.Module):
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature

    def forward(self, emb_cat):
        """                                                                                                                                                                              
        emb_cat are the concatenated batches of pairs emb_cat = z_i + z_j                                                                                                                
        """
        batch_size = emb_cat.shape[0]//2
        z_cat = nn.functional.normalize(emb_cat, dim=1)
        z_i, z_j = z_cat[:batch_size], z_cat[batch_size:]

        negatives_mask = (~torch.eye(batch_size*2, batch_size*2, dtype=bool, device=emb_cat.device)).float()
        representations = torch.cat([z_i, z_j], dim=0)
        # xcs = torch.matmul(z_i, z_j.T)

        similarity_matrix = nn.functional.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)

        sorted_indices = torch.argsort(similarity_matrix, dim=1)  # Sort similarities in ascending order
        top_10_percent = int(batch_size * 2 * 0.1)
        filtered_mask = torch.zeros_like(negatives_mask)

        for i in range(batch_size * 2):
            # Keep only the top 10% least similar negatives
            top_negatives = sorted_indices[i, :top_10_percent]
            filtered_mask[i, top_negatives] = 1.0

        # Adjust mask to include only top 10% least similar negatives
        final_negatives_mask = negatives_mask * filtered_mask
        
        sim_ij = torch.diag(similarity_matrix, batch_size)
        sim_ji = torch.diag(similarity_matrix, -batch_size)
        positives = torch.cat([sim_ij, sim_ji], dim=0)

        nominator = torch.exp(positives / self.temperature)
        denominator = final_negatives_mask * torch.exp(similarity_matrix / self.temperature)

        loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))
        loss = torch.sum(loss_partial) / (2*batch_size)

        return loss

In [None]:
## Wrap the training in a nicer function...
def run_training(num_iterations, log_dir, encoder, decoder, project, ntx_temp, dataloader, optimizer, batch_size, scheduler=None, prof=None):

    print("Training with", num_iterations, "iterations")
    tstart = time.process_time()

    if log_dir: writer = SummaryWriter(log_dir=log_dir)

    reco_loss_fn = AsymmetricL2LossME(10, 1, batch_size)
    latent_loss_fn = NTXentMergedTopTenNeg(ntx_temp)

    encoder.to(device, non_blocking=True)
    decoder.to(device, non_blocking=True)
    project.to(device)
    
    ## Loop over the desired iterations
    for iteration in range(num_iterations):
        
        tot_loss_tensor = 0
        rec_loss_tensor = 0
        lat_loss_tensor = 0

        nbatches   = 0
        
        # Set train mode for both the encoder and the decoder
        encoder.train()
        decoder.train()
        project.train()
        
        # Iterate over batches of images with the dataloader                                                                                                                             
        for cat_bcoords, cat_bfeats in train_loader:

            ## Send to the device, then make the sparse tensors                                                                                                                          
            cat_bcoords = cat_bcoords.to(device, non_blocking=True)
            cat_bfeats  = cat_bfeats .to(device)
            cat_batch   = ME.SparseTensor(cat_bfeats, cat_bcoords, device=device)

            ## Now do the forward passes                                                                                                                                                 
            encoded_batch = encoder(cat_batch)
            decoded_batch = decoder(encoded_batch)
            project_batch = project(encoded_batch)

            # Evaluate losses                                                                                                                                                            
            rec_loss = reco_loss_fn(decoded_batch, cat_batch)
            lat_loss = latent_loss_fn(project_batch.F)
            tot_loss = rec_loss + 0.05*lat_loss

            # Backward pass
            optimizer.zero_grad()
            tot_loss.backward()
            optimizer.step()            
            
            ## keep track of losses                                                                                                                                                      
            tot_loss_tensor += tot_loss.item()
            rec_loss_tensor += rec_loss.item()
            lat_loss_tensor += lat_loss.item()
            nbatches += 1
            
            torch.cuda.empty_cache()
        
        ## See if we have an LR scheduler...
        if scheduler: scheduler.step() #total_loss)
        
        av_tot_loss = tot_loss_tensor/nbatches
        av_rec_loss = rec_loss_tensor/nbatches
        av_lat_loss = lat_loss_tensor/nbatches

        if log_dir: writer.add_scalar('loss/train', av_tot_loss, iteration)
        print("Processed", iteration, "/", num_iterations, "; loss =", av_tot_loss,\
                  "(", av_rec_loss, "+", av_lat_loss, ")")
        print("Time taken:", time.process_time() - tstart)

        ## For profiling, it can be helpful to add a break here
        #break
        
        # if prof: prof.step()
        
        ## dump some images so we can see how the training is going
        # if iteration%10 == 0: plot_ae_outputs(encoder,decoder,10)
        
        ## End so empty cache because MinkowskiEngine can't be trusted
        torch.cuda.empty_cache()

In [None]:
## Check that the trained autoencoder actually does something
plot_ae_outputs(encoder,decoder)

In [None]:
torch.cuda.set_sync_debug_mode(0)

nchan=32
nlatent=128
hidden_act_fn=ME.MinkowskiSiLU
latent_act_fn=ME.MinkowskiTanh
dropout=0
final_layer=128
num_iterations=50
log_dir="log"
temp=0.5

## Define the models
encoder=EncoderME(nchan, nlatent, hidden_act_fn, latent_act_fn, dropout)
decoder=DecoderME(nchan, nlatent, hidden_act_fn)
project=ProjectionHead([nlatent, nlatent, nlatent, final_layer], latent_act_fn)

## Load in the pre-calculated model weights
chk_file="/pscratch/sd/c/cwilk/state_lat"+str(nlatent)+"_nchan"+str(nchan)+"_5e-6_PROJECT_TEMP"+str(temp)+"_onecycle_smallmodFIXCROP_2M_ME.pth"

checkpoint = torch.load(chk_file, map_location=device)
encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])
project.load_state_dict(checkpoint['project_state_dict'])
#project.eval()

encoder.to(device)
decoder.to(device)
project.to(device)

params_to_optimize = [
        {'params': encoder.parameters()},
        {'params': decoder.parameters()},
        {'params': project.parameters()},
    ]

lr=1e-5
weight_decay=0
optimizer = torch.optim.AdamW(params_to_optimize, lr=lr, weight_decay=weight_decay)

## Scheduler options
scheduler = None 

run_training(num_iterations, log_dir, encoder, decoder, project, temp, train_loader, optimizer, batch_size, scheduler)

In [None]:
import torch