In [1]:
!wget https://repository.gatech.edu/bitstreams/03f9679f-28ce-4d8b-b195-4b3b1aa4adc9/download -O biomechanics_data.zip

--2025-10-16 21:35:14--  https://repository.gatech.edu/bitstreams/03f9679f-28ce-4d8b-b195-4b3b1aa4adc9/download
Resolving repository.gatech.edu (repository.gatech.edu)... 143.215.137.31
Connecting to repository.gatech.edu (repository.gatech.edu)|143.215.137.31|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://repository.gatech.edu/server/api/core/bitstreams/03f9679f-28ce-4d8b-b195-4b3b1aa4adc9/content [following]
--2025-10-16 21:35:15--  https://repository.gatech.edu/server/api/core/bitstreams/03f9679f-28ce-4d8b-b195-4b3b1aa4adc9/content
Reusing existing connection to repository.gatech.edu:443.
HTTP request sent, awaiting response... 200 200
Length: 13378286332 (12G) [application/octet-stream]
Saving to: ‘biomechanics_data.zip’


2025-10-16 21:38:20 (68.8 MB/s) - ‘biomechanics_data.zip’ saved [13378286332/13378286332]



In [None]:
import zipfile
import pandas as pd
import numpy as np
from scipy.signal import butter, filtfilt
import os
import pickle
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv1D, Dropout, ReLU, Add
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras import backend as K
from tensorflow.keras.mixed_precision import Policy, set_global_policy
from tensorflow.keras.mixed_precision import LossScaleOptimizer
import shutil
import glob
import gc
import matplotlib.pyplot as plt
import seaborn as sns

# ---------------------- Setup ----------------------
policy = Policy('mixed_float16')
set_global_policy(policy)

gpus = tf.config.list_physical_devices('GPU')
for gpu in gpus:
    try:
        tf.config.experimental.set_memory_growth(gpu, True)
    except:
        pass

tf.config.optimizer.set_jit(True)

# ---------------------- Subject Mass Data ----------------------
SUBJECT_MASSES = {

    
    "AB01": 78.9,
    "AB02": 82.2,
    "AB03": 113.5,
    "AB05": 71.5,
    "AB06": 79.1,
    "AB07": 62.3,
    "AB08": 87.6,
    "AB09": 84.1,
    "AB10": 67.5,
    "AB11": 65.1,
    "AB12": 64.0,
    "AB13": 67.6,
    
}

# ---------------------- Paths & Groups ----------------------
ZIP_PATH = "/kaggle/working/biomechanics_data.zip"

GROUPS = {
    "stairs up": ["stairs_1_3_up", "stairs_1_7_up", "stairs_1_9_up", "stairs_1_1_up", "stairs_1_11_up", "stairs_1_5_up"],
    "stairs down": ["stairs_1_2_down", "stairs_1_10_down", "stairs_1_12_down", "stairs_1_6_down", "stairs_1_8_down", "stairs_1_4_down"],
    "incline": ["incline_walk_1_up5", "incline_walk_2_up10"],
    "decline": ["incline_walk_2_down10", "incline_walk_1_down5"],
    "cutting": ["cutting_1_right-slow", "cutting_1_left-slow", "cutting_1_left-fast", "cutting_1_right-fast"],
    "jump": ["jump_1_fb", "jump_2_lateral", "jump_2_180", "jump_3_90-1", "jump_3_90-2", "jump_1_vertical", "jump_1_hop"],
    "run": ["tire_run_1"],
    "obstacle": ["obstacle_walk_1"],
    "ball": ["ball_toss_1_right", "ball_toss_1_center", "ball_toss_1_left"],
    "normal_walk": ["normal_walk_1_shuffle", "normal_walk_1_2-5", "normal_walk_1_0-6", "normal_walk_1_skip", "normal_walk_1_1-2", "normal_walk_1_1-8", "normal_walk_1_2-0"],
    "poses": ["poses_1"],
    "lift": ["lift_weight_2_0lbs-r-r", "lift_weight_1_25lbs-r-r", "lift_weight_1_25lbs-r-c", "lift_weight_1_25lbs-l-l", "lift_weight_2_0lbs-r-c", "lift_weight_1_25lbs-l-c", "lift_weight_2_0lbs-l-l", "lift_weight_2_0lbs-l-c"],
    "step": ["step_ups_1_right", "step_ups_1_left"],
    "sit": ["sit_to_stand_2_tall-noarm", "sit_to_stand_1_short-noarm", "sit_to_stand_1_short-arm"],
    "curb": ["curb_up_1", "curb_down_1"],
    "walk_backward": ["walk_backward_1_1-0", "walk_backward_1_0-6", "walk_backward_1_0-8"],
    "squats": ["squats_1_25lbs", "squats_1_0lbs"],
    "start_stop": ["start_stop_1"],
    "calisthenics": ["dynamic_walk_1_heel-walk", "dynamic_walk_1_toe-walk", "dynamic_walk_1_high-knees", "dynamic_walk_1_butt-kicks"],
    "push": ["push_1"],
    "tug of war": ["tug_of_war_1"],
    "lunge": ["lunges_2_right", "lunges_2_left", "lunges_1"],
    "side": ["side_shuffle_1"],
    "twister": ["twister_1"],
    "turn": ["turn_and_step_1_left-turn", "turn_and_step_1_right-turn"],
    "meander": ["meander_1"],
    "weighted_walk": ["weighted_walk_1_25lbs"]
}

# ---------------------- Column definitions ----------------------
IMU_COLUMNS = [
    'LShank_ACCX', 'LShank_ACCY', 'LShank_ACCZ', 'LShank_GYROX', 'LShank_GYROY', 'LShank_GYROZ',
    'LAThigh_ACCX', 'LAThigh_ACCY', 'LAThigh_ACCZ', 'LAThigh_GYROX', 'LAThigh_GYROY', 'LAThigh_GYROZ',
    'LPThigh_ACCX', 'LPThigh_ACCY', 'LPThigh_ACCZ', 'LPThigh_GYROX', 'LPThigh_GYROY', 'LPThigh_GYROZ',
    'LPelvis_ACCX', 'LPelvis_ACCY', 'LPelvis_ACCZ', 'LPelvis_GYROX', 'LPelvis_GYROY', 'LPelvis_GYROZ'
]
INSOLE_COLUMNS = ['LCOP_AP', 'LCOP_ML', 'LVerticalF', 'LShearF_AP', 'LShearF_ML']
EMG_COLUMNS =  ['LGMED', 'RGMED', 'LGMAX', 'RGMAX', 'LGRAC', 'RGRAC', 'LTA', 'RTA']
MOMENT_COLUMNS = ['hip_flexion_l_moment', 'knee_angle_l_moment']

