In [8]:
# Imports

import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
import spektral
from spektral.utils import normalized_laplacian, rescale_laplacian

import re
from tqdm import tqdm
from umap import UMAP
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.manifold import Isomap
from matplotlib import pyplot as plt

def set_mpl_style():
    plt.rcParams.update({
        "figure.dpi": 300,
        "axes.linewidth": 0.8,
        "lines.linewidth": 1.0,
        "lines.markersize": 2,
        "xtick.direction": "out",
        "ytick.direction": "out",
        "xtick.major.size": 4,
        "ytick.major.size": 4,
        "xtick.major.width": 0.8,
        "ytick.major.width": 0.8,
        "xtick.minor.visible": True,
        "ytick.minor.visible": True,
        "legend.frameon": True,
        "axes.grid": False,
        "savefig.dpi": 300,
        "savefig.format": "pdf",
        "savefig.bbox": "tight",
    })
    
class CheckpointCallback(tf.keras.callbacks.Callback):
    def __init__(self, ckpt_manager):
        super().__init__()
        self.ckpt_manager = ckpt_manager

    def on_epoch_end(self, epoch, logs=None):
        save_path = self.ckpt_manager.save()
        print(f'\nsaved checkpoint for epoch {epoch + 1}: {save_path}')

In [9]:
# Plots and visualizations

def plot_events(events, title='hit_pattern', subtitles=[]):
    assert len(events.shape) == 4, 'Events must be a 3D array with shape (num_events, num_rows * num_cols, num_samples)'
    events = np.transpose(events, axes=[3, 0, 1, 2])

    gif_frames = []
    for sample in tqdm(events):
        num_events = sample.shape[0]
        fig, ax = plt.subplots(1, num_events, figsize=(3*num_events, 3), dpi=100)
        fig.suptitle(title, fontsize=16)

        if num_events == 1: ax = [ax]

        for i, hit_pattern in enumerate(sample):
            ax[i].imshow(hit_pattern, vmin=0, vmax=5)
            ax[i].set_xticks([])
            ax[i].set_yticks([])
            ax[i].grid(False)
            
            if len(subtitles) >= i + 1:
                ax[i].set_title(subtitles[i], fontsize=8)

        fig.canvas.draw()
        width, height = fig.get_size_inches() * fig.dpi
        data = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
        image_array = data.reshape(int(height), int(width), 4)

        image_frame = Image.fromarray(image_array)
        gif_frames.append(image_frame)
        plt.clf()
        plt.close()

    filename = re.sub(r'[^a-zA-Z0-9]', '', title.lower()) + '_gif.gif'
    gif_frames[0].save(
        filename,
        save_all = True,
        duration = 20,
        loop = 0,
        append_images = gif_frames[1:]
    )

    return filename


def vis_latent_space_categories(data_categories, data_categories_labels, title):
    # plt.rcParams['figure.dpi'] = 120
    
    for data, label in zip(data_categories, data_categories_labels):
        plt.scatter(data[:, 0], data[:, 1], label=label, s=0.5)
    
    plt.gca().set_aspect('equal', adjustable='box')
    plt.title(title)
    plt.legend()
    plt.xlabel('Component 1')
    plt.ylabel('Component 2')
    plt.savefig(re.sub(r'[^a-zA-Z0-9]', '', title.lower()) + '.png')
    # plt.show()
    plt.clf()
    
def vis_latent_space_gradients(latent_space, labels, title, colorbar_label):
    # plt.rcParams['figure.dpi'] = 120
    
    plt.scatter(latent_space[:, 0], latent_space[:, 1], c=labels, cmap='viridis', s=0.5)
    plt.title(title)
    plt.colorbar().set_label(colorbar_label)
    plt.xlabel('Component 1')
    plt.ylabel('Component 2')
    plt.savefig(re.sub(r'[^a-zA-Z0-9]', '', title.lower()) + '.png')
    # plt.show()
    plt.clf()
    

def vis_latent_space_num_scatters(fit_model, latent_space, Y, N=4):
    num_scatters = np.where(Y > 0, 1, 0).sum(axis=(1, 2))
    
    num_sactters_categories = [latent_space[np.where(num_scatters == i)] for i in range(1, N + 1)]
    num_scatters_labels = [f'{i} scatter' + ('s' if i > 1 else '') for i in range(1, N + 1)]
    
    vis_latent_space_categories(num_sactters_categories, num_scatters_labels, f'Autoencoder Latent Space by Number of Scatters {type(fit_model).__name__.upper()}')


