In [1]:
%%capture
from pathlib import Path

if Path.cwd().stem == "notebooks":
    %cd ..
    %load_ext autoreload
    %autoreload 2

In [2]:
import logging

import holoviews as hv
import hvplot.polars  # noqa
import polars as pl
from icecream import ic
from polars import col

from src.data.database_manager import DatabaseManager
from src.data.quality_checks import check_sample_rate
from src.features.resampling import add_timestamp_μs_column
from src.features.scaling import scale_min_max
from src.features.transforming import map_trials, merge_data_dfs
from src.log_config import configure_logging
from src.plots.confidence_intervals import plot_confidence_intervals
from src.plots.utils import prepare_multiline_hvplot

configure_logging(
    stream_level=logging.DEBUG,
    ignore_libs=("Comm", "bokeh", "tornado", "matplotlib"),
)
logger = logging.getLogger(__name__.rsplit(".", maxsplit=1)[-1])

pl.Config.set_tbl_rows(12)  # for the 12 trials
hv.output(widget_location="bottom", size=150)

In [9]:
(
    plot_confidence_intervals("stimulus")
    # * plot_confidence_intervals("pupil")
    # * plot_confidence_intervals("eda")#, signals=["eda_tonic"])
    * plot_confidence_intervals("ppg")  # , signals=["ppg_rate"])
)
# TODO: add bin_size parameter, signal parameter, etc.
# NOTE: pupil looks weird because of left free outliers, improve pipeline