# ---------------------- EMG Debug & Fix Functions ----------------------
def debug_emg_columns(all_subjects_data):
    """Debug function to see what EMG columns are actually available"""
    print("\n🔍 DEBUG: Checking available EMG columns in each subject:")
    
    all_emg_columns = set()
    for subject, df in all_subjects_data.items():
        # Find all columns that might be EMG-related
        emg_cols = [col for col in df.columns if any(emg_keyword in col for emg_keyword in 
                    ['GMED', 'GMAX', 'GRAC', 'TA', 'RF', 'BF', 'MGAS', 'VL'])]
        all_emg_columns.update(emg_cols)
        print(f"   {subject}: {sorted(emg_cols)}")
    
    print(f"\n📊 All unique EMG-related columns: {sorted(all_emg_columns)}")
    return sorted(all_emg_columns)

def fix_emg_column_names_complete(df):
    """Complete EMG column mapping that handles all cases"""
    
    # First, let's see what columns we actually have
    available_cols = df.columns.tolist()
    
    # Map all possible EMG column variations to the expected names
    emg_mapping = {}
    
    # Check for various naming patterns and create mappings
    emg_variations = {
        'VL': ['VL', 'VL_L', 'VL_R', 'LVL', 'RVL'],
        'RF': ['RF', 'RF_L', 'RF_R', 'LRF', 'RRF'], 
        'BF': ['BF', 'BF_L', 'BF_R', 'LBF', 'RBF'],
        'MGAS': ['MGAS', 'MGAS_L', 'MGAS_R', 'LMGAS', 'RMGAS'],
        'GMED': ['GMED', 'GMED_L', 'GMED_R', 'LGMED', 'RGMED'],
        'GMAX': ['GMAX', 'GMAX_L', 'GMAX_R', 'LGMAX', 'RGMAX'],
        'GRAC': ['GRAC', 'GRAC_L', 'GRAC_R', 'LGRAC', 'RGRAC'],
        'TA': ['TA', 'TA_L', 'TA_R', 'LTA', 'RTA']
    }
    
    # Create mappings for available columns
    mappings_applied = []
    for target_col, variations in emg_variations.items():
        for variation in variations:
            if variation in available_cols and variation != target_col:
                emg_mapping[variation] = target_col
                mappings_applied.append(f"{variation} -> {target_col}")
    
    # Apply mapping
    if emg_mapping:
        df = df.rename(columns=emg_mapping)
        print(f"   🔧 Applied EMG mappings: {mappings_applied}")
    
    return df

def ensure_emg_columns_exist(df):
    """Ensure all required EMG columns exist, create missing ones as zeros"""
    required_emg = ['VL', 'RF', 'BF', 'MGAS', 'GMED', 'GMAX', 'GRAC', 'TA']
    
    missing_emg = []
    for emg_col in required_emg:
        if emg_col not in df.columns:
            df[emg_col] = 0.0  # Add as zeros if missing
            missing_emg.append(emg_col)
    
    if missing_emg:
        print(f"   ⚠️ Added missing EMG columns as zeros: {missing_emg}")
    
    return df

def fix_emg_duplicate_columns(df):
    """Fix duplicate EMG column names by taking the first occurrence"""
    # Get all column names
    cols = df.columns.tolist()
    
    # Find duplicates
    seen = set()
    duplicates = set()
    for col in cols:
        if col in seen:
            duplicates.add(col)
        else:
            seen.add(col)
    
    # Remove duplicates by keeping first occurrence
    if duplicates:
        print(f"   🔧 Removing duplicate columns: {list(duplicates)}")
        df = df.loc[:, ~df.columns.duplicated()]
    
    return df

def fix_all_emg_duplicates(all_subjects_data):
    """Fix duplicate EMG columns in all subjects"""
    print("🔧 Fixing duplicate EMG columns in all subjects...")
    
    for subject, df in all_subjects_data.items():
        # Remove duplicate columns (keep first occurrence)
        if df.columns.duplicated().any():
            duplicates = df.columns[df.columns.duplicated()].tolist()
            print(f"   {subject}: Removing duplicates {duplicates}")
            all_subjects_data[subject] = df.loc[:, ~df.columns.duplicated()]
    
    return all_subjects_data

def get_available_emg_features_simple(df):
    """Simpler version - just check if columns exist"""
    available_emg = []
    expected_emg = ['VL', 'RF', 'BF', 'MGAS', 'GMED', 'GMAX', 'GRAC', 'TA']
    
    for emg_col in expected_emg:
        if emg_col in df.columns:
            available_emg.append(emg_col)
    
    print(f"   ✅ Found {len(available_emg)} EMG features: {available_emg}")
    return available_emg
 
# ---------------------- Mass Normalization Functions ----------------------
def normalize_by_mass(df, subject, columns_to_normalize):
    """Normalize specified columns by subject mass (Paper-compliant)"""
    if subject in SUBJECT_MASSES:
        mass = SUBJECT_MASSES[subject]
        for col in columns_to_normalize:
            if col in df.columns:
                df[col] = df[col] / mass
        print(f"   📊 Mass normalization applied for {subject} ({mass} kg)")
    else:
        print(f"   ⚠️ No mass data for {subject}, skipping mass normalization")
    return df

def denormalize_by_mass(df, subject, columns_to_normalize):
    """Denormalize mass-normalized columns"""
    if subject in SUBJECT_MASSES:
        mass = SUBJECT_MASSES[subject]
        for col in columns_to_normalize:
            if col in df.columns:
                df[col] = df[col] * mass
    return df

# ---------------------- Signal processing functions ----------------------
def butter_bandpass(lowcut, highcut, fs, order=4):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a

def butter_lowpass(cutoff, fs, order=4):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = butter(order, normal_cutoff, btype='low')
    return b, a

def process_emg_signal(signal, original_fs, target_times):
    signal = signal - np.mean(signal)
    b, a = butter_bandpass(30, 300, fs=original_fs, order=4)
    signal = filtfilt(b, a, signal)
    signal = np.abs(signal)
    b, a = butter_lowpass(6, fs=original_fs, order=4)
    signal = filtfilt(b, a, signal)
    signal *= 10000
    original_times = np.linspace(0, len(signal)/original_fs, len(signal))
    signal_sync = np.interp(target_times, original_times, signal)
    return signal_sync

