In [None]:
# Credits: Andrei Miroshnikov andrej.miroshnikow@gmail.com
# Useful links:
# https://mne.tools/stable/auto_tutorials/index.html
# https://mne.tools/mne-nirs/stable/auto_examples/index.html

############################################################

#installs for EEG and fNIRS processing. uncomment if needed to install in Jupyter

# !pip install mne
# !pip install mne_nirs
# !pip install numpy
# !pip install matplotlib

In [None]:
import os.path as op
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
from itertools import compress

import mne
import mne_nirs
from mne.preprocessing.nirs import optical_density, beer_lambert_law

from mne.preprocessing.nirs import (optical_density,
                                    temporal_derivative_distribution_repair)
from mne_nirs.signal_enhancement import enhance_negative_correlation, short_channel_regression
from mne import Epochs, events_from_annotations
%matplotlib auto

In [None]:
%matplotlib auto

### Read our fNIRS files

In [None]:
fnirs_dirname = '/mnt/diskus/fNIRS data ME_MI_TS_TI_SA/AB/AB_ME'
raw_intensity = mne.io.read_raw_nirx(fnirs_dirname)
raw_intensity.drop_channels(['S2_D4 760', 'S2_D4 850'])

raw_info = raw_intensity.info

#### We can look at our data and sensors' positions

In [None]:
# raw_intensity.plot()

In [None]:
# sensors_pos = mne.viz.plot_sensors(raw_info, kind='select')

In [None]:
channels_we_can_pick = ['S5_D6 760',
 'S5_D6 850',
 'S5_D12 760',
 'S5_D12 850',
 'S6_D3 760',
 'S6_D3 850',
 'S6_D6 760',
 'S6_D6 850',
 'S6_D7 760',
 'S6_D7 850',
 'S6_D10 760',
 'S6_D10 850',
 'S6_D13 760',
 'S6_D13 850',
 'S7_D4 760',
 'S7_D4 850',
 'S7_D8 760',
 'S7_D8 850',
 'S7_D9 760',
 'S7_D9 850',
 'S7_D11 760',
 'S7_D11 850',
 'S7_D14 760',
 'S7_D14 850',
 'S8_D9 760',
 'S8_D9 850',
 'S8_D15 760',
 'S8_D15 850',
 'S9_D6 760',
 'S9_D6 850',
 'S9_D10 760',
 'S9_D10 850',
 'S9_D12 760',
 'S9_D12 850',
 'S9_D13 760',
 'S9_D13 850',
 'S9_D18 760',
 'S9_D18 850',
 'S10_D7 760',
 'S10_D7 850',
 'S10_D13 760',
 'S10_D13 850',
 'S10_D16 760',
 'S10_D16 850',
 'S10_D19 760',
 'S10_D19 850',
 'S11_D8 760',
 'S11_D8 850',
 'S11_D14 760',
 'S11_D14 850',
 'S11_D17 760',
 'S11_D17 850',
 'S11_D20 760',
 'S11_D20 850',
 'S12_D9 760',
 'S12_D9 850',
 'S12_D11 760',
 'S12_D11 850',
 'S12_D14 760',
 'S12_D14 850',
 'S12_D15 760',
 'S12_D15 850',
 'S12_D21 760',
 'S12_D21 850',
 'S13_D12 760',
 'S13_D12 850',
 'S13_D18 760',
 'S13_D18 850',
 'S13_D34 760',
 'S13_D34 850',
 'S14_D15 760',
 'S14_D15 850',
 'S14_D21 760',
 'S14_D21 850',
 'S14_D35 760',
 'S14_D35 850',
 'S15_D12 760',
 'S15_D12 850',
 'S15_D18 760',
 'S15_D18 850',
 'S15_D22 760',
 'S15_D22 850',
 'S16_D13 760',
 'S16_D13 850',
 'S16_D16 760',
 'S16_D16 850',
 'S16_D18 760',
 'S16_D18 850',
 'S16_D19 760',
 'S16_D19 850',
 'S16_D23 760',
 'S16_D23 850',
 'S17_D14 760',
 'S17_D14 850',
 'S17_D17 760',
 'S17_D17 850',
 'S17_D20 760',
 'S17_D20 850',
 'S17_D21 760',
 'S17_D21 850',
 'S17_D24 760',
 'S17_D24 850',
 'S18_D15 760',
 'S18_D15 850',
 'S18_D21 760',
 'S18_D21 850',
 'S18_D25 760',
 'S18_D25 850',
 'S19_D18 760',
 'S19_D18 850',
 'S19_D22 760',
 'S19_D22 850',
 'S20_D18 760',
 'S20_D18 850',
 'S20_D23 760',
 'S20_D23 850',
 'S20_D36 760',
 'S20_D36 850',
 'S21_D21 760',
 'S21_D21 850',
 'S21_D24 760',
 'S21_D24 850',
 'S21_D37 760',
 'S21_D37 850',
 'S22_D21 760',
 'S22_D21 850',
 'S22_D25 760',
 'S22_D25 850',
 'S23_D18 760',
 'S23_D18 850',
 'S23_D22 760',
 'S23_D22 850',
 'S23_D23 760',
 'S23_D23 850',
 'S24_D19 760',
 'S24_D19 850',
 'S25_D20 760',
 'S25_D20 850',
 'S26_D21 760',
 'S26_D21 850',
 'S26_D24 760',
 'S26_D24 850',
 'S26_D25 760',
 'S26_D25 850']