def vis_latent_space_phd(fit_model, latent_space, XC):
    phd = XC.sum(axis=(1, 2))
    vis_latent_space_gradients(latent_space, phd, f'Autoencoder Latent Space by Total Photoelectrons Deposited {type(fit_model).__name__.upper()}', colorbar_label='phd')
    
def vis_latent_space_footprint(fit_model, latent_space, XC):
    footprint = np.where(XC.sum(axis=-1) > 4, 1, 0).sum(axis=-1)
    
    vis_latent_space_gradients(latent_space, footprint, f'Autoencoder Latent Space by Footprint Size {type(fit_model).__name__.upper()}', colorbar_label='Total # PMT')
    

def codebook_usage_histogram(vqvae, XC):
    indices = vqvae.encode_to_indices_probabilistic(XC).numpy().flatten().astype(int)
    codebook_usage = np.bincount(indices, minlength=vqvae.num_embeddings)
    codebook_usage_sorted = codebook_usage[np.argsort(codebook_usage)[::-1]] / codebook_usage.sum()

    # plt.rcParams['figure.dpi'] = 120
    
    plt.fill_between(np.arange(len(codebook_usage_sorted)), codebook_usage_sorted, color='blue', alpha=0.3)
    plt.plot(np.arange(len(codebook_usage_sorted)), codebook_usage_sorted, color='blue', label='Usage')

    plt.xlabel('Codebook Index')
    plt.ylabel('Usage (PDF)')
    plt.title('VQ-VAE Codebook Usage Distribution')
    plt.margins(0)
    plt.savefig('codebook_usage.png')
    # plt.show()
    plt.clf()


def get_encoded_data_generator(compression_func, data_generator):
    while True:
        XC, XYZ, P = next(iter(data_generator))
        XC_encoded = compression_func(XC)
        yield XC_encoded, XYZ, P

def run_aux_task(models, compression_funcs, labels, data_generator, val_data_generator, fname_suffix=''):
    def precompute_batches(generator, steps):
        return [next(generator) for _ in tqdm(range(steps))]

    def train_epoch(model, train_batches, desc=''):
        total_loss = 0.0
        for batch in tqdm(train_batches, desc=desc, ncols=100):
            x, y, _ = batch
            loss = model.step(x, y, training=True)
            if isinstance(loss, (list, tuple)):
                loss = loss[0]
            total_loss += loss
        return total_loss

    def test_epoch(model, val_batches):
        total_loss = 0.0
        for batch in val_batches:
            x, y, _ = batch
            loss = model.step(x, y, training=False)
            if isinstance(loss, (list, tuple)):
                loss = loss[0]
            total_loss += loss
        return total_loss
    
    def train_and_plot_model(model, compression_func, data_generator, val_data_generator, epochs=10, steps_per_epoch=64, val_steps=4):
        if compression_func is not None:
            data_generator = get_encoded_data_generator(compression_func, data_generator)
            val_data_generator = get_encoded_data_generator(compression_func, val_data_generator)
        
        initial_val_loss = test_epoch(model, precompute_batches(val_data_generator, val_steps)) / val_steps
        
        train_times, val_losses = [0], [initial_val_loss]
        
        for epoch in range(epochs):
            train_batches = precompute_batches(data_generator, steps_per_epoch)
            val_batches = precompute_batches(data_generator, val_steps)
            
            desc = f'epoch {epoch + 1}/{epochs} ({model.name})'
            t0 = time.time()
            loss = train_epoch(model, train_batches, desc=desc)
            train_time = time.time() - t0
            avg_loss = loss / steps_per_epoch
            avg_val_loss = test_epoch(model, val_batches) / val_steps
            print(desc + f' - loss: {avg_loss:.3f}, val_loss: {avg_val_loss:.3f}')
            
            train_times.append(train_time)
            val_losses.append(avg_val_loss)
                
        K.clear_session()
            
        cdf_train_times = [sum(train_times[:i + 1]) for i in range(len(train_times))]
        best_val_losses = [min(val_losses[:i + 1]) for i in range(len(val_losses))]
            
        return cdf_train_times, best_val_losses
    
    epochs = 50
    steps_per_epoch = 64
    val_steps = 8
            
    model_cdf_train_times, model_best_val_losses = [], []
    
    for model, compression_func in zip(models, compression_funcs):
        cdf_train_times, best_val_losses = train_and_plot_model(model, compression_func, data_generator, val_data_generator, epochs=epochs, steps_per_epoch=steps_per_epoch, val_steps=val_steps)
        model_cdf_train_times.append(cdf_train_times)
        model_best_val_losses.append(best_val_losses)
    
    sample_batch, _, _ = next(iter(data_generator))
    batch_size = sample_batch.shape[0]
    
    plt.figure()
    epochs_axis = np.arange(0, epochs + 1) * steps_per_epoch * batch_size
    for label, val_losses in zip(labels, model_best_val_losses):
        plt.plot(epochs_axis, val_losses, '-o', label=label, markersize=2)
    plt.xlabel('Training Samples')
    plt.ylabel('Validation Loss')
    plt.title('Sample Efficiency: Validation Loss (Best) vs Training Samples')
    plt.legend()
    plt.savefig(f'raw_vs_compressed_sample_efficiency{fname_suffix}.png')
    
    plt.figure()
    for label, cdf_train_times, val_losses in zip(labels, model_cdf_train_times, model_best_val_losses):
        plt.plot(cdf_train_times, val_losses, '-o', label=label, markersize=2)
    plt.xlabel('Training Time (s)')
    plt.ylabel('Validation Loss')
    plt.title('Time Efficiency: Validation Loss (Best) vs Training Time')
    plt.legend()
    plt.savefig(f'raw_vs_compressed_time_efficiency{fname_suffix}.png')

