Kaggle/Drive

In [None]:
!pip install kagglehub --quiet
import kagglehub
from google.colab import drive
drive.mount('/content/drive')

KeyboardInterrupt: 

Import and store dataset

In [None]:
# Mount Google Drive (if not already done)
from google.colab import drive
drive.mount('/content/drive')

# Install Kaggle CLI if missing
!pip install -q kaggle

# Setup Kaggle API credentials (make sure kaggle.json is in your Drive)
!mkdir -p ~/.kaggle
!cp /content/drive/MyDrive/kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# Download and unzip dataset locally
!kaggle datasets download -d tanjemahamed/mental-fatigue-level-detection-fatigueset-data --unzip -p /content/fatigue_data

# List local downloaded files to verify
!ls /content/fatigue_data

import shutil
import os

source_dir = '/content/fatigue_data/fatigueset'  # The actual dataset folder
drive_dest = '/content/drive/MyDrive/Fatigue_Set'

# Create destination folder if it doesn't exist
os.makedirs(drive_dest, exist_ok=True)

# Define full destination path for the folder copy
dest_dir = os.path.join(drive_dest, 'fatigueset')

# Remove destination folder if it exists to avoid copytree error
if os.path.exists(dest_dir):
    shutil.rmtree(dest_dir)

# Recursively copy entire directory
shutil.copytree(source_dir, dest_dir)

print(f'Dataset folder copied recursively to: {dest_dir}')

Focused files and column structure

In [None]:
import os
import pandas as pd

BASE_PATH = '/content/drive/MyDrive/Fatigue_Set/fatigueset'
persons = [f'{i:02d}' for i in range(1, 13)]
sessions = [f'{i:02d}' for i in range(1, 4)]

# Selected sensor files relevant for your fatigue detection model
selected_sensor_files = [
    'wrist_hr.csv',
    'wrist_ibi.csv',
    'wrist_acc.csv',
    'wrist_eda.csv',
    'wrist_skin_temperature.csv',
    'exp_fatigue.csv'
]

for person in persons:
    for session in sessions:
        session_folder = os.path.join(BASE_PATH, person, session)
        print(f'\nPerson: {person}, Session: {session}')
        for sensor_file in selected_sensor_files:
            file_path = os.path.join(session_folder, sensor_file)
            if os.path.exists(file_path):
                try:
                    df = pd.read_csv(file_path, nrows=3)  # Read only first few rows
                    print(f'{sensor_file} columns: {list(df.columns)}')
                except Exception as e:
                    print(f"Error reading {file_path}: {e}")
            else:
                print(f'{sensor_file} not found in session {session} of person {person}')


Focused columns sample


In [None]:
import os
import pandas as pd
import numpy as np

BASE_PATH = '/content/drive/MyDrive/Fatigue_Set/fatigueset'
PERSONS = [f'{i:02d}' for i in range(1, 13)]
SESSIONS = [f'{i:02d}' for i in range(1, 4)]

SENSOR_FILES = {
    'wrist_hr.csv': ['hr'],
    'wrist_ibi.csv': ['duration'],
    'wrist_acc.csv': ['ax', 'ay', 'az'],
    'wrist_eda.csv': ['eda'],
    'wrist_skin_temperature.csv': ['temp']
}

WINDOW_SIZE_SEC = 30
STEP_SIZE_SEC = 15

def get_min_sampling_interval(filepath):
    df = pd.read_csv(filepath)
    if 'timestamp' not in df.columns or df.empty:
        return None
    ts = pd.to_datetime(df['timestamp'], unit='ms')
    intervals = ts.diff().dropna()
    min_interval = intervals.min()
    return min_interval

def load_and_resample(filepath, cols, resample_freq):
    df = pd.read_csv(filepath)
    df = df.dropna(subset=['timestamp'])
    df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
    df.set_index('timestamp', inplace=True)

    original_counts = df[cols].count()

    df_resampled = df[cols].resample(resample_freq).mean().interpolate()
    resampled_counts = df_resampled.count()

    print(f"File: {os.path.basename(filepath)}")
    for col in cols:
        orig = original_counts[col]
        resampled = resampled_counts[col]
        percent = (resampled / orig * 100) if orig > 0 else 0
        print(f"  Column: {col}, Original entries: {orig}, Resampled entries: {resampled}, Percentage: {percent:.2f}%")
    return df_resampled

def load_and_merge_session(person, session):
    session_path = os.path.join(BASE_PATH, person, session)
    sensor_min_intervals = []

    for file_name in SENSOR_FILES.keys():
        file_path = os.path.join(session_path, file_name)
        if os.path.exists(file_path):
            min_intv = get_min_sampling_interval(file_path)
            if min_intv is not None:
                sensor_min_intervals.append(min_intv)
    if not sensor_min_intervals:
        print(f"No sensor data found for person {person} session {session}")
        return None

    best_interval = max(sensor_min_intervals)
    resample_milliseconds = int(best_interval.total_seconds() * 1000)
    resample_freq_str = f"{resample_milliseconds}ms"
    print(f"\nPerson {person}, Session {session}, Resampling freq chosen: every {resample_freq_str}")

    data_frames = []
    for file_name, cols in SENSOR_FILES.items():
        file_path = os.path.join(session_path, file_name)
        if os.path.exists(file_path):
            df_resampled = load_and_resample(file_path, cols, resample_freq_str)
            data_frames.append(df_resampled)
    if not data_frames:
        return None

    merged_df = pd.concat(data_frames, axis=1).interpolate().dropna()
    print(f"\nSynchronized merged data sample for Person {person} Session {session}:")
    print(merged_df.head(20))

    return merged_df

# Run for all persons and sessions
for person in PERSONS:
    for session in SESSIONS:
        print(f"Processing Person {person}, Session {session}")
        merged_data = load_and_merge_session(person, session)


Features , Labels , Count data

In [None]:
import os
import pandas as pd
import numpy as np

BASE_PATH = '/content/drive/MyDrive/Fatigue_Set/fatigueset'
PERSONS   = [f'{i:02d}' for i in range(1, 13)]
SESSIONS  = [f'{i:02d}' for i in range(1, 4)]

SENSOR_FILES = {
    'wrist_hr.csv': ['hr'],
    'wrist_ibi.csv': ['duration'],
    'wrist_acc.csv': ['ax', 'ay', 'az'],
    'wrist_eda.csv': ['eda'],
    'wrist_skin_temperature.csv': ['temp']
}

WINDOW_SIZE_SEC = 30
STEP_SIZE_SEC   = 15

session_type_map = {
    '01': 0,  # baseline
    '02': 1,  # physical
    '03': 2   # mental
}

# ========== Utilities for Data Processing ==========

def get_min_sampling_interval(filepath):
    df = pd.read_csv(filepath)
    if 'timestamp' not in df.columns or df.empty:
        return None
    ts = pd.to_datetime(df['timestamp'], unit='ms')
    intervals = ts.diff().dropna()
    return intervals.min()

def load_and_resample(filepath, cols, resample_freq):
    df = pd.read_csv(filepath)
    df = df.dropna(subset=['timestamp'])
    df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
    df.set_index('timestamp', inplace=True)

    df_resampled = df[cols].resample(resample_freq).mean().interpolate()
    return df_resampled

def load_and_merge_session(person, session):
    session_path = os.path.join(BASE_PATH, person, session)
    sensor_min_intervals = []
    for file_name in SENSOR_FILES.keys():
        file_path = os.path.join(session_path, file_name)
        if os.path.exists(file_path):
            min_intv = get_min_sampling_interval(file_path)
            if min_intv is not None:
                sensor_min_intervals.append(min_intv)
    if not sensor_min_intervals:
        return None, None, None

    best_interval = max(sensor_min_intervals)
    resample_milliseconds = int(best_interval.total_seconds() * 1000)
    resample_freq_str = f"{resample_milliseconds}ms"

    data_frames = []
    for file_name, cols in SENSOR_FILES.items():
        file_path = os.path.join(session_path, file_name)
        if os.path.exists(file_path):
            df_resampled = load_and_resample(file_path, cols, resample_freq_str)
            data_frames.append(df_resampled)
    if not data_frames:
        return None, None, None

    merged_df = pd.concat(data_frames, axis=1).interpolate().dropna()
    return merged_df, None, None

def windowed_segmentation(data, window_size_sec=30, step_size_sec=15, fs_hz=None):
    if fs_hz is None:
        timedelta = (data.index[1] - data.index[0]).total_seconds()
        fs_hz = 1 / timedelta

    window_size_samples = int(window_size_sec * fs_hz)
    step_size_samples   = int(step_size_sec * fs_hz)

    segments, indices = [], []
    for start in range(0, len(data) - window_size_samples + 1, step_size_samples):
        end = start + window_size_samples
        segment = data.iloc[start:end]
        segments.append(segment)
        indices.append(segment.index[0])
    return segments, indices

def extract_features(segment):
    features = {}
    for col in segment.columns:
        features[f'{col}_mean'] = segment[col].mean()
        features[f'{col}_std']  = segment[col].std()
    return features

def load_fatigue_labels(person, session, session_start_timestamp):
    fatigue_path = os.path.join(BASE_PATH, person, session, 'exp_fatigue.csv')
    if not os.path.exists(fatigue_path):
        return None
    df = pd.read_csv(fatigue_path)
    df['fatigueSurveySubmissionDatetime'] = df['fatigueSurveySubmissionTime'].apply(
        lambda x: pd.Timestamp(session_start_timestamp) + pd.Timedelta(seconds=x)
    )
    return df

def align_labels_to_windows_time_based(label_df, window_starts):
    """
    Instead of nearest label assignment, interpolate fatigue scores over time.
    - Assumes label_df has 'fatigueSurveySubmissionDatetime',
      'physicalFatigueScore', 'mentalFatigueScore'
    - window_starts is a list/array of pandas Timestamps
    """
    submission_times = pd.to_datetime(label_df['fatigueSurveySubmissionDatetime'])
    phys_scores = label_df['physicalFatigueScore'].values
    ment_scores = label_df['mentalFatigueScore'].values

    # Build interpolation functions (time → fatigue score)
    phys_interp = np.interp(
        [ts.value for ts in window_starts],     # convert to ns int
        [t.value for t in submission_times],
        phys_scores
    )
    ment_interp = np.interp(
        [ts.value for ts in window_starts],
        [t.value for t in submission_times],
        ment_scores
    )

    labels = []
    for p, m in zip(phys_interp, ment_interp):
        labels.append({
            'physicalFatigueScore': p,
            'mentalFatigueScore': m
        })
    return labels