# ---------------------- Data collection function ----------------------
def collect_subject_data(subject_name, original_emg_fs=2000, joint='knee'):
    all_data = []

    with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
        zip_files = zip_ref.namelist()
        for group_name, task_list in GROUPS.items():
            for task in task_list:
                task_folder = f"{subject_name}/{task}/"
                if not any(task_folder in f for f in zip_files):
                    continue

                files_dict = {}
                required_suffixes = ["angle.csv", "velocity.csv", "imu_real.csv", "insole_sim.csv", "emg.csv", "moment_filt.csv"]
                for suffix in required_suffixes:
                    file_path = f"{subject_name}/{task}/{subject_name}_{task}_{suffix}"
                    if file_path not in zip_files:
                        print(f"⚠️ File {suffix} for task '{task}' not found. Skipping.")
                        continue
                    files_dict[suffix] = file_path

                if len(files_dict) != len(required_suffixes):
                    continue

                try:
                    with zip_ref.open(files_dict["angle.csv"]) as file:
                        df_angle = pd.read_csv(file, usecols=['time', 'hip_flexion_l', 'knee_angle_l'])
                    with zip_ref.open(files_dict["velocity.csv"]) as file:
                        df_vel = pd.read_csv(file, usecols=['time', 'hip_flexion_velocity_l', 'knee_velocity_l'])
                    with zip_ref.open(files_dict["imu_real.csv"]) as file:
                        df_imu = pd.read_csv(file, usecols=['time'] + IMU_COLUMNS)
                    with zip_ref.open(files_dict["insole_sim.csv"]) as file:
                        df_insole = pd.read_csv(file, usecols=['time'] + INSOLE_COLUMNS)
                    with zip_ref.open(files_dict["emg.csv"]) as file:
                        df_emg_raw = pd.read_csv(file, usecols=['time'] + EMG_COLUMNS)
                    with zip_ref.open(files_dict["moment_filt.csv"]) as file:
                        df_moment = pd.read_csv(file)

                    target_times = df_angle['time'].values
                    df_emg = pd.DataFrame({'time': target_times})
                    for col in EMG_COLUMNS:
                        df_emg[col] = process_emg_signal(df_emg_raw[col].values, original_emg_fs, target_times)

                    df = pd.merge(df_angle, df_vel, on='time', how='left')
                    df = pd.merge(df, df_imu, on='time', how='left')
                    df = pd.merge(df, df_insole, on='time', how='left')
                    df = pd.merge(df, df_emg, on='time', how='left')

                    for col in MOMENT_COLUMNS:
                        df[col] = np.interp(df['time'].values, df_moment['time'].values, df_moment[col].values)

                    df['subject'] = subject_name
                    df['group'] = group_name
                    df['task'] = task

                    all_data.append(df)
                    
                except Exception as e:
                    print(f"❌ Error processing {subject_name}/{task}: {str(e)}")
                    continue

    if all_data:
        full_df = pd.concat(all_data, ignore_index=True)
        
        # Remove rows with NaN in important columns
        columns_to_check = ['LCOP_AP', 'LCOP_ML', 'LVerticalF', 'LShearF_AP', 'LShearF_ML',
                           'hip_flexion_l_moment', 'knee_angle_l_moment']
        all_nan_rows = full_df[columns_to_check].isna().all(axis=1)
        full_df = full_df[~all_nan_rows]

        # Apply MASS NORMALIZATION (Paper-compliant)
        force_columns_to_normalize = ['LVerticalF', 'LShearF_AP', 'LShearF_ML', 
                                    'hip_flexion_l_moment', 'knee_angle_l_moment']
        full_df = normalize_by_mass(full_df, subject_name, force_columns_to_normalize)

        # Fix EMG column names
        print(f"   🔧 Fixing EMG column names for {subject_name}...")
        full_df = fix_emg_column_names_complete(full_df)
        full_df = ensure_emg_columns_exist(full_df)
        full_df = fix_emg_duplicate_columns(full_df)

        # Mirror data for side independence
        def _left_to_base_columns_map(columns):
            mapping = {}
            for col in columns:
                if col in ['time', 'subject', 'group', 'task']:
                    mapping[col] = col
                elif col.endswith('_l_moment'):
                    mapping[col] = col.replace('_l_moment', '_moment')
                elif col.endswith('_l'):
                    mapping[col] = col[:-2]
                elif col.startswith('L') and len(col) > 1:
                    mapping[col] = col[1:]
                else:
                    mapping[col] = col
            return mapping

        col_map = _left_to_base_columns_map(full_df.columns)
        base_df = full_df.rename(columns=col_map).copy()
        mirrored_df = base_df.copy()

        flip_accy = [c for c in base_df.columns if 'ACCY' in c]
        flip_gyro = [c for c in base_df.columns if ('GYROX' in c) or ('GYROZ' in c)]
        flip_cop_ml = [c for c in base_df.columns if 'COP_ML' in c]
        flip_shear_ml = [c for c in base_df.columns if 'ShearF_ML' in c]
        flip_cols = list(set(flip_accy + flip_gyro + flip_cop_ml + flip_shear_ml))
        numeric_cols = [c for c in flip_cols if np.issubdtype(mirrored_df[c].dtype, np.number)]
        mirrored_df[numeric_cols] = mirrored_df[numeric_cols] * -1

        base_df['side'] = 'L'
        mirrored_df['side'] = 'R'
        combined_df = pd.concat([base_df, mirrored_df], ignore_index=True)

        # Select joint-specific columns
        joint = joint.lower()
        if joint == 'knee':
            joint_cols = [
                'time', 'knee_angle', 'knee_velocity',
                'hip_flexion', 'hip_flexion_velocity',  # Using base names after mirroring
                'Shank_ACCX', 'Shank_ACCY', 'Shank_ACCZ',
                'Shank_GYROX', 'Shank_GYROY', 'Shank_GYROZ',
                'AThigh_ACCX', 'AThigh_ACCY', 'AThigh_ACCZ',
                'AThigh_GYROX', 'AThigh_GYROY', 'AThigh_GYROZ',
                'COP_AP', 'COP_ML', 'VerticalF', 'ShearF_AP', 'ShearF_ML',
                'VL', 'RF', 'BF', 'MGAS', 'GMED', 'GMAX', 'GRAC', 'TA',
                'knee_angle_moment', 'subject', 'group', 'task', 'side'
            ]
        elif joint == 'hip':
            joint_cols = [
                'time', 'hip_flexion', 'hip_flexion_velocity',
                'PThigh_ACCX', 'PThigh_ACCY', 'PThigh_ACCZ',
                'PThigh_GYROX', 'PThigh_GYROY', 'PThigh_GYROZ',
                'Pelvis_ACCX', 'Pelvis_ACCY', 'Pelvis_ACCZ',
                'Pelvis_GYROX', 'Pelvis_GYROY', 'Pelvis_GYROZ',
                'COP_AP', 'COP_ML', 'VerticalF', 'ShearF_AP', 'ShearF_ML',
                'RF', 'BF', 'GMED', 'GMAX', 'GRAC', 'TA',
                'hip_flexion_moment', 'subject', 'group', 'task', 'side'
            ]
        else:
            raise ValueError("joint must be 'hip' or 'knee'")
            
        available_cols = [col for col in joint_cols if col in combined_df.columns]
        joint_df = combined_df[available_cols].copy()
        
        print(f"✅ Processed {subject_name}: {len(joint_df)} rows, {len(joint_df.columns)} columns")
        return joint_df
    else:
        print(f"⚠️ No data found for subject {subject_name}.")
        return None

