In [1]:
%%capture
from pathlib import Path

if Path.cwd().stem == "notebooks":
    %cd ..
    %load_ext autoreload
    %autoreload 2

In [2]:
import logging
from pathlib import Path

import holoviews as hv
import hvplot.polars  # noqa
import ipywidgets
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from polars import col

from src.data.database_manager import DatabaseManager
from src.features.resampling import (
    add_normalized_timestamp,
)
from src.features.scaling import scale_min_max
from src.log_config import configure_logging
from src.models.data_loader import transform_sample_df_to_arrays
from src.models.sample_creation import create_samples, make_sample_set_balanced
from src.plots.utils import prepare_multiline_hvplot

configure_logging(
    stream_level=logging.DEBUG,
    ignore_libs=["matplotlib", "Comm", "bokeh", "tornado"],
)

pl.Config.set_tbl_rows(12)  # for the 12 trials
hv.output(widget_location="bottom", size=130)

db = DatabaseManager()

In [3]:
db = DatabaseManager()
with db:
    trials = db.get_trials("Trials", exclude_problematic=1)
    inv_trials = db.get_table("Invalid_Trials")

trials


trial_id,trial_number,participant_id,stimulus_seed,skin_patch,timestamp_start,timestamp_end,duration
u16,u8,u8,u16,u8,f64,f64,f64
1,1,1,870,1,200176.946,380187.1239,180010.1779
2,2,1,681,2,412801.8887,592813.1696,180011.2809
3,3,1,265,3,626262.7175,806289.7444,180027.0269
4,4,1,396,4,834066.1436,1.0141e6,180012.663
5,5,1,743,5,1.0434e6,1.2234e6,180028.4845
6,6,1,658,6,1.2588e6,1.4388e6,180027.9697
…,…,…,…,…,…,…,…
491,11,41,243,5,2.3274e6,2.5074e6,180010.7594
492,12,41,133,6,2.5373e6,2.7174e6,180012.6872
493,1,42,243,6,266684.9433,446695.4535,180010.5102


In [4]:
df.group_by("trial_id").first()

NameError: name 'df' is not defined

In [None]:
with db:
    df = db.get_trials(
        # "Explore_Data",
        "Model_Data",
        exclude_problematic=True,
    )

intervals = {
    "increases": "strictly_increasing_intervals",
    "plateaus": "plateau_intervals",
    "decreases": "major_decreasing_intervals",
}
label_mapping = {
    "increases": 0,
    "plateaus": 0,
    "decreases": 1,
}
offsets_ms = {
    "increases": 0,
    "decreases": 1000,
    "plateaus": 5000,
}

sample_duration_ms = 7000
samples = create_samples(
    df,
    intervals,
    label_mapping,
    sample_duration_ms,
    offsets_ms,
)
samples = samples.select(
    "sample_id",
    "trial_id",
    "participant_id",
    "normalized_timestamp",
    "timestamp",
    "rating",
    "temperature",
    "eda_raw",
    "pupil",
    "cheek_raise",
    "label",
)
samples.group_by("sample_id").first().group_by("label").len().sort("label")

