In [177]:
%%capture
from pathlib import Path

if Path.cwd().stem == "features":
    %cd ../..
    %load_ext autoreload
    %autoreload 2

In [178]:
import logging

import holoviews as hv
import hvplot.pandas  # noqa
import polars as pl
from icecream import ic
from polars import col

from src.data.database_manager import DatabaseManager
from src.data.quality_checks import check_sample_rate
from src.experiments.measurement.stimulus_generator import StimulusGenerator
from src.features.resampling import add_time_column, downsample
from src.features.scaling import scale_min_max
from src.features.transforming import map_trials
from src.log_config import configure_logging
from src.plots.plot_stimulus import plot_stimulus_with_shapes

configure_logging(
    stream_level=logging.DEBUG,
    ignore_libs=["matplotlib", "Comm", "bokeh", "tornado"],
)

pl.Config.set_tbl_rows(12)  # for the 12 trials
hv.output(widget_location="bottom", size=130)

In [179]:
stim = StimulusGenerator(seed=9)
plot_stimulus_with_shapes(stim)

In [180]:
stim.labels

{'decreasing_intervals': [(9000, 20000),
  (26000, 46000),
  (76000, 93000),
  (120000, 140000),
  (160000, 180000)],
 'major_decreasing_intervals': [(26000, 46000),
  (120000, 140000),
  (160000, 180000)],
 'increasing_intervals': [(0, 9000),
  (20000, 26000),
  (46000, 76000),
  (93000, 120000),
  (145000, 160000)],
 'plateau_intervals': [(54600, 69600), (97000, 112000)],
 'prolonged_minima_intervals': [(139900, 144900)]}

In [181]:
db = DatabaseManager()

In [182]:
with db:
    stimulus = db.get_table("Feature_Stimulus")
    trials = db.get_table("Trials")


In [183]:
# Get label intervals for all stimulus seeds
seeds = trials.get_column("stimulus_seed").unique()
labels = {seed: StimulusGenerator(seed=seed).labels for seed in seeds}

labels

