In [None]:
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
writer = SummaryWriter("log")

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from torchvision import datasets, transforms

%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)
import numpy as np
SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)

In [None]:
## Make the single module dataset
from torch.utils.data import Dataset
import os
import h5py
import time
import numpy as np
import joblib
import scipy
from glob import glob
from itertools import chain

class SingleModuleImage2D_sparse_joblib(Dataset):

    def __init__(self, infilename, normalize=None, transform=None):
        
        self._data = joblib.load(infilename)
        self._length = len(self._data)
        self._normalize = normalize
        self._transform = transform

    def __len__(self):
        return self._length
    
    def __getitem__(self,idx):

        ## Convert to a dense pytorch tensor...
        # data = torch.Tensor(self._data[idx].toarray())
        data = self._data[idx].toarray()
        
        if self._transform:
            data = self._transform(data)
        
        data = torch.Tensor(data) #self._data[idx].toarray())
        
        ## Normalize entries if necessary
        ## Various possibilities here. Should the sum be equal to 1, or should the maximum value in the image be 1?
        ## Alternatively, maybe I should transform it so the sqrt of all values is used (to amplify the small features?)
        if self._normalize:
            # data = data/np.amax(data.numpy())
            data = data -1.2 #/np.sum(data.numpy())

        ## By default, this is assumed to be in "Tensor, label" format. The collate function is necessary because this is different...
        return data

def collate(batch):
    batched_data = torch.cat([sample[None][None] for sample in batch],0)
    return batched_data

In [None]:
from torchvision import transforms
def sqrt_transform(x):
    return np.sqrt(x)
def cbrt_transform(x):
    return np.cbrt(x)
def log_transform(x):
    return np.log10(1+x)

transform = transforms.Compose([
    transforms.Lambda(log_transform)
])

## Get a concrete dataset and data loader
inFile = "/global/cfs/cdirs/dune/users/cwilk/single_module_images/sparse_joblib_fixdupes_pluscuts_noneg/training_images_200k.joblib"
## Small version
inFile = "/global/cfs/cdirs/dune/users/cwilk/single_module_images/sparse_joblib_fixdupes_pluscuts_noneg_transform/packet_2022_02_07_23_09_05_CET_0cd913fb_20220207_230906.data.module1_flow_images.joblib"
start = time.process_time() 
train_dataset = SingleModuleImage2D_sparse_joblib(inFile, False) #, transform=transform)
print("Time taken to load", train_dataset.__len__(),"images:", time.process_time() - start)

In [None]:
## Randomly chosen batching
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           collate_fn=collate,
                                           batch_size=128,
                                           shuffle=True, 
                                           num_workers=4,
                                           drop_last=True)

In [None]:
start = time.process_time()
minVal = 0
maxVal = 0
for image_batch in train_loader:
    if torch.max(image_batch) > maxVal:
        maxVal = torch.max(image_batch)
    if torch.min(image_batch) < minVal:
        minVal = torch.min(image_batch)
print("Time taken to loop:", time.process_time() - start)
print("Found a minimum value of:", minVal)
print("Found a maximum value of:", maxVal)

In [None]:
## Visualise data
# Access a specific instance
data = train_dataset[2]
print(data.size())
# The data instance is a dictionary
# print('List of keys in a data element',data.keys(),'\n')

# Visualize the image
gr = plt.imshow(data, origin='lower')
plt.colorbar(gr)
plt.show()

