<a href="https://colab.research.google.com/github/papertuc2000/CL-Drive/blob/dev/CL_Drive.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================================================
# CL-Drive Multi-Modal Data Generator
# ------------------------------------------------------------
# This generator:
#   • Automatically scans dataset structure
#   • Synchronizes modalities per (participant, level)
#   • Ensures label alignment
#   • Performs sliding-window segmentation
#   • Resamples all signals to a unified sampling rate
#   • Supports classification or regression
#   • Uses file-level caching for speed optimization
#   • Returns dictionary input compatible with fusion models
# ===

In [None]:
# --- 1. Connect to Google Drive ---
print("Connecting to Google Drive...")
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

Connecting to Google Drive...
Mounted at /content/drive/


In [None]:

import os
from glob import glob
from collections import defaultdict

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.utils import Sequence
from scipy.signal import resample

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

import warnings
warnings.filterwarnings('ignore')

from glob import glob
from collections import defaultdict

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.utils import Sequence
from scipy.signal import resample


# ============================================================
# CL-Drive Multi-Modal Data Generator (FINAL REVISED)
# ============================================================


class CLDriveMultiModalGenerator(Sequence):

    def __init__(
        self,
        dataset_path,
        modalities=('ECG', 'EEG', 'EDA', 'Gaze'),
        batch_size=8,
        window_sec=10,
        target_fs=128,
        shuffle=True,
        task='classification',
        n_classes=3,
        use_cache=True
    ):

        self.dataset_path = dataset_path
        self.modalities = modalities
        self.batch_size = batch_size
        self.window_sec = window_sec
        self.target_fs = target_fs
        self.shuffle = shuffle
        self.task = task
        self.n_classes = n_classes
        self.use_cache = use_cache

        # Original sampling rates
        self.fs_dict = {
            'ECG': 512,
            'EEG': 256,
            'EDA': 128,
            'Gaze': 50
        }

        self.cache = {}

        # Build aligned dataset
        self.samples = self._build_index()

        # Build window index
        self.indices = self._create_windows()

        self.on_epoch_end()

    # =====================================================
    # 1️⃣  Correct Alignment Logic (Your Exact Structure)
    # =====================================================
    def _build_index(self):

        aligned_data = {}
        label_files = {}

        all_files = glob(
            os.path.join(self.dataset_path, "**", "*.csv"),
            recursive=True
        )

        for f_path in all_files:

            f_path = f_path.replace("\\", "/")

            # Skip baseline
            if "baseline" in f_path.lower():
                continue

            parts = f_path.split("/")
            filename = parts[-1]

            # -------------------------
            # LABEL FILE
            # -------------------------
            if "/Labels/" in f_path:
                participant = os.path.basename(f_path).split('.')[0]
                label_files[participant] = f_path
                continue

            # -------------------------
            # MODALITY FILE
            # -------------------------
            modality = parts[-3]        # e.g. EEG
            participant = parts[-2]     # e.g. participant_ID_1

            if modality not in self.modalities:
                continue

            # Extract level
            if "level_" not in filename:
                continue

            level = filename.split("level_")[-1].replace(".csv", "")

            group_key = (participant, level)

            if group_key not in aligned_data:
                aligned_data[group_key] = {}

            aligned_data[group_key][modality] = f_path

        # --------------------------------------------------
        # 2️⃣  FILTER: Keep ONLY complete modality sets
        # --------------------------------------------------
        final_samples = []

        for (participant, level), found_modalities in aligned_data.items():

            missing = set(self.modalities) - set(found_modalities.keys())

            # Skip if any modality missing
            if len(missing) > 0:
                print(f"Skipping {participant} Level {level}: Missing {missing}")
                continue

            # Skip if no label
            if participant not in label_files:
                print(f"Skipping {participant} Level {level}: Missing label file")
                continue

            final_samples.append({
                'participant': participant,
                'level': int(level),
                'paths': found_modalities,
                'label_path': label_files[participant]
            })

        print(f"\n✔ Total fully synchronized samples: {len(final_samples)}\n")

        return final_samples

    # =====================================================
    # 2️⃣  Window Construction
    # =====================================================
    def _create_windows(self):

        window_indices = []

        for sample_idx, sample in enumerate(self.samples):

            label_df = self._load_csv(sample['label_path'])
            n_segments = len(label_df)

            for seg_idx in range(n_segments):
                window_indices.append((sample_idx, seg_idx))

        return window_indices

    # =====================================================
    # 3️⃣  CSV Loader (File-level Cache)
    # =====================================================
    def _load_csv(self, path):

        if self.use_cache:
            if path not in self.cache:
                self.cache[path] = pd.read_csv(path)
            return self.cache[path]
        else:
            return pd.read_csv(path)

    # =====================================================
    # 4️⃣  Resampling
    # =====================================================
    def _resample_signal(self, signal, orig_fs):

        target_length = int(self.window_sec * self.target_fs)

        if len(signal) == 0:
            return np.zeros((target_length, signal.shape[1]))

        return resample(signal, target_length)

    # =====================================================
    # 5️⃣  Segment Extraction
    # =====================================================
    def _load_segment(self, path, modality, seg_idx):

        df = self._load_csv(path)
        signal = df.values
        orig_fs = self.fs_dict[modality]

        start = int(seg_idx * self.window_sec * orig_fs)
        end = int((seg_idx + 1) * self.window_sec * orig_fs)

        segment = signal[start:end]

        if segment.shape[0] == 0:
            segment = np.zeros((1, signal.shape[1]))

        segment = self._resample_signal(segment, orig_fs)

        return segment.astype(np.float32)

    # =====================================================
    # 6️⃣  Label Processing
    # =====================================================
    def _process_label(self, label_value):

        if self.task == 'classification':
            cls = 0
            if label_value < 4:
                cls = 0
            elif label_value < 7:
                cls = 1
            elif label_value < 10:
                cls = 2

            return tf.keras.utils.to_categorical(
                cls,
                self.n_classes
            )
        else:
            return np.array([label_value], dtype=np.float32)

    # =====================================================
    # 7️⃣  Required by Keras
    # =====================================================
    def __len__(self):
        return int(np.ceil(len(self.indices) / self.batch_size))

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)

    # =====================================================
    # 8️⃣  Batch Generation
    # =====================================================
    def __getitem__(self, index):

        batch_indices = self.indices[
            index * self.batch_size:(index + 1) * self.batch_size
        ]

        X_batches = {f"{m.lower()}_input": [] for m in self.modalities}
        y_batch = []

        for sample_idx, seg_idx in batch_indices:

            sample = self.samples[sample_idx]

            # Load modalities
            for m in self.modalities:
                segment = self._load_segment(
                    sample['paths'][m],
                    m,
                    seg_idx
                )
                X_batches[f"{m.lower()}_input"].append(segment)

            # Load label
            label_df = self._load_csv(sample['label_path'])
            label_lvl = (os.path.basename(sample['paths']['ECG']).split('.')[0]).split('_')[-1]
            label_value = label_df[f'lvl_{label_lvl}'].iloc[seg_idx]
            y_batch.append(self._process_label(label_value))

        X = {k: np.array(v, dtype=np.float32) for k, v in X_batches.items()}
        y = np.array(y_batch)

        return X, y

