In [1]:
%%capture
from pathlib import Path

if Path.cwd().stem == "features":
    %cd ../..
    %load_ext autoreload
    %autoreload 2

In [2]:
import logging
from pathlib import Path

import holoviews as hv
import hvplot.polars  # noqa
import matplotlib.pyplot as plt
import mne
import numpy as np
import pandas as pd
import plotly.io as pio
import polars as pl
from icecream import ic
from ipywidgets import interact
from polars import col
from scipy import signal

from src.data.database_manager import DatabaseManager
from src.data.quality_checks import check_sample_rate
from src.experiments.measurement.stimulus_generator import StimulusGenerator
from src.features.eeg import (
    decimate_eeg,
    highpass_filter_eeg,
    preprocess_eeg,
    remove_line_noise,
)
from src.features.labels import add_labels, process_labels
from src.features.resampling import decimate, interpolate_and_fill_nulls
from src.features.scaling import scale_min_max
from src.features.transforming import map_trials, merge_dfs
from src.features.utils import add_time_column
from src.log_config import configure_logging
from src.plots.plot_modality import plot_modality_over_trials
from src.plots.utils import prepare_multiline_hvplot

logger = logging.getLogger(__name__.rsplit(".", 1)[-1])

configure_logging(
    stream_level=logging.DEBUG,
    ignore_libs=["matplotlib", "Comm", "bokeh", "tornado"],
)

pl.Config.set_tbl_rows(12)  # for the 12 trials
hv.output(widget_location="bottom", size=130)
mne.set_log_level(verbose=False, return_old_level=False, add_frames=None)

In [3]:
# Goal: aggregated EEG time frequency data across Trials x Channels (12 * 8)

In [4]:
def plot_psd_df(
    df: pl.DataFrame,
    channel: str = "c4",
    fs: int = 500,
):
    data = df.get_column(channel).to_numpy()
    plt.psd(data, Fs=fs)
    plt.title(f"Power Spectral Density of {channel}")


In [5]:
def add_normalized_timestamp(
    df: pl.DataFrame,
    time_column: str = "timestamp",
    trial_column: str = "trial_id",
):
    return df.with_columns(
        [
            (col(time_column) - col(time_column).min().over(trial_column)).alias(
                "normalized_timestamp"
            )
        ]
    ).sort("trial_id", "timestamp")


In [6]:
db = DatabaseManager()

In [7]:
exclude_ = False
eeg_query = """
SELECT * 
FROM Raw_EEG
WHERE trial_id < 60
"""
trial_query = """
SELECT *
FROM Trials
WHERE trial_id < 60
"""
with db:
    eeg = db.get_table("Raw_EEG", exclude_)
    trials = db.get_table("Trials", exclude_)
    stimulus = db.get_table("Feature_Stimulus", exclude_)

df = merge_dfs([stimulus, eeg])
df = merge_dfs(
    dfs=[trials, df],
    on=[
        "trial_id",
        "participant_id",
        "trial_number",
    ],
).drop("timestamp_start", "timestamp_end", "duration", "skin_area", "rownumber")
# we only interpolate stimulus columns to keep the sampling rate of the eeg data
df = interpolate_and_fill_nulls(df, ["rating", "temperature"]).drop_nulls()
df
# add_normalized_timestamp(df).write_parquet("raw_eeg.parquet")

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

