In [None]:
from torch.utils.tensorboard import SummaryWriter
import MinkowskiEngine as ME
from torch import nn
import torch
import numpy as np

import matplotlib.pyplot as plt
import matplotlib as mpl

%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

## Tell pytorch we have a GPU if we do
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)

SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)
writer = SummaryWriter("log")

In [None]:
from torch.utils.data import Dataset
import os
import time
import h5py
from glob import glob
from bisect import bisect

class SingleModuleImage2D_MultiHDF5_ME(Dataset):

    def __init__(self, infile_dir, nom_transform, aug_transform=None, max_events=None):
        self.hdf5_files = sorted(glob(os.path.join(infile_dir, '*.h5')))
        self.file_indices = []
        self.nom_transform = nom_transform
        self.aug_transform = aug_transform        
        self.max_events = max_events
        
        ## Sort out the file map
        self.create_file_indices()

        ## Apply some limitation to the size
        if self.max_events and max_events < self.length:
            self.length = self.max_events
         
    def create_file_indices(self):
        cumulative_size = 0
        
        for file in self.hdf5_files:
            self.file_indices.append(cumulative_size)
            f = h5py.File(file, 'r', libver='latest')
            cumulative_size += f.attrs['N']
            f .close()
        self.file_indices.append(cumulative_size)
        self.length = cumulative_size
        
    def __len__(self):
        return self.length
    
    def __getitem__(self,idx):

        file_index = bisect(self.file_indices, idx)-1
        this_idx = idx - self.file_indices[file_index]
        
        f = h5py.File(self.hdf5_files[file_index], 'r') 
        group = f[str(this_idx)]
        data = group['data'][:]
        row = group['row'][:]
        col = group['col'][:]

        ## Use the format that ME requires
        ## Note that we can't build the sparse tensor here because ME uses some sort of global indexing
        ## And this function is replicated * num_workers
        raw_coords = np.vstack((row, col)).T #.copy()
        raw_feats = data.reshape(-1, 1)  # Reshape data to be of shape (N, 1)
            
        ## Apply transforms to augment the data
        if not self.aug_transform:
            raw_coords, raw_feats = self.nom_transform(raw_coords, raw_feats)
            aug1_coords,aug1_feats = raw_coords,raw_feats
            aug2_coords,aug2_feats = raw_coords,raw_feats
        else:
            aug1_coords, aug1_feats = self.aug_transform(raw_coords, raw_feats)
            aug2_coords, aug2_feats = self.aug_transform(raw_coords, raw_feats)

            ## Make sure the images aren't empty...
            while aug1_feats.size == 0: aug1_coords, aug1_feats = self.aug_transform(raw_coords, raw_feats)
            while aug2_feats.size == 0: aug2_coords, aug2_feats = self.aug_transform(raw_coords, raw_feats)

            raw_coords, raw_feats   = self.nom_transform(raw_coords, raw_feats)

        return aug1_coords, aug1_feats, aug2_coords, aug2_feats, raw_coords, raw_feats
    
def triple_ME_collate_fn(batch):
    aug1_coords, aug1_feats, aug2_coords, aug2_feats, raw_coords, raw_feats = zip(*batch)
    
    # Create batched coordinates for the SparseTensor input
    aug1_bcoords = ME.utils.batched_coordinates(aug1_coords)
    aug2_bcoords = ME.utils.batched_coordinates(aug2_coords)
    raw_bcoords  = ME.utils.batched_coordinates(raw_coords)
    
    # Concatenate all lists
    aug1_bfeats = torch.from_numpy(np.concatenate(aug1_feats, 0)).float()
    aug2_bfeats = torch.from_numpy(np.concatenate(aug2_feats, 0)).float()
    raw_bfeats  = torch.from_numpy(np.concatenate(raw_feats, 0)).float()
    
    return aug1_bcoords, aug1_bfeats, aug2_bcoords, aug2_bfeats, raw_bcoords, raw_bfeats


In [None]:
import torchvision.transforms.v2 as transforms
import torchvision.transforms.v2.functional as F
import random

## This is a transformation for the nominal image
class CenterCrop:
    def __init__(self):
        self.orig_y = 280
        self.orig_x = 140
        self.new_y = 256
        self.new_x = 128
        self.pad_y = (self.orig_y - self.new_y)/2
        self.pad_x = (self.orig_x - self.new_x)/2
        
    def __call__(self, coords, feats):
        
        coords = coords - np.array([self.pad_y, self.pad_x])
        mask = (coords[:,0] > 0) & (coords[:,0] < (self.new_y)) \
             & (coords[:,1] > 0) & (coords[:,1] < (self.new_x))
        
        return coords[mask], feats[mask]

    
## This just takes a 256x128 subimage from the original 280x140 block
class RandomCrop:
    def __init__(self):
        self.orig_y = 280
        self.orig_x = 140
        self.new_y = 256
        self.new_x = 128       

    def __call__(self, coords, feats):
        ## Need to copy the array
        new_coords = coords.copy()
        new_feats = feats.copy()
        
        shift_y = random.randint(0, self.orig_y - self.new_y)
        shift_x = random.randint(0, self.orig_x - self.new_x)
        
        new_coords = new_coords - np.array([shift_x, shift_y])
        mask = (new_coords[:,0] > 0) & (new_coords[:,0] < (self.new_y)) \
             & (new_coords[:,1] > 0) & (new_coords[:,1] < (self.new_x))
        
        return new_coords[mask], new_feats[mask]
    
    