14:43:10 | [36mDEBUG   [0m| sample_creation | Removed 749 samples with less than 69 data points


label,len
u8,u32
0,2235
1,1413


In [19]:
samples.group_by("sample_id").first().group_by("label").len().sort("label")

label,len
u8,u32
0,1822
1,1368


In [47]:
df

trial_id,trial_number,participant_id,timestamp,temperature,rating,eda_raw,ppg_raw,heart_rate,ibi,pupil_l_raw,pupil_r_raw,pupil_r,pupil_l,pupil,brow_furrow,cheek_raise,mouth_open,upper_lip_raise,nose_wrinkle,normalized_timestamp,stimulus_seed,skin_patch,decreasing_intervals,major_decreasing_intervals,increasing_intervals,strictly_increasing_intervals,strictly_increasing_intervals_without_plateaus,plateau_intervals,prolonged_minima_intervals
u16,u8,u8,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,u16,u8,u16,u16,u16,u16,u16,u16,u16
1,1,1,216743.5436,0.0,0.51125,33.476543,1392.210959,106.043289,-4.649092,5.280281,5.411819,5.313233,5.161567,5.2374,0.000002,0.001009,0.344061,6.2573e-9,0.000004,0.0,265,1,0,0,1,1,1,0,0
1,1,1,216843.5436,0.000144,0.503768,33.482077,1390.172786,105.734251,0.809401,5.285799,5.357385,5.358735,5.193335,5.276035,0.000002,0.000993,0.338073,6.1435e-9,0.000004,100.0,265,1,0,0,1,1,1,0,0
1,1,1,216943.5436,0.000627,0.499372,33.51028,1371.116608,104.402571,27.412256,5.251862,5.318561,5.378839,5.214746,5.296793,0.000002,0.000955,0.324101,5.8773e-9,0.000005,200.0,265,1,0,0,1,1,1,0,0
1,1,1,217043.5436,0.00149,0.497817,33.540616,1276.424473,105.238931,34.487759,5.206454,5.299811,5.366948,5.230644,5.298796,0.000001,0.00089,0.30134,5.4432e-9,0.000005,300.0,265,1,0,0,1,1,1,0,0
1,1,1,217143.5436,0.002739,0.4975,33.564648,1323.643234,103.603785,-11.190426,5.167074,5.293698,5.370986,5.230511,5.300749,0.000001,0.000828,0.279653,5.0370e-9,0.000006,400.0,265,1,0,0,1,1,1,0,0
1,1,1,217243.5436,0.004337,0.4975,33.608385,1549.396125,104.577336,-5.950501,5.166472,5.289672,5.370156,5.226913,5.298535,0.000001,0.000773,0.26018,4.6649e-9,0.000007,500.0,265,1,0,0,1,1,1,0,0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
500,12,42,2.9169e6,0.343664,0.0,5.63217,1303.314885,69.363974,-0.566383,2.874716,3.138802,3.105452,2.865239,2.985345,0.540668,0.000837,0.075903,0.002281,0.168981,179500.0,243,6,2405,0,0,0,0,0,0
500,12,42,2.9170e6,0.343051,0.0,5.630226,1261.295081,68.311165,-1.120452,2.740845,3.053513,3.125596,2.877898,3.001747,0.552525,0.000909,0.075434,0.002196,0.179887,179600.0,243,6,2405,0,0,0,0,0,0
500,12,42,2.9171e6,0.342598,0.0,5.628944,1289.453577,67.471458,-1.083037,2.385958,2.843083,3.12297,2.871397,2.997183,0.569679,0.000929,0.074396,0.002602,0.187344,179700.0,243,6,2405,0,0,0,0,0,0


In [49]:
df.hvplot(
    x="normalized_timestamp",
    y=["rating", "temperature", "eda_raw", "pupil", "cheek_raise"],
    groupby="trial_id",
)

BokehModel(combine_events=True, render_bundle={'docs_json': {'08e95d58-0d3c-45e0-8f82-b85e6219cfe7': {'version…

In [50]:
samples.group_by("sample_id").len("count")

sample_id,count
u16,u32
2,70
4,70
5,70
9,70
10,70
12,70
…,…
3843,70
3844,70
3845,70


In [51]:
samples = make_sample_set_balanced(samples)

sample_ids = (
    samples.group_by("sample_id").agg(pl.all().first()).select("sample_id", "label")
)
sample_ids_count = sample_ids.get_column("label").value_counts()
sample_ids_count

label,count
u8,u32
0,1244
1,1244


In [38]:
df.hvplot(
    x="normalized_timestamp", y="temperature", groupby="trial_id", height=300
) * samples.hvplot(
    x="normalized_timestamp",
    y="temperature",
    groupby="trial_id",
    height=300,
    kind="scatter",
    color="red",
)

BokehModel(combine_events=True, render_bundle={'docs_json': {'44850aea-4106-4e79-9cae-74f2185afc40': {'version…

In [52]:
scale_min_max(samples).hvplot(
    x="normalized_timestamp",
    y=["rating", "eda_raw", "pupil"],
    groupby="sample_id",
)

BokehModel(combine_events=True, render_bundle={'docs_json': {'b0352983-7c48-448e-b54c-af3bb35639cf': {'version…

In [53]:
samples

sample_id,trial_id,participant_id,normalized_timestamp,timestamp,rating,temperature,eda_raw,pupil,cheek_raise,label
u16,u16,u8,f64,f64,f64,f64,f64,f64,f64,u8
2,1,1,27000.0,243743.5436,0.0,0.17205,37.453392,4.593487,0.000147,0
2,1,1,27100.0,243843.5436,0.0,0.172108,37.424159,4.594653,0.000144,0
2,1,1,27200.0,243943.5436,0.0,0.172292,37.400388,4.594157,0.000142,0
2,1,1,27300.0,244043.5436,0.0,0.172613,37.380407,4.594548,0.00014,0
2,1,1,27400.0,244143.5436,0.0,0.173067,37.338049,4.594175,0.000137,0
2,1,1,27500.0,244243.5436,0.0,0.173661,37.310459,4.595849,0.000137,0
…,…,…,…,…,…,…,…,…,…,…
3848,500,42,138400.0,2.8758e6,0.569916,0.716249,6.049321,3.742309,0.0004,1
3848,500,42,138500.0,2.8759e6,0.539784,0.709564,6.048434,3.722801,0.000398,1
3848,500,42,138600.0,2.8760e6,0.48749,0.702579,6.046803,3.690881,0.00041,1


In [54]:
X, y, groups = transform_sample_df_to_arrays(
    (samples),
    [
        "temperature",  # only for visualization
        "rating",
        "eda_raw",
        # "eda_raw",
        # "cheek_raise",
        # "pupil",
    ],
)


In [56]:
X, y, groups = transform_sample_df_to_arrays(
    (samples),
    [
        "temperature",  # only for visualization
        "rating",
        "eda_raw",
        # "cheek_raise",
        # "pupil",
    ],
)


@ipywidgets.interact(sample=(0, X.shape[0] - 1))
def plot_sample(sample):
    # note that sample != sample id
    for i in range(X.shape[2]):
        plt.plot(X[sample, :, i])
    plt.title(f"Sample {sample} - {groups[sample]}")
    # plt.ylim(0, 1.05)

interactive(children=(IntSlider(value=1243, description='sample', max=2487), Output()), _dom_classes=('widget-…

In [42]:
prepare_multiline_hvplot(
    add_normalized_timestamp(
        (samples),
        time_column="normalized_timestamp",
        trial_column="sample_id",
    ),
    time_column="normalized_timestamp",
    trial_column="sample_id",
).hvplot(
    x="normalized_timestamp",
    y=[
        "rating",
        "eda_raw",
        "temperature",
    ],
    groupby="label",
    height=300,
    ylim=(0, 1.05),
    color=["blue", "orange"],
)

BokehModel(combine_events=True, render_bundle={'docs_json': {'0ac30dab-c5b5-4cab-b085-e72567ba3141': {'version…

In [43]:
prepare_multiline_hvplot(
    add_normalized_timestamp(
        (samples),
        time_column="normalized_timestamp",
        trial_column="sample_id",
    ),
    time_column="normalized_timestamp",
    trial_column="sample_id",
).hvplot(
    x="normalized_timestamp",
    y=[
        # "rating",
        "eda_raw",
        "pupil",
        "temperature",
    ],
    groupby="label",
    height=300,
    ylim=(0, 1.05),
)

BokehModel(combine_events=True, render_bundle={'docs_json': {'039e5e64-0e11-4fba-9296-ea9a5e3eba8b': {'version…