In [None]:
import mne
import numpy as np
import pandas as pd
import seaborn as sns
import plotly.express as plx
from mne.datasets import misc
import mne_connectivity
import pyxdf
import matplotlib.pyplot as plt

In [None]:
channels_to_use = [
    # prefrontal
    'Fp1',
    'Fp2',
    # frontal
    'F7',
    'F3',
    'Fz',
    'F4',
    'F8',
    # central and temporal
    'T3',
    'C3',
    'Cz',
    'C4',
    'T4',
    # parietal
    'T5',
    'P3',
    'Pz',
    'P4',
    'T6',
    # occipital
    'O1',
    'O2',
]

In [None]:
# example file path
fname = "/Users/saracruz/Desktop/N-Pulse/BCI/DSI/data/sub-TEST1/ses-S001/eeg/sub-TEST1_ses-S001_task-Default_run-001_eeg.xdf"
streams, header = pyxdf.load_xdf(fname)
data = streams[1]["time_series"].T
%visualize data
data 

In [None]:
# Find the index of the stimulus and EEG streams

eeg_index = []
for stream in range(len(streams)):
    if streams[stream]["info"]["name"][0] == "EEG-stream":
        eeg_index.append(stream)

# The EEG channels are assumed to be constant across streams 
# because this is built into the DSI-24 system
eeg_index1 = eeg_index[0]
ch_names = []
for i in range(0, len(streams[eeg_index1]["info"]["desc"][0]["channels"][0]["channel"])):
    ch_names.append(streams[eeg_index1]["info"]["desc"][0]["channels"][0]["channel"][i]["label"][0])

# Create the info object
samp_frq = float(streams[eeg_index1]["info"]["nominal_srate"][0])
ch_types = ['eeg'] * len(ch_names)

# Find the stimulus stream in streams
stimulus_stream = None
for stream in range(len(streams)):
    if streams[stream]["info"]["name"][0] == "stimulus_stream":  # Match name
        stimulus_stream = streams[stream]
        break

if stimulus_stream is None:
    raise ValueError("No 'stimulus_stream' found in the dataset.")

# Extract stimulus timestamps and event markers
first_timestamp =float(stimulus_stream["footer"]["info"]["first_timestamp"][0])

event_timestamps = stimulus_stream["time_stamps"] 
eeg_timestamps = streams[eeg_index1]["time_stamps"]
event_index = np.searchsorted(eeg_timestamps, event_timestamps)

event_dict = stimulus_stream["time_series"].flatten()  # Convert to 1D array

# format the events array to correspond to what MNE expects
events = np.column_stack([
    (event_index).astype(int),
    np.zeros(len(event_timestamps), dtype=int),
    event_dict
])

info = mne.create_info(ch_names, sfreq = samp_frq, ch_types= ch_types, verbose=None)

In [None]:
#slightly different because they don't start at the same timepoint
print(event_timestamps)
print(eeg_timestamps[event_index])
print(event_index)

In [None]:
# # uV -> V
data *= 1e-6  

In [None]:
raw = mne.io.RawArray(data, info)
raw.plot(scalings="auto", duration=1, start=14)
#this should be the correct way to do it but since our data is the noisy one so far, i will leave it commented
#raw.plot(scalings=dict(eeg=100e-6))

# Fs
fs = raw.info['sfreq']
print(f'Frequency of Sampling: {fs} Hz')
# Length in seconds
print(f'Duration: {len(raw) / fs} seconds')

In [None]:
sample_1020 = raw.copy().pick_channels(channels_to_use)
assert len(channels_to_use) == len(sample_1020.ch_names)

ch_map = {ch.lower(): ch for ch in sample_1020.ch_names}
ten_twenty_montage = mne.channels.make_standard_montage('standard_1020')
len(ten_twenty_montage.ch_names)
ten_twenty_montage.ch_names = [ch_map[ch.lower()] if ch.lower() in ch_map else ch 
                               for ch in ten_twenty_montage.ch_names]
sample_1020.set_montage(ten_twenty_montage)

In [None]:
sample_1020.plot_sensors(show_names=True)
sample_1020.compute_psd().plot()


In [None]:
# components < 1Hz are not informative and above 50Hz are also not
# !double check this information - papers!
sample_1020.filter(l_freq=1, h_freq=50, method='iir')
sample_1020.notch_filter(
    freqs=[50],               # or [50, 100] for harmonics
    notch_widths=1,           # smaller notch width (default can be 2)
    trans_bandwidth=1,        # narrower transition band
    filter_length='auto',     
    phase='zero-double',      # zero-phase filtering
    fir_design='firwin'
)
sample_1020.compute_psd().plot()

In [None]:
# EEG signals scaled to 100 microvolts for better readability
sample_1020.plot(n_channels=8, duration=20, scalings=dict(eeg=100e-6))

In [None]:
ica = mne.preprocessing.ICA(n_components=10, random_state=42)
ica.fit(sample_1020)

In [None]:
ica.plot_components()

In [None]:
# events: #  indicates the time point in the continuous data where an event of interest occurred.
#  EEG data segments (epochs) that are time-locked to
#  specific events. Each epoch corresponds to a specific 
#  time window around an event of interest, such as a stimulus
#  presentation or a response.

# tmin= -0.5 --> each epoch starts 0.5 seconds before the event
# tmax= 0.8 --> each epoch ends 1.5 seconds after the event

# define events dictionary example
event_id = {
    'left': 1,
    'right': 2,
    'foot': 3
}

# here maybe we need to adapt the lenght of the stimuli 
epochs = mne.Epochs(sample_1020, events, event_id = event_id, tmin=-0.5, tmax=0.8, preload=True,     on_missing='ignore'  # Ignores missing events instead of raising an error
)

# different approach that is done by the campus biotech uses a strategy 
# called time-shifted averaging that helps extract more stable features from EEG

epochs_time_shift = mne.Epochs(sample_1020, events, tmin=0.5, tmax=1.5, preload=True, baseline=None)   

In [None]:
# converts to dataframe and counts the number of events per ID
pd.DataFrame(epochs.events, columns=['_', '__', 'event_id'])['event_id'].value_counts()

In [None]:
sample_1020.to_data_frame().shape
df = epochs.to_data_frame()
df.head(3).iloc[:, :10]

In [None]:
#try my approach :this is usually done but threhold = 300e-6 instead of 
# 600e-6 as try previous proposed

# Calculate the mean and standard deviation of the EEG signal
data = epochs.get_data()
mean = np.mean(data)
std = np.std(data)

# Define a custom threshold based on the mean and standard deviation
threshold = (mean + 6 * std)
print(f'Threshold: {threshold}')

reject_criteria = {'eeg': threshold}

# or we can define a fixed thereshold per epoch 

In [None]:
epochs = mne.Epochs(sample_1020, events,  tmin=-0.5, tmax=0.8, reject=reject_criteria, preload=True, baseline=(-.1, 0))

In [None]:
left = epochs['left'].average()
#right = epochs['right'].average()
foot = epochs['foot'].average()

In [None]:
left.plot(spatial_colors=True);
foot.plot(spatial_colors=True);
