In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras import layers, regularizers
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc
import h5py

def distance_correlation_tf(x, y, eps=1e-8):
    x = tf.reshape(x, [-1, 1])
    y = tf.reshape(y, [-1, 1])
    
    # Pairwise distances
    a = tf.norm(x[:, None, :] - x[None, :, :], axis=-1)
    b = tf.norm(y[:, None, :] - y[None, :, :], axis=-1)
    
    # Double centering
    A = a - tf.reduce_mean(a, axis=0, keepdims=True) \
           - tf.reduce_mean(a, axis=1, keepdims=True) \
           + tf.reduce_mean(a)
    B = b - tf.reduce_mean(b, axis=0, keepdims=True) \
           - tf.reduce_mean(b, axis=1, keepdims=True) \
           + tf.reduce_mean(b)
    
    #eps for avoinding NaN situation
    dcov = tf.sqrt(tf.maximum(tf.reduce_mean(A*B), eps))
    dvar_x = tf.sqrt(tf.maximum(tf.reduce_mean(A*A), eps))
    dvar_y = tf.sqrt(tf.maximum(tf.reduce_mean(B*B), eps))
    
    return dcov / (tf.sqrt(dvar_x*dvar_y) + eps)

def disco_loss(y_true, y_pred, hlt_batch, lambda_disco=10.0):
    """
    Compute per-batch loss: recon MSE + lambda * DisCo penalty
    Returns recon_loss, disco_penalty, total_loss
    """
    # Reconstruction MSE per sample
    recon_per_sample = tf.reduce_mean(tf.square(y_true - y_pred), axis=-1)
    tf.print("Per-sample MSE shape:", tf.shape(recon_per_sample), "values:", recon_per_sample)

    #Batch average
    recon_loss = tf.reduce_mean(recon_per_sample)
    tf.print("Final loss shape:", tf.shape(recon_loss), "value:", recon_loss)

    # Offline anomaly score (same as recon here)
    offline_scores = recon_per_sample
    
    # DisCo penalty
    disco_penalty = distance_correlation_tf(offline_scores, hlt_batch)
    
    # Total loss, add HLT AD score (MSE)
    total_loss = recon_loss + lambda_disco * disco_penalty
    
    return recon_loss, disco_penalty, total_loss

In [3]:
file_path = "/eos/user/s/ssaha/SWAN_projects/AD_Trigger_1/ntuples/lam_output/lam_test/EB_482596.h5"

with h5py.File(file_path, "r") as f:
    datasets = {k: np.array(f[k]) for k in f}
    datasets["Offline_data"] = datasets["Offline_data"].reshape(-1, 45)
print("Offline_data shape after reshape:", datasets["Offline_data"].shape)
print("Keys in file:", list(datasets.keys()))
print("Offline_data shape:", datasets["Offline_data"].shape)

output_file = '/eos/user/s/ssaha/SWAN_projects/AD_Trigger_1/output_files_DisCo/ad_scores_noDisCo.h5'
with h5py.File(output_file, 'r') as f:
    group = f['EB_test']
    HLT_AD_scores = group['HLT_AD_scores'][:].astype(np.float32)
    #Offline_AD_scores = group['Offline_AD_scores'][:].astype(np.float32)

print('HLT_AD scores',HLT_AD_scores, HLT_AD_scores.shape)

Offline_data shape after reshape: (1048336, 45)
Keys in file: ['HLT_data', 'L1_data', 'Offline_data', 'event_numbers', 'passHLT', 'passL1', 'pileups', 'run_numbers', 'topo2A_AD_scores', 'weights']
Offline_data shape: (1048336, 45)
HLT_AD scores [  0.          0.          0.        ...   2.109743    3.8268185
 515.9285   ] (1048336,)