hb_channels_we_can_pick = ['S5_D6 hbo',
 'S5_D6 hbr',
 'S5_D12 hbo',
 'S5_D12 hbr',
 'S6_D3 hbo',
 'S6_D3 hbr',
 'S6_D6 hbo',
 'S6_D6 hbr',
 'S6_D7 hbo',
 'S6_D7 hbr',
 'S6_D10 hbo',
 'S6_D10 hbr',
 'S6_D13 hbo',
 'S6_D13 hbr',
 'S7_D4 hbo',
 'S7_D4 hbr',
 'S7_D8 hbo',
 'S7_D8 hbr',
 'S7_D9 hbo',
 'S7_D9 hbr',
 'S7_D11 hbo',
 'S7_D11 hbr',
 'S7_D14 hbo',
 'S7_D14 hbr',
 'S8_D9 hbo',
 'S8_D9 hbr',
 'S8_D15 hbo',
 'S8_D15 hbr',
 'S9_D6 hbo',
 'S9_D6 hbr',
 'S9_D10 hbo',
 'S9_D10 hbr',
 'S9_D12 hbo',
 'S9_D12 hbr',
 'S9_D13 hbo',
 'S9_D13 hbr',
 'S9_D18 hbo',
 'S9_D18 hbr',
 'S10_D7 hbo',
 'S10_D7 hbr',
 'S10_D13 hbo',
 'S10_D13 hbr',
 'S10_D16 hbo',
 'S10_D16 hbr',
 'S10_D19 hbo',
 'S10_D19 hbr',
 'S11_D8 hbo',
 'S11_D8 hbr',
 'S11_D14 hbo',
 'S11_D14 hbr',
 'S11_D17 hbo',
 'S11_D17 hbr',
 'S11_D20 hbo',
 'S11_D20 hbr',
 'S12_D9 hbo',
 'S12_D9 hbr',
 'S12_D11 hbo',
 'S12_D11 hbr',
 'S12_D14 hbo',
 'S12_D14 hbr',
 'S12_D15 hbo',
 'S12_D15 hbr',
 'S12_D21 hbo',
 'S12_D21 hbr',
 'S13_D12 hbo',
 'S13_D12 hbr',
 'S13_D18 hbo',
 'S13_D18 hbr',
 'S13_D34 hbo',
 'S13_D34 hbr',
 'S14_D15 hbo',
 'S14_D15 hbr',
 'S14_D21 hbo',
 'S14_D21 hbr',
 'S14_D35 hbo',
 'S14_D35 hbr',
 'S15_D12 hbo',
 'S15_D12 hbr',
 'S15_D18 hbo',
 'S15_D18 hbr',
 'S15_D22 hbo',
 'S15_D22 hbr',
 'S16_D13 hbo',
 'S16_D13 hbr',
 'S16_D16 hbo',
 'S16_D16 hbr',
 'S16_D18 hbo',
 'S16_D18 hbr',
 'S16_D19 hbo',
 'S16_D19 hbr',
 'S16_D23 hbo',
 'S16_D23 hbr',
 'S17_D14 hbo',
 'S17_D14 hbr',
 'S17_D17 hbo',
 'S17_D17 hbr',
 'S17_D20 hbo',
 'S17_D20 hbr',
 'S17_D21 hbo',
 'S17_D21 hbr',
 'S17_D24 hbo',
 'S17_D24 hbr',
 'S18_D15 hbo',
 'S18_D15 hbr',
 'S18_D21 hbo',
 'S18_D21 hbr',
 'S18_D25 hbo',
 'S18_D25 hbr',
 'S19_D18 hbo',
 'S19_D18 hbr',
 'S19_D22 hbo',
 'S19_D22 hbr',
 'S20_D18 hbo',
 'S20_D18 hbr',
 'S20_D23 hbo',
 'S20_D23 hbr',
 'S20_D36 hbo',
 'S20_D36 hbr',
 'S21_D21 hbo',
 'S21_D21 hbr',
 'S21_D24 hbo',
 'S21_D24 hbr',
 'S21_D37 hbo',
 'S21_D37 hbr',
 'S22_D21 hbo',
 'S22_D21 hbr',
 'S22_D25 hbo',
 'S22_D25 hbr',
 'S23_D18 hbo',
 'S23_D18 hbr',
 'S23_D22 hbo',
 'S23_D22 hbr',
 'S23_D23 hbo',
 'S23_D23 hbr',
 'S24_D19 hbo',
 'S24_D19 hbr',
 'S25_D20 hbo',
 'S25_D20 hbr',
 'S26_D21 hbo',
 'S26_D21 hbr',
 'S26_D24 hbo',
 'S26_D24 hbr',
 'S26_D25 hbo',
 'S26_D25 hbr']



