# Multiple Subject Processing

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import mne
from mne_bids import (BIDSPath,read_raw_bids)
import sys
sys.path.insert(0,'.')
import os
import gc


import ccs_eeg_utils
import config
from tools import get_valid_input
from visualization import show_single_psd, psd_compare, iclabel_visualize, plot_erp
from s00_add_reference import add_reference_channel
from s01_downsample_filter import down_sampling, band_filter, notch_filter, zapline_filter
from s02_drop_bad_channels import drop_bad_channels, reref
from s03_07_trial_rejection import trial_rejection_cust, trial_rejection_mne
from s04_ICA import get_ica, get_iclabel, iccomponent_removal, iccomponent_removal_author, iccomponent_removal_new
from s05_interpolation import interpolation
from s06_early_trial_removal import exclude_early_trials
from s07_epoching import epoching, epoching_cust
from s08_find_bad_channels import find_bad_channels
from s09_make_erps import get_evoked, get_evoked_difference
from s10_rewp_calculation import rewp_calculation

In [None]:
SUBJECTS = list(config.SUBJECT_INFO.keys())

USER = get_valid_input(
    'Select the user (options: qian/zheng): ',
    list(config.BIDS_ROOT.keys())
)

# Authors' pipeline

In [None]:
ACTIVE_PIPELINE = 'authors'
group_evokeds = {}

for subject_id in SUBJECTS:
    # ------ LOAD DATA ------
    cfg = config.PIPELINES[ACTIVE_PIPELINE]
    root = config.BIDS_ROOT[USER]

    bids_path = BIDSPath(subject=SUBJECTS, task='casinos',
                        datatype='eeg', suffix='eeg',
                        root=root)
    # read the file
    raw = read_raw_bids(bids_path)
    # fix the annotations readin
    ccs_eeg_utils.read_annotations_core(bids_path,raw)

    ### montage setup
    montage_site2_path = os.path.join(root, 'code', config.LOCS_FILENAME['site2']) 
    montage_site2 = mne.channels.read_custom_montage(montage_site2_path)
    montage_common_path = os.path.join(root, 'code', config.LOCS_FILENAME['common'])
    montage_common = mne.channels.read_custom_montage(montage_common_path)

    raw.load_data() 


    # ------- PREPROCESSING PIPELINE -------
    # Add reference channel Fz
    raw = add_reference_channel(raw, 'Fz') 

    # Set custom montage
    raw.set_montage(montage_site2, match_case=False)

    # Downsample to 250 Hz
    eeg_down = down_sampling(raw)

    # Bandpass filter 0.1-30 Hz
    eeg_band = band_filter(eeg_down)

    # Notch filter at 50 Hz and harmonics
    eeg_band_notch = notch_filter(eeg_band)

    # Drop bad channels and reference channel
    eeg_ica = drop_bad_channels(subject_id,eeg_band_notch)
    eeg_ica = reref(eeg_ica)

    ## get the dictionary for the events
    evts, evts_dict = mne.events_from_annotations(eeg_ica)
    evts_dict_stim = {k: evts_dict[k] for k in evts_dict.keys() if k in config.CONDITIONS_DICT['onset_locked']}

    # Trial rejection using customized function
    rejection_params = config.PIPELINES['original']['rejection_params']['ica']
    trials, rejection_info = trial_rejection_cust(eeg_ica, evts, evts_dict_stim, **rejection_params)
    
    # ICA and ICLabel
    ica = get_ica(trials, config.PIPELINES['original']['ica_method'])
    ic_labels = get_iclabel(trials, ica, method='iclabel')
    eeg_band_notch = iccomponent_removal_author(eeg_band_notch, ica)

    # Interpolation of bad channels
    eeg_band_notch = interpolation(eeg_band_notch)

    # exclude early trials
    eeg_final = exclude_early_trials(eeg_band_notch, config.PIPELINES[ACTIVE_PIPELINE]['early_trial_deletion'])

    # Epoching
    epoch_dict = config.CONDITIONS_DICT['feedback_locked']
    rejection_params = config.PIPELINES['original']['rejection_params']['erp']
    bad_channel_criteria = config.PIPELINES['original']['bad_channels_rejection_criteria']
    epochs_all, rejection_info = epoching_cust(epoch_dict, eeg_final, **rejection_params)
    find_bad_channels(epochs_all, reject_criteria=bad_channel_criteria, custom=True, rejection_info=rejection_info)
    
    # Get evoked responses
    all_evokeds = get_evoked(epoch_dict, epochs_all, proportiontocut=config.PIPELINES[ACTIVE_PIPELINE]['evoked_proportiontocut'])

    group_evokeds[subject_id] = all_evokeds

    del raw
    gc.collect()

## (TBC) ours pipeline

In [None]:
group_evokeds = {}