In [None]:
dataset_path='/content/drive/MyDrive/Colab Notebooks/CL-Drive'

In [None]:
modalities = ['ECG', 'EEG', 'EDA', 'Gaze']

In [None]:
modalities_path_dic = {f'{m}': [] for m in modalities}
modalities_path_dic.update({'Labels': []})

In [None]:
all_files = glob(os.path.join(dataset_path, "**", "*.csv"), recursive=True)

In [None]:
temp = all_files[0].split(os.sep)
temp

['',
 'content',
 'drive',
 'MyDrive',
 'Colab Notebooks',
 'CL-Drive',
 'ECG',
 '1372',
 'ecg_data_level_1.csv']

In [None]:
for f_path in all_files:
    temp = f_path.split(os.sep)
    modalities_path_dic[temp[6]].append(f_path)

In [None]:
import os
from glob import glob

modalities = ['ECG', 'EEG', 'EDA', 'Gaze', 'Labels']
dataset_path = "D:/Projects/PythonProjects/BioSignalClassification/CL-Drive" # Example path

# 1. Nested storage: { (participant, level): { 'EEG': path, 'ECG': path ... } }
aligned_data = {}

all_files = glob(os.path.join(dataset_path, "**", "*.csv"), recursive=True)

for f_path in all_files:
    f_path = f_path.replace("\\", "/")
    parts = f_path.split("/")

    # Extract info from path (Adjust indices based on your exact root)
    # Example: .../EEG/participant_ID_1/eeg_data_level_9.csv
    modality = parts[-3]       # e.g., 'EEG'
    participant = parts[-2]    # e.g., 'participant_ID_1'
    filename = parts[-1]       # e.g., 'eeg_data_level_9.csv'

    # Extract level (assuming 'level_X' is always in the filename)
    # We use a simple split or regex to get the number
    level = filename.split('level_')[-1].replace('.csv', '')

    # Initialize the group key
    group_key = (participant, level)
    if group_key not in aligned_data:
        aligned_data[group_key] = {}

    aligned_data[group_key][modality] = f_path

# 2. Filter: Only keep groups that have ALL modalities
final_paths = {m: [] for m in modalities}

for (participant, level), found_modalities in aligned_data.items():
    # Check if every required modality exists for this specific participant+level
    if all(m in found_modalities for m in modalities):
        for m in modalities:
            final_paths[m].append(found_modalities[m])
    else:
        print(f"Skipping {participant} Level {level}: Missing {set(modalities) - set(found_modalities)}")

In [None]:
train_gen = CLDriveMultiModalGenerator(
    dataset_path='/content/drive/MyDrive/Colab Notebooks/CL-Drive',
    batch_size=16,
    window_sec=10,
    target_fs=128,
)

Skipping 1744 Level 7: Missing {'EEG'}
Skipping 1716 Level 2: Missing {'Gaze'}
Skipping 1716 Level 3: Missing {'Gaze'}
Skipping 1716 Level 4: Missing {'Gaze'}
Skipping 1716 Level 7: Missing {'Gaze'}
Skipping 1716 Level 8: Missing {'Gaze'}
Skipping 1547 Level 1: Missing {'Gaze'}
Skipping 1547 Level 2: Missing {'Gaze'}
Skipping 1547 Level 7: Missing {'Gaze'}
Skipping 1547 Level 8: Missing {'Gaze'}
Skipping 1323 Level 4: Missing {'EEG'}
Skipping 1323 Level 2: Missing {'ECG', 'EDA'}

✔ Total fully synchronized samples: 171