In [None]:
## Define the encoder and decoders that do the business
from torch import nn
class EncoderSimple(nn.Module):
    
    def __init__(self, 
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.LeakyReLU):
        """
        Inputs:
            - base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        n_chan = base_channel_size
        
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            ## Note the assumption that the input image has a single channel
            nn.Conv2d(in_channels=1, 
                      out_channels=n_chan, 
                      kernel_size=3, stride=2, padding=1), ## 280x140 ==> 140x70
            act_fn(),
            nn.Conv2d(in_channels=n_chan, 
                      out_channels=n_chan, 
                      kernel_size=3, padding=1), ## No change in size
            act_fn(),
            nn.Conv2d(in_channels=n_chan, 
                      out_channels=n_chan, 
                      kernel_size=3, padding=1), ## No change in size
            act_fn(),
            nn.Conv2d(in_channels=n_chan, 
                      out_channels=2*n_chan, 
                      kernel_size=3, stride=2, padding=1), ## 140x70 ==> 70x35
            act_fn(),
            nn.Conv2d(in_channels=2*n_chan, 
                      out_channels=2*n_chan, 
                      kernel_size=3, padding=1), ## No change in size
            act_fn(),
            nn.Conv2d(in_channels=2*n_chan, 
                      out_channels=2*n_chan, 
                      kernel_size=3, padding=1), ## No change in size
            act_fn(),
            nn.Conv2d(in_channels=2*n_chan, 
                      out_channels=4*n_chan, 
                      kernel_size=3, stride=2, padding=(1,0)), ## 35x17
            act_fn(),
            ## Add a 1x1 convolution to reduce the number of layers here
            #nn.Conv2d(in_channels=4*n_chan, out_channels=n_chan, kernel_size=1),
            #act_fn()
        )
        
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)
        
        ### Linear section, simple for now
        ## This is 8960...
        self.encoder_lin = nn.Sequential(
            ## Number of nodes in last layer (16*n_chan) multiplied by number of pixels in deepest layer (4x4)
            nn.Linear(4*n_chan*35*17, 1000),
            #act_fn(),      
            #nn.Linear(1000,500),
            act_fn(),      
            nn.Linear(1000,latent_dim)
        )
        
        # Initialize weights using Xavier initialization
        self.initialize_weights()

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        x = self.encoder_cnn(x)
        x = self.flatten(x)
        x = self.encoder_lin(x)
        return x
    
class DecoderSimple(nn.Module):
    
    def __init__(self, 
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.LeakyReLU):
        """
        Inputs:
            - base_channel_size : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the decoder network
        """
        super().__init__()
        n_chan = base_channel_size

        self.decoder_lin = nn.Sequential(
            nn.Linear(latent_dim, 1000),
            act_fn(),
            #nn.Linear(500, 1000),
            #act_fn(),
            nn.Linear(1000, 4*n_chan*35*17),
            act_fn()
        )

        self.unflatten = nn.Unflatten(dim=1, 
        unflattened_size=(4*n_chan, 35, 17))

        self.decoder_conv = nn.Sequential(  
            nn.ConvTranspose2d(in_channels=4*n_chan, 
                               out_channels=2*n_chan, 
                               kernel_size=3, stride=2, padding=(1,0), output_padding=(1,0)), ## 35x17 ==> 70x35
            act_fn(),
            nn.Conv2d(in_channels=2*n_chan,
                      out_channels=2*n_chan,
                      kernel_size=3, padding=1), ## No change in size
            act_fn(), 
            nn.Conv2d(in_channels=2*n_chan,
                      out_channels=2*n_chan,
                      kernel_size=3, padding=1), ## No change in size
            act_fn(), 
            nn.ConvTranspose2d(in_channels=2*n_chan, 
                               out_channels=n_chan, 
                               kernel_size=3, stride=2, padding=1, output_padding=1), ## 70x35 ==> 140x70
            act_fn(),
            nn.Conv2d(in_channels=n_chan,
                      out_channels=n_chan,
                      kernel_size=3, padding=1), ## No change in size
            act_fn(),
            nn.Conv2d(in_channels=n_chan,
                      out_channels=n_chan,
                      kernel_size=3, padding=1), ## No change in size
            act_fn(),
            nn.ConvTranspose2d(in_channels=n_chan, 
                               out_channels=1, 
                               kernel_size=3, stride=2, padding=1, output_padding=1), ## 140x70 ==> 280x140
            act_fn()
        )
        
        # Initialize weights using Xavier initialization
        self.initialize_weights()

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        # x = torch.sigmoid(x)
        return x

In [None]:
from torch import nn
class EncoderDeep(nn.Module):
    
    def __init__(self, 
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.LeakyReLU,
                 drop_fract : float = 0.2):
        """
        Inputs:
            - base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        n_chan = base_channel_size
        
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            ## Note the assumption that the input image has a single channel
            nn.Conv2d(in_channels=1, out_channels=n_chan, kernel_size=3, stride=2, padding=1), ## 280x140 ==> 140x70
            nn.BatchNorm2d(n_chan),
            #nn.LayerNorm([n_chan, 140, 70]),
            act_fn(),
            nn.Dropout(drop_fract),
            #nn.Conv2d(in_channels=n_chan, out_channels=n_chan, kernel_size=3, padding=1), ## No change in size
            nn.BatchNorm2d(n_chan),
            #nn.LayerNorm([n_chan, 140, 70]),
            act_fn(),
            # nn.Dropout(drop_fract),
            nn.Conv2d(in_channels=n_chan, out_channels=n_chan, kernel_size=3, padding=1), ## No change in size
            nn.BatchNorm2d(n_chan),
            #nn.LayerNorm([n_chan, 140, 70]),          
            act_fn(),
            nn.Dropout(drop_fract),
            nn.Conv2d(in_channels=n_chan, out_channels=2*n_chan, kernel_size=3, stride=2, padding=1), ## 140x70 ==> 70x35
            nn.BatchNorm2d(2*n_chan),
            #nn.LayerNorm([2*n_chan, 70, 35]),
            act_fn(),
            nn.Dropout(drop_fract),
            nn.Conv2d(in_channels=2*n_chan, out_channels=2*n_chan, kernel_size=3, padding=1), ## No change in size
            nn.BatchNorm2d(2*n_chan),
            #nn.LayerNorm([2*n_chan, 70, 35]),
            act_fn(),
            nn.Dropout(drop_fract),
            nn.Conv2d(in_channels=2*n_chan, out_channels=2*n_chan, kernel_size=3, padding=1), ## No change in size
            nn.BatchNorm2d(2*n_chan),
            #nn.LayerNorm([2*n_chan, 70, 35]),
            act_fn(),
            nn.Dropout(drop_fract),
            nn.Conv2d(in_channels=2*n_chan, out_channels=4*n_chan, kernel_size=3, stride=2, padding=1), ## 70x35 ==> 35x18
            nn.BatchNorm2d(4*n_chan),
            #nn.LayerNorm([4*n_chan, 35, 18]),
            act_fn(),
            nn.Dropout(drop_fract),
            nn.Conv2d(in_channels=4*n_chan, out_channels=4*n_chan, kernel_size=3, padding=1), ## No change in size
            nn.BatchNorm2d(4*n_chan),
            #nn.LayerNorm([4*n_chan, 35, 18]),
            #act_fn(),
            nn.Dropout(drop_fract),
            nn.Conv2d(in_channels=4*n_chan, out_channels=4*n_chan, kernel_size=3, padding=1), ## No change in size
            nn.BatchNorm2d(4*n_chan),
            #nn.LayerNorm([4*n_chan, 35, 18]),
            act_fn(),
            nn.Dropout(drop_fract),
            nn.Conv2d(in_channels=4*n_chan, out_channels=8*n_chan, kernel_size=3, stride=2, padding=1), ## 35x18 ==> 18x9
            nn.BatchNorm2d(8*n_chan),
            #nn.LayerNorm([8*n_chan, 18, 9]),
            act_fn(),
            nn.Dropout(drop_fract)

        )
        
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)
        
        ### Linear section, simple for now
        self.encoder_lin = nn.Sequential(
            ## Number of nodes in last layer multiplied by number of pixels in deepest layer
            nn.Linear(8*n_chan*18*9, 1024),
            nn.BatchNorm1d(1024),
            act_fn(),
            # nn.Dropout(drop_fract),
            #nn.Linear(512, 512),
            # nn.BatchNorm1d(1000),
            #act_fn(),
            ## https://openaccess.thecvf.com/content_CVPR_2019/papers/Li_Understanding_the_Disharmony_Between_Dropout_and_Batch_Normalization_by_Variance_CVPR_2019_paper.pdf
            ## This paper suggests that dropout and batchnorm don't play well, but adding dropout at the last stage can help
            nn.Dropout(drop_fract),
            nn.Linear(1024, latent_dim),
            # act_fn()
        )
        
        # Initialize weights using Xavier initialization
        self.initialize_weights()

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        x = self.encoder_cnn(x)
        x = self.flatten(x)
        x = self.encoder_lin(x)
        return x
    
