In [1]:
import keras
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [2]:
DATA_DIR = "mu3e_trigger_data"
SIGNAL_DATA_FILE = f"{DATA_DIR}/run42_sig_positions.npy"
BACKGROUND_DATA_FILE = f"{DATA_DIR}/run42_bg_positions.npy"

max_barrel_radius = 86.3
max_endcap_distance = 372.6

In [3]:
signal_data = np.load(SIGNAL_DATA_FILE)
background_data = np.load(BACKGROUND_DATA_FILE)

background_data[background_data[:, :, 0] != -1, 0] = (
    background_data[background_data[:, :, 0] != -1, 0] + max_barrel_radius
) / max_barrel_radius
background_data[background_data[:, :, 0] != -1, 1] = (
    background_data[background_data[:, :, 0] != -1, 1] + max_barrel_radius
) / max_barrel_radius
background_data[background_data[:, :, 0] != -1, 2] = (
    background_data[background_data[:, :, 0] != -1, 2] + max_endcap_distance / 2
) / max_endcap_distance

signal_data[signal_data[:, :, 0] != -1, 0] /= max_barrel_radius
signal_data[signal_data[:, :, 0] != -1, 1] /= max_barrel_radius
signal_data[signal_data[:, :, 0] != -1, 2] /= max_endcap_distance

In [4]:
# Transfrom data to cyclic coordinates
r_bg = np.sqrt(background_data[:, :, 0] ** 2 + background_data[:, :, 1] ** 2)
phi_bg = np.arctan2(background_data[:, :, 1], background_data[:, :, 0])
z_bg = background_data[:, :, 2]
background_data_cyclindric = np.stack([r_bg, phi_bg, z_bg], axis=-1)
background_data_cyclindric[background_data[:, :, 0] == -1, :] = -1

In [5]:
class MMDLoss(tf.keras.losses.Loss):
    def __init__(self, latent_dim, kernel='rbf', sigma=1.0, weight=1.0, **kwargs):
        super().__init__(**kwargs)
        self.latent_dim = latent_dim
        self.kernel = kernel
        self.sigma = sigma
        self.weight = weight

    def call(self, y_true, y_pred):
        z = y_pred  # shape: (batch_size, latent_dim)
        batch_size = tf.shape(z)[0]
        prior = tf.random.normal(shape=(batch_size, self.latent_dim))  # standard Gaussian

        return self.weight * self._mmd(z, prior)

    def _mmd(self, x, y):
        xx = self._compute_kernel(x, x)
        yy = self._compute_kernel(y, y)
        xy = self._compute_kernel(x, y)
        return tf.reduce_mean(xx + yy - 2 * xy)

    def _compute_kernel(self, x, y):
        x = tf.expand_dims(x, 1)  # shape: (batch, 1, dim)
        y = tf.expand_dims(y, 0)  # shape: (1, batch, dim)
        dist = tf.reduce_sum((x - y) ** 2, axis=2)
        return tf.exp(-dist / (2 * self.sigma ** 2))


class SlicedWassersteinLoss(tf.keras.losses.Loss):
    def __init__(self, latent_dim, num_projections=100, weight=1.0, **kwargs):
        super().__init__(**kwargs)
        self.latent_dim = latent_dim
        self.num_projections = num_projections
        self.weight = weight

    def call(self, y_true, y_pred):
        z = y_pred  # shape: (batch_size, latent_dim)
        batch_size = tf.shape(z)[0]

        proj_vectors = tf.random.normal((self.num_projections, self.latent_dim))
        proj_vectors = tf.math.l2_normalize(proj_vectors, axis=1)  # shape: (num_proj, latent_dim)

        z_proj = tf.matmul(z, proj_vectors, transpose_b=True)  # shape: (batch, num_proj)
        prior = tf.random.normal((batch_size, self.latent_dim))
        prior_proj = tf.matmul(prior, proj_vectors, transpose_b=True)

        z_sorted = tf.sort(z_proj, axis=0)
        prior_sorted = tf.sort(prior_proj, axis=0)

        return self.weight * tf.reduce_mean(tf.square(z_sorted - prior_sorted))


