# ELM fNIRS Data Preprocessing

This notebook preprocesses fNIRS data from the ELM study.

**Use only the first code cell** to execute this notebook automatically for multiple subjects in a row.

In [None]:
import os
import json
import sys
from pathlib import Path

input_base = Path("/Users/saewonchung/Desktop/ELM_MW_data_analysis/ELM_filtered_data")
output_base = Path("/Users/saewonchung/Desktop/ELM_MW_data_analysis/ELM_preprocessed")

# Auto-discover all input directories with .tri files
input_dirs = []
for date_folder in sorted(input_base.iterdir()):
    if not date_folder.is_dir():
        continue
    for session_folder in sorted(date_folder.iterdir()):
        if not session_folder.is_dir():
            continue
        # Check if .tri file exists
        if list(session_folder.glob("*.tri")):
            relative_path = f"{date_folder.name}/{session_folder.name}"
            input_dirs.append(relative_path)

print(f"Found {len(input_dirs)} datasets with .tri files")
print("First 5:", input_dirs[:5])

# Trigger configuration for segmentation
segment_by_triggers = {
    "2.0": "Video/Begin",
    "3.0": "Video/End",
    "0.0": "Video/TaskEnd"
}

# ============================================================================
# Code for reprocessing only missing subjects
# After adjusting duration tolerance, reprocess only 14 subjects missing
# Zima/Splitscreen label files
# Target subjects: 40, 41, 42, 43, 44, 45, 46, 47, 50, 55, 56, 57, 109, 113
# (sub-53 is unrecoverable due to trigger recording error)
#
# Usage: Set REPROCESS_MISSING_ONLY = True and run the first cell
# ============================================================================
REPROCESS_MISSING_ONLY = True  # Set to True to reprocess only missing subjects

missing_subjects = ['40', '41', '42', '43', '44', '45', '46', '47', '50', '55', '56', '57', '109', '113']

if REPROCESS_MISSING_ONLY:
    # Filter input_dirs to only include missing subjects
    def get_subject_id(input_dir_path):
        """Extract subject ID from description.json"""
        desc_files = list(input_dir_path.glob("*_description.json"))
        if desc_files:
            with open(desc_files[0], 'r') as f:
                desc = json.load(f)
                return desc.get('subject', '')
        return ''
    
    filtered_dirs = []
    for rel_path in input_dirs:
        full_path = input_base / rel_path
        subj_id = get_subject_id(full_path)
        # Remove leading zeros for comparison (e.g., "055" -> "55")
        subj_id_clean = subj_id.lstrip('0') or '0'
        if subj_id_clean in missing_subjects or subj_id in missing_subjects:
            filtered_dirs.append(rel_path)
            print(f"  ‚Üí Will reprocess: {rel_path} (subject {subj_id})")
    
    input_dirs = filtered_dirs
    print(f"\nüîÑ Reprocessing {len(input_dirs)} missing subjects only")

if 'INPUT_DIR' not in os.environ:
    import subprocess

    # Get notebook path
    from IPython import get_ipython
    ip = get_ipython()
    if '__vsc_ipynb_file__' in ip.user_ns:
        nb_path = ip.user_ns['__vsc_ipynb_file__']
    else:
        import ipynbname
        nb_path = ipynbname.path()

    # For each input directory, submit a subprocess to run nbconvert on this notebook
    for target_input_dir in input_dirs:
        date_folder = target_input_dir.split('/')[0]
        session_folder = target_input_dir.split('/')[1]
        full_input_dir = input_base / date_folder / session_folder

        if not full_input_dir.exists():
            print(f"‚ö†Ô∏è  Skipping: {full_input_dir} does not exist.")
            continue

        # Check if tri file is present
        if not list(full_input_dir.glob("*.tri")):
            print(f"‚ùå Skipping: No .tri file found in {full_input_dir}")
            continue

        env = os.environ.copy()
        env["INPUT_DIR"] = str(full_input_dir)

        output_name = f"{session_folder}.ipynb"

        try:
            print(f"\n{'='*60}")
            print(f"Processing: {target_input_dir}")
            print(f"{'='*60}")
            # Use python -m jupyter instead of jupyter command directly
            subprocess.run([
                sys.executable, "-m", "jupyter", "nbconvert",
                "--to", "notebook",
                "--execute", str(nb_path),
                "--output", output_name,
                "--output-dir", str(output_base)
            ], env=env, check=True)
            print(f"‚úÖ Success: {target_input_dir}")
        except subprocess.CalledProcessError as e:
            print(f"‚ùå Notebook execution failed for {target_input_dir}: {e}")