class DecoderDeep(nn.Module):
    
    def __init__(self, 
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.LeakyReLU):
        """
        Inputs:
            - base_channel_size : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the decoder network
        """
        super().__init__()
        n_chan = base_channel_size

        self.decoder_lin = nn.Sequential(
            nn.Linear(latent_dim, 1024),
            # nn.BatchNorm1d(1000),
            #act_fn(),
            #nn.Linear(512, 512),
            act_fn(),
            nn.Linear(1024, 8*n_chan*18*9),
            # nn.BatchNorm1d(8*n_chan*17*8),
            act_fn()
        )

        self.unflatten = nn.Unflatten(dim=1, 
        unflattened_size=(8*n_chan, 18, 9))

        self.decoder_conv = nn.Sequential(  
            nn.ConvTranspose2d(in_channels=8*n_chan, out_channels=4*n_chan, kernel_size=3, stride=2, padding=1, output_padding=(0,1)), ## 18x9 ==> 35x18
            nn.BatchNorm2d(4*n_chan),
            #nn.LayerNorm([4*n_chan, 35, 18]),
            act_fn(),
            nn.Conv2d(in_channels=4*n_chan, out_channels=4*n_chan, kernel_size=3, padding=1), ## No change in size
            nn.BatchNorm2d(4*n_chan),
            #nn.LayerNorm([4*n_chan, 35, 18])
            act_fn(),
            nn.Conv2d(in_channels=4*n_chan, out_channels=4*n_chan, kernel_size=3, padding=1), ## No change in size
            nn.BatchNorm2d(4*n_chan),
            #nn.LayerNorm([4*n_chan, 35, 18]),
            act_fn(),
            nn.ConvTranspose2d(in_channels=4*n_chan, out_channels=2*n_chan, kernel_size=3, stride=2, padding=1, output_padding=(1,0)), ## 35x18 ==> 70x35
            nn.BatchNorm2d(2*n_chan),
            #nn.LayerNorm([2*n_chan, 70, 35]),            
            act_fn(),
            #nn.Conv2d(in_channels=2*n_chan, out_channels=2*n_chan, kernel_size=3, padding=1), ## No change in size
            nn.BatchNorm2d(2*n_chan),
            #nn.LayerNorm([2*n_chan, 70, 35]),            
            act_fn(), 
            nn.Conv2d(in_channels=2*n_chan, out_channels=2*n_chan, kernel_size=3, padding=1), ## No change in size
            nn.BatchNorm2d(2*n_chan),
            #nn.LayerNorm([2*n_chan, 70, 35]),                      
            act_fn(), 
            nn.ConvTranspose2d(in_channels=2*n_chan, out_channels=n_chan, kernel_size=3, stride=2, padding=1, output_padding=1), ## 70x35 ==> 140x70
            nn.BatchNorm2d(n_chan),
            #nn.LayerNorm([n_chan, 140, 70]),            
            act_fn(),
            nn.Conv2d(in_channels=n_chan, out_channels=n_chan, kernel_size=3, padding=1), ## No change in size
            nn.BatchNorm2d(n_chan),
            #nn.LayerNorm([n_chan, 140, 70]),            
            act_fn(),
            nn.Conv2d(in_channels=n_chan, out_channels=n_chan, kernel_size=3, padding=1), ## No change in size
            nn.BatchNorm2d(n_chan),
            #nn.LayerNorm([n_chan, 140, 70]),                        
            act_fn(),
            nn.ConvTranspose2d(in_channels=n_chan, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1), ## 140x70 ==> 280x140
            # nn.BatchNorm2d(1),
            act_fn()
        )
        # Initialize weights using Xavier initialization
        self.initialize_weights()

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        return x