def process_person_session(person, session):
    merged_df, _, _ = load_and_merge_session(person, session)
    if merged_df is None:
        return None

    ts_deltas = merged_df.index.to_series().diff().dropna()
    fs_hz = 1 / ts_deltas.mean().total_seconds()

    windows, window_starts = windowed_segmentation(merged_df, WINDOW_SIZE_SEC, STEP_SIZE_SEC, fs_hz)

    feature_list = [extract_features(window) for window in windows]
    features_df = pd.DataFrame(feature_list)

    session_start_timestamp = merged_df.index.min()
    fatigue_labels_df = load_fatigue_labels(person, session, session_start_timestamp)
    if fatigue_labels_df is None:
        return None

    labels = align_labels_to_windows_time_based(fatigue_labels_df, window_starts)
    labels_df = pd.DataFrame(labels)

    merged_features_labels_df = pd.concat([features_df.reset_index(drop=True),
                                           labels_df.reset_index(drop=True)], axis=1)

    merged_features_labels_df['window_start'] = pd.to_datetime(window_starts).values
    merged_features_labels_df['person'] = person
    merged_features_labels_df['session'] = session
    return merged_features_labels_df

# ========== Build Whole Dataset ==========

all_data = []
for person in PERSONS:
    for session in SESSIONS:
        df = process_person_session(person, session)
        if df is not None:
            all_data.append(df)

final_df = pd.concat(all_data, ignore_index=True)
print("Final dataframe shape:", final_df.shape)

save_path = '/content/drive/MyDrive/Fatigue_Set/final_feature_label_dataset_interpolated.csv'
final_df.to_csv(save_path, index=False)
print(f"Saved merged dataset to {save_path}")


In [None]:
ffld=pd.read_csv('/content/drive/MyDrive/Fatigue_Set/final_feature_label_dataset_interpolated.csv')
ffld.columns

Data Normalization

In [None]:
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import joblib
import random

# Load dataset
df = pd.read_csv('/content/drive/MyDrive/Fatigue_Set/final_feature_label_dataset_interpolated.csv')

# Sensor feature columns
feature_cols = [
    'hr_mean', 'hr_std', 'duration_mean', 'duration_std',
    'ax_mean', 'ax_std', 'ay_mean', 'ay_std', 'az_mean', 'az_std',
    'eda_mean', 'eda_std', 'temp_mean', 'temp_std'
]

# Target label columns
label_cols = ['physicalFatigueScore', 'mentalFatigueScore']

# ---- Hold-out split BEFORE scaling ----
all_pairs = df.groupby(['person','session']).size().index.tolist()
random.seed(42); random.shuffle(all_pairs)
train_pairs, test_pairs = all_pairs[:31], all_pairs[31:]

train_df = pd.concat([df[(df.person==p) & (df.session==s)] for (p,s) in train_pairs])
test_df  = pd.concat([df[(df.person==p) & (df.session==s)] for (p,s) in test_pairs])

# ---- Fit scaler only on training data ----
feat_scaler = MinMaxScaler()
train_df[feature_cols] = feat_scaler.fit_transform(train_df[feature_cols])
test_df[feature_cols]  = feat_scaler.transform(test_df[feature_cols])
joblib.dump(feat_scaler, 'feature_scaler.save')

# ---- Scale labels (optional: usually for regression stability) ----
label_scaler = MinMaxScaler()
train_df[label_cols] = label_scaler.fit_transform(train_df[label_cols])
test_df[label_cols]  = label_scaler.transform(test_df[label_cols])
joblib.dump(label_scaler, 'label_scaler.save')

# ---- Save normalized train/test sets ----
train_df.to_csv('/content/drive/MyDrive/Fatigue_Set/final_train_normalized.csv', index=False)
test_df.to_csv('/content/drive/MyDrive/Fatigue_Set/final_test_normalized.csv', index=False)

print("Features + Labels normalized (train/test separately) and saved.")

In [None]:
df=pd.read_csv('/content/drive/MyDrive/Fatigue_Set/final_train_normalized.csv')
df.head()

In [None]:
import pandas as pd

df = pd.read_csv('/content/drive/MyDrive/Fatigue_Set/final_feature_label_dataset_normalized.csv')

# Count entries per (person, session)
counts = df.groupby(['person', 'session']).size().reset_index(name='count')

print(counts)

total_counts = df.groupby('person').size().reset_index(name='total_count')

print(total_counts)

#Single Client - train with 3 X 3 sessions , test with 1 X 1 session (Desired model is after time sequencing model)

Base line model (without time sequencing)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from copy import deepcopy

# ========== Model Components ==========
class ModalityLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim=32):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
    def forward(self, x):
        if x.ndim == 2:
            x = x.unsqueeze(1)
        out, _ = self.lstm(x)
        return out[:, -1, :]

class CrossModalAttention(nn.Module):
    def __init__(self, n_modalities, hidden_dim=32, fusion_dim=64):
        super().__init__()
        concat_dim = n_modalities * hidden_dim
        self.query = nn.Linear(concat_dim, fusion_dim)
        self.key   = nn.Linear(concat_dim, fusion_dim)
        self.value = nn.Linear(concat_dim, fusion_dim)
    def forward(self, features, domain_disc):
        x = torch.cat(features, dim=1)
        Q, K, V = self.query(x), self.key(x), self.value(x)
        attn_scores = torch.matmul(Q, K.T) / (K.size(-1) ** 0.5)
        attn_scores = attn_scores + domain_disc
        attn_weights = F.softmax(attn_scores, dim=-1)
        out = torch.matmul(attn_weights, V)
        return out

class GradReverse(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lambd):
        ctx.lambd = lambd
        return x.view_as(x)
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.lambd, None

class GradientReversal(nn.Module):
    def __init__(self, lambd=1.0):
        super().__init__()
        self.lambd = lambd
    def forward(self, x):
        return GradReverse.apply(x, self.lambd)

class DomainAdaptiveLayer(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.grl = GradientReversal()
        self.fc = nn.Linear(input_dim, input_dim)
    def forward(self, x):
        return self.fc(self.grl(x))

class FMAL_Daf(nn.Module):
    def __init__(self, modalities_dim, lstm_hidden=32, fusion_dim=64):
        super().__init__()
        self.modality_lstms = nn.ModuleList([ModalityLSTM(inp_dim, lstm_hidden) for inp_dim in modalities_dim])
        self.attn_fusion = CrossModalAttention(len(modalities_dim), lstm_hidden, fusion_dim)
        self.domain_adapt = DomainAdaptiveLayer(fusion_dim)
        self.global_lstm = nn.LSTM(fusion_dim, 32, batch_first=True)
        self.reg_head_phys = nn.Linear(32, 1)
        self.reg_head_ment = nn.Linear(32, 1)
        self.class_head = nn.Linear(32, 3)
    def forward(self, modal_inputs, domain_disc):
        feats = [mod(mod_inp) for mod, mod_inp in zip(self.modality_lstms, modal_inputs)]
        fused = self.attn_fusion(feats, domain_disc)
        adapted = self.domain_adapt(fused)
        glstm_out, _ = self.global_lstm(adapted.unsqueeze(1))
        feat = glstm_out[:, -1, :]
        return self.reg_head_phys(feat), self.reg_head_ment(feat), self.class_head(feat)

# ========== Dataset Loader ==========
class FatigueSessionDataset(Dataset):
    def __init__(self, df):
        self.data = df.reset_index(drop=True)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        HR   = torch.tensor([row['hr_mean'], row['hr_std']], dtype=torch.float32)
        IBI  = torch.tensor([row['duration_mean'], row['duration_std']], dtype=torch.float32)
        ACC  = torch.tensor([row['ax_mean'], row['ax_std'], row['ay_mean'], row['ay_std'], row['az_mean'], row['az_std']], dtype=torch.float32)
        EDA  = torch.tensor([row['eda_mean'], row['eda_std']], dtype=torch.float32)
        Temp = torch.tensor([row['temp_mean'], row['temp_std']], dtype=torch.float32)
        phys = torch.tensor([row['physicalFatigueScore']], dtype=torch.float32)
        ment = torch.tensor([row['mentalFatigueScore']], dtype=torch.float32)
        session_type = torch.tensor(0, dtype=torch.long)  # placeholder
        domain_disc = torch.tensor([0.0], dtype=torch.float32)  # placeholder
        return [HR, IBI, ACC, EDA, Temp], domain_disc, (phys, ment, session_type)

def collate(batch):
    modal_inputs = [torch.stack([sample[0][i] for sample in batch]) for i in range(5)]
    domain_disc = torch.stack([sample[1] for sample in batch])
    phys = torch.stack([sample[2][0] for sample in batch])
    ment = torch.stack([sample[2][1] for sample in batch])
    stype = torch.stack([sample[2][2] for sample in batch])
    return modal_inputs, domain_disc, (phys, ment, stype)

# ========== Data Preparation ==========
df = pd.read_csv('/content/drive/MyDrive/Fatigue_Set/final_feature_label_dataset_normalized.csv')

all_pairs = df.groupby(['person', 'session']).size().index.tolist()

# One client: 3 persons × 3 sessions = 9 pairs
train_pairs = all_pairs[:9]
# One test pair
test_pair = all_pairs[9]

# ✅ Print which person–session pairs are chosen
print("Training person-session pairs:")
for p, s in train_pairs:
    print(f"  Person {p}, Session {s}")

print("\nTesting person-session pair:")
print(f"  Person {test_pair[0]}, Session {test_pair[1]}")

# Build client dataset (only 1 client here)
client_df = pd.concat([df[(df.person==p) & (df.session==s)] for (p, s) in train_pairs]).drop(columns=['person'])
clients_datasets = [FatigueSessionDataset(client_df)]

# Test dataset
test_df = df[(df.person == test_pair[0]) & (df.session == test_pair[1])]
test_df= test_df.drop(columns=['person'])
test_dataset = FatigueSessionDataset(test_df)

# Loaders
clients_loaders = [DataLoader(clients_datasets[0], batch_size=32, shuffle=True, collate_fn=collate)]
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate)