# Preprocessing Pipeline

Code adapted from https://mne.tools/stable/auto_tutorials/preprocessing/70_fnirs_processing.html

## Setup

In [None]:
if 'INPUT_DIR' not in os.environ:
    raise Exception('Missing INPUT_DIR variable! Need to run using nbconvert; see first code cell')

from itertools import compress
import matplotlib.pyplot as plt
import numpy as np
import mne
import mne_nirs

input_dir = os.environ['INPUT_DIR']

input_path = Path(input_dir)
raw_intensity = mne.io.read_raw_nirx(input_path, verbose=True).load_data()

subj = raw_intensity.info['subject_info']['his_id']

# Extract session identifier from input directory (e.g., "2025-03-12_001")
session_id = input_path.name

# Check if accelerometer is available
config_file = list(input_path.glob("*_config.json"))[0]
with open(config_file, 'r') as f:
    config = json.load(f)
has_accelerometer = config.get('use_accelerometer', False)
print(f"Dataset has accelerometer: {has_accelerometer}")
print(f"Session ID: {session_id}")

# Create the full output directory
output_dir = output_base / f'sub-{subj}'
output_dir.mkdir(parents=True, exist_ok=True)

## Step 1: Compute Optical Density Time Series

In [None]:
# Plot raw intensity time series
raw_intensity.plot(n_channels=len(raw_intensity.ch_names), duration=500, show_scrollbars=False)

# Convert from raw intensity to optical density and plot
raw_od = mne.preprocessing.nirs.optical_density(raw_intensity)
raw_od.plot(n_channels=len(raw_od.ch_names), duration=500, show_scrollbars=False)

## Step 2: Signal Quality Assessment (SCI)

Identify and filter out channels with poor signal quality using the Scalp Coupling Index.

In [None]:
# Visualize original SCI distribution
sci = mne.preprocessing.nirs.scalp_coupling_index(raw_od)
fig, ax = plt.subplots(layout="constrained")
ax.hist(sci)
ax.set(xlabel="Scalp Coupling Index", ylabel="Count", xlim=[0, 1])
plt.show()

In [None]:
# Filter out bad channels (SCI < 0.5)
od = raw_od.copy()
od.info["bads"] = list(compress(od.ch_names, sci < 0.5))
bad_channels = list(compress(od.ch_names, sci < 0.5))
od.drop_channels(bad_channels)
print(f"Dropped bad channels based on SCI < 0.5: {bad_channels}")

In [None]:
# Visualize new SCI distribution
sci_clean = mne.preprocessing.nirs.scalp_coupling_index(od)
fig, ax = plt.subplots(layout="constrained")
ax.hist(sci_clean, bins=20)
ax.set(xlabel="Scalp Coupling Index (after removal of bad channels)", ylabel="Count", xlim=[0, 1])
plt.show()

## Step 3: Signal Enhancement - Short Channel Regression

In [None]:
# Regress out short channels to remove systemic noise
od = mne_nirs.signal_enhancement.short_channel_regression(od)
od.plot(n_channels=len(od.ch_names), duration=500, show_scrollbars=False)

## Step 4: Motion Artifact Correction - TDDR

In [None]:
# Apply Temporal Derivative Distribution Repair to remove motion artifacts
od = mne.preprocessing.nirs.temporal_derivative_distribution_repair(od)
od.plot(n_channels=len(od.ch_names), duration=500, show_scrollbars=False)

## Step 5: Motion Artifact Correction - Accelerometer (if available)

TODO: Implement accelerometer-based motion correction for datasets with `use_accelerometer: true`

In [None]:
if has_accelerometer:
    print("‚ö†Ô∏è  Accelerometer data detected but correction not yet implemented")
    print("TODO: Add accelerometer-based motion artifact correction")
    # Future implementation here
else:
    print("No accelerometer data - skipping accelerometer-based correction")

## Step 6: Convert to Haemoglobin Concentration

In [None]:
# Convert from optical density to haemoglobin using Beer-Lambert law
raw_haemo = mne.preprocessing.nirs.beer_lambert_law(od, ppf=0.1)
raw_haemo = mne_nirs.channels.get_long_channels(raw_haemo)

# Plot haemoglobin time series
raw_haemo.plot()

## Step 7: Signal Filtering - Remove Heart Rate

In [None]:
haemo = raw_haemo.copy()
haemo.filter(0.05, 0.7, h_trans_bandwidth=0.2, l_trans_bandwidth=0.02)