### Define train and test functions

In [None]:
def plot_ae_outputs(encoder,decoder,n=10):  
    
    plt.figure(figsize=(12,6))

    encoder.eval()
    decoder.eval()
    ## Loop over figures
    for i in range(n):
        ax = plt.subplot(3,n,i+1)
        ## This is not working the way I expect when shuffle is not on...
        ## It always gives the first image...
        img = next(iter(train_loader))
        with torch.no_grad():
            img = img.to(device)
            # temp=encoder(img)
            rec_img  = decoder(encoder(img))
        this_input  = img[0].cpu().numpy().squeeze()
        this_output = rec_img[0].cpu().numpy().squeeze()
        
        ## Input row
        plt.imshow(this_input, cmap='viridis', origin='lower')            
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2: ax.set_title('Original images')
        
        ## Reconstructed row
        ax = plt.subplot(3, n, i + 1 + n)
        plt.imshow(this_output, cmap='viridis', origin='lower')  
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2: ax.set_title('Reconstructed images')
        
        ## In - rec row
        ax = plt.subplot(3, n, i + 1 + 2*n)
        plt.imshow(this_input-this_output, cmap='viridis', origin='lower')  
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2: ax.set_title('Input - reco images')
        
    plt.show()   

In [None]:
# Function to plot weight distribution
def plot_weight_distribution(model):
    # Collect all the weights from the model
    all_weights = []
    for param in model.parameters():
        all_weights.extend(param.data.cpu().numpy().flatten())
    
    # Plot histogram of weights
    plt.figure(figsize=(8, 6))
    plt.hist(all_weights, bins=50, color='blue', alpha=0.7)
    plt.title('Weight Distribution')
    plt.xlabel('Weight Value')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.show()

