In [2]:
%%capture
from pathlib import Path

if Path.cwd().stem == "models":
    %cd ../..
    %load_ext autoreload
    %autoreload 2

In [3]:
import logging
from pathlib import Path

import holoviews as hv
import hvplot.polars  # noqa
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from icecream import ic
from polars import col
from sklearn.model_selection import train_test_split

from src.data.database_manager import DatabaseManager
from src.features.utils import to_describe
from src.log_config import configure_logging

configure_logging(
    stream_level=logging.DEBUG,
    ignore_libs=["matplotlib", "Comm", "bokeh", "tornado"],
)

pl.Config.set_tbl_rows(12)  # for the 12 trials
hv.output(widget_location="bottom", size=130)

db = DatabaseManager()

In [4]:
with db:
    labels = db.get_table("Labels")
    eda = db.get_table("Feature_EDA")

In [5]:
eda.hvplot(x="timestamp", y=["eda_raw"], groupby="trial_id", width=800, height=400)

BokehModel(combine_events=True, render_bundle={'docs_json': {'6a52465e-d1d3-4ff9-bad9-3d1f6ea63b9f': {'version…

In [6]:
labels = labels.with_columns(
    # Add time counter for decreases and strictly increases
    (
        pl.when(col("strictly_increasing_intervals") != 0)
        .then(
            col("timestamp")
            - col("timestamp").min().over("strictly_increasing_intervals")
        )
        .otherwise(None)
    ).alias("normalized_timestamp_increases"),
    (
        pl.when(col("decreasing_intervals") != 0)
        .then(col("timestamp") - col("timestamp").min().over("decreasing_intervals"))
        .otherwise(None)
    ).alias("normalized_timestamp_decreases"),
    # Only keep the first 5 seconds
).filter(
    (col("normalized_timestamp_increases") < 5000)
    | (col("normalized_timestamp_decreases") < 5000)
)
labels

trial_id,trial_number,participant_id,rownumber,timestamp,temperature,rating,stimulus_seed,skin_area,normalized_timestamp,decreasing_intervals,major_decreasing_intervals,increasing_intervals,plateau_intervals,prolonged_minima_intervals,strictly_increasing_intervals,normalized_timestamp_increases,normalized_timestamp_decreases
u16,u8,u8,u32,f64,f64,f64,u16,u8,f64,u16,u16,u16,u16,u16,u16,f64,f64
1,1,1,320,326250.8398,0.80056,0.83375,396,1,32026.5088,1,0,0,0,0,0,,0.0
1,1,1,321,326351.5703,0.800524,0.83375,396,1,32127.2393,1,0,0,0,0,0,,100.7305
1,1,1,322,326453.3964,0.800415,0.8375,396,1,32229.0654,1,0,0,0,0,0,,202.5566
1,1,1,323,326551.3829,0.800232,0.83875,396,1,32327.0519,1,0,0,0,0,0,,300.5431
1,1,1,324,326651.1161,0.799977,0.83875,396,1,32426.7851,1,0,0,0,0,0,,400.2763
1,1,1,325,326751.5133,0.799649,0.83875,396,1,32527.1823,1,0,0,0,0,0,,500.6735
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
332,12,28,21586,2.7746e6,0.25889,0.80625,133,1,177496.8039,1660,0,0,0,0,0,,4403.8851
332,12,28,21587,2.7747e6,0.251239,0.80625,133,1,177596.5374,1660,0,0,0,0,0,,4503.6186
332,12,28,21588,2.7748e6,0.243786,0.81375,133,1,177698.2663,1660,0,0,0,0,0,,4605.3475


In [7]:
# Split data into decreasing and increasing intervals to add labels and sample ids
decreases = labels.filter(
    col("normalized_timestamp_decreases").is_not_null()
).with_columns(
    pl.lit(1).alias("label").cast(pl.UInt8),
    col("decreasing_intervals").alias("sample_id"),
)

In [9]:
increases = labels.filter(
    col("normalized_timestamp_increases").is_not_null()
).with_columns(
    pl.lit(0).alias("label").cast(pl.UInt8),
    (
        col("strictly_increasing_intervals")
        + (decreases.select(pl.last("decreasing_intervals")))  # continue from decreases
    ).alias("sample_id"),
)

In [10]:
# Join the two tables
labels = decreases.vstack(increases).sort("sample_id", "timestamp")

trial_id,trial_number,participant_id,rownumber,timestamp,temperature,rating,stimulus_seed,skin_area,normalized_timestamp,decreasing_intervals,major_decreasing_intervals,increasing_intervals,plateau_intervals,prolonged_minima_intervals,strictly_increasing_intervals,normalized_timestamp_increases,normalized_timestamp_decreases,label,sample_id
u16,u8,u8,u32,f64,f64,f64,u16,u8,f64,u16,u16,u16,u16,u16,u16,f64,f64,u8,u16
1,1,1,320,326250.8398,0.80056,0.83375,396,1,32026.5088,1,0,0,0,0,0,,0.0,1,1
1,1,1,321,326351.5703,0.800524,0.83375,396,1,32127.2393,1,0,0,0,0,0,,100.7305,1,1
1,1,1,322,326453.3964,0.800415,0.8375,396,1,32229.0654,1,0,0,0,0,0,,202.5566,1,1
1,1,1,323,326551.3829,0.800232,0.83875,396,1,32327.0519,1,0,0,0,0,0,,300.5431,1,1
1,1,1,324,326651.1161,0.799977,0.83875,396,1,32426.7851,1,0,0,0,0,0,,400.2763,1,1
1,1,1,325,326751.5133,0.799649,0.83875,396,1,32527.1823,1,0,0,0,0,0,,500.6735,1,1
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
332,12,28,21526,2.7686e6,0.476868,0.66375,133,1,171491.0208,0,0,1660,0,0,996,4403.2538,,0,2656
332,12,28,21527,2.7687e6,0.485034,0.6675,133,1,171591.7536,0,0,1660,0,0,996,4503.9866,,0,2656
332,12,28,21528,2.7688e6,0.492691,0.6675,133,1,171691.484,0,0,1660,0,0,996,4603.717,,0,2656


In [13]:
# Normalize the data
ROWS_PER_SAMPLE = 50

labels = (
    labels.sort(["sample_id"])  # Sort within each group if needed
    .group_by("sample_id", maintain_order=True)
    .agg(pl.all().head(ROWS_PER_SAMPLE))
    .explode(pl.all().exclude("sample_id"))  # Explode the result back into rows
)

labels.select(pl.last("sample_id")).item() * ROWS_PER_SAMPLE, labels.height

(132800, 132796)

In [26]:
# Sanity check
if not labels.height == labels.select(pl.last("sample_id")).item() * ROWS_PER_SAMPLE:
    affected_samples = []
    for sample_id, group in labels.group_by("sample_id", maintain_order=True):
        if group.height < ROWS_PER_SAMPLE or group.height > ROWS_PER_SAMPLE:
            affected_samples.append(sample_id[0])  # sample_id is a tuple
logging.debug(
    f"Normalizing to equal {ROWS_PER_SAMPLE} rows per sample was not successful for the following samples: {affected_samples}"
)
labels = labels.filter(~col("sample_id").is_in(affected_samples))

[1310, 2143, 2428, 2641]

In [146]:
train, test = train_test_split(labels, test_size=0.2, random_state=42)
train

trial_id,trial_number,participant_id,rownumber,timestamp,temperature,rating,stimulus_seed,skin_area,normalized_timestamp,decreasing_intervals,major_decreasing_intervals,increasing_intervals,plateau_intervals,prolonged_minima_intervals,strictly_increasing_intervals,normalized_timestamp_increases,normalized_timestamp_decreases
u16,u8,u8,u32,f64,f64,f64,u16,u8,f64,i64,i64,i64,i64,i64,u32,f64,f64
37,1,4,378,304543.7522,0.21797,0.05375,243,6,37830.8841,0,0,182,0,0,109,800.5433,
197,9,17,14948,2.0773e6,0.395006,0.135,396,3,53994.742,0,0,982,0,0,589,3901.7761,
139,11,12,19603,2.4931e6,0.949331,0.37375,841,2,159298.8712,0,0,695,0,0,417,4201.6864,
18,6,2,10741,1.5562e6,0.515107,0.1275,133,1,174111.4672,90,0,0,0,0,0,,1101.7058
209,9,18,15013,2.1266e6,0.949337,0.85625,658,4,60500.0553,1042,626,0,0,0,0,,2499.9926
308,12,26,20091,2.7847e6,0.17883,0.0,265,1,27970.6565,0,0,1537,0,0,923,900.7303,
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
137,9,12,15741,2.0576e6,0.969615,1.0,243,4,133277.7672,684,411,0,0,0,0,,2202.6687
275,3,24,4581,829854.0288,0.867462,1.0,681,4,97898.3963,1373,824,0,0,0,0,,4804.4061
299,3,26,4614,782041.3907,0.972046,1.0,396,4,101184.5357,1493,896,0,0,0,0,,2101.5981