# ---------------------- Process All Subjects ----------------------
def process_all_subjects(joint_input):
    # Use subjects from the mass data
    subjects = list(SUBJECT_MASSES.keys())
    all_subjects_data = {}

    print("📊 Subject Mass Distribution:")
    for subject, mass in SUBJECT_MASSES.items():
        print(f"   {subject}: {mass} kg")
    print(f"   Mean: {np.mean(list(SUBJECT_MASSES.values())):.1f} ± {np.std(list(SUBJECT_MASSES.values())):.1f} kg")
    
    for subject in subjects:
        print(f"\n==================================================")
        print(f"🚀 Processing subject: {subject}, joint: {joint_input}")
        print("==================================================\n")

        try:
            joint_df = collect_subject_data(subject_name=subject, joint=joint_input)

            if joint_df is not None:
                all_subjects_data[subject] = joint_df
                file_name = f"{subject}_{joint_input}_df.pkl"
                joint_df.to_pickle(file_name)
                print(f"✅ DataFrame for {subject} saved as '{file_name}'")
            else:
                print(f"⚠️ No data generated for subject {subject}.")
        except Exception as e:
            print(f"❌ Error processing subject {subject}: {str(e)}")

    print("\n==================================================")
    print("📊 Final Report: Processed Subjects")
    for subject, df in all_subjects_data.items():
        print(f"🗂 {subject}: {len(df)} rows, {len(df.columns)} columns")
    print("==================================================\n")

    return all_subjects_data

# ---------------------- Feature Sets ----------------------
def create_dynamic_feature_sets(all_subjects_data):
    """Create feature sets that MATCH THE PAPER exactly"""
    
    # Use the first subject to check available features
    sample_subject = list(all_subjects_data.keys())[0]
    
    # Use the simple version to avoid pandas issues
    available_emg = get_available_emg_features_simple(all_subjects_data[sample_subject])
    
    print(f"🎯 Available EMG features: {available_emg}")
    
    # PAPER-COMPLIANT kinematic features - USING CORRECT COLUMN NAMES AFTER MIRRORING
    base_kinematic = [
        # Joint angles and velocities (CRITICAL - using base names after mirroring)
        'hip_flexion', 'hip_flexion_velocity',
        'knee_angle', 'knee_velocity',
        
        # Thigh IMU (AThigh in your data)
        'AThigh_ACCX', 'AThigh_ACCY', 'AThigh_ACCZ',
        'AThigh_GYROX', 'AThigh_GYROY', 'AThigh_GYROZ',
        
        # Shank IMU  
        'Shank_ACCX', 'Shank_ACCY', 'Shank_ACCZ',
        'Shank_GYROX', 'Shank_GYROY', 'Shank_GYROZ',
        
        # Insole data
        'COP_AP', 'COP_ML', 'VerticalF', 'ShearF_AP', 'ShearF_ML'
    ]
    
    # Create dynamic feature sets
    features_sets_dynamic = {
        "kinematic": base_kinematic,
        "kinematic_emg": base_kinematic + available_emg,
       # "kinematic_insole": base_kinematic,  # Already includes insole
       # "all": base_kinematic + available_emg
    }
    
    print(f"\n📊 PAPER-COMPLIANT Feature Sets Created:")
    for feat_name, features in features_sets_dynamic.items():
        print(f"   {feat_name}: {len(features)} features")
        # Show hip features to verify they're included
        hip_features = [f for f in features if 'hip' in f]
        if hip_features:
            print(f"      Includes hip data: {hip_features}")
    
    return features_sets_dynamic

# ---------------------- Normalization ----------------------
def normalize_data(X):
    mean = np.mean(X, axis=0, keepdims=True)
    std = np.std(X, axis=0, keepdims=True) + 1e-8
    return (X - mean) / std, mean, std

def denormalize_data(X, mean, std):
    return X * std + mean

# ---------------------- WeightNormalizedConv1D ----------------------
class WeightNormalizedConv1D(Conv1D):
    def build(self, input_shape):
        super().build(input_shape)
        self.g = self.add_weight(name='g', shape=(self.filters,), initializer='ones', trainable=True, dtype=self.dtype)

    def call(self, inputs):
        w_norm = tf.math.l2_normalize(self.kernel, axis=[0,1])
        g_reshaped = tf.reshape(self.g, (1,1,-1))
        w_normalized = w_norm * g_reshaped
        padding = self.padding
        if padding == 'causal':
            k = self.kernel_size[0] if isinstance(self.kernel_size, (list, tuple)) else self.kernel_size
            d = self.dilation_rate[0] if isinstance(self.dilation_rate, (list, tuple)) else self.dilation_rate
            pad_len = (k-1)*d
            inputs = tf.pad(inputs, [[0,0],[pad_len,0],[0,0]])
            conv_padding = 'VALID'
        else:
            conv_padding = 'SAME' if padding=='same' else 'VALID'
        strides = (self.strides[0],) if isinstance(self.strides,(list,tuple)) else (self.strides,)
        dilations = (self.dilation_rate[0],) if isinstance(self.dilation_rate,(list,tuple)) else (self.dilation_rate,)
        outputs = tf.nn.convolution(inputs, w_normalized, padding=conv_padding, strides=strides, dilations=dilations)
        if self.use_bias:
            outputs = tf.nn.bias_add(outputs, self.bias)
        return outputs