# ========== Training ==========
modal_dims = [2, 2, 6, 2, 2]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
global_model = FMAL_Daf(modal_dims).to(device)

criterion_reg = nn.MSELoss()
criterion_cls = nn.CrossEntropyLoss()

def train_one_client(model, loader, client_id, rnd):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    total_loss = 0
    for batch_idx, (modal_inputs, domain_disc, (phys, ment, stype)) in enumerate(loader):
        modal_inputs = [x.to(device) for x in modal_inputs]
        domain_disc = domain_disc.to(device)
        phys, ment, stype = phys.to(device), ment.to(device), stype.to(device)
        optimizer.zero_grad()
        phys_pred, ment_pred, stype_pred = model(modal_inputs, domain_disc)
        loss = criterion_reg(phys_pred, phys) + criterion_reg(ment_pred, ment)
        loss += criterion_cls(stype_pred, stype)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(loader)
    print(f"[Round {rnd+1}] Client {client_id+1} Avg Training Loss: {avg_loss:.4f}")
    return model.state_dict()

optimizer = torch.optim.Adam(global_model.parameters(), lr=1e-3)

for epoch in range(10):
    global_model.train()
    total_loss = 0
    for modal_inputs, domain_disc, (phys, ment, stype) in clients_loaders[0]:
        modal_inputs = [x.to(device) for x in modal_inputs]
        domain_disc = domain_disc.to(device)
        phys, ment, stype = phys.to(device), ment.to(device), stype.to(device)

        optimizer.zero_grad()
        phys_pred, ment_pred, stype_pred = global_model(modal_inputs, domain_disc)
        loss = criterion_reg(phys_pred, phys) + criterion_reg(ment_pred, ment)
        loss += criterion_cls(stype_pred, stype)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"[Epoch {epoch+1}] Avg Training Loss: {total_loss / len(clients_loaders[0]):.4f}")


# ========== Testing ==========
global_model.eval()
reg_losses, cls_losses = [], []
with torch.no_grad():
    for modal_inputs, domain_disc, (phys, ment, stype) in test_loader:
        modal_inputs = [x.to(device) for x in modal_inputs]
        domain_disc = domain_disc.to(device)
        phys, ment, stype = phys.to(device), ment.to(device), stype.to(device)
        phys_pred, ment_pred, stype_pred = global_model(modal_inputs, domain_disc)
        reg_loss = criterion_reg(phys_pred, phys) + criterion_reg(ment_pred, ment)
        cls_loss = criterion_cls(stype_pred, stype)
        reg_losses.append(reg_loss.item())
        cls_losses.append(cls_loss.item())

print("\n=== Final Test Performance ===")
print(f"Test Regression Loss: {np.mean(reg_losses):.4f}")
print(f"Test Classification Loss: {np.mean(cls_losses):.4f}")


PREVIOUSLY WITHOUT TIME SEQUENCING

In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ========== Dataset Loader ==========
class FatigueSessionDataset(Dataset):
    def __init__(self, df):
        self.data = df.reset_index(drop=True)
        self.session_map = {1: 0, 2: 1, 3: 2}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # --- Features ---
        HR   = torch.tensor([row['hr_mean'], row['hr_std']], dtype=torch.float32)
        IBI  = torch.tensor([row['duration_mean'], row['duration_std']], dtype=torch.float32)
        ACC  = torch.tensor([row['ax_mean'], row['ax_std'],
                            row['ay_mean'], row['ay_std'],
                            row['az_mean'], row['az_std']], dtype=torch.float32)
        EDA  = torch.tensor([row['eda_mean'], row['eda_std']], dtype=torch.float32)
        Temp = torch.tensor([row['temp_mean'], row['temp_std']], dtype=torch.float32)

        # --- Labels ---
        phys = torch.tensor([row['physicalFatigueScore']], dtype=torch.float32)
        ment = torch.tensor([row['mentalFatigueScore']], dtype=torch.float32)

        session_type = torch.tensor(self.session_map[row['session']], dtype=torch.long)
        domain_disc = torch.tensor([0.0], dtype=torch.float32)  # placeholder
        return [HR, IBI, ACC, EDA, Temp], domain_disc, (phys, ment, session_type)


def collate(batch):
    modal_inputs = [torch.stack([sample[0][i] for sample in batch]) for i in range(5)]
    domain_disc  = torch.stack([sample[1] for sample in batch])
    phys         = torch.stack([sample[2][0] for sample in batch])
    ment         = torch.stack([sample[2][1] for sample in batch])
    stype        = torch.stack([sample[2][2] for sample in batch])
    return modal_inputs, domain_disc, (phys, ment, stype)


# ========== Model Architecture ==========

class ModalityLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim=32):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
    def forward(self, x):
        if x.ndim == 2:
            x = x.unsqueeze(1)
        out, _ = self.lstm(x)
        return out[:, -1, :]


class CrossModalAttention(nn.Module):
    def __init__(self, n_modalities, hidden_dim=32, fusion_dim=64):
        super().__init__()
        # Instead of concatenation, we expect [batch, n_modalities, hidden_dim]
        self.query = nn.Linear(hidden_dim, fusion_dim)
        self.key   = nn.Linear(hidden_dim, fusion_dim)
        self.value = nn.Linear(hidden_dim, fusion_dim)

    def forward(self, features, domain_disc):
        # features: list of modality outputs, each [batch, hidden_dim]
        # Stack to [batch, n_modalities, hidden_dim]
        x = torch.stack(features, dim=1)

        Q = self.query(x)  # [batch, n_modalities, fusion_dim]
        K = self.key(x)
        V = self.value(x)

        # Compute attention scores along modality dim
        # attn_scores shape: [batch, n_modalities, n_modalities]
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (K.size(-1) ** 0.5)

        # domain_disc shape is [batch, 1] or [batch, dim], broadcasting accordingly
        # Add domain distribution bias if provided
        if domain_disc is not None and domain_disc.numel() > 1:
            # reshape if needed
            domain_disc = domain_disc.unsqueeze(1).expand(-1, attn_scores.size(1), attn_scores.size(2))
            attn_scores = attn_scores + domain_disc
        else:
            # domain_disc is placeholder scalar, ignore
            pass

        attn_weights = F.softmax(attn_scores, dim=-1)  # weights over modalities
        out = torch.matmul(attn_weights, V)  # [batch, n_modalities, fusion_dim]

        # Aggregate modalities by averaging weighted outputs per sample
        out = out.mean(dim=1)  # [batch, fusion_dim]
        return out


class FMAL_Daf_Modified(nn.Module):
    def __init__(self, modalities_dim, lstm_hidden=32, fusion_dim=64, use_grl=False):
        super().__init__()
        self.modality_lstms = nn.ModuleList([ModalityLSTM(inp_dim, lstm_hidden) for inp_dim in modalities_dim])
        self.attn_fusion = CrossModalAttention(len(modalities_dim), lstm_hidden, fusion_dim)
        self.attn_dropout = nn.Dropout(0.2)  # Added dropout
        # Domain adaptation disabled for debugging
        self.domain_adapt = nn.Identity() if not use_grl else nn.Linear(fusion_dim, fusion_dim)
        self.global_lstm = nn.LSTM(fusion_dim, 32, batch_first=True)
        self.reg_head_phys = nn.Linear(32, 1)
        self.reg_head_ment = nn.Linear(32, 1)
        self.class_head = nn.Linear(32, 3)

    def forward(self, modal_inputs, domain_disc):
        feats = [mod(mod_inp) for mod, mod_inp in zip(self.modality_lstms, modal_inputs)]
        fused = self.attn_fusion(feats, domain_disc)
        fused = self.attn_dropout(fused)
        adapted = self.domain_adapt(fused)
        glstm_out, _ = self.global_lstm(adapted.unsqueeze(1))
        feat = glstm_out[:, -1, :]
        return self.reg_head_phys(feat), self.reg_head_ment(feat), self.class_head(feat)


# ========== Training & Evaluation Functions ==========
def train_model(model, loader, optimizer, criterion_reg, criterion_cls, epochs=10):
    for epoch in range(1, epochs+1):
        model.train()
        total_loss = 0.0
        for modal_inputs, domain_disc, (phys, ment, stype) in loader:
            modal_inputs = [x for x in modal_inputs]
            optimizer.zero_grad()

            # Move domain_disc to device if not None
            domain_disc = domain_disc.to(modal_inputs[0].device)

            phys_pred, ment_pred, cls_pred = model(modal_inputs, domain_disc)

            loss_reg = criterion_reg(phys_pred, phys) + criterion_reg(ment_pred, ment)
            loss_cls = criterion_cls(cls_pred, stype)
            loss = loss_reg + 0.1 * loss_cls  # Scale classification loss
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(loader)
        print(f"[Epoch {epoch}] Avg Training Loss: {avg_loss:.4f}")


