In [2]:
%%capture
from pathlib import Path

if Path.cwd().stem == "features":
    %cd ../..
    %load_ext autoreload
    %autoreload 2

In [3]:
import logging
import operator
from functools import reduce

import holoviews as hv
import hvplot.pandas  # noqa
import polars as pl
from icecream import ic
from polars import col

from src.data.database_manager import DatabaseManager
from src.data.quality_checks import check_sample_rate
from src.experiments.measurement.stimulus_generator import StimulusGenerator
from src.features.resampling import add_time_column, downsample
from src.features.scaling import scale_min_max
from src.features.transforming import map_trials, merge_data_dfs
from src.log_config import configure_logging
from src.plots.plot_stimulus import plot_stimulus_with_shapes

configure_logging(
    stream_level=logging.DEBUG,
    ignore_libs=["matplotlib", "Comm", "bokeh", "tornado"],
)

pl.Config.set_tbl_rows(12)  # for the 12 trials
hv.output(widget_location="bottom", size=130)

In [4]:
stim = StimulusGenerator(seed=9)
plot_stimulus_with_shapes(stim)

In [5]:
db = DatabaseManager()

In [12]:
with db:
    stimulus = db.get_table("Feature_Stimulus")
    trials = db.get_table("Trials")

In [17]:
def create_labels_df(stimulus: pl.DataFrame, trials: pl.DataFrame) -> pl.DataFrame:
    # Get label intervals for all stimulus seeds

    # Merge stimulus and trials dataframes
    df = merge_data_dfs(
        [stimulus, trials], merge_on=["trial_id", "participant_id", "trial_number"]
    )
    # Normalize timestamps for each trial
    df = df.with_columns(
        [
            (col("timestamp") - col("timestamp").min().over("trial_id")).alias(
                "normalized_timestamp"
            )
        ]
    ).drop("duration", "timestamp_end", "timestamp_start")

    return process_labels(df)


def process_labels(df: pl.DataFrame) -> pl.DataFrame:
    """Apply label intervals to the dataframe."""
    labels = _get_label_intervals(df)
    df = df.group_by("stimulus_seed", maintain_order=True).map_groups(
        lambda group: label_intervals(group, labels)
    )
    return number_intervals(df, labels).sort("trial_id", "timestamp")


def _get_label_intervals(df):
    """Get label intervals for all stimulus seeds."""
    seeds = df.get_column("stimulus_seed").unique()
    return {seed: StimulusGenerator(seed=seed).labels for seed in seeds}


def _get_mask(
    group: pl.DataFrame,
    labels: dict[int, dict[str, list[tuple[float, float]]]],
    label_name: str,
) -> pl.Series:
    """Create a mask for each interval segment and combine them with an OR."""
    # without reduce and operator.or_, one would have to start with an already
    # initialized neutral element, which is not possible with lambda functions
    return reduce(
        operator.or_,
        [
            group["normalized_timestamp"].is_between(start, end)
            for start, end in labels[group.get_column("stimulus_seed").unique().item()][
                label_name
            ]
        ],
    )


def label_intervals(
    group: pl.DataFrame,
    labels: dict[int, dict[str, list[tuple[float, float]]]],
) -> pl.DataFrame:
    """Create a binary column for each label that is 1 if the timestamp is within."""
    stimulus_seed = group.get_column("stimulus_seed").unique().item()
    label_names = labels[stimulus_seed].keys()

    return group.with_columns(
        [
            pl.when(_get_mask(group, labels, label_name))
            .then(1)
            .otherwise(0)
            .alias(
                label_name,
            )
            for label_name in label_names
        ]
    )


def number_intervals(df, labels):
    """Give each interval a unique, consecutive number."""
    return (
        df.with_columns(
            [
                (
                    pl.col(label_name)
                    * (
                        (pl.col(label_name).diff().fill_null(0) == 1)
                        | (pl.col(label_name).cum_count() == 0)
                    )
                )
                .cum_sum()
                .alias("temp_" + label_name)
                for label_name in labels[list(labels)[0]].keys()
            ]  # converting to list to get the first key
        )
        .with_columns(
            [
                pl.when(pl.col(label_name) == 1)
                .then(pl.col("temp_" + label_name))
                .otherwise(0)
                .alias(label_name)
                for label_name in labels[list(labels)[0]].keys()
            ]
        )
        .drop(pl.col(r"^temp_.*$"))
    )


create_labels_df(stimulus, trials)

ok


trial_id,trial_number,participant_id,rownumber,timestamp,temperature,rating,stimulus_seed,skin_area,normalized_timestamp,decreasing_intervals,major_decreasing_intervals,increasing_intervals,plateau_intervals,prolonged_minima_intervals
u16,u8,u8,u32,f64,f64,f64,u16,u8,f64,i32,i32,i32,i32,i32
1,1,1,0,294224.331,0.0,0.425,396,1,0.0,0,0,0,0,0
1,1,1,1,294357.9645,0.000069,0.425,396,1,133.6335,0,0,0,0,0
1,1,1,2,294458.0292,0.000277,0.35375,396,1,233.6982,0,0,0,0,0
1,1,1,3,294558.6006,0.000622,0.14875,396,1,334.2696,0,0,0,0,0
1,1,1,4,294658.3354,0.001106,0.10125,396,1,434.0044,0,0,0,0,0
1,1,1,5,294758.4957,0.001728,0.2275,396,1,534.1647,0,0,0,0,0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
324,4,28,7198,1.0582e6,0.343688,0.82625,243,3,179481.6145,1660,0,0,0,0
324,4,28,7199,1.0583e6,0.343059,0.81125,243,3,179581.348,1660,0,0,0,0
324,4,28,7200,1.0584e6,0.342608,0.7875,243,3,179682.0784,1660,0,0,0,0


In [9]:
df

trial_id,trial_number,participant_id,rownumber,timestamp,temperature,rating,stimulus_seed,skin_area,normalized_timestamp,decreasing_intervals,major_decreasing_intervals,increasing_intervals,plateau_intervals,prolonged_minima_intervals
u16,u8,u8,u32,f64,f64,f64,u16,u8,f64,i32,i32,i32,i32,i32
1,1,1,0,294224.331,0.0,0.425,396,1,0.0,0,0,0,0,0
1,1,1,1,294357.9645,0.000069,0.425,396,1,133.6335,0,0,0,0,0
1,1,1,2,294458.0292,0.000277,0.35375,396,1,233.6982,0,0,0,0,0
1,1,1,3,294558.6006,0.000622,0.14875,396,1,334.2696,0,0,0,0,0
1,1,1,4,294658.3354,0.001106,0.10125,396,1,434.0044,0,0,0,0,0
1,1,1,5,294758.4957,0.001728,0.2275,396,1,534.1647,0,0,0,0,0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
324,4,28,7198,1.0582e6,0.343688,0.82625,243,3,179481.6145,1660,0,0,0,0
324,4,28,7199,1.0583e6,0.343059,0.81125,243,3,179581.348,1660,0,0,0,0
324,4,28,7200,1.0584e6,0.342608,0.7875,243,3,179682.0784,1660,0,0,0,0
