## HGQ train of models: MLP , DeepSets , MLPMixer , JEDILinear and Lipschitz MLPs on PerNano dataset
Based on Cheng Su implementations using HGQ 
https://github.com/calad0i/JEDI-linear/blob/master/src/

In [1]:
# ─────────────────────────────────────────────────────────────
# Standard library
# ─────────────────────────────────────────────────────────────
import os
import time
import random
import threading
import pickle as pkl
from math import log2, cos, pi
from pathlib import Path
from functools import partial

# ─────────────────────────────────────────────────────────────
# Third-party scientific / data
# ─────────────────────────────────────────────────────────────
import numpy as np
import awkward as ak
import h5py as h5
import psutil

# ─────────────────────────────────────────────────────────────
# TensorFlow / Keras
# (Use ONLY tensorflow.keras — avoid mixing keras.*)
# ─────────────────────────────────────────────────────────────
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras import backend as K

# Layers
from tensorflow.keras.layers import (
    Dense,
    EinsumDense,
    BatchNormalization,
    Flatten,
    Add,
    GlobalAveragePooling1D,
    Masking,
)

# Callbacks
from tensorflow.keras.callbacks import (
    TerminateOnNaN,
    EarlyStopping,
    ReduceLROnPlateau,
    ModelCheckpoint,
    LearningRateScheduler,
    LambdaCallback,
)

# Constraints
from tensorflow.keras.constraints import MaxNorm, UnitNorm, MinMaxNorm

# Utilities
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.saving import register_keras_serializable

# ─────────────────────────────────────────────────────────────
# Scikit-learn
# ─────────────────────────────────────────────────────────────
from sklearn.metrics import (
    roc_curve,
    auc,
    mean_absolute_error,
    mean_squared_error,
    r2_score,
)
from sklearn.preprocessing import LabelBinarizer, label_binarize
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


# ─────────────────────────────────────────────────────────────
# HGQ + +DA4ML + HLS4ML
# ─────────────────────────────────────────────────────────────
from hgq.config import QuantizerConfig, QuantizerConfigScope
from hgq.layers import QDense, QSoftmax, QEinsumDenseBatchnorm
from hgq.utils.sugar import BetaScheduler, Dataset, FreeEBOPs, ParetoFront, PBar, PieceWiseSchedule

from hgq.regularizers import MonoL1
#from da4ml.codegen import VerilogModel
#from da4ml.converter.hgq2.parser import trace_model
#from da4ml.trace import HWConfig, comb_trace
from hls4ml.converters import convert_from_keras_model



