In [2]:
%%capture
from pathlib import Path

if Path.cwd().stem == "notebooks":
    %cd ..
    %load_ext autoreload
    %autoreload 2

In [3]:
import logging

import holoviews as hv
import hvplot.polars  # noqa
import matplotlib.pyplot as plt
import polars as pl
from polars import col

from src.data.database_manager import DatabaseManager
from src.log_config import configure_logging
from src.visualization.utils import prepare_multiline_hvplot

configure_logging(stream_level=logging.DEBUG, ignore_libs=("Comm", "bokeh", "tornado"))
logger = logging.getLogger(__name__.rsplit(".", maxsplit=1)[-1])


hv.output(widget_location="bottom", size=130)

In [3]:
db = DatabaseManager()

In [4]:
query = """
-- Collect all relevant information in a temporary table
CREATE OR REPLACE TEMPORARY TABLE temp_joined AS
SELECT 
    t.trial_id,
    t.stimulus_seed,
    t.participant_id,
    s.major_decreasing_intervals,
    rs.timestamp-t.timestamp_start as normalized_timestamp,
    rs.temperature,
    rs.rating
FROM 
    Trials t
    JOIN Seeds s ON t.stimulus_seed = s.seed
    JOIN Raw_Stimulus rs ON t.trial_id = rs.trial_id
WHERE t.participant_id != 5 -- exclude participant 5  (incomplete data, TODO)
ORDER BY t.trial_id, normalized_timestamp;
from temp_joined
;


-- Create a temporary table with interval IDs and normalized normalized_timestamp
CREATE OR REPLACE TEMPORARY TABLE tmp AS
WITH interval_ids AS (
  SELECT 
    *,
    ROW_NUMBER() OVER (PARTITION BY trial_id ORDER BY interval[1]) AS trial_specific_interval_id,
    ROW_NUMBER() OVER (ORDER BY trial_id, interval[1]) AS continuous_interval_id
  FROM (
    SELECT DISTINCT trial_id, unnest(major_decreasing_intervals) AS interval
    FROM temp_joined
  ) t
),
intervals_with_start_time AS (
  SELECT 
    tj.*,
    i.trial_specific_interval_id,
    i.continuous_interval_id,
    FIRST_VALUE(tj.normalized_timestamp) OVER (
      PARTITION BY tj.trial_id, i.continuous_interval_id 
      ORDER BY tj.normalized_timestamp
    ) AS interval_start_time
  FROM temp_joined tj
  JOIN interval_ids i ON 
    tj.trial_id = i.trial_id AND
    tj.normalized_timestamp >= i.interval[1] AND 
    tj.normalized_timestamp <= i.interval[2]
)
SELECT 
  *,
  normalized_timestamp - interval_start_time AS normalized_time  -- rename stuff TODO
FROM intervals_with_start_time
ORDER BY participant_id, trial_id, normalized_timestamp;

-- Query from the temporary table
SELECT * EXCLUDE (major_decreasing_intervals) FROM tmp
ORDER BY participant_id, trial_id, normalized_timestamp;
;
"""
with db:
    df = db.execute(query).pl()
df

AttributeError: 'DatabaseManager' object has no attribute 'conn'

In [None]:
df = prepare_multiline_hvplot(df, "normalized_time", "continuous_interval_id")

In [None]:
df.hvplot(
    x="normalized_time",
    y=["rating"],
    ylim=(0, 100),
    groupby="participant_id",
    kind="step",
    width=800,
    height=400,
    widget_type="scrubber",
    widget_location="bottom",
)

BokehModel(combine_events=True, render_bundle={'docs_json': {'8b62c744-c134-42d1-aafc-297167a8911d': {'version…

In [None]:
df.hvplot(
    x="normalized_time",
    y=["rating"],
    ylim=(0, 100),
    groupby=["stimulus_seed", "trial_specific_interval_id"],
    kind="step",
)

BokehModel(combine_events=True, render_bundle={'docs_json': {'20569c8e-82e2-4503-ae53-4933a5c34517': {'version…

In [None]:
df

trial_id,stimulus_seed,participant_id,normalized_timestamp,temperature,rating,trial_specific_interval_id,continuous_interval_id,interval_start_time,normalized_time
u16,u16,u16,f64,f64,f64,i64,i64,f64,f64
1,396,1,69035.6035,47.184339,100.0,1,1,69035.6035,0.0
1,396,1,69134.3401,47.184253,100.0,1,1,69035.6035,98.7366
1,396,1,69235.0699,47.183993,100.0,1,1,69035.6035,199.4664
1,396,1,69336.8015,47.183561,100.0,1,1,69035.6035,301.198
1,396,1,69436.5315,47.182956,100.0,1,1,69035.6035,400.928
…,…,…,…,…,…,…,…,…,…
332,133,28,166615.0306,45.459782,60.0,3,960,147032.6267,19582.4039
332,133,28,166714.7639,45.45935,60.0,3,960,147032.6267,19682.1372
332,133,28,166815.4953,45.459091,60.0,3,960,147032.6267,19782.8686
332,133,28,166915.2289,45.459004,60.0,3,960,147032.6267,19882.6022


In [None]:
# filter all intervals out that do not cover the whole rating spectrum ranging from 0 to 100
agg_ratings = df.group_by("continuous_interval_id", maintain_order=True).agg(
    min_rating=col("rating").min(),
    max_rating=col("rating").max(),
)
agg_ratings

agg_ratings = agg_ratings.filter(col("min_rating") == 0).filter(
    col("max_rating") == 100
)
agg_ratings

continuous_interval_id,min_rating,max_rating
i64,f64,f64
1,0.0,100.0
2,0.0,100.0
4,0.0,100.0
5,0.0,100.0
6,0.0,100.0
…,…,…
888,0.0,100.0
919,0.0,100.0
939,0.0,100.0
945,0.0,100.0
