In [None]:
from torch.utils.tensorboard import SummaryWriter
import MinkowskiEngine as ME
import matplotlib.pyplot as plt
import matplotlib as mpl
import torchvision.transforms.v2 as transforms
import torchvision.transforms.v2.functional as F

%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)
import numpy as np
SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)
writer = SummaryWriter("log")

In [None]:
## Includes from my libraries for this project                                                                                                                                           
from ME_NN_libs import AsymmetricL2LossME, EuclideanDistLoss
from ME_NN_libs import EncoderME, DecoderME, DeepEncoderME, DeepDecoderME, DeeperEncoderME, DeeperDecoderME
from ME_dataset_libs import CenterCrop, RandomCrop, RandomHorizontalFlip, RandomRotation2D, RandomShear2D, RandomBlockZero, RandomJitterCharge, RandomScaleCharge, RandomElasticDistortion2D, RandomGridDistortion2D
from ME_dataset_libs import SingleModuleImage2D_solo_ME, solo_ME_collate_fn
from ME_dataset_libs import make_dense, make_dense_from_tensor

In [None]:
import time

aug_transform = transforms.Compose([
    RandomGridDistortion2D(5,5),
    RandomShear2D(0.1, 0.1),
    RandomHorizontalFlip(),
    RandomRotation2D(-10,10),
    RandomBlockZero(5, 6),
    RandomScaleCharge(0.02),
    RandomJitterCharge(0.02),
    RandomCrop()
])

## How many events shall we use here?
nevents = 200000
inDir = "/pscratch/sd/c/cwilk/h5_inputs/"

start = time.process_time() 
train_dataset = SingleModuleImage2D_solo_ME(inDir, transform=aug_transform, max_events=nevents)
print("Time taken to load", train_dataset.__len__(),"images:", time.process_time() - start)

## Randomly chosen batching
batch_size=512
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           collate_fn=solo_ME_collate_fn,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           drop_last=True,
                                           num_workers=16)

In [None]:
## Dump out some of the input and reconstructed images to see how the autoencoder is getting on
def plot_ae_outputs(encoder,decoder,n=10):  
    
    plt.figure(figsize=(12,5))
    
    encoder.eval()
    decoder.eval()
    ## Loop over figures
    for i in range(n):
        ax = plt.subplot(3,n,i+1)

        bcoords, bfeats = next(iter(train_loader))
        
        with torch.no_grad():
            
            bcoords = bcoords.to(device)
            bfeats = bfeats.to(device)
            orig = ME.SparseTensor(bfeats.float(), bcoords.int(), device=device)

            enc_orig  = encoder(orig)
            rec_orig  = decoder(enc_orig)
            
        inputs  = make_dense_from_tensor(orig)
        outputs = make_dense_from_tensor(rec_orig)
        
        this_input = inputs[0].cpu().squeeze().numpy()
        this_output = outputs[0].cpu().squeeze().numpy()
                
        ## Input row
        plt.imshow(this_input, cmap='viridis', origin='lower')            
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2: ax.set_title('Original images')
        
        ## Reconstructed row
        ax = plt.subplot(3, n, i + 1 + n)
        plt.imshow(this_output, cmap='viridis', origin='lower')  
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2: ax.set_title('Reconstructed images')
    
    plt.show()   

In [None]:
## Wrap the training in a nicer function...
def run_training(num_iterations, log_dir, encoder, decoder, dataloader, optimizer, batch_size, scheduler=None, prof=None):

    print("Training with", num_iterations, "iterations")
    tstart = time.process_time()

    if log_dir: writer = SummaryWriter(log_dir=log_dir)

    reco_loss_fn = AsymmetricL2LossME(10, 1, batch_size)
        
    encoder.to(device, non_blocking=True)
    decoder.to(device)
    
    ## Loop over the desired iterations
    for iteration in range(num_iterations):
        
        total_loss = 0
        nbatches   = 0
        
        # Set train mode for both the encoder and the decoder
        encoder.train()
        decoder.train()
        
        # Iterate over batches of images with the dataloader                                                                                                                             
        for orig_bcoords, orig_bfeats in train_loader:

            ## Send to the device, then make the sparse tensors                                                                                                                          
            orig_bcoords = orig_bcoords.to(device, non_blocking=True)
            orig_bfeats = orig_bfeats.to(device)
            orig_batch = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)

            ## Now do the forward passes                                                                                                                                                 
            encoded_orig   = encoder(orig_batch)
            decoded_orig   = decoder(encoded_orig)

            # Evaluate loss                                                                                                                                                              
            loss = reco_loss_fn(decoded_orig, orig_batch)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()            
            
            total_loss += loss.item()
            nbatches += 1
            
            torch.cuda.empty_cache()
        
        ## See if we have an LR scheduler...
        if scheduler: scheduler.step() #total_loss)
        
        av_loss = total_loss/nbatches

        if log_dir: writer.add_scalar('loss/train', av_loss, iteration)
        print("Processed", iteration, "/", num_iterations, "; loss =", av_loss)
        print("Time taken:", time.process_time() - tstart)

        ## For profiling, it can be helpful to add a break here
        #break
        
        if prof: prof.step()
        
        ## dump some images so we can see how the training is going
        if iteration%10 == 0: plot_ae_outputs(encoder,decoder,10)
        
        ## End so empty cache because MinkowskiEngine can't be trusted
        torch.cuda.empty_cache()

In [None]:
## Check that the trained autoencoder actually does something
plot_ae_outputs(encoder,decoder)

In [None]:
torch.cuda.set_sync_debug_mode(0)

nchan=32
nlatent=96
hidden_act_fn=ME.MinkowskiSiLU
latent_act_fn=ME.MinkowskiTanh
dropout=0
num_iterations=50
log_dir="log"

## Define the models
encoder=EncoderME(nchan, nlatent, hidden_act_fn, latent_act_fn, dropout)
decoder=DecoderME(nchan, nlatent, hidden_act_fn)

## Load in the pre-calculated model weights
chk_file="/pscratch/sd/c/cwilk/state_lat"+str(nlatent)+"_nchan"+str(nchan)+"_5e-6_archsimple_SOLO_actfns_silu_tanh_2M_onecycle_ME.pth"

checkpoint = torch.load(chk_file, map_location=device)
encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])
encoder.eval()
decoder.eval()

encoder.to(device)
decoder.to(device)

params_to_optimize = [
        {'params': encoder.parameters()},
        {'params': decoder.parameters()}
    ]

lr=1e-6
weight_decay=0
optimizer = torch.optim.AdamW(params_to_optimize, lr=lr, weight_decay=weight_decay)

## Scheduler options
scheduler = None 

run_training(num_iterations, log_dir, encoder, decoder, train_loader, optimizer, batch_size, scheduler)