In [11]:
# Encoder base class
@tf.keras.utils.register_keras_serializable(package='Custom', name='Encoder')
class Encoder(tf.keras.Model):
    def __init__(self, name='encoder'):
        super(Encoder, self).__init__(name=name)
    
    def call(self, x):
        pass
    
    def encode(self, x):
        pass
    
    def decode(self, x):
        pass
    
    def compress(self, x):
        pass
    
    def get_data_size_reduction(self):
        pass

In [12]:
# Autoencoder
@tf.keras.utils.register_keras_serializable(package='Custom', name='Autoencoder')
class Autoencoder(Encoder):
    def __init__(self, input_shape, latent_dim, encoder_layer_sizes=[], decoder_layer_sizes=[], name='autoencoder'):
        super(Autoencoder, self).__init__(name=name)
        
        self.input_shape = input_shape
        self.encoder_layer_sizes = encoder_layer_sizes
        self.decoder_layer_sizes = decoder_layer_sizes
        self.latent_dim = latent_dim
        
        self.encoder = tf.keras.Sequential(
            [tf.keras.layers.Input(input_shape)] +
            [tf.keras.layers.Flatten()] +
            [tf.keras.layers.Dense(sz, activation='relu') for sz in encoder_layer_sizes] +
            [tf.keras.layers.Dense(latent_dim)], name=f'{name}_encoder'
        )
    
        self.decoder = tf.keras.Sequential(
            [tf.keras.layers.Input((latent_dim,))] +
            [tf.keras.layers.Dense(sz, activation='relu') for sz in decoder_layer_sizes] +
            [tf.keras.layers.Dense(np.prod(input_shape), activation='softplus'),
             tf.keras.layers.Reshape(input_shape)], name=f'{name}_decoder'
        )
    
    def call(self, x):
        return self.decoder(self.encoder(x))
    
    def encode(self, x):
        return self.encoder(x)
    
    def compress(self, x):
        return self.encoder(x)
    
    def decode(self, x):
        return self.decoder(x)
    
    def compile(self, optimizer, loss, metrics=[]):
        super().compile(optimizer=optimizer, loss=loss, metrics=metrics)
        self.optimizer = optimizer
        self.loss = loss
        
    def summary(self):
        self.encoder.summary()
        self.decoder.summary()
    
    def from_config(self, config):
        return Autoencoder(**config)
    
    def get_config(self):
        return {
            'input_shape': self.input_shape,
            'latent_dim': self.latent_dim,
            'encoder_layer_sizes': self.encoder_layer_sizes,
            'decoder_layer_sizes': self.decoder_layer_sizes
        }
        
    def build(self, input_shape):
        self.input_shape = input_shape
        super().build(input_shape)
        
    def get_data_size_reducton(self):
        return np.prod(self.input_shape) / self.latent_dim