In [None]:
def plot_distribution_from_dataloader(data_loader, encoder, decoder):
    encoder.eval()
    decoder.eval()
    
    # Initialize empty lists to store histogram counts
    num_bins=50
    input_hist = np.zeros(num_bins, dtype=int)
    output_hist = np.zeros(num_bins, dtype=int)

    # Create logarithmically spaced bins
    # bins = np.logspace(0, np.log10(1000), num=num_bins+1)
    bins = np.linspace(0.1, 3.1, num=num_bins+1)
    with torch.no_grad():
        for image_batch in data_loader:

            image_batch = image_batch.to(device)
            # Encode data
            encoded_batch = encoder(image_batch)
            # Decode data
            decoded_batch = decoder(encoded_batch)
            
            # Flatten input and output tensors to 1D arrays
            # Update input histogram
            input_hist += np.histogram(image_batch.cpu().numpy(), bins=bins)[0]

            # Update output histogram
            output_hist += np.histogram(decoded_batch.cpu().numpy(), bins=bins)[0]

    # Plot distribution of input and output values
    plt.figure(figsize=(10, 5))
    plt.plot(bins[:-1], input_hist, color='blue', label='Input')
    plt.plot(bins[:-1], output_hist, color='orange', label='Output')
    plt.title('Distribution of Input and Output Values')
    plt.xlabel('Value')
    plt.ylabel('Frequency')
    # plt.xscale('log')
    plt.legend()
    plt.show()

In [None]:
## This is a loss function to deweight the penalty for getting blank pixels wrong
class AsymmetricLoss(torch.nn.Module):
    def __init__(self, nonzero_cost=2.0, zero_cost=1.0, l1_weight=0):
        super(AsymmetricLoss, self).__init__()
        self.nonzero_cost = nonzero_cost
        self.zero_cost = zero_cost
        self.l1_weight = l1_weight
        self.null_value = 0
    
    def forward(self, predictions, targets, encoder, decoder):
        ## Calculate the absolute difference between predictions and targets
        sq_err = (predictions - targets)**2
        
        ## Calculate the loss for nonzero values
        nonzero = self.nonzero_cost * torch.where(targets != self.null_value, sq_err, torch.zeros_like(sq_err))
        
        ## Calculate the loss for predicting a nonzero value for zero
        zero = self.zero_cost * torch.where(targets == self.null_value, torch.where(predictions != self.null_value, sq_err, torch.zeros_like(sq_err)), torch.zeros_like(sq_err))

        ## Total loss is the sum of nonzero_loss and zero_loss
        reco_loss = torch.mean(zero + nonzero)
        
        ## Add the L1 norm term
        l1_norm = torch.tensor(0., device=predictions.device)
        num_params = sum(p.numel() for p in encoder.parameters()) + sum(p.numel() for p in decoder.parameters())
        for param in encoder.parameters(): l1_norm += torch.norm(param, p=1)
        for param in decoder.parameters(): l1_norm += torch.norm(param, p=1)
    
        total_loss = reco_loss + self.l1_weight*reco_loss.item()*l1_norm/num_params
        # print("Total loss =", total_loss.item(), "=", reco_loss.item(), " (reco) +", l1_norm.item(), "(l1)", num_params)
        
        return total_loss