# ---------------------- TCN (PAPER-COMPLIANT HYPERPARAMETERS) ----------------------
def temporal_block(x, n_filters, kernel_size, dropout, dilation_rate):
    prev = x
    conv1 = WeightNormalizedConv1D(n_filters, kernel_size, padding='causal', dilation_rate=dilation_rate)(x)
    conv1 = ReLU()(conv1)
    conv1 = Dropout(dropout)(conv1)
    conv2 = WeightNormalizedConv1D(n_filters, kernel_size, padding='causal', dilation_rate=dilation_rate)(conv1)
    conv2 = ReLU()(conv2)
    conv2 = Dropout(dropout)(conv2)
    if prev.shape[-1] != conv2.shape[-1]:
        prev = WeightNormalizedConv1D(conv2.shape[-1], 1)(prev)
    out = Add()([prev, conv2])
    return out

def build_tcn_paper_compliant(input_shape, n_filters=80, kernel_size=5, depth=5, dropout=0.1):
    """
    PAPER-COMPLIANT TCN architecture:
    - 5 residual blocks (depth=5)
    - 80 filters per layer (n_filters=80) 
    - Kernel size 5 (kernel_size=5)
    - Receptive field: 249 samples
    """
    inputs = Input(shape=input_shape)
    x = inputs
    
    # Paper: "five residual block layers with channel size of 80"
    for i in range(depth):
        x = temporal_block(x, n_filters, kernel_size, dropout, dilation_rate=2**i)
    
    outputs = WeightNormalizedConv1D(1, 1, dtype='float32')(x)
    return Model(inputs, outputs)

# ---------------------- Sequence Helper (PAPER-COMPLIANT) ----------------------
def create_sequence_data(X, y, seq_len=248, stride=124):
    """
    Paper-compliant sequence creation:
    - 248 samples = 1.24 seconds at 200 Hz (matches paper)
    - 124 stride = 50% overlap
    """
    X_seq, y_seq = [], []
    for start in range(0, len(X) - seq_len + 1, stride):
        end = start + seq_len
        X_seq.append(X[start:end])
        y_seq.append(y[start:end])
    
    if len(X_seq) == 0:
        # Handle edge case with padding
        if len(X) < seq_len:
            X_padded = np.pad(X, ((0, seq_len - len(X)), (0, 0)), mode='edge')
            y_padded = np.pad(y, ((0, seq_len - len(y)), (0, 0)), mode='edge')
            return np.array([X_padded]), np.array([y_padded])
        return np.zeros((0, seq_len, X.shape[1]), dtype=X.dtype), np.zeros((0, seq_len, y.shape[1]), dtype=y.dtype)
    
    return np.stack(X_seq), np.stack(y_seq)

# ---------------------- CORRECTED LOSO Implementation ----------------------
def create_proper_loso_splits(subjects):
    """
    Create proper Leave-One-Subject-Out splits where each subject is test exactly once
    AND validation subjects rotate properly to avoid data leakage
    """
    splits = []
    
    for i, test_subject in enumerate(subjects):
        # Remaining subjects for train/validation
        remaining_subjects = [s for s in subjects if s != test_subject]
        
        # Use different validation subject for each fold (rotate through remaining subjects)
        # This ensures no subject is overused as validation
        val_idx = i % len(remaining_subjects)
        val_subject = remaining_subjects[val_idx]
        
        # The rest are training subjects
        train_subjects = [s for s in remaining_subjects if s != val_subject]
        
        splits.append({
            'test_subject': test_subject,
            'val_subject': val_subject,
            'train_subjects': train_subjects
        })
        
        print(f"🎯 LOSO Fold: Test: {test_subject}, Val: {val_subject}, Train: {len(train_subjects)} subjects")
    
    return splits