{133: {'decreasing_intervals': [(10000, 30000),
   (65000, 85000),
   (99000, 116000),
   (147000, 167000),
   (173000, 180000)],
  'major_decreasing_intervals': [(10000, 30000),
   (65000, 85000),
   (147000, 167000)],
  'increasing_intervals': [(0, 10000),
   (35000, 65000),
   (85000, 99000),
   (116000, 132000),
   (167000, 173000)],
  'plateau_intervals': [(35200, 50200), (114300, 129300)],
  'prolonged_minima_intervals': [(29900, 34900)]},
 243: {'decreasing_intervals': [(20000, 37000),
   (47000, 67000),
   (99000, 119000),
   (131000, 151000),
   (170000, 180000)],
  'major_decreasing_intervals': [(47000, 67000),
   (99000, 119000),
   (131000, 151000)],
  'increasing_intervals': [(0, 20000),
   (37000, 47000),
   (67000, 99000),
   (124000, 131000),
   (151000, 170000)],
  'plateau_intervals': [(2600, 17600), (72800, 87800)],
  'prolonged_minima_intervals': [(119000, 124000)]},
 265: {'decreasing_intervals': [(9000, 27000),
   (44000, 64000),
   (92000, 112000),
   (147000, 16

In [184]:
from src.features.transforming import merge_data_dfs

df = merge_data_dfs(
    [stimulus, trials], merge_on=["trial_id", "participant_id", "trial_number"]
)
df

df = df.with_columns(
    [
        (col("timestamp") - col("timestamp").min().over("trial_id")).alias(
            "normalized_timestamp"
        )
    ]
).drop("duration", "timestamp", "timestamp_end", "timestamp_start")

df

trial_id,trial_number,participant_id,rownumber,temperature,rating,stimulus_seed,skin_area,normalized_timestamp
u16,u8,u8,u32,f64,f64,u16,u8,f64
1,1,1,0,0.0,0.425,396,1,0.0
1,1,1,1,0.000069,0.425,396,1,133.6335
1,1,1,2,0.000277,0.35375,396,1,233.6982
1,1,1,3,0.000622,0.14875,396,1,334.2696
1,1,1,4,0.001106,0.10125,396,1,434.0044
1,1,1,5,0.001728,0.2275,396,1,534.1647
…,…,…,…,…,…,…,…,…
332,12,28,21606,0.158607,0.85,133,1,179498.3054
332,12,28,21607,0.157223,0.85,133,1,179600.0331
332,12,28,21608,0.156232,0.85,133,1,179697.772


In [185]:
mdi = labels[396]["major_decreasing_intervals"][0]  # TODO: do this for all intervals

df.with_columns(
    pl.when(col("normalized_timestamp").is_between(mdi[0], mdi[1]))
    .then(1)
    .otherwise(0)
    .alias("is_mdi")
).filter(col("is_mdi") == 1)

trial_id,trial_number,participant_id,rownumber,temperature,rating,stimulus_seed,skin_area,normalized_timestamp,is_mdi
u16,u8,u8,u32,f64,f64,u16,u8,f64,i32
1,1,1,690,0.967948,1.0,396,1,69008.667,1
1,1,1,691,0.96789,1.0,396,1,69107.4036,1
1,1,1,692,0.967715,1.0,396,1,69208.1334,1
1,1,1,693,0.967423,1.0,396,1,69309.865,1
1,1,1,694,0.967015,1.0,396,1,69409.595,1
1,1,1,695,0.96649,1.0,396,1,69510.841,1
…,…,…,…,…,…,…,…,…,…
332,12,28,20696,0.142945,0.43625,133,1,88472.3373,1
332,12,28,20697,0.147973,0.43625,133,1,88572.0722,1
332,12,28,20698,0.15311,0.43625,133,1,88672.8021,1


In [186]:
import operator
from functools import reduce


def get_mask(
    group: pl.DataFrame,
    label_name: str,
) -> pl.Series:
    # Create a mask for each segment and combine them with an OR
    # without reduce and operator.or_ , one would have to start with an already
    # initialized neutral element, which is not possible in the lambda function
    return reduce(
        operator.or_,
        [
            group["normalized_timestamp"].is_between(start, end)
            for start, end in labels[group.get_column("stimulus_seed").unique().item()][
                label_name
            ]
        ],
    )


def label_intervals(
    group: pl.DataFrame,
    labels: dict[int, dict[str, list[tuple[float, float]]]],
) -> pl.DataFrame:
    stimulus_seed = group.get_column("stimulus_seed").unique().item()
    label_names = labels[stimulus_seed].keys()

    return group.with_columns(
        [
            pl.when(get_mask(group, label_name))
            .then(1)
            .otherwise(0)
            .alias(
                label_name,
            )
            for label_name in label_names
        ]
    )


df = df.group_by("stimulus_seed", maintain_order=True).map_groups(
    lambda group: label_intervals(group, labels)
)
df

trial_id,trial_number,participant_id,rownumber,temperature,rating,stimulus_seed,skin_area,normalized_timestamp,decreasing_intervals,major_decreasing_intervals,increasing_intervals,plateau_intervals,prolonged_minima_intervals
u16,u8,u8,u32,f64,f64,u16,u8,f64,i32,i32,i32,i32,i32
1,1,1,0,0.0,0.425,396,1,0.0,0,0,1,0,0
1,1,1,1,0.000069,0.425,396,1,133.6335,0,0,1,0,0
1,1,1,2,0.000277,0.35375,396,1,233.6982,0,0,1,0,0
1,1,1,3,0.000622,0.14875,396,1,334.2696,0,0,1,0,0
1,1,1,4,0.001106,0.10125,396,1,434.0044,0,0,1,0,0
1,1,1,5,0.001728,0.2275,396,1,534.1647,0,0,1,0,0
…,…,…,…,…,…,…,…,…,…,…,…,…,…
324,4,28,7198,0.343688,0.82625,243,3,179481.6145,1,0,0,0,0
324,4,28,7199,0.343059,0.81125,243,3,179581.348,1,0,0,0,0
324,4,28,7200,0.342608,0.7875,243,3,179682.0784,1,0,0,0,0


In [191]:
result = (
    df.with_columns(
        [
            (
                pl.col(label_name)
                * (
                    (pl.col(label_name).diff().fill_null(0) == 1)
                    | (pl.col(label_name).cum_count() == 0)
                )
            )
            .cum_sum()
            .alias("temp_" + label_name)
            for label_name in labels[list(labels)[0]].keys()
        ]  # converting to list to get the first key
    )
    .with_columns(
        [
            pl.when(pl.col(label_name) == 1)
            .then(pl.col("temp_" + label_name))
            .otherwise(0)
            .alias(label_name)
            for label_name in labels[list(labels)[0]].keys()
        ]
    )
    .drop(pl.col(r"^temp_.*$"))
)

result.describe()

statistic,trial_id,trial_number,participant_id,rownumber,temperature,rating,stimulus_seed,skin_area,normalized_timestamp,decreasing_intervals,major_decreasing_intervals,increasing_intervals,plateau_intervals,prolonged_minima_intervals
str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""count""",597860.0,597860.0,597860.0,597860.0,597860.0,597860.0,597860.0,597860.0,597860.0,597860.0,597860.0,597860.0,597860.0,597860.0
"""null_count""",0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""mean""",166.515554,6.451813,14.627821,10717.490066,0.471194,0.459627,587.292353,3.51204,89983.635147,383.942632,166.090499,395.838455,55.406249,4.624839
"""std""",95.834757,3.442407,8.042692,6220.8608,0.284453,0.371246,265.414013,1.710133,51984.231159,526.233148,287.672814,533.754567,146.548219,31.682827
"""min""",1.0,1.0,1.0,0.0,0.0,0.0,133.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
"""25%""",84.0,3.0,8.0,5338.0,0.267112,0.05875,396.0,2.0,44972.597,0.0,0.0,0.0,0.0,0.0
"""50%""",167.0,6.0,15.0,10676.0,0.436876,0.45625,658.0,4.0,89976.1188,0.0,0.0,0.0,0.0,0.0
"""75%""",250.0,9.0,22.0,16073.0,0.698026,0.82,806.0,5.0,134995.1967,756.0,249.0,807.0,0.0,0.0
"""max""",332.0,12.0,28.0,21611.0,1.0,1.0,952.0,6.0,180005.3386,1660.0,996.0,1659.0,664.0,332.0


In [82]:
labels

{133: {'decreasing_intervals': [(10000, 30000),
   (65000, 85000),
   (99000, 116000),
   (147000, 167000),
   (173000, 180000)],
  'major_decreasing_intervals': [(10000, 30000),
   (65000, 85000),
   (147000, 167000)],
  'increasing_intervals': [(0, 10000),
   (35000, 65000),
   (85000, 99000),
   (116000, 132000),
   (167000, 173000)],
  'plateau_intervals': [(35200, 50200), (114300, 129300)],
  'prolonged_minima_intervals': [(29900, 34900)]},
 243: {'decreasing_intervals': [(20000, 37000),
   (47000, 67000),
   (99000, 119000),
   (131000, 151000),
   (170000, 180000)],
  'major_decreasing_intervals': [(47000, 67000),
   (99000, 119000),
   (131000, 151000)],
  'increasing_intervals': [(0, 20000),
   (37000, 47000),
   (67000, 99000),
   (124000, 131000),
   (151000, 170000)],
  'plateau_intervals': [(2600, 17600), (72800, 87800)],
  'prolonged_minima_intervals': [(119000, 124000)]},
 265: {'decreasing_intervals': [(9000, 27000),
   (44000, 64000),
   (92000, 112000),
   (147000, 16

In [124]:
import polars as pl

# Assuming your data is in a DataFrame called 'df' with a column named 'intervals'
df = pl.DataFrame(
    {"intervals": [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0]}
)

# Step 1: Create a column that marks the start of each interval
df = df.with_columns(
    [(pl.col("intervals").diff().fill_null(0) == 1).alias("interval_start")]
)

# Step 2: Assign a cumulative sum to the interval starts
df = df.with_columns([pl.col("interval_start").cum_sum().alias("interval_number")])

# Step 3: Fill the interval numbers within each interval
df = df.with_columns(
    [
        pl.when(pl.col("intervals") == 1)
        .then(pl.col("interval_number"))
        .otherwise(0)
        .alias("result")
    ]
)

print(df)


shape: (19, 4)
┌───────────┬────────────────┬─────────────────┬────────┐
│ intervals ┆ interval_start ┆ interval_number ┆ result │
│ ---       ┆ ---            ┆ ---             ┆ ---    │
│ i64       ┆ bool           ┆ u32             ┆ u32    │
╞═══════════╪════════════════╪═════════════════╪════════╡
│ 0         ┆ false          ┆ 0               ┆ 0      │
│ 0         ┆ false          ┆ 0               ┆ 0      │
│ 0         ┆ false          ┆ 0               ┆ 0      │
│ 0         ┆ false          ┆ 0               ┆ 0      │
│ 1         ┆ true           ┆ 1               ┆ 1      │
│ 1         ┆ false          ┆ 1               ┆ 1      │
│ …         ┆ …              ┆ …               ┆ …      │
│ 0         ┆ false          ┆ 2               ┆ 0      │
│ 0         ┆ false          ┆ 2               ┆ 0      │
│ 0         ┆ false          ┆ 2               ┆ 0      │
│ 1         ┆ true           ┆ 3               ┆ 3      │
│ 1         ┆ false          ┆ 3               ┆ 3      │

In [127]:
import polars as pl

# Assuming your data is in a DataFrame called 'df' with a column named 'intervals'
df = pl.DataFrame(
    {"intervals": [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0]}
)

result = df.with_columns(
    [
        pl.when(pl.col("intervals") == 1)
        .then(pl.col("intervals").diff().fill_null(0).cum_sum())
        .otherwise(0)
        .alias("result")
    ]
)

print(result)


shape: (19, 2)
┌───────────┬────────┐
│ intervals ┆ result │
│ ---       ┆ ---    │
│ i64       ┆ i64    │
╞═══════════╪════════╡
│ 0         ┆ 0      │
│ 0         ┆ 0      │
│ 0         ┆ 0      │
│ 0         ┆ 0      │
│ 1         ┆ 1      │
│ 1         ┆ 1      │
│ …         ┆ …      │
│ 0         ┆ 0      │
│ 0         ┆ 0      │
│ 0         ┆ 0      │
│ 1         ┆ 1      │
│ 1         ┆ 1      │
│ 0         ┆ 0      │
└───────────┴────────┘


In [131]:
import polars as pl

# Assuming your data is in a DataFrame called 'df' with a column named 'intervals'
df = pl.DataFrame(
    {"intervals": [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0]}
)

result = df.with_columns(
    [
        (
            pl.col("intervals")
            * (
                (pl.col("intervals").diff().fill_null(0) == 1)
                | (pl.col("intervals").cum_count() == 0)
            )
        )
        .cum_sum()
        .alias("result")
    ]
)

print(result)


shape: (19, 2)
┌───────────┬────────┐
│ intervals ┆ result │
│ ---       ┆ ---    │
│ i64       ┆ i64    │
╞═══════════╪════════╡
│ 0         ┆ 0      │
│ 0         ┆ 0      │
│ 0         ┆ 0      │
│ 0         ┆ 0      │
│ 1         ┆ 1      │
│ 1         ┆ 1      │
│ …         ┆ …      │
│ 0         ┆ 2      │
│ 0         ┆ 2      │
│ 0         ┆ 2      │
│ 1         ┆ 3      │
│ 1         ┆ 3      │
│ 0         ┆ 3      │
└───────────┴────────┘


In [136]:
import polars as pl

# Assuming your data is in a DataFrame called 'df' with a column named 'intervals'
df = pl.DataFrame(
    {"intervals": [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0]}
)

result = df.with_columns(
    (
        pl.col("intervals")
        * (
            (pl.col("intervals").diff().fill_null(0) == 1)
            | (pl.col("intervals").cum_count() == 0)
        )
    )
    .cum_sum()
    .alias("temp")
).with_columns(
    pl.when(pl.col("intervals") == 1).then(pl.col("temp")).otherwise(0).alias("result")
)

print(result)


shape: (19, 3)
┌───────────┬──────┬────────┐
│ intervals ┆ temp ┆ result │
│ ---       ┆ ---  ┆ ---    │
│ i64       ┆ i64  ┆ i64    │
╞═══════════╪══════╪════════╡
│ 0         ┆ 0    ┆ 0      │
│ 0         ┆ 0    ┆ 0      │
│ 0         ┆ 0    ┆ 0      │
│ 0         ┆ 0    ┆ 0      │
│ 1         ┆ 1    ┆ 1      │
│ 1         ┆ 1    ┆ 1      │
│ …         ┆ …    ┆ …      │
│ 0         ┆ 2    ┆ 0      │
│ 0         ┆ 2    ┆ 0      │
│ 0         ┆ 2    ┆ 0      │
│ 1         ┆ 3    ┆ 3      │
│ 1         ┆ 3    ┆ 3      │
│ 0         ┆ 3    ┆ 0      │
└───────────┴──────┴────────┘
