In [1]:
%%capture
from pathlib import Path

if Path.cwd().stem == "notebooks":
    %cd ..
    %load_ext autoreload
    %autoreload 2

In [2]:
import logging
from pathlib import Path

import altair as alt
import holoviews as hv
import hvplot.polars  # noqa
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from polars import col

from src.data.database_manager import DatabaseManager
from src.features.labels import add_labels
from src.features.scaling import scale_min_max
from src.features.transforming import interpolate_and_fill_nulls, map_trials, merge_dfs
from src.features.utils import to_describe
from src.log_config import configure_logging
from src.plots.confidence_intervals import plot_confidence_intervals
from src.plots.correlations import (
    aggregate_correlations_fisher_z,
    calculate_correlations_by_trial,
    plot_correlations_by_participant,
    plot_correlations_by_trial,
)
from src.plots.utils import prepare_multiline_hvplot

logger = logging.getLogger(__name__.rsplit(".", maxsplit=1)[-1])
configure_logging(
    stream_level=logging.DEBUG,
    ignore_libs=["matplotlib", "Comm", "bokeh", "tornado", "param", "numba"],
)

pl.Config.set_tbl_rows(12)  # for the 12 trials
hv.output(widget_location="bottom", size=130)

In [3]:
db = DatabaseManager()

In [4]:
with db:
    pupil = db.get_table("Feature_Pupil")
    eda = db.get_table("Feature_EDA")
    ppg = db.get_table("Feature_PPG")
    stimulus = db.get_table("Feature_Stimulus")
    trials = db.get_table("Trials")


df = merge_dfs([pupil, stimulus, eda, ppg])
df = merge_dfs(
    dfs=[df, trials],
    on=["trial_id", "participant_id", "trial_number"],
)
df = interpolate_and_fill_nulls(df)
df = df.with_columns(col("pupil_mean").alias("pupil"))
df

trial_id,trial_number,participant_id,rownumber,timestamp,pupil_l_raw,pupil_r_raw,pupil_r,pupil_l,pupil_mean,temperature,rating,samplenumber,eda_raw,eda_tonic,eda_phasic,ppg_raw,ppg_heartrate,ppg_ibi,ppg_clean,ppg_rate,ppg_quality,ppg_peaks,stimulus_seed,skin_area,timestamp_start,timestamp_end,duration,pupil
u16,u8,u8,u32,f64,f64,f64,f64,f64,f64,f64,f64,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,u16,u8,f64,f64,f64,f64
1,1,1,17631,294198.9762,5.73628,6.245389,5.640989,5.183107,5.412048,0.0,0.425,57892,0.743774,0.743503,0.000271,1416.012559,-1.0,-1.0,-25.366688,65.896546,0.975548,0.0,396,1,294197.3945,474206.7098,180009.3153,5.412048
1,1,1,17631,294200.0,5.735654,6.245111,5.640912,5.183059,5.411985,0.0,0.425,57892,0.743774,0.743503,0.000271,1416.012559,-1.0,-1.0,-25.366688,65.896546,0.975548,0.0,396,1,294197.3945,474206.7098,180009.3153,5.411985
1,1,1,17631,294210.3603,5.729314,6.242299,5.640126,5.18257,5.411348,0.0,0.425,57892,0.743774,0.743503,0.000271,1411.918496,-1.0,-1.0,-29.145591,65.896546,0.975548,0.0,396,1,294197.3945,474206.7098,180009.3153,5.411348
1,1,1,17632,294215.605,5.726105,6.240875,5.639728,5.182322,5.411025,0.0,0.425,57892,0.743825,0.743504,0.000321,1409.845957,-1.0,-1.0,-31.058587,65.896546,0.975548,0.0,396,1,294197.3945,474206.7098,180009.3153,5.411025
1,1,1,17632,294224.331,5.718118,6.236929,5.639069,5.181911,5.41049,0.0,0.425,57892,0.743911,0.743505,0.000405,1406.397718,-1.0,-1.0,-34.241382,65.896546,0.975548,0.0,396,1,294197.3945,474206.7098,180009.3153,5.41049
1,1,1,17633,294232.4178,5.710716,6.233272,5.638458,5.18153,5.409994,0.000004,0.425,57892,0.74399,0.743506,0.000483,1403.202071,-1.0,-1.0,-37.19103,65.896546,0.975548,0.0,396,1,294197.3945,474206.7098,180009.3153,5.409994
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
332,12,28,166439,2.7771e6,4.113826,4.007632,3.981218,4.025649,4.003433,0.155438,0.85,467067,13.52329,13.423457,-0.012743,1784.639006,72.0,-1.0,44.046749,65.934066,0.987849,0.0,133,1,2.5971e6,2.7771e6,180026.123,4.003433
332,12,28,166439,2.7771e6,4.106207,3.997951,3.981182,4.025596,4.003389,0.155438,0.85,467067,13.524636,13.423457,-0.011957,1796.336996,72.0,-1.0,41.542434,65.934066,0.987849,0.0,133,1,2.5971e6,2.7771e6,180026.123,4.003389
332,12,28,166440,2.7771e6,4.102623,3.993397,3.981166,4.025572,4.003369,0.155438,0.85,467067,13.525269,13.423457,-0.011587,1796.336996,72.0,-1.0,41.542434,65.934066,0.987849,0.0,133,1,2.5971e6,2.7771e6,180026.123,4.003369


In [5]:
col1, col2 = "pupil", "rating"

corr_by_trial = calculate_correlations_by_trial(df, col1, col2)
corr_by_participant = aggregate_correlations_fisher_z(
    corr_by_trial, f"{col1}_{col2}_corr", "participant_id", include_ci=True
)
plot_correlations_by_trial(corr_by_trial, f"{col1}_{col2}_corr")
# plot_correlations_by_participant(corr_by_participant, f"{col1}_{col2}_corr")

In [None]:
scale_min_max(df).filter(col("trial_id") == 259).hvplot(
    x="timestamp", y=["pupil", "rating", "eda_tonic"]
)

In [None]:
col1, col2 = "pupil", "temperature"

corr_by_trial = calculate_correlations_by_trial(df, col1, col2)
corr_by_participant = aggregate_correlations_fisher_z(
    corr_by_trial, f"{col1}_{col2}_corr", "participant_id", include_ci=True
)
# plot_correlations_by_trial(corr_by_trial, f"{col1}_{col2}_corr")
plot_correlations_by_participant(corr_by_participant, f"{col1}_{col2}_corr")

In [None]:
col1, col2 = "eda_tonic", "rating"

corr_by_trial = calculate_correlations_by_trial(df, col1, col2)
corr_by_participant = aggregate_correlations_fisher_z(
    corr_by_trial, f"{col1}_{col2}_corr", "participant_id", include_ci=True
)
plot_correlations_by_trial(corr_by_trial, f"{col1}_{col2}_corr")
# plot_correlations_by_participant(corr_by_participant, f"{col1}_{col2}_corr")

In [None]:
col1, col2 = "eda_phasic", "rating"

corr_by_trial = calculate_correlations_by_trial(df, col1, col2)
corr_by_participant = aggregate_correlations_fisher_z(
    corr_by_trial, f"{col1}_{col2}_corr", "participant_id", include_ci=True
)
plot_correlations_by_trial(corr_by_trial, f"{col1}_{col2}_corr")
# plot_correlations_by_participant(corr_by_participant, f"{col1}_{col2}_corr")