### Go from raw intensity to optical density and enchance our signal

In [None]:
raw_od = optical_density(raw_intensity)

In [None]:
raw_od_shorts = raw_od.copy()
raw_od_shorts = mne_nirs.channels.get_short_channels(raw_od_shorts)

### Check scalp coupling index in short and long channels

In [None]:
raw_to_sci = raw_od_shorts
sci = mne.preprocessing.nirs.scalp_coupling_index(raw_to_sci)
fig, ax = plt.subplots(layout="constrained")
ax.hist(sci)
ax.set(xlabel="Scalp Coupling Index", ylabel="Count", xlim=[0, 1])
bad_shorts_sci = list(compress(raw_to_sci.ch_names, sci < 0.5))

In [None]:
#we remove bad short channels
raw_od.drop_channels(bad_shorts_sci)

In [None]:
raw_to_sci = raw_od
sci = mne.preprocessing.nirs.scalp_coupling_index(raw_to_sci)
fig, ax = plt.subplots(layout="constrained")
ax.hist(sci)
ax.set(xlabel="Scalp Coupling Index", ylabel="Count", xlim=[0, 1])
bad_long_sci = list(compress(raw_to_sci.ch_names, sci < 0.5))
raw_od.info['bads'] = bad_long_sci


### Perform TDDR and short channels regression

In [None]:
raw_od = temporal_derivative_distribution_repair(raw_od) #repairs movement artifacts
raw_od = short_channel_regression(raw_od)
raw_od.pick_channels(channels_we_can_pick)
raw_od.interpolate_bads(reset_bads=False, method=dict(fnirs='nearest'))

### Let's convert our optical density to haemoglobin concentration. We will also perform negative correlation enchancement

In [None]:
raw_haemo = mne.preprocessing.nirs.beer_lambert_law(raw_od, ppf=0.1) #from wavelength to HbO\HbR
raw_haemo = mne_nirs.channels.get_long_channels(raw_haemo)
raw_haemo = enhance_negative_correlation(raw_haemo)

### Filtering -- one of trickier parts of fNIRS data processing

In [None]:
fig = raw_haemo.plot_psd(average=True)
fig.suptitle('Before filtering', weight='bold', size='x-large')
fig.subplots_adjust(top=0.88)

l_freq = 0.01
h_freq = 0.1
h_trans_bandwidth = 0.1
l_trans_bandwidth = 0.01

raw_haemo = raw_haemo.filter(l_freq, h_freq, 
                             h_trans_bandwidth=h_trans_bandwidth,
                             l_trans_bandwidth=l_trans_bandwidth
                            )
fig = raw_haemo.plot_psd(average=True)
fig.suptitle('After filtering', weight='bold', size='x-large')
fig.subplots_adjust(top=0.88)


### Let's also resample to make it more convenient

In [None]:
SFREQ = 1
raw_haemo.resample(SFREQ)

### This cell is about selecting a subgroup of channels. In our case, we wanted to select channels around C3 and C4 10-20 leads (corresponding locations are listed). So we end up having a pair of groups of ~20 long channels corresponding for each side's sensorimotor zone.

In [None]:
C3_chans_of_interest_hbo =  ['S9_D13 hbo',
'S9_D18 hbo',
'S10_D13 hbo',
'S10_D16 hbo',
'S10_D19 hbo',
'S13_D18 hbo',
'S16_D13 hbo',
'S16_D16 hbo',
'S16_D18 hbo',
'S16_D19 hbo',
'S16_D23 hbo',
'S24_D19 hbo']
C3_chans_of_interest_hbr = [i.replace('hbo', 'hbr') for i in C3_chans_of_interest_hbo]