In [None]:
## This is a loss function to deweight the penalty for getting blank pixels wrong
class AsymmetricL1Loss(torch.nn.Module):
    def __init__(self, nonzero_cost=2.0, zero_cost=1.0, l1_weight=0):
        super(AsymmetricL1Loss, self).__init__()
        self.nonzero_cost = nonzero_cost
        self.zero_cost = zero_cost
        self.l1_weight = l1_weight
    
    def forward(self, predictions, targets, encoder, decoder):
        ## Calculate the absolute difference between predictions and targets
        diff = torch.abs(predictions - targets)
        
        ## Calculate the loss for nonzero values
        nonzero = self.nonzero_cost * torch.where(targets != 0, diff, torch.zeros_like(diff))
        
        ## Calculate the loss for predicting a nonzero value for zero
        zero = self.zero_cost * torch.where(targets == 0, torch.where(predictions != 0, diff, torch.zeros_like(diff)), torch.zeros_like(diff))

        ## Total loss is the sum of nonzero_loss and zero_loss
        reco_loss = torch.mean(zero + nonzero)
        
        ## Add the L1 norm term
        l1_norm = torch.tensor(0., device=predictions.device)
        num_params = sum(p.numel() for p in encoder.parameters()) + sum(p.numel() for p in decoder.parameters())
        for param in encoder.parameters(): l1_norm += torch.norm(param, p=1)
        for param in decoder.parameters(): l1_norm += torch.norm(param, p=1)
    
        total_loss = reco_loss + self.l1_weight*reco_loss.item()*l1_norm/num_params
        # print("Total loss =", total_loss.item(), "=", reco_loss.item(), " (reco) +", l1_norm.item(), "(l1)", num_params)
        
        return total_loss


In [None]:
## Wrap the training in a nicer function...
def run_training(num_iterations, log_dir, encoder, decoder, dataloader, optimizer, scheduler=None):

    print("Training with", num_iterations, "iterations")
    tstart = time.time()

    if log_dir:
        writer = SummaryWriter(log_dir=log_dir)

    # Set train mode for both the encoder and the decoder
    # encoder.train()
    # decoder.train()

    ## The loss function isn't going to change, so... don't
    # loss_fn = torch.nn.L1Loss()
    # loss_fn = torch.nn.MSELoss() 
    # loss_fn = torch.nn.SmoothL1Loss()
    loss_fn = AsymmetricLoss(10, 1, 0) #1e-2)
    # loss_fn2 = torch.nn.MSELoss() 
    ## Loop over the desired iterations
    for iteration in range(num_iterations):
        
        total_loss = 0
        # total_loss2 = 0
        nbatches   = 0
        
        # Set train mode for both the encoder and the decoder
        encoder.train()
        decoder.train()
    
        # Iterate over batches of images with the dataloader
        for image_batch in dataloader:
            print(image_batch.shape)

            # Move tensor to the proper device
            image_batch = image_batch.to(device)
            # Encode data
            encoded_batch = encoder(image_batch)
            # Decode data
            decoded_batch = decoder(encoded_batch)
            # Evaluate loss
            loss = loss_fn(decoded_batch, image_batch, encoder, decoder)

            # loss2 = loss_fn2(decoded_batch, image_batch)
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()        
            
            total_loss += loss.item()
            # total_loss2 += loss2.item()
            nbatches += 1
        
        ## See if we have an LR scheduler...
        if scheduler: scheduler.step() #total_loss)
        
        av_loss = total_loss/nbatches
        # av_loss2 = total_loss2/nbatches

        if log_dir: 
            writer.add_scalar('loss/train', av_loss, iteration)
        #if iteration%10 == 0:
        print("Processed", iteration, "/", num_iterations, "; loss =", av_loss) #, av_loss2)
        print("Time taken:", time.process_time() - start)
        if iteration%20 == 0: 
            ## Plot how things look
            plot_ae_outputs(encoder,decoder,10)
            ## Also plot the weights for the encoder
            #print("Plotting encoder weights")
            #plot_weight_distribution(encoder)
            #print("Plotting decoder weights")
            #plot_weight_distribution(decoder)  
            
            ## Now plot the distribution of predicted values
            plot_distribution_from_dataloader(dataloader, encoder, decoder)

In [None]:
## Test...
num_iterations=101
log_dir="log"
base_channel_size=16
latent_dim=32
act_fn=nn.LeakyReLU

## The performance seems worse... is it relu/leakyrelu? The added act_fn for the innermost layer, or removing dropout?
## It's the final activation layer... hmmmmmmm

encoder=EncoderDeep(base_channel_size, latent_dim, act_fn, 0)
decoder=DecoderDeep(base_channel_size, latent_dim, act_fn)     