BokehModel(combine_events=True, render_bundle={'docs_json': {'9e43c613-26b1-4284-8cbc-059991cb2a11': {'version…

In [4]:
with DatabaseManager() as db:
    df = db.get_table("preprocess_eda")
scale_min_max(df).plot(
    x="timestamp",
    y=["eda_raw", "eda_tonic", "eda_phasic"],
    groupby="trial_id",
)

BokehModel(combine_events=True, render_bundle={'docs_json': {'713d5ee7-351b-4e44-9381-22275057c07f': {'version…

## Question: What is a good bin size?

In [5]:
modality_map = {
    "stimulus": ["rating", "temperature"],
    "eda": ["eda_tonic", "eda_phasic"],
}


def plot_confidence_interval_over_seeds(modality: str):
    signals = modality_map[modality]
    # As we plot for each stimulus seed, we need trial metadata first
    with DatabaseManager() as db:
        df = db.get_table("feature_" + modality)
        trials = db.get_table("trials")  # get trials for stimulus seeds
    df = merge_data_dfs(
        [df, trials],
        merge_on=["participant_id", "trial_id", "trial_number"],
    ).drop("duration", "skin_area", "timestamp_start", "timestamp_end", strict=False)

    #
    df = aggregate_over_seeds(df, modality, bin_size=1)
    # scale for better visualization, must come before adding confidence intervals
    df = scale_min_max(
        df, exclude_additional_columns=["time_bin", "rating", "temperature"]
    )
    df = add_confidence_interval(df, modality)

    # Create plot
    plots = df.plot(
        x="time_bin",
        y=[f"avg_{signal}" for signal in signals],
        groupby="stimulus_seed",
        kind="line",
        xlabel="Time (s)",
        ylabel="Normalized value",
        grid=True,
    )
    for signal in signals:
        plots *= df.hvplot.area(
            x="time_bin",
            y=f"ci_lower_{signal}",
            y2=f"ci_upper_{signal}",
            groupby="stimulus_seed",
            alpha=0.5,
            line_width=0,
            fill_color="lightblue",
            grid=True,
        )

    return plots


plot_confidence_interval_over_seeds("stimulus") * plot_confidence_interval_over_seeds(
    "eda"
)


NameError: name 'aggregate_over_seeds' is not defined

In [None]:
with DatabaseManager() as db:
    stimulus = db.get_table("feature_eda")
    trials = db.get_table("trials")  # get trials for stimulus seeds


df = merge_data_dfs(
    [stimulus, trials],
    merge_on=["participant_id", "trial_id", "trial_number"],
).drop("duration", "skin_area", "timestamp_start", "timestamp_end", strict=False)
df

trial_id,trial_number,participant_id,rownumber,timestamp,samplenumber,eda_raw,eda_tonic,eda_phasic,stimulus_seed
u16,u8,u8,f64,f64,f64,f64,f64,f64,u16
1,1,1,37664.142857,294200.0,57896.142857,0.753564,0.752119,0.001445,396
1,1,1,37676.5,294300.0,57908.5,0.753506,0.752138,0.001368,396
1,1,1,37690.875,294400.0,57922.875,0.753413,0.752158,0.001256,396
1,1,1,37704.25,294500.0,57936.25,0.753432,0.752177,0.001255,396
1,1,1,37714.8,294600.0,57946.8,0.75378,0.752193,0.001586,396
1,1,1,37727.642857,294700.0,57959.642857,0.753564,0.752212,0.001352,396
…,…,…,…,…,…,…,…,…,…
332,12,28,355422.0,2.7766e6,467019.0,13.60289,13.578916,-0.065532,133
332,12,28,355433.555556,2.7767e6,467030.555556,13.624445,13.578906,-0.048941,133
332,12,28,355446.1,2.7768e6,467043.1,13.644242,13.5789,-0.034242,133


In [None]:
with DatabaseManager() as db:
    stimulus = db.get_table("feature_stimulus")
    trials = db.get_table("trials")  # get trials for stimulus seeds


stim = merge_data_dfs(
    [stimulus, trials],
    merge_on=["participant_id", "trial_id", "trial_number"],
).drop("duration", "skin_area", "timestamp_start", "timestamp_end", strict=False)
stim
stim = aggregate_over_seeds(stim, bin_size=1, modality="stimulus")
stim
# scale for better visualization, must come before adding confidence intervals
stim = scale_min_max(
    stim, exclude_additional_columns=["time_bin", "rating", "temperature"]
)
stim = add_confidence_interval(stim, modality="stimulus")
stim



stimulus_seed,avg_rating,avg_temperature,std_rating,std_temperature,sample_size,time_bin,ci_lower_rating,ci_lower_temperature,ci_upper_rating,ci_upper_temperature
u16,f64,f64,f64,f64,u32,f64,f64,f64,f64,f64
133,0.436889,0.005348,0.693159,0.073292,282,0.0,0.355986,-0.003207,0.517792,0.013902
133,0.593096,0.051038,0.567941,0.218349,280,1.0,0.526572,0.025462,0.659621,0.076614
133,0.634111,0.140056,0.584458,0.345046,281,2.0,0.565774,0.099712,0.702448,0.1804
133,0.675038,0.264217,0.564405,0.441206,283,3.0,0.60928,0.212812,0.740797,0.315622
133,0.716654,0.412146,0.521555,0.498346,287,4.0,0.656312,0.35449,0.776995,0.469802
133,0.757338,0.567857,0.449593,0.493479,285,5.0,0.70514,0.510564,0.809536,0.62515
…,…,…,…,…,…,…,…,…,…,…
952,0.136547,0.211576,0.545738,0.17714,280,174.0,0.072624,0.190827,0.200471,0.232324
952,0.108094,0.160278,0.498057,0.150399,280,175.0,0.049756,0.142662,0.166433,0.177895
952,0.090878,0.117952,0.463657,0.119929,280,176.0,0.036568,0.103904,0.145187,0.131999


In [None]:
def _zero_based_timestamps(df: pl.DataFrame) -> pl.DataFrame:
    return df.with_columns(
        (col("timestamp") - col("timestamp").min().over("trial_id")).alias(
            "zeroed_timestamp"
        )
    )


def aggregate_over_seeds(
    df: pl.DataFrame,
    modality: str,
    bin_size: int = 1,  # TODO
) -> pl.DataFrame:
    """Aggregate over seeds for each trial using group_by_dynamic."""
    # Note: without group_by_dynamic, this would be something like
    # >>> df.with_columns(
    # >>>     [(col("zeroed_timestamp") // 1000).cast(pl.Int32).alias("time_bin")]
    # >>>     )
    # >>>     .group_by(["stimulus_seed", "time_bin"])

    # Select signals for the given modality
    modality = modality.lower()
    signals = [signal for signal in modality_map[modality] if signal in df.columns]

    # Zero-based timestamp in milliseconds
    df = _zero_based_timestamps(df)
    # Add microsecond timestamp column for better precision as group_by_dynamic uses int
    df = add_timestamp_μs_column(df, "zeroed_timestamp")
    # Time binning
    return (
        (
            df.sort("zeroed_timestamp_µs")
            .group_by_dynamic(
                "zeroed_timestamp_µs",
                every=f"{int((1000 / (1/bin_size))*1000)}i",
                group_by=["stimulus_seed"],
            )
            .agg(
                # Average and standard deviation for each signal
                [
                    col(signal).mean().alias(f"avg_{signal.lower()}")
                    for signal in signals
                ]
                + [
                    col(signal).std().alias(f"std_{signal.lower()}")
                    for signal in signals
                ]
                # Sample size for each bin
                + [pl.len().alias("sample_size")]
                # TODO find out why the sample sizes are not constant
            )
        )
        .with_columns((col("zeroed_timestamp_µs") / 1_000_000).alias("time_bin"))
        .sort("stimulus_seed", "time_bin")
        # remove measures at exactly 180s so that they don't get their own bin
        .filter(col("time_bin") < 180)
        .drop("zeroed_timestamp_µs")
    )


def add_confidence_interval(
    df: pl.DataFrame,
    modality: str,
) -> pl.DataFrame:
    return df.with_columns(
        [
            (
                col(f"avg_{signal}")
                - 1.96 * (col(f"std_{signal}") / col("sample_size").sqrt())
            ).alias(f"ci_lower_{signal}")
            for signal in modality_map[modality]
        ]
        + [
            (
                col(f"avg_{signal}")
                + 1.96 * (col(f"std_{signal}") / col("sample_size").sqrt())
            ).alias(f"ci_upper_{signal}")
            for signal in modality_map[modality]
        ]
    ).sort("stimulus_seed", "time_bin")


result = aggregate_over_seeds(df, bin_size=1, modality="eda")
# scale for better visualization, must come before adding confidence intervals
result = scale_min_max(result, exclude_additional_columns=["time_bin"])
conf = add_confidence_interval(result, modality="eda")
conf



stimulus_seed,avg_eda_tonic,avg_eda_phasic,std_eda_tonic,std_eda_phasic,sample_size,time_bin,ci_lower_eda_tonic,ci_lower_eda_phasic,ci_upper_eda_tonic,ci_upper_eda_phasic
u16,f64,f64,f64,f64,u32,f64,f64,f64,f64,f64
133,0.811622,0.635856,0.609608,0.361577,280,0.0,0.740217,0.593504,0.883027,0.678209
133,0.811483,0.898172,0.601589,0.519109,280,1.0,0.741017,0.837368,0.881948,0.958977
133,0.808668,0.738668,0.591169,0.371447,280,2.0,0.739423,0.695159,0.877913,0.782176
133,0.803703,0.430211,0.579055,0.281552,280,3.0,0.735877,0.397232,0.871529,0.46319
133,0.798057,0.273332,0.567652,0.328337,280,4.0,0.731566,0.234873,0.864547,0.311791
133,0.79298,0.242776,0.558477,0.338503,280,5.0,0.727565,0.203126,0.858396,0.282426
…,…,…,…,…,…,…,…,…,…,…
952,0.272081,0.501009,0.635495,0.167498,280,174.0,0.197643,0.481389,0.346518,0.520628
952,0.261934,0.485828,0.620165,0.063436,280,175.0,0.189293,0.478398,0.334576,0.493259
952,0.254477,0.474188,0.608833,0.00555,280,176.0,0.183163,0.473538,0.325791,0.474838


In [None]:
variables = modality_map["eda"]

all_plots = conf.plot(
    x="time_bin",
    y=[f"avg_{var}" for var in variables],
    groupby="stimulus_seed",
    kind="line",
    grid=True,
)


for var in variables:
    all_plots *= conf.hvplot.area(
        x="time_bin",
        y=f"ci_lower_{var}",
        y2=f"ci_upper_{var}",
        groupby="stimulus_seed",
        alpha=0.5,
        line_width=0,
        fill_color="lightblue",
        grid=True,
    )

all_plots


BokehModel(combine_events=True, render_bundle={'docs_json': {'2b5090ee-7305-4c02-8d2e-63a24852eeed': {'version…

In [None]:
variables = modality_map["eda"]

all_plots = conf.plot(
    x="time_bin",
    y=[f"avg_{var}" for var in variables],
    groupby="stimulus_seed",
    kind="line",
    grid=True,
)


for var in variables:
    all_plots *= conf.hvplot.area(
        x="time_bin",
        y=f"ci_lower_{var}",
        y2=f"ci_upper_{var}",
        groupby="stimulus_seed",
        alpha=0.5,
        line_width=0,
        fill_color="lightblue",
        grid=True,
    )


all_plots *= stim.hvplot.area(
    x="time_bin",
    y="ci_lower_rating",
    y2="ci_upper_rating",
    groupby="stimulus_seed",
    alpha=0.5,
    line_width=0,
    fill_color="yellow",
    grid=True,
)

all_plots * stim.plot(
    x="time_bin",
    y="avg_rating",
    groupby="stimulus_seed",
    kind="line",
    grid=True,
)


BokehModel(combine_events=True, render_bundle={'docs_json': {'3c51fcff-5143-4469-999a-0288ac6b2197': {'version…

In [None]:
stim

stimulus_seed,avg_eda_tonic,avg_eda_phasic,std_eda_tonic,std_eda_phasic,sample_size,time_bin,ci_lower_eda_tonic,ci_lower_eda_phasic,ci_upper_eda_tonic,ci_upper_eda_phasic
u16,f64,f64,f64,f64,u32,f64,f64,f64,f64,f64
133,0.811622,0.635856,0.609608,0.361577,280,0.0,0.740217,0.593504,0.883027,0.678209
133,0.811483,0.898172,0.601589,0.519109,280,1.0,0.741017,0.837368,0.881948,0.958977
133,0.808668,0.738668,0.591169,0.371447,280,2.0,0.739423,0.695159,0.877913,0.782176
133,0.803703,0.430211,0.579055,0.281552,280,3.0,0.735877,0.397232,0.871529,0.46319
133,0.798057,0.273332,0.567652,0.328337,280,4.0,0.731566,0.234873,0.864547,0.311791
133,0.79298,0.242776,0.558477,0.338503,280,5.0,0.727565,0.203126,0.858396,0.282426
…,…,…,…,…,…,…,…,…,…,…
952,0.272081,0.501009,0.635495,0.167498,280,174.0,0.197643,0.481389,0.346518,0.520628
952,0.261934,0.485828,0.620165,0.063436,280,175.0,0.189293,0.478398,0.334576,0.493259
952,0.254477,0.474188,0.608833,0.00555,280,176.0,0.183163,0.473538,0.325791,0.474838