C4_chans_of_interest_hbo =  ['S11_D14 hbo',
'S11_D17 hbo',
'S11_D20 hbo',
'S12_D14 hbo',
'S12_D21 hbo',
'S17_D14 hbo',
'S17_D17 hbo',
'S17_D20 hbo',
'S17_D21 hbo',
'S17_D24 hbo',
'S18_D15 hbo',
'S25_D20 hbo']
C4_chans_of_interest_hbr = [i.replace('hbo', 'hbr') for i in C4_chans_of_interest_hbo]

###############################################

### Now we make a transition from continiuous data to segmented data -- we will make epochs

In [None]:
events, ids = mne.events_from_annotations(raw_haemo)
ids["Rest"] = 2
ids["Sensorimotor"] = 1

try:
    ids.pop("2.0")
    ids.pop("1.0")
except:
    ids.pop("2")
    ids.pop("1")

### We can set different epochs parameters like length, baseline or else

In [None]:
tmin, tmax = 0, 14
baseline = (0, 0)

# reject_criteria = dict(hbo=35e-6)

epochs = mne.Epochs(raw_haemo, events, event_id=ids,
                    tmin=tmin, tmax=tmax,
#                     reject=reject_criteria, 
#                     reject_by_annotation=True,
                     baseline=baseline, preload=True,
                    detrend=None, verbose=True)
# epochs.plot_drop_log()

### Optional step but can be of use: a function which will remove extreme epochs in 'remove the outliers' fashion 

In [None]:
def epochs_rejector(epochs, criterion='median',
                    sfreq=SFREQ, 
                    time_limits = (4, 12),
                    lower=0.10, upper=0.90):

    time_limits = (time_limits[0]*sfreq, time_limits[1]*sfreq)
    epochs.copy().pick_channels(C3_chans_of_interest_hbo)
    epochs_data = epochs.get_data()[:, :, time_limits[0]:time_limits[1]]

    if criterion == 'median':

        median = np.median(epochs_data, axis=1)
        median = np.median(median, axis=1)
        lower_quantile = np.quantile(median, lower)
        upper_quantile = np.quantile(median, upper)

        reject_bool_negative = median < lower_quantile
        reject_bool_positive = median > upper_quantile

        reject_bool = np.logical_or( 
                                    reject_bool_negative, 
                                    reject_bool_positive)
    return reject_bool


### We can set different quantile limits for different states\epochs

In [None]:
SMR_LOWER_QUANTILE, SMR_UPPER_QUANTILE = 0.1, 0.9
REST_LOWER_QUANTILE, REST_UPPER_QUANTILE = 0.1, 0.9
time_limits = (4, 13)

rest_epochs = epochs["Rest"]
smr_epochs = epochs["Sensorimotor"]

smr_reject_bool = epochs_rejector(smr_epochs, 
                                  lower=SMR_LOWER_QUANTILE, 
                                  upper=SMR_UPPER_QUANTILE, 
                                  time_limits=time_limits)
rest_reject_bool = epochs_rejector(rest_epochs, 
                                   lower=REST_LOWER_QUANTILE, 
                                   upper=REST_UPPER_QUANTILE, 
                                   time_limits=time_limits)

#Drop epochs using boolean masks
smr_epochs = smr_epochs.drop(smr_reject_bool)
rest_epochs = rest_epochs.drop(rest_reject_bool)


smr_epochs_data = smr_epochs.get_data()
rest_epochs_data = rest_epochs.get_data()

evoked_smr = smr_epochs.average()
evoked_rest = rest_epochs.average()

In [None]:
CONDITION = 'CONDITION'
SUBJECT = 'SKOLTECH'
averaging_method = 'mean'

picks_hbo_left, picks_hbr_left = C3_chans_of_interest_hbo, C3_chans_of_interest_hbr
picks_hbo_right, picks_hbr_right = C4_chans_of_interest_hbo, C4_chans_of_interest_hbr



evoked_dict_left = {f'{CONDITION}/HbO': smr_epochs.copy().average(picks=picks_hbo_left, method=averaging_method),
               f'{CONDITION}/HbR': smr_epochs.copy().average(picks=picks_hbr_left, method=averaging_method),
               'Rest/HbO': rest_epochs.copy().average(picks=picks_hbo_left, method=averaging_method),
               'Rest/HbR': rest_epochs.copy().average(picks=picks_hbr_left, method=averaging_method)}

evoked_dict_right = {f'{CONDITION}/HbO': smr_epochs.copy().average(picks=picks_hbo_right, method=averaging_method),
               f'{CONDITION}/HbR': smr_epochs.copy().average(picks=picks_hbr_right, method=averaging_method),
               'Rest/HbO': rest_epochs.copy().average(picks=picks_hbo_right, method=averaging_method),
               'Rest/HbR': rest_epochs.copy().average(picks=picks_hbr_right, method=averaging_method)}