class RandomHorizontalFlip:
    def __init__(self, p=0.5):
        self.p = p
        self.ncols = 128
        
    def __call__(self, coords, feats):
        
        ## Need to copy the array
        new_coords = coords.copy()
        
        if torch.rand(1) < self.p:
            new_coords[:,1] = self.ncols - 1 - new_coords[:,1]
        return new_coords,feats
    
    
## Need to define a fairly standard functions that work for ME tensors
class RandomRotation2D:
    def __init__(self, min_angle, max_angle):
        self.min_angle = min_angle
        self.max_angle = max_angle

    def _M(self, theta):
        """Generate a 2D rotation matrix for a given angle theta."""
        return np.array([
            [np.cos(theta), -np.sin(theta)],
            [np.sin(theta),  np.cos(theta)]
        ])

    def __call__(self, coords, feats):
        """Apply a random rotation to 2D coordinates and return the rotated coordinates with features."""
        # Generate a random rotation angle
        angle = np.deg2rad(torch.FloatTensor(1).uniform_(self.min_angle, self.max_angle).item())

        # Get the 2D rotation matrix
        R = self._M(angle)
        # Apply the rotation
        rotated_coords = coords @ R
        return rotated_coords, feats
    
class RandomShear2D:
    def __init__(self, max_shear_x, max_shear_y):
        self.max_shear_x = max_shear_x
        self.max_shear_y = max_shear_y

    def __call__(self, coords, feats):
        """Apply a random rotation to 2D coordinates and return the rotated coordinates with features."""
        # Generate a random rotation angle
        shear_x = np.random.uniform(-self.max_shear_x, self.max_shear_x)
        shear_y = np.random.uniform(-self.max_shear_y, self.max_shear_y)

        shear_matrix = np.array([
            [1, shear_x],
            [shear_y, 1]
        ])
        
        rotated_coords = coords @ shear_matrix
        return rotated_coords, feats
    
    
## A function to randomly remove some number of blocks of size
## This has to be called before the cropping as it uses the original image size
class RandomBlockZero:
    def __init__(self, max_blocks=4, block_size=6):
        self.max_blocks = max_blocks
        self.block_size = block_size
        self.xmax = 140
        self.ymax = 280

    def __call__(self,  coords, feats):

        combined_mask = np.full(feats.size, True, dtype=bool)
        
        num_blocks_removed = random.randint(0, self.max_blocks)
        for _ in range(num_blocks_removed):
            this_size = self.block_size
            block_x = random.randint(0, self.xmax - this_size - 1)
            block_y = random.randint(0, self.ymax - this_size - 1)
            
            mask = ~((coords[:,0] > block_y) & (coords[:,0] < (block_y+this_size)) \
                   & (coords[:,1] > block_x) & (coords[:,1] < (block_x+this_size)))
            combined_mask = np.logical_and(combined_mask, mask)
            
        ## Need to copy the array
        new_coords = coords.copy()
        new_feats = feats.copy()
        
        return new_coords[combined_mask], new_feats[combined_mask]

    
aug_transform = transforms.Compose([
    RandomShear2D(0.1, 0.1),
    RandomHorizontalFlip(),
    RandomRotation2D(-10,10),
    RandomBlockZero(5, 6),
    RandomCrop()
])

## Get a concrete dataset and data loader
# inFile = "/global/cfs/cdirs/dune/users/cwilk/single_module_images/sparse_joblib_fixdupes_pluscuts_noneg_transform/training_images_200k.joblib"
inDir = "/pscratch/sd/c/cwilk/h5_inputs/"
start = time.process_time() 
train_dataset = SingleModuleImage2D_MultiHDF5_ME(inDir, nom_transform=CenterCrop(), aug_transform=aug_transform, max_events=100000)
print("Time taken to load", train_dataset.__len__(),"images:", time.process_time() - start)

In [None]:
## Randomly chosen batching
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           collate_fn=triple_ME_collate_fn,
                                           batch_size=512,
                                           shuffle=True, 
                                           num_workers=16,
                                           drop_last=True,
                                           pin_memory=True,
                                           prefetch_factor=2)

In [None]:
## Utility function to make a dense image
def make_dense(coords_batch, feats_batch, device, max_i=256, max_j=128):
    img = ME.SparseTensor(feats_batch.float(), coords_batch.int(), device=device)
    coords, feats = img.decomposed_coordinates_and_features
    batch_size = len(coords)
    img_dense,_,_ = img.dense(torch.Size([batch_size, 1, max_i, max_j]))
    return img_dense[0].squeeze().numpy()

def make_dense_from_tensor(sparse_batch, max_i=256, max_j=128):
    coords, feats = sparse_batch.decomposed_coordinates_and_features
    batch_size = len(coords)
    img_dense,_,_ = sparse_batch.dense(torch.Size([batch_size, 1, max_i, max_j]), min_coordinate=torch.IntTensor([0,0]))
    return img_dense[0]
    

In [None]:
## Visualise data
# Access a specific instance
aug1_bcoords, aug1_bfeats, aug2_bcoords, aug2_bfeats, orig_bcoords, orig_bfeats = next(iter(train_loader))

# Visualize the image
plt.figure(figsize=(10,3))
ax = plt.subplot(1,3,1)

