<a href="https://colab.research.google.com/github/papertuc2000/CL-Drive/blob/dev/CL_Drive.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================================================
# CL-Drive Multi-Modal Data Generator
# ------------------------------------------------------------
# This generator:
#   • Automatically scans dataset structure
#   • Synchronizes modalities per (participant, level)
#   • Ensures label alignment
#   • Performs sliding-window segmentation
#   • Resamples all signals to a unified sampling rate
#   • Supports classification or regression
#   • Uses file-level caching for speed optimization
#   • Returns dictionary input compatible with fusion models
# ===

In [1]:
# --- 1. Connect to Google Drive ---
print("Connecting to Google Drive...")
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

Connecting to Google Drive...
Mounted at /content/drive/


In [2]:

import os
from glob import glob
from collections import defaultdict

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.utils import Sequence
from scipy.signal import resample

In [3]:



class CLDriveMultiModalGenerator(Sequence):
    """
    Keras-compatible multi-modal generator for CL-Drive dataset.

    Each batch contains:
        X = {
            'ecg_input':  (B, T, C1),
            'eeg_input':  (B, T, C2),
            'eda_input':  (B, T, C3),
            'gaze_input': (B, T, C4)
        }
        y = (B, n_classes) or (B, 1)

    Where:
        B = batch size
        T = window_sec * target_fs (unified temporal length)
        Cx = channel dimension of each modality
    """

    def __init__(
        self,
        dataset_path,
        modalities=('ECG', 'EEG', 'EDA', 'Gaze'),
        batch_size=8,
        window_sec=10,
        target_fs=128,
        shuffle=True,
        task='classification',   # 'classification' or 'regression'
        n_classes=3,
        require_all_modalities=True,
        use_cache=True
    ):

        # ----------------------------
        # Basic configuration
        # ----------------------------
        self.dataset_path = dataset_path
        self.modalities = modalities
        self.batch_size = batch_size
        self.window_sec = window_sec
        self.target_fs = target_fs
        self.shuffle = shuffle
        self.task = task
        self.n_classes = n_classes
        self.require_all_modalities = require_all_modalities
        self.use_cache = use_cache

        # ----------------------------
        # Original sampling rates
        # These are used to compute window boundaries
        # before resampling to target_fs
        # ----------------------------
        self.fs_dict = {
            'ECG': 512,
            'EEG': 256,
            'EDA': 128,
            'Gaze': 50
        }

        # ----------------------------
        # File-level cache:
        # Stores loaded CSV files in memory to
        # avoid repeated disk reads during training
        # ----------------------------
        self.cache = {}

        # --------------------------------------------------
        # Step 1: Build synchronized dataset index
        # --------------------------------------------------
        self.samples = self._build_index()

        # --------------------------------------------------
        # Step 2: Create window index mapping
        # (sample_idx, segment_idx)
        # --------------------------------------------------
        self.indices = self._create_windows()

        self.on_epoch_end()

    # =====================================================
    # Dataset Synchronization
    # =====================================================
    def _build_index(self):
        """
        Scans dataset directory recursively and builds
        synchronized multi-modal index.

        Synchronization logic:
            key = (participant_id, level)

        Only samples satisfying:
            - label file exists
            - required modalities exist
        are included.
        """

        # Collect all CSV files recursively
        all_files = glob(os.path.join(self.dataset_path, "**", "*.csv"), recursive=True)

        synced_data = defaultdict(dict)
        label_files = {}

        for file_path in all_files:

            # Ignore baseline recordings
            if 'baseline' in file_path.lower():
                continue

            parts = file_path.split(os.sep)
            filename = parts[-1]

            # --------------------------
            # Detect label files
            # One label file per participant
            # --------------------------
            if f"{os.sep}Labels{os.sep}" in file_path:
                p_id = parts[-2]
                label_files[p_id] = file_path
                continue

            # --------------------------
            # Detect modality files
            # --------------------------
            for m in self.modalities:
                if f"{os.sep}{m}{os.sep}" in file_path:

                    p_id = parts[-2]

                    # Extract difficulty level from filename
                    if '_level_' in filename:
                        level = filename.split('_level_')[-1].replace('.csv', '')
                        key = (p_id, level)
                        synced_data[key][m] = file_path

                    break

        # --------------------------
        # Final filtering
        # --------------------------
        final_samples = []

        for (p_id, level), files in synced_data.items():

            # Must have labels
            if p_id not in label_files:
                continue

            # Strict multi-modal requirement
            if self.require_all_modalities:
                if not all(m in files for m in self.modalities):
                    continue
            else:
                # Minimal requirement: ECG present
                if 'ECG' not in files:
                    continue

            final_samples.append({
                'p_id': p_id,
                'level': int(level),
                'paths': files,
                'label_path': label_files[p_id]
            })

        print(f"Total synchronized samples: {len(final_samples)}")

        return final_samples

    # =====================================================
    # Window Construction
    # =====================================================
    def _create_windows(self):
        """
        Each subjective label corresponds to one
        10-second segment.

        This method creates index pairs:
            (sample_idx, segment_idx)
        """

        window_indices = []

        for sample_idx, sample in enumerate(self.samples):

            # Load label file to determine number of segments
            label_df = self._load_csv(sample['label_path'])
            n_segments = len(label_df)

            for seg_idx in range(n_segments):
                window_indices.append((sample_idx, seg_idx))

        return window_indices

    # =====================================================
    # CSV Loader with Optional Caching
    # =====================================================
    def _load_csv(self, path):
        """
        Loads CSV file with optional in-memory caching.
        """

        if self.use_cache:
            if path not in self.cache:
                self.cache[path] = pd.read_csv(path)
            return self.cache[path]
        else:
            return pd.read_csv(path)

    # =====================================================
    # Required by Keras Sequence
    # =====================================================
    def __len__(self):
        """
        Returns number of batches per epoch.
        """
        return int(np.ceil(len(self.indices) / self.batch_size))

    def on_epoch_end(self):
        """
        Shuffle window indices at epoch end
        to prevent ordering bias.
        """
        if self.shuffle:
            np.random.shuffle(self.indices)

    # =====================================================
    # Resampling
    # =====================================================
    def _resample_signal(self, signal, orig_fs):
        """
        Resamples signal to unified target_fs.

        Ensures that all modalities share
        identical temporal length:
            window_sec * target_fs
        """

        target_length = int(self.window_sec * self.target_fs)

        if len(signal) == 0:
            return np.zeros((target_length, signal.shape[1]))

        return resample(signal, target_length)

    # =====================================================
    # Segment Extraction
    # =====================================================
    def _load_segment(self, path, modality, seg_idx):
        """
        Extracts a fixed-length time window corresponding
        to the seg_idx-th subjective label.
        """

        df = self._load_csv(path)
        signal = df.values

        orig_fs = self.fs_dict[modality]

        # Compute temporal window boundaries
        start = int(seg_idx * self.window_sec * orig_fs)
        end = int((seg_idx + 1) * self.window_sec * orig_fs)

        segment = signal[start:end]

        # Resample to target frequency
        segment = self._resample_signal(segment, orig_fs)

        return segment.astype(np.float32)

    # =====================================================
    # Label Processing
    # =====================================================
    def _process_label(self, label_value):
        """
        Converts raw subjective score into:
            - one-hot class vector (classification)
            - scalar value (regression)
        """

        if self.task == 'classification':
            return tf.keras.utils.to_categorical(label_value, self.n_classes)
        else:
            return np.array([label_value], dtype=np.float32)

    # =====================================================
    # Batch Construction
    # =====================================================
    def __getitem__(self, index):
        """
        Generates one batch of multi-modal data.
        """

        batch_indices = self.indices[
            index * self.batch_size:(index + 1) * self.batch_size
        ]

        # Initialize per-modality containers
        X_batches = {f"{m.lower()}_input": [] for m in self.modalities}
        y_batch = []

        for sample_idx, seg_idx in batch_indices:

            sample = self.samples[sample_idx]

            # --------------------------
            # Load each modality window
            # --------------------------
            for m in self.modalities:
                segment = self._load_segment(
                    sample['paths'][m],
                    m,
                    seg_idx
                )
                X_batches[f"{m.lower()}_input"].append(segment)

            # --------------------------
            # Load corresponding label
            # --------------------------
            label_df = self._load_csv(sample['label_path'])
            label_value = label_df.iloc[seg_idx].values[0]
            label = self._process_label(label_value)

            y_batch.append(label)

        # Convert lists to numpy arrays
        X = {
            k: np.array(v, dtype=np.float32)
            for k, v in X_batches.items()
        }

        y = np.array(y_batch)

        return X, y