for subject_id in ['27','28','29','31','34','35','36','37','38']:   # learners alone 
    print(f"Processing subject {subject_id}...")
    bids_path = BIDSPath(subject=subject_id, task='casinos',
                         datatype='eeg', suffix='eeg',
                         root=bids_root)
    # read the file
    raw = read_raw_bids(bids_path)
    # fix the annotations readin
    ccs_eeg_utils.read_annotations_core(bids_path,raw)

    raw.load_data() 

    # Add reference channel Fz
    raw = add_reference_channel(raw, 'Fz') 

    # Set custom montage
    locs_filename = 'site2channellocations.locs'
    locs_path = os.path.join(bids_root, 'code', locs_filename)      
    custom_montage = mne.channels.read_custom_montage(locs_path)
    raw.set_montage(custom_montage, match_case=False)

    # Downsample to 250 Hz
    eeg_down = down_sampling(raw)

    # Bandpass filter 0.1-30 Hz
    eeg_band = band_filter(eeg_down)

    # Notch filter at 50 Hz and harmonics
    eeg_band_notch = notch_filter(eeg_band)

    # Drop bad channels and reference channel
    eeg_ica = drop_bad_channels(subject_id,eeg_band_notch)
    eeg_ica = reref(eeg_ica)

    ## get the dictionary for the events
    evts, evts_dict = mne.events_from_annotations(eeg_ica)
    evts_dict_stim = {k: evts_dict[k] for k in evts_dict.keys() if k in ['Stimulus:S  1', 'Stimulus:S 11', 'Stimulus:S 21', 'Stimulus:S 31']}

    # Trial rejection using MNE built-in function
    trials_mne = trial_rejection_mne(eeg_ica, evts, evts_dict_stim, max=500e-6, min=0.1e-6)

    # ICA and ICLabel
    ica = get_ica(trials_mne, method="picard")
    ic_labels = get_iclabel(trials_mne, ica, method='iclabel')
    eeg_band_notch = iccomponent_removal_new(eeg_band_notch, ica)

    # Interpolation of bad channels
    eeg_band_notch = interpolation(eeg_band_notch)

    # exclude early trials
    eeg_final = exclude_early_trials(eeg_band_notch, num_to_exclude=10)

    # Epoching
    epochs_all = epoching(CONDITIONS_DICT, eeg_final, max=150e-6, min=0.1e-6, tmin=-0.2, tmax=0.6, baseline=(-0.2, 0))

    # Get evoked responses
    all_evokeds = get_evoked(CONDITIONS_DICT, epochs_all, proportiontocut=0.05)

    group_evokeds[subject_id] = all_evokeds


# Calculate the average

In [None]:
# calculate the grand average across subjects
grand_averages = {}
for condition in CONDITIONS_DICT.keys():
    evokeds_list = [group_evokeds[subject_id][condition] for subject_id in group_evokeds.keys()]
    grand_averages[condition] = mne.grand_average(evokeds_list)
# # Save grand averages
# output_dir = os.path.join(bids_root, 'derivatives', 'grand_averages')
# os.makedirs(output_dir, exist_ok=True)
# for condition, evoked in grand_averages.items():
#     evoked.save(os.path.join(output_dir, f'grand_average_{condition.replace(" ", "_")}-ave.fif'))

# Plotting

In [None]:
channel = 'FCz'
mean_window = [0.240, 0.340]

In [None]:
# Define styles for the plot
colors = {
    'Low-Low Win': 'red', 'Low-Low Loss': 'blue',
    'Mid-Low Win': 'red', 'Mid-Low Loss': 'blue',
    'Mid-High Win': 'red', 'Mid-High Loss': 'blue',
    'High-High Win': 'red', 'High-High Loss': 'blue'
}
linestyles = {
    'Low-Low Win': '-', 'Low-Low Loss': '-',
    'Mid-Low Win': '--', 'Mid-Low Loss': '--',
    'Mid-High Win': '-.', 'Mid-High Loss': '-.',
    'High-High Win': ':', 'High-High Loss': ':'
}
plot_erp(grand_averages, channel, mean_window, colors=colors, linestyles=linestyles, title=f"ERP Results")

In [None]:
diff_evokeds = {}

cases = [
    ('Low-Low', 'Low-Low Win', 'Low-Low Loss'),
    ('Mid-Low', 'Mid-Low Win', 'Mid-Low Loss'),
    ('Mid-High', 'Mid-High Win', 'Mid-High Loss'),
    ('High-High', 'High-High Win', 'High-High Loss')
]

for case_name, win_cond, loss_cond in cases:
    # Calculate Difference: Win - Loss
    diff = mne.combine_evoked(
        [grand_averages[win_cond], grand_averages[loss_cond]],
        weights=[1, -1]
    )
    diff.comment = case_name # Set name for plotting
    diff_evokeds[case_name] = diff
    print(f"Calculated difference for: {case_name}")

colors_diff = {
    'Low-Low': '#4C72B0',   # Muted Blue
    'Mid-Low': '#64B5CD',   # Soft Cyan
    'Mid-High': '#E1BC66',  # Sand/Gold
    'High-High': '#C44E52'  # Muted Crimson
}

plot_erp(diff_evokeds, channel, mean_window, colors=colors_diff, title=f"RewP Difference Waves (Win-Loss) - averaged across subjects (authors pipeline)")



In [None]:
import matplotlib.pyplot as plt

# List of your evoked objects (one per subject)
time_point = 0.300 # The specific time (e.g., 300ms) you want to compare

# Create a figure with a subplot for each subject
n_subs = len(group_evokeds)
fig, axes = plt.subplots(1, n_subs, figsize=(n_subs * 2, 3))

for i, evoked in enumerate(group_evokeds):
    # Plot topomap onto the specific subplot axis
    evoked.plot_topomap(times=time_point, axes=axes[i], colorbar=False, show=False)
    axes[i].set_title(f'Sub {i+1}')

plt.tight_layout()
plt.show()

In [None]:
times = [0.18, 0.22, 0.26, 0.30, 0.34, 0.38] # Fewer time points to save space
vlimit = (-5, 5)

for condition, evoked in diff_evokeds.items():
    print(f"Plotting Topomap for: {condition}")
    evoked.plot_topomap(times=times, ch_type='eeg', colorbar=True, vlim=vlimit)