def evaluate_and_show_predictions(model, test_loader, device, criterion_reg, criterion_cls):
    model.eval()
    total_reg_loss, total_cls_loss = 0.0, 0.0
    with torch.no_grad():
        for batch in test_loader:
            features, domain_disc, (true_phys, true_ment, true_stype) = batch
            features = [x.to(device) for x in features]
            true_phys = true_phys.to(device)
            true_ment = true_ment.to(device)
            true_stype = true_stype.to(device)

            domain_disc = domain_disc.to(device)

            pred_phys, pred_ment, pred_cls = model(features, domain_disc)
            loss_reg = criterion_reg(pred_phys, true_phys) + criterion_reg(pred_ment, true_ment)
            loss_cls = criterion_cls(pred_cls, true_stype)

            total_reg_loss += loss_reg.item()
            total_cls_loss += loss_cls.item()

            pred_cls_labels = torch.argmax(pred_cls, dim=1)

            print("=== Predictions vs Ground Truth ===")
            for i in range(len(pred_phys)):
                print(f"Sample {i}:")
                print(f"  True phys={true_phys[i].item():.3f}, True ment={true_ment[i].item():.3f}, "
                      f"True session_type={true_stype[i].item()}")
                print(f"  Pred phys={pred_phys[i].item():.3f}, Pred ment={pred_ment[i].item():.3f}, "
                      f"Pred session_type={pred_cls_labels[i].item()}")
            break  # only first batch

    avg_reg_loss = total_reg_loss / len(test_loader)
    avg_cls_loss = total_cls_loss / len(test_loader)
    print("\n=== Final Test Performance ===")
    print(f"Test Regression Loss: {avg_reg_loss:.4f}")
    print(f"Test Classification Loss: {avg_cls_loss:.4f}")


# ========== Main ==========
df = pd.read_csv('/content/drive/MyDrive/Fatigue_Set/final_feature_label_dataset_normalized.csv')

# Extract all unique (person, session) pairs
all_pairs = df.groupby(['person', 'session']).size().index.tolist()

import random
random.seed(42)
random.shuffle(all_pairs)

# Select 9 pairs for training and 3 for testing
train_pairs = all_pairs[:9]
test_pairs = all_pairs[9:12]  # next 3 pairs

print("Training person-session pairs:")
for (p, s) in train_pairs:
    print(f"  Person {int(p)}, Session {int(s)}")

print("\nTesting person-session pairs:")
for (p, s) in test_pairs:
    print(f"  Person {int(p)}, Session {int(s)}")

# Concatenate data for train and test sets respectively
train_df = pd.concat([df[(df.person==p) & (df.session==s)] for (p, s) in train_pairs]).drop(columns=['person'])
test_df = pd.concat([df[(df.person==p) & (df.session==s)] for (p, s) in test_pairs]).drop(columns=['person'])

train_dataset = FatigueSessionDataset(train_df)
test_dataset = FatigueSessionDataset(test_df)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate)

print("\n=== Sample from Training Data ===")
for i in range(3):
    features, _, (phys, ment, stype) = train_dataset[i]
    print(f"Sample {i}: phys={phys.item():.3f}, ment={ment.item():.3f}, session_type={stype.item()}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
modal_dims = [2, 2, 6, 2, 2]
model = FMAL_Daf_Modified(modal_dims, use_grl=False).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion_reg = nn.MSELoss()
criterion_cls = nn.CrossEntropyLoss()

train_model(model, train_loader, optimizer, criterion_reg, criterion_cls, epochs=10)
evaluate_and_show_predictions(model, test_loader, device, criterion_reg, criterion_cls)


AFTER TIME SEQUENCING (SESSIONS-WISE SEQUENCE-WISE)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd, numpy as np, random
from torch.nn.utils.rnn import pad_sequence


# ========== Dataset ==========
class FatigueSessionDataset(Dataset):
    """
    Groups rows into sessions. Each __getitem__ = one timestep from one session.
    Why? -> To keep sample-level printing like your original "Sample 0: phys=.."
    """
    def __init__(self, df, session_map=None):
        self.data = df.sort_values(["session","window_start"]).reset_index(drop=True)
        self.session_map = session_map if session_map else {1:0,2:1,3:2}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        HR   = torch.tensor([row['hr_mean'], row['hr_std']], dtype=torch.float32)
        IBI  = torch.tensor([row['duration_mean'], row['duration_std']], dtype=torch.float32)
        ACC  = torch.tensor([row['ax_mean'],row['ax_std'],row['ay_mean'],row['ay_std'],row['az_mean'],row['az_std']],dtype=torch.float32)
        EDA  = torch.tensor([row['eda_mean'], row['eda_std']], dtype=torch.float32)
        Temp = torch.tensor([row['temp_mean'], row['temp_std']], dtype=torch.float32)

        phys = torch.tensor([row['physicalFatigueScore']], dtype=torch.float32)
        ment = torch.tensor([row['mentalFatigueScore']], dtype=torch.float32)
        stype = torch.tensor(self.session_map[row['session']], dtype=torch.long)
        domain_disc = torch.tensor([0.0], dtype=torch.float32)
        return [HR,IBI,ACC,EDA,Temp], domain_disc, (phys,ment,stype)


def collate(batch):
    modal_inputs = [torch.stack([b[0][i] for b in batch]) for i in range(5)]
    domain_disc  = torch.stack([b[1] for b in batch])
    phys         = torch.stack([b[2][0] for b in batch])
    ment         = torch.stack([b[2][1] for b in batch])
    stype        = torch.stack([b[2][2] for b in batch])
    return modal_inputs, domain_disc, (phys,ment,stype)


# ========== Model ==========
class ModalityLSTM(nn.Module):
    def __init__(self,input_dim,hidden_dim=32):
        super().__init__()
        self.lstm = nn.LSTM(input_dim,hidden_dim,batch_first=True)
    def forward(self,x):
        if x.ndim==2:
            x = x.unsqueeze(1)      # single-step fallback
        out,_ = self.lstm(x)
        return out[:,-1,:]


class CrossModalAttention(nn.Module):
    def __init__(self,n_modalities,hidden_dim=32,fusion_dim=64):
        super().__init__()
        self.query = nn.Linear(hidden_dim,fusion_dim)
        self.key   = nn.Linear(hidden_dim,fusion_dim)
        self.value = nn.Linear(hidden_dim,fusion_dim)
    def forward(self,features,domain_disc):
        x = torch.stack(features,1)          # [B,M,H]
        Q,K,V = self.query(x),self.key(x),self.value(x)
        scores = torch.matmul(Q,K.transpose(-2,-1))/(K.size(-1)**0.5)
        attn = F.softmax(scores,dim=-1)
        out = torch.matmul(attn,V).mean(1)   # [B,Fusion]
        return out


class DomainAdaptiveLayer(nn.Module):
    def __init__(self,dim): super().__init__(); self.fc=nn.Linear(dim,dim)
    def forward(self,x): return self.fc(x)


class FMAL_Daf(nn.Module):
    def __init__(self,modalities_dim,lstm_hidden=32,fusion_dim=64):
        super().__init__()
        self.modality_lstms=nn.ModuleList([ModalityLSTM(d,lstm_hidden) for d in modalities_dim])
        self.attn_fusion=CrossModalAttention(len(modalities_dim),lstm_hidden,fusion_dim)
        self.domain_adapt=DomainAdaptiveLayer(fusion_dim)
        self.global_lstm=nn.LSTM(fusion_dim,32,batch_first=True)
        self.reg_phys=nn.Linear(32,1)
        self.reg_ment=nn.Linear(32,1)
        self.class_head=nn.Linear(32,3)
    def forward(self,modal_inputs,domain_disc):
        feats=[m(inp) for m,inp in zip(self.modality_lstms,modal_inputs)]
        fused=self.attn_fusion(feats,domain_disc)
        adapted=self.domain_adapt(fused)
        glstm_out,_=self.global_lstm(adapted.unsqueeze(1))
        feat=glstm_out[:,-1,:]
        return self.reg_phys(feat),self.reg_ment(feat),self.class_head(feat)


# ========== Train / Eval ==========
def train_model(model,loader,optim,crit_reg,crit_cls,epochs=10):
    for ep in range(1,epochs+1):
        model.train(); tot=0
        for modals,dom,(phys,ment,stype) in loader:
            modals=[m.to(device) for m in modals]
            dom=dom.to(device); phys,ment,stype=phys.to(device),ment.to(device),stype.to(device)
            optim.zero_grad()
            p,m,s=model(modals,dom)
            loss=crit_reg(p,phys)+crit_reg(m,ment)+0.1*crit_cls(s,stype)
            loss.backward(); optim.step(); tot+=loss.item()
        print(f"[Epoch {ep}] Avg Training Loss: {tot/len(loader):.4f}")


def evaluate_and_show_predictions(model,test_loader,device,crit_reg,crit_cls):
    model.eval(); regL,clsL=0,0
    with torch.no_grad():
        for modals,dom,(phys,ment,stype) in test_loader:
            modals=[m.to(device) for m in modals]
            phys,ment,stype=phys.to(device),ment.to(device),stype.to(device)
            dom=dom.to(device)
            p_pred,m_pred,s_pred=model(modals,dom)
            loss_r=crit_reg(p_pred,phys)+crit_reg(m_pred,ment)
            loss_c=crit_cls(s_pred,stype)
            regL+=loss_r.item(); clsL+=loss_c.item()
            pred_cls=torch.argmax(s_pred,1)
            print("=== Predictions vs Ground Truth ===")
            for i in range(len(phys)):
                print(f"Sample {i}:")
                print(f"  True phys={phys[i].item():.3f}, True ment={ment[i].item():.3f}, True session_type={stype[i].item()}")
                print(f"  Pred phys={p_pred[i].item():.3f}, Pred ment={m_pred[i].item():.3f}, Pred session_type={pred_cls[i].item()}")
            break
    print("\n=== Final Test Performance ===")
    print(f"Test Regression Loss: {regL/len(test_loader):.4f}")
    print(f"Test Classification Loss: {clsL/len(test_loader):.4f}")


# ========== Main ==========
df=pd.read_csv('/content/drive/MyDrive/Fatigue_Set/final_feature_label_dataset_normalized.csv')
all_pairs=df.groupby(['person','session']).size().index.tolist()
random.seed(42); random.shuffle(all_pairs)
train_pairs,test_pairs=all_pairs[:9],all_pairs[9:12]

print("Training person-session pairs:")
for p,s in train_pairs: print(f"  Person {p}, Session {s}")
print("\nTesting person-session pairs:")
for p,s in test_pairs: print(f"  Person {p}, Session {s}")

train_df=pd.concat([df[(df.person==p)&(df.session==s)] for (p,s) in train_pairs])
test_df=pd.concat([df[(df.person==p)&(df.session==s)] for (p,s) in test_pairs])

train_dataset=FatigueSessionDataset(train_df)
test_dataset=FatigueSessionDataset(test_df)
train_loader=DataLoader(train_dataset,batch_size=32,shuffle=True,collate_fn=collate)
test_loader=DataLoader(test_dataset,batch_size=32,shuffle=False,collate_fn=collate)

print("\n=== Sample from Training Data ===")
for i in range(3):
    feats,_,(phys,ment,stype)=train_dataset[i]
    print(f"Sample {i}: phys={phys.item():.3f}, ment={ment.item():.3f}, session_type={stype.item()}")

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=FMAL_Daf([2,2,6,2,2]).to(device)
optim=torch.optim.Adam(model.parameters(),lr=1e-3)
crit_reg,crit_cls=nn.MSELoss(),nn.CrossEntropyLoss()

train_model(model,train_loader,optim,crit_reg,crit_cls,epochs=10)
evaluate_and_show_predictions(model,test_loader,device,crit_reg,crit_cls)


#Centralized Model - With Novelty

---



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import math
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import pprint # Import pprint for nice dictionary printing
from collections import defaultdict # To easily append metrics

# =========================
# Config / Debug
# =========================
DEBUG_ATTENTION = False  # set True to print domain scaling + attention weights for the first batch

# ========== Dataset ==========
class FatigueSessionDataset(Dataset):
    """
    Each __getitem__ = one timestep.
    - Converts session to int and uses it as a numeric feature (optionally can embed later).
    - Computes elapsed_time within each (person, session) from window_start.
    - Returns modal tensors, domain_disc placeholder, session_feat, elapsed_time, and targets.
    """
    def __init__(self, df, session_map=None, compute_elapsed=True):
        df = df.copy()
        # ensure correct types
        # Convert only if not already int/float, handle potential errors
        df['session'] = pd.to_numeric(df['session'], errors='coerce').fillna(0).astype(int)
        df['person'] = pd.to_numeric(df['person'], errors='coerce').fillna(0).astype(int)
        # parse window_start as datetime
        df['window_start'] = pd.to_datetime(df['window_start'], errors='coerce')
        df.dropna(subset=['window_start'], inplace=True) # Drop rows where date conversion failed

        # compute elapsed time (seconds) within each person-session
        if compute_elapsed:
            # Add check for empty groups which can cause errors
            if not df.empty:
                 df['elapsed_time'] = df.groupby(['person', 'session'], observed=True)['window_start'] \
                                       .transform(lambda x: (x - x.min()).dt.total_seconds())
            else:
                 df['elapsed_time'] = 0.0
        else:
            df['elapsed_time'] = 0.0

        # sort to preserve temporal order per session
        self.data = df.sort_values(['person', 'session', 'window_start']).reset_index(drop=True)

        # Build session_map if not given
        if session_map is None:
            all_sessions = sorted(self.data['session'].unique())
            self.session_map = {s: i for i, s in enumerate(all_sessions)}
        else:
            self.session_map = session_map

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # Modalities - handle potential missing keys gracefully
        hr_mean = row.get('hr_mean', 0.0)
        hr_std = row.get('hr_std', 0.0)
        dur_mean = row.get('duration_mean', 0.0)
        dur_std = row.get('duration_std', 0.0)
        ax_mean = row.get('ax_mean', 0.0); ax_std = row.get('ax_std', 0.0)
        ay_mean = row.get('ay_mean', 0.0); ay_std = row.get('ay_std', 0.0)
        az_mean = row.get('az_mean', 0.0); az_std = row.get('az_std', 0.0)
        eda_mean = row.get('eda_mean', 0.0); eda_std = row.get('eda_std', 0.0)
        temp_mean = row.get('temp_mean', 0.0); temp_std = row.get('temp_std', 0.0)

        HR   = torch.tensor([hr_mean, hr_std], dtype=torch.float32)
        IBI  = torch.tensor([dur_mean, dur_std], dtype=torch.float32)
        ACC  = torch.tensor([ax_mean, ax_std, ay_mean, ay_std, az_mean, az_std], dtype=torch.float32)
        EDA  = torch.tensor([eda_mean, eda_std], dtype=torch.float32)
        Temp = torch.tensor([temp_mean, temp_std], dtype=torch.float32)

        # Targets (regression)
        phys  = torch.tensor([row.get('physicalFatigueScore', 0.0)], dtype=torch.float32)
        ment  = torch.tensor([row.get('mentalFatigueScore', 0.0)], dtype=torch.float32)

        # Session numeric feature
        session_val = int(row.get('session', 0))
        session_index = self.session_map.get(session_val, 0) # Default to index 0 if session not in map
        session_feat = torch.tensor([float(session_index)], dtype=torch.float32)

        # Elapsed time feature (seconds)
        elapsed = torch.tensor([float(row.get('elapsed_time', 0.0))], dtype=torch.float32)

        # Domain discrepancy placeholder
        domain_disc = torch.tensor([0.0], dtype=torch.float32)

        return [HR, IBI, ACC, EDA, Temp], domain_disc, session_feat, elapsed, (phys, ment)


def collate(batch):
    # Filter out None items potentially caused by errors in __getitem__
    batch = [b for b in batch if b is not None]
    if not batch:
        # Return empty tensors or raise error if batch is empty after filtering
        # Returning None signals DataLoader to skip this batch if batch_size > 0 and drop_last=False
        return None

    # Stack modality tensors: output is a list of 5 tensors each [B, feature_dim]
    try:
        modal_inputs = [torch.stack([b[0][i] for b in batch]) for i in range(5)]
        domain_disc  = torch.stack([b[1] for b in batch])
        session_feat = torch.stack([b[2] for b in batch])      # [B, 1]
        elapsed_feat = torch.stack([b[3] for b in batch])      # [B, 1]
        phys         = torch.stack([b[4][0] for b in batch])
        ment         = torch.stack([b[4][1] for b in batch])
        return modal_inputs, domain_disc, session_feat, elapsed_feat, (phys, ment)
    except Exception as e:
        print(f"Error during collation: {e}")
        # Return None to skip the batch if collation fails
        return None

# ========== Model ==========
# (ModalityLSTM, DomainScalingMLP, CrossModalAttention, DomainAdaptiveLayer remain unchanged)
class ModalityLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim=32):
        super().__init__()
        # small LSTM per modality (keeps same interface as original)
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)

    def forward(self, x):
        # handle [B, feat] -> [B,1,feat]
        if x.ndim == 2:
            x = x.unsqueeze(1)
        # Ensure input tensor is contiguous
        x = x.contiguous()
        try:
            out, _ = self.lstm(x)
            return out[:, -1, :]  # last timestep
        except Exception as e:
            print(f"Error in ModalityLSTM forward: {e}")
            print(f"Input shape: {x.shape}, Input dtype: {x.dtype}, Input device: {x.device}")
            # Depending on the error, you might want to return a zero tensor of expected shape
            # Or re-raise the exception
            raise e

