# Read single-trial EEG epochs

In [1]:
import pandas as pd
from matplotlib import pyplot as plt
from spudtr import get_demo_df, P3_1500_FEATHER
from spudtr import epf

SyntaxError: EOL while scanning string literal (__init__.py, line 115)

In [None]:
epochs_df = get_demo_df(P3_1500_FEATHER)

# example: four midline EEG channels
eeg_channels = ['MiPf', 'MiCe', 'MiPa', 'MiOc']

# check the epochs format
epf.check_epochs(epochs_df, eeg_channels, epoch_id="epoch_id", time="time_ms")

# preview the entire dataframe
epochs_df

# Common A1 to average mastoid reference

a.k.a. "bimastoid", "linked mastoid"

**Warning, only valid for EEG recorded with a common A1 reference**

In [None]:
epf.re_reference(
    epochs_df, 
    eeg_channels, 
    'A2', 
    "linked_pair",
    epoch_id="epoch_id", time="time_ms",
)

# New common reference

> Note: Only valid for common reference EEG data.

For example change from common A1 reference to a vertex or nose tip common reference.

Note: new the new reference = 0 as expected.

In [None]:
# vertex location is MiCe
epf.re_reference(
    epochs_df, 
    eeg_channels, 
    'MiCe',
    "new_common", 
    epoch_id='epoch_id', 
    time='time_ms'
)

# Common average reference

Note: for demonstration only, a real application would use all scalp locations

In [None]:
reference_channels = ["MiPf", "MiCe", "MiPa", "MiOc", "A2"]
epf.re_reference(
    epochs_df, 
    eeg_channels, 
    reference_channels ,
    "common_average", 
    epoch_id='epoch_id', 
    time='time_ms'
)

# Center EEG data in an interval (baseline)

> The `start` and `stop` interval units are the same as the time channel

In [None]:
start = -500
stop = -4
centered_eeg_df = epf.center_eeg(
    epochs_df, 
    eeg_channels, 
    start, 
    stop, 
    epoch_id='epoch_id', 
    time='time_ms'
)
centered_eeg_df

# Exclude previously tagged artifacts

This special-purpose filter drops entire epochs where the time-locking event at time 0 is tagged as bad for some reason on the specified `bads_column`. 

This implements a simple convention for pruning epochs based on tags generated by artifact screening functions.

Any column can be used or constructed for this purpose.

Example: drop all epochs where `eeg_artifact` is other than 0 at `time_ms` == 0 

In [None]:
good_epochs = epf.drop_bad_epochs(
    epochs_df, 
    bads_column="eeg_artifact",
    epoch_id='epoch_id', 
    time='time_ms',
)

print("Total number of epoch ids: ", len(epochs_df["epoch_id"].unique()))
print("Number of good epoch ids: ", len(good_epochs["epoch_id"].unique()))
good_epochs

# Filter EEG epochs: 

**Configure the filter**

If left unspecified in `show_filter()` the optional parameters `width_hz`, `ripple_db` and `window` default to reasonable values.

However, to **apply** the filter **all** parameters must be explicitly specified. You can use the defaults or something else.

In [None]:
from spudtr import filters

# note this fills in default transition width, ripple, and window values
lp_params = {
    "ftype": "lowpass",  # lowpass, bandpass, highpass
    "cutoff_hz": 20,     # 1/2 amplitude frequency Hz
    "sfreq": 250.0,      # sampling rate in samples/second
}

bode, imp, s_edge, n_edge = filters.show_filter(**lp_params);  # expand the dictionary with python ** trick

**Apply the filter**

With **all** the parameters specified.

Note the `**lp_params` python trick to expand the filter parameter dictionary for the keyword arguments.

In [None]:
# update the filter specs with default values
lp_params["window"] = "kaiser"
lp_params["width_hz"] = 5.0
lp_params["ripple_db"] = 53.0
display(lp_params)

epochs_df_lp = epf.fir_filter_epochs(
    epochs_df,
    data_columns=eeg_channels,
    epoch_id="epoch_id",
    time="time_ms",
    trim_edges=False,
    **lp_params,
    )

times = epochs_df["time_ms"].unique()
epoch_ids = epochs_df["epoch_id"].unique()

3. Compare the output

# Filter EEG epochs, trim distorted edges


Edge distortion is reduced by zero-padding, but not eliminated.

How much to trim the epochs is your business.

In [None]:
epochs_df_lp_trimmed = epf.fir_filter_epochs(
    epochs_df,
    data_columns=eeg_channels,
    epoch_id="epoch_id",
    time="time_ms",
    trim_edges=True,
    **lp_params,
    )

trimmed_times = epochs_df_lp_trimmed["time_ms"].unique()

Compare the output

In [None]:
# select some epochs to show
epidxs = [0, 6, 20]
channel = "MiPa"

for epidx in epidxs:

    qstr = f"epoch_id == @epoch_ids[{epidx}]"
    f, ax = plt.subplots(figsize=(12,8))
    ax.set_title(f"epoch_id={epidx}",  fontsize=18)
    
    # unfiltered
    ax.plot(
        times,
        epochs_df.query(qstr)[channel],
        color="black", 
        label="unfiltered"
    )
    
    # filtered, phase compensated with distortion
    ax.plot(
        times,
        epochs_df_lp.query(qstr)[channel],
        color="red", 
        label="filtered"
    )
    
    # filtered, phase compensated, distortion trimmed
    ax.plot(
        trimmed_times,
        epochs_df_lp_trimmed.query(qstr)[channel], 
        color="blue", 
        lw=3, 
        label="filtered, trimmed"
    )

    # decorate the beginning and end of the delay shift distortion regions
    for xtime in [trimmed_times[0], trimmed_times[-1]]:
        
        # beginning and end of the trimmed data
        ax.axvline(xtime, color="red")
        ax.annotate(
            s=f"{str(xtime)} ms", 
            xy=(xtime, ax.get_ylim()[1]), 
            fontsize=18, 
            ha="center", 
            va="bottom",
            bbox=dict(boxstyle="round", ec="red",fc="white")
        )
            
        # highlight the trimmed region
        for bound in [0, -1]:
            ax.axvspan(
                times[bound], 
                trimmed_times[bound], 
                color="pink", 
                alpha=.2
            )
    
    ax.legend()