gr1 = plt.imshow(make_dense(aug1_bcoords, aug1_bfeats, 'cpu'), origin='lower')
plt.colorbar(gr1)
ax = plt.subplot(1,3,2)
gr2 = plt.imshow(make_dense(aug2_bcoords, aug2_bfeats, 'cpu'), origin='lower')
plt.colorbar(gr2)
ax = plt.subplot(1,3,3)
gr3 = plt.imshow(make_dense(orig_bcoords, orig_bfeats, 'cpu'), origin='lower')
plt.colorbar(gr3)


In [None]:
## This is now modified to work with/produce 256x128 tensors
class EncoderME(nn.Module):
    
    def __init__(self, 
                 nchan : int,
                 latent_dim : int,
                 act_fn : object = ME.MinkowskiReLU,
                 drop_fract : float = 0.2):
        """
        Inputs:
            - n_chan : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        
        self.ch = [nchan, nchan*2, nchan*4, nchan*8, nchan*16, nchan*32, nchan*64, nchan*128]
        
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            ME.MinkowskiConvolution(in_channels=1, out_channels=self.ch[0], kernel_size=2, stride=2, bias=False, dimension=2), ## 256x128 ==> 128x64
            ME.MinkowskiBatchNorm(self.ch[0]),
            act_fn(),
            ME.MinkowskiDropout(drop_fract),
            ME.MinkowskiConvolution(in_channels=self.ch[0], out_channels=self.ch[0], kernel_size=3, bias=False, dimension=2), ## No change in size
            ME.MinkowskiBatchNorm(self.ch[0]),
            act_fn(),
            ME.MinkowskiDropout(drop_fract),
            ME.MinkowskiConvolution(in_channels=self.ch[0], out_channels=self.ch[1], kernel_size=2, stride=2, bias=False, dimension=2), ## 128x64 ==> 64x32
            ME.MinkowskiBatchNorm(self.ch[1]),
            act_fn(),
            ME.MinkowskiDropout(drop_fract),
            ME.MinkowskiConvolution(in_channels=self.ch[1], out_channels=self.ch[1], kernel_size=3, bias=False, dimension=2), ## No change in size
            ME.MinkowskiBatchNorm(self.ch[1]),
            act_fn(),
            ME.MinkowskiDropout(drop_fract),
            ME.MinkowskiConvolution(in_channels=self.ch[1], out_channels=self.ch[2], kernel_size=2, stride=2, bias=False, dimension=2), ## 64x32 ==> 32x16
            ME.MinkowskiBatchNorm(self.ch[2]),
            act_fn(),
            ME.MinkowskiDropout(drop_fract),
            ME.MinkowskiConvolution(in_channels=self.ch[2], out_channels=self.ch[2], kernel_size=3, bias=False, dimension=2), ## No change in size
            ME.MinkowskiBatchNorm(self.ch[2]),
            act_fn(),
            ME.MinkowskiDropout(drop_fract),
            ME.MinkowskiConvolution(in_channels=self.ch[2], out_channels=self.ch[3], kernel_size=2, stride=2, bias=False, dimension=2), ## 32x16 ==> 16x8
            ME.MinkowskiBatchNorm(self.ch[3]),
            act_fn(),
            ME.MinkowskiDropout(drop_fract),
            ME.MinkowskiConvolution(in_channels=self.ch[3], out_channels=self.ch[3], kernel_size=3, bias=False, dimension=2), ## No change in size
            ME.MinkowskiBatchNorm(self.ch[3]),
            act_fn(),
            ME.MinkowskiDropout(drop_fract),
            ME.MinkowskiConvolution(in_channels=self.ch[3], out_channels=self.ch[4], kernel_size=2, stride=2, bias=False, dimension=2), ## 16x8 ==> 8x4
            ME.MinkowskiBatchNorm(self.ch[4]),
            act_fn(),
            ME.MinkowskiDropout(drop_fract),
            ME.MinkowskiConvolution(in_channels=self.ch[4], out_channels=self.ch[4], kernel_size=3, bias=False, dimension=2), ## No change in size
            ME.MinkowskiBatchNorm(self.ch[4]),
            act_fn(),
            ME.MinkowskiDropout(drop_fract),
            ME.MinkowskiConvolution(in_channels=self.ch[4], out_channels=self.ch[5], kernel_size=2, stride=2, bias=False, dimension=2), ## 8x4 ==> 4x2
            ME.MinkowskiBatchNorm(self.ch[5]),
            act_fn(),
            ME.MinkowskiDropout(drop_fract),
            ME.MinkowskiConvolution(in_channels=self.ch[5], out_channels=self.ch[5], kernel_size=3, bias=False, dimension=2), ## No change in size
            ME.MinkowskiBatchNorm(self.ch[5]),
            act_fn(),
            ME.MinkowskiDropout(drop_fract),
            ME.MinkowskiConvolution(in_channels=self.ch[5], out_channels=self.ch[6], kernel_size=2, stride=2, bias=False, dimension=2), ## 4x2 ==> 2x1
            ME.MinkowskiBatchNorm(self.ch[6]),
            act_fn(),
            ME.MinkowskiDropout(drop_fract),
            ME.MinkowskiConvolution(in_channels=self.ch[6], out_channels=self.ch[6], kernel_size=3, bias=False, dimension=2), ## No change in size
            ME.MinkowskiBatchNorm(self.ch[6]),
            act_fn(),
            ME.MinkowskiDropout(drop_fract),
            ME.MinkowskiConvolution(in_channels=self.ch[6], out_channels=self.ch[7], kernel_size=(2,1), stride=(2,1), bias=False, dimension=2), ## 2x1 ==> 1x1
            ME.MinkowskiBatchNorm(self.ch[7]),
            act_fn(),
            ME.MinkowskiDropout(drop_fract),
        )
        
        ### Linear section, simple for now
        self.encoder_lin = nn.Sequential(
            ME.MinkowskiLinear(self.ch[7], self.ch[4]),
            act_fn(),      
            ME.MinkowskiDropout(drop_fract),
            ME.MinkowskiLinear(self.ch[4], self.ch[3]),
            act_fn(),      
            ME.MinkowskiDropout(drop_fract),
            ME.MinkowskiLinear(self.ch[3], latent_dim),
        )
        # Initialize weights using Xavier initialization
        self.initialize_weights()

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, ME.MinkowskiConvolution) or \
               isinstance(m, ME.MinkowskiGenerativeConvolutionTranspose):
                ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu")
            if isinstance(m, ME.MinkowskiLinear):
                ME.utils.kaiming_normal_(m.linear.weight, mode='fan_out', nonlinearity='relu')
            if isinstance(m, ME.MinkowskiBatchNorm):
                    nn.init.constant_(m.bn.weight, 1)
                    nn.init.constant_(m.bn.bias, 0)
        
    def forward(self, x):
        x = self.encoder_cnn(x)
        x = self.encoder_lin(x)
        return x
    
class DecoderME(nn.Module):
    
    def __init__(self, 
                 nchan : int,
                 latent_dim : int,
                 act_fn : object):
        """
        Inputs:
            - n_chan : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the decoder network
        """
        super().__init__()
        
        self.ch = [nchan, nchan*2, nchan*4, nchan*8, nchan*16, nchan*32, nchan*64, nchan*128]
        
        self.decoder_lin = nn.Sequential(
            ME.MinkowskiLinear(latent_dim, self.ch[3]),
            act_fn(), 
            ME.MinkowskiLinear(self.ch[3], self.ch[4]),
            act_fn(),  
            ME.MinkowskiLinear(self.ch[4], self.ch[7]),
            act_fn()   
        )

        self.decoder_conv = nn.Sequential(  
            ME.MinkowskiGenerativeConvolutionTranspose(in_channels=self.ch[7], out_channels=self.ch[6], kernel_size=(2,1), stride=(2,1), bias=False, dimension=2), ## 1x1 ==> 2x1
            ME.MinkowskiBatchNorm(self.ch[6]),
            act_fn(),
            ME.MinkowskiConvolution(in_channels=self.ch[6], out_channels=self.ch[6], kernel_size=3, stride=1, bias=False, dimension=2), ## No change in size
            ME.MinkowskiBatchNorm(self.ch[6]),
            act_fn(), 
            ME.MinkowskiGenerativeConvolutionTranspose(in_channels=self.ch[6], out_channels=self.ch[5], kernel_size=2, stride=2, bias=False, dimension=2), ## 2x1 ==> 4x2
            ME.MinkowskiBatchNorm(self.ch[5]),
            act_fn(),
            ME.MinkowskiConvolution(in_channels=self.ch[5], out_channels=self.ch[5], kernel_size=3, stride=1, bias=False, dimension=2), ## No change in size
            ME.MinkowskiBatchNorm(self.ch[5]),
            act_fn(), 
            ME.MinkowskiGenerativeConvolutionTranspose(in_channels=self.ch[5], out_channels=self.ch[4], kernel_size=2, stride=2, bias=False, dimension=2), ## 4x2 ==> 8x4
            ME.MinkowskiBatchNorm(self.ch[4]),
            act_fn(),
            ME.MinkowskiConvolution(in_channels=self.ch[4], out_channels=self.ch[4], kernel_size=3, bias=False, dimension=2), ## No change in size
            ME.MinkowskiBatchNorm(self.ch[4]),
            act_fn(), 
            ME.MinkowskiGenerativeConvolutionTranspose(in_channels=self.ch[4], out_channels=self.ch[3], kernel_size=2, stride=2, bias=False, dimension=2), ## 8x4 ==> 16x8
            ME.MinkowskiBatchNorm(self.ch[3]),
            act_fn(),
            ME.MinkowskiConvolution(in_channels=self.ch[3], out_channels=self.ch[3], kernel_size=3, bias=False, dimension=2), ## No change in size
            ME.MinkowskiBatchNorm(self.ch[3]),
            act_fn(), 
            ME.MinkowskiGenerativeConvolutionTranspose(in_channels=self.ch[3], out_channels=self.ch[2], kernel_size=2, stride=2, bias=False, dimension=2), ## 16x8 ==> 32x16
            ME.MinkowskiBatchNorm(self.ch[2]),
            act_fn(),
            ME.MinkowskiConvolution(in_channels=self.ch[2], out_channels=self.ch[2], kernel_size=3, bias=False, dimension=2), ## No change in size
            ME.MinkowskiBatchNorm(self.ch[2]),
            act_fn(), 
            ME.MinkowskiGenerativeConvolutionTranspose(in_channels=self.ch[2], out_channels=self.ch[1], kernel_size=2, stride=2, bias=False, dimension=2), ## 32x16 ==> 64x32
            ME.MinkowskiBatchNorm(self.ch[1]),
            act_fn(),
            ME.MinkowskiConvolution(in_channels=self.ch[1], out_channels=self.ch[1], kernel_size=3, bias=False, dimension=2), ## No change in size
            ME.MinkowskiBatchNorm(self.ch[1]),
            act_fn(), 
            ME.MinkowskiGenerativeConvolutionTranspose(in_channels=self.ch[1], out_channels=self.ch[0], kernel_size=2, stride=2, bias=False, dimension=2), ## 64x32 ==> 128x64
            ME.MinkowskiBatchNorm(self.ch[0]),
            act_fn(),
            ME.MinkowskiConvolution(in_channels=self.ch[0], out_channels=self.ch[0], kernel_size=3, bias=False, dimension=2), ## No change in size
            ME.MinkowskiBatchNorm(self.ch[0]),
            act_fn(),          
            ME.MinkowskiGenerativeConvolutionTranspose(in_channels=self.ch[0], out_channels=1, kernel_size=2, stride=2, bias=True, dimension=2), ## 128x64 ==> 256x128
            act_fn()
        )
        
        # Initialize weights using Xavier initialization
        self.initialize_weights()
        
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, ME.MinkowskiConvolution) or \
               isinstance(m, ME.MinkowskiGenerativeConvolutionTranspose):
                ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu")
            if isinstance(m, ME.MinkowskiLinear):
                ME.utils.kaiming_normal_(m.linear.weight, mode='fan_out', nonlinearity='relu')
            if isinstance(m, ME.MinkowskiBatchNorm):
                    nn.init.constant_(m.bn.weight, 1)
                    nn.init.constant_(m.bn.bias, 0)

    def forward(self, x, target_key=None):
        x = self.decoder_lin(x)
        x = self.decoder_conv(x)
        return x

In [None]:
## Dump out some of the input and reconstructed images to see how the autoencoder is getting on
def plot_ae_outputs(encoder,decoder,n=10):  
    
    plt.figure(figsize=(12,5))
    
    encoder.eval()
    decoder.eval()
    ## Loop over figures
    for i in range(n):
        ax = plt.subplot(3,n,i+1)

        aug1_bcoords, aug1_bfeats, aug2_bcoords, aug2_bfeats, orig_bcoords, orig_bfeats = next(iter(train_loader))
        
        with torch.no_grad():
            
            orig_bcoords = orig_bcoords.to(device)
            orig_bfeats = orig_bfeats.to(device)
            orig = ME.SparseTensor(orig_bfeats.float(), orig_bcoords.int(), device=device)

            enc_orig  = encoder(orig)
            rec_orig  = decoder(enc_orig)
            
        inputs  = make_dense_from_tensor(orig)
        outputs = make_dense_from_tensor(rec_orig)
        
        this_input = inputs[0].cpu().squeeze().numpy()
        this_output = outputs[0].cpu().squeeze().numpy()
                
        ## Input row
        plt.imshow(this_input, cmap='viridis', origin='lower')            
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2: ax.set_title('Original images')
        
        ## Reconstructed row
        ax = plt.subplot(3, n, i + 1 + n)
        plt.imshow(this_output, cmap='viridis', origin='lower')  
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2: ax.set_title('Reconstructed images')
    
    plt.show()   

In [None]:
# Function to plot model weight distribution
def plot_weight_distribution(model):
    # Collect all the weights from the model
    all_weights = []
    for param in model.parameters():
        all_weights.extend(param.data.cpu().numpy().flatten())
    
    # Plot histogram of weights
    plt.figure(figsize=(8, 6))
    plt.hist(all_weights, bins=50, color='blue', alpha=0.7)
    plt.title('Weight Distribution')
    plt.xlabel('Weight Value')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.show()

In [None]:
## Function to plot the distribution of charge in the images for the input and reconstructed images
## This hasn't been used since the move to MinkowskiEngine, so probably won't work without modification...
def plot_distribution_from_dataloader(data_loader, encoder, decoder):
    encoder.eval()
    decoder.eval()
    
    # Initialize empty lists to store histogram counts
    num_bins=50
    input_hist = np.zeros(num_bins, dtype=int)
    output_hist = np.zeros(num_bins, dtype=int)

    # Create logarithmically spaced bins
    # bins = np.logspace(0, np.log10(1000), num=num_bins+1)
    bins = np.linspace(0.1, 3.1, num=num_bins+1)
    with torch.no_grad():
        for image_batch1, image_batch2, orig_batch in data_loader:

            image_batch1 = image_batch1.to(device, non_blocking=True)
            # Encode data
            encoded_batch1 = encoder(image_batch1)
            # Decode data
            decoded_batch1 = decoder(encoded_batch1)
            
            # Flatten input and output tensors to 1D arrays
            # Update input histogram
            input_hist += np.histogram(image_batch1.cpu().numpy(), bins=bins)[0]

            # Update output histogram
            output_hist += np.histogram(decoded_batch1.cpu().numpy(), bins=bins)[0]

    # Plot distribution of input and output values
    plt.figure(figsize=(10, 5))
    plt.plot(bins[:-1], input_hist, color='blue', label='Input')
    plt.plot(bins[:-1], output_hist, color='orange', label='Output')
    plt.title('Distribution of Input and Output Values')
    plt.xlabel('Value')
    plt.ylabel('Frequency')
    # plt.xscale('log')
    plt.legend()
    plt.show()

In [None]:
## This is MSE loss, except the penalty for getting something that should be 0 wrong is reduced w.r.t. 
## the penalty for getting something that should be nonzero wrong
class AsymmetricL2LossME(torch.nn.Module):
    def __init__(self, nonzero_cost=2.0, zero_cost=1.0):
        super(AsymmetricL2LossME, self).__init__()
        self.nonzero_cost = nonzero_cost
        self.zero_cost = zero_cost
    
    def forward(self, pred, targ):
        ## To make this just MSE loss, these three lines would suffice
        ## (this also doesn't require a synch call)
        #diff = pred - targ
        #loss = torch.sum(diff.F**2)
        #return loss/512/128/256
        
        # Extract coordinates and features from both sparse tensors
        pred_C = pred.C
        pred_F = pred.F
    
        targ_C = targ.C
        targ_F = targ.F
        
        _, idx, counts = torch.cat([pred_C, targ_C], dim=0).unique(dim=0, return_inverse=True, return_counts=True)
        gt_mask = counts.gt(1)
        mask = gt_mask[idx]
        pred_mask = mask[:pred_C.size(0)]
        targ_mask = mask[pred_C.size(0):]
               
        indices_pred = torch.arange(pred_F.size(0), device='cuda')[pred_mask]
        indices_targ = torch.arange(targ_F.size(0), device='cuda')[targ_mask]
        
        common_pred_F = pred_F.index_select(0, indices_pred)
        common_targ_F = targ_F.index_select(0, indices_targ)
        
        ## These cause blocking synchronization calls
        common_pred_F = pred_F[pred_mask]
        common_targ_F = targ_F[targ_mask]
        uncommon_pred_F = pred_F[~pred_mask]
        uncommon_targ_F = targ_F[~targ_mask]
       
        common = torch.sum(self.nonzero_cost*(common_pred_F - common_targ_F)**2)
        only_p = torch.sum(self.zero_cost*(uncommon_pred_F**2))
        only_t = torch.sum(self.nonzero_cost*(uncommon_targ_F**2))
        
        return (common+only_p+only_t)/(512*128*256)

In [None]:
class EuclideanDistLoss(torch.nn.Module):
    def __init__(self, scale, cutoff=0.1, pressure=100):
        super(EuclideanDistLoss, self).__init__()
        self.cutoff = cutoff
        self.pressure = pressure
        self.scale = scale
        
    def forward(self, latent1, latent2):
        # Compute the Euclidean distance between each pair of corresponding tensors in the batch
        batch_size = latent1.size(0)
        #print(batch_size)
        # print(latent1)
        # print(latent2)
        distances = torch.norm(latent1 - latent2, p=2, dim=1)
        #print(distances)
        mod_penalty = torch.stack([self.calc_penalty(item) for item in distances])
        #print(mod_penalty)
        loss = mod_penalty.mean()*self.scale
        return loss
        
    def calc_penalty(self, value):
        ## Apply a penalty that is the value-cutoff above the cutoff, and is penalty*(cutoff - value)**2 for values below it
        if value > self.cutoff:
            return (value - self.cutoff)**2
        else: 
            return self.pressure*(self.cutoff - value)**2


In [None]:
## Wrap the training in a nicer function...
def run_training(num_iterations, log_dir, encoder, decoder, dataloader, optimizer, latent_scale=1, scheduler=None, prof=None):

    print("Training with", num_iterations, "iterations")
    tstart = time.process_time()

    if log_dir: writer = SummaryWriter(log_dir=log_dir)

    reco_loss_fn = AsymmetricL2LossME(10, 1)
    latent_loss_fn = EuclideanDistLoss(latent_scale)
    
    encoder.to(device, non_blocking=True)
    decoder.to(device)
    
    ## Set a maximum for thresholding
    total_memory = torch.cuda.get_device_properties(0).total_memory
    gpu_threshold = 0.8 * total_memory
    
    ## Loop over the desired iterations
    for iteration in range(num_iterations):
        
        total_loss = 0
        total_aug1_loss = 0
        total_aug2_loss = 0
        total_orig_loss = 0
        total_latent_loss = 0
        nbatches   = 0
        
        # Set train mode for both the encoder and the decoder
        encoder.train()
        decoder.train()
        
        # Iterate over batches of images with the dataloader
        for aug1_bcoords, aug1_bfeats, aug2_bcoords, aug2_bfeats, orig_bcoords, orig_bfeats in dataloader:
            
            ## Send to the device, then make the sparse tensors
            aug1_bcoords = aug1_bcoords.to(device, non_blocking=True)
            aug1_bfeats = aug1_bfeats.to(device, non_blocking=True)
            aug2_bcoords = aug2_bcoords.to(device, non_blocking=True)
            aug2_bfeats = aug2_bfeats.to(device) ## This one has to block or it can try to run the forward function without the data on the GPU...
            #orig_bcoords = orig_bcoords.to(device, non_blocking=True)
            #orig_bfeats = orig_bfeats.to(device)
            
            aug1_batch = ME.SparseTensor(aug1_bfeats, aug1_bcoords, device=device)
            aug2_batch = ME.SparseTensor(aug2_bfeats, aug2_bcoords, device=device)
            #orig_batch = ME.SparseTensor(orig_bfeats, orig_bcoords, device=device)            
                                    
            ## Now do the forward passes
            encoded_batch1 = encoder(aug1_batch)
            decoded_batch1 = decoder(encoded_batch1)
            encoded_batch2 = encoder(aug2_batch)
            decoded_batch2 = decoder(encoded_batch2)
            #encoded_orig   = encoder(orig_batch)
            #decoded_orig   = decoder(encoded_orig)
     
            # Evaluate loss
            aug1_loss = reco_loss_fn(decoded_batch1, aug1_batch)
            aug2_loss = reco_loss_fn(decoded_batch2, aug2_batch)
            # orig_loss = reco_loss_fn(decoded_orig, orig_batch)
            latent_loss = latent_loss_fn(encoded_batch1.F, encoded_batch2.F)
            
            loss = aug1_loss + aug2_loss + latent_loss
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()            
            
            total_loss += loss.item()
            total_aug1_loss += aug1_loss.item()
            total_aug2_loss += aug2_loss.item()
            # total_orig_loss += orig_loss.item()
            total_latent_loss += latent_loss.item()
            nbatches += 1
            
            if torch.cuda.memory_allocated() > gpu_threshold:
                torch.cuda.empty_cache()
        
        ## See if we have an LR scheduler...
        if scheduler: scheduler.step(total_loss)
        
        av_loss = total_loss/nbatches
        av_aug1_loss = total_aug1_loss/nbatches
        av_aug2_loss = total_aug2_loss/nbatches
        # av_orig_loss = total_orig_loss/nbatches
        av_latent_loss = total_latent_loss/nbatches

        if log_dir: writer.add_scalar('loss/train', av_loss, iteration)
        print("Processed", iteration, "/", num_iterations, "; loss =", av_loss, "(", av_aug1_loss, "+", av_aug2_loss, "+", av_latent_loss) #, ";", av_orig_loss, ")")
        print("Time taken:", time.process_time() - tstart)

        ## For profiling, it can be helpful to add a break here
        #break
        
        if prof: prof.step()
        
        ## dump some images so we can see how the training is going
        if iteration%10 == 0: plot_ae_outputs(encoder,decoder,10)
        
        ## End so empty cache because MinkowskiEngine can't be trusted
        torch.cuda.empty_cache()

In [None]:
## This is a useful but experimental pytorch function which flags where synchronization calls are made
torch.cuda.set_sync_debug_mode(0)

## Varius config parameters
nchan=16
latent=16
act_fn=ME.MinkowskiReLU
dropout = 0
num_iterations=50
log_dir="log"

## Set up the model
enc = EncoderME
dec = DecoderME
encoder=enc(nchan, latent, act_fn, dropout)
decoder=dec(nchan, latent, act_fn)

encoder.to(device, non_blocking=True)
decoder.to(device)

params_to_optimize = [
        {'params': encoder.parameters()},
        {'params': decoder.parameters()}
    ]

lr=2e-4
weight_decay=0
optimizer = torch.optim.AdamW(params_to_optimize, lr=lr, weight_decay=weight_decay)

## Scheduler options
scheduler = None 
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.5)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=3)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100,200], gamma=0.2, last_epoch=-1)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=num_iterations, cycle_momentum=False)

## New hyperparameter to change the relative importance of the latent space distance term in the loss function
scale = 1

#from torch.profiler import profile, record_function, ProfilerActivity
#with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], use_cuda=True, 
#             with_stack=False, record_shapes=False) as prof:
#    with record_function("model_inference"):
#        run_training(num_iterations, log_dir, encoder, decoder, train_loader, optimizer, scale scheduler, prof)

run_training(num_iterations, log_dir, encoder, decoder, train_loader, optimizer, scale, scheduler)

In [None]:
## Now take the trained model and try to run some unsupervised learning on it...
import pandas as pd 

## Make a single loader to loop over for ease
single_loader = torch.utils.data.DataLoader(train_dataset,
                                            collate_fn=collate_triplet,
                                            batch_size=1,
                                            shuffle=True,
                                            num_workers=1)

encoded_samples = []
encoded_images  = []
for aug1, aug2, orig in single_loader:
    orig = orig.to(device)
    
    # Encode image
    encoder.eval()
    with torch.no_grad():
        encoded_img  = encoder(orig)
    # Append to list
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    # print(type(img.cpu().numpy()))
    encoded_sample['nhits'] = np.count_nonzero(orig.cpu().numpy())
    encoded_samples.append(encoded_sample)
    encoded_images .append(encoded_img)
encoded_samples = pd.DataFrame(encoded_samples)

#labels=pd.DataFrame(labels)

In [None]:
## Make a plot of what it looks like
plt.scatter(encoded_samples["Enc. Variable 0"], encoded_samples["Enc. Variable 1"],vmin=100, vmax=500, c=encoded_samples["nhits"])

In [None]:
## Now TSNE it up
from sklearn.manifold import TSNE

perp=5
exag=50
print("Perplexity =", perp, "early exaggeration =", exag)
tsne = TSNE(n_components=2, perplexity=perp, n_iter=1000, early_exaggeration=exag)#, verbose=1, perplexity=60, n_iter=1000, early_exaggeration=20)
tsne_results = tsne.fit_transform(encoded_samples)

In [None]:
gr = plt.scatter(list(zip(*tsne_results))[0], list(zip(*tsne_results))[1], s=1, alpha=0.8, vmin=100, vmax=500, c=encoded_samples["nhits"])
plt.colorbar(gr)
plt.show()

In [None]:
## Make a single loader to loop over for ease
single_loader = torch.utils.data.DataLoader(train_dataset,
                                            collate_fn=collate_triplet,
                                            batch_size=1,
                                            shuffle=True,
                                            num_workers=1)

encoded_images  = []
for aug1, aug2, orig in single_loader:
    orig = orig.to(device)
    
    # Encode image
    encoder.eval()
    with torch.no_grad():
        encoded_img  = encoder(orig)
    # Append to list
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_images .append(encoded_img)
encoded_images = np.vstack(encoded_images)

In [None]:
## Try k-NN algorithm
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler

# Assuming `latent_space` is your latent space representation
latent_space = encoded_images #StandardScaler().fit_transform(encoded_images) #encoded_images

# Find the distances to the k-nearest neighbors
k = 5  # You can set k equal to min_samples
neighbors = NearestNeighbors(n_neighbors=k)
neighbors_fit = neighbors.fit(latent_space)
distances, indices = neighbors_fit.kneighbors(latent_space)

# Sort distances to the k-th nearest neighbor (ascending order)
distances = np.sort(distances, axis=0)
distances = distances[:, 1]

# Plot the distances
plt.figure(figsize=(10, 6))
plt.plot(distances)
plt.title('k-NN Distance Plot')
plt.xlabel('Points sorted by distance to {}-th nearest neighbor'.format(k))
plt.ylabel('Distance')
#plt.ylim(0, 5)
plt.show()

In [None]:
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler

scaled_encoded_images = encoded_images #StandardScaler().fit_transform(encoded_images)

plt.scatter(scaled_encoded_images[:, 0], scaled_encoded_images[:, 1], s=1)
plt.show()

dbscan = DBSCAN(eps=0.03, min_samples=5)

clusters = dbscan.fit(scaled_encoded_images)

labels = clusters.labels_

# Number of clusters in labels, ignoring noise if present.
n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
n_noise_ = list(labels).count(-1)

n_points = [list(labels).count(x) for x in range(n_clusters_)]

print("Estimated number of clusters: %d" % n_clusters_)
print("N. points in clusters:", n_points)
print("Estimated number of noise points: %d" % n_noise_)
print("(Out of a total of %d images)" % len(scaled_encoded_images))

In [None]:
unique_labels = set(labels)
core_samples_mask = np.zeros_like(labels, dtype=bool)
core_samples_mask[dbscan.core_sample_indices_] = True

colors = [plt.cm.Spectral(each) for each in np.linspace(0, 1, len(unique_labels))]
for k, col in zip(unique_labels, colors):
    if k == -1:
        # Black used for noise.
        col = [0, 0, 0, 1]

    class_member_mask = labels == k

    xy = scaled_encoded_images[class_member_mask & core_samples_mask]
    plt.plot(
        xy[:, 0],
        xy[:, 1],
        "o",
        markerfacecolor=tuple(col),
        markeredgecolor="k",
        markersize=14,
    )

    xy = scaled_encoded_images[class_member_mask & ~core_samples_mask]
    plt.plot(
        xy[:, 0],
        xy[:, 1],
        "o",
        markerfacecolor=tuple(col),
        markeredgecolor="k",
        markersize=0.1,
    )

plt.title(f"Estimated number of clusters: {n_clusters_}")
plt.show()


In [None]:
## Now take all of the above, and put it into t-SNE for visualization
import pandas as pd 

encoded_samples = []
index=0
for img in single_loader:
    ## Skip noise
    if labels[index] == -1:
        index += 1
        continue
    img = img.to(device)
    
    # Encode image
    encoder.eval()
    with torch.no_grad():
        encoded_img  = encoder(img)
    # Append to list
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    # print(type(img.cpu().numpy()))
    encoded_sample['nhits'] = np.count_nonzero(img.cpu().numpy())
    encoded_sample['db_cluster'] = labels[index]
    encoded_samples.append(encoded_sample)
    index+=1

encoded_samples = pd.DataFrame(encoded_samples)

In [None]:
## Now TSNE it up
from sklearn.manifold import TSNE

perp=50
exag=50
print("Perplexity =", perp, "early exaggeration =", exag)
tsne = TSNE(n_components=2, perplexity=perp, n_iter=1000, early_exaggeration=exag)#, verbose=1, perplexity=60, n_iter=1000, early_exaggeration=20)
tsne_results = tsne.fit_transform(encoded_samples)

In [None]:
## Visualise the results
plt.scatter(list(zip(*tsne_results))[0], list(zip(*tsne_results))[1], s=4, c=encoded_samples["db_cluster"])

In [None]:
## Function to show examples for each cluster
def plot_cluster_examples(raw_images, labels, index, max_images=10): 
    
    plt.figure(figsize=(12,4.5))

    ## Get a mask of labels
    indices = np.where(np.array(labels) == index)[0]
    
    ## Grab the first 10 images (if there are 10)
    if len(indices) < max_images:
        max_images = len(indices)
    
    ## Plot
    for i in range(max_images):
        ax = plt.subplot(2,max_images,i+1)
        plt.imshow(raw_images[indices[i]][2].squeeze(), origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)            
    plt.show()   

In [None]:
## Now pull out a bank of example images for each cluster

for index in range(n_clusters_):
    print("Showing examples for cluster:", index, "which has", n_points[index], "values")
    plot_cluster_examples(train_dataset, labels, index)

print("Showing examples for the noise, which has", n_noise_, "values")
plot_cluster_examples(train_dataset, labels, -1)