In [8]:
# Create a large Autoencoder (encoder + decoder + AE model)
# ---------------------------------------------------------
def create_large_AE(input_dim, h_dim_1, h_dim_2, h_dim_3, h_dim_4, latent_dim, 
                    l2_reg=0.01, dropout_rate=0.0):
    
    # ---------------- Encoder ----------------
    encoder_inputs = layers.Input(shape=(input_dim,))
    x = layers.Dense(h_dim_1, kernel_regularizer=regularizers.l2(l2_reg))(encoder_inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Dropout(dropout_rate)(x)

    x = layers.Dense(h_dim_2, kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Dropout(dropout_rate)(x)

    x = layers.Dense(h_dim_3, kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Dropout(dropout_rate)(x)

    x = layers.Dense(h_dim_4, kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Dropout(dropout_rate)(x)

    z = layers.Dense(latent_dim, kernel_regularizer=regularizers.l2(l2_reg))(x)
    z = layers.BatchNormalization()(z)
    z = layers.Activation('relu')(z)

    encoder = Model(inputs=encoder_inputs, outputs=z, name="encoder")

    # ---------------- Decoder ----------------
    decoder_inputs = layers.Input(shape=(latent_dim,))
    x = layers.Dense(h_dim_4, kernel_regularizer=regularizers.l2(l2_reg))(decoder_inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Dropout(dropout_rate)(x)

    x = layers.Dense(h_dim_3, kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Dropout(dropout_rate)(x)

    x = layers.Dense(h_dim_2, kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Dropout(dropout_rate)(x)

    x = layers.Dense(h_dim_1, kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Dropout(dropout_rate)(x)

    outputs = layers.Dense(input_dim, kernel_regularizer=regularizers.l2(l2_reg))(x)

    decoder = Model(inputs=decoder_inputs, outputs=outputs, name="decoder")

    # ---------------- Autoencoder ----------------
    ae_outputs = decoder(encoder(encoder_inputs))
    ae = Model(encoder_inputs, outputs=ae_outputs, name="autoencoder")

    return ae, encoder, decoder


INPUT_DIM = 45 #No MET in Offline data yet! 
H_DIM_1 = 100
H_DIM_2 = 100
H_DIM_3 = 64
H_DIM_4 = 32
LATENT_DIM = 4
L2_reg_coupling = 0.01
dropout_p = 0.2

ae, encoder, decoder = create_large_AE(INPUT_DIM, H_DIM_1, H_DIM_2, H_DIM_3, H_DIM_4,
                                       LATENT_DIM, l2_reg=L2_reg_coupling, dropout_rate=dropout_p)

X = datasets["Offline_data"].astype(np.float32)
HLT_scores = HLT_AD_scores 

# Split train/validation
X_train, X_val, hlt_train, hlt_val = train_test_split(
    X, HLT_scores, test_size=0.2, random_state=42
)

optimizer = tf.keras.optimizers.Adam(1e-5)

BATCH_SIZE = 16773
EPOCHS = 1
num_batches = int(np.ceil(len(X_train) / BATCH_SIZE))
print("Total samples in X_train:", len(X_train))
print("Number of batches:", num_batches)

all_total_losses_disco = []

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    
    # Shuffle the training data - check 
    idx = np.arange(len(X_train))
    np.random.shuffle(idx)
    X_train_shuffled = X_train[idx]
    hlt_train_shuffled = hlt_train[idx]
 
    total_recon, total_disco, total_total = 0.0, 0.0, 0.0

    for i in range(num_batches):
        batch_X = X_train_shuffled[i*BATCH_SIZE:(i+1)*BATCH_SIZE]
        batch_hlt = hlt_train_shuffled[i*BATCH_SIZE:(i+1)*BATCH_SIZE]
        
        # Model predictions
        batch_pred = ae(batch_X, training=False)

        # Compute loss with DisCo
        recon_val, disco_val, total_val = disco_loss(
            tf.constant(batch_X, dtype=tf.float32),
            batch_pred,
            tf.constant(batch_hlt, dtype=tf.float32),
            lambda_disco=10.0
        )

        total_recon += np.sum(recon_val.numpy())           # sum over batch
        total_disco += float(disco_val.numpy())            # scalar
        total_total += np.mean(total_val.numpy())          # mean over batch

        # Save per-batch total loss
        all_total_losses_disco.append(np.mean(total_val.numpy()))
print('loss',all_total_losses_disco )

#No back propagation here, use gradient tape for that?
# Keras - back propagation

Total samples in X_train: 838668
Number of batches: 51

Epoch 1/1
Per-sample MSE shape: [16773] values: [24.3773708 0 0 ... 288.348358 85.7589 803.734]
Final loss shape: [] value: 2182.31665
Per-sample MSE shape: [16773] values: [145.69577 72.9674072 160.783432 ... 52.0121155 71.2855606 53.9495926]
Final loss shape: [] value: 1292.48633
Per-sample MSE shape: [16773] values: [436.773651 77.4311523 137.724106 ... 757.737915 211.847092 1176.2782]
Final loss shape: [] value: 1015.87671
Per-sample MSE shape: [16773] values: [19.9373856 68.7154846 3.66431928 ... 763.366943 0 31.0943584]
Final loss shape: [] value: 4365.98096
Per-sample MSE shape: [16773] values: [433.362152 0 95.693367 ... 549.44342 22.8241653 296.668488]
Final loss shape: [] value: 11408397
Per-sample MSE shape: [16773] values: [306.771179 0 136.311295 ... 346.629089 6.1140213 1274.83203]
Final loss shape: [] value: 47026.0625
Per-sample MSE shape: [16773] values: [70.8425 298.470215 99.7860641 ... 79.253479 79.3036575 76.5

In [24]:
def recon_loss(x_true, x_pred):
    # Mean Squared Error
    recon = tf.reduce_mean(tf.square(x_true - x_pred))
    return recon

all_recon_losses = []

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    
    # Shuffle training data
    idx = np.arange(len(X_train))
    np.random.shuffle(idx)
    X_train_shuffled = X_train[idx]
    
    num_batches = int(np.ceil(len(X_train) / BATCH_SIZE))
    
    for i in range(num_batches):
        batch_X = X_train_shuffled[i*BATCH_SIZE:(i+1)*BATCH_SIZE]
        
        # Model predictions
        batch_pred = ae(batch_X, training=False)

        # Compute reconstruction loss per sample
        recon_val = tf.reduce_mean(tf.square(tf.constant(batch_X, dtype=tf.float32) - batch_pred), axis=1)

        # Convert to numpy and extend the global array
        all_recon_losses.extend(recon_val.numpy())

        print(f"  Batch {i+1}/{num_batches} | Batch Recon Mean = {np.mean(recon_val.numpy()):.4f}")

print(f"Total samples recorded: {len(all_recon_losses)}")


Epoch 1/1
  Batch 1/1639 | Batch Recon Mean = 352.5190
  Batch 2/1639 | Batch Recon Mean = 756.3524
  Batch 3/1639 | Batch Recon Mean = 405.0544
  Batch 4/1639 | Batch Recon Mean = 352.4970
  Batch 5/1639 | Batch Recon Mean = 361.9637
  Batch 6/1639 | Batch Recon Mean = 404.6236
  Batch 7/1639 | Batch Recon Mean = 400.2768
  Batch 8/1639 | Batch Recon Mean = 425.4569
  Batch 9/1639 | Batch Recon Mean = 346.7185
  Batch 10/1639 | Batch Recon Mean = 425.2041
  Batch 11/1639 | Batch Recon Mean = 3406.2261
  Batch 12/1639 | Batch Recon Mean = 3232.5256
  Batch 13/1639 | Batch Recon Mean = 477.9403
  Batch 14/1639 | Batch Recon Mean = 687.3617
  Batch 15/1639 | Batch Recon Mean = 334.0822
  Batch 16/1639 | Batch Recon Mean = 531.6884
  Batch 17/1639 | Batch Recon Mean = 344.9128
  Batch 18/1639 | Batch Recon Mean = 440.1750
  Batch 19/1639 | Batch Recon Mean = 473.9649
  Batch 20/1639 | Batch Recon Mean = 430.2592
  Batch 21/1639 | Batch Recon Mean = 439.6924
  Batch 22/1639 | Batch Recon 

  Batch 217/1639 | Batch Recon Mean = 879.4343
  Batch 218/1639 | Batch Recon Mean = 308.9157
  Batch 219/1639 | Batch Recon Mean = 5902.1572
  Batch 220/1639 | Batch Recon Mean = 199884.4375
  Batch 221/1639 | Batch Recon Mean = 421.3971
  Batch 222/1639 | Batch Recon Mean = 3415.2512
  Batch 223/1639 | Batch Recon Mean = 451.0280
  Batch 224/1639 | Batch Recon Mean = 473.2113
  Batch 225/1639 | Batch Recon Mean = 549.3217
  Batch 226/1639 | Batch Recon Mean = 530.9839
  Batch 227/1639 | Batch Recon Mean = 328674.1250
  Batch 228/1639 | Batch Recon Mean = 332.7794
  Batch 229/1639 | Batch Recon Mean = 988.9750
  Batch 230/1639 | Batch Recon Mean = 546.7418
  Batch 231/1639 | Batch Recon Mean = 363.2619
  Batch 232/1639 | Batch Recon Mean = 382.8813
  Batch 233/1639 | Batch Recon Mean = 470.9095
  Batch 234/1639 | Batch Recon Mean = 589.8806
  Batch 235/1639 | Batch Recon Mean = 851.7124
  Batch 236/1639 | Batch Recon Mean = 321.7053
  Batch 237/1639 | Batch Recon Mean = 420.2737
  Bat

  Batch 422/1639 | Batch Recon Mean = 848.1891
  Batch 423/1639 | Batch Recon Mean = 420.5546
  Batch 424/1639 | Batch Recon Mean = 1910.1812
  Batch 425/1639 | Batch Recon Mean = 416.6086
  Batch 426/1639 | Batch Recon Mean = 357.4554
  Batch 427/1639 | Batch Recon Mean = 600.8380
  Batch 428/1639 | Batch Recon Mean = 394.7303
  Batch 429/1639 | Batch Recon Mean = 424.0294
  Batch 430/1639 | Batch Recon Mean = 1639119.1250
  Batch 431/1639 | Batch Recon Mean = 489.5123
  Batch 432/1639 | Batch Recon Mean = 506.0142
  Batch 433/1639 | Batch Recon Mean = 17056.8223
  Batch 434/1639 | Batch Recon Mean = 28995.8516
  Batch 435/1639 | Batch Recon Mean = 37841.8789
  Batch 436/1639 | Batch Recon Mean = 628.9310
  Batch 437/1639 | Batch Recon Mean = 3190.0835
  Batch 438/1639 | Batch Recon Mean = 935.5486
  Batch 439/1639 | Batch Recon Mean = 4106.2510
  Batch 440/1639 | Batch Recon Mean = 414.4799
  Batch 441/1639 | Batch Recon Mean = 497.2188
  Batch 442/1639 | Batch Recon Mean = 34643.160

  Batch 626/1639 | Batch Recon Mean = 1222.3726
  Batch 627/1639 | Batch Recon Mean = 305.4504
  Batch 628/1639 | Batch Recon Mean = 539.9421
  Batch 629/1639 | Batch Recon Mean = 365.5411
  Batch 630/1639 | Batch Recon Mean = 439.1593
  Batch 631/1639 | Batch Recon Mean = 1871.8524
  Batch 632/1639 | Batch Recon Mean = 457.5872
  Batch 633/1639 | Batch Recon Mean = 327.9374
  Batch 634/1639 | Batch Recon Mean = 466.0433
  Batch 635/1639 | Batch Recon Mean = 971.2928
  Batch 636/1639 | Batch Recon Mean = 464.9757
  Batch 637/1639 | Batch Recon Mean = 574.8829
  Batch 638/1639 | Batch Recon Mean = 732.5135
  Batch 639/1639 | Batch Recon Mean = 357.4946
  Batch 640/1639 | Batch Recon Mean = 512.6268
  Batch 641/1639 | Batch Recon Mean = 522.8133
  Batch 642/1639 | Batch Recon Mean = 969.2991
  Batch 643/1639 | Batch Recon Mean = 453.3822
  Batch 644/1639 | Batch Recon Mean = 430.4913
  Batch 645/1639 | Batch Recon Mean = 501.0124
  Batch 646/1639 | Batch Recon Mean = 315.9433
  Batch 647

  Batch 830/1639 | Batch Recon Mean = 394.3433
  Batch 831/1639 | Batch Recon Mean = 606.1627
  Batch 832/1639 | Batch Recon Mean = 454.5139
  Batch 833/1639 | Batch Recon Mean = 445.8496
  Batch 834/1639 | Batch Recon Mean = 350.5640
  Batch 835/1639 | Batch Recon Mean = 389.0025
  Batch 836/1639 | Batch Recon Mean = 543.7013
  Batch 837/1639 | Batch Recon Mean = 363.6602
  Batch 838/1639 | Batch Recon Mean = 22912.3691
  Batch 839/1639 | Batch Recon Mean = 607.4343
  Batch 840/1639 | Batch Recon Mean = 569.1500
  Batch 841/1639 | Batch Recon Mean = 2059.7080
  Batch 842/1639 | Batch Recon Mean = 8730.1738
  Batch 843/1639 | Batch Recon Mean = 503.1559
  Batch 844/1639 | Batch Recon Mean = 633.9711
  Batch 845/1639 | Batch Recon Mean = 484.1166
  Batch 846/1639 | Batch Recon Mean = 350.5714
  Batch 847/1639 | Batch Recon Mean = 2289.7095
  Batch 848/1639 | Batch Recon Mean = 434.7591
  Batch 849/1639 | Batch Recon Mean = 507.6905
  Batch 850/1639 | Batch Recon Mean = 372.3290
  Batch 

  Batch 1038/1639 | Batch Recon Mean = 417.3666
  Batch 1039/1639 | Batch Recon Mean = 503.3537
  Batch 1040/1639 | Batch Recon Mean = 18114.9902
  Batch 1041/1639 | Batch Recon Mean = 19996.6934
  Batch 1042/1639 | Batch Recon Mean = 360.8119
  Batch 1043/1639 | Batch Recon Mean = 441.8965
  Batch 1044/1639 | Batch Recon Mean = 991.2458
  Batch 1045/1639 | Batch Recon Mean = 495.2239
  Batch 1046/1639 | Batch Recon Mean = 278.0438
  Batch 1047/1639 | Batch Recon Mean = 672.4716
  Batch 1048/1639 | Batch Recon Mean = 563.4738
  Batch 1049/1639 | Batch Recon Mean = 448.3307
  Batch 1050/1639 | Batch Recon Mean = 318.7101
  Batch 1051/1639 | Batch Recon Mean = 530.1174
  Batch 1052/1639 | Batch Recon Mean = 1079.4573
  Batch 1053/1639 | Batch Recon Mean = 1255.2803
  Batch 1054/1639 | Batch Recon Mean = 411.5474
  Batch 1055/1639 | Batch Recon Mean = 502.9629
  Batch 1056/1639 | Batch Recon Mean = 960.2173
  Batch 1057/1639 | Batch Recon Mean = 428.2696
  Batch 1058/1639 | Batch Recon Me

  Batch 1248/1639 | Batch Recon Mean = 727.0869
  Batch 1249/1639 | Batch Recon Mean = 312.9745
  Batch 1250/1639 | Batch Recon Mean = 414.7506
  Batch 1251/1639 | Batch Recon Mean = 383.9324
  Batch 1252/1639 | Batch Recon Mean = 346.2285
  Batch 1253/1639 | Batch Recon Mean = 381.7446
  Batch 1254/1639 | Batch Recon Mean = 389.1370
  Batch 1255/1639 | Batch Recon Mean = 421.0818
  Batch 1256/1639 | Batch Recon Mean = 627.6968
  Batch 1257/1639 | Batch Recon Mean = 1667.6157
  Batch 1258/1639 | Batch Recon Mean = 582.7667
  Batch 1259/1639 | Batch Recon Mean = 366.1027
  Batch 1260/1639 | Batch Recon Mean = 427.9088
  Batch 1261/1639 | Batch Recon Mean = 547.6089
  Batch 1262/1639 | Batch Recon Mean = 2552.5850
  Batch 1263/1639 | Batch Recon Mean = 422.6458
  Batch 1264/1639 | Batch Recon Mean = 452.8612
  Batch 1265/1639 | Batch Recon Mean = 430.0946
  Batch 1266/1639 | Batch Recon Mean = 6492.3354
  Batch 1267/1639 | Batch Recon Mean = 591.6754
  Batch 1268/1639 | Batch Recon Mean 

  Batch 1457/1639 | Batch Recon Mean = 709.6163
  Batch 1458/1639 | Batch Recon Mean = 504.2379
  Batch 1459/1639 | Batch Recon Mean = 452.3064
  Batch 1460/1639 | Batch Recon Mean = 367.8025
  Batch 1461/1639 | Batch Recon Mean = 395.6343
  Batch 1462/1639 | Batch Recon Mean = 414.5881
  Batch 1463/1639 | Batch Recon Mean = 455.9862
  Batch 1464/1639 | Batch Recon Mean = 357.2433
  Batch 1465/1639 | Batch Recon Mean = 1139.7009
  Batch 1466/1639 | Batch Recon Mean = 323.4390
  Batch 1467/1639 | Batch Recon Mean = 367.8623
  Batch 1468/1639 | Batch Recon Mean = 513.7874
  Batch 1469/1639 | Batch Recon Mean = 730.2029
  Batch 1470/1639 | Batch Recon Mean = 342.1722
  Batch 1471/1639 | Batch Recon Mean = 4425.8662
  Batch 1472/1639 | Batch Recon Mean = 479.3406
  Batch 1473/1639 | Batch Recon Mean = 507.9540
  Batch 1474/1639 | Batch Recon Mean = 357.8586
  Batch 1475/1639 | Batch Recon Mean = 417.9428
  Batch 1476/1639 | Batch Recon Mean = 346.0545
  Batch 1477/1639 | Batch Recon Mean =

In [None]:
#cross check per epoch 