def train_and_evaluate_loso_corrected(all_subjects_data, features_sets, selected_tasks, joint_type="knee", save_dir="./loso_output_corrected"):
    """CORRECTED LOSO: Each subject is test exactly once with proper validation rotation"""
    
    # Ensure save directory exists
    os.makedirs(save_dir, exist_ok=True)
    
    results_file = os.path.join(save_dir, "loso_joint_moment_results_corrected.pkl")
    all_results = {}

    # Load previous results if exist
    if os.path.exists(results_file):
        with open(results_file, "rb") as f:
            all_results = pickle.load(f)

    # Get list of subjects
    subjects = list(all_subjects_data.keys())
    print(f"📊 Performing CORRECTED LOSO CV with {len(subjects)} subjects")
    print("🎯 Structure: Each subject is test exactly once with rotating validation")
    print(f"📏 Using paper-compliant window: 248 samples (1.24s) with 50% overlap")
    
    # PAPER-COMPLIANT HYPERPARAMETERS
    PAPER_PARAMS = {
        'learning_rate': 5e-5,
        'batch_size': 64,
        'n_filters': 80,
        'kernel_size': 5,
        'depth': 5,
        'dropout': 0.1,
        'early_stopping_patience': 20,
        'epochs': 15
    }
    
    # Set label column based on joint type
    if joint_type.lower() == "knee":
        label_col = "knee_angle_moment"
    elif joint_type.lower() == "hip":
        label_col = "hip_flexion_moment"
    else:
        raise ValueError("joint_type must be 'knee' or 'hip'")
    
    print(f"🎯 Predicting {label_col}")
    
    # Create proper LOSO splits with rotating validation
    loso_splits = create_proper_loso_splits(subjects)
    
    for fold_idx, split in enumerate(loso_splits):
        test_subject = split['test_subject']
        val_subject = split['val_subject']
        train_subjects = split['train_subjects']
        
        print(f"\n🎯 LOSO Fold {fold_idx + 1}/{len(loso_splits)} - Test: {test_subject}, Val: {val_subject}, Train: {len(train_subjects)} subjects")
        
        # Process each feature set
        for feat_idx, feat_name in enumerate(selected_tasks, 1):
            print(f"\n=== [{feat_idx}/{len(selected_tasks)}] Feature Set: {feat_name} ===")
            
            if feat_name not in all_results:
                all_results[feat_name] = {}
            
            # Create unique key for this fold
            result_key = f"test_{test_subject}_val_{val_subject}"
            
            # Skip if already processed
            if result_key in all_results[feat_name]:
                print("✅ Already processed, skipping...")
                continue

            # Combine data with proper splits
            train_dfs = []
            val_dfs = []
            test_dfs = []
            
            feat_cols = features_sets[feat_name]
            
            # Verify all required columns exist
            missing_cols_all = []
            for subject in [test_subject, val_subject] + train_subjects:
                if subject in all_subjects_data:
                    missing_cols = [col for col in feat_cols + [label_col] if col not in all_subjects_data[subject].columns]
                    if missing_cols:
                        missing_cols_all.extend(missing_cols)
            
            if missing_cols_all:
                print(f"⚠️ Missing columns: {list(set(missing_cols_all))}. Skipping fold.")
                continue
            
            for subject, df in all_subjects_data.items():
                if subject in train_subjects:
                    train_dfs.append(df[feat_cols + [label_col]])
                elif subject == val_subject:
                    val_dfs.append(df[feat_cols + [label_col]])
                elif subject == test_subject:
                    test_dfs.append(df[feat_cols + [label_col]])
            
            if not train_dfs or not val_dfs or not test_dfs:
                print(f"⚠️ Insufficient data for this split. Skipping.")
                continue
            
            # Process and normalize data
            train_combined = pd.concat(train_dfs, ignore_index=True)
            val_combined = pd.concat(val_dfs, ignore_index=True)
            test_combined = pd.concat(test_dfs, ignore_index=True)
            
            # Normalize using TRAINING data only (no data leakage)
            X_train = train_combined[feat_cols].values.astype(np.float32)
            y_train = train_combined[[label_col]].values.astype(np.float32)
            X_val = val_combined[feat_cols].values.astype(np.float32)
            y_val = val_combined[[label_col]].values.astype(np.float32)
            X_test = test_combined[feat_cols].values.astype(np.float32)
            y_test = test_combined[[label_col]].values.astype(np.float32)
            
            X_train_norm, X_mean, X_std = normalize_data(X_train)
            y_train_norm, y_mean, y_std = normalize_data(y_train)
            
            # Apply same normalization to val and test
            X_val_norm = (X_val - X_mean) / X_std
            y_val_norm = (y_val - y_mean) / y_std
            X_test_norm = (X_test - X_mean) / X_std
            y_test_norm = (y_test - y_mean) / y_std
            
            # Create sequences with PAPER-COMPLIANT parameters
            X_train_seq, y_train_seq = create_sequence_data(X_train_norm, y_train_norm, seq_len=248, stride=124)
            X_val_seq, y_val_seq = create_sequence_data(X_val_norm, y_val_norm, seq_len=248, stride=124)
            X_test_seq, y_test_seq = create_sequence_data(X_test_norm, y_test_norm, seq_len=248, stride=124)
            
            print(f"📊 Sequences - Train: {X_train_seq.shape}, Val: {X_val_seq.shape}, Test: {X_test_seq.shape}")
            
            if X_train_seq.shape[0] == 0 or X_val_seq.shape[0] == 0 or X_test_seq.shape[0] == 0:
                print(f"⚠️ Not enough sequences. Skipping.")
                continue

            # Create datasets with PAPER batch size
            train_dataset = tf.data.Dataset.from_tensor_slices((X_train_seq, y_train_seq)) \
                                          .shuffle(len(X_train_seq)).batch(PAPER_PARAMS['batch_size']).prefetch(tf.data.AUTOTUNE)
            val_dataset = tf.data.Dataset.from_tensor_slices((X_val_seq, y_val_seq)) \
                                        .batch(PAPER_PARAMS['batch_size']).prefetch(tf.data.AUTOTUNE)
            test_dataset = tf.data.Dataset.from_tensor_slices((X_test_seq, y_test_seq)) \
                                         .batch(PAPER_PARAMS['batch_size']).prefetch(tf.data.AUTOTUNE)

            # Build and train model with PAPER-COMPLIANT architecture
            model = build_tcn_paper_compliant(
                input_shape=(248, X_train.shape[1]),
                n_filters=PAPER_PARAMS['n_filters'],
                kernel_size=PAPER_PARAMS['kernel_size'],
                depth=PAPER_PARAMS['depth'],
                dropout=PAPER_PARAMS['dropout']
            )
            
            # Paper-compliant optimizer
            optimizer = LossScaleOptimizer(Adam(PAPER_PARAMS['learning_rate']))
            model.compile(optimizer=optimizer, loss="mse", jit_compile=True)
            
            # Early stopping on validation loss
            es = EarlyStopping(
                monitor='val_loss', 
                patience=PAPER_PARAMS['early_stopping_patience'], 
                restore_best_weights=True
            )

            print("🚀 Training model with PAPER-COMPLIANT hyperparameters...")
            history = model.fit(
                train_dataset, 
                epochs=PAPER_PARAMS['epochs'], 
                validation_data=val_dataset,
                callbacks=[es], 
                verbose=2
            )

            # Evaluate model on test subject
            y_pred = model.predict(test_dataset, verbose=0)
            y_true = y_test_seq
            
            if y_pred.shape != y_true.shape:
                y_pred = y_pred.reshape(y_true.shape)

            # Denormalize
            y_pred_denorm = denormalize_data(y_pred, y_mean, y_std)
            y_true_denorm = denormalize_data(y_true, y_mean, y_std)

            # Calculate metrics
            mse = np.mean((y_true_denorm - y_pred_denorm)**2)
            rmse = np.sqrt(mse)
            ss_res = np.sum((y_true_denorm - y_pred_denorm)**2)
            ss_tot = np.sum((y_true_denorm - np.mean(y_true_denorm))**2)
            r2 = 1 - ss_res / ss_tot if ss_tot > 0 else np.nan
            mae = np.mean(np.abs(y_true_denorm - y_pred_denorm))

            # Store results
            all_results[feat_name][result_key] = {
                "rmse": rmse, 
                "r2": r2,
                "mae": mae,
                "train_subjects": train_subjects,
                "val_subject": val_subject,
                "test_subject": test_subject,
                "final_loss": history.history['loss'][-1],
                "final_val_loss": history.history['val_loss'][-1],
                "feature_count": len(feat_cols),
                "test_subject_mass": SUBJECT_MASSES.get(test_subject, "Unknown"),
                "val_subject_mass": SUBJECT_MASSES.get(val_subject, "Unknown"),
                "epochs_trained": len(history.history['loss'])
            }

            print(f"📊 LOSO {test_subject} (Test) + {val_subject} (Val) -> RMSE={rmse:.3f}, R²={r2:.3f}, MAE={mae:.3f}")

            # Save model
            model_path = os.path.join(save_dir, f"loso_{joint_type}_{feat_name}_{result_key}.keras")
            model.save(model_path)
            print(f"💾 Saved model: {model_path}")

            # Save results after each fold
            with open(results_file, "wb") as f:
                pickle.dump(all_results, f)

            # Free RAM
            del model, train_dataset, val_dataset, test_dataset
            del X_train, X_val, X_test, y_train, y_val, y_test
            del X_train_seq, y_train_seq, X_val_seq, y_val_seq, X_test_seq, y_test_seq
            del y_pred, y_pred_denorm, y_true_denorm
            for _ in range(3):
                gc.collect()
            K.clear_session()

    print(f"\n✅ Completed {len(loso_splits)} CORRECTED LOSO folds")
    return all_results