#encoder=EncoderSimple(base_channel_size, latent_dim, act_fn)
#decoder=DecoderSimple(base_channel_size, latent_dim, act_fn)
# print(encoder)
# print(decoder)

encoder.to(device)
decoder.to(device)

params_to_optimize = [
        {'params': encoder.parameters()},
        {'params': decoder.parameters()}
    ]

# lr=1e-2 # For the deep version
lr=1e-4
weight_decay=0 # 1e-3
optimizer = torch.optim.AdamW(params_to_optimize, lr=lr, weight_decay=weight_decay)
# optimizer = torch.optim.SGD(params_to_optimize, lr=lr, weight_decay=weight_decay, momentum=0.5)
scheduler = None #torch.optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.5)
                                            #max_lr=1e-3, total_steps=num_iterations, cycle_momentum=False)

## Try a new scheduler
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=5, verbose=True)
#scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, milestones=[200,400,600], gamma=0.5, last_epoch=-1, verbose=False)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-2, total_steps=num_iterations, cycle_momentum=False)

run_training(num_iterations, log_dir, encoder, decoder, train_loader, optimizer, scheduler)

## Look at the post-training parameters
print("Plotting encoder weights")
plot_weight_distribution(encoder)
print("Plotting decoder weights")
plot_weight_distribution(decoder)

In [None]:
## Now take the trained model and try to run some unsupervised learning on it...
import pandas as pd 

## Make a single loader to loop over for ease
single_loader = torch.utils.data.DataLoader(train_dataset,
                                            collate_fn=collate,
                                            batch_size=1,
                                            shuffle=True,
                                            num_workers=1)

encoded_samples = []
encoded_images  = []
for img in single_loader:
    img = img.to(device)
    
    # Encode image
    encoder.eval()
    with torch.no_grad():
        encoded_img  = encoder(img)
    # Append to list
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    # print(type(img.cpu().numpy()))
    encoded_sample['nhits'] = np.count_nonzero(img.cpu().numpy())
    encoded_samples.append(encoded_sample)
    encoded_images .append(encoded_img)
encoded_samples = pd.DataFrame(encoded_samples)

#labels=pd.DataFrame(labels)

In [None]:
## Make a plot of what it looks like
plt.scatter(encoded_samples["Enc. Variable 0"], encoded_samples["Enc. Variable 1"],vmin=100, vmax=500, c=encoded_samples["nhits"])

In [None]:
## Now TSNE it up
from sklearn.manifold import TSNE

perp=500
exag=200
print("Perplexity =", perp, "early exaggeration =", exag)
tsne = TSNE(n_components=2, perplexity=perp, n_iter=1000, early_exaggeration=exag)#, verbose=1, perplexity=60, n_iter=1000, early_exaggeration=20)
tsne_results = tsne.fit_transform(encoded_samples)

In [None]:
gr = plt.scatter(list(zip(*tsne_results))[0], list(zip(*tsne_results))[1], s=1, alpha=0.8, vmin=100, vmax=500, c=encoded_samples["nhits"])
plt.colorbar(gr)
plt.show()

In [None]:
## Make a single loader to loop over for ease
single_loader = torch.utils.data.DataLoader(train_dataset,
                                            collate_fn=collate,
                                            batch_size=1,
                                            shuffle=True,
                                            num_workers=1)

encoded_images  = []
for img in single_loader:
    img = img.to(device)
    
    # Encode image
    encoder.eval()
    with torch.no_grad():
        encoded_img  = encoder(img)
    # Append to list
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_images .append(encoded_img)
encoded_images = np.vstack(encoded_images)

In [None]:
## Try k-NN algorithm
from sklearn.neighbors import NearestNeighbors

# Assuming `latent_space` is your latent space representation
latent_space = encoded_images #np.array(latent_space)  # Ensure latent_space is a NumPy array

# Find the distances to the k-nearest neighbors
k = 5  # You can set k equal to min_samples
neighbors = NearestNeighbors(n_neighbors=k)
neighbors_fit = neighbors.fit(latent_space)
distances, indices = neighbors_fit.kneighbors(latent_space)

# Sort distances to the k-th nearest neighbor (ascending order)
distances = np.sort(distances, axis=0)
distances = distances[:, 1]