# Rename channels until the encoding of frequency in ch_name is fixed
for condition in evoked_dict_left:
    evoked_dict_left[condition].rename_channels(lambda x: x[:-4])

for condition in evoked_dict_right:
    evoked_dict_right[condition].rename_channels(lambda x: x[:-4])

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

color_dict = dict(HbO='#AA3377', HbR='b')
styles_dict = dict(Rest=dict(linestyle='dashed'))


y_min = min(evoked_dict_left[f'{CONDITION}/HbO'].data.min(),
            evoked_dict_right[f'{CONDITION}/HbO'].data.min(),()) * 10**6
y_max = max(evoked_dict_left[f'{CONDITION}/HbO'].data.max(),
            evoked_dict_right[f'{CONDITION}/HbO'].data.max(),()) * 10**6
ylim = {'hbo':[y_min, y_max],
            'hbr':[y_min, y_max]}

left_plot = mne.viz.plot_compare_evokeds(evoked_dict_left,
                                         combine=averaging_method,
                                         ci=0.95,
                                         colors=color_dict,
                                         styles=styles_dict,
                                         title=f'{CONDITION} and Rest trials LEFT hemisphere\nSubject {SUBJECT}',
                                         axes=axes[0],
                                         ylim=ylim,
                                         truncate_xaxis=False)  # Use the first subplot axes
right_plot = mne.viz.plot_compare_evokeds(evoked_dict_right,
                                           combine=averaging_method,
                                           ci=0.95,
                                           colors=color_dict,
                                           styles=styles_dict,
                                           title=f'{CONDITION} and Rest trials RIGHT hemisphere\nSubject {SUBJECT}',
                                           axes=axes[1],
                                           ylim=ylim,
                                           truncate_xaxis=False)  # Use the second subplot axes

In [None]:
topomaps_plotter('hbo', smr_epochs=smr_epochs, rest_epochs=rest_epochs, CONDITION=CONDITION, SUBJECT=SUBJECT)

In [None]:
def topomaps_plotter(haemo_picks, smr_epochs, rest_epochs, CONDITION, SUBJECT):
        times = np.arange(2, 14, 2)
        if haemo_picks=='hbo':
            topo_haemo = 'HbO'
        else:
            topo_haemo = 'HbR'

        topomap_args = dict(extrapolate='local')
        smr_evoked = smr_epochs.average(picks=haemo_picks, method='median')
        rest_evoked = rest_epochs.average(picks=haemo_picks, method='median')
        vmin = min(smr_evoked.data.min(), rest_evoked.data.min())*10**6
        vmax = max(smr_evoked.data.max(), rest_evoked.data.max())*10**6
        vlim = (vmin, vmax)
        sm = plt.cm.ScalarMappable(cmap='RdBu_r', norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax))

        # create a figure to contain both topomap plots
        fig, axes = plt.subplots(2, len(times), figsize=(14, 7))

        # loop through times and plot the topomaps for smr epochs and rest epochs
        smr_fig = smr_evoked.plot_topomap(times, axes=axes[0, :],
                                colorbar=False,
                                show=False,
                                **topomap_args)
        rest_fig = rest_evoked.plot_topomap(times, axes=axes[1, :],
                                show=False,
                                colorbar=False,
                                **topomap_args)

        cbaxes = fig.add_axes([0.095, 0.25, 0.02, 0.5]) # setup colorbar axes. 

        cbar = plt.colorbar(mappable=sm, cax=cbaxes, pad=0.15, orientation='vertical')
        cbar.set_label(f'{topo_haemo} concentration, Δ μM\L', loc='center', size=12)

        fig.subplots_adjust( 
                            top=0.910, 
                            bottom=0.06,
                            left=0.150, 
                            right=0.950, 
                            hspace=0.195, 
                            wspace=0.0 
                        )

        x_top, y_top = 0.55, 0.95
        x_bottom, y_bottom = 0.55, 0.5

        fig.text(
                x=x_top, y=y_top, 
                s=f'{CONDITION} {topo_haemo} changes timeline', 
                fontsize='x-large', 
                horizontalalignment='center', 
                verticalalignment='center' 
                )
        fig.text( 
                x=x_bottom, y=y_bottom, 
                s=f'Rest {topo_haemo} changes timeline', 
                fontsize='x-large', 
                horizontalalignment='center', 
                verticalalignment='center'
                )#we set a timeline for each epoch