# Visualize power spectral density before and after filtering
for when, _haemo in dict(Before=raw_haemo, After=haemo).items():
    fig = _haemo.compute_psd().plot(
        average=True, amplitude=False, picks="data", exclude="bads"
    )
    fig.suptitle(f"{when} filtering", weight="bold", size="x-large")

## Step 8: Trigger-Based Segmentation

In [None]:
# Rename triggers according to segment_by_triggers dictionary
haemo.annotations.rename(segment_by_triggers)

# Keep only the annotations we want to segment by
wanted_annotations = list(segment_by_triggers.values())
unwanted_annotations = np.nonzero(~np.isin(haemo.annotations.description, wanted_annotations))
haemo.annotations.delete(unwanted_annotations)

In [None]:
events, event_dict = mne.events_from_annotations(haemo)

# Make each event unique
event_desc = {v: k for k, v in event_dict.items()}
original_event_ids = []
for i, event in enumerate(events):
    original_event_ids.append(event_desc[event[2]])
    event[2] = i + 1

# Apply HRF delay shift (6 seconds)
shift_seconds_for_hrf_delay = 6.0
events_hrf_shifted = mne.event.shift_time_events(
    events, ids=None, tshift=shift_seconds_for_hrf_delay, sfreq=haemo.info["sfreq"]
)

# Create annotations from shifted events
annot_from_events = mne.annotations_from_events(
    events=events_hrf_shifted,
    sfreq=haemo.info["sfreq"],
    orig_time=haemo.info["meas_date"],
    first_samp=haemo.first_samp
)

# Set durations to span until next annotation
mapping = {}
for i, annot in enumerate(annot_from_events):
    if i == len(annot_from_events) - 1:
        continue
    duration = annot_from_events.onset[i + 1] - annot["onset"]
    mapping[annot["description"]] = duration
annot_from_events.set_durations(mapping, verbose=True)

# Restore original event IDs
annot_from_events.rename(
    {str(i + 1): original_event_ids[i] for i in range(len(original_event_ids))}
)

# Segment the data by annotations
haemo_segments = haemo.crop_by_annotations(annot_from_events)

## Step 9: Export to CSV

Save full and segmented data using BIDS-like naming conventions.

In [None]:
# Video duration to name mapping
# NOTE: Original exact matching caused 15 subjects to miss video labels due to trigger timing variance
# Affected subjects: sub-40~47, 50, 53, 55~57, 109, 113 (durations like 507, 504, 144 instead of 508, 145)
# Additionally, 5 subjects have no source data: sub-1, 49, 63 (recorded as P63), 74, 77

# Original exact matching (kept for reference):
# video_duration_to_name = {
#     508: 'Zima',
#     145: 'Splitscreen'
# }

# Updated: Use tolerance-based matching to handle trigger timing variance
def get_video_name(duration, tolerance=5):
    """Match video duration with tolerance for trigger timing variance."""
    if abs(duration - 508) <= tolerance:
        return 'Zima'
    elif abs(duration - 145) <= tolerance:
        return 'Splitscreen'
    return None

# Export full preprocessed data with session ID
haemo.to_data_frame().to_csv(output_dir / f"sub-{subj}_ses-{session_id}_desc-preproc_haemo.csv", index=False)
annot_from_events.to_data_frame().to_csv(output_dir / f"sub-{subj}_ses-{session_id}_annotations.csv", index=False)

# Export segmented data
segment_index = 0
for i, segment in enumerate(haemo_segments):
    duration = int(annot_from_events.duration[i])
    if duration > 0:
        segment_index += 1
        task_label = annot_from_events.description[i].split("/")[0]
        
        # Check if this duration corresponds to a named video (with tolerance)
        video_name = get_video_name(duration)
        if video_name:
            print(f"Segment {segment_index}: duration={duration}s, video={video_name}")
            # Save with video name
            segment.to_data_frame().to_csv(
                output_dir / f"sub-{subj}_ses-{session_id}_task-{task_label}_label-{video_name}_haemo.csv",
                index=False
            )
        
        # Always save with acq number and duration
        segment.to_data_frame().to_csv(
            output_dir / f"sub-{subj}_ses-{session_id}_task-{task_label}_acq-{segment_index}_dur-{duration}_desc-preproc_haemo.csv",
            index=False
        )

print(f"\n‚úÖ Preprocessing complete for subject {subj}, session {session_id}")
print(f"Output directory: {output_dir}")
print(f"Exported {segment_index} segments")