In [1]:
import gc
import numpy as np
import pandas as pd
import polars as pl
import tensorflow as tf
from pathlib import Path
from tensorflow import shape, minimum
from tensorflow.keras import backend as k
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import pad_sequences, Sequence, to_categorical
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import (
    Dense, Input, Conv1D, MaxPooling1D, GlobalAveragePooling1D, GlobalMaxPooling1D, Concatenate,
    BatchNormalization, GRU, Dropout, add, Activation, Multiply, Reshape,
    LayerNormalization, Add, Bidirectional, LSTM, UpSampling1D, Lambda, GaussianNoise, MultiHeadAttention
)
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import classification_report, accuracy_score

from src.nn_blocks import tof_block, residual_se_cnn_block, TransformerBlock, tof_block_2, features_processing, unet_se_cnn

NUM_CLASSES = 18


# --- Gated Model 1: Based on CNN-RNN Hybrid ---
def create_gated_cnn_rnn(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # IMU branch
    x1 = residual_se_cnn_block(imu, 64, 3, drop=0.2, wd=wd) # Output: (None, 64, 64)
    x1 = residual_se_cnn_block(x1, 128, 5, drop=0.2, wd=wd) # Output: (None, 32, 128)
    x1 = Bidirectional(LSTM(128, return_sequences=True, kernel_regularizer=l2(wd)))(x1) # Output: (None, 32, 256)
    
    # Standard ToF branch
    x2 = tof_block_2(tof, wd) # Output: (None, 32, 128)

    # --- FIX: Project x2 to match x1's feature dimension before processing ---
    x2_projected = Dense(256, activation='relu')(x2)

    # Now both inputs to features_processing have shape (None, 32, 256)
    x = features_processing(x1, x2_projected)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# --- Gated Model 2: Based on UNet_Style ---
def create_gated_unet(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # IMU branch
    x1 = unet_se_cnn(imu, unet_depth=4, base_filters=64, kernel_size=5, drop=0.3) # Output: (None, 128, 64)
    
    # Standard ToF branch
    x2 = tof_block_2(tof, wd) # Output: (None, 32, 128)

    # --- FIX: This architecture is fundamentally incompatible. We will use a simpler merge. ---
    # The U-Net doesn't downsample time, making it hard to merge.
    # We will use a simpler approach for this model.
    x1_pooled = GlobalAveragePooling1D()(x1)
    x2_pooled = GlobalAveragePooling1D()(x2)
    x = Concatenate()([x1_pooled, x2_pooled])
    
    x = Dropout(0.4)(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# --- Gated Model 3: Based on CNN_Transformer ---
def create_gated_cnn_transformer(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # IMU branch
    x1 = residual_se_cnn_block(imu, 64, 3, drop=0.2, wd=wd) # Output: (None, 64, 64)
    x1 = residual_se_cnn_block(x1, 128, 5, drop=0.2, wd=wd) # Output: (None, 32, 128)
    x1 = TransformerBlock(embed_dim=128, num_heads=4, ff_dim=128, rate=0.3)(x1) # Output: (None, 32, 128)
    x1 = residual_se_cnn_block(x1, 64, 3, drop=0.2, wd=wd) # Output: (None, 16, 64)
    x1 = residual_se_cnn_block(x1, 128, 5, drop=0.2, wd=wd) # Output: (None, 8, 128)    
    x1 = TransformerBlock(embed_dim=128, num_heads=4, ff_dim=128, rate=0.3)(x1) # Output: (None, 8, 128)
    
    # Standard ToF branch
    x2 = tof_block_2(tof, wd) # Output: (None, 32, 128)
    x2 = tf.keras.layers.MaxPooling1D(4)(x2)

    x = features_processing(x1, x2)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

def best_unet_1(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    x1 = unet_se_cnn(imu, 3, base_filters=128, kernel_size=3)
    x2 = tof_block(tof, wd)

    x = features_processing(x1, x2)
    x = tf.keras.layers.Dropout(0.3)(x) 
    main_out = tf.keras.layers.Dense(18, activation="softmax", name="main_output")(x)
    gate_out = tf.keras.layers.Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return tf.keras.models.Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

def best_unet_2(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    x1 = unet_se_cnn(imu, 3, base_filters=128, kernel_size=3)
    x2 = tof_block_2(tof, wd)

    x = features_processing(x1, x2)
    x = tf.keras.layers.Dropout(0.3)(x) 
    main_out = tf.keras.layers.Dense(18, activation="softmax", name="main_output")(x)
    gate_out = tf.keras.layers.Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return tf.keras.models.Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

2025-08-15 20:46:06.734002: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755287166.776980 2128177 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755287166.789545 2128177 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1755287166.838058 2128177 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755287166.838085 2128177 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755287166.838087 2128177 computation_placer.cc:177] computation placer alr

In [3]:
import polars as pl

from pathlib import Path
from src.merge_feats_dynamic import merge_feature_sets
from src.functions import create_sequence_dataset, generate_gate_targets, train_model
from src.nn_blocks import GatedMixupGenerator

# =====================================================================================
# TRAINING LOGIC
# =====================================================================================

FEATURE_DIR = Path('output')
RAW_DIR = Path('input/cmi-detect-behavior-with-sensor-data')
RANDOM_STATE = 42

files_to_merge = [
    'imu_basic_physics_feats.parquet',
    'tof_basic_kaggle_feats.parquet'
]
feature_paths = [FEATURE_DIR / f for f in files_to_merge]
base_df = pl.read_parquet(FEATURE_DIR / "cleaned_base_train_data.parquet")
demographics_df = pl.read_csv(RAW_DIR / "train_demographics.csv")
base_df = base_df.join(demographics_df, on='subject', how='left')

meta_cols = ['sequence_id', 'sequence_counter', 'subject', 'gesture']
thm_cols = ['thm_1', 'thm_2', 'thm_3', 'thm_4', 'thm_5']
base_df = base_df.select(meta_cols + thm_cols)

final_df = merge_feature_sets(base_df, feature_paths)
print(f"  Final merged DataFrame created with shape: {final_df.shape}")
final_df.columns

 Starting merge process...
  Loading and joining features from: imu_basic_physics_feats.parquet
  Loading and joining features from: tof_basic_kaggle_feats.parquet
  Merge complete.
  Final merged DataFrame created with shape: (574945, 45)


['sequence_id',
 'sequence_counter',
 'subject',
 'gesture',
 'thm_1',
 'thm_2',
 'thm_3',
 'thm_4',
 'thm_5',
 'acc_x',
 'acc_y',
 'acc_z',
 'rot_w',
 'rot_x',
 'rot_y',
 'rot_z',
 'linear_acc_x',
 'linear_acc_y',
 'linear_acc_z',
 'linear_acc_mag',
 'linear_acc_mag_jerk',
 'angular_vel_x',
 'angular_vel_y',
 'angular_vel_z',
 'angular_distance',
 'tof_1_mean',
 'tof_1_std',
 'tof_1_min',
 'tof_1_max',
 'tof_2_mean',
 'tof_2_std',
 'tof_2_min',
 'tof_2_max',
 'tof_3_mean',
 'tof_3_std',
 'tof_3_min',
 'tof_3_max',
 'tof_4_mean',
 'tof_4_std',
 'tof_4_min',
 'tof_4_max',
 'tof_5_mean',
 'tof_5_std',
 'tof_5_min',
 'tof_5_max']

In [4]:
df = pl.read_parquet('output/kaggle_0.8_feats.parquet')
df.columns

['sequence_id',
 'subject',
 'gesture',
 'gesture_int',
 'linear_acc_x',
 'linear_acc_y',
 'linear_acc_z',
 'rot_w',
 'rot_x',
 'rot_y',
 'rot_z',
 'linear_acc_mag',
 'linear_acc_mag_jerk',
 'angular_vel_x',
 'angular_vel_y',
 'angular_vel_z',
 'angular_distance',
 'thm_1',
 'thm_2',
 'thm_3',
 'thm_4',
 'thm_5',
 'tof_1_mean',
 'tof_1_std',
 'tof_1_min',
 'tof_1_max',
 'tof_2_mean',
 'tof_2_std',
 'tof_2_min',
 'tof_2_max',
 'tof_3_mean',
 'tof_3_std',
 'tof_3_min',
 'tof_3_max',
 'tof_4_mean',
 'tof_4_std',
 'tof_4_min',
 'tof_4_max',
 'tof_5_mean',
 'tof_5_std',
 'tof_5_min',
 'tof_5_max']

In [2]:
# =====================================================================================
# 5 NEW ADVANCED MODEL ARCHITECTURES
# =====================================================================================

from src.nn_blocks import match_time_steps, wave_block, res_se_cnn_decoder_block

# --- Advanced Model 2: Stacked Transformer Tower ---
# Hypothesis: Stacking multiple Transformer layers after a strong CNN backbone will
# allow the model to learn very complex global relationships within the sequence.
def create_advanced_model_2_transformer_tower(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # Strong CNN backbone to create rich features for the Transformer
    x1 = residual_se_cnn_block(imu, 64, 3, drop=0.2, wd=wd)
    x1 = residual_se_cnn_block(x1, 128, 5, drop=0.2, wd=wd) # Output shape: (None, 32, 128)
    
    # Stacked Transformer Tower
    # Each block attends to the output of the previous one, building deeper context.
    x1 = TransformerBlock(embed_dim=128, num_heads=4, ff_dim=256, rate=0.3)(x1)
    x1 = TransformerBlock(embed_dim=128, num_heads=4, ff_dim=256, rate=0.3)(x1)
    x1 = TransformerBlock(embed_dim=128, num_heads=4, ff_dim=256, rate=0.3)(x1)
    x1 = TransformerBlock(embed_dim=128, num_heads=4, ff_dim=256, rate=0.3)(x1)
    
    # Standard ToF branch
    x2 = tof_block_2(tof, wd) # Output shape: (None, 32, 128)

    # Merge and classify
    x = features_processing(x1, x2)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# --- Advanced Model 3: Hybrid UNet + WaveNet ---
def create_advanced_model_3_unet_wave(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # U-Net branch
    x1_unet = unet_se_cnn(imu, unet_depth=3, base_filters=64, kernel_size=5)
    
    # Parallel WaveNet branch
    x1_wave = wave_block(imu, 64, 3, n=5, dropout_rate=0.3) # n=5 -> dilations up to 16
    
    # Match time steps and concatenate
    x1_unet_matched, x1_wave_matched = match_time_steps(x1_unet, x1_wave)
    x1 = Concatenate()([x1_unet_matched, x1_wave_matched])
    
    # Standard ToF branch
    x2 = tof_block_2(tof, wd)

    # Merge and classify
    x = features_processing(x1, x2)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# --- Advanced Model 4: Triple Stacked Block Design ---
# Hypothesis: Repeatedly applying a powerful hybrid block (CNN + RNN) will
# progressively refine the feature representation.
def cnn_gru_block(x, filters, kernel_size, wd=1e-4):
    # A self-contained block combining CNN and GRU
    x_cnn = residual_se_cnn_block(x, filters, kernel_size, wd=wd)
    x_gru = Bidirectional(GRU(filters // 2, return_sequences=True))(x_cnn)
    return x_gru

def cnn_gru_block(x, filters, kernel_size, wd=1e-4):
    """
    A simplified and robust block that first applies a CNN, then a GRU.
    """
    # 1. CNN part for feature extraction and downsampling
    x = residual_se_cnn_block(x, filters, kernel_size, wd=wd)
    
    # 2. GRU part for sequence processing
    x = Bidirectional(GRU(filters, return_sequences=True, kernel_regularizer=l2(wd)))(x)
    
    return x

def create_advanced_model_4_stacked_blocks(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # Apply the hybrid block three times
    x1 = cnn_gru_block(imu, 64, 3)  # Output: (None, 64, 128)
    x1 = cnn_gru_block(x1, 128, 5) # Output: (None, 32, 256)
    
    # The final block will not return sequences to simplify the final merge
    x1 = Bidirectional(GRU(128, return_sequences=False))(x1) # Output: (None, 256)
    
    # Standard ToF branch, but we need to aggregate it to match x1
    x2 = tof_block_2(tof, wd) # Output: (None, 32, 128)
    x2 = GlobalAveragePooling1D()(x2) # Output: (None, 128)

    # Merge the two aggregated feature vectors
    x = Concatenate()([x1, x2]) # Output: (None, 256 + 128) = (None, 384)
    
    # Final classifier MLP
    x = Dropout(0.4)(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# --- Advanced Model 5: UNet with BiLSTM Bottleneck ---
# Hypothesis: Adding sequential processing (BiLSTM) at the point of maximum feature
# compression in the U-Net (the bottleneck) will improve context understanding.
def unet_se_cnn_bilstm(x, unet_depth=3, base_filters=64, kernel_size=3, drop=0.3):
    filters = base_filters
    skips = []
    for _ in range(unet_depth):
        x = residual_se_cnn_block(x, filters, kernel_size, drop=drop)
        skips.append(x)
        filters *= 2
    
    # --- BiLSTM Bottleneck ---
    # Process the most compressed representation sequentially
    x = Bidirectional(LSTM(filters // 2, return_sequences=True))(x)
    
    for skip in reversed(skips):
        filters //= 2
        x = res_se_cnn_decoder_block(x, filters, kernel_size, drop=drop, skip_connection=skip)
    return x

def create_advanced_model_1_deep_unet(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # --- IMU Branches ---
    x1_unet = unet_se_cnn(imu, unet_depth=4, base_filters=128, kernel_size=5, drop=0.3)
    x1_conv_k3 = residual_se_cnn_block(imu, 64, 3)
    x1_conv_k7 = residual_se_cnn_block(imu, 64, 7)
    
    # --- FIX: Aggregate each branch BEFORE merging ---
    # This creates a fixed-size vector from each branch, avoiding shape conflicts.
    p1 = GlobalAveragePooling1D()(x1_unet)
    p2 = GlobalAveragePooling1D()(x1_conv_k3)
    p3 = GlobalAveragePooling1D()(x1_conv_k7)
    
    # --- ToF Branch ---
    x2 = tof_block_2(tof, wd)
    p4 = GlobalAveragePooling1D()(x2)

    # Concatenate the aggregated feature vectors
    x = Concatenate()([p1, p2, p3, p4])
    
    # --- Final Classifier MLP ---
    x = Dropout(0.4)(x)
    x = Dense(512, activation='relu')(x)
    x = Dropout(0.3)(x)
    x = Dense(256, activation='relu')(x)
    
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

def create_advanced_model_5_unet_bilstm(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # Use the UNet with the BiLSTM bottleneck
    x1 = unet_se_cnn_bilstm(imu, unet_depth=3, base_filters=128, kernel_size=3)
    
    # Standard ToF branch
    x2 = tof_block_2(tof, wd)

    # --- FIX: Use the robust aggregation strategy instead of features_processing ---
    x1_pooled = GlobalAveragePooling1D()(x1)
    x2_pooled = GlobalAveragePooling1D()(x2)
    x = Concatenate()([x1_pooled, x2_pooled])

    # --- Final Classifier MLP ---
    x = Dropout(0.4)(x)
    x = Dense(512, activation='relu')(x)
    x = Dropout(0.3)(x)
    x = Dense(256, activation='relu')(x)
    
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

In [None]:
# =====================================================================================
# 3 NEW ADVANCED MODEL ARCHITECTURES
# =====================================================================================
from src.nn_blocks import attention_layer

def create_advanced_model_A_dual_unet(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # Branch 1: A deep U-Net for the IMU data
    x1_raw = unet_se_cnn(imu, unet_depth=4, base_filters=128, kernel_size=5, drop=0.3)
    
    # Branch 2: A parallel, slightly lighter U-Net for the ToF/Thermal data
    x2_raw = unet_se_cnn(tof, unet_depth=3, base_filters=64, kernel_size=5, drop=0.3)

    # --- FIX: Project both branches to a common feature dimension (e.g., 128) ---
    # This ensures the input to features_processing is consistent.
    x1 = Conv1D(128, 1, padding='same', activation='relu', name='imu_projection')(x1_raw)
    x2 = Conv1D(128, 1, padding='same', activation='relu', name='tof_projection')(x2_raw)
    
    # Now both x1 and x2 have shape (None, 128, 128)
    # They can be safely passed to the features_processing block.
    x = features_processing(x1, x2)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# --- Advanced Model B: Cross-Attention Fusion ---
# Hypothesis: Instead of just concatenating the IMU and ToF branches, we can create
# richer features by allowing them to "talk to each other." The IMU branch will learn
# what to pay attention to in the ToF data, and vice-versa.
def create_advanced_model_B_cross_attention(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # 1. Create strong, downsampled feature representations for both branches
    # Output Shape for both: (None, 32, 128)
    x1 = residual_se_cnn_block(imu, 64, 3, drop=0.2, wd=wd)
    x1 = residual_se_cnn_block(x1, 128, 5, drop=0.2, wd=wd)
    
    x2 = tof_block_2(tof, wd)

    # 2. Cross-Attention Fusion
    # The IMU branch queries the ToF branch for relevant context
    imu_attends_tof = tf.keras.layers.Attention()([x1, x2])
    # The ToF branch queries the IMU branch for relevant context
    tof_attends_imu = tf.keras.layers.Attention()([x2, x1])
    
    # 3. Create an enriched representation by concatenating all perspectives
    # The final tensor contains the original features plus the context-aware features.
    # Shape: (None, 32, 128 + 128 + 128 + 128) = (None, 32, 512)
    x = Concatenate()([x1, imu_attends_tof, x2, tof_attends_imu])
    
    # 4. Final Processing
    # We use a powerful sequence processor on this ultra-rich tensor
    x = Bidirectional(GRU(256, return_sequences=True, kernel_regularizer=l2(wd)))(x)
    x = attention_layer(x)
    
    x = Dropout(0.4)(x)
    x = Dense(256, activation='relu')(x)
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

# --- Advanced Model C: Stacked Hybrid Blocks ---
# Hypothesis: A single block of (CNN -> RNN) is good. Repeatedly stacking this
# hybrid block will allow the model to learn progressively more abstract and

# powerful spatio-temporal features.
def cnn_lstm_block(x, filters, kernel_size, drop=0.2, wd=1e-4):
    # A self-contained, reusable block
    x = residual_se_cnn_block(x, filters, kernel_size, drop=drop, wd=wd)
    x = Bidirectional(LSTM(filters, return_sequences=True, kernel_regularizer=l2(wd)))(x)
    return x

def create_advanced_model_C_stacked_hybrid(input_shape, imu_dim, wd=1e-4):
    inp = tf.keras.layers.Input(shape=input_shape)
    imu = tf.keras.layers.Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = tf.keras.layers.Lambda(lambda t: t[:, :, imu_dim:])(inp)

    # --- IMU Branch: Stacked Hybrid Blocks ---
    # Each block refines the output of the previous one
    # Input: (128, D) -> Block1: (64, 128) -> Block2: (32, 256)
    x1 = cnn_lstm_block(imu, 64, 3)
    x1 = cnn_lstm_block(x1, 128, 5)
    
    # --- ToF Branch ---
    # Output: (32, 128)
    x2 = tof_block_2(tof, wd)
    # Project ToF features to match the final IMU feature dimension (256)
    x2_projected = Dense(256, activation='relu')(x2)

    # Now both inputs have shape (None, 32, 256) and can be processed
    x = features_processing(x1, x2_projected)
    x = Dropout(0.3)(x) 
    main_out = Dense(NUM_CLASSES, activation="softmax", name="main_output")(x)
    gate_out = Dense(1, activation="sigmoid", name="tof_gate")(x)
    
    return Model(inputs=inp, outputs={"main_output": main_out, "tof_gate": gate_out})

In [5]:
# =====================================================================================
# CONFIGURATION
# =====================================================================================
LR_INIT = 5e-4
WD = 3e-3
NUM_CLASSES = 18
BATCH_SIZE = 64
N_SPLITS = 4 
MAX_PAD_LEN = 128


from src.merge_feats_dynamic import merge_feature_sets
from src.functions import create_sequence_dataset, generate_gate_targets, train_model
from src.nn_blocks import GatedMixupGenerator

# =====================================================================================
# TRAINING LOGIC
# =====================================================================================

FEATURE_DIR = Path('output')
RAW_DIR = Path('input/cmi-detect-behavior-with-sensor-data')
RANDOM_STATE = 42

files_to_merge = [
    'imu_basic_physics_feats.parquet',
    'tof_basic_kaggle_feats.parquet'
]
feature_paths = [FEATURE_DIR / f for f in files_to_merge]
base_df = pl.read_parquet(FEATURE_DIR / "cleaned_base_train_data.parquet")
demographics_df = pl.read_csv(RAW_DIR / "train_demographics.csv")
base_df = base_df.join(demographics_df, on='subject', how='left')

meta_cols = ['sequence_id', 'sequence_counter', 'subject', 'gesture']
thm_cols = ['thm_1', 'thm_2', 'thm_3', 'thm_4', 'thm_5']
base_df = base_df.select(meta_cols + thm_cols)

le = LabelEncoder()
gesture_encoded = le.fit_transform(base_df.get_column('gesture'))
base_df = base_df.with_columns(pl.Series("gesture_int", gesture_encoded))  

final_df = merge_feature_sets(base_df, feature_paths)
print(f"  Final merged DataFrame created with shape: {final_df.shape}")

# --- Step 2: Define FINAL Feature Columns ---
all_final_columns = final_df.columns
final_meta_cols = {'gesture', 'gesture_int', 'subject', 'sequence_id', 'sequence_counter'}
demographic_cols = {'adult_child', 'age', 'sex', 'handedness', 'height_cm', 'shoulder_to_wrist_cm', 'elbow_to_wrist_cm'}
all_feature_cols = [c for c in all_final_columns if c not in final_meta_cols and c not in demographic_cols]

# Define IMU and ToF columns from the final feature list
imu_cols = [c for c in all_feature_cols if c.startswith(('acc_', 'rot_', 'linear_', 'angular_'))]
tof_cols = [c for c in all_feature_cols if c.startswith(('tof_', 'thm_'))]
imu_dim = len(imu_cols)
print(f"  Training with {len(all_feature_cols)} total features ({imu_dim} IMU, {len(tof_cols)} ToF/Thm).")    

# --- Step 3: Prepare for Cross-Validation ---
cv_info = final_df.group_by("sequence_id").agg(pl.first("gesture_int")).sort("sequence_id")
all_sequence_ids = cv_info.get_column("sequence_id").to_numpy()
y_for_split = cv_info.get_column("gesture_int").to_numpy()

input_shape = (MAX_PAD_LEN, len(all_feature_cols)) 

# --- Step 4: Define the Model Architectures to Test ---
# model_builders = [
#     # ("Advanced_Transformer_Tower", lambda: create_advanced_model_2_transformer_tower(input_shape, imu_dim)),
#     # ("Advanced_UNet_Wave", lambda: create_advanced_model_3_unet_wave(input_shape, imu_dim)),
#     # ("Advanced_Stacked_Blocks", lambda: create_advanced_model_4_stacked_blocks(input_shape, imu_dim)),
#     # ("Advanced_Deep_UNet", lambda: create_advanced_model_1_deep_unet(input_shape, imu_dim)),
#     # ("Advanced_UNet_BiLSTM", lambda: create_advanced_model_5_unet_bilstm(input_shape, imu_dim)),
#     ("gated_cnn_transformer_2_blocks", lambda: create_gated_cnn_transformer(input_shape, imu_dim))
    
# ]

model_builders = [
    # ("Advanced_Dual_UNet", lambda: create_advanced_model_A_dual_unet(input_shape, imu_dim)),
    ("Advanced_Cross_Attention", lambda: create_advanced_model_B_cross_attention(input_shape, imu_dim)),
    ("Advanced_Stacked_Hybrid", lambda: create_advanced_model_C_stacked_hybrid(input_shape, imu_dim)),
]

# --- Step 5: Loop Through Each Model Architecture ---
for model_name, model_builder in model_builders:
    print("\n" + "="*60)
    print(f"▶ Training and Evaluating Model: {model_name}")
    print("="*60)

    kf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_STATE)
    fold_accuracies = []
    all_preds = []
    all_labels = []

    for fold_idx, (train_indices, val_indices) in enumerate(kf.split(all_sequence_ids, y_for_split)):
        print(f"\n=== Fold {fold_idx + 1}/{N_SPLITS} for {model_name} ===")
        train_ids = all_sequence_ids[train_indices]
        val_ids = all_sequence_ids[val_indices]

        train_df = final_df.filter(pl.col('sequence_id').is_in(train_ids))
        val_df = final_df.filter(pl.col('sequence_id').is_in(val_ids))
        
        scaler = StandardScaler()
        train_features_scaled = scaler.fit_transform(train_df[all_feature_cols])
        val_features_scaled = scaler.transform(val_df[all_feature_cols])
        X_train_scaled_features = pl.DataFrame(train_features_scaled, schema=all_feature_cols)
        X_val_scaled_features = pl.DataFrame(val_features_scaled, schema=all_feature_cols)

        meta_cols_to_keep = ['sequence_id', 'sequence_counter', 'gesture_int']
        train_df_final = train_df.select(meta_cols_to_keep).with_columns(X_train_scaled_features)
        val_df_final = val_df.select(meta_cols_to_keep).with_columns(X_val_scaled_features)

        # Use the full create_sequence_dataset for gate targets
        X_train, y_train, train_gate_target = create_sequence_dataset(train_df_final, all_feature_cols, generate_gate_targets(train_df, tof_cols))
        X_val, y_val, val_gate_target = create_sequence_dataset(val_df_final, all_feature_cols, generate_gate_targets(val_df, tof_cols))

        X_train_padded = pad_sequences(X_train, maxlen=MAX_PAD_LEN, padding='post', truncating='post', dtype='float32')
        X_val_padded = pad_sequences(X_val, maxlen=MAX_PAD_LEN, padding='post', truncating='post', dtype='float32')
        
        y_train_cat = to_categorical(y_train, num_classes=NUM_CLASSES)
        y_val_cat = to_categorical(y_val, num_classes=NUM_CLASSES)
        
        # Use the GatedMixupGenerator for the two-branch models
        train_dataset = GatedMixupGenerator(
            X=X_train_padded, y=y_train_cat, gate_targets=train_gate_target,
            batch_size=BATCH_SIZE, imu_dim=imu_dim, alpha=0.2, masking_prob=0.25
        )
        val_dataset = tf.data.Dataset.from_tensor_slices((
            X_val_padded, {'main_output': y_val_cat, 'tof_gate': val_gate_target[:, np.newaxis]}
        )).batch(BATCH_SIZE).cache().prefetch(tf.data.AUTOTUNE)

        del X_train, y_train, X_val, y_val, X_train_padded, X_val_padded
        gc.collect()
        
        model = model_builder()
        
        # Use the original train_model function for multi-output models
        history = train_model(model, train_dataset, val_dataset, 150, LR_INIT, WD)
        
        val_preds = model.predict(val_dataset)
        main_output_preds = val_preds['main_output']
        
        y_pred_fold = np.argmax(main_output_preds, axis=1)
        y_true_fold = np.argmax(y_val_cat, axis=1)
        fold_acc = accuracy_score(y_true_fold, y_pred_fold)
        fold_accuracies.append(fold_acc)
        print(f"Fold {fold_idx + 1} Accuracy: {fold_acc:.4f}")
        all_preds.append(y_pred_fold)
        all_labels.append(y_true_fold)

        del train_dataset, model, val_dataset
        gc.collect()

    # --- FINAL OOF REPORT for this model architecture ---
    print(f"\n=== OOF Summary for {model_name} ===")
    print(f"Per-fold Accuracies: {[round(a, 4) for a in fold_accuracies]}")
    print(f"Mean Accuracy: {np.mean(fold_accuracies):.4f} ± {np.std(fold_accuracies):.4f}")
    y_all_pred = np.concatenate(all_preds)
    y_all_true = np.concatenate(all_labels)
    print("\n=== Overall Classification Report ===")
    print(classification_report(y_all_true, y_all_pred, target_names=le.classes_, digits=4))

 Starting merge process...
  Loading and joining features from: imu_basic_physics_feats.parquet
  Loading and joining features from: tof_basic_kaggle_feats.parquet
  Merge complete.
  Final merged DataFrame created with shape: (574945, 46)
  Training with 41 total features (16 IMU, 25 ToF/Thm).

▶ Training and Evaluating Model: Advanced_Cross_Attention

=== Fold 1/4 for Advanced_Cross_Attention ===
LR Scheduler: 96 steps per epoch, 14400 total decay steps.
Epoch 1/150


  self._warn_if_super_not_called()
I0000 00:00:1755287275.909473 2128368 cuda_dnn.cc:529] Loaded cuDNN version 90300


[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 120ms/step - loss: 3.0076 - main_output_accuracy: 0.1456 - main_output_loss: 2.7659 - tof_gate_loss: 0.3349 - val_loss: 2.3775 - val_main_output_accuracy: 0.3445 - val_main_output_loss: 2.1937 - val_tof_gate_loss: 0.0718
Epoch 2/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 91ms/step - loss: 2.5718 - main_output_accuracy: 0.2896 - main_output_loss: 2.3603 - tof_gate_loss: 0.2244 - val_loss: 2.0886 - val_main_output_accuracy: 0.4411 - val_main_output_loss: 1.9117 - val_tof_gate_loss: 0.0665
Epoch 3/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 94ms/step - loss: 2.2848 - main_output_accuracy: 0.3811 - main_output_loss: 2.0969 - tof_gate_loss: 0.1339 - val_loss: 1.9252 - val_main_output_accuracy: 0.4863 - val_main_output_loss: 1.7620 - val_tof_gate_loss: 0.0311
Epoch 4/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 92ms/step - loss: 2.1501 - main_output_accurac

  self._warn_if_super_not_called()


[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 114ms/step - loss: 3.0609 - main_output_accuracy: 0.1277 - main_output_loss: 2.8101 - tof_gate_loss: 0.3772 - val_loss: 2.3897 - val_main_output_accuracy: 0.3204 - val_main_output_loss: 2.2016 - val_tof_gate_loss: 0.0951
Epoch 2/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 95ms/step - loss: 2.5160 - main_output_accuracy: 0.2808 - main_output_loss: 2.3155 - tof_gate_loss: 0.1600 - val_loss: 2.1634 - val_main_output_accuracy: 0.3891 - val_main_output_loss: 1.9916 - val_tof_gate_loss: 0.0481
Epoch 3/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 99ms/step - loss: 2.3171 - main_output_accuracy: 0.3806 - main_output_loss: 2.1289 - tof_gate_loss: 0.1525 - val_loss: 1.9754 - val_main_output_accuracy: 0.4760 - val_main_output_loss: 1.8120 - val_tof_gate_loss: 0.0339
Epoch 4/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 104ms/step - loss: 2.2175 - main_output_accu

  self._warn_if_super_not_called()


[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 113ms/step - loss: 3.0185 - main_output_accuracy: 0.1495 - main_output_loss: 2.7817 - tof_gate_loss: 0.3118 - val_loss: 2.3747 - val_main_output_accuracy: 0.3356 - val_main_output_loss: 2.1810 - val_tof_gate_loss: 0.1232
Epoch 2/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 100ms/step - loss: 2.5305 - main_output_accuracy: 0.2912 - main_output_loss: 2.3278 - tof_gate_loss: 0.1772 - val_loss: 2.1427 - val_main_output_accuracy: 0.4107 - val_main_output_loss: 1.9520 - val_tof_gate_loss: 0.1423
Epoch 3/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 96ms/step - loss: 2.3216 - main_output_accuracy: 0.3658 - main_output_loss: 2.1333 - tof_gate_loss: 0.1387 - val_loss: 1.9751 - val_main_output_accuracy: 0.4539 - val_main_output_loss: 1.8091 - val_tof_gate_loss: 0.0468
Epoch 4/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 99ms/step - loss: 2.1964 - main_output_accu

  self._warn_if_super_not_called()


[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 120ms/step - loss: 3.0073 - main_output_accuracy: 0.1460 - main_output_loss: 2.7696 - tof_gate_loss: 0.3113 - val_loss: 2.3568 - val_main_output_accuracy: 0.3397 - val_main_output_loss: 2.1736 - val_tof_gate_loss: 0.0704
Epoch 2/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 91ms/step - loss: 2.5158 - main_output_accuracy: 0.2887 - main_output_loss: 2.3149 - tof_gate_loss: 0.1617 - val_loss: 2.0984 - val_main_output_accuracy: 0.4183 - val_main_output_loss: 1.9236 - val_tof_gate_loss: 0.0619
Epoch 3/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 90ms/step - loss: 2.3394 - main_output_accuracy: 0.3717 - main_output_loss: 2.1424 - tof_gate_loss: 0.1777 - val_loss: 1.9387 - val_main_output_accuracy: 0.4737 - val_main_output_loss: 1.7750 - val_tof_gate_loss: 0.0342
Epoch 4/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 96ms/step - loss: 2.2258 - main_output_accurac

  self._warn_if_super_not_called()


[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 167ms/step - loss: 3.7012 - main_output_accuracy: 0.1004 - main_output_loss: 3.2555 - tof_gate_loss: 0.4760 - val_loss: 2.9820 - val_main_output_accuracy: 0.2326 - val_main_output_loss: 2.5741 - val_tof_gate_loss: 0.3218
Epoch 2/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 151ms/step - loss: 3.0780 - main_output_accuracy: 0.2184 - main_output_loss: 2.6631 - tof_gate_loss: 0.3743 - val_loss: 2.4589 - val_main_output_accuracy: 0.3974 - val_main_output_loss: 2.0727 - val_tof_gate_loss: 0.2712
Epoch 3/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 164ms/step - loss: 2.7884 - main_output_accuracy: 0.2967 - main_output_loss: 2.4014 - tof_gate_loss: 0.2937 - val_loss: 2.1719 - val_main_output_accuracy: 0.4578 - val_main_output_loss: 1.8058 - val_tof_gate_loss: 0.2312
Epoch 4/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 158ms/step - loss: 2.5538 - main_output_a

  self._warn_if_super_not_called()


[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 178ms/step - loss: 3.6931 - main_output_accuracy: 0.1038 - main_output_loss: 3.2220 - tof_gate_loss: 0.6007 - val_loss: 3.0422 - val_main_output_accuracy: 0.3047 - val_main_output_loss: 2.6133 - val_tof_gate_loss: 0.4287
Epoch 2/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 153ms/step - loss: 2.9782 - main_output_accuracy: 0.2411 - main_output_loss: 2.5731 - tof_gate_loss: 0.3240 - val_loss: 2.4999 - val_main_output_accuracy: 0.3999 - val_main_output_loss: 2.1123 - val_tof_gate_loss: 0.2860
Epoch 3/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 152ms/step - loss: 2.7928 - main_output_accuracy: 0.2981 - main_output_loss: 2.4049 - tof_gate_loss: 0.3023 - val_loss: 2.2221 - val_main_output_accuracy: 0.4490 - val_main_output_loss: 1.8695 - val_tof_gate_loss: 0.1725
Epoch 4/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 146ms/step - loss: 2.5683 - main_output_a

  self._warn_if_super_not_called()


[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 178ms/step - loss: 3.6593 - main_output_accuracy: 0.1127 - main_output_loss: 3.1897 - tof_gate_loss: 0.5914 - val_loss: 2.9582 - val_main_output_accuracy: 0.3116 - val_main_output_loss: 2.5224 - val_tof_gate_loss: 0.4599
Epoch 2/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 161ms/step - loss: 2.9721 - main_output_accuracy: 0.2516 - main_output_loss: 2.5619 - tof_gate_loss: 0.3472 - val_loss: 2.5049 - val_main_output_accuracy: 0.3906 - val_main_output_loss: 2.1147 - val_tof_gate_loss: 0.2898
Epoch 3/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 159ms/step - loss: 2.7607 - main_output_accuracy: 0.3063 - main_output_loss: 2.3724 - tof_gate_loss: 0.2974 - val_loss: 2.1997 - val_main_output_accuracy: 0.4578 - val_main_output_loss: 1.8431 - val_tof_gate_loss: 0.1829
Epoch 4/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 167ms/step - loss: 2.5887 - main_output_a

  self._warn_if_super_not_called()


[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 176ms/step - loss: 3.6896 - main_output_accuracy: 0.1074 - main_output_loss: 3.2350 - tof_gate_loss: 0.5571 - val_loss: 2.9439 - val_main_output_accuracy: 0.3196 - val_main_output_loss: 2.5202 - val_tof_gate_loss: 0.4020
Epoch 2/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 163ms/step - loss: 2.9494 - main_output_accuracy: 0.2490 - main_output_loss: 2.5475 - tof_gate_loss: 0.3079 - val_loss: 2.4505 - val_main_output_accuracy: 0.3702 - val_main_output_loss: 2.0696 - val_tof_gate_loss: 0.2462
Epoch 3/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 171ms/step - loss: 2.7357 - main_output_accuracy: 0.3086 - main_output_loss: 2.3482 - tof_gate_loss: 0.2956 - val_loss: 2.1865 - val_main_output_accuracy: 0.4659 - val_main_output_loss: 1.8392 - val_tof_gate_loss: 0.1393
Epoch 4/150
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 159ms/step - loss: 2.5806 - main_output_a

In [None]:
import traceback
# =====================================================================================
# ARCHITECTURE SANITY CHECK
# =====================================================================================

# --- Step 1: Get a sample batch and define shapes ---
# (This part of your code is correct)
# Make sure your train_dataset is created before this block
try:
    sample_batch = next(iter(train_dataset))
    sample_input = sample_batch[0]
    input_shape = sample_input.shape[1:]
    imu_dim = len(imu_cols) # Assuming imu_cols is defined
    print(f"Sample input shape for testing: {sample_input.shape}\n")
except Exception as e:
    print(f"Could not get a sample from the dataset. Error: {e}")
    # Exit if we can't get a sample to test with
    exit()

# --- Step 2: Create a list of all model-building functions ---
# (This part of your code is correct)
model_builders = [
    ("CNN_Baseline", lambda: create_model_1_cnn_baseline(input_shape)),
    ("GRU_Baseline", lambda: create_model_2_gru_baseline(input_shape)),
    ("CNN_RNN_Hybrid", lambda: create_model_3_cnn_rnn_hybrid(input_shape)),
    ("WaveNet_Style", lambda: create_model_4_wavenet_style(input_shape)),
    ("UNet_Style", lambda: create_model_5_unet_style(input_shape)),
    ("Transformer", lambda: create_model_6_transformer(input_shape)),
    ("CNN_Transformer", lambda: create_model_7_cnn_transformer(input_shape)),
    # For your two-branch model, you'll need the full IMU+ToF dataset
    # ("Two_Branch", lambda: create_model_8_two_branch(input_shape, imu_dim)),
]

# --- Step 3: Loop through the models, build them, and test with the sample ---
print("--- Testing all model architectures with a sample batch ---")
for model_name, model_builder in model_builders:
    print("\n" + "="*60)
    print(f"▶ Testing Model: {model_name}")
    print("="*60)
    
    try:
        # 1. Build the model using the builder function
        model = model_builder()
        
        # Optional: Print the model summary to check its structure
        print(f"Model Summary for {model_name}:")
        model.summary()
        
        # 2. Pass the sample input through the model
        print(f"\nPerforming forward pass for {model_name}...")
        output = model(sample_input)
        
        # 3. Print the output shape to verify it's correct
        print(f"✅ SUCCESS: Model '{model_name}' ran successfully.")
        # For multi-output models, output might be a list/dict. For single, it's a tensor.
        if isinstance(output, dict):
            for key, value in output.items():
                print(f"   Output '{key}' shape: {value.shape}")
        elif isinstance(output, list):
            for i, value in enumerate(output):
                print(f"   Output {i} shape: {value.shape}")
        else:
            print(f"   Output shape: {output.shape}")

    except Exception as e:
        print(f"❌ ERROR: Model '{model_name}' failed to build or run.")
        traceback.print_exc() # This will print the full error traceback
        
    # Clean up the created model to save memory
    del model
    gc.collect()

print("\n--- Model architecture testing complete. ---")

Could not get a sample from the dataset. Error: name 'train_dataset' is not defined
--- Testing all model architectures with a sample batch ---

▶ Testing Model: CNN_Baseline
❌ ERROR: Model 'CNN_Baseline' failed to build or run.


Traceback (most recent call last):
  File "/tmp/ipykernel_1943241/937420362.py", line 43, in <module>
    model = model_builder()
            ^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_1943241/937420362.py", line 23, in <lambda>
    ("CNN_Baseline", lambda: create_model_1_cnn_baseline(input_shape)),
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
NameError: name 'create_model_1_cnn_baseline' is not defined


NameError: name 'model' is not defined

: 