# ====================================================
# Environment setup (before TensorFlow import)
# ====================================================
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"   # 0=all, 1=info, 2=warning, 3=error
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"  # Force CPU only (disable Metal GPU)
os.environ["TF_METAL_ENABLE"] = "0"        # disables Metal plugin (macOS)
#os.environ['KERAS_BACKEND'] = 'jax'         # Sets Jax as Keras backend 
os.environ["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=-1"
# ====================================================
# TensorFlow version
# ====================================================
print(tf.__version__)
print(tf.config.list_physical_devices())

# ====================================================
# JAX devices and 
# ====================================================
#print("jax devices:", jax.devices())           # should show METAL device(s) if jax-metal active

# ====================================================
# TensorFlow Debugging 
# ====================================================
#tf.debugging.set_log_device_placement(True)

# ====================================================
# Background monitor for CPU and RAM
# ====================================================
#def monitor_resources(interval=2):
#    """Print CPU and memory usage every `interval` seconds."""
#    while True:
#        mem = psutil.virtual_memory()
#        cpu = psutil.cpu_percent()
#        print(f"[Resource Monitor] CPU: {cpu:.1f}% | RAM Used: {mem.used/1e9:.2f} GB / {mem.total/1e9:.2f} GB")
#        time.sleep(interval)
#
# Start monitoring in background thread
#thread = threading.Thread(target=monitor_resources, daemon=True)
#thread.start()

# JAX Keras backend
#if keras.backend.backend() == 'jax':
#    jax.config.update("jax_default_matmul_precision", "tensorfloat32")

print("keras backend:", keras.config.backend())


2.18.0
[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]
keras backend: tensorflow


## Configurations

In [2]:
## Config
seed = 123
seed = 123
first_decay_steps = 100
merit = 'loss'
# merit = 'accuracy'
# merit = 
#cls_output_categorical_accuracy
#cls_output_loss
#loss
#pt_output_loss
#pt_output_mean_squared_error
#val_cls_output_categorical_accuracy
#val_cls_output_loss
#val_loss,val_pt_output_loss
#val_pt_output_mean_squared_error

bsz =128

## Dataloader

In [3]:
#Data PATH
inputdir = '/Users/sznajder/cernbox/Temp/scaled_classes_perjet'


# ------------------------------------------
# 1. Classes and mapping to integer indices
# ------------------------------------------
class_labels = {
    "Top": 10,
    "HQQ": 20,
    "HTauTau": 40,
    "Wqq": 50,
    "QCD": 70,
}

label_to_idx = {label: i for i, label in enumerate(class_labels.values())}
nclasses = len(class_labels)

# ------------------------------------------
# 2. Features
# ------------------------------------------
#constituents_features_list = [
#    'pt','pt_phys','pt_rel','eta_phys','deta_phys','phi_phys','dphi_phys',
#    'dr','z0','dxy','dxy_custom','id','charge','track_vx','track_vy',
#    'track_vz','track_d0','track_z0','puppiweight'
#]

constituents_features_list = ['pt_phys','pt_rel','deta_phys','dphi_phys','dxy']

jet_features_list = [
    "jet_pt", "jet_eta", "jet_phi", "jet_mass", "jet_energy", "jet_bjetscore",
    "jet_genmatch_pt", "jet_genmatch_eta", "jet_genmatch_phi",
    "jet_genmatch_mass", "jet_genmatch_flav"
]

outdir = "./scaled_classes_perjet"


# ------------------------------------------
# 3. Load per-class parquet files and build arrays
# ------------------------------------------
constituents_list = []
jets_list = []
labels_list = []
nmax = 100000 # max number of jets per class
weights_list = []  # weights to balance classes during training

for cls_name, cls_label in class_labels.items():
    print(f"\nLoading class: {cls_name}")

    # Load constituents
    const_file = os.path.join(inputdir, f"{cls_name}_Constituents.parquet")
    constituents = ak.from_parquet(const_file)
    constituents_array = np.stack([np.ascontiguousarray(constituents[f]) for f in constituents_features_list], axis=-1)

    # Load jet-level features
    jets_file = os.path.join(inputdir, f"{cls_name}_Jets.parquet")
    jets = ak.from_parquet(jets_file)
    jets_array = np.stack([np.ascontiguousarray(jets[f]) for f in jet_features_list], axis=-1)

    # Labels (consecutive indices)
    label_idx = label_to_idx[cls_label]
    labels_array = np.full(jets_array.shape[0], label_idx)

    # Filter out up to nmax jets per class
    constituents_array = constituents_array[0:nmax]
    jets_array = jets_array[0:nmax]
    labels_array = labels_array[0:nmax]
    
    # Append to lists
    constituents_list.append(constituents_array)
    jets_list.append(jets_array)
    labels_list.append(labels_array)

    # Define class weight
    weight = 1./len(labels_array)
    weights_list.append( np.ones(len(labels_array))*weight )

    
    print(f"{cls_name}: constituents {constituents_array.shape}, jets {jets_array.shape}, labels {labels_array.shape}")


# Th=ransform the list of weights into sample weights
class_weights = dict(enumerate(weights_list))




# ------------------------------------------
# 4a. Concatenate all classes 
# ------------------------------------------
X_constituents = np.concatenate(constituents_list, axis=0).astype(np.float32)
X_jets = np.concatenate(jets_list, axis=0).astype(np.float32)
labels_array = np.concatenate(labels_list, axis=0)
weights = np.concatenate(weights_list, axis=0)


# One-hot encode labels for Keras
y = to_categorical(labels_array, num_classes=nclasses).astype(np.float32)


# Data shapes
nconstit = X_constituents.shape[-2]
nfeat = X_constituents.shape[-1]

print("\nShapes before train/test split:")
print("X_constituents:", X_constituents.shape)
print("X_jets:", X_jets.shape)
print("y:", y.shape)

# ------------------------------------------
# 4b. Create regression target: jet pT
# ------------------------------------------
# Assuming X_jets has jet-level features in the order:
# ["jet_pt", "jet_eta", "jet_phi", ...] as defined in jet_features_list
# We take the first column (jet_pt) as regression target
y_reg = X_jets[:, 0].astype(np.float32)  # shape (n_jets,)
#y_reg = y_reg.reshape(-1, 1)

# ------------------------------------------
# 5. Split into train/test (stratified)
# ------------------------------------------
X_train_val, X_test, \
Xj_train_val, Xj_test, y_train_val, \
y_test, y_reg_train_val, y_reg_test, \
weights_train_val, weights_test = train_test_split(
    X_constituents,
    X_jets,
    y,
    y_reg,
    weights,
    test_size=0.3,
    shuffle=True,
    random_state=42,
#    stratify=labels_array
)



X_train, X_val, \
Xj_train, Xj_val, \
y_train, y_val, \
y_reg_train, y_reg_val, \
weights_train, weights_val = train_test_split(
    X_train_val,
    Xj_train_val,
    y_train_val,
    y_reg_train_val,
    weights_train_val,
    test_size=0.3,
    shuffle=True,
    random_state=42,
#    stratify=y_train_val 
)

# Sets constant weights for regression 
weights_reg_train = np.ones( len(y_train) )
weights_reg_val = np.ones( len(y_val) )

# Do this once before the loop:
#X_train, X_val, y_cls_train, y_cls_val, y_reg_train, y_reg_val = train_test_split(
#    X_train_val, y_train_val, y_reg_train_val, test_size=0.3, stratify=np.argmax(y_train_val, axis=1), random_state=42)

nsamples = len(X_train_val)

'''
def make_dataset(X, y_cls, y_reg, batch_size=bsz, shuffle=True):
    # y_cls: (N, nclasses) one-hot
    # y_reg: (N,1)
    ds = tf.data.Dataset.from_tensor_slices((X, {"cls_output": y_cls, "pt_output": y_reg}))
    if shuffle:
        ds = ds.shuffle(10000, reshuffle_each_iteration=True)
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds

train_dataset = make_dataset(X_train, y_train, y_reg_train, batch_size=bsz, shuffle=True)
val_dataset = make_dataset(X_val, y_val, y_reg_val, batch_size=bsz, shuffle=True)
test_dataset  = make_dataset(X_test     , y_test     , y_reg_test     , batch_size=bsz, shuffle=False)
'''

print("\nShapes after train/test split:")
print("X_train_val:", X_train_val.shape, "Xc_test:", X_test.shape)
print("Xj_train:", Xj_train_val.shape, "Xj_test:", Xj_test.shape)
print("y_train_val (classification):", y_train_val.shape, "y_test:", y_test.shape)
print("y_reg_train_val (regression):", y_reg_train_val.shape, "y_reg_test:", y_reg_test.shape)
print(" ")
print("----------------------------------------------------------------------------------------")
print(" ")
print("y_train_val[:10] (regression):", y_train_val[:10], "y_test[:10]:", y_test[:10])
print("y_reg_train_val[:10] (regression):", y_reg_train_val[:10], "y_reg_test[:10]:", y_reg_test[:10])


# ------------------------------------------
# 6. Arrays ready for Keras
# X_train_val / X_test → (n_jets, Nconstit, n_features)
# Xj_train_val / Xj_test → (n_jets, n_jet_features)
# y_train_val / y_test   → (n_jets, n_classes)
# ------------------------------------------



Loading class: Top
Top: constituents (2815, 16, 5), jets (2815, 11), labels (2815,)

Loading class: HQQ
HQQ: constituents (100000, 16, 5), jets (100000, 11), labels (100000,)

Loading class: HTauTau
HTauTau: constituents (6659, 16, 5), jets (6659, 11), labels (6659,)

Loading class: Wqq
Wqq: constituents (18838, 16, 5), jets (18838, 11), labels (18838,)

Loading class: QCD
QCD: constituents (100000, 16, 5), jets (100000, 11), labels (100000,)

Shapes before train/test split:
X_constituents: (228312, 16, 5)
X_jets: (228312, 11)
y: (228312, 5)

Shapes after train/test split:
X_train_val: (159818, 16, 5) Xc_test: (68494, 16, 5)
Xj_train: (159818, 11) Xj_test: (68494, 11)
y_train_val (classification): (159818, 5) y_test: (68494, 5)
y_reg_train_val (regression): (159818,) y_reg_test: (68494,)
 
----------------------------------------------------------------------------------------
 
y_train_val[:10] (regression): [[0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 1.]
 [0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0.

## Define the Models
#### https://github.com/calad0i/JEDI-linear/blob/master/src/model.py 

In [4]:
# Skipping these should also work.
# Usually, the default configs are good enough for most cases, but the initial number of bits, `[bif]0`
# may need to be increased. If you see that the model is not converging, you can try increasing these values.
#
#scope0 = QuantizerConfigScope(place='all', k0=1, b0=3, i0=0, default_q_type='kbi', overflow_mode='sat_sym')
#scope1 = QuantizerConfigScope(place='datalane', k0=0, default_q_type='kif', overflow_mode='wrap', f0=3, i0=3)
#
#iq_conf = QuantizerConfig(place='datalane', round_mode='RND')
#iq_default = QuantizerConfig(place='datalane')
#
#
#exp_table_conf = QuantizerConfig('kif', 'table', k0=0, i0=1, f0=8, overflow_mode='sat_sym')
#inv_table_conf = QuantizerConfig('kif', 'table', k0=1, i0=4, f0=4, overflow_mode='sat_sym')
#
#iq_conf = QuantizerConfig(place='datalane', k0=1)
#oq_conf = QuantizerConfig(place='datalane', k0=1, fr=MonoL1(1e-3))
#
# Layer scope will over formal one. When using scope0, scope1, 'datalane' config will be overriden with config in scope1
#lscope = LayerConfigScope(enable_ebops=True, beta0=1e-5)
#


# From https://github.com/calad0i/HGQ2-examples/blob/master/jsc/model.py
#init_bw_k=3
#init_bw_a=3
init_bw_k=8
init_bw_a=8
#beta0=1e-5
# 
scope0 = QuantizerConfigScope(place='weight', 
                              overflow_mode='SAT_SYM', 
                              f0=init_bw_k, 
                              trainable=True)
scope1 = QuantizerConfigScope(place='bias', 
                              overflow_mode='WRAP', 
                              f0=init_bw_k, 
                              trainable=True)
scope2 = QuantizerConfigScope(place='datalane', 
                              i0=8, 
                              f0=init_bw_a)
#
#scope0 = QuantizerConfigScope(default_q_type='dummy')
#scope1 = QuantizerConfigScope(default_q_type='dummy')
#scope2 = QuantizerConfigScope(default_q_type='dummy')
#



#################################################################################################################################

#### # MLP using Dense layers
def MLP_dense(nhid=64):
    with (scope0, scope1, scope2):

        inp = keras.layers.Input((nconstit, nfeat))

#    masked_inp = Masking(mask_value=-999, name="masking")(inp) 
        x = BatchNormalization()(inp)
        x = Flatten()(x)

        xr = QDense( nhid, activation='relu', name="pt_dense1" )(x)
        xr = QDense( nhid, activation='relu', name="pt_dense2" )(xr)
        pt_out = QDense(1, name="pt_output")(xr)     # Pt regression head

        x = QDense( nhid, activation='relu', name="cls_dense1")(x)
        x = QDense( nhid, activation='relu', name="cls_dense2" )(x)
        x = QDense( nhid, activation='relu', name="cls_dense3" )(x)
        cls_out = QDense(nclasses, name="cls_output")(x) # classification head
 
    model = keras.Model(inputs=inp, outputs=[cls_out, pt_out])
    return model

####################################################################################################

# MLP using EinsumDense layers 
def MLP_einsum(nhid=64):

    with (scope0, scope1, scope2):

        inp = keras.layers.Input((nconstit, nfeat))

#        x = BatchNormalization()(inp)
        x = inp
        x = Flatten()(x)

        # Pt Regression
        xr = QEinsumDenseBatchnorm('bc,cF->bF', nhid, bias_axes='F', activation='relu', name="pt_dense1" )(x)
        xr = QEinsumDenseBatchnorm('bc,cF->bF', nhid, bias_axes='F', activation='relu', name="pt_dense2" )(xr)
        pt_out = QEinsumDenseBatchnorm('bc,cF->bF', 1, bias_axes='F', name="pt_output" )(xr)      # Pt regression head

        # Classifier
        x = QEinsumDenseBatchnorm('bc,cF->bF', nhid, bias_axes='F', activation='relu', name="cls_einsum1" )(x)
        x = QEinsumDenseBatchnorm('bc,cF->bF', nhid, bias_axes='F', activation='relu', name="cls_einsum2" )(x)
        x = QEinsumDenseBatchnorm('bc,cF->bF', nhid, bias_axes='F', activation='relu', name="cls_einsum3" )(x)
        cls_out = QEinsumDenseBatchnorm('bc,cF->bF', nclasses, bias_axes='F', name="cls_output" )(x)  # classification head

    model = keras.Model(inputs=inp, outputs=[cls_out, pt_out])
    return model


#################################################################################################

def DS_dense(nhid=64):

    with (scope0, scope1, scope2):

        inp = keras.layers.Input((nconstit, nfeat))

#        x = BatchNormalization()(inp)
        x = inp

        x = QDense(nhid, activation='relu', name="dense1")(x)
        x = QDense(nhid, activation='relu', name="dense2")(x)
        x = QDense(nhid, activation='relu', name="dense3")(x)
#        x = tf.keras.sum(x, axis=1, keepdims=False) / N
        x = GlobalAveragePooling1D()(x)

        # Pt Regression
        xr = QDense(nhid, activation='relu', name="pt_dense1" )(x)
        xr = QDense(nhid, activation='relu', name="pt_dense2" )(xr)
        pt_out = QDense(1, name="pt_output")(xr)     # Pt regression head

    # Classifier
    x = QDense(nhid, activation='relu', name="cls_dense1")(x)
    x = QDense(nhid, activation='relu', name="cls_dense2")(x)
    cls_out = QDense(nclasses, name="cls_output")(x) # classification head

    model = keras.Model(inputs=inp, outputs=[cls_out, pt_out])
    return model

#################################################################################################

def DS_einsum(nhid=64):

    with (scope0, scope1, scope2):

        N=nconstit
        inp = keras.layers.Input((nconstit, nfeat))
    
#        x = BatchNormalization()(inp)
        x = inp

        x = QEinsumDenseBatchnorm('bnc,cC->bnC', (N, nhid), bias_axes='C', activation='relu', name="einsum1")(x)
        x = QEinsumDenseBatchnorm('bnc,cC->bnC', (N, nhid), bias_axes='C', activation='relu', name="einsum2")(x)
        x = QEinsumDenseBatchnorm('bnc,cC->bnC', (N, nhid), bias_axes='C', activation='relu', name="einsum3")(x)
#        x = tf.keras.sum(x, axis=1, keepdims=False) / N
        x = GlobalAveragePooling1D()(x)

        # Pt Regression
        xr = QEinsumDenseBatchnorm('bc,cC->bC', nhid, bias_axes='C', activation='relu', name="pt_dense1" )(x)
        xr = QEinsumDenseBatchnorm('bc,cC->bC', nhid, bias_axes='C', activation='relu', name="pt_dense2" )(xr)
        pt_out = QEinsumDenseBatchnorm('bc,cC->bC', 1, bias_axes='C', name="pt_output" )(xr)      # Pt regression head

        # Classifier
        x = QEinsumDenseBatchnorm('bc,cC->bC', nhid, bias_axes='C', activation='relu', name="cls_einsum1")(x)
        x = QEinsumDenseBatchnorm('bc,cC->bC', nhid, bias_axes='C', activation='relu', name="cls_einsum2")(x)
        cls_out = QEinsumDenseBatchnorm('bc,cC->bC', nclasses, bias_axes='C', name="cls_output" )(x)  # classification head

    model = keras.Model(inputs=inp, outputs=[cls_out, pt_out])
    return model

#################################################################################################

# MLP mixer network
def MLP_mixer(nhid=16):

    with (scope0, scope1, scope2):

        N=nconstit
        inp = keras.layers.Input((nconstit, nfeat))

#        x = BatchNormalization()(inp)
        x = inp
        
        # Patches MLP
        x1 = QEinsumDenseBatchnorm('bnc,cC->bnC', (N, nhid), bias_axes='C', activation='relu', name="einsum1")(x)
        x1 = QEinsumDenseBatchnorm('bnc,cC->bnC', (N, n), bias_axes='C', activation='relu', name="einsum2")(x1)
        x1 = QEinsumDenseBatchnorm('bnc,nN->bNc', (N, n), bias_axes='N', name="einsum3")(x1)

        # modification on original Cheng Sun MLP_MIXER model to add batch normed input 
        x = Add()([x, x1])

        # Features MLP
        x = QEinsumDenseBatchnorm('bnc,cC->bnC', (N, nhid), bias_axes='C', activation='relu', name="einsum4" )(x)
        x = QEinsumDenseBatchnorm('bnc,cC->bnC', (N, nhid), bias_axes='C', activation='relu', name="einsum5")(x)
        x = QEinsumDenseBatchnorm('bnc,n->bc', nhid, name="einsum6")(x)

        # Pt Regression
        xr = QEinsumDenseBatchnorm('bc,cC->bC', nhid, bias_axes='C', activation='relu', name="pt_dense1"  )(x)
        xr = QEinsumDenseBatchnorm('bc,cC->bC', nhid, bias_axes='C', activation='relu', name="pt_dense2"  )(xr)
        pt_out = QEinsumDenseBatchnorm('bc,cC->bC', 1, bias_axes='C', name="pt_output" )(xr)      # Pt regression head

        # Classifier
        x = QEinsumDenseBatchnorm('bc,cC->bC', nhid, bias_axes='C', activation='relu', name="cls_einsum1")(x)
        x = QEinsumDenseBatchnorm('bc,cC->bC', nhid, bias_axes='C', activation='relu', name="cls_einsum2")(x)
        x = QEinsumDenseBatchnorm('bc,cC->bC', nhid, bias_axes='C', activation='relu', name="cls_einsum3")(x)
        cls_out = QEinsumDenseBatchnorm('bc,cC->bC', nclasses, bias_axes='C', name="cls_output" )(x)  # classification head

    model = keras.Model(inputs=inp, outputs=[cls_out, pt_out])
    return model

####################################################################################################

# JEDI linear InteractionNetwork
def JEDI_linear(nhid=64):

    with (scope0, scope1, scope2):

        N=nconstit

        pool_scale = 2.**-round(log2(N))
    
        inp = keras.layers.Input((nconstit, nfeat))

#        x = BatchNormalization()(inp)
        x = inp
        
        # Edges MLP
        x = QEinsumDenseBatchnorm('bnc,cC->bnC', (N, nhid), bias_axes='C', activation='relu', name="einsum1")(x)
        s = QEinsumDenseBatchnorm('bnc,cC->bnC', (N, nhid), bias_axes='C', activation='relu', name="einsum2")(x)
        xx = tf.keras.ops.sum(x, axis=1, keepdims=True) / pool_scale
        d = QEinsumDenseBatchnorm('bnc,cC->bnC', (1, nhid), bias_axes='C', activation='relu', name="einsum3")(xx)

        # Nodes MLP
        x = Add()([s, d])
        x = QEinsumDenseBatchnorm('bnc,cC->bnC', (N, nhid), bias_axes='C', activation='relu', name="einsum4" )(x)
        x = tf.keras.ops.sum(x, axis=1, keepdims=False) / N
#        x = GlobalAveragePooling1D()(x)  # Why not use GlobalAverage Pooling1D instead o Sum above ?
    

        # Pt Regression
        xr = QEinsumDenseBatchnorm('bc,cC->bC', int(nhid), bias_axes='C', activation='relu', name="pt_dense1" )(x)
        xr = QEinsumDenseBatchnorm('bc,cC->bC', int(nhid), bias_axes='C', activation='relu', name="pt_dense2" )(xr)
        pt_out = QEinsumDenseBatchnorm('bc,cC->bC', 1, bias_axes='C', name="pt_output" )(xr)      # Pt regression head

        # Graph classifier
        x = QEinsumDenseBatchnorm('bc,cC->bC', nhid, bias_axes='C', activation='relu', name="cls_einsum1")(x)
        x = QEinsumDenseBatchnorm('bc,cC->bC', int(nhid), bias_axes='C', activation='relu', name="cls_einsum2" )(x)
        x = QEinsumDenseBatchnorm('bc,cC->bC', int(nhid), bias_axes='C', activation='relu', name="cls_einsum3" )(x)
        cls_out = QEinsumDenseBatchnorm('bc,cC->bC', nclasses, bias_axes='C' ,name='cls_output')(x)  # classification head

    model = keras.Model(inputs=inp, outputs=[cls_out, pt_out])
    return model


###################################################################################################

# Lipschitz MLP based on weights constrain/regularization
@register_keras_serializable()
class LipschitzReg(layers.Wrapper):
    """Wrapper around a provided Keras Dense layer with Lipschitz constraint on weights."""
    
    def __init__(self, dense_layer, kind="one-inf", **kwargs):
        """Initialize with a pre-defined Dense layer and norm constraint kind.
        
        Args:
            dense_layer: A keras.layers.Dense instance to wrap.
            kind: Norm constraint type ("one", "inf", "one-inf", "two-inf"). Default: "one-inf".
                  - "one": L1 norm (|W|_1) constraint per row.
                  - "inf": Linf norm (|W|_inf) constraint per column.
                  - "one-inf": L1 norm per row, Linf per column.
                  - "two-inf": L2 norm per row, Linf per column.
        """
        self.kind = kind.lower()
        if self.kind not in ["one", "inf", "one-inf", "two-inf"]:
            raise ValueError(f"Unsupported kind '{kind}'. Use 'one', 'inf', 'one-inf', or 'two-inf'.")
        # Validate and store the provided Dense layer
        if not isinstance(dense_layer, layers.Dense):
            raise ValueError("dense_layer must be an instance of keras.layers.Dense.")
        super(LipschitzReg, self).__init__(dense_layer, **kwargs)

    def build(self, input_shape):
        """Build the wrapped Dense layer and normalize its weights."""
        super(LipschitzReg, self).build(input_shape)
        self._normalize_weights()

    def _normalize_weights(self):
        """Normalize the wrapped Dense layer's weights based on the specified kind."""
        W = self.layer.kernel
        if self.kind == "one":
            row_norms = tf.reduce_sum(tf.abs(W), axis=1, keepdims=True)
            self.layer.kernel.assign(W / tf.maximum(row_norms, tf.keras.backend.epsilon()))
        elif self.kind == "inf":
            col_norms = tf.reduce_max(tf.abs(W), axis=0, keepdims=True)
            self.layer.kernel.assign(W / tf.maximum(col_norms, tf.keras.backend.epsilon()))
        elif self.kind == "one-inf":
            row_norms = tf.reduce_sum(tf.abs(W), axis=1, keepdims=True)
            col_norms = tf.reduce_max(tf.abs(W), axis=0, keepdims=True)
            W_normalized = W / tf.maximum(row_norms, tf.keras.backend.epsilon())
            self.layer.kernel.assign(W_normalized / tf.maximum(col_norms, tf.keras.backend.epsilon()))
        elif self.kind == "two-inf":
            row_norms = tf.sqrt(tf.reduce_sum(tf.square(W), axis=1, keepdims=True))
            col_norms = tf.reduce_max(tf.abs(W), axis=0, keepdims=True)
            W_normalized = W / tf.maximum(row_norms, tf.keras.backend.epsilon())
            self.layer.kernel.assign(W_normalized / tf.maximum(col_norms, tf.keras.backend.epsilon()))

    def call(self, inputs, training=None):
        """Apply the wrapped Dense layer with dynamic weight normalization during training."""
        if training:
            self._normalize_weights()  # Re-normalize weights each forward pass in training
        return self.layer(inputs)

    def compute_output_shape(self, input_shape):
        return self.layer.compute_output_shape(input_shape)


def MLP_LipschitzReg(nhid=64):

    inp = keras.layers.Input((nconstit, nfeat))

#    masked_inp = Masking(mask_value=-999, name="masking")(inp) 
    x = BatchNormalization()(inp)

    x = Flatten()(x)

    # Pt Regression
    xr = LipschitzReg(layers.Dense(int(nhid), activation='relu'), name="pt_dense1", kind="inf")(x)   
    xr = LipschitzReg(layers.Dense(int(nhid), activation='relu'), name="pt_dense2", kind="inf")(xr)
    pt_out = LipschitzReg(layers.Dense(1, activation='relu' ), name="pt_output", kind="inf")(xr)  # Pt regression head

    
    x = LipschitzReg(layers.Dense(nhid, activation='relu'), kind="one-inf")(x)  # First layer with one-inf constraint
    x = LipschitzReg(layers.Dense(int(nhid), activation='relu'), kind="inf")(x)     # Subsequent layer with inf constraint
    x = LipschitzReg(layers.Dense(int(nhid), activation='relu'), kind="inf")(x)
#    x = LipschitzMonotonic(layers.Dense(int(nhid), activation='relu'), kind="inf")(x)
    cls_out = LipschitzReg(layers.Dense(nclasses), name="cls_output", kind="inf")(x)

    model = keras.Model(inputs=inp, outputs=[cls_out, pt_out])
    return model


#############################################################################################################
#	•	Using kernel_constraint=MaxNorm(1.0) is a cheap approximation for controlling the Lipschitz constant.
#	•	It does not guarantee 1-Lipschitz behavior, so the network is effectively K-Lipschitz, where K depends 
#         on the number of neurons and weight geometry.
#	•	For a provably 1-Lipschitz network, you need either:
#	•	Spectral normalization (largest singular value ≤ 1)
#	•	Or explicit row/column norm normalization like LipschitzMonotonic.

# ---- GroupSort2 activation ----
'''
class GroupSort2(layers.Layer):
    def call(self, x):
        # x shape: (batch, features)
        # Pad if number of features is odd
        if x.shape[-1] % 2 != 0:
            x = tf.pad(x, [[0,0],[0,1]], "CONSTANT")
        features = tf.shape(x)[-1]
        x = tf.reshape(x, (-1, features // 2, 2))  # group of 2
        x = tf.sort(x, axis=-1)                        # sort each pair
        return tf.reshape(x, (-1, x.shape[-1]*2))     # flatten
'''
@register_keras_serializable()
class GroupSort2(layers.Layer):
    """GroupSort activation with group size 2."""

    def call(self, x):
        # Number of channels (last dimension)
        channels = tf.shape(x)[-1]

        # Pad if odd number of channels
        needs_pad = channels % 2

        # Construct padding dynamically for arbitrary tensor rank
        pad = tf.concat([
            tf.zeros([tf.rank(x) - 1, 2], dtype=tf.int32),  # no pad for all axes except last
            [[needs_pad, 0]]                               # pad last dimension
        ], axis=0)

        # Always pad, easier than tf.cond
        x_pad = tf.pad(x, pad)

        # Split into pairs
        a, b = tf.split(x_pad, 2, axis=-1)

        # Sort each pair
        y = tf.concat([tf.minimum(a, b), tf.maximum(a, b)], axis=-1)

        # Remove padding
        return y[..., :channels]


        

# Spectral Normalization can be approximated using a custom constraint
def SpectralDense(units, name=None):
    return layers.Dense(units, activation=None, kernel_constraint=MaxNorm(max_value=1.0), name=name)
    
def MLP_LipschitzGroupSort(nhid=64):

    inp = keras.layers.Input(shape=(nconstit, nfeat))

    # --- Flatten + normalization ---
    x = layers.BatchNormalization()(inp)
    x = layers.Flatten()(x)

    # ========== Regression tower ==========
    xr = SpectralDense(nhid, name="pt_dense1")(x)
    xr = GroupSort2()(xr)
    xr = SpectralDense(nhid, name="pt_dense2")(xr)
    xr = GroupSort2()(xr)
    pt_out = SpectralDense(1, name="pt_output")(xr)   # Linear output, no activation


    # ========== Classification tower ==========
    xc = SpectralDense(nhid, name="cls_dense1")(x)
    xc = GroupSort2()(xc)
    xc = SpectralDense(nhid, name="cls_dense2")(xc)
    xc = GroupSort2()(xc)
    cls_out = SpectralDense(nclasses, name="cls_output")(xc)   # logits

    # --- Combined model ---
    model = keras.Model(inputs=inp, outputs=[cls_out, pt_out])
    return model   

############################################################################################################

def cosine_decay_restarts_schedule(
    initial_learning_rate: float, first_decay_steps: int, t_mul=1.0, m_mul=1.0, alpha=0.0, alpha_steps=0
):
    def schedule(global_step):
        n_cycle = 1
        cycle_step = global_step
        cycle_len = first_decay_steps
        while cycle_step >= cycle_len:
            cycle_step -= cycle_len
            cycle_len *= t_mul
            n_cycle += 1

        cycle_t = min(cycle_step / (cycle_len - alpha_steps), 1)
        lr = alpha + 0.5 * (initial_learning_rate - alpha) * (1 + cos(pi * cycle_t)) * m_mul ** max(n_cycle - 1, 0)
        return lr

    return schedule


#############################################################################################################
from tensorflow.keras.models import Model
def inspect_layer_outputs(model, sample_inputs, layers_to_check=None, show_sample=True):
    """
    Prints layer-wise output statistics for a given Keras model.

    Parameters
    ----------
    model : keras.Model
        Your trained model.
    sample_inputs : np.array
        Input samples to feed through the model.
    layers_to_check : list of str, optional
        Names of layers to inspect. If None, all layers are inspected.
    show_sample : bool
        Whether to print the first sample's output in detail.
    """
    # Choose layers to inspect
    if layers_to_check is None:
        layers_to_check = [layer.name for layer in model.layers]

    # Get outputs of the chosen layers
    layer_outputs = [model.get_layer(name).output for name in layers_to_check]

    # Create a temporary model that returns all layer outputs
    inspect_model = Model(inputs=model.input, outputs=layer_outputs)

    # Forward pass
    outputs = inspect_model.predict(sample_inputs, verbose=0)

    # Print layer-wise statistics
    for layer_name, out in zip(layers_to_check, outputs):
        print(f"\nLayer: {layer_name}")
        print(f" Output shape: {out.shape}")
        print(f" Min: {out.min():.6f}, Max: {out.max():.6f}, Mean: {out.mean():.6f}")
        if show_sample:
            print(f" First sample output (flattened): {out[0].flatten()}")

## Train the Models
#### https://github.com/calad0i/JEDI-linear/blob/master/src/train.py

In [6]:
from pathlib import Path
import numpy as np
from matplotlib import pyplot as plt
import keras
import pickle as pkl
import random

print('Setting seed...')

seed=42
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)

print('Loading data...')

work_path = '/Users/sznajder/WorkM1/workdir/Notebooks/CMS_L1TML/Models_PerfNanoDataset'
trained_models={}

# Define models dictionary ( name and hyperparams ): Minimal models with about 10K parameters

'''
# Define models dictionary ( name and hyperparams ): Large Models
models_parms = [ {'name': 'mlp_dense',  'nhid':128, 'lr': 0.001, 'nepochs': 100, 'bsz': 25},
                 {'name': 'mlp_einsum', 'nhid':128, 'lr': 0.001, 'nepochs': 100, 'bsz': 258},
                 {'name': 'ds_dense',   'nhid':96, 'lr': 0.001, 'nepochs': 100, 'bsz': 258} ,
                 {'name': 'ds_einsum',  'nhid':96, 'lr': 0.001, 'nepochs': 100, 'bsz': 258} ,
                 {'name': 'mlp_mixer',  'nhid':96, 'lr': 0.001, 'nepochs': 100, 'bsz': 258} ,
                 {'name': 'gnn',        'nhid':96, 'lr': 0.001, 'nepochs': 100, 'bsz': 258} ,
                 {'name': 'mlp_monotonic',   'nhid':64, 'lr': 0.001, 'nepochs': 100, 'bsz': 258} ]
'''

models_parms = [  {'name': 'MLP_dense',  'nhid':38, 'lr': 0.001, 'nepochs': 10},
                 {'name': 'MLP_einsum', 'nhid':38, 'lr': 0.001, 'nepochs': 100},
                 {'name': 'DS_dense',   'nhid':40, 'lr': 0.001, 'nepochs': 100} ,
                 {'name': 'DS_einsum',  'nhid':40, 'lr': 0.001, 'nepochs': 100} ,
                 {'name': 'MLP_mixer',  'nhid':38, 'lr': 0.001, 'nepochs': 100} ,
                 {'name': 'JEDI_linear', 'nhid':38, 'lr': 0.001, 'nepochs': 100} ,
#                 {'name': 'MLP_LipschitzReg',   'nhid':38, 'lr': 0.01, 'nepochs': 100} , 
#                 {'name': 'MLP_LipschitzGroupSort',   'nhid':38, 'lr': 0.01, 'nepochs': 100} 
                ]


# Loop over models
for m in models_parms:

    # Get model name and hyperparameters
    name, nhid, lr, nepochs = m['name'], m['nhid'], m['lr'], m['nepochs']

    # Clear clutter from previous Keras session graphs.
    K.clear_session()
 
    # Build the model and print its params
    #    model = eval(name)(nhid)
    model = globals()[name](nhid)
    model.summary()

    save_path = Path(work_path)
    save_path.mkdir(parents=True, exist_ok=True)


    # Define classification Loss 
    cls_loss = keras.losses.CategoricalCrossentropy(from_logits=True)
#    cls_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

    '''
    # Early stopping callback                       
    es = EarlyStopping(monitor=merit, patience=20)

    # Learning rate scheduler 
    ls = ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, min_lr=0.0000001, min_delta=0.02)
    

    # Define the Score output printou callback
    print_outputs = LambdaCallback(on_epoch_end=lambda epoch, 
                                   logs: print("\nEpoch {} - outputs:\n{}".format(epoch, model.predict(inputs))))
    '''

    
    # Model checkpoint callback to save best model
    chkp = ModelCheckpoint('./models/{}.keras'.format(name), 
                           monitor='val_accuracy', # best generalization 
                           mode='max', 
                           verbose=0, 
                           save_best_only=True, 
                           save_freq='epoch')
                           
    # Define list of callbacks
#    callbk = [nan, es, ls, chkp]
#    callbk = [es, ls, chkp]


    
    # HGQ callbacks
    ebops = FreeEBOPs()

    lr_sched = LearningRateScheduler(cosine_decay_restarts_schedule(lr, 4000, t_mul=1.0, m_mul=0.94, alpha=1e-6, alpha_steps=50))

#    beta_sched = BetaScheduler(PieceWiseSchedule([(0, 5e-7, 'constant'), (4000, 5e-7, 'log'), (200000, 1e-3, 'constant')]))
#    beta_sched = BetaScheduler(PieceWiseSchedule([(0, 2e-8, 'linear'), 
#                                                  (int(nepochs/2.), 3e-7, 'log'), 
#                                                  (nepochs, 3.0e-6, 'constant')]))
    
    beta_sched = BetaScheduler(PieceWiseSchedule([ (0, 1e-8, 'constant'),
                                                   (int(nepochs/3.), 7e-8, 'linear'),
                                                   (2*int(nepochs/3.), 1e-7, 'linear') ]))
    
#    pbar = PBar( 'loss: {loss:.2f}/{val_loss:.2f} - \
#                  acc: {accuracy:.4f}/{val_accuracy:.4f} - \
#                  lr: {learning_rate:.2e} - beta: {beta:.1e}')
    pbar = PBar('loss: {loss:.3f}/{val_loss:.3f} - ' \
                'acc: {cls_output_accuracy:.3f}/{val_cls_output_accuracy:.3f} - ' \
                'lr: {learning_rate:.2e} - beta: {beta:.1e}')

    # Define the ParetoFrontier for best models in terms of  'val_cls_output_accuracy' and 'ebops'
#    pareto = ParetoFront( './train/',['val_accuracy', 'ebops'],[1, -1], \
#                          fname_format='epoch={epoch}-val_acc={val_accuracy:.3f}-ebops={ebops}-val_loss={val_loss:.3f}.keras')
    fname = f'{name}-epoch={{epoch}}-val_acc={{val_cls_output_accuracy:.3f}}-ebops={{ebops}}-val_loss={{val_loss:.3f}}.keras'

    pareto = ParetoFront( path='./models/', metrics=['val_cls_output_accuracy', 'ebops'], 
                          sides=[1, -1],  # maximize accuracy, minimize EBOPs 
                          fname_format=fname )

    # Define model callbacks
    callbk = [ebops, lr_sched, beta_sched, pbar, pareto, chkp ]

   
    # Intial classification and regression accuracies
#    pred, pred_reg = model.predict(X_test, batch_size=bsz, verbose=0)  # type: ignore
    pred, pred_reg = model.predict(X_test, batch_size=bsz, verbose=0)  # type: ignore
    y_true = np.argmax(y_test, axis=1)  # true class indices
    y_pred = np.argmax(pred, axis=1)  # predicted class indices
    acc = np.mean(y_pred == y_true)
    print(f'pre-training Model accuracy: {acc:.2%}')

    mae = mean_absolute_error(y_reg_test, pred_reg)
    mse = mean_squared_error(y_reg_test, pred_reg)
    r2  = r2_score(y_reg_test, pred_reg)
    print(f"Regression MAE: {mae},  MSE: {mse},  R2: {r2}")

    
    # Freeze regression layers
#    for layer in model.layers:
#            layer.trainable = True
    for layer_name in ["pt_output", "pt_dense1", "pt_dense2"]:
        model.get_layer(layer_name).trainable = False

    # Compile Models for pre-training classification  
#    opt = tf.keras.optimizers.legacy.Adam(learning_rate=lr) # Faster on M1 Mac
    opt = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=1.0) # Define optimizer
    model.compile(optimizer=opt, \
                  loss={'cls_output': cls_loss,'pt_output':  'mean_squared_logarithmic_error'}, \
                  loss_weights={'cls_output': 1.0, 'pt_output': 0.0}, \
                  metrics={'cls_output': 'accuracy', 'pt_output':  'mean_squared_logarithmic_error'}, \
                  steps_per_execution=32)

    # Fit the classification head
    print('Pre-Training the model classification:',name)
    history_class = model.fit(X_train, [y_train, y_reg_train], 
                              validation_data=(X_val, [y_val, y_reg_val], [weights_val,weights_reg_val]), 
                              sample_weight=[weights_train,weights_reg_train], 
                              batch_size=bsz, epochs=nepochs, callbacks=callbk, verbose=1)


    
    # Call the function to inspact layer outputs after training to diagnose quantization problems
    inspect_layer_outputs(model, X_test[:5])

    
    # Unfreeze regression layers and freeze everithing else
    for layer in model.layers:
        if layer.name not in ["pt_output", "pt_dense1", "pt_dense2"]:
            layer.trainable = False
        else:
            layer.trainable = True

    # Compile Models for pre-training regression 
    opt = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=1.0) # Define optimizer
    model.compile( optimizer=opt, \
                   loss={ "cls_output": cls_loss, 'pt_output': 'mean_squared_logarithmic_error'}, \
                   loss_weights={'cls_output': 0, 'pt_output': 1.0}, \
                   metrics={'cls_output': 'accuracy', 'pt_output': 'mean_squared_logarithmic_error' }, \
                   steps_per_execution=32 )

 
    # Fit the regression head
    print('Pre-Training the model regression:',name)
    history_reg = model.fit(X_train, [y_train, y_reg_train], \
                            validation_data=(X_val, [y_val, y_reg_val], [weights_val,weights_reg_val]), \
                            sample_weight=[weights_train,weights_reg_train], \
                            batch_size=bsz, epochs=nepochs, callbacks=callbk, verbose=1)
    '''
    # Unfreeze all layers
    for layer in model.layers:
            layer.trainable = True
    
    # Compile Models for fine-tunning complete model 
    opt = tf.keras.optimizers.Adam(learning_rate=lr/1000, clipnorm=1.0) # Define optimizer
    model.compile( optimizer=opt, \
                   loss={ 'cls_output': cls_loss, 'pt_output': 'mean_squared_logarithmic_error'}, \
                   loss_weights={'cls_output': 0, 'pt_output': 1.0}, \
                   metrics={'cls_output': 'accuracy', 'pt_output': 'mean_squared_logarithmic_error' }, \
                   steps_per_execution=32 )

 
    # Fit for fine-tune the complete model
    print('Fine-Tune the model:',name)
    history_finetune = model.fit(X_train, [y_train, y_reg_train], \
                                 validation_data=(X_val, [y_val, y_reg_val], [weights_val,weights_reg_val]), \
                                 sample_weight=[weights_train,weights_reg_train], \
                                 batch_size=bsz, epochs=nepochs, callbacks=callbk, verbose=1)



#    with open(save_path / 'history.pkl', 'wb') as f:
#        f.write(pkl.dumps(history))
    '''
    