# ---------------------- Visualization Functions ----------------------
def plot_loso_results_corrected(results_file_path):
    """Plot CORRECTED LOSO cross-validation results"""
    
    # Load results
    with open(results_file_path, "rb") as f:
        results = pickle.load(f)
    
    # Create dataframes for plotting
    rmse_data = []
    r2_data = []
    
    for feat_name in results:
        for fold_key in results[feat_name]:
            fold_data = results[feat_name][fold_key]
            test_subject = fold_data['test_subject']
            val_subject = fold_data['val_subject']
            
            rmse_data.append({
                'Feature Set': feat_name,
                'Test Subject': test_subject,
                'Val Subject': val_subject,
                'RMSE': fold_data['rmse'],
                'Mass': fold_data.get('test_subject_mass', 'Unknown')
            })
            r2_data.append({
                'Feature Set': feat_name,
                'Test Subject': test_subject,
                'Val Subject': val_subject,
                'R²': fold_data['r2'],
                'Mass': fold_data.get('test_subject_mass', 'Unknown')
            })
    
    rmse_df = pd.DataFrame(rmse_data)
    r2_df = pd.DataFrame(r2_data)
    
    # Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle('CORRECTED LOSO Results - Knee Joint Moment Prediction\n(Each Subject Tested Exactly Once with Rotating Validation)', 
                 fontsize=16, fontweight='bold', y=0.98)
    
    # Plot 1: RMSE by Feature Set
    sns.boxplot(data=rmse_df, x='Feature Set', y='RMSE', ax=axes[0,0])
    sns.stripplot(data=rmse_df, x='Feature Set', y='RMSE', ax=axes[0,0], 
                  color='black', alpha=0.6, jitter=True)
    axes[0,0].set_title('RMSE by Feature Set (Mass-Normalized)')
    axes[0,0].set_ylabel('RMSE (Nm/kg)')
    axes[0,0].tick_params(axis='x', rotation=45)
    
    # Plot 2: R² by Feature Set
    sns.boxplot(data=r2_df, x='Feature Set', y='R²', ax=axes[0,1])
    sns.stripplot(data=r2_df, x='Feature Set', y='R²', ax=axes[0,1], 
                  color='black', alpha=0.6, jitter=True)
    axes[0,1].set_title('R² by Feature Set')
    axes[0,1].set_ylabel('R²')
    axes[0,1].tick_params(axis='x', rotation=45)
    
    # Plot 3: RMSE across Test Subjects
    subject_rmse = rmse_df.groupby(['Test Subject', 'Feature Set'])['RMSE'].mean().unstack()
    subject_rmse.plot(kind='bar', ax=axes[1,0], width=0.8)
    axes[1,0].set_title('Average RMSE Across Test Subjects')
    axes[1,0].set_ylabel('RMSE (Nm/kg)')
    axes[1,0].legend(title='Feature Set', bbox_to_anchor=(1.05, 1), loc='upper left')
    axes[1,0].tick_params(axis='x', rotation=45)
    
    # Plot 4: R² across Test Subjects
    subject_r2 = r2_df.groupby(['Test Subject', 'Feature Set'])['R²'].mean().unstack()
    subject_r2.plot(kind='bar', ax=axes[1,1], width=0.8)
    axes[1,1].set_title('Average R² Across Test Subjects')
    axes[1,1].set_ylabel('R²')
    axes[1,1].legend(title='Feature Set', bbox_to_anchor=(1.05, 1), loc='upper left')
    axes[1,1].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    return rmse_df, r2_df

def print_detailed_results_corrected(results_file_path):
    """Print detailed CORRECTED results table"""
    
    # Load results
    with open(results_file_path, "rb") as f:
        results = pickle.load(f)
    
    print("="*100)
    print("📊 DETAILED CORRECTED LOSO RESULTS - KNEE JOINT MOMENT PREDICTION")
    print("="*100)
    print("🎯 CORRECTED STRUCTURE: Each subject is test exactly once with rotating validation")
    print("📏 Using paper-compliant window: 248 samples (1.24s) with 50% overlap")
    print("⚖️ Using mass normalization for all force and moment signals")
    print("="*100)
    
    for feat_name in results:
        print(f"\n🎯 Feature Set: {feat_name}")
        print("-" * 70)
        
        rmse_values = []
        r2_values = []
        mae_values = []
        
        for key in results[feat_name]:
            fold_data = results[feat_name][key]
            test_subject = fold_data['test_subject']
            val_subject = fold_data['val_subject']
            rmse = fold_data['rmse']
            r2 = fold_data['r2']
            mae = fold_data['mae']
            test_mass = fold_data.get('test_subject_mass', 'Unknown')
            val_mass = fold_data.get('val_subject_mass', 'Unknown')
            
            rmse_values.append(rmse)
            r2_values.append(r2)
            mae_values.append(mae)
            
            print(f"  👤 Test: {test_subject} ({test_mass} kg) | Val: {val_subject} ({val_mass} kg)")
            print(f"     RMSE={rmse:.3f}, R²={r2:.3f}, MAE={mae:.3f}")
        
        if rmse_values:
            print(f"\n  📈 Summary ({len(rmse_values)} folds):")
            print(f"     RMSE: {np.mean(rmse_values):.3f} ± {np.std(rmse_values):.3f} Nm/kg")
            print(f"     R²:   {np.mean(r2_values):.3f} ± {np.std(r2_values):.3f}")
            print(f"     MAE:  {np.mean(mae_values):.3f} ± {np.std(mae_values):.3f} Nm/kg")