trial_id,trial_number,participant_id,stimulus_seed,timestamp,temperature,rating,f3,f4,c3,cz,c4,p3,p4,oz
u16,u8,u8,u16,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
1,1,1,396,294198.4628,0.0,0.425,9948.540039,12283.850586,5801.344238,18263.294922,12240.93457,17119.123047,13553.095703,5451.536133
1,1,1,396,294200.3867,0.0,0.425,9941.483398,12282.277344,5801.916016,18267.203125,12242.74707,17122.318359,13557.387695,5453.300781
1,1,1,396,294202.5715,0.0,0.425,9942.960938,12290.860352,5800.676758,18270.494141,12251.998047,17126.800781,13561.726562,5462.074219
1,1,1,396,294204.4057,0.0,0.425,9963.942383,12311.935547,5811.35791,18279.601562,12267.97168,17134.095703,13567.115234,5467.701172
1,1,1,396,294206.6215,0.0,0.425,9981.442383,12317.563477,5804.109863,18269.015625,12266.636719,17120.648438,13552.333008,5452.824219
1,1,1,396,294208.5568,0.0,0.425,9996.795898,12321.043945,5787.086914,18254.662109,12260.152344,17105.152344,13534.927734,5438.041992
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
332,12,28,133,2.7771e6,0.155438,0.85,-25060.083984,-25377.228516,-23833.611328,-20785.285156,-23060.753906,-26660.158203,-26556.542969,-24024.822266
332,12,28,133,2.7771e6,0.155438,0.85,-25062.753906,-25377.275391,-23837.044922,-20784.761719,-23065.902344,-26662.78125,-26562.980469,-24024.154297
332,12,28,133,2.7771e6,0.155438,0.85,-25066.140625,-25378.324219,-23841.574219,-20787.050781,-23072.435547,-26670.267578,-26570.65625,-24037.744141


In [8]:
def compute_time_frequency(
    data: np.ndarray,
    fs: float,
    nperseg: int = 256,
    noverlap: int = 128,
    freq_range: tuple = (0.5, 50),
    window: str = "hann",
) -> tuple:
    """Compute time-frequency representation using better parameters for EEG.

    Args:
        data: EEG signal
        fs: Sampling frequency
        nperseg: Length of each segment
        noverlap: Number of points to overlap
        freq_range: Tuple of (min_freq, max_freq)
        window: Window function to use

    Returns:
        Tuple of (frequencies, times, power)
    """
    # Compute spectrogram
    f, t, Sxx = signal.spectrogram(
        data,
        fs=fs,
        nperseg=nperseg,
        noverlap=noverlap,
        window=window,
        scaling="density",
        detrend="constant",
    )

    # Find frequency range indices
    freq_mask = (f >= freq_range[0]) & (f <= freq_range[1])
    f = f[freq_mask]
    Sxx = Sxx[freq_mask]

    # Convert to dB scale with proper normalization
    Sxx = 10 * np.log10(Sxx + np.finfo(float).eps)

    return f, t, Sxx


@interact(trial_id=(0, 60))
def plot_spectrogram(trial_id):
    SAMPLE_RATE = 500
    decimation_factor = 2

    # Get trial data
    eeg = df.filter(col("trial_id") == trial_id)

    # Preprocess
    decimated = decimate_eeg(eeg, decimation_factor)
    fs = SAMPLE_RATE / decimation_factor
    filtered = highpass_filter_eeg(decimated, 0.5, fs)
    denoised = remove_line_noise(filtered, fs)

    # Get EEG data
    data = denoised.get_column("c4").to_numpy()

    # Compute time-frequency representation
    f, t, Sxx = compute_time_frequency(
        data,
        fs=fs,
        nperseg=int(fs),  # 1-second windows
        noverlap=int(fs * 0.9),  # 90% overlap
        freq_range=(0.5, 50),
        window="hann",
    )

    # Plot
    plt.figure(figsize=(12, 6))

    # Plot spectrogram with better parameters
    plt.pcolormesh(t, f, Sxx, cmap="RdBu_r", shading="gouraud")

    # Get and plot stimulus
    seed = (
        denoised.filter(col("trial_id") == trial_id)
        .get_column("stimulus_seed")
        .unique()
        .item()
    )
    stimulus = StimulusGenerator(seed=seed)

    # Scale stimulus to frequency range
    stimulus_scaled = np.interp(
        stimulus.y, (stimulus.y.min(), stimulus.y.max()), (f.min(), f.max())
    )
    stimulus_time = np.linspace(0, t[-1], len(stimulus.y))

    # Plot stimulus overlay
    plt.plot(stimulus_time, stimulus_scaled, color="black", alpha=0.5, linewidth=1)

    # Customize plot
    plt.colorbar(label="Power (dB)")
    plt.ylabel("Frequency (Hz)")
    plt.xlabel("Time (s)")
    plt.title(f"Trial {trial_id} - Channel C4")

    # Set better frequency range for visualization
    plt.ylim(0.5, 50)

    plt.tight_layout()
    plt.show()