In [6]:
class EncodingLoss(keras.losses.Loss):
    def __init__(self, latent_dim, diversity_encouragement=1, name=None):
        super().__init__(name)
        self.diversity_encouragement = diversity_encouragement
        self.mmd_loss = MMDLoss(latent_dim=latent_dim, weight=1)

    def call(self, y_true, y_pred):
        latent_dim = y_pred.shape[-1] // 2
        input, ae_output = tf.split(y_pred, [latent_dim, latent_dim], axis=-1)
        ae_loss = tf.reduce_mean(tf.square(input - ae_output))

        # include a regularization term to encourage diversity in the latent space to encourage the model to learn a variance of 1
        if self.diversity_encouragement > 0:
            diversity_loss = tf.reduce_sum(
                tf.square(
                    tf.ones(
                        input.shape[-1],
                    )
                    - tf.math.reduce_variance(input, axis=0)
                )
            )
        else:
            diversity_loss = 0.0
        return ae_loss + self.diversity_encouragement * diversity_loss 


class ReconstructionQuality(keras.metrics.Metric):
    def __init__(self, name="reconstruction_quality", **kwargs):
        super().__init__(name=name, **kwargs)
        self.total_loss = self.add_weight(name="total_loss", initializer="zeros")
        self.count = self.add_weight(name="count", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        latent_dim = y_pred.shape[-1] // 2
        input, ae_output = tf.split(y_pred, [latent_dim, latent_dim], axis=-1)
        loss = tf.reduce_mean(tf.square(input - ae_output))
        self.total_loss.assign_add(loss)
        self.count.assign_add(1)

    def result(self):
        return self.total_loss / self.count


class FeatureVariance(keras.metrics.Metric):
    def __init__(self, name="feature_variance", **kwargs):
        super().__init__(name=name, **kwargs)
        self.variance = self.add_weight(name="variance", initializer="zeros")
        self.count = self.add_weight(name="count", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        latent_dim = y_pred.shape[-1] // 2
        input, ae_output = tf.split(y_pred, [latent_dim, latent_dim], axis=-1)
        variance = tf.math.reduce_variance(input, axis=0)
        self.variance.assign_add(tf.reduce_mean(variance))
        self.count.assign_add(1)

    def result(self):
        return self.variance / self.count

In [7]:
sequence_length = 256
feature_dim = 3
hidden_dim = 6
latent_dim = 64

In [None]:
from src.model.components import (
    SelfAttentionStack,
    MLP,
    PoolingAttentionBlock,
    PointTransformerFromCoords,
    DecoderQueries,
    MultiHeadAttentionBlock,
)
from src.model.components import (
    GenerateDecoderMask,
    GenerateMask,
    MaskOutput,
    GetSequenceLength,
)

# Fixed size encoding models
input = keras.Input(shape=(sequence_length, feature_dim), name="input")
mask = GenerateMask(name="generate_mask")(input)
sequence_length_layer = GetSequenceLength(name="get_sequence_length")(mask)

input_embedding = MLP(name="input_embedding", output_dim=latent_dim, num_layers=3)(
    input
)

attention_block = SelfAttentionStack(
    name="self_attention_block", num_heads=4, key_dim=latent_dim, stack_size=1
)(input_embedding, mask)
pooling = keras.layers.GlobalAveragePooling1D(name="global_average_pooling")(attention_block, mask)

if False:
    attention_pooling = PoolingAttentionBlock(
        name="pooling_attention_block",
        num_heads=8,
        key_dim=hidden_dim,
        dropout_rate=0.0,
        num_seed_vectors=latent_dim,
    )(attention_block, mask)
fixed_size_encoding = keras.layers.Flatten(name="flatten")(pooling)

#fixed_size_encoding = MLP(name="mlp", output_dim=1)(fixed_size_encoding)

# Autoencoder model
encoder = MLP(name="encoder", output_dim=int(latent_dim / 1.5), num_layers=4)
decoder = MLP(name="decoder", output_dim=latent_dim, num_layers=4)


output = keras.layers.Concatenate(name="concatenate")([fixed_size_encoding, decoder(encoder(fixed_size_encoding))])

# Define model
autoencoder_layers = [encoder, decoder]
# A: Embedding Model (input to fixed-size embedding)
embedding_model = keras.Model(inputs=input, outputs=fixed_size_encoding, name="embedding_model")

# B: Autoencoder Model (fixed-size embedding to reconstructed latent vector)
autoencoder_input = keras.Input(shape=(fixed_size_encoding.shape[-1],), name="ae_input")
encoded = encoder(autoencoder_input)
decoded = decoder(encoded)
ae_output = keras.layers.Concatenate(name="ae_concat")([autoencoder_input, decoded])
autoencoder_model = keras.Model(inputs=autoencoder_input, outputs=ae_output, name="autoencoder_model")

# C: Full Model (input to concatenated fixed + decoded vector)
full_model = keras.Model(inputs=input, outputs=output, name="full_model")

In [9]:
loss_fn = EncodingLoss(diversity_encouragement=2, latent_dim = latent_dim)
metrics = [ReconstructionQuality(), FeatureVariance()]

full_model.compile(optimizer=keras.optimizers.Adam(1e-3), loss=loss_fn, metrics=metrics)
autoencoder_model.compile(optimizer=keras.optimizers.Adam(1e-2), loss=EncodingLoss(diversity_encouragement = 0, latent_dim = latent_dim))


In [10]:
x_train = background_data[:10000]
epochs = 20
autoencoder_train_every = 2
autoencoder_train_steps = 100
batch_size = 1024

for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    
    if epoch % autoencoder_train_every == 0:
        print("➡️  Training only autoencoder")

        # Step 1: Freeze input embedding and extract encodings
        fixed_embeddings = embedding_model.predict(x_train, batch_size=batch_size, verbose=0)

        for i in range(autoencoder_train_steps):
            # Step 2: Train only the autoencoder on these embeddings
            autoencoder_model.fit(
                x=fixed_embeddings,
                y=fixed_embeddings,
                batch_size=batch_size,
                epochs=1,
                verbose=0
            )
    else:
        print("🔁 Training full model")
        full_model.fit(
            x=x_train,
            y=x_train,
            batch_size=batch_size,
            epochs=1,
            verbose=1
        )



Epoch 1/20
➡️  Training only autoencoder

Epoch 2/20
🔁 Training full model
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 3s/step - feature_variance: 0.0886 - loss: 107.5750 - reconstruction_quality: 0.2653

Epoch 3/20
➡️  Training only autoencoder

Epoch 4/20
🔁 Training full model
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m58s[0m 6s/step - feature_variance: 0.3062 - loss: 69.4411 - reconstruction_quality: 0.2238

Epoch 5/20
➡️  Training only autoencoder

Epoch 6/20
🔁 Training full model
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m55s[0m 5s/step - feature_variance: 0.6096 - loss: 31.6149 - reconstruction_quality: 0.3630

Epoch 7/20
➡️  Training only autoencoder

Epoch 8/20
🔁 Training full model
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m47s[0m 5s/step - feature_variance: 0.8046 - loss: 9.0890 - reconstruction_quality: 0.4476

Epoch 9/20
➡️  Training only autoencoder

Epoch 10/20
🔁 Training full model
[1m10/10[0m [32m━━━━

In [11]:
embedding_model.predict(x_train[:10], batch_size=10)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 50ms/step


array([[ 0.2918582 , -0.5106309 ,  0.14294906,  1.7450398 , -0.5682388 ,
         0.2878471 , -0.14214495, -0.26114702, -0.30662382,  0.29714516,
        -1.1316632 ,  0.09863343,  1.3607783 ,  0.9319748 , -1.2696887 ,
         0.5054124 ,  2.3402386 , -0.3812863 ,  0.18824877,  0.41202492,
         1.1504692 , -1.0950174 ,  0.43162268,  0.6681028 , -0.96428645,
         0.03551219,  0.34769547,  0.5842959 ,  2.9174266 , -1.0390228 ,
         0.7652624 ,  0.17830877, -1.1441555 ,  0.60953224, -0.8939077 ,
        -1.6499591 ,  0.556143  , -0.9901701 , -0.8074555 ,  2.7105122 ,
         0.33631784,  0.07599775, -0.9133696 ,  0.2875812 ,  0.6705343 ,
         0.12424276, -1.7945004 , -1.4842741 ,  0.26408294,  0.44468483,
        -1.1845734 ,  0.46260887,  0.8979039 , -0.00468242,  0.45253587,
        -0.6554974 ,  0.30323333,  0.6867631 , -2.0339274 , -0.47890267,
        -0.06434529,  0.43443152,  0.7792276 ,  0.59222937],
       [-0.4387758 , -0.00682363, -0.6070815 ,  2.256508  ,  0.