class DomainScalingMLP(nn.Module):
    def __init__(self, input_dim=1, n_modalities=5, hidden=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, n_modalities),
            nn.Sigmoid()
        )

    def forward(self, domain_disc):
        return self.net(domain_disc)


class CrossModalAttention(nn.Module):
    def __init__(self, n_modalities, hidden_dim=32, fusion_dim=64):
        super().__init__()
        self.query = nn.Linear(hidden_dim, fusion_dim)
        self.key   = nn.Linear(hidden_dim, fusion_dim)
        self.value = nn.Linear(hidden_dim, fusion_dim)
        self.scaler = DomainScalingMLP(input_dim=1, n_modalities=n_modalities, hidden=16)

    def forward(self, features, domain_disc, debug=False):
        # features: list of [B, H] -> stack -> [B, M, H]
        x = torch.stack(features, dim=1)  # [B, M, H]
        Q, K, V = self.query(x), self.key(x), self.value(x)  # -> [B, M, F]
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (K.size(-1) ** 0.5)
        attn = F.softmax(scores, dim=-1)

        # Domain scaling (kept from original)
        scaling = self.scaler(domain_disc).unsqueeze(1)  # [B,1,M]
        attn = attn * scaling
        # Add a small epsilon to prevent division by zero or NaN issues if all inputs are zero after scaling
        attn = F.softmax(attn + 1e-9, dim=-1)


        if debug and DEBUG_ATTENTION:
            print("Domain scaling factors:", scaling[0, 0].detach().cpu().numpy())
            print("Attention weights:", attn[0].detach().cpu().numpy())

        fused_per_query = torch.matmul(attn, V)  # [B, M, F]
        fused = fused_per_query.mean(dim=1)      # [B, F] (average across queries)
        return fused