In [None]:
# Variational Autoencoder
@tf.keras.utils.register_keras_serializable(package='Custom', name='VariationalAutoencoder')
class VariationalAutoencoder(Autoencoder):
    def __init__(self, input_shape, latent_dim, encoder_layer_sizes=[], decoder_layer_sizes=[], name='variational_autoencoder'):
        super(VariationalAutoencoder, self).__init__(input_shape, latent_dim, encoder_layer_sizes, decoder_layer_sizes, name)
        
        self.encoder = tf.keras.Sequential(
            [tf.keras.layers.Input(input_shape)] +
            [tf.keras.layers.Flatten()] +
            [tf.keras.layers.Dense(sz, activation='relu') for sz in encoder_layer_sizes] +
            [tf.keras.layers.Dense(latent_dim * 2)], name=f'{name}_encoder'  # output mean & log_var
        )
        
        self.decoder = tf.keras.Sequential(
            [tf.keras.layers.Input((latent_dim,))] +
            [tf.keras.layers.Dense(sz, activation='relu') for sz in decoder_layer_sizes] +
            [tf.keras.layers.Dense(np.prod(input_shape), activation='softplus'),
             tf.keras.layers.Reshape(input_shape)], name=f'{name}_decoder'
        )
        
        self.loss_tracker = tf.keras.metrics.Mean(name='loss')
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name='reconstruction_loss')
    
    def reparameterize(self, mean, log_var):
        epsilon = tf.random.normal(shape=tf.shape(mean))
        return mean + tf.exp(0.5 * log_var) * epsilon

    def encode(self, x):
        encoder_output = self.encoder(x)
        mean, log_var = tf.split(encoder_output, num_or_size_splits=2, axis=-1)
        return mean, log_var

    def decode(self, z):
        return self.decoder(z)

    def call(self, x):
        mean, log_var = self.encode(x)
        z = self.reparameterize(mean, log_var)
        reconstructed = self.decode(z)
        return reconstructed, mean, log_var
    
    def compress(self, x):
        encoder_output = self.encoder(x)
        mean, log_var = tf.split(encoder_output, num_or_size_splits=2, axis=-1)
        
        # check this
        eps = tf.random.normal(shape=tf.shape(mean))
        sigma = tf.exp(0.5 * log_var)
        sample = mean + sigma * eps
        
        return sample
        
    
    def train_step(self, data):
        x = data[0]
        x = tf.cast(x, tf.float32)

        with tf.GradientTape() as tape:
            reconstructed, mean, log_var = self(x, training=True)
            loss = vae_loss(x, reconstructed, mean, log_var)
            r_loss = reconstruction_loss(x, reconstructed)

        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        self.loss_tracker.update_state(loss)
        self.reconstruction_loss_tracker.update_state(r_loss)
        
        return {'loss': self.loss_tracker.result(), 'reconstruction_loss': self.reconstruction_loss_tracker.result()} 

    def test_step(self, data):
        x = data[0]
        x = tf.cast(x, tf.float32)

        reconstructed, mean, log_var = self(x, training=False)
        loss = vae_loss(x, reconstructed, mean, log_var)
        r_loss = reconstruction_loss(x, reconstructed)

        self.loss_tracker.update_state(loss)
        self.reconstruction_loss_tracker.update_state(r_loss)
        
        return {'loss': self.loss_tracker.result(), 'reconstruction_loss': self.reconstruction_loss_tracker.result()}
    
    @property
    def metrics(self):
        return [self.loss_tracker, self.reconstruction_loss_tracker]

    @classmethod
    def from_config(self, config):
        return VariationalAutoencoder(**config)
    
    def get_config(self):
        return {
            'input_shape': self.input_shape,
            'latent_dim': self.latent_dim,
            'encoder_layer_sizes': self.encoder_layer_sizes,
            'decoder_layer_sizes': self.decoder_layer_sizes
        }
        
    def get_data_size_reduction(self):
        return np.prod(self.input_shape) / (self.latent_dim * 2)