In [1]:
import keras
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [2]:
DATA_DIR = "mu3e_trigger_data"
SIGNAL_DATA_FILE = f"{DATA_DIR}/run42_sig_positions.npy"
BACKGROUND_DATA_FILE = f"{DATA_DIR}/run42_bg_positions.npy"

max_barrel_radius = 86.3
max_endcap_distance = 372.6

In [3]:
signal_data = np.load(SIGNAL_DATA_FILE)
background_data = np.load(BACKGROUND_DATA_FILE)

background_data[background_data[:, :, 0] != -1, 0] = (
    background_data[background_data[:, :, 0] != -1, 0] + max_barrel_radius
) / max_barrel_radius
background_data[background_data[:, :, 0] != -1, 1] = (
    background_data[background_data[:, :, 0] != -1, 1] + max_barrel_radius
) / max_barrel_radius
background_data[background_data[:, :, 0] != -1, 2] = (
    background_data[background_data[:, :, 0] != -1, 2] + max_endcap_distance / 2
) / max_endcap_distance

signal_data[signal_data[:, :, 0] != -1, 0] /= max_barrel_radius
signal_data[signal_data[:, :, 0] != -1, 1] /= max_barrel_radius
signal_data[signal_data[:, :, 0] != -1, 2] /= max_endcap_distance

In [4]:
# Transfrom data to cyclic coordinates
r_bg = np.sqrt(background_data[:, :, 0] ** 2 + background_data[:, :, 1] ** 2)
phi_bg = np.arctan2(background_data[:, :, 1], background_data[:, :, 0])
z_bg = background_data[:, :, 2]
background_data_cyclindric = np.stack([r_bg, phi_bg, z_bg], axis=-1)
background_data_cyclindric[background_data[:, :, 0] == -1, :] = -1

In [5]:
sequence_length = 256
feature_dim = 3
hidden_dim = 6
latent_dim = 32

In [6]:
class EncodingLoss(keras.losses.Loss):
    def __init__(self, diversity_encouragement=1, name=None):
        super().__init__(name)
        self.diversity_encouragement = diversity_encouragement

    def call(self, y_true, y_pred):
        latent_dim = y_pred.shape[-1] // 2
        input, ae_output = tf.split(y_pred, [latent_dim, latent_dim], axis=-1)
        ae_loss = tf.reduce_mean(tf.square(input - ae_output))

        # include a regularization term to encourage diversity in the latent space to encourage the model to learn a variance of 1
        diversity_loss = tf.reduce_sum(
            tf.square(
                tf.ones(
                    ae_output.shape[-1],
                )
                - tf.math.reduce_variance(input, axis=0)
            )
        )
        return ae_loss + self.diversity_encouragement * diversity_loss


class ReconstructionQuality(keras.metrics.Metric):
    def __init__(self, name="reconstruction_quality", **kwargs):
        super().__init__(name=name, **kwargs)
        self.total_loss = self.add_weight(name="total_loss", initializer="zeros")
        self.count = self.add_weight(name="count", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        latent_dim = y_pred.shape[-1] // 2
        input, ae_output = tf.split(y_pred, [latent_dim, latent_dim], axis=-1)
        loss = tf.reduce_mean(tf.square(input - ae_output))
        self.total_loss.assign_add(loss)
        self.count.assign_add(1)

    def result(self):
        return self.total_loss / self.count


class FeatureVariance(keras.metrics.Metric):
    def __init__(self, name="feature_variance", **kwargs):
        super().__init__(name=name, **kwargs)
        self.variance = self.add_weight(name="variance", initializer="zeros")
        self.count = self.add_weight(name="count", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        latent_dim = y_pred.shape[-1] // 2
        input, ae_output = tf.split(y_pred, [latent_dim, latent_dim], axis=-1)
        variance = tf.math.reduce_variance(input, axis=0)
        self.variance.assign_add(tf.reduce_mean(variance))
        self.count.assign_add(1)

    def result(self):
        return self.variance / self.count

In [None]:
from src.model.components import (
    SelfAttentionStack,
    MLP,
    PoolingAttentionBlock,
    PointTransformerFromCoords,
    DecoderQueries,
    MultiHeadAttentionBlock,
)
from src.model.components import (
    GenerateDecoderMask,
    GenerateMask,
    MaskOutput,
    GetSequenceLength,
)

# Fixed size encoding models
input = keras.Input(shape=(sequence_length, feature_dim), name="input")
mask = GenerateMask(name="generate_mask")(input)
sequence_length_layer = GetSequenceLength(name="get_sequence_length")(mask)

input_embedding = MLP(name="input_embedding", output_dim=hidden_dim, num_layers=3)(
    input
)

attention_block = SelfAttentionStack(
    name="self_attention_block", num_heads=8, key_dim=hidden_dim, stack_size=3
)(input_embedding, mask)
# pooling = keras.layers.GlobalAveragePooling1D(name="global_average_pooling")(attention_block, mask)

attention_pooling = PoolingAttentionBlock(
    name="pooling_attention_block",
    num_heads=8,
    key_dim=hidden_dim,
    dropout_rate=0.1,
    num_seed_vectors=latent_dim,
)(attention_block, mask)
fixed_size_encoding = MLP(name="mlp", output_dim=1)(attention_pooling)
fixed_size_encoding = keras.layers.Flatten(name="flatten")(fixed_size_encoding)

# Autoencoder model
encoder = MLP(name="encoder", output_dim=int(latent_dim / 2), num_layers=4)(
    fixed_size_encoding
)
decoder = MLP(name="decoder", output_dim=latent_dim, num_layers=4)(encoder)

output = keras.layers.Concatenate(name="concatenate")([fixed_size_encoding, decoder])

# Define model
model = keras.Model(inputs=input, outputs=output, name="fixed_size_encoding_model")
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss=EncodingLoss(diversity_encouragement=2),
    metrics=[ReconstructionQuality(), FeatureVariance()],
)
model.summary()

In [None]:
from sklearn.model_selection import train_test_split

train_data = background_data[:10000]


model.fit(
    x=train_data,
    y=np.concatenate([train_data, train_data], axis=-1),
    batch_size=1024,
    epochs=30,
    validation_split=0.2,
)

Epoch 1/30
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m126s[0m 16s/step - feature_variance: 0.0813 - loss: 54.4903 - reconstruction_quality: 0.1942 - val_feature_variance: 0.2772 - val_loss: 34.0292 - val_reconstruction_quality: 0.5157
Epoch 2/30
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m165s[0m 20s/step - feature_variance: 0.3579 - loss: 27.5135 - reconstruction_quality: 0.7045 - val_feature_variance: 0.6039 - val_loss: 11.0432 - val_reconstruction_quality: 0.9762
Epoch 3/30
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m144s[0m 18s/step - feature_variance: 0.6127 - loss: 10.9696 - reconstruction_quality: 0.9764 - val_feature_variance: 0.9316 - val_loss: 1.4721 - val_reconstruction_quality: 1.1478
Epoch 4/30
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m138s[0m 17s/step - feature_variance: 0.9129 - loss: 1.9198 - reconstruction_quality: 1.1167 - val_feature_variance: 1.1878 - val_loss: 3.4330 - val_reconstruction_quality: 1.1446
Epoch 5/30