class DomainAdaptiveLayer(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(dim, dim)
        self.bn = nn.BatchNorm1d(dim)

    def forward(self, x):
        # Need to handle potential batch size of 1 during evaluation
        if x.size(0) > 1:
            return self.bn(self.fc(x))
        else:
            # BatchNorm1d throws error for batch size 1, just pass through linear layer
            return self.fc(x)


class FMAL_Daf(nn.Module):
    def __init__(self, modalities_dim, lstm_hidden=64, fusion_dim=128,
                 num_sessions=3, use_time_and_session=True):
        super().__init__()
        self.use_time_and_session = use_time_and_session

        # modality LSTMs
        self.modality_lstms = nn.ModuleList([ModalityLSTM(d, lstm_hidden) for d in modalities_dim])

        # attention fusion
        self.attn_fusion = CrossModalAttention(len(modalities_dim), lstm_hidden, fusion_dim)

        # domain-adaptive transformation
        self.domain_adapt = DomainAdaptiveLayer(fusion_dim)

        # session embedding
        self.num_sessions = num_sessions # Store it
        if num_sessions > 0:
             self.session_emb = nn.Embedding(num_sessions, 8)
        else:
             # Handle case with 0 sessions if it can occur, maybe default to zeros?
             self.session_emb = None # Or a dummy module

        # elapsed time encoder
        self.time_enc = nn.Sequential(
            nn.Linear(1, 8),
            nn.ReLU(),
            nn.Linear(8, 8)
        )

        # global LSTM input dim calculation
        global_input_dim = fusion_dim
        if use_time_and_session and self.session_emb is not None:
            global_input_dim += 16 # 8 for session, 8 for time

        self.global_lstm = nn.LSTM(global_input_dim, 32, batch_first=True)

        # regression heads
        self.reg_phys = nn.Linear(32, 1)
        self.reg_ment = nn.Linear(32, 1)

        self.apply(self._init_weights)

    @staticmethod
    def _init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LSTM):
            for name, param in m.named_parameters():
                if 'weight' in name:
                    nn.init.xavier_uniform_(param)
                elif 'bias' in name:
                    nn.init.zeros_(param)
        elif isinstance(m, nn.Embedding):
             nn.init.uniform_(m.weight, -0.1, 0.1)


    def forward(self, modal_inputs, domain_disc, session_feat=None, elapsed_feat=None, debug=False):
        feats = [m(inp) for m, inp in zip(self.modality_lstms, modal_inputs)]
        fused = self.attn_fusion(feats, domain_disc, debug=debug)
        adapted = self.domain_adapt(fused)

        if self.use_time_and_session:
            if session_feat is None or elapsed_feat is None:
                raise ValueError("Need session_feat and elapsed_feat when use_time_and_session is True")
            if self.session_emb is None:
                 raise ValueError("Session embedding is not initialized, likely num_sessions was 0.")


            # session_feat is already int-mapped
            session_indices = session_feat.squeeze(-1).long()
            # Clamp indices to be within the valid range for the embedding layer
            session_indices = torch.clamp(session_indices, 0, self.num_sessions - 1)
            session_emb = self.session_emb(session_indices)

            time_emb = self.time_enc(elapsed_feat)
            adapted = torch.cat([adapted, session_emb, time_emb], dim=1)

        # Ensure input is contiguous before LSTM
        adapted = adapted.contiguous()
        glstm_out, _ = self.global_lstm(adapted.unsqueeze(1))
        feat = glstm_out[:, -1, :]
        return self.reg_phys(feat), self.reg_ment(feat)


# ========== Train / Eval ==========

# --- MODIFICATION START ---
# Helper function to evaluate model performance for one epoch (used during training)
def evaluate_epoch(model, loader, device, crit_reg):
    model.eval() # Set model to evaluation mode
    regL = 0.0
    all_phys_preds, all_phys_labels = [], []
    all_ment_preds, all_ment_labels = [], []

    with torch.no_grad():
        for batch_data in loader:
             # Skip batch if collate function returned None
             if batch_data is None:
                 print("Skipping a batch due to collation error.")
                 continue
             modals, dom, session_feat, elapsed_feat, (phys, ment) = batch_data

             modals = [m.to(device) for m in modals]
             dom = dom.to(device)
             session_feat = session_feat.to(device)
             elapsed_feat = elapsed_feat.to(device)
             phys = phys.to(device)
             ment = ment.to(device)

             p_pred, m_pred = model(modals, dom, session_feat, elapsed_feat)
             loss_r = crit_reg(p_pred, phys) + crit_reg(m_pred, ment) # Sum loss for eval
             regL += loss_r.item()

             all_phys_preds.extend(p_pred.squeeze().cpu().numpy())
             all_phys_labels.extend(phys.squeeze().cpu().numpy())
             all_ment_preds.extend(m_pred.squeeze().cpu().numpy())
             all_ment_labels.extend(ment.squeeze().cpu().numpy())

    # Calculate average loss
    avg_regL = regL / len(loader) if len(loader) > 0 else float('nan')

    # Metrics calculation (similar to evaluate_and_show_predictions)
    all_phys_labels = pd.Series(all_phys_labels).fillna(0).values
    all_phys_preds = pd.Series(all_phys_preds).fillna(0).values
    all_ment_labels = pd.Series(all_ment_labels).fillna(0).values
    all_ment_preds = pd.Series(all_ment_preds).fillna(0).values

    if len(all_phys_labels) < 2:
         rmse_phys, mae_phys, r2_phys = float('nan'), float('nan'), float('nan')
    else:
        rmse_phys = math.sqrt(mean_squared_error(all_phys_labels, all_phys_preds))
        mae_phys  = mean_absolute_error(all_phys_labels, all_phys_preds)
        r2_phys   = r2_score(all_phys_labels, all_phys_preds)

    if len(all_ment_labels) < 2:
        rmse_ment, mae_ment, r2_ment = float('nan'), float('nan'), float('nan')
    else:
        rmse_ment = math.sqrt(mean_squared_error(all_ment_labels, all_ment_preds))
        mae_ment  = mean_absolute_error(all_ment_labels, all_ment_preds)
        r2_ment   = r2_score(all_ment_labels, all_ment_preds)

    metrics_dict = {
        'loss': avg_regL,
        'rmse_physical': rmse_phys,
        'mae_physical': mae_phys,
        'r2_physical': r2_phys,
        'rmse_mental': rmse_ment,
        'mae_mental': mae_ment,
        'r2_mental': r2_ment
    }
    return metrics_dict
# --- MODIFICATION END ---


# Modified train_model to collect metrics each epoch
def train_model(model, train_loader, eval_loader, optim, crit_reg, epochs=10, device=torch.device("cpu")):
    # --- MODIFICATION START ---
    # Initialize dictionary to store metrics history like the previous example
    metrics_history = defaultdict(list)
    # --- MODIFICATION END ---

    for ep in range(1, epochs + 1):
        model.train() # Set model to training mode
        tot_train_loss = 0.0
        for i, batch_data in enumerate(train_loader):
             # Skip batch if collate function returned None
             if batch_data is None:
                 print(f"Skipping training batch {i} due to collation error.")
                 continue
             modals, dom, session_feat, elapsed_feat, (phys, ment) = batch_data

             modals = [m.to(device) for m in modals]
             dom = dom.to(device)
             session_feat = session_feat.to(device)
             elapsed_feat = elapsed_feat.to(device)
             phys = phys.to(device)
             ment = ment.to(device)

             optim.zero_grad()
             debug_flag = (ep == 1 and i == 0 and DEBUG_ATTENTION) # Pass debug only if flag is True
             try:
                 p_pred, m_pred = model(modals, dom, session_feat, elapsed_feat, debug=debug_flag)
                 loss = crit_reg(p_pred, phys) + 2*crit_reg(m_pred, ment) # Weight mental loss more
                 # Check for NaN loss
                 if torch.isnan(loss):
                      print(f"Warning: NaN loss detected at epoch {ep}, batch {i}. Skipping backward pass.")
                      # Optionally add more debugging here: print inputs, model parameters, etc.
                      continue # Skip this batch update
                 loss.backward()
                 # Optional: Gradient clipping
                 # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                 optim.step()
                 tot_train_loss += loss.item()
             except Exception as e:
                 print(f"Error during training forward/backward pass at epoch {ep}, batch {i}: {e}")
                 # Decide how to handle: skip batch, stop training, etc.
                 continue # Skip batch on error


        avg_train_loss = tot_train_loss / len(train_loader) if len(train_loader) > 0 else float('nan')
        print(f"[Epoch {ep}] Avg Training Loss: {avg_train_loss:.6f}")

        # --- MODIFICATION START ---
        # Evaluate on the evaluation set (e.g., test set) after each epoch
        epoch_metrics = evaluate_epoch(model, eval_loader, device, crit_reg)
        print(f"[Epoch {ep}] Evaluation Metrics:")
        pprint.pprint(epoch_metrics) # Print metrics for this epoch

        # Append metrics to history with epoch number (round)
        for key, value in epoch_metrics.items():
            metrics_history[key].append((ep, value)) # Store as (epoch, value) tuple
        # --- MODIFICATION END ---

    # --- MODIFICATION START ---
    # Return the collected history
    return dict(metrics_history) # Convert back to regular dict if preferred
    # --- MODIFICATION END ---