#    model.save('last_model.keras')

    # After training finishes retrieve the best model saved by the checkpoint
    best_model = load_model('./models/{}.keras'.format(name), custom_objects={"LipschitzReg": LipschitzReg})
    trained_models[name] = best_model






Setting seed...
Loading data...


pre-training Model accuracy: 18.50%
Regression MAE: 235.22256469726562,  MSE: 84828.8046875,  R2: -1.873410940170288
Pre-Training the model classification: MLP_dense


  0%|                                                 | 0/10 [00:00<?, ?epoch/s]

Epoch 1/10
[1m864/874[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 11ms/step - cls_output_accuracy: 0.5534 - cls_output_loss: 2.6811e-05 - loss: 0.0071 - pt_output_loss: 26.5830 - pt_output_mean_squared_logarithmic_error: 26.5830

  if self._should_save_model(epoch, batch, logs, filepath):


[1m874/874[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 15ms/step - cls_output_accuracy: 0.7051 - cls_output_loss: 2.1036e-05 - loss: 0.0070 - pt_output_loss: 26.6564 - pt_output_mean_squared_logarithmic_error: 26.6564 - val_cls_output_accuracy: 0.8817 - val_cls_output_loss: 1.2756e-05 - val_loss: 2.5190e-04 - val_pt_output_loss: 26.8561 - val_pt_output_mean_squared_logarithmic_error: 26.8557 - ebops: 658648.0000 - learning_rate: 0.0010 - beta: 1.0000e-08
Epoch 2/10
[1m864/874[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 4ms/step - cls_output_accuracy: 0.8385 - cls_output_loss: 1.5157e-05 - loss: 0.0064 - pt_output_loss: 26.7502 - pt_output_mean_squared_logarithmic_error: 26.7502

loss: 0.006/0.000 - acc: 0.828/0.907 - lr: 1.00e-03 - beta: 1.0e-08 - EBOPs: 566

[1m874/874[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 7ms/step - cls_output_accuracy: 0.8283 - cls_output_loss: 1.5678e-05 - loss: 0.0061 - pt_output_loss: 26.7865 - pt_output_mean_squared_logarithmic_error: 26.7865 - val_cls_output_accuracy: 0.9071 - val_cls_output_loss: 1.2280e-05 - val_loss: 2.2055e-04 - val_pt_output_loss: 26.9193 - val_pt_output_mean_squared_logarithmic_error: 26.9188 - ebops: 566275.0000 - learning_rate: 1.0000e-03 - beta: 1.0000e-08
Epoch 3/10
[1m864/874[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 4ms/step - cls_output_accuracy: 0.8555 - cls_output_loss: 1.3844e-05 - loss: 0.0057 - pt_output_loss: 26.8047 - pt_output_mean_squared_logarithmic_error: 26.8047

loss: 0.006/0.000 - acc: 0.858/0.928 - lr: 1.00e-03 - beta: 1.0e-08 - EBOPs: 548

[1m874/874[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 8ms/step - cls_output_accuracy: 0.8575 - cls_output_loss: 1.3597e-05 - loss: 0.0057 - pt_output_loss: 26.8346 - pt_output_mean_squared_logarithmic_error: 26.8346 - val_cls_output_accuracy: 0.9277 - val_cls_output_loss: 1.0446e-05 - val_loss: 1.8799e-04 - val_pt_output_loss: 26.9454 - val_pt_output_mean_squared_logarithmic_error: 26.9449 - ebops: 548290.0000 - learning_rate: 1.0000e-03 - beta: 1.0000e-08
Epoch 4/10
[1m864/874[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 4ms/step - cls_output_accuracy: 0.4090 - cls_output_loss: 2.5648e-05 - loss: 0.0332 - pt_output_loss: 26.8244 - pt_output_mean_squared_logarithmic_error: 26.8244

loss: 0.031/0.000 - acc: 0.174/0.012 - lr: 1.00e-03 - beta: 7.0e-08 - EBOPs: 369

[1m874/874[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 7ms/step - cls_output_accuracy: 0.1744 - cls_output_loss: 3.1644e-05 - loss: 0.0312 - pt_output_loss: 26.8473 - pt_output_mean_squared_logarithmic_error: 26.8473 - val_cls_output_accuracy: 0.0120 - val_cls_output_loss: 3.5055e-05 - val_loss: 1.3021e-04 - val_pt_output_loss: 26.9391 - val_pt_output_mean_squared_logarithmic_error: 26.9386 - ebops: 369519.0000 - learning_rate: 1.0000e-03 - beta: 7.0000e-08
Epoch 5/10
[1m864/874[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 4ms/step - cls_output_accuracy: 0.0122 - cls_output_loss: 3.5019e-05 - loss: 0.0294 - pt_output_loss: 26.8206 - pt_output_mean_squared_logarithmic_error: 26.8206

loss: 0.029/0.000 - acc: 0.012/0.012 - lr: 1.00e-03 - beta: 8.0e-08 - EBOPs: 298

[1m874/874[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 7ms/step - cls_output_accuracy: 0.0123 - cls_output_loss: 3.5198e-05 - loss: 0.0286 - pt_output_loss: 26.8456 - pt_output_mean_squared_logarithmic_error: 26.8456 - val_cls_output_accuracy: 0.0120 - val_cls_output_loss: 3.5055e-05 - val_loss: 7.1479e-05 - val_pt_output_loss: 26.9391 - val_pt_output_mean_squared_logarithmic_error: 26.9386 - ebops: 298673.0000 - learning_rate: 1.0000e-03 - beta: 8.0000e-08
Epoch 6/10
[1m864/874[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 4ms/step - cls_output_accuracy: 0.0122 - cls_output_loss: 3.5019e-05 - loss: 0.0265 - pt_output_loss: 26.8206 - pt_output_mean_squared_logarithmic_error: 26.8206

KeyboardInterrupt: 

## Test/Compare the Models
#### https://github.com/calad0i/JEDI-linear/blob/master/src/test.py

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
from sklearn.metrics import (
    roc_curve, auc, roc_auc_score,
    mean_squared_error, mean_absolute_error, r2_score
)
from sklearn.preprocessing import label_binarize

# ---------------------------------------------------------------------------
# 1. Evaluate classification + regression for all models
# ---------------------------------------------------------------------------

results = {}

fpr, tpr, roc_auc = {}, {}, {}
inv_fpr_at_tpr = {}
regression_metrics = {}

n_classes = 5
class_names = ['Top', 'HQQ', 'HTauTau', 'Wqq', 'QCD']
target_tpr = 0.8

# Binarize labels for ROC curves
#y_test_bin = label_binarize(y_test, classes=np.arange(n_classes))
y_test_bin = y_test
y_test_pt  = y_reg_test

for name, model in trained_models.items():
    print(f"Evaluating model {name} ...")

    # Predict
    cls_pred_logits, pt_pred = model.predict(X_test, verbose=0)
    cls_pred = tf.nn.softmax(cls_pred_logits, axis=-1).numpy()
    pt_pred = pt_pred.reshape(-1)

#    print("Y_TRUE=",y_test_bin[:4])
#    print("Y_PRED=",cls_pred[:4])
#    print("------------------------------")

    # -----------------------------
    # Classification Metrics
    # -----------------------------
    fpr[name], tpr[name], roc_auc[name], inv_fpr_at_tpr[name] = {}, {}, {}, {}

    for c in range(n_classes):
        fpr[name][c], tpr[name][c], _ = roc_curve(y_test_bin[:, c], cls_pred[:, c])
        roc_auc[name][c] = auc(fpr[name][c], tpr[name][c])

        # compute 1/FPR @ TPR ≈ 0.8
        idx = np.argmin(np.abs(tpr[name][c] - target_tpr))
        fpr_val = fpr[name][c][idx]
        inv_fpr_at_tpr[name][c] = np.inf if fpr_val == 0 else 1.0 / fpr_val

    # -----------------------------
    # Regression Metrics
    # -----------------------------
    mse  = mean_squared_error(y_test_pt, pt_pred)
    rmse = np.sqrt(mse)
    mae  = mean_absolute_error(y_test_pt, pt_pred)
    r2   = r2_score(y_test_pt, pt_pred)

    regression_metrics[name] = {
        "mse": mse,
        "rmse": rmse,
        "mae": mae,
        "r2": r2,
    }

    # store summary
    results[name] = {
        "roc_auc": roc_auc[name],
        "inv_fpr_at_tpr": inv_fpr_at_tpr[name],
        "regression": regression_metrics[name],
    }


# ---------------------------------------------------------------------------
# 2. Plot ROC curves
# ---------------------------------------------------------------------------

for c in range(n_classes):
    plt.figure(figsize=(7,5))
    for name in trained_models.keys():
        plt.plot(
            fpr[name][c], tpr[name][c], lw=2,
            label=f"{name} (AUC={roc_auc[name][c]:.3f})"
        )

    plt.plot([0,1], [0,1], "k--", lw=1)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve — {class_names[c]}")
    plt.legend(loc="lower right")
    plt.grid(ls="--", alpha=0.5)
    plt.show()


# ---------------------------------------------------------------------------
# 3. Plot background rejection (1/FPR @ TPR=0.8)
# ---------------------------------------------------------------------------

for c in range(n_classes):
    plt.figure(figsize=(7,5))

    model_names = list(trained_models.keys())
    values = [inv_fpr_at_tpr[m][c] for m in model_names]

    plt.bar(model_names, values, alpha=0.8, color='royalblue')
    plt.ylabel(r"$1/\mathrm{FPR}$ @ TPR=0.8")
    plt.title(f"Background Rejection — {class_names[c]}")
    plt.xticks(rotation=45, ha='right')
    plt.grid(axis='y', ls='--')
    plt.tight_layout()
    plt.show()


# ---------------------------------------------------------------------------
# 4. Plot regression: True Pt vs Predicted Pt
# ---------------------------------------------------------------------------

for name, model in trained_models.items():
    _, pt_pred = model.predict(X_test, verbose=0)
    pt_pred = pt_pred.reshape(-1)

    plt.figure(figsize=(7,5))
    plt.scatter(y_test_pt, pt_pred, s=3, alpha=0.3)
    plt.xlabel("True Jet $p_T$")
    plt.ylabel("Predicted Jet $p_T$")
    plt.title(f"Regression Scatter Plot — {name}")
    plt.grid(ls="--", alpha=0.5)
    plt.tight_layout()
    plt.show()


# ---------------------------------------------------------------------------
# 5. Plot regression residuals
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# 5. Plot regression residuals with mean & RMS in legend/text
# ---------------------------------------------------------------------------
import matplotlib.pyplot as plt
import numpy as np

for name, model in trained_models.items():
    print(f"Plotting residuals for {name}...")
    
    _, pt_pred = model.predict(X_test, verbose=0)
    pt_pred = pt_pred.reshape(-1)
    pt_true = y_reg_test.flatten()
    
    residuals = pt_true - pt_pred
    
    # Compute statistics
    mean_res = np.mean(residuals)
    std_res  = np.std(residuals)
    rms_res  = np.sqrt(np.mean(residuals**2))  # sometimes preferred over std
    
    plt.figure(figsize=(8, 5.5))
    
    # Histogram
    counts, bins, _ = plt.hist(residuals, bins=100, range=(-500, 500), alpha=0.75, color='steelblue',
                               edgecolor='black', linewidth=0.5, density=False)
    
    # Optional: overplot a nice label box
    textstr = '\n'.join([
        rf'Mean (bias) = ${mean_res:+.2f}\ \mathrm{{GeV}}$',
        rf'Std. dev.  = ${std_res:.2f}\ \mathrm{{GeV}}$',
        rf'RMS         = ${rms_res:.2f}\ \mathrm{{GeV}}$',
        rf'Entries     = {len(residuals):,}'
    ])
    
    # Place a beautiful box in the upper right (adjust loc if needed)
    props = dict(boxstyle='round', facecolor='white', alpha=0.85, edgecolor='gray')
    plt.text(0.95, 0.95, textstr, transform=plt.gca().transAxes,
             fontsize=11, verticalalignment='top', horizontalalignment='right',
             bbox=props)

    plt.xlabel(r'Residual  $(p_T^\mathrm{true} - p_T^\mathrm{pred})$ [GeV]', fontsize=13)
    plt.ylabel('Jets', fontsize=13)
    plt.title(f'Regression Residuals — {name}', fontsize=14, pad=15)
    plt.grid(True, alpha=0.3, ls='--')
    plt.xlim(-500, 500)   # adjust if your residuals are wider
    plt.tight_layout()
    plt.show()
    
    # Also print to console
    print(f"{name} → Bias: {mean_res:+.3f} GeV | Std: {std_res:.3f} GeV | RMS: {rms_res:.3f} GeV")
    

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
import tensorflow as tf

# ------------------------------------------------------------------
# 1. Settings
# ------------------------------------------------------------------
target_tpr = 0.80                         # we want 1/FPR @ 80% signal efficiency
class_names = ['Top', 'HQQ', 'HTauTau', 'Wqq', 'QCD']
colors = plt.cm.tab10(np.linspace(0, 1, len(class_names)))

# Pre-compute integer labels once (faster + cleaner)
y_test_int = np.argmax(y_test, axis=1)

# Containers
fpr_dict      = {}
tpr_dict      = {}
auc_dict      = {}
rejection_dict = {}   # this will hold 1/FPR @ TPR=0.8

# ------------------------------------------------------------------
# 2. Loop over all trained models
# ------------------------------------------------------------------
for name, model in trained_models.items():
    print(f"Evaluating ROC for model: {name}")

    # Predict logits → convert to probabilities (CRITICAL!)
    cls_logits, _ = model.predict(X_test, verbose=0)
    cls_proba = tf.nn.softmax(cls_logits, axis=-1).numpy()

    fpr_dict[name]      = {}
    tpr_dict[name]      = {}
    auc_dict[name]      = {}
    rejection_dict[name] = {}

    for c in range(n_classes):
        # BEST WAY: binary label + probability of class c
        fpr, tpr, _ = roc_curve(y_test_int == c, cls_proba[:, c])
        roc_auc = auc(fpr, tpr)

        # Store
        fpr_dict[name][c] = fpr
        tpr_dict[name][c] = tpr
        auc_dict[name][c] = roc_auc

        # -------------------------------
        # 1/FPR @ exactly TPR = 0.80
        # -------------------------------
        # Find the point closest to 80%
        idx = np.argmin(np.abs(tpr - target_tpr))

        # Interpolate for more precision (optional but nicer)
        if idx > 0 and idx < len(tpr)-1:
            tpr_low  = tpr[idx-1]
            tpr_high = tpr[idx]
            fpr_low  = fpr[idx-1]
            fpr_high = fpr[idx]

            if tpr_high != tpr_low:
                # Linear interpolation in log(FPR) space
                alpha = (target_tpr - tpr_low) / (tpr_high - tpr_low)
                fpr_at_80 = np.exp(np.log(fpr_low) * (1-alpha) + np.log(fpr_high) * alpha)
            else:
                fpr_at_80 = fpr[idx]
        else:
            fpr_at_80 = fpr[idx]

        # Safe handling of FPR = 0 → infinite rejection
        if fpr_at_80 <= 0 or np.isnan(fpr_at_80):
            rejection = 1e6   # visual cap (you can change to any large number)
        else:
            rejection = 1.0 / fpr_at_80

        rejection_dict[name][c] = rejection

# ------------------------------------------------------------------
# 3. Plot Background Rejection (1/FPR @ TPR=0.8) per class
# ------------------------------------------------------------------
model_names = list(trained_models.keys())

for c in range(n_classes):
    plt.figure(figsize=(9, 6))
    
    values = [rejection_dict[model][c] for model in model_names]
    
    # Cap very large values for nice plotting (optional)
    capped_values = [min(v, 15000) if not np.isinf(v) else 15000 for v in values]
    
    bars = plt.bar(model_names, capped_values, color=colors[c], alpha=0.8, edgecolor='black', linewidth=1.2)
    
    # Write real value on top of each bar
    for bar, real_val in zip(bars, values):
        if real_val >= 15000:
            txt = "∞" if np.isinf(real_val) else f"{real_val:,.0f}"
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 300,
                     txt, ha='center', va='bottom', fontweight='bold', fontsize=11)
        else:
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 100,
                     f"{real_val:,.0f}", ha='center', va='bottom', fontsize=10)

    plt.ylabel(r"Background Rejection = $1/\mathrm{FPR}$ @ TPR = 80%", fontsize=13)
    plt.title(f"Background Rejection — {class_names[c]}", fontsize=15, pad=20)
    plt.ylim(0, 16000)
    plt.xticks(rotation=45, ha='right')
    plt.grid(axis='y', alpha=0.3, ls='--')
    plt.tight_layout()
    plt.show()

In [None]:
from pathlib import Path
import numpy as np
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
import json
from tqdm import tqdm
import matplotlib.pyplot as plt


# ---------------------------------------------------------
# 1. Plot helper
# ---------------------------------------------------------

def plot_history(history, metrics, ylabel="Value", title=None, logy=False):
    fig, ax = plt.subplots(figsize=(6,4))
    for metric in metrics:
        if metric in history.history:
            ax.plot(history.history[metric], label=metric)
        else:
            print(f"[WARN] Metric {metric} not found in history.")
    ax.set_xlabel("Epoch")
    ax.set_ylabel(ylabel)
    if title:
        ax.set_title(title)
    if logy:
        ax.set_yscale("log")
    ax.legend()
    return fig, ax



# ---------------------------------------------------------
# 2. Unified TEST function (classification + regression)
# ---------------------------------------------------------

def test_model(model, save_path: Path, X_train, X_test, Y_test_cls, Y_test_pt):
    """
    Evaluates a multi-task model with:
      - classification head: cls_output
      - regression head: pt_output
    Computes several metrics and saves JSON.

    Arguments:
      X_train: needed for trace_minmax
      Y_test_cls: integer labels for classification
      Y_test_pt: true jet Pt for regression
    """

    results = {}
    save_path = Path(save_path)
    (save_path / "models").mkdir(exist_ok=True)

    ckpts = list(save_path.glob("ckpts/*.keras"))
    pbar = tqdm(ckpts)

    X_train = np.array(X_train, np.float32)
    X_test  = np.array(X_test,  np.float32)

    Y_test_cls = np.array(Y_test_cls)
    Y_test_pt  = np.array(Y_test_pt).reshape(-1)

    for ckpt in pbar:
        model.load_weights(ckpt)

        # Optional: normalize activations
        # trace_minmax(model, X_train, batch_size=16384)

        # Predictions
        cls_pred, pt_pred = model.predict(X_test, batch_size=16384, verbose=0)
        cls_pred_labels = np.argmax(cls_pred, axis=1)
        pt_pred = pt_pred.reshape(-1)

        # Compute classification
        cls_acc = np.mean(cls_pred_labels == Y_test_cls)

        # Regression metrics
        mse  = mean_squared_error(Y_test_pt, pt_pred)
        rmse = np.sqrt(mse)
        mae  = mean_absolute_error(Y_test_pt, pt_pred)
        r2   = r2_score(Y_test_pt, pt_pred)

        # EBOPs
        ebops = sum(float(layer.ebops) for layer in model.layers if hasattr(layer, "ebops"))

        # Store
        results[ckpt.name] = {
            "classification_accuracy": float(cls_acc),
            "regression_mse": float(mse),
            "regression_rmse": float(rmse),
            "regression_mae": float(mae),
            "regression_r2": float(r2),
            "ebops": ebops,
        }

        # Save a copy of the model with these weights
        model.save(save_path / "models" / f"{ckpt.stem}.keras")

        pbar.set_description(
            f"Acc: {cls_acc:.4f}, RMSE: {rmse:.4f}, R2: {r2:.4f} @ {ebops:.0f} EBOPs"
        )

    # Save all results
    with open(save_path / "test_results.json", "w") as f:
        json.dump(results, f, indent=2)

    return results



# ---------------------------------------------------------
# 3. Run test + generate plots
# ---------------------------------------------------------

print("Running unified evaluation...")

results = test_model(
    model=model,
    save_path=Path(save_path),
    X_train=X_train_val,
    X_test=X_test,
    Y_test_cls=y_test,        # classification labels
    Y_test_pt=y_reg_test      # regression true Pt values
)

print("Saved test metrics to: {}/test_results.json".format(save_path))



# ---------------------------------------------------------
# 4. PLOTS
# ---------------------------------------------------------

# --- Total + component losses
plot_history(
    history,
    metrics=[
        "loss", "val_loss",
        "cls_output_loss", "val_cls_output_loss",
        "pt_output_loss",  "val_pt_output_loss",
    ],
    ylabel="Loss",
    title="Loss Curves",
    logy=True,
);

# --- Classification accuracy
plot_history(
    history,
    metrics=[
        "cls_output_accuracy",
        "val_cls_output_accuracy"
    ],
    ylabel="Accuracy",
    title="Classification Accuracy",
);

# --- Regression MSE
plot_history(
    history,
    metrics=[
        "pt_output_mse",
        "val_pt_output_mse"
    ],
    ylabel="MSE",
    title="Regression MSE",
);

### Feature importance based on SHAP, Permutation and Gradient Methods

In [None]:
import numpy as np
import shap
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.metrics import log_loss
from pathlib import Path

save_path = Path(work_path)
save_path.mkdir(parents=True, exist_ok=True)

'''
# ------------------------------------------------------------
# Helper functions
# ------------------------------------------------------------
def permutation_importance(model, X, y, metric=log_loss):
    """Model-agnostic permutation importance."""
    baseline = metric(y, model.predict(X))
    importances = []
    for i in range(X.shape[1]):
        Xp = X.copy()
        np.random.shuffle(Xp[:, i])
        score = metric(y, model.predict(Xp))
        importances.append(score - baseline)
    return np.array(importances)

def gradient_importance(model, X):
    """Mean absolute gradient of output w.r.t. input."""
    X_tf = tf.convert_to_tensor(X)
    with tf.GradientTape() as tape:
        tape.watch(X_tf)
        preds = model(X_tf)
    grads = tape.gradient(preds, X_tf)
    return tf.reduce_mean(tf.abs(grads), axis=0).numpy()

def plot_importances(importances, feature_names, title, savefile):
    """Generic plotting function."""
    sorted_idx = np.argsort(importances)[::-1]
    plt.figure(figsize=(8, 5))
    plt.bar(range(len(importances)), importances[sorted_idx])
    plt.xticks(range(len(importances)), np.array(feature_names)[sorted_idx],
               rotation=90, fontsize=8)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(savefile, bbox_inches="tight")
    plt.close()

# ------------------------------------------------------------
# SHAP background data
# ------------------------------------------------------------
background = X_test[np.random.choice(X_test.shape[0], 100, replace=False)]
X_explain = X_test[:200]

# If you have feature names, define them; otherwise, use indices
try:
    feature_names = X_test.columns
except Exception:
    feature_names = [f"feat_{i}" for i in range(X_test.shape[1])]

# ------------------------------------------------------------
# Loop over trained models
# ------------------------------------------------------------
for name, model in trained_models.items():
    print(f"\n=== Evaluating feature importance for {name} ===")

    # --- SHAP ---
    try:
        print(" → Computing SHAP values...")
        explainer = shap.DeepExplainer(model, background)
        shap_values = explainer.shap_values(X_explain)
        shap.summary_plot(shap_values, X_explain, show=False)
        plt.savefig(save_path / f'shap_summary_{name}.png', bbox_inches='tight')
        plt.close()
        np.save(save_path / f'shap_values_{name}.npy', shap_values)
        print("   ✓ SHAP values saved.")
    except Exception as e:
        print(f"   ⚠️ SHAP failed for {name}: {e}")

    # --- Permutation Importance ---
    try:
        print(" → Computing permutation importances...")
        importances = permutation_importance(model, X_test, y_test)
        np.save(save_path / f'perm_importance_{name}.npy', importances)

        # Plot permutation importances
        plot_importances(importances, feature_names,
                         f"Permutation Importance — {name}",
                         save_path / f'perm_importance_{name}.png')
        print("   ✓ Permutation importances saved and plotted.")
    except Exception as e:
        print(f"   ⚠️ Permutation importance failed for {name}: {e}")

    # --- Gradient-based Importance ---
    try:
        print(" → Computing gradient-based importances...")
        grads = gradient_importance(model, X_test[:500])  # subset for speed
        np.save(save_path / f'gradients_{name}.npy', grads)

        # Plot gradient importances
        plot_importances(grads, feature_names,
                         f"Gradient-based Importance — {name}",
                         save_path / f'gradients_{name}.png')
        print("   ✓ Gradient importances saved and plotted.")
    except Exception as e:
        print(f"   ⚠️ Gradient-based importance failed for {name}: {e}")
'''


import numpy as np
import shap
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.metrics import log_loss

# ------------------------------------------------------------
# Helper functions
# ------------------------------------------------------------
def permutation_importance(model, X, y, metric=log_loss):
    """Model-agnostic permutation importance."""
    baseline = metric(y, model.predict(X))
    importances = []
    for i in range(X.shape[1]):
        Xp = X.copy()
        np.random.shuffle(Xp[:, i])
        score = metric(y, model.predict(Xp))
        importances.append(score - baseline)
    return np.array(importances)

def gradient_importance(model, X):
    """Mean absolute gradient of output w.r.t. input."""
    X_tf = tf.convert_to_tensor(X)
    with tf.GradientTape() as tape:
        tape.watch(X_tf)
        preds = model(X_tf)
    grads = tape.gradient(preds, X_tf)
    return tf.reduce_mean(tf.abs(grads), axis=0).numpy()

def plot_importances_inline(importances, feature_names, title, ax):
    """Generic plotting function for inline plotting."""
    sorted_idx = np.argsort(importances)[::-1]
    ax.bar(range(len(importances)), importances[sorted_idx])
    ax.set_xticks(range(len(importances)))
    ax.set_xticklabels(np.array(feature_names)[sorted_idx], rotation=90, fontsize=8)
    ax.set_title(title)
    ax.grid(alpha=0.3)


# ------------------------------------------------------------
# SHAP background and feature names
# ------------------------------------------------------------
background = X_test[np.random.choice(X_test.shape[0], 100, replace=False)]
X_explain = X_test[:200]

try:
    feature_names = X_test.columns
except Exception:
    feature_names = [f"feat_{i}" for i in range(X_test.shape[1])]


# ------------------------------------------------------------
# Loop over trained models
# ------------------------------------------------------------
for name, model in trained_models.items():
    print(f"\n=== Feature importance for model: {name} ===\n")

    # ------------------------------------------------------------
    # 1) SHAP
    # ------------------------------------------------------------
    try:
        print(" → SHAP values...")
        explainer = shap.DeepExplainer(model, background)
        shap_values = explainer.shap_values(X_explain)

        # Display summary in notebook
        shap.summary_plot(shap_values, X_explain, feature_names=feature_names, show=True)

    except Exception as e:
        print(f"   ⚠️ SHAP failed: {e}")

    # ------------------------------------------------------------
    # Prepare 6 side-by-side axes (3 per row, 2 rows)
    # ------------------------------------------------------------
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    axes = axes.flatten()
    plot_idx = 0

    # ------------------------------------------------------------
    # 2) Permutation Importance
    # ------------------------------------------------------------
    try:
        print(" → Permutation importance...")
        perm_imp = permutation_importance(model, X_test, y_test)
        plot_importances_inline(perm_imp, feature_names,
                                f"Permutation – {name}", axes[plot_idx])
        plot_idx += 1
    except Exception as e:
        print(f"   ⚠️ Permutation failed: {e}")

    # ------------------------------------------------------------
    # 3) Gradient Importance
    # ------------------------------------------------------------
    try:
        print(" → Gradient importance...")
        grads = gradient_importance(model, X_test[:500])
        plot_importances_inline(grads, feature_names,
                                f"Gradient – {name}", axes[plot_idx])
        plot_idx += 1
    except Exception as e:
        print(f"   ⚠️ Gradient failed: {e}")

    # Format spare plots cleanly
    for k in range(plot_idx, len(axes)):
        axes[k].axis("off")

    plt.tight_layout()
    plt.show()

    

## Plot Features for sanity check

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Configuration
CLASSES = list(class_labels.keys())       # ["Top","HQQ","HTauTau","Wqq","QCD"]
colors  = {"Top":"red", "HQQ":"blue", "HTauTau":"purple", "Wqq":"green", "QCD":"orange"}
ncols = 3  # number of plots per row

# -------------------------------
# Constituent features (step histograms)
# -------------------------------
n_features = len(const_feature_names)
nrows = int(np.ceil(n_features / ncols))

fig, axes = plt.subplots(nrows, ncols, figsize=(5*ncols, 4*nrows))
axes = axes.flatten()

for ifeat, featname in enumerate(const_feature_names):
    ax = axes[ifeat]
    
    for cls in CLASSES:
        cls_label_idx = label_to_idx[class_labels[cls]]
        mask = labels_array == cls_label_idx
        
        Xc = X_constituents[mask, :, ifeat].flatten()
        
        ax.hist(
            Xc,
            bins=60,
            density=True,
            histtype='step',        # contour only
            linewidth=1.5,
            label=cls,
            color=colors[cls]
        )
    
    ax.set_title(featname)
    ax.set_xlabel(featname)
    ax.set_ylabel("Normalized entries")
    ax.grid(alpha=0.3)
    ax.legend(fontsize=9)

# remove empty axes
for i in range(n_features, len(axes)):
    fig.delaxes(axes[i])

fig.tight_layout()
plt.show()

# -------------------------------
# Jet-level features (step histograms)
# -------------------------------
n_features = len(jet_feature_names)
nrows = int(np.ceil(n_features / ncols))

fig, axes = plt.subplots(nrows, ncols, figsize=(5*ncols, 4*nrows))
axes = axes.flatten()

for ifeat, featname in enumerate(jet_feature_names):
    ax = axes[ifeat]
    
    for cls in CLASSES:
        cls_label_idx = label_to_idx[class_labels[cls]]
        mask = labels_array == cls_label_idx
        
        Xj = X_jets[mask, ifeat]
        
        ax.hist(
            Xj,
            bins=60,
            density=True,
            histtype='step',        # contour only
            linewidth=1.5,
            label=cls,
            color=colors[cls]
        )
    
    ax.set_title(featname)
    ax.set_xlabel(featname)
    ax.set_ylabel("Normalized entries")
    ax.grid(alpha=0.3)
    ax.legend(fontsize=9)

# remove empty axes
for i in range(n_features, len(axes)):
    fig.delaxes(axes[i])

fig.tight_layout()
plt.show()
