In [57]:
%%capture
from pathlib import Path

if Path.cwd().stem == "notebooks":
    %cd ..
    %load_ext autoreload
    %autoreload 2

In [58]:
import logging

import holoviews as hv
import hvplot.polars  # noqa
import polars as pl
from icecream import ic
from polars import col

from src.data.database_manager import DatabaseManager
from src.features.quality_checks import check_sample_rate
from src.features.resampling import add_timestamp_μs_column
from src.features.scaling import scale_min_max
from src.features.transforming import map_trials, merge_data_dfs
from src.log_config import configure_logging
from src.plots.utils import prepare_multiline_hvplot

configure_logging(
    stream_level=logging.DEBUG, ignore_libs=("Comm", "bokeh", "tornado", "matplotlib")
)
logger = logging.getLogger(__name__.rsplit(".", maxsplit=1)[-1])

pl.Config.set_tbl_rows(12)  # for the 12 trials
hv.output(widget_location="bottom", size=130)

## Question: What is a good bin size?

In [72]:
with DatabaseManager() as db:
    stimulus = db.get_table("feature_eda")
    trials = db.get_table("trials")  # get trials for stimulus seeds


df = merge_data_dfs(
    [stimulus, trials],
    merge_on=["participant_id", "trial_id", "trial_number"],
).drop("duration", "skin_area", "timestamp_start", "timestamp_end", strict=False)
df

trial_id,trial_number,participant_id,rownumber,timestamp,samplenumber,eda_raw,eda_tonic,eda_phasic,stimulus_seed
u16,u8,u8,f64,f64,f64,f64,f64,f64,u16
1,1,1,37664.142857,294200.0,57896.142857,0.753564,0.752119,0.001445,396
1,1,1,37676.5,294300.0,57908.5,0.753506,0.752138,0.001368,396
1,1,1,37690.875,294400.0,57922.875,0.753413,0.752158,0.001256,396
1,1,1,37704.25,294500.0,57936.25,0.753432,0.752177,0.001255,396
1,1,1,37714.8,294600.0,57946.8,0.75378,0.752193,0.001586,396
1,1,1,37727.642857,294700.0,57959.642857,0.753564,0.752212,0.001352,396
…,…,…,…,…,…,…,…,…,…
332,12,28,355422.0,2.7766e6,467019.0,13.60289,13.578916,-0.065532,133
332,12,28,355433.555556,2.7767e6,467030.555556,13.624445,13.578906,-0.048941,133
332,12,28,355446.1,2.7768e6,467043.1,13.644242,13.5789,-0.034242,133


In [61]:
modality_map = {
    "stimulus": ["rating", "temperature"],
}


def zero_based_timestamps(df: pl.DataFrame) -> pl.DataFrame:
    return df.with_columns(
        (col("timestamp") - col("timestamp").min().over("trial_id")).alias(
            "zeroed_timestamp"
        )
    )


def aggregate_over_seeds(
    df: pl.DataFrame,
    modality: str = "stimulus",
    bin_size: int = 1,  # TODO
) -> pl.DataFrame:
    """Aggregate over seeds for each trial using group_by_dynamic."""
    # Note: without group_by_dynamic, this would be something like
    # >>> df.with_columns(
    # >>>     [(col("zeroed_timestamp") // 1000).cast(pl.Int32).alias("time_bin")]
    # >>>     )
    # >>>     .group_by(["stimulus_seed", "time_bin"])

    # Select signals for the given modality
    modality = modality.lower()
    signals = [signal for signal in modality_map[modality] if signal in df.columns]

    # Zero-based timestamp in milliseconds
    df = zero_based_timestamps(df)
    # Add microsecond timestamp column for better precision as group_by_dynamic uses int
    df = add_timestamp_μs_column(df, "zeroed_timestamp")
    # Time binning
    return (
        (
            df.sort("zeroed_timestamp_µs")
            .group_by_dynamic(
                "zeroed_timestamp_µs",
                every=f"{int((1000 / (1/bin_size))*1000)}i",
                group_by=["stimulus_seed"],
            )
            .agg(
                # Average and standard deviation for each signal
                [
                    col(signal).mean().alias(f"avg_{signal.lower()}")
                    for signal in signals
                ]
                + [
                    col(signal).std().alias(f"std_{signal.lower()}")
                    for signal in signals
                ]
                # Sample size for each bin
                + [pl.len().alias("sample_size")]
                # TODO find out why the sample sizes are not constant
            )
        )
        .with_columns((col("zeroed_timestamp_µs") / 1_000_000).alias("time_bin"))
        .sort("stimulus_seed", "time_bin")
        # remove measures at exactly 180s so that they don't get their own bin
        .filter(col("time_bin") < 180)
        .drop("zeroed_timestamp_µs")
    )


def add_confidence_interval(
    df: pl.DataFrame,
    modality: str = "stimulus",
) -> pl.DataFrame:
    return result.with_columns(
        [
            (
                col(f"avg_{signal}")
                - 1.96 * (col(f"std_{signal}") / col("sample_size").sqrt())
            ).alias(f"ci_lower_{signal}")
            for signal in modality_map[modality]
        ]
        + [
            (
                col(f"avg_{signal}")
                + 1.96 * (col(f"std_{signal}") / col("sample_size").sqrt())
            ).alias(f"ci_upper_{signal}")
            for signal in modality_map[modality]
        ]
    ).sort("stimulus_seed", "time_bin")


result = aggregate_over_seeds(df, bin_size=1)
result
conf = add_confidence_interval(result)
conf

stimulus_seed,avg_rating,avg_temperature,std_rating,std_temperature,sample_size,time_bin,ci_lower_rating,ci_lower_temperature,ci_upper_rating,ci_upper_temperature
u16,f64,f64,f64,f64,u32,f64,f64,f64,f64,f64
133,0.432017,0.007027,0.260652,0.006643,282,0.0,0.401594,0.006252,0.462439,0.007803
133,0.580223,0.052573,0.219553,0.01979,280,1.0,0.554506,0.050255,0.60594,0.054891
133,0.619137,0.141308,0.224974,0.031273,281,2.0,0.592832,0.137652,0.645442,0.144965
133,0.657968,0.265075,0.218392,0.039988,283,3.0,0.632523,0.260416,0.683413,0.269734
133,0.697452,0.412534,0.204328,0.045167,287,4.0,0.673812,0.407309,0.721092,0.41776
133,0.736053,0.567752,0.180708,0.044726,285,5.0,0.715072,0.562559,0.757033,0.572944
…,…,…,…,…,…,…,…,…,…,…
952,0.147058,0.212601,0.212265,0.016055,280,174.0,0.122195,0.21072,0.171921,0.214481
952,0.120062,0.161466,0.196615,0.013631,280,175.0,0.097033,0.15987,0.143092,0.163063
952,0.103728,0.119274,0.185324,0.01087,280,176.0,0.08202,0.118001,0.125435,0.120547


In [70]:
variables = ["rating", "temperature"]

all_plots = conf.plot(
    x="time_bin",
    y=[f"avg_{var}" for var in variables],
    groupby="stimulus_seed",
    kind="line",
    grid=True,
)


for var in variables:
    all_plots *= conf.hvplot.area(
        x="time_bin",
        y=f"ci_lower_{var}",
        y2=f"ci_upper_{var}",
        groupby="stimulus_seed",
        alpha=0.5,
        line_width=0,
        fill_color="lightblue",
        grid=True,
        ylim=(0, 1),
    )

all_plots


BokehModel(combine_events=True, render_bundle={'docs_json': {'2b89e9a5-f732-4b2c-884d-a704f3c7b5f7': {'version…