# evaluate_and_show_predictions remains largely the same, but now it's for the *final* evaluation
def evaluate_and_show_predictions(model, test_loader, device, crit_reg):
    # This function now primarily serves to show final performance and sample predictions
    # The metrics calculation part is duplicated from evaluate_epoch, could be refactored
    model.eval()
    regL = 0.0
    all_phys_preds, all_phys_labels = [], []
    all_ment_preds, all_ment_labels = [], []
    first_batch_printed = False

    with torch.no_grad():
        for i, batch_data in enumerate(test_loader):
             if batch_data is None:
                 print("Skipping a batch due to collation error during final evaluation.")
                 continue
             modals, dom, session_feat, elapsed_feat, (phys, ment) = batch_data

             modals = [m.to(device) for m in modals]
             dom = dom.to(device)
             session_feat = session_feat.to(device)
             elapsed_feat = elapsed_feat.to(device)
             phys = phys.to(device)
             ment = ment.to(device)

             p_pred, m_pred = model(modals, dom, session_feat, elapsed_feat)
             loss_r = crit_reg(p_pred, phys) + crit_reg(m_pred, ment)
             regL += loss_r.item()

             # Ensure predictions/labels are detached and moved to CPU before extending list
             all_phys_preds.extend(p_pred.squeeze().detach().cpu().numpy().tolist())
             all_phys_labels.extend(phys.squeeze().detach().cpu().numpy().tolist())
             all_ment_preds.extend(m_pred.squeeze().detach().cpu().numpy().tolist())
             all_ment_labels.extend(ment.squeeze().detach().cpu().numpy().tolist())


             if not first_batch_printed:
                 print("\n=== Predictions vs Ground Truth (first test batch - FINAL EVAL) ===")
                 num_to_print = min(len(phys), 10)
                 for j in range(num_to_print):
                     # Check if predictions are single items or arrays before calling .item()
                     p_val = p_pred[j].item() if p_pred[j].numel() == 1 else p_pred[j].detach().cpu().numpy()[0]
                     m_val = m_pred[j].item() if m_pred[j].numel() == 1 else m_pred[j].detach().cpu().numpy()[0]
                     phys_val = phys[j].item() if phys[j].numel() == 1 else phys[j].detach().cpu().numpy()[0]
                     ment_val = ment[j].item() if ment[j].numel() == 1 else ment[j].detach().cpu().numpy()[0]
                     print(f"Sample {j}: True phys={phys_val:.3f}, True ment={ment_val:.3f} | "
                           f"Pred phys={p_val:.3f}, Pred ment={m_val:.3f}")
                 first_batch_printed = True


    avg_regL = regL / len(test_loader) if len(test_loader) > 0 else float('nan')

    # Convert lists to numpy arrays for metric calculation
    # Using pandas Series first helps handle potential variations in list content (e.g., scalar vs array)
    all_phys_labels_np = pd.Series(all_phys_labels).fillna(0).values
    all_phys_preds_np = pd.Series(all_phys_preds).fillna(0).values
    all_ment_labels_np = pd.Series(all_ment_labels).fillna(0).values
    all_ment_preds_np = pd.Series(all_ment_preds).fillna(0).values

    # Metrics Calculation
    if len(all_phys_labels_np) < 2:
         rmse_phys, mae_phys, r2_phys = float('nan'), float('nan'), float('nan')
    else:
        # Ensure labels/preds are 1D arrays
        all_phys_labels_np = all_phys_labels_np.ravel()
        all_phys_preds_np = all_phys_preds_np.ravel()
        rmse_phys = math.sqrt(mean_squared_error(all_phys_labels_np, all_phys_preds_np))
        mae_phys  = mean_absolute_error(all_phys_labels_np, all_phys_preds_np)
        r2_phys   = r2_score(all_phys_labels_np, all_phys_preds_np)


    if len(all_ment_labels_np) < 2:
        rmse_ment, mae_ment, r2_ment = float('nan'), float('nan'), float('nan')
    else:
        # Ensure labels/preds are 1D arrays
        all_ment_labels_np = all_ment_labels_np.ravel()
        all_ment_preds_np = all_ment_preds_np.ravel()
        rmse_ment = math.sqrt(mean_squared_error(all_ment_labels_np, all_ment_preds_np))
        mae_ment  = mean_absolute_error(all_ment_labels_np, all_ment_preds_np)
        r2_ment   = r2_score(all_ment_labels_np, all_ment_preds_np)


    # --- FINAL EVALUATION METRICS ---
    final_metrics_dict = {
        'loss': avg_regL,
        'rmse_physical': rmse_phys,
        'mae_physical': mae_phys,
        'r2_physical': r2_phys,
        'rmse_mental': rmse_ment,
        'mae_mental': mae_ment,
        'r2_mental': r2_ment
    }

    print("\n=== Final Test Performance (after all epochs) ===")
    print(f"Test Regression Loss (avg): {avg_regL:.6f}")
    print(f"Physical Fatigue - RMSE: {rmse_phys:.6f}, MAE: {mae_phys:.6f}, R²: {r2_phys:.6f}")
    print(f"Mental Fatigue   - RMSE: {rmse_ment:.6f}, MAE: {mae_ment:.6f}, R²: {r2_ment:.6f}")

    print("\n=== Final Metrics Dictionary (after all epochs) ===")
    pprint.pprint(final_metrics_dict)
    # This function now implicitly returns None, but primarily prints results


# ========== Main ==========
if __name__ == "__main__":
    # replace with your CSV path - ENSURE THIS PATH IS CORRECT
    csv_path = '/content/drive/MyDrive/Fatigue_Set/final_feature_label_dataset_normalized_interpolated.csv'
    # csv_path = 'path/to/your/local/file.csv' # Example for local machine

    try:
        df = pd.read_csv(csv_path)
        print(f"Successfully loaded data from {csv_path}. Shape: {df.shape}")
    except FileNotFoundError:
        print(f"Error: CSV file not found at {csv_path}")
        print("Please ensure the path is correct and the file exists.")
        exit()
    except Exception as e:
        print(f"Error loading CSV: {e}")
        exit()

    # Data Cleaning and Preparation
    required_cols = ['session', 'person', 'window_start', 'hr_mean', 'hr_std',
                     'duration_mean', 'duration_std', 'ax_mean', 'ax_std',
                     'ay_mean', 'ay_std', 'az_mean', 'az_std', 'eda_mean', 'eda_std',
                     'temp_mean', 'temp_std', 'physicalFatigueScore', 'mentalFatigueScore']
    missing_cols = [col for col in required_cols if col not in df.columns]
    if missing_cols:
        print(f"Error: Missing required columns in CSV: {missing_cols}")
        exit()

    # Convert columns to numeric where possible, coercing errors
    for col in required_cols:
         if col not in ['window_start']: # Skip date column for now
              df[col] = pd.to_numeric(df[col], errors='coerce')


    # Drop rows with NaN in features OR targets (more robust)
    initial_len = len(df)
    df.dropna(subset=required_cols, inplace=True)
    if len(df) < initial_len:
        print(f"Dropped {initial_len - len(df)} rows containing NaN values in required columns.")

    if df.empty:
         print("Error: DataFrame is empty after dropping NaN values. Cannot proceed.")
         exit()


    # Compute session_map from potentially cleaned df
    unique_sessions = sorted(df['session'].astype(int).unique())
    if not unique_sessions:
        print("Error: No valid sessions found in the data.")
        exit()
    session_map = {s: i for i, s in enumerate(unique_sessions)}
    num_sessions = len(session_map)
    print(f"Found {num_sessions} unique sessions.")

    # Simple 80/20 random split
    shuffled = df.sample(frac=1.0, random_state=42).reset_index(drop=True)
    cut = int(0.8 * len(shuffled))
    train_df = shuffled.iloc[:cut].reset_index(drop=True)
    test_df  = shuffled.iloc[cut:].reset_index(drop=True)

    print(f"Training set size: {len(train_df)}")
    print(f"Test set size: {len(test_df)}")
    if len(train_df) == 0 or len(test_df) == 0:
        print("Error: Train or test set is empty after splitting.")
        exit()

    train_dataset = FatigueSessionDataset(train_df, session_map=session_map, compute_elapsed=True)
    test_dataset  = FatigueSessionDataset(test_df,  session_map=session_map, compute_elapsed=True)

    # Dataloaders - set num_workers=0 for debugging, increase for performance later if possible
    # Add drop_last=True to training loader to prevent potential batch size 1 issues with BatchNorm
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate, num_workers=0, drop_last=True)
    test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate, num_workers=0)


    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Model Initialization
    # Make sure num_sessions is correctly passed
    if num_sessions == 0:
         print("Error: Cannot initialize model with num_sessions=0.")
         exit()

    model = FMAL_Daf([2,2,6,2,2], lstm_hidden=32, fusion_dim=64,
                     num_sessions=num_sessions, use_time_and_session=True).to(device)

    optim = torch.optim.Adam(model.parameters(), lr=5e-4)
    crit_reg = nn.MSELoss()

    # --- MODIFICATION START ---
    # Train the model and get the metrics history
    # Pass test_loader as the evaluation loader for simplicity here
    num_epochs = 47 # Set number of epochs
    all_epoch_metrics = train_model(model, train_loader, test_loader, optim, crit_reg,
                                    epochs=num_epochs, device=device)

    # Print the final history dictionary after training completes
    print("\n=== Full Metrics History (All Epochs) ===")
    pprint.pprint(all_epoch_metrics)
    # --- MODIFICATION END ---


    # --- FINAL EVALUATION ---
    # Perform a final evaluation on the test set and print detailed results
    # This call is somewhat redundant as the last epoch's metrics are in all_epoch_metrics,
    # but it explicitly shows the final state and prints sample predictions.
    print("\n--- Running Final Evaluation on Test Set ---")
    evaluate_and_show_predictions(model, test_loader, device, crit_reg)

Successfully loaded data from /content/drive/MyDrive/Fatigue_Set/final_feature_label_dataset_normalized_interpolated.csv. Shape: (2819, 19)
Found 3 unique sessions.
Training set size: 2255
Test set size: 564
Using device: cpu
[Epoch 1] Avg Training Loss: 0.357502
[Epoch 1] Evaluation Metrics:
{'loss': 0.0962678889433543,
 'mae_mental': 0.17874468863010406,
 'mae_physical': 0.17359048128128052,
 'r2_mental': 0.0024542808532714844,
 'r2_physical': 0.004254043102264404,
 'rmse_mental': 0.2262439296761638,
 'rmse_physical': 0.21509712190768646}
[Epoch 2] Avg Training Loss: 0.148031
[Epoch 2] Evaluation Metrics:
{'loss': 0.08527866311164366,
 'mae_mental': 0.16470257937908173,
 'mae_physical': 0.156401127576828,
 'r2_mental': 0.10842150449752808,
 'r2_physical': 0.1252657175064087,
 'rmse_mental': 0.21388992641573712,
 'rmse_physical': 0.20160365756431362}