interactive(children=(IntSlider(value=30, description='trial_id', max=60), Output()), _dom_classes=('widget-in…

multitaper

In [9]:
@interact(trial_id=(0, 60))
def plot_spectrogram(trial_id):
    SAMPLE_RATE = 500
    # trial_id = 32
    decimation_factor = 2

    eeg = df.filter(col("trial_id") == trial_id)

    decimated = decimate_eeg(eeg, decimation_factor)
    SAMPLE_RATE /= decimation_factor
    filtered = highpass_filter_eeg(decimated, 0.5, SAMPLE_RATE)
    denoised = remove_line_noise(filtered, SAMPLE_RATE)

    data = denoised.get_column("c4").to_numpy().T

    # Create the spectrogram
    f, t, Sxx = signal.spectrogram(
        data, fs=SAMPLE_RATE, nperseg=256, noverlap=128, scaling="density"
    )

    # Create the plot
    plt.figure(figsize=(10, 6))

    # Plot spectrogram
    plt.pcolormesh(t, f, 10 * np.log10(Sxx), cmap="jet", shading="auto")

    # Get stimulus data
    seed = (
        denoised.filter(col("trial_id") == trial_id)
        .get_column("stimulus_seed")
        .unique()
        .item()
    )
    print(seed)
    stimulus = StimulusGenerator(seed=seed)

    # Scale the stimulus data to match the frequency range
    stimulus_scaled = np.interp(
        stimulus.y, (stimulus.y.min(), stimulus.y.max()), (f.min(), f.max())
    )

    # Create time points for stimulus that match spectrogram time axis
    stimulus_time = np.linspace(0, t[-1], len(stimulus.y))

    # Plot the scaled stimulus
    plt.plot(stimulus_time, stimulus_scaled, color="black")

    plt.colorbar(label="Power/Frequency (dB/Hz)")
    plt.ylabel("Frequency [Hz]")
    plt.xlabel("Time [s]")
    plt.show()


interactive(children=(IntSlider(value=30, description='trial_id', max=60), Output()), _dom_classes=('widget-in…

In [10]:
# Normalizing the data
normalized = (
    decimated.sort(["trial_id"])  # Sort within each group if needed
    .group_by("trial_id", maintain_order=True)
    .agg(
        [
            pl.all().limit(180 * 250)  # Take the first 9000 rows for each group
        ]
    )
    .explode(pl.all().exclude("trial_id"))  # Explode the result back into rows
)
seed = normalized.get_column("stimulus_seed").unique().item()

stim = StimulusGenerator(seed=seed, debug=True)
for _, group in normalized.group_by("trial_id", maintain_order=False):
    data = group.get_column("c4").to_numpy().T

    fig, ax = plt.subplots(figsize=(10, 6))
    ax.specgram(data, Fs=250, NFFT=256, noverlap=128, cmap="viridis")

    plt.plot(stim.y[::8] * 12 + 15, color="white", linewidth=0.5)

    plt.title(f"Trial ID: {group['trial_id'][0]}")
    plt.xlabel("Time")
    plt.ylabel("Frequency")
    plt.ylim(3, 30)

    plt.tight_layout()
    plt.show()


NameError: name 'decimated' is not defined

In [None]:
# Normalizing the data
normalized = (
    denoised.sort(["trial_id"])  # Sort within each group if needed
    .group_by("trial_id", maintain_order=True)
    .agg(
        pl.all().limit(180 * 250)  # Take the first 9000 rows for each group
    )
    .explode(pl.all().exclude("trial_id"))  # Explode the result back into rows
)
normalized