In [1]:
import gc
import numpy as np
import polars as pl
import tensorflow as tf
from pathlib import Path
from tensorflow import shape, minimum
from tensorflow.keras import backend as k
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import pad_sequences, Sequence, to_categorical
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import (
    Dense, Input, Conv1D, MaxPooling1D, GlobalAveragePooling1D, GlobalMaxPooling1D, Concatenate,
    BatchNormalization, GRU, Dropout, add, Activation, Multiply, Reshape,
    LayerNormalization, Add, Bidirectional, LSTM, UpSampling1D, Lambda, GaussianNoise, MultiHeadAttention
)
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder, RobustScaler,StandardScaler
from sklearn.metrics import classification_report, accuracy_score

from src.nn_blocks import tof_block, residual_se_cnn_block, TransformerBlock, tof_block_2, features_processing, unet_se_cnn

NUM_CLASSES = 18


# --- Gated Model 1: Based on CNN-RNN Hybrid ---
def create_gated_cnn_rnn(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # IMU branch
    x1 = residual_se_cnn_block(imu, 64, 3, drop=0.2, wd=wd) # Output: (None, 64, 64)
    x1 = residual_se_cnn_block(x1, 128, 5, drop=0.2, wd=wd) # Output: (None, 32, 128)
    x1 = Bidirectional(LSTM(128, return_sequences=True, kernel_regularizer=l2(wd)))(x1) # Output: (None, 32, 256)
    
    # Standard ToF branch
    x2 = tof_block_2(tof, wd) # Output: (None, 32, 128)

    # --- FIX: Project x2 to match x1's feature dimension before processing ---
    x2_projected = Dense(256, activation='relu')(x2)

    # Now both inputs to features_processing have shape (None, 32, 256)
    x = features_processing(x1, x2_projected)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# --- Gated Model 2: Based on UNet_Style ---
def create_gated_unet(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # IMU branch
    x1 = unet_se_cnn(imu, unet_depth=4, base_filters=64, kernel_size=5, drop=0.3) # Output: (None, 128, 64)
    
    # Standard ToF branch
    x2 = tof_block_2(tof, wd) # Output: (None, 32, 128)

    # We will use a simpler approach for this model.
    x1_pooled = GlobalAveragePooling1D()(x1)
    x2_pooled = GlobalAveragePooling1D()(x2)
    x = Concatenate()([x1_pooled, x2_pooled])
    
    x = Dropout(0.4)(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# --- Gated Model 3: Based on CNN_Transformer ---
def create_gated_cnn_transformer(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # IMU branch
    x1 = residual_se_cnn_block(imu, 64, 3, drop=0.2, wd=wd) # Output: (None, 64, 64)
    x1 = residual_se_cnn_block(x1, 128, 5, drop=0.2, wd=wd) # Output: (None, 32, 128)
    x1 = TransformerBlock(embed_dim=128, num_heads=4, ff_dim=128, rate=0.3)(x1) # Output: (None, 32, 128)
    x1 = residual_se_cnn_block(x1, 64, 3, drop=0.2, wd=wd) # Output: (None, 16, 64)
    x1 = residual_se_cnn_block(x1, 128, 5, drop=0.2, wd=wd) # Output: (None, 8, 128)    
    x1 = TransformerBlock(embed_dim=128, num_heads=4, ff_dim=128, rate=0.3)(x1) # Output: (None, 8, 128)
    
    # Standard ToF branch
    x2 = tof_block_2(tof, wd) # Output: (None, 32, 128)
    x2 = tf.keras.layers.MaxPooling1D(4)(x2)

    x = features_processing(x1, x2)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

def best_unet_1(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    x1 = unet_se_cnn(imu, 3, base_filters=128, kernel_size=3)
    x2 = tof_block(tof, wd)

    x = features_processing(x1, x2)
    x = tf.keras.layers.Dropout(0.3)(x) 
    main_out = tf.keras.layers.Dense(18, activation="softmax", name="main_output")(x)
    gate_out = tf.keras.layers.Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return tf.keras.models.Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

def best_unet_2(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    x1 = unet_se_cnn(imu, 3, base_filters=128, kernel_size=3)
    x2 = tof_block_2(tof, wd)

    x = features_processing(x1, x2)
    x = tf.keras.layers.Dropout(0.3)(x) 
    main_out = tf.keras.layers.Dense(18, activation="softmax", name="main_output")(x)
    gate_out = tf.keras.layers.Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return tf.keras.models.Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

2025-09-06 09:19:54.525213: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757146794.721344 1934972 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1757146794.781828 1934972 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1757146795.422333 1934972 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1757146795.422399 1934972 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1757146795.422402 1934972 computation_placer.cc:177] computation placer alr

In [2]:
# =====================================================================================
# 5 NEW ADVANCED MODEL ARCHITECTURES
# =====================================================================================

from src.nn_blocks import match_time_steps, wave_block, res_se_cnn_decoder_block

# --- Advanced Model 2: Stacked Transformer Tower ---
def create_advanced_model_2_transformer_tower(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # Strong CNN backbone to create rich features for the Transformer
    x1 = residual_se_cnn_block(imu, 64, 3, drop=0.2, wd=wd)
    x1 = residual_se_cnn_block(x1, 128, 5, drop=0.2, wd=wd) # Output shape: (None, 32, 128)
    
    # Stacked Transformer Tower
    # Each block attends to the output of the previous one, building deeper context.
    x1 = TransformerBlock(embed_dim=128, num_heads=4, ff_dim=256, rate=0.3)(x1)
    x1 = TransformerBlock(embed_dim=128, num_heads=4, ff_dim=256, rate=0.3)(x1)
    x1 = TransformerBlock(embed_dim=128, num_heads=4, ff_dim=256, rate=0.3)(x1)
    x1 = TransformerBlock(embed_dim=128, num_heads=4, ff_dim=256, rate=0.3)(x1)
    
    # Standard ToF branch
    x2 = tof_block_2(tof, wd) # Output shape: (None, 32, 128)

    # Merge and classify
    x = features_processing(x1, x2)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# --- Advanced Model 3: Hybrid UNet + WaveNet ---
def create_advanced_model_3_unet_wave(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    x1_unet = unet_se_cnn(imu, unet_depth=3, base_filters=64, kernel_size=5)
    x1_wave = wave_block(imu, 64, 3, n=5, dropout_rate=0.3) # n=5 -> dilations up to 16
    
    x1_unet_matched, x1_wave_matched = match_time_steps(x1_unet, x1_wave)
    x1 = Concatenate()([x1_unet_matched, x1_wave_matched])
    
    x2 = tof_block(tof, wd)

    x = features_processing(x1, x2)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

def create_wave_net(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    x1 = wave_block(imu, 128, 3, n=4, dropout_rate=0.3) 
    x2 = tof_block(tof, wd)

    x = features_processing(x1, x2)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# --- Advanced Model 4: Triple Stacked Block Design ---
def cnn_gru_block(x, filters, kernel_size, wd=1e-4):
    # A self-contained block combining CNN and GRU
    x_cnn = residual_se_cnn_block(x, filters, kernel_size, wd=wd)
    x_gru = Bidirectional(GRU(filters // 2, return_sequences=True))(x_cnn)
    return x_gru

def cnn_gru_block(x, filters, kernel_size, wd=1e-4):
    """
    A simplified and robust block that first applies a CNN, then a GRU.
    """
    # 1. CNN part for feature extraction and downsampling
    x = residual_se_cnn_block(x, filters, kernel_size, wd=wd)
    
    # 2. GRU part for sequence processing
    x = Bidirectional(GRU(filters, return_sequences=True, kernel_regularizer=l2(wd)))(x)
    
    return x

def create_advanced_model_4_stacked_blocks(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # Apply the hybrid block three times
    x1 = cnn_gru_block(imu, 64, 3)  # Output: (None, 64, 128)
    x1 = cnn_gru_block(x1, 128, 5) # Output: (None, 32, 256)
    
    # The final block will not return sequences to simplify the final merge
    x1 = Bidirectional(GRU(128, return_sequences=False))(x1) # Output: (None, 256)
    
    # Standard ToF branch, but we need to aggregate it to match x1
    x2 = tof_block_2(tof, wd) # Output: (None, 32, 128)
    x2 = GlobalAveragePooling1D()(x2) # Output: (None, 128)

    # Merge the two aggregated feature vectors
    x = Concatenate()([x1, x2]) # Output: (None, 256 + 128) = (None, 384)
    
    # Final classifier MLP
    x = Dropout(0.4)(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# --- Advanced Model 5: UNet with BiLSTM Bottleneck ---
def unet_se_cnn_bilstm(x, unet_depth=3, base_filters=64, kernel_size=3, drop=0.3):
    filters = base_filters
    skips = []
    for _ in range(unet_depth):
        x = residual_se_cnn_block(x, filters, kernel_size, drop=drop)
        skips.append(x)
        filters *= 2
    
    # --- BiLSTM Bottleneck ---
    # Process the most compressed representation sequentially
    x = Bidirectional(LSTM(filters // 2, return_sequences=True))(x)
    
    for skip in reversed(skips):
        filters //= 2
        x = res_se_cnn_decoder_block(x, filters, kernel_size, drop=drop, skip_connection=skip)
    return x

def create_advanced_model_1_deep_unet(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # --- IMU Branches ---
    x1_unet = unet_se_cnn(imu, unet_depth=4, base_filters=128, kernel_size=5, drop=0.3)
    x1_conv_k3 = residual_se_cnn_block(imu, 64, 3)
    x1_conv_k7 = residual_se_cnn_block(imu, 64, 7)
    
    # --- FIX: Aggregate each branch BEFORE merging ---
    # This creates a fixed-size vector from each branch, avoiding shape conflicts.
    p1 = GlobalAveragePooling1D()(x1_unet)
    p2 = GlobalAveragePooling1D()(x1_conv_k3)
    p3 = GlobalAveragePooling1D()(x1_conv_k7)
    
    # --- ToF Branch ---
    x2 = tof_block_2(tof, wd)
    p4 = GlobalAveragePooling1D()(x2)

    # Concatenate the aggregated feature vectors
    x = Concatenate()([p1, p2, p3, p4])
    
    # --- Final Classifier MLP ---
    x = Dropout(0.4)(x)
    x = Dense(512, activation='relu')(x)
    x = Dropout(0.3)(x)
    x = Dense(256, activation='relu')(x)
    
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

def create_advanced_model_5_unet_bilstm(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # Use the UNet with the BiLSTM bottleneck
    x1 = unet_se_cnn_bilstm(imu, unet_depth=3, base_filters=128, kernel_size=3)
    
    # Standard ToF branch
    x2 = tof_block_2(tof, wd)

    # --- FIX: Use the robust aggregation strategy instead of features_processing ---
    x1_pooled = GlobalAveragePooling1D()(x1)
    x2_pooled = GlobalAveragePooling1D()(x2)
    x = Concatenate()([x1_pooled, x2_pooled])

    # --- Final Classifier MLP ---
    x = Dropout(0.4)(x)
    x = Dense(512, activation='relu')(x)
    x = Dropout(0.3)(x)
    x = Dense(256, activation='relu')(x)
    
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

In [3]:
# =====================================================================================
# 3 NEW ADVANCED MODEL ARCHITECTURES
# =====================================================================================
from src.nn_blocks import attention_layer

def create_advanced_model_A_dual_unet(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # Branch 1: A deep U-Net for the IMU data
    x1_raw = unet_se_cnn(imu, unet_depth=4, base_filters=128, kernel_size=5, drop=0.3)
    
    # Branch 2: A parallel, slightly lighter U-Net for the ToF/Thermal data
    x2_raw = unet_se_cnn(tof, unet_depth=3, base_filters=64, kernel_size=5, drop=0.3)

    # --- FIX: Project both branches to a common feature dimension (e.g., 128) ---
    # This ensures the input to features_processing is consistent.
    x1 = Conv1D(128, 1, padding='same', activation='relu', name='imu_projection')(x1_raw)
    x2 = Conv1D(128, 1, padding='same', activation='relu', name='tof_projection')(x2_raw)
    
    # Now both x1 and x2 have shape (None, 128, 128)
    # They can be safely passed to the features_processing block.
    x = features_processing(x1, x2)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# --- Advanced Model B: Cross-Attention Fusion ---
# Hypothesis: Instead of just concatenating the IMU and ToF branches, we can create
# richer features by allowing them to "talk to each other." The IMU branch will learn
# what to pay attention to in the ToF data, and vice-versa.
def create_advanced_model_B_cross_attention(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # 1. Create strong, downsampled feature representations for both branches
    # Output Shape for both: (None, 32, 128)
    x1 = residual_se_cnn_block(imu, 64, 3, drop=0.2, wd=wd)
    x1 = residual_se_cnn_block(x1, 128, 5, drop=0.2, wd=wd)
    
    x2 = tof_block_2(tof, wd)

    # 2. Cross-Attention Fusion
    # The IMU branch queries the ToF branch for relevant context
    imu_attends_tof = tf.keras.layers.Attention()([x1, x2])
    # The ToF branch queries the IMU branch for relevant context
    tof_attends_imu = tf.keras.layers.Attention()([x2, x1])
    
    # 3. Create an enriched representation by concatenating all perspectives
    # The final tensor contains the original features plus the context-aware features.
    # Shape: (None, 32, 128 + 128 + 128 + 128) = (None, 32, 512)
    x = Concatenate()([x1, imu_attends_tof, x2, tof_attends_imu])
    
    # 4. Final Processing
    # We use a powerful sequence processor on this ultra-rich tensor
    x = Bidirectional(GRU(256, return_sequences=True, kernel_regularizer=l2(wd)))(x)
    x = attention_layer(x)
    
    x = Dropout(0.4)(x)
    x = Dense(256, activation='relu')(x)
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# --- Advanced Model C: Stacked Hybrid Blocks ---
# Hypothesis: A single block of (CNN -> RNN) is good. Repeatedly stacking this
# hybrid block will allow the model to learn progressively more abstract and

# powerful spatio-temporal features.
def cnn_lstm_block(x, filters, kernel_size, drop=0.2, wd=1e-4):
    # A self-contained, reusable block
    x = residual_se_cnn_block(x, filters, kernel_size, drop=drop, wd=wd)
    x = Bidirectional(LSTM(filters, return_sequences=True, kernel_regularizer=l2(wd)))(x)
    return x

def create_advanced_model_C_stacked_hybrid(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # --- IMU Branch: Stacked Hybrid Blocks ---
    # Each block refines the output of the previous one
    # Input: (128, D) -> Block1: (64, 128) -> Block2: (32, 256)
    x1 = cnn_lstm_block(imu, 64, 3)
    x1 = cnn_lstm_block(x1, 128, 5)
    
    # --- ToF Branch ---
    # Output: (32, 128)
    x2 = tof_block_2(tof, wd)
    # Project ToF features to match the final IMU feature dimension (256)
    x2_projected = Dense(256, activation='relu')(x2)

    # Now both inputs have shape (None, 32, 256) and can be processed
    x = features_processing(x1, x2_projected)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

In [4]:
# =====================================================================================
# 3 NEW ADVANCED MODEL ARCHITECTURES
# =====================================================================================

# --- Advanced Model A: BERT-Fusion (Keras Implementation) ---
# Hypothesis: Using a Transformer (BERT) as a late-stage fusion layer for features
# from three separate, specialized branches will create the most powerful representation.
# This is a direct translation of the PyTorch model's core idea.
def create_advanced_model_A_bert_fusion(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof_and_thm = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)
    
    # We need to split ToF and Thermal for separate processing
    # Assuming thm_cols are the first 5 in the tof_and_thm tensor
    thm = tf.keras.layers.Lambda(lambda t: t[:, :, :5])(tof_and_thm)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, 5:])(tof_and_thm)

    # 1. Create three separate feature extraction branches
    # IMU Branch
    x_imu = residual_se_cnn_block(imu, 128, 3)
    x_imu = residual_se_cnn_block(x_imu, 256, 5) # Shape: (None, 32, 256)
    
    # Thermal Branch
    x_thm = residual_se_cnn_block(thm, 64, 3)
    x_thm = residual_se_cnn_block(x_thm, 128, 5)
    x_thm = Conv1D(256, 1, padding='same', activation='relu')(x_thm) # Project to 256 features
    
    # ToF Branch
    x_tof = residual_se_cnn_block(tof, 128, 3)
    x_tof = residual_se_cnn_block(x_tof, 256, 5) # Shape: (None, 32, 256)
    
    # 2. Concatenate along the feature axis and feed into a Transformer
    # Shape: (None, 32, 256+256+256) -> (None, 32, 768)
    x = Concatenate()([x_imu, x_thm, x_tof])
    
    # Transformer (BERT-like) layers for deep fusion
    x = TransformerBlock(embed_dim=768, num_heads=8, ff_dim=1024, rate=0.2)(x)
    x = TransformerBlock(embed_dim=768, num_heads=8, ff_dim=1024, rate=0.2)(x)
    
    # 3. Use Global Pooling to aggregate the time dimension
    x = GlobalAveragePooling1D()(x)
    
    # 4. Final Classifier MLP
    x = Dropout(0.4)(x)
    x = Dense(512, activation='relu')(x)
    x = Dropout(0.3)(x)
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# --- Advanced Model B: Hyper-UNet ---
# Hypothesis: Since U-Nets are the top performers, an even deeper and wider U-Net
# with more filters and a deeper encoder/decoder structure will capture more complex features.
def create_advanced_model_B_hyper_unet(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # Branch 1: A very deep and wide U-Net for IMU data
    # unet_depth=5 creates a very deep model, base_filters=128 makes it wide.
    x1 = unet_se_cnn(imu, unet_depth=5, base_filters=128, kernel_size=5, drop=0.3)
    
    # Branch 2: A standard ToF block
    x2 = tof_block_2(tof, wd)

    # Project both branches to a common, large feature dimension before merging
    x1_proj = Conv1D(128, 1, padding='same', activation='relu')(x1)
    x2_proj = Conv1D(128, 1, padding='same', activation='relu')(x2)
    
    # Use the standard features_processing block to merge and classify
    x = features_processing(x1_proj, x2_proj)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# --- Advanced Model C: Parallel UNet-Transformer Hybrid ---
# Hypothesis: The IMU signal contains both local patterns (best for U-Net) and global
# context (best for Transformer). Processing the IMU with both backbones in parallel
# and fusing their outputs will create the ultimate feature representation.
def create_advanced_model_C_parallel_hybrid(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # --- IMU Branch is now two parallel streams ---
    
    # Stream 1: U-Net for multi-resolution analysis
    imu_unet = unet_se_cnn(imu, unet_depth=4, base_filters=128, kernel_size=5)
    
    # Stream 2: CNN -> Transformer Tower for global context
    imu_cnn = residual_se_cnn_block(imu, 64, 3)
    imu_cnn = residual_se_cnn_block(imu_cnn, 128, 5) # Shape: (None, 32, 128)
    imu_transformer = TransformerBlock(embed_dim=128, num_heads=4, ff_dim=256)(imu_cnn)
    
    # --- Fusion of IMU streams ---
    imu_unet_matched, imu_transformer_matched = match_time_steps(imu_unet, imu_transformer)
    x1 = Concatenate()([imu_unet_matched, imu_transformer_matched]) # Shape: (None, 32, 256)
    
    # --- ToF Branch ---
    x2 = tof_block_2(tof, wd) # Shape: (None, 32, 128)

    # --- FIX: Project both branches to a common, predictable feature dimension ---
    # Let's project both to 256 features, so the merged result is 512.
    x1_proj = Conv1D(256, 1, padding='same', activation='relu', name='imu_projection')(x1)
    x2_proj = Conv1D(256, 1, padding='same', activation='relu', name='tof_projection')(x2)
    
    # Now both x1_proj and x2_proj have shape (None, 32, 256)
    x = features_processing(x1_proj, x2_proj)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

In [5]:
# import polars as pl
# df = pl.read_parquet('output/imu_physics_feats.parquet')
# df.columns

In [6]:
from tensorflow.keras import Layer, Sequential

def ImuFeatureExtractorLayer(imu_input):
    """A Keras layer to perform on-the-fly feature engineering."""
    acc = imu_input[:, :, :3]  # Assuming raw acc_x, y, z are the first 3 features
    gyro = imu_input[:, :, 3:6] # Assuming raw rot_w,x,y,z -> angular velocity are next
    
    acc_mag = tf.norm(acc, axis=-1, keepdims=True)
    gyro_mag = tf.norm(gyro, axis=-1, keepdims=True)
    
    # Jerk (diff) requires padding to maintain time dimension
    jerk = tf.pad(acc[:, 1:, :] - acc[:, :-1, :], [[0, 0], [1, 0], [0, 0]])
    
    # Squared values
    acc_pow = tf.square(acc)
    
    # Concatenate all derived features
    return Concatenate()([acc, gyro, acc_mag, gyro_mag, jerk, acc_pow])

def create_new_model_1_in_model_fe(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    # IMPORTANT: This model expects the RAW acc/rot features, not the engineered ones.
    # You will need to adjust your data pipeline to feed the raw features.
    imu_raw = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # 1. On-the-fly feature engineering branch
    x1 = ImuFeatureExtractorLayer(imu_raw)
    
    # 2. Standard CNN backbone to process these rich features
    x1 = residual_se_cnn_block(x1, 128, 5)
    x1 = residual_se_cnn_block(x1, 256, 7)
    
    # 3. Standard ToF branch
    x2 = tof_block_2(tof, wd)

    # 4. Merge and classify
    x = features_processing(x1, x2)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# =====================================================================================
# 3 NEW ADVANCED PANNs-BASED MODEL ARCHITECTURES
# =====================================================================================

def create_panns_model_A_rnn_head(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # --- IMU Branch: Parallel CNNs (PANNs Backbone) ---
    # Each branch downsamples time to 32 and outputs 128 features
    k3 = residual_se_cnn_block(imu, 128, 3)
    k5 = residual_se_cnn_block(imu, 128, 5)
    k7 = residual_se_cnn_block(imu, 128, 7)
    
    # Concatenate the multi-scale features
    # Shape: (None, 32, 128 + 128 + 128) = (None, 32, 384)
    x1 = Concatenate()([k3, k5, k7])
    
    # --- ToF Branch ---
    x2 = tof_block(tof, wd) # Shape: (None, 32, 128)

    # --- Merge and Process with RNN Head ---
    # Project ToF features to match the IMU feature dimension for a cleaner merge
    x2_proj = Conv1D(384, 1, padding='same', activation='relu')(x2)
    
    # Concatenate the full feature set
    x = Concatenate()([x1, x2_proj]) # Shape: (None, 32, 384 + 384) = (None, 32, 768)
    
    # Add a powerful RNN head to learn sequential patterns from the rich features
    x = Bidirectional(GRU(384, return_sequences=True, kernel_regularizer=l2(wd)))(x)
    x = attention_layer(x) # Use attention to summarize the sequence
    
    # --- Final Classifier MLP ---
    x = Dropout(0.4)(x)
    x = Dense(256, activation='relu')(x)
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

def pann_rnn_head_feat_processing(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # --- IMU Branch: Parallel CNNs (PANNs Backbone) ---
    # Each branch downsamples time to 32 and outputs 128 features
    k3 = residual_se_cnn_block(imu, 128, 3)
    k5 = residual_se_cnn_block(imu, 128, 5)
    k7 = residual_se_cnn_block(imu, 128, 7)
    
    # Shape: (None, 32, 128 + 128 + 128) = (None, 32, 384)
    x1 = Concatenate()([k3, k5, k7])
    x2 = tof_block(tof, wd) # Shape: (None, 32, 128)
    x2_proj = Conv1D(384, 1, padding='same', activation='relu')(x2)
    
    x = features_processing(x1, x2_proj)
    x = Dropout(0.4)(x)
    x = Dense(256, activation='relu')(x)
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

In [7]:
from scipy.spatial.transform import Rotation as R

def ImuFeatureExtractorLayer(imu_input, sampling_rate_hz=200):
    acc = imu_input[:, :, :3]  # acc_x, acc_y, acc_z
    rot = imu_input[:, :, 3:]  # rot_w, rot_x, rot_y, rot_z
    
    def _calculate_angular_velocity_tf(quats):
        # This function will run in eager mode
        num_sequences = quats.shape[0]
        seq_len = quats.shape[1]
        angular_vel = np.zeros((num_sequences, seq_len, 3), dtype=np.float32)
        dt = 1.0 / sampling_rate_hz
        for i in range(num_sequences):
            q_seq = R.from_quat(quats[i, :, [1, 2, 3, 0]]) # Scipy expects (x,y,z,w)
            vel = np.diff(quats[i, :, :3], axis=0, prepend=quats[i, :1, :3]) / dt
            angular_vel[i, :, :] = vel
        return angular_vel

    gyro = tf.keras.layers.Lambda(
        lambda t: tf.py_function(func=_calculate_angular_velocity_tf, inp=[t], Tout=tf.float32)
    )(rot)
    
    # --- On-the-fly Feature Creation (wrapped in Lambda layers) ---
    acc_mag = tf.keras.layers.Lambda(lambda t: tf.norm(t, axis=-1, keepdims=True), name='acc_mag')(acc)
    gyro_mag = tf.keras.layers.Lambda(lambda t: tf.norm(t, axis=-1, keepdims=True), name='gyro_mag')(gyro)
    jerk = tf.keras.layers.Lambda(
        lambda t: tf.pad(t[:, 1:, :] - t[:, :-1, :], [[0, 0], [1, 0], [0, 0]]), name='jerk'
    )(acc)
    acc_pow = tf.keras.layers.Lambda(tf.square, name='acc_pow')(acc)
    
    # Concatenate all the resulting KerasTensors
    return Concatenate()([acc, gyro, acc_mag, gyro_mag, jerk, acc_pow])

def create_stacked_fe_unet(input_shape, raw_imu_dim, engineered_imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    
    imu_raw = tf.keras.layers.Lambda(lambda t: t[:, :, :raw_imu_dim])(inp)
    imu_engineered = tf.keras.layers.Lambda(lambda t: t[:, :, raw_imu_dim : raw_imu_dim + engineered_imu_dim])(inp)
    tof_engineered = tf.keras.layers.Lambda(lambda t: t[:, :, raw_imu_dim + engineered_imu_dim :])(inp)


    imu_on_the_fly_feats = ImuFeatureExtractorLayer(imu_raw)
    imu_on_the_fly_matched, imu_engineered_matched = match_time_steps(imu_on_the_fly_feats, imu_engineered)
    imu_final_features = Concatenate()([imu_on_the_fly_matched, imu_engineered_matched])
    
    x1 = unet_se_cnn(imu_final_features, 3, base_filters=128, kernel_size=3)
    x2 = tof_block(tof_engineered, wd)

    x = features_processing(x1, x2)
    x = tf.keras.layers.Dropout(0.3)(x) 
    main_out = tf.keras.layers.Dense(18, activation="softmax", name="main_output")(x)
    gate_out = tf.keras.layers.Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return tf.keras.models.Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

In [8]:
def transformer_encoder_block(inputs, head_size, num_heads, ff_dim, dropout=0.0, wd=1e-4):
    """A standard Transformer Encoder block."""
    # Attention and Normalization
    x = MultiHeadAttention(
        key_dim=head_size, num_heads=num_heads, dropout=dropout,
        kernel_regularizer=l2(wd)
    )(inputs, inputs)
    x = Dropout(dropout)(x)
    x = LayerNormalization(epsilon=1e-6)(Add()([inputs, x]))

    # Feed Forward Part
    ff_outputs = Dense(ff_dim, activation="relu", kernel_regularizer=l2(wd))(x)
    ff_outputs = Dense(inputs.shape[-1], kernel_regularizer=l2(wd))(ff_outputs)
    ff_outputs = Dropout(dropout)(ff_outputs)
    outputs = LayerNormalization(epsilon=1e-6)(Add()([x, ff_outputs]))
    return outputs

def cross_attention_block(query, value, key_dim, num_heads, dropout=0.1, wd=1e-4):
    """
    A cross-attention block where the query is from one modality and the
    value/key is from another.
    """
    attention_output = MultiHeadAttention(
        num_heads=num_heads, key_dim=key_dim, dropout=dropout,
        kernel_regularizer=l2(wd)
    )(query, value)
    x = Add()([query, attention_output])
    x = LayerNormalization(epsilon=1e-6)(x)
    return x

def wavenet_residual_block(inputs, filters, kernel_size, dilation_rate, wd=1e-4):
    """A WaveNet-style residual block with dilated convolutions."""
    shortcut = inputs
    
    # Gated activation unit
    tanh_out = Conv1D(filters, kernel_size, dilation_rate=dilation_rate,
                      padding='causal', activation='tanh', kernel_regularizer=l2(wd))(inputs)
    sigmoid_out = Conv1D(filters, kernel_size, dilation_rate=dilation_rate,
                         padding='causal', activation='sigmoid', kernel_regularizer=l2(wd))(inputs)
    x = Multiply()([tanh_out, sigmoid_out])
    
    # Projection
    x = Conv1D(filters, 1, padding='same', kernel_regularizer=l2(wd))(x)
    
    # Residual connection
    if shortcut.shape[-1] != filters:
        shortcut = Conv1D(filters, 1, padding='same', kernel_regularizer=l2(wd))(shortcut)
        
    return Add()([shortcut, x])


def create_conv_transformer_model(input_shape, imu_dim, wd=1e-4):
    """
    Architecture: Two-stream CNN feature extractors followed by a Transformer head.
    """
    inp = Input(shape=input_shape)
    imu = Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = Lambda(lambda t: t[:, :, imu_dim:])(inp)


    x1 = unet_se_cnn(imu, unet_depth=3, base_filters=128, kernel_size=3)
    x2 = tof_block(tof, wd)

    fusion_dim = 256
    x1_proj = Conv1D(fusion_dim, 1, padding='same', activation='relu')(x1)
    x2_proj = Conv1D(fusion_dim, 1, padding='same', activation='relu')(x2)
    
    x = Concatenate(axis=-1)([x1_proj, x2_proj])
    
    x = transformer_encoder_block(x, head_size=64, num_heads=4, ff_dim=fusion_dim*2, dropout=0.2, wd=wd)
    x = transformer_encoder_block(x, head_size=64, num_heads=4, ff_dim=fusion_dim*2, dropout=0.2, wd=wd)
    
    x = GlobalAveragePooling1D()(x)
    x = Dropout(0.3)(x)
    
    main_out = Dense(18, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

def create_cross_fusion_unet(input_shape, imu_dim, wd=1e-4):
    """
    Architecture: Parallel encoders with a cross-attention fused decoder.
    """
    inp = Input(shape=input_shape)
    imu = Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = Lambda(lambda t: t[:, :, imu_dim:])(inp)

    unet_depth = 3
    base_filters = 128
    kernel_size = 5

    imu_skips, tof_skips = [], []
    x_imu, x_tof = imu, tof
    
    filters = base_filters
    for _ in range(unet_depth):
        x_imu = residual_se_cnn_block(x_imu, filters, kernel_size, drop=0.3, wd=wd)
        imu_skips.append(x_imu)
        
        # ToF encoder can be simpler
        x_tof = residual_se_cnn_block(x_tof, filters // 2, kernel_size, drop=0.3, wd=wd)
        tof_skips.append(x_tof)
        
        filters *= 2

    x = residual_se_cnn_block(x_imu, filters, kernel_size, drop=0.3, wd=wd)

    for i in reversed(range(unet_depth)):
        filters //= 2
        # Upsample
        x = UpSampling1D(size=2)(x)
        x = Conv1D(filters, 2, padding='same', activation='relu')(x)
        
        x = Add()([x, imu_skips[i]])
        tof_context = cross_attention_block(x, tof_skips[i], key_dim=64, num_heads=4, wd=wd)
        x = Concatenate()([x, tof_context])
        
        x = Conv1D(filters, kernel_size, padding='same', use_bias=False)(x)
        x = BatchNormalization()(x); x = Activation('relu')(x)
        x = Conv1D(filters, kernel_size, padding='same', use_bias=False)(x)
        x = BatchNormalization()(x); x = Activation('relu')(x)

    x = GlobalAveragePooling1D()(x)
    x = Dropout(0.3)(x)
    
    main_out = Dense(18, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

def create_wavenet_style_model(input_shape, imu_dim, wd=1e-4):
    """
    Architecture: Parallel WaveNet-style backbones for efficient sequence modeling.
    """
    inp = Input(shape=input_shape)
    imu = Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = Lambda(lambda t: t[:, :, imu_dim:])(inp)

    embed_dim = 128
    num_blocks = 8 

    x1 = Conv1D(embed_dim, 1, padding='same')(imu)
    x2 = Conv1D(embed_dim, 1, padding='same')(tof)

    # --- Parallel WaveNet Backbones ---
    skip_connections1, skip_connections2 = [], []
    for i in range(num_blocks):
        dilation_rate = 2**(i % 4)
        x1 = wavenet_residual_block(x1, embed_dim, kernel_size=3, dilation_rate=dilation_rate, wd=wd)
        skip_connections1.append(Conv1D(embed_dim, 1, padding='same')(x1)) # Collect skip outputs
        
        x2 = wavenet_residual_block(x2, embed_dim, kernel_size=3, dilation_rate=dilation_rate, wd=wd)
        skip_connections2.append(Conv1D(embed_dim, 1, padding='same')(x2))


    x1_fused = Add()(skip_connections1)
    x1_fused = Activation('relu')(x1_fused)
    x2_fused = Add()(skip_connections2)
    x2_fused = Activation('relu')(x2_fused)
    
    x = Concatenate()([x1_fused, x2_fused])
    x = Conv1D(256, 1, padding='same', activation='relu')(x)
    
    x = GlobalAveragePooling1D()(x)
    x = Dropout(0.3)(x)
    
    main_out = Dense(18, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

In [9]:
def create_multiscale_unet_model_definitive(input_shape, imu_dim, wd=1e-4):
    
    # --- Internal Helper Block (local to this function and stable) ---
    def _conv_block(x, filters, kernel_size):
        # A simple, standard residual block
        shortcut = x
        x = Conv1D(filters, kernel_size, padding='same', use_bias=False, kernel_regularizer=l2(wd))(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Conv1D(filters, kernel_size, padding='same', use_bias=False, kernel_regularizer=l2(wd))(x)
        x = BatchNormalization()(x)
        
        # Ensure the shortcut can be added if the number of filters changes
        if shortcut.shape[-1] != filters:
            shortcut = Conv1D(filters, 1, padding='same', use_bias=False, kernel_regularizer=l2(wd))(shortcut)
            shortcut = BatchNormalization()(shortcut)
            
        x = Add()([x, shortcut])
        x = Activation('relu')(x)
        return x
        
    # --- Model Definition Starts Here ---
    inp = Input(shape=input_shape)
    imu = Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # --- 1. Multi-Scale IMU Backbone (Built from scratch) ---
    filters = 128
    skips = []
    x = imu
    
    # --- Encoder ---
    for _ in range(3): # unet_depth = 3
        x = _conv_block(x, filters, kernel_size=3)
        skips.append(x)
        x = MaxPooling1D(pool_size=2)(x)
        filters *= 2
    
    # --- Bottleneck ---
    x = _conv_block(x, filters, kernel_size=3)
    
    # --- Decoder ---
    decoder_outputs = []
    for i, skip in enumerate(reversed(skips)):
        filters //= 2
        x = UpSampling1D(size=2)(x)
        
        # Robustly handle potential off-by-one errors from pooling
        if x.shape[1] != skip.shape[1]:
             skip = Lambda(lambda s: s[:, :x.shape[1], :])(skip)

        x = Concatenate()([x, skip])
        x = _conv_block(x, filters, kernel_size=3)
        if i >= 3 - 2: # Capture last 2 outputs
            decoder_outputs.append(x)

    # --- Robust Multi-Scale Combination ---
    small_res_out = decoder_outputs[0]
    large_res_out = decoder_outputs[1]
    small_res_out_upsampled = UpSampling1D(size=2)(small_res_out)

    common_dim = 256
    proj_small = Conv1D(common_dim, 1, padding='same')(small_res_out_upsampled)
    proj_large = Conv1D(common_dim, 1, padding='same')(large_res_out)

    # Robustly match time steps after projection
    if proj_small.shape[1] != proj_large.shape[1]:
        proj_large = Lambda(lambda t: t[:, :proj_small.shape[1], :])(proj_large)

    x1 = Concatenate()([proj_small, proj_large])

    # --- 2. Standard ToF Backbone ---
    x2 = tof_block(tof, wd)

    # --- 3. Your Proven Fusion & RNN Head ---
    x_final = features_processing(x1, x2, wd=wd)
    
    main_out = Dense(18, activation="softmax", name="main_output")(x_final)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x_final)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# You will need this building block from the previous response
def transformer_encoder_block(inputs, head_size, num_heads, ff_dim, dropout=0.0, wd=1e-4):
    x = LayerNormalization(epsilon=1e-6)(inputs)
    x = MultiHeadAttention(
        key_dim=head_size, num_heads=num_heads, dropout=dropout, kernel_regularizer=l2(wd)
    )(x, x)
    x = Dropout(dropout)(x)
    res = Add()([inputs, x])

    x = LayerNormalization(epsilon=1e-6)(res)
    x = Dense(ff_dim, activation="relu", kernel_regularizer=l2(wd))(x)
    x = Dense(inputs.shape[-1], kernel_regularizer=l2(wd))(x)
    x = Dropout(dropout)(x)
    return Add()([res, x])

def create_transformer_refined_model(input_shape, imu_dim, wd=1e-4):
    """
    This model uses a Transformer stack to refine the U-Net's output before
    passing it to your features_processing head.
    """
    inp = Input(shape=input_shape)
    imu = Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # --- 1. Standard U-Net Backbone ---
    unet_out = unet_se_cnn(imu, unet_depth=3, base_filters=128, kernel_size=3)
    
    # --- 2. Transformer Refinement Step ---
    # This makes the U-Net features context-aware
    x1 = transformer_encoder_block(unet_out, head_size=128, num_heads=8, ff_dim=unet_out.shape[-1]*4, dropout=0.1, wd=wd)
    x1 = transformer_encoder_block(x1, head_size=128, num_heads=8, ff_dim=x1.shape[-1]*4, dropout=0.1, wd=wd)
    x1 = transformer_encoder_block(x1, head_size=128, num_heads=8, ff_dim=x1.shape[-1]*4, dropout=0.1, wd=wd)
    x1 = transformer_encoder_block(x1, head_size=128, num_heads=8, ff_dim=x1.shape[-1]*4, dropout=0.1, wd=wd)

    # --- 3. Standard ToF Backbone ---
    x2 = tof_block(tof, wd)

    # --- 4. Your Proven Fusion & RNN Head ---
    x_final = features_processing(x1, x2, wd=wd)
    
    main_out = Dense(18, activation="softmax", name="main_output")(x_final)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x_final)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

def gated_fusion_block(x1, x2, wd=1e-4):
    """
    Intelligently fuses two streams using a learned gating mechanism.
    """
    # Ensure both tensors have the same feature dimension
    dim = max(x1.shape[-1], x2.shape[-1])
    if x1.shape[-1] != dim:
        x1 = Dense(dim, kernel_regularizer=l2(wd))(x1)
    if x2.shape[-1] != dim:
        x2 = Dense(dim, kernel_regularizer=l2(wd))(x2)
        
    # Match time steps
    x1_matched, x2_matched = match_time_steps(x1, x2)
    
    # Compute the gate from the concatenation of both inputs
    gate_input = Concatenate()([x1_matched, x2_matched])
    gate = Dense(dim, activation='sigmoid', kernel_regularizer=l2(wd))(gate_input)
    
    # Apply the gate: gate*x1 + (1-gate)*x2
    gated_x1 = Multiply()([gate, x1_matched])
    gated_x2 = Multiply()([Lambda(lambda t: 1.0 - t)(gate), x2_matched])
    
    return Add()([gated_x1, gated_x2])

def create_gated_fusion_model(input_shape, imu_dim, wd=1e-4):
    """
    This model uses a gated fusion unit before passing a single, fused tensor
    to the RNN/Attention head.
    """
    inp = Input(shape=input_shape)
    imu = Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # --- 1. Standard Backbones ---
    x1 = unet_se_cnn(imu, unet_depth=3, base_filters=128, kernel_size=3)
    x2 = tof_block(tof, wd)

    # --- 2. Gated Fusion ---
    merged = gated_fusion_block(x1, x2, wd=wd)
    
    # --- 3. Modified Head (no initial fusion needed) ---
    xa = Bidirectional(LSTM(128, return_sequences=True, kernel_regularizer=l2(wd)))(merged)
    xb = Bidirectional(GRU(128, return_sequences=True, kernel_regularizer=l2(wd)))(merged)
    xc = GaussianNoise(0.09)(merged)
    xc = Dense(16, activation='elu')(xc)
    
    x = Concatenate()([xa, xb, xc])
    x = Dropout(0.4)(x)
    x = attention_layer(x)

    for units, drop in [(256, 0.5), (128, 0.3)]:
        x = Dense(units, use_bias=False, kernel_regularizer=l2(wd))(x)
        x = BatchNormalization()(x); x = Activation('relu')(x)
        x = Dropout(drop)(x)
    
    main_out = Dense(18, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

In [10]:
def create_dual_refiner_model(input_shape, imu_dim, wd=1e-4):
    """
    Applies the successful Transformer Refiner concept to BOTH the IMU and ToF streams
    before the final fusion and RNN head.
    """
    inp = Input(shape=input_shape)
    imu = Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # --- 1. Proven Backbones ---
    unet_out = unet_se_cnn(imu, unet_depth=3, base_filters=128, kernel_size=3)
    tof_out = tof_block(tof, wd)

    # --- 2. Dual Transformer Refinement Step ---
    # Refine the IMU stream
    x1_refined = transformer_encoder_block(unet_out, head_size=64, num_heads=4, ff_dim=unet_out.shape[-1]*4, dropout=0.1, wd=wd)
    
    # Refine the ToF stream
    x2_refined = transformer_encoder_block(tof_out, head_size=32, num_heads=2, ff_dim=tof_out.shape[-1]*4, dropout=0.1, wd=wd)
    
    # --- 3. Your Proven Fusion & RNN Head ---
    x_final = features_processing(x1_refined, x2_refined, wd=wd)
    
    main_out = Dense(18, activation="softmax", name="main_output")(x_final)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x_final)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# You need to add this new building block
def conformer_block(inputs, ff_dim, num_heads, kernel_size=3, dropout=0.1, wd=1e-4):
    """A single Conformer block mixing FFN, Self-Attention, and Convolution."""
    # Feed Forward 1
    x = LayerNormalization()(inputs)
    x = Dense(ff_dim, activation='swish', kernel_regularizer=l2(wd))(x)
    x = Dropout(dropout)(x)
    x = Dense(inputs.shape[-1], kernel_regularizer=l2(wd))(x)
    x = Add()([inputs, x])

    # Multi-Head Self-Attention
    res = x
    x = LayerNormalization()(x)
    x = MultiHeadAttention(num_heads=num_heads, key_dim=inputs.shape[-1]//num_heads, dropout=dropout)(x, x)
    x = Dropout(dropout)(x)
    x = Add()([res, x])

    # Convolution Module
    res = x
    x = LayerNormalization()(x)
    x = Conv1D(filters=inputs.shape[-1]*2, kernel_size=1, activation='swish')(x)
    x = Conv1D(filters=inputs.shape[-1], kernel_size=kernel_size, padding='same', groups=inputs.shape[-1])(x) # Depthwise Conv
    x = BatchNormalization()(x)
    x = Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)
    x = Dropout(dropout)(x)
    x = Add()([res, x])

    # Feed Forward 2
    res = x
    x = LayerNormalization()(x)
    x = Dense(ff_dim, activation='swish', kernel_regularizer=l2(wd))(x)
    x = Dropout(dropout)(x)
    x = Dense(inputs.shape[-1], kernel_regularizer=l2(wd))(x)
    x = Add()([res, x])
    return x

def create_conformer_model(input_shape, imu_dim, wd=1e-4):
    """Replaces the U-Net backbone with a stack of powerful Conformer blocks."""
    inp = Input(shape=input_shape)
    imu = Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # --- 1. Initial Projection ---
    # Project both streams to a common embedding dimension
    embed_dim = 192
    x1 = Conv1D(embed_dim, 1, padding='same')(imu)
    x2 = Conv1D(embed_dim, 1, padding='same')(tof)

    # --- 2. Conformer Backbone ---
    # Process each stream with a stack of Conformer blocks
    for _ in range(3): # Number of blocks is a hyperparameter
        x1 = conformer_block(x1, ff_dim=embed_dim*4, num_heads=4, kernel_size=7)
    
    for _ in range(2):
        x2 = conformer_block(x2, ff_dim=embed_dim*4, num_heads=4, kernel_size=7)

    # --- 3. Your Proven Fusion & RNN Head ---
    x_final = features_processing(x1, x2, wd=wd)
    
    main_out = Dense(18, activation="softmax", name="main_output")(x_final)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x_final)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

def create_hybrid_head_model(input_shape, imu_dim, wd=1e-4):
    """
    An internal ensemble that processes U-Net features through two parallel paths:
    one direct, and one refined by a Transformer.
    """
    inp = Input(shape=input_shape)
    imu = Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # --- 1. Common Backbones ---
    # This is the single source of truth for the initial feature extraction
    unet_out = unet_se_cnn(imu, unet_depth=3, base_filters=128, kernel_size=3)
    tof_out = tof_block(tof, wd)

    # --- 2. Parallel Processing Paths ---
    # Path A: The "Direct" path (like Best_unet_1)
    head_a_out = features_processing(unet_out, tof_out, wd=wd)
    
    # Path B: The "Refined" path (like transformer_refined_model)
    unet_refined = transformer_encoder_block(unet_out, head_size=128, num_heads=8, ff_dim=unet_out.shape[-1]*4, dropout=0.1)
    unet_refined = transformer_encoder_block(unet_refined, head_size=128, num_heads=8, ff_dim=unet_out.shape[-1]*4, dropout=0.1)

    head_b_out = features_processing(unet_refined, tof_out, wd=wd)

    # --- 3. Final Fusion of Heads ---
    # Concatenate the outputs of the two parallel heads
    x_final = Concatenate()([head_a_out, head_b_out])
    x_final = Dense(256, activation='relu')(x_final)
    x_final = Dropout(0.4)(x_final)
    
    main_out = Dense(18, activation="softmax", name="main_output")(x_final)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x_final)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})    

In [11]:
# =====================================================================================
# CONFIGURATION
# =====================================================================================
LR_INIT = 5e-4
WD = 3e-3
NUM_CLASSES = 18
BATCH_SIZE = 64
N_SPLITS = 4
MAX_PAD_LEN = 128


from src.merge_feats_dynamic import merge_feature_sets
from src.functions import create_sequence_dataset, generate_gate_targets, train_model
from src.nn_blocks import GatedMixupGenerator

# =====================================================================================
# TRAINING LOGIC
# =====================================================================================

FEATURE_DIR = Path('output')
RAW_DIR = Path('input/cmi-detect-behavior-with-sensor-data')
RANDOM_STATE = 42

# --- Step 1: Define the feature sets to merge for this experiment ---
files_to_merge = [
    'imu_basic_physics_feats.parquet',
    'tof_basic_kaggle_feats.parquet',
    'tof_features_advanced_train_polars.parquet',
    'imu_physics_feats.parquet'
    # 'thermal_features.parquet'
]

feature_paths = [FEATURE_DIR / f for f in files_to_merge]
base_df = pl.read_parquet(FEATURE_DIR / "cleaned_base_train_data.parquet")
demographics_df = pl.read_csv(RAW_DIR / "train_demographics.csv")
base_df = base_df.join(demographics_df, on='subject', how='left')
meta_cols = ['sequence_id', 'sequence_counter', 'subject', 'gesture']
imu_raw_cols =['acc_x', 'acc_y', 'acc_z', 'rot_w', 'rot_x', 'rot_y', 'rot_z', ]
thm_cols = ['thm_1', 'thm_2', 'thm_3', 'thm_4', 'thm_5']
base_df = base_df.select(meta_cols + imu_raw_cols + thm_cols)
# base_df = base_df.select(meta_cols)

le = LabelEncoder()
gesture_encoded = le.fit_transform(base_df.get_column('gesture'))
base_df = base_df.with_columns(pl.Series("gesture_int", gesture_encoded))  
final_df = merge_feature_sets(base_df, feature_paths)
print(f"  Final merged DataFrame created with shape: {final_df.shape}")

▶ Starting merge process...
  Loading and joining features from: imu_basic_physics_feats.parquet
  Loading and joining features from: tof_basic_kaggle_feats.parquet
  Loading and joining features from: tof_features_advanced_train_polars.parquet
  Loading and joining features from: imu_physics_feats.parquet
  Merge complete.
  Final merged DataFrame created with shape: (556380, 126)


In [12]:
from tensorflow.keras.callbacks import Callback

class EMACallback(Callback):
    """
    Callback to update and apply Exponential Moving Average of weights.
    """
    def __init__(self, decay=0.999):
        super(EMACallback, self).__init__()
        self.decay = decay
        self.ema_weights = None

    def on_train_begin(self, logs=None):
        """Initialize EMA weights at the beginning of training."""
        self.ema_weights = [tf.identity(w) for w in self.model.get_weights()]
        print("EMA Callback: EMA weights initialized.")

    def on_batch_end(self, batch, logs=None):
        """Update EMA weights after each training batch."""
        current_weights = self.model.get_weights()
        for i in range(len(self.ema_weights)):
            self.ema_weights[i] = (self.decay * self.ema_weights[i]) + ((1 - self.decay) * current_weights[i])

    def on_epoch_end(self, epoch, logs=None):
        """
        Optionally, you could evaluate with EMA weights at the end of each epoch,
        but for this workflow, we'll just apply them at the end of training.
        """
        pass

    def apply_ema_weights(self):
        """
        Applies the stored EMA weights to the model. This should be called
        after training is complete and before evaluation or saving.
        """
        if self.ema_weights is not None:
            self.model.set_weights(self.ema_weights)
            print("EMA Callback: EMA weights have been applied to the model.")

In [None]:
le = LabelEncoder()
gesture_encoded = le.fit_transform(base_df.get_column('gesture'))
base_df = base_df.with_columns(pl.Series("gesture_int", gesture_encoded))  
final_df = merge_feature_sets(base_df, feature_paths)
print(f"  Final merged DataFrame created with shape: {final_df.shape}")

imu_cols = [
    'acc_x', 'acc_y', 'acc_z', 'rot_w', 'rot_x', 'rot_y', 'rot_z', 
    'acc_mag', 'acc_mag_jerk', 'linear_acc_x_right', 'linear_acc_y_right', 'linear_acc_z_right',
    'angular_vel_x_right', 'angular_vel_y_right', 'angular_vel_z_right',
    'angular_accel_x', 'angular_accel_y', 'angular_accel_z',
    'grav_orient_x', 'grav_orient_y', 'grav_orient_z',
    'linear_acc_mag_right', 'angular_vel_mag', 'angular_accel_mag',
    'linear_acc_mag_jerk_right', 'angular_vel_mag_jerk', 'angular_accel_mag_jerk',
    'linear_acc_x', 'linear_acc_y', 'linear_acc_z', 'linear_acc_mag', 'linear_acc_mag_jerk', 'angular_vel_x',
    'angular_vel_y', 'angular_vel_z', 'angular_distance',
]

tof_cols = [
 'thm_1', 'thm_2', 'thm_3', 'thm_4', 'thm_5', 
 'tof_1_mean', 'tof_1_std', 'tof_1_min', 'tof_1_max', 'tof_2_mean', 'tof_2_std', 'tof_2_min', 'tof_2_max',
 'tof_3_mean', 'tof_3_std', 'tof_3_min', 'tof_3_max', 'tof_4_mean', 'tof_4_std', 'tof_4_min', 'tof_4_max',
 'tof_5_mean', 'tof_5_std', 'tof_5_min', 'tof_5_max', 'tof_1_mean_right', 'tof_1_std_right', 'tof_1_min_right',
 'tof_1_max_right', 'tof_1_diff_mean', 'tof_1_mean_decay', 'tof_1_active_pixels', 'tof_1_centroid_x', 'tof_1_centroid_y', 'tof_2_mean_right',
 'tof_2_std_right', 'tof_2_min_right', 'tof_2_max_right', 'tof_2_diff_mean', 'tof_2_mean_decay', 
 'tof_2_active_pixels', 'tof_2_centroid_x', 'tof_2_centroid_y', 'tof_3_mean_right', 'tof_3_std_right', 'tof_3_min_right',
 'tof_3_max_right', 'tof_3_diff_mean', 'tof_3_mean_decay', 'tof_3_active_pixels', 'tof_3_centroid_x', 'tof_3_centroid_y', 'tof_4_mean_right', 'tof_4_std_right', 'tof_4_min_right',
 'tof_4_max_right', 'tof_4_diff_mean', 'tof_4_mean_decay', 'tof_4_active_pixels', 'tof_4_centroid_x', 'tof_4_centroid_y', 'tof_5_mean_right', 'tof_5_std_right', 'tof_5_min_right',
 'tof_5_max_right', 'tof_5_diff_mean', 'tof_5_mean_decay', 'tof_5_active_pixels', 'tof_5_centroid_x', 'tof_5_centroid_y',]

# 2. Create the final, ordered list of all features.
all_feature_cols = imu_cols + tof_cols
imu_dim = len(imu_cols)
print(f"  Training with {len(all_feature_cols)} total features ({imu_dim} IMU, {len(tof_cols)} ToF/Thm).")    

# 3. Reorder the DataFrame to match the required structure for the model.
metadata_to_keep = ['sequence_id', 'sequence_counter', 'gesture', 'gesture_int', 'subject']
final_df = final_df.select(metadata_to_keep + all_feature_cols)
print("  DataFrame columns have been reordered for the model.")

# --- Step 3: Prepare for Cross-Validation ---
cv_info = final_df.group_by("sequence_id").agg(pl.first("gesture_int")).sort("sequence_id")
all_sequence_ids = cv_info.get_column("sequence_id").to_numpy()
y_for_split = cv_info.get_column("gesture_int").to_numpy()

input_shape = (MAX_PAD_LEN, len(all_feature_cols)) 

model_results = {}
# model_builders = [
# #     # ("Best_unet_2", lambda: best_unet_2(input_shape, imu_dim)),
# #     ("Advanced_Dual_UNet", lambda: create_advanced_model_A_dual_unet(input_shape, imu_dim)),
#     ("Best_unet_1", lambda: best_unet_1(input_shape, imu_dim)),
# #     # ("gated_cnn_transformer_2_blocks", lambda: create_gated_cnn_transformer(input_shape, imu_dim)),
# ]

model_builders = [
        # ("Best_unet_1", lambda: best_unet_1(input_shape, imu_dim)),
        # ("conformer_model", lambda: create_conformer_model(input_shape, imu_dim)),
        # ("dual_refiner_model", lambda: create_dual_refiner_model(input_shape, imu_dim)),
        # ("pann_rnn_head", lambda:pann_rnn_head_feat_processing(input_shape, imu_dim))
    ]

for model_name, model_builder in model_builders:
    print("\n" + "="*60)
    print(f"▶ Training and Evaluating Model: {model_name}")
    print("="*60)

    kf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_STATE)
    fold_accuracies = []
    all_preds = []
    all_labels = []
    best_epochs = []

    for fold_idx, (train_indices, val_indices) in enumerate(kf.split(all_sequence_ids, y_for_split)):
        print(f"\n=== Fold {fold_idx + 1}/{N_SPLITS} for {model_name} ===")
        train_ids = all_sequence_ids[train_indices]
        val_ids = all_sequence_ids[val_indices]

        train_df = final_df.filter(pl.col('sequence_id').is_in(train_ids))
        val_df = final_df.filter(pl.col('sequence_id').is_in(val_ids))
        
        scaler = StandardScaler()
        # Use the explicitly ordered 'all_feature_cols' list
        train_features_scaled = scaler.fit_transform(train_df.select(all_feature_cols))
        val_features_scaled = scaler.transform(val_df.select(all_feature_cols))
        
        X_train_scaled_features = pl.DataFrame(train_features_scaled, schema=all_feature_cols)
        X_val_scaled_features = pl.DataFrame(val_features_scaled, schema=all_feature_cols)

        meta_cols_to_keep = ['sequence_id', 'sequence_counter', 'gesture_int']
        train_df_final = train_df.select(meta_cols_to_keep).with_columns(X_train_scaled_features)
        val_df_final = val_df.select(meta_cols_to_keep).with_columns(X_val_scaled_features)

        # Use the explicitly ordered 'all_feature_cols' list
        X_train, y_train, train_gate_target = create_sequence_dataset(train_df_final, all_feature_cols, generate_gate_targets(train_df, tof_cols))
        X_val, y_val, val_gate_target = create_sequence_dataset(val_df_final, all_feature_cols, generate_gate_targets(val_df, tof_cols))

        X_train_padded = pad_sequences(X_train, maxlen=MAX_PAD_LEN, padding='post', truncating='pre', dtype='float32')
        X_val_padded = pad_sequences(X_val, maxlen=MAX_PAD_LEN, padding='post', truncating='pre', dtype='float32')
        
        y_train_cat = to_categorical(y_train, num_classes=NUM_CLASSES)
        y_val_cat = to_categorical(y_val, num_classes=NUM_CLASSES)
        
        train_dataset = GatedMixupGenerator(
            X=X_train_padded, y=y_train_cat, gate_targets=train_gate_target,
            batch_size=BATCH_SIZE, imu_dim=imu_dim, alpha=0.2, masking_prob=0.25
        )
        val_dataset = tf.data.Dataset.from_tensor_slices((
            X_val_padded, {'main_output': y_val_cat, 'tof_gate': val_gate_target[:, np.newaxis]}
        )).batch(BATCH_SIZE).cache().prefetch(tf.data.AUTOTUNE)

        print(f'Size of X_train: {X_train[0].shape}')
        del X_train, y_train, X_val, y_val, X_train_padded, X_val_padded
        gc.collect()
        
        model = model_builder()
        # =============================================================================
        # MODIFICATION 1: Instantiate the EMACallback
        # =============================================================================
        ema_callback = EMACallback(decay=0.999)

        # Pass the callback to your training function
        history = train_model(
            model, 
            train_dataset, 
            val_dataset, 
            epochs=150, 
            initial_learning_rate=LR_INIT, 
            weight_decay=WD,
            extra_callbacks=[ema_callback] # Pass the callback here
        )
        
        # =============================================================================
        # MODIFICATION 2: Apply the EMA weights before evaluation
        # =============================================================================
        print("--- Training complete. Applying EMA weights for evaluation. ---")
        ema_callback.apply_ema_weights()

        monitor_metric = 'val_main_output_accuracy' if isinstance(model.output, dict) else 'val_accuracy'
        best_epoch = np.argmax(history.history[monitor_metric]) + 1
        best_epochs.append(best_epoch)
        print(f"--- Fold {fold_idx + 1} Best Epoch: {best_epoch} ---")

        # --- EVALUATION (now uses the EMA weights) ---
        val_preds = model.predict(val_dataset)
        main_output_preds = val_preds['main_output']
        
        y_pred_fold = np.argmax(main_output_preds, axis=1)
        y_true_fold = np.argmax(y_val_cat, axis=1)
        fold_acc = accuracy_score(y_true_fold, y_pred_fold)
        fold_accuracies.append(fold_acc)
        print(f"Fold {fold_idx + 1} Accuracy: {fold_acc:.4f}")
        all_preds.append(y_pred_fold)
        all_labels.append(y_true_fold)

        del train_dataset, model, val_dataset
        gc.collect()

    # --- FINAL OOF REPORT for this model architecture ---
    print(f"\n=== OOF Summary for {model_name} ===")
    print(f"Per-fold Accuracies: {[round(a, 4) for a in fold_accuracies]}")
    print(f"Mean Accuracy: {np.mean(fold_accuracies):.4f} ± {np.std(fold_accuracies):.4f}")
    
    # --- NEW: Report on the best epochs found ---
    avg_best_epoch = int(np.mean(best_epochs))
    print(f"Best epochs per fold: {best_epochs}")
    print(f"Average best epoch: {avg_best_epoch}")
    
    # Store the results for this model
    model_results[model_name] = {
        'mean_accuracy': np.mean(fold_accuracies),
        'avg_best_epoch': avg_best_epoch
    }

    y_all_pred = np.concatenate(all_preds)
    y_all_true = np.concatenate(all_labels)
    print("\n=== Overall Classification Report ===")
    print(classification_report(y_all_true, y_all_pred, target_names=le.classes_, digits=4))

# --- FINAL SUMMARY ACROSS ALL MODELS ---
print("\n" + "="*60)
print("▶ FINAL MODEL EXPERIMENT SUMMARY")
print("="*60)
for model_name, results in model_results.items():
    print(f"  - {model_name}: Mean Accuracy = {results['mean_accuracy']:.4f}, Avg Best Epoch = {results['avg_best_epoch']}")

▶ Starting merge process...
  Loading and joining features from: imu_basic_physics_feats.parquet
  Loading and joining features from: tof_basic_kaggle_feats.parquet
  Loading and joining features from: tof_features_advanced_train_polars.parquet
  Loading and joining features from: imu_physics_feats.parquet
  Merge complete.
  Final merged DataFrame created with shape: (556380, 126)
  Training with 106 total features (36 IMU, 70 ToF/Thm).
  DataFrame columns have been reordered for the model.

▶ Training and Evaluating Model: hybrid_head_model

=== Fold 1/4 for hybrid_head_model ===


I0000 00:00:1757146843.041968 1934972 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 4714 MB memory:  -> device: 0, name: NVIDIA GeForce GTX 1060, pci bus id: 0000:01:00.0, compute capability: 6.1


Size of X_train: (57, 106)
LR Scheduler: 92 steps per epoch, 13800 total decay steps.


  self._warn_if_super_not_called()


EMA Callback: EMA weights initialized.
Epoch 1/150


I0000 00:00:1757146883.673650 1935193 cuda_dnn.cc:529] Loaded cuDNN version 90300
2025-09-06 09:21:27.602908: W external/local_xla/xla/tsl/framework/bfc_allocator.cc:310] Allocator (GPU_0_bfc) ran out of memory trying to allocate 4.24GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2025-09-06 09:21:27.650462: W external/local_xla/xla/tsl/framework/bfc_allocator.cc:310] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.09GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2025-09-06 09:21:27.674472: W external/local_xla/xla/tsl/framework/bfc_allocator.cc:310] Allocator (GPU_0_bfc) ran out of memory trying to allocate 4.08GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if m

[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m97s[0m 598ms/step - loss: 3.9721 - main_output_accuracy: 0.1201 - main_output_loss: 2.9605 - tof_gate_loss: 0.4930 - val_loss: 3.4353 - val_main_output_accuracy: 0.2491 - val_main_output_loss: 2.4861 - val_tof_gate_loss: 0.3661
Epoch 2/150
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m48s[0m 519ms/step - loss: 3.3554 - main_output_accuracy: 0.2555 - main_output_loss: 2.4345 - tof_gate_loss: 0.3024 - val_loss: 3.0290 - val_main_output_accuracy: 0.3072 - val_main_output_loss: 2.1886 - val_tof_gate_loss: 0.1326
Epoch 3/150
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m50s[0m 541ms/step - loss: 3.1351 - main_output_accuracy: 0.3149 - main_output_loss: 2.2820 - tof_gate_loss: 0.2784 - val_loss: 2.6899 - val_main_output_accuracy: 0.4096 - val_main_output_loss: 1.9087 - val_tof_gate_loss: 0.1645
Epoch 4/150
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m51s[0m 551ms/step - loss: 2.9325 - main_output_a

  self._warn_if_super_not_called()


EMA Callback: EMA weights initialized.
Epoch 1/150
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m107s[0m 652ms/step - loss: 4.0142 - main_output_accuracy: 0.1143 - main_output_loss: 3.0030 - tof_gate_loss: 0.4997 - val_loss: 3.4697 - val_main_output_accuracy: 0.2099 - val_main_output_loss: 2.5314 - val_tof_gate_loss: 0.3412
Epoch 2/150
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m55s[0m 598ms/step - loss: 3.3366 - main_output_accuracy: 0.2672 - main_output_loss: 2.4129 - tof_gate_loss: 0.3419 - val_loss: 3.1589 - val_main_output_accuracy: 0.2919 - val_main_output_loss: 2.3254 - val_tof_gate_loss: 0.1429
Epoch 3/150
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 589ms/step - loss: 3.0804 - main_output_accuracy: 0.3345 - main_output_loss: 2.2345 - tof_gate_loss: 0.2834 - val_loss: 2.5665 - val_main_output_accuracy: 0.4513 - val_main_output_loss: 1.8084 - val_tof_gate_loss: 0.0918
Epoch 4/150
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m 

  self._warn_if_super_not_called()


EMA Callback: EMA weights initialized.
Epoch 1/150
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m102s[0m 635ms/step - loss: 4.0036 - main_output_accuracy: 0.1155 - main_output_loss: 2.9951 - tof_gate_loss: 0.4785 - val_loss: 3.4079 - val_main_output_accuracy: 0.2620 - val_main_output_loss: 2.4583 - val_tof_gate_loss: 0.3660
Epoch 2/150
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 592ms/step - loss: 3.3718 - main_output_accuracy: 0.2483 - main_output_loss: 2.4479 - tof_gate_loss: 0.3145 - val_loss: 2.8090 - val_main_output_accuracy: 0.4179 - val_main_output_loss: 1.9547 - val_tof_gate_loss: 0.1954
Epoch 3/150
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m57s[0m 622ms/step - loss: 3.1182 - main_output_accuracy: 0.3083 - main_output_loss: 2.2703 - tof_gate_loss: 0.2460 - val_loss: 2.5919 - val_main_output_accuracy: 0.4567 - val_main_output_loss: 1.8143 - val_tof_gate_loss: 0.1225
Epoch 4/150
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m 

In [None]:
# import traceback
# # =====================================================================================
# # ARCHITECTURE SANITY CHECK
# # =====================================================================================

# # --- Step 1: Get a sample batch and define shapes ---
# # (This part of your code is correct)
# # Make sure your train_dataset is created before this block
# try:
#     sample_batch = next(iter(train_dataset))
#     sample_input = sample_batch[0]
#     input_shape = sample_input.shape[1:]
#     imu_dim = len(imu_cols) # Assuming imu_cols is defined
#     print(f"Sample input shape for testing: {sample_input.shape}\n")
# except Exception as e:
#     print(f"Could not get a sample from the dataset. Error: {e}")
#     # Exit if we can't get a sample to test with
#     exit()

# # --- Step 2: Create a list of all model-building functions ---
# # (This part of your code is correct)
# # model_builders = [
# #     ("CNN_Baseline", lambda: create_model_1_cnn_baseline(input_shape)),
# #     ("GRU_Baseline", lambda: create_model_2_gru_baseline(input_shape)),
# #     ("CNN_RNN_Hybrid", lambda: create_model_3_cnn_rnn_hybrid(input_shape)),
# #     ("WaveNet_Style", lambda: create_model_4_wavenet_style(input_shape)),
# #     ("UNet_Style", lambda: create_model_5_unet_style(input_shape)),
# #     ("Transformer", lambda: create_model_6_transformer(input_shape)),
# #     ("CNN_Transformer", lambda: create_model_7_cnn_transformer(input_shape)),
# #     # For your two-branch model, you'll need the full IMU+ToF dataset
# #     # ("Two_Branch", lambda: create_model_8_two_branch(input_shape, imu_dim)),
# # ]

# model_builders = [
#     ("wave_net", lambda: create_wave_net(input_shape, imu_dim)),
#     ("unet_wave", lambda: create_advanced_model_3_unet_wave(input_shape, imu_dim)),
#     ("Advanced_Dual_UNet", lambda: create_advanced_model_A_dual_unet(input_shape, imu_dim)),
#     ("Hyper UNet", lambda: create_advanced_model_B_hyper_unet(input_shape, imu_dim)),
# ]

# # --- Step 3: Loop through the models, build them, and test with the sample ---
# print("--- Testing all model architectures with a sample batch ---")
# for model_name, model_builder in model_builders:
#     print("\n" + "="*60)
#     print(f"▶ Testing Model: {model_name}")
#     print("="*60)
    
#     try:
#         # 1. Build the model using the builder function
#         model = model_builder()
        
#         # Optional: Print the model summary to check its structure
#         print(f"Model Summary for {model_name}:")
#         model.summary()
        
#         # 2. Pass the sample input through the model
#         print(f"\nPerforming forward pass for {model_name}...")
#         output = model(sample_input)
        
#         # 3. Print the output shape to verify it's correct
#         print(f"✅ SUCCESS: Model '{model_name}' ran successfully.")
#         # For multi-output models, output might be a list/dict. For single, it's a tensor.
#         if isinstance(output, dict):
#             for key, value in output.items():
#                 print(f"   Output '{key}' shape: {value.shape}")
#         elif isinstance(output, list):
#             for i, value in enumerate(output):
#                 print(f"   Output {i} shape: {value.shape}")
#         else:
#             print(f"   Output shape: {output.shape}")

#     except Exception as e:
#         print(f"❌ ERROR: Model '{model_name}' failed to build or run.")
#         traceback.print_exc() # This will print the full error traceback
        
#     # Clean up the created model to save memory
#     del model
#     gc.collect()

# print("\n--- Model architecture testing complete. ---")