# ---------------------- Main Execution ----------------------
if __name__ == "__main__":
    # Process all subjects
    joint_input = "knee"
    print(f"🦵 Processing data for {joint_input} joint moments")
    
    # Process all subjects data
    all_subjects_data = process_all_subjects(joint_input)
    
    if not all_subjects_data:
        print("❌ No data processed. Exiting.")
        exit()
    
    # DEBUG: Check what EMG columns we have
    debug_emg_columns(all_subjects_data)
    
    # FIX: Remove duplicate columns first
    all_subjects_data = fix_all_emg_duplicates(all_subjects_data)
    
    # Check again after fixing duplicates
    print("\n🔍 AFTER FIXING DUPLICATES:")
    debug_emg_columns(all_subjects_data)
    
    # Create dynamic feature sets based on available EMG
    features_sets_dynamic = create_dynamic_feature_sets(all_subjects_data)
    
    # Define which feature sets to test
    #selected_tasks = ["kinematic", "kinematic_emg", "kinematic_insole", "all"]
    selected_tasks = ["kinematic", "kinematic_emg"]
    
    # Run CORRECTED LOSO cross validation
    KAGGLE_SAVE_DIR = "/kaggle/working/loso_joint_moment_output_corrected"
    
    results = train_and_evaluate_loso_corrected(
        all_subjects_data, 
        features_sets_dynamic,
        selected_tasks, 
        joint_type=joint_input,
        save_dir=KAGGLE_SAVE_DIR
    )

    # Plot results
    try:
        rmse_df, r2_df = plot_loso_results_corrected("/kaggle/working/loso_joint_moment_output_corrected/loso_joint_moment_results_corrected.pkl")
        print_detailed_results_corrected("/kaggle/working/loso_joint_moment_output_corrected/loso_joint_moment_results_corrected.pkl")
    except Exception as e:
        print(f"⚠️ Could not plot results: {str(e)}")

    # Print final summary
    print("\n" + "="*100)
    print("🎯 CORRECTED LOSO CROSS VALIDATION COMPLETED - JOINT MOMENT PREDICTION")
    print("="*100)
    print("✅ PROPER LOSO IMPLEMENTATION:")
    print("   - Each subject is test exactly once")
    print("   - Validation subjects rotate properly (no data leakage)")
    print("   - Proper train/validation/test splits")
    print("   - No data leakage between folds")
    print("   - Valid results comparable to published research")
    print("="*100)

2025-10-16 21:38:29.673974: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760650710.100006      37 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760650710.233418      37 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


🦵 Processing data for knee joint moments
📊 Subject Mass Distribution:
   AB11: 65.1 kg
   AB12: 64.0 kg
   AB13: 67.6 kg
   AB01: 78.9 kg
   AB02: 82.2 kg
   AB03: 113.5 kg
   AB05: 71.5 kg
   AB06: 79.1 kg
   AB07: 62.3 kg
   AB08: 87.6 kg
   AB09: 84.1 kg
   AB10: 67.5 kg
   Mean: 77.0 ± 13.7 kg

🚀 Processing subject: AB11, joint: knee

   📊 Mass normalization applied for AB11 (65.1 kg)
   🔧 Fixing EMG column names for AB11...
   🔧 Applied EMG mappings: ['LGMED -> GMED', 'RGMED -> GMED', 'LGMAX -> GMAX', 'RGMAX -> GMAX', 'LGRAC -> GRAC', 'RGRAC -> GRAC', 'LTA -> TA', 'RTA -> TA']
   ⚠️ Added missing EMG columns as zeros: ['VL', 'RF', 'BF', 'MGAS']
   🔧 Removing duplicate columns: ['TA', 'GRAC', 'GMAX', 'GMED']
✅ Processed AB11: 789054 rows, 35 columns
✅ DataFrame for AB11 saved as 'AB11_knee_df.pkl'

🚀 Processing subject: AB12, joint: knee

   📊 Mass normalization applied for AB12 (64.0 kg)
   🔧 Fixing EMG column names for AB12...
   🔧 Applied EMG mappings: ['LGMED -> GMED', 'RGMED -

I0000 00:00:1760651166.376790      37 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1760651166.378601      37 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13942 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5


🚀 Training model with PAPER-COMPLIANT hyperparameters...
Epoch 1/15


I0000 00:00:1760651178.041979     101 service.cc:148] XLA service 0x79a1d0014b00 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1760651178.044894     101 service.cc:156]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1760651178.044922     101 service.cc:156]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1760651178.177143     101 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1760651178.374730     101 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


858/858 - 49s - 57ms/step - loss: 0.4742 - val_loss: 0.1589
Epoch 2/15
858/858 - 8s - 9ms/step - loss: 0.1640 - val_loss: 0.1155
Epoch 3/15
858/858 - 8s - 9ms/step - loss: 0.1201 - val_loss: 0.1049
Epoch 4/15
858/858 - 8s - 9ms/step - loss: 0.1002 - val_loss: 0.0967
Epoch 5/15
858/858 - 8s - 9ms/step - loss: 0.0887 - val_loss: 0.0878
Epoch 6/15
858/858 - 8s - 9ms/step - loss: 0.0805 - val_loss: 0.0880
Epoch 7/15
858/858 - 8s - 9ms/step - loss: 0.0745 - val_loss: 0.0856
Epoch 8/15
858/858 - 8s - 9ms/step - loss: 0.0695 - val_loss: 0.0799
Epoch 9/15
858/858 - 7s - 8ms/step - loss: 0.0652 - val_loss: 0.0782
Epoch 10/15
858/858 - 8s - 9ms/step - loss: 0.0617 - val_loss: 0.0772
Epoch 11/15
858/858 - 8s - 9ms/step - loss: 0.0587 - val_loss: 0.0753
Epoch 12/15
858/858 - 8s - 9ms/step - loss: 0.0559 - val_loss: 0.0779
Epoch 13/15
858/858 - 8s - 9ms/step - loss: 0.0538 - val_loss: 0.0726
Epoch 14/15
858/858 - 8s - 9ms/step - loss: 0.0518 - val_loss: 0.0696
Epoch 15/15
858/858 - 7s - 9ms/step - 