In [1]:
%%capture
from pathlib import Path

if Path.cwd().stem == "notebooks":
    %cd ..
    %load_ext autoreload
    %autoreload 2

In [2]:
import json
import logging
import os
import time
from functools import partial, reduce
from pathlib import Path

import hvplot.polars
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import seaborn as sns
import torch

from src.data.database_manager import DatabaseManager
from src.features.labels import add_labels
from src.features.resampling import add_normalized_timestamp, interpolate_and_fill_nulls
from src.features.scaling import scale_min_max
from src.features.transforming import merge_dfs
from src.log_config import configure_logging
from src.models.data_loader import create_dataloaders
from src.models.data_preparation import prepare_data
from src.models.main_config import RANDOM_SEED
from src.models.utils import load_model
from src.plots.model_performance_per_participant import (
    analyze_per_participant,
    get_summary_statistics_single_model,
    plot_feature_accuracy_comparison,
    plot_participant_accuracy_comparison,
    plot_participant_performance_single_model,
)

configure_logging(
    stream_level=logging.DEBUG,
    ignore_libs=["matplotlib", "Comm", "bokeh", "tornado", "filelock"],
)

pl.Config.set_tbl_rows(12)  # for the 12 trials

polars.config.Config

In [3]:
def merge_dfs(
    dfs: list[pl.DataFrame],
    on: list[str] = ["participant_id", "trial_id", "trial_number", "timestamp"],
    sort_by: list[str] = ["trial_id", "timestamp"],
) -> pl.DataFrame:
    """
    Merge multiple DataFrames into a single DataFrame.

    Default on and sort_by columns are set for data DataFrames (raw, preprocess,
    feature), however, these can be adjusted as needed for other DataFrames (e.g.,
    Trials).

    For merging e.g. Stimulus data with Trials information, use:
    ````
    df = merge_dfs(
        dfs=[stimulus, trials],
        on=["trial_id", "participant_id", "trial_number"],
    )
    ````
    """
    if len(dfs) < 2:
        raise ValueError(
            "A list with at least two DataFrames are required for merging."
        )

    df = reduce(
        lambda left, right: left.join(
            right,
            on=on,
            how="full",
            coalesce=True,
        )
        .sort(sort_by)
        .drop(["rownumber_right", "samplenumber_right"], strict=False),
        dfs,
    )
    return df


In [4]:
# # from main
#     # Load data from database
#     db = DatabaseManager()
#     with db:
#         if any(channel in eeg_features for channel in args.features):
#             eeg = db.get_table(
#                 "Preprocess_EEG",
#                 exclude_trials_with_measurement_problems=True,
#             )
#             trials = db.get_table(
#                 "Trials",
#                 exclude_trials_with_measurement_problems=True,
#             )
#             eeg = add_normalized_timestamp(eeg)
#             df = add_labels(eeg, trials)
#         else:
#             df = db.get_table(
#                 "Merged_and_Labeled_Data", exclude_trials_with_measurement_problems=True
#             )

In [5]:
db = DatabaseManager()

In [8]:
features = ["heart_rate", "eda_raw", "pupil"]
with db:
    eeg = (
        db.execute("from feature_eeg select * where participant_id = 1")
        .pl()
        .with_columns(marker=1)
    )
    data = (
        db.execute("from merged_and_labeled_data select * where participant_id = 1")
        .pl()
        .with_columns(marker=0)
    ).select(
        ["participant_id", "trial_id", "trial_number", "timestamp", "marker"] + features
    )
    trials = db.execute("from trials select * where participant_id = 1").pl()
df = merge_dfs(
    [eeg, data],
    on=["participant_id", "trial_id", "trial_number", "timestamp", "marker"],
)
df = interpolate_and_fill_nulls(df, features).filter(marker=1)

df = add_normalized_timestamp(df)
df = add_labels(df, trials)
df

trial_id,trial_number,participant_id,timestamp,f3,f4,c3,cz,c4,p3,p4,oz,heart_rate,eda_raw,pupil,normalized_timestamp,stimulus_seed,skin_patch,decreasing_intervals,major_decreasing_intervals,increasing_intervals,strictly_increasing_intervals,strictly_increasing_intervals_without_plateaus,plateau_intervals,prolonged_minima_intervals
u16,u8,u8,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,u16,u8,u16,u16,u16,u16,u16,u16,u16
1,1,1,216940.3005,-2.0428e-14,2.5480e-12,1.4277e-12,1.8789e-12,2.3834e-12,-7.4474e-13,3.5290e-12,-4.9960e-15,83.014546,32.228419,4.394018,0.0,243,1,0,0,1,0,0,0,0
1,1,1,216944.0628,36.215756,38.623169,36.793974,38.023269,38.381408,40.83865,35.842668,34.765243,83.044786,32.227853,4.394021,3.7623,243,1,0,0,1,0,0,0,0
1,1,1,216948.0696,82.124126,92.883143,83.587051,91.005893,91.710653,95.496317,87.574106,85.924382,83.076992,32.227251,4.394024,7.7691,243,1,0,0,1,0,0,0,0
1,1,1,216952.132,72.591659,82.607926,75.210979,83.518908,85.043701,87.604904,81.858174,77.411817,83.109645,32.22664,4.394027,11.8315,243,1,0,0,1,0,0,0,0
1,1,1,216955.9735,36.521192,38.429742,39.03288,41.224318,40.175841,42.038494,37.271644,37.422558,83.140522,32.226063,4.394029,15.673,243,1,0,0,1,0,0,0,0
1,1,1,216960.0794,3.835589,3.191793,7.085367,7.604406,6.01906,5.940223,4.730759,7.21841,83.173524,32.225445,4.394032,19.7789,243,1,0,0,1,0,0,0,0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
12,12,1,2.8462e6,14.373961,17.410294,13.92084,17.261947,20.586672,14.635118,21.065382,15.797168,68.217047,21.734118,3.482662,179982.0205,806,6,60,36,0,0,0,0,0
12,12,1,2.8462e6,-36.435794,-34.966611,-37.587998,-35.796078,-31.392383,-40.472262,-31.092165,-35.634271,68.216962,21.734953,3.482721,179986.1539,806,6,60,36,0,0,0,0,0
12,12,1,2.8462e6,-110.547623,-106.954595,-112.331654,-109.489213,-108.875391,-114.838676,-109.349376,-111.384959,68.21688,21.735762,3.482778,179990.1552,806,6,60,36,0,0,0,0,0


In [None]:
scale_min_max(df).hvplot(
    x="timestamp",
    y=["c4"] + features,
    groupby="trial_id",
)

BokehModel(combine_events=True, render_bundle={'docs_json': {'233f82a3-4a53-4bc3-b774-56619709626b': {'version…

In [None]:
features = ["heart_rate", "eda_raw"]
with db:
    eeg = db.get_table(
        "Feature_EEG", exclude_trials_with_measurement_problems=True
    ).with_columns(marker=1)
    data = (
        db.get_table(
            "Merged_and_Labeled_Data", exclude_trials_with_measurement_problems=True
        ).with_columns(marker=0)
    ).select(
        ["participant_id", "trial_id", "trial_number", "timestamp", "marker"] + features
    )
    trials = db.get_table("Trials", exclude_trials_with_measurement_problems=True)
df = merge_dfs(
    [eeg, data],
    on=["participant_id", "trial_id", "trial_number", "timestamp", "marker"],
)
df = interpolate_and_fill_nulls(df, features).filter(marker=1)

df = add_normalized_timestamp(df)
df = add_labels(df, trials)
df

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))