[Epoch 3] Avg Training Loss: 0.143667
[Epoch 3] Evaluation Metrics:
{'loss': 0.08531244357840882,
 'mae_mental': 0.16737863421440125,
 '

#Centralized Model - Without Novelty


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import math
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

# =========================
# Dataset (unchanged)
# =========================
class FatigueSessionDataset(Dataset):
    def __init__(self, df, session_map=None, compute_elapsed=True):
        df = df.copy()
        df['session'] = df['session'].astype(int)
        df['person'] = df['person'].astype(int)
        df['window_start'] = pd.to_datetime(df['window_start'])
        if compute_elapsed:
            df['elapsed_time'] = df.groupby(['person', 'session'])['window_start'] \
                                  .transform(lambda x: (x - x.min()).dt.total_seconds())
        else:
            df['elapsed_time'] = 0.0
        self.data = df.sort_values(['person', 'session', 'window_start']).reset_index(drop=True)
        if session_map is None:
            all_sessions = sorted(self.data['session'].unique())
            self.session_map = {s: i for i, s in enumerate(all_sessions)}
        else:
            self.session_map = session_map

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        HR   = torch.tensor([row['hr_mean'], row['hr_std']], dtype=torch.float32)
        IBI  = torch.tensor([row['duration_mean'], row['duration_std']], dtype=torch.float32)
        ACC  = torch.tensor([row['ax_mean'], row['ax_std'], row['ay_mean'], row['ay_std'], row['az_mean'], row['az_std']], dtype=torch.float32)
        EDA  = torch.tensor([row['eda_mean'], row['eda_std']], dtype=torch.float32)
        Temp = torch.tensor([row['temp_mean'], row['temp_std']], dtype=torch.float32)

        phys  = torch.tensor([row['physicalFatigueScore']], dtype=torch.float32)
        ment  = torch.tensor([row['mentalFatigueScore']], dtype=torch.float32)
        session_index = self.session_map[int(row['session'])]
        session_feat = torch.tensor([float(session_index)], dtype=torch.float32)
        elapsed = torch.tensor([float(row['elapsed_time'])], dtype=torch.float32)
        domain_disc = torch.tensor([0.0], dtype=torch.float32)  # placeholder
        return [HR, IBI, ACC, EDA, Temp], domain_disc, session_feat, elapsed, (phys, ment)


def collate(batch):
    modal_inputs = [torch.stack([b[0][i] for b in batch]) for i in range(5)]
    domain_disc  = torch.stack([b[1] for b in batch])
    session_feat = torch.stack([b[2] for b in batch])
    elapsed_feat = torch.stack([b[3] for b in batch])
    phys         = torch.stack([b[4][0] for b in batch])
    ment         = torch.stack([b[4][1] for b in batch])
    return modal_inputs, domain_disc, session_feat, elapsed_feat, (phys, ment)


# =========================
# Model
# =========================
class ModalityLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim=32):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)

    def forward(self, x):
        if x.ndim == 2:
            x = x.unsqueeze(1)
        out, _ = self.lstm(x)
        return out[:, -1, :]


class BaselineFMAL(nn.Module):
    def __init__(self, modalities_dim, lstm_hidden=32, use_time_and_session=True, num_sessions=3):
        super().__init__()
        self.use_time_and_session = use_time_and_session
        self.modality_lstms = nn.ModuleList([ModalityLSTM(d, lstm_hidden) for d in modalities_dim])

        # session embedding
        self.session_emb = nn.Embedding(num_sessions, 8)

        # elapsed time encoder
        self.time_enc = nn.Sequential(
            nn.Linear(1, 8),
            nn.ReLU(),
            nn.Linear(8, 8)
        )

        # global LSTM input = concat of modality LSTMs + session + time
        global_input_dim = lstm_hidden * len(modalities_dim)
        if use_time_and_session:
            global_input_dim += 16
        self.global_lstm = nn.LSTM(global_input_dim, 32, batch_first=True)

        # regression heads
        self.reg_phys = nn.Linear(32, 1)
        self.reg_ment = nn.Linear(32, 1)
        self.apply(self._init_weights)

    @staticmethod
    def _init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LSTM):
            for name, param in m.named_parameters():
                if 'weight' in name:
                    nn.init.xavier_uniform_(param)
                elif 'bias' in name:
                    nn.init.zeros_(param)

    def forward(self, modal_inputs, domain_disc=None, session_feat=None, elapsed_feat=None, debug=False):
        feats = [m(inp) for m, inp in zip(self.modality_lstms, modal_inputs)]
        fused = torch.cat(feats, dim=1)

        if self.use_time_and_session:
            session_emb = self.session_emb(session_feat.squeeze(-1).long())
            time_emb = self.time_enc(elapsed_feat)
            fused = torch.cat([fused, session_emb, time_emb], dim=1)

        glstm_out, _ = self.global_lstm(fused.unsqueeze(1))
        feat = glstm_out[:, -1, :]
        return self.reg_phys(feat), self.reg_ment(feat)


# =========================
# Training / Evaluation (unchanged)
# =========================
def train_model(model, loader, optim, crit_reg, epochs=10, device=torch.device("cpu")):
    for ep in range(1, epochs + 1):
        model.train()
        tot = 0.0
        for modals, dom, session_feat, elapsed_feat, (phys, ment) in loader:
            modals = [m.to(device) for m in modals]
            session_feat = session_feat.to(device)
            elapsed_feat = elapsed_feat.to(device)
            phys = phys.to(device)
            ment = ment.to(device)

            optim.zero_grad()
            p_pred, m_pred = model(modals, session_feat=session_feat, elapsed_feat=elapsed_feat)
            loss = crit_reg(p_pred, phys) + 2*crit_reg(m_pred, ment)
            loss.backward()
            optim.step()
            tot += loss.item()
        avg_loss = tot / len(loader) if len(loader) > 0 else float('nan')
        print(f"[Epoch {ep}] Avg Training Loss: {avg_loss:.6f}")


def evaluate_and_show_predictions(model, test_loader, device, crit_reg):
    model.eval()
    regL = 0.0
    all_phys_preds, all_phys_labels = [], []
    all_ment_preds, all_ment_labels = [], []

    with torch.no_grad():
        for modals, dom, session_feat, elapsed_feat, (phys, ment) in test_loader:
            modals = [m.to(device) for m in modals]
            session_feat = session_feat.to(device)
            elapsed_feat = elapsed_feat.to(device)
            phys = phys.to(device)
            ment = ment.to(device)

            p_pred, m_pred = model(modals, session_feat=session_feat, elapsed_feat=elapsed_feat)
            loss_r = crit_reg(p_pred, phys) + crit_reg(m_pred, ment)
            regL += loss_r.item()

            all_phys_preds.extend(p_pred.squeeze().cpu().numpy())
            all_phys_labels.extend(phys.squeeze().cpu().numpy())
            all_ment_preds.extend(m_pred.squeeze().cpu().numpy())
            all_ment_labels.extend(ment.squeeze().cpu().numpy())

            # print first batch only
            print("=== Predictions vs Ground Truth (first batch) ===")
            for i in range(len(phys)):
                print(f"Sample {i}: True phys={phys[i].item():.3f}, True ment={ment[i].item():.3f} | "
                      f"Pred phys={p_pred[i].item():.3f}, Pred ment={m_pred[i].item():.3f}")
            break

    rmse_phys = math.sqrt(mean_squared_error(all_phys_labels, all_phys_preds))
    mae_phys  = mean_absolute_error(all_phys_labels, all_phys_preds)
    r2_phys   = r2_score(all_phys_labels, all_phys_preds)
    rmse_ment = math.sqrt(mean_squared_error(all_ment_labels, all_ment_preds))
    mae_ment  = mean_absolute_error(all_ment_labels, all_ment_preds)
    r2_ment   = r2_score(all_ment_labels, all_ment_preds)

    print("\n=== Final Test Performance ===")
    print(f"Test Regression Loss (sum): {regL/len(test_loader):.6f}")
    print(f"Physical Fatigue - RMSE: {rmse_phys:.6f}, MAE: {mae_phys:.6f}, R²: {r2_phys:.6f}")
    print(f"Mental Fatigue  - RMSE: {rmse_ment:.6f}, MAE: {mae_ment:.6f}, R²: {r2_ment:.6f}")


# =========================
# Main
# =========================
if __name__ == "__main__":
    csv_path = '/content/drive/MyDrive/Fatigue_Set/final_feature_label_dataset_normalized_interpolated.csv'
    df = pd.read_csv(csv_path)

    session_map = {s: i for i, s in enumerate(sorted(df['session'].astype(int).unique()))}
    shuffled = df.sample(frac=1.0, random_state=42).reset_index(drop=True)
    cut = int(0.8 * len(shuffled))
    train_df = shuffled.iloc[:cut].reset_index(drop=True)
    test_df  = shuffled.iloc[cut:].reset_index(drop=True)

    train_dataset = FatigueSessionDataset(train_df, session_map=session_map, compute_elapsed=True)
    test_dataset  = FatigueSessionDataset(test_df,  session_map=session_map, compute_elapsed=True)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate)
    test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = BaselineFMAL([2,2,6,2,2], lstm_hidden=32, use_time_and_session=True, num_sessions=len(session_map)).to(device)

    optim = torch.optim.Adam(model.parameters(), lr=5e-4)
    crit_reg = nn.MSELoss()

    train_model(model, train_loader, optim, crit_reg, epochs=47, device=device)
    evaluate_and_show_predictions(model, test_loader, device, crit_reg)


[Epoch 1] Avg Training Loss: 0.513347


  return datetime.utcnow().replace(tzinfo=utc)


[Epoch 2] Avg Training Loss: 0.161110
[Epoch 3] Avg Training Loss: 0.156698
[Epoch 4] Avg Training Loss: 0.155157
[Epoch 5] Avg Training Loss: 0.154061
[Epoch 6] Avg Training Loss: 0.152778
[Epoch 7] Avg Training Loss: 0.154899
[Epoch 8] Avg Training Loss: 0.154858
[Epoch 9] Avg Training Loss: 0.151549
[Epoch 10] Avg Training Loss: 0.149262
[Epoch 11] Avg Training Loss: 0.150932
[Epoch 12] Avg Training Loss: 0.150330
[Epoch 13] Avg Training Loss: 0.152497
[Epoch 14] Avg Training Loss: 0.148133
[Epoch 15] Avg Training Loss: 0.149063
[Epoch 16] Avg Training Loss: 0.146970
[Epoch 17] Avg Training Loss: 0.147941
[Epoch 18] Avg Training Loss: 0.148184
[Epoch 19] Avg Training Loss: 0.146879
[Epoch 20] Avg Training Loss: 0.146558
[Epoch 21] Avg Training Loss: 0.145402
[Epoch 22] Avg Training Loss: 0.145739
[Epoch 23] Avg Training Loss: 0.142945
[Epoch 24] Avg Training Loss: 0.144016
[Epoch 25] Avg Training Loss: 0.143619
[Epoch 26] Avg Training Loss: 0.141361
[Epoch 27] Avg Training Loss: 0.1

  return datetime.utcnow().replace(tzinfo=utc)


Training and Testing Data heads

In [None]:
# 🔎 Show dataset splits
print("\n=== Training Dataset Head ===")
print(train_df.head())

print("\n=== Testing Dataset Head ===")
print(test_df.head())