# Plot the distances
plt.figure(figsize=(10, 6))
plt.plot(distances)
plt.title('k-NN Distance Plot')
plt.xlabel('Points sorted by distance to {}-th nearest neighbor'.format(k))
plt.ylabel('Distance')
plt.show()

In [None]:
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler

scaled_encoded_images = encoded_images #StandardScaler().fit_transform(encoded_images)

plt.scatter(scaled_encoded_images[:, 0], scaled_encoded_images[:, 1], s=1)
plt.show()

dbscan = DBSCAN(eps=200, min_samples=5)

clusters = dbscan.fit(scaled_encoded_images)

labels = clusters.labels_

# Number of clusters in labels, ignoring noise if present.
n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
n_noise_ = list(labels).count(-1)

n_points = [list(labels).count(x) for x in range(n_clusters_)]

print("Estimated number of clusters: %d" % n_clusters_)
print("N. points in clusters:", n_points)
print("Estimated number of noise points: %d" % n_noise_)
print("(Out of a total of %d images)" % len(scaled_encoded_images))

In [None]:
unique_labels = set(labels)
core_samples_mask = np.zeros_like(labels, dtype=bool)
core_samples_mask[dbscan.core_sample_indices_] = True

colors = [plt.cm.Spectral(each) for each in np.linspace(0, 1, len(unique_labels))]
for k, col in zip(unique_labels, colors):
    if k == -1:
        # Black used for noise.
        col = [0, 0, 0, 1]

    class_member_mask = labels == k

    xy = scaled_encoded_images[class_member_mask & core_samples_mask]
    plt.plot(
        xy[:, 0],
        xy[:, 1],
        "o",
        markerfacecolor=tuple(col),
        markeredgecolor="k",
        markersize=14,
    )

    xy = scaled_encoded_images[class_member_mask & ~core_samples_mask]
    plt.plot(
        xy[:, 0],
        xy[:, 1],
        "o",
        markerfacecolor=tuple(col),
        markeredgecolor="k",
        markersize=0.1,
    )

plt.title(f"Estimated number of clusters: {n_clusters_}")
plt.show()


In [None]:
## Now take all of the above, and put it into t-SNE for visualization
import pandas as pd 

encoded_samples = []
index=0
for img in single_loader:
    ## Skip noise
    if labels[index] == -1:
        index += 1
        continue
    img = img.to(device)
    
    # Encode image
    encoder.eval()
    with torch.no_grad():
        encoded_img  = encoder(img)
    # Append to list
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    # print(type(img.cpu().numpy()))
    encoded_sample['nhits'] = np.count_nonzero(img.cpu().numpy())
    encoded_sample['db_cluster'] = labels[index]
    encoded_samples.append(encoded_sample)
    index+=1

encoded_samples = pd.DataFrame(encoded_samples)

In [None]:
## Now TSNE it up
from sklearn.manifold import TSNE

perp=50
exag=50
print("Perplexity =", perp, "early exaggeration =", exag)
tsne = TSNE(n_components=2, perplexity=perp, n_iter=1000, early_exaggeration=exag)#, verbose=1, perplexity=60, n_iter=1000, early_exaggeration=20)
tsne_results = tsne.fit_transform(encoded_samples)

In [None]:
## Visualise the results
plt.scatter(list(zip(*tsne_results))[0], list(zip(*tsne_results))[1], s=4, c=encoded_samples["db_cluster"])

In [None]:
## Function to show examples for each cluster
def plot_cluster_examples(raw_images, labels, index, max_images=10): 
    
    plt.figure(figsize=(12,4.5))

    ## Get a mask of labels
    indices = np.where(np.array(labels) == index)[0]
    
    ## Grab the first 10 images (if there are 10)
    if len(indices) < max_images:
        max_images = len(indices)
    
    ## Plot
    for i in range(max_images):
        ax = plt.subplot(2,max_images,i+1)
        plt.imshow(raw_images[indices[i]], origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)            
    plt.show()   

In [None]:
## Now pull out a bank of example images for each cluster

for index in range(n_clusters_):
    print("Showing examples for cluster:", index, "which has", n_points[index], "values")
    plot_cluster_examples(train_dataset, labels, index)

print("Showing examples for the noise, which has", n_noise_, "values")
plot_cluster_examples(train_dataset, labels, -1)