In [60]:
from pathlib import Path

if Path.cwd().stem == "features":
    %cd ../..
    %load_ext autoreload
    %autoreload 2

In [61]:
import logging
import os
from dataclasses import dataclass
from functools import reduce, wraps
from pathlib import Path
from typing import Dict, List

import holoviews as hv
import hvplot.polars
import matplotlib.pyplot as plt
import neurokit2 as nk
import numpy as np
import pandas as pd
import panel as pn
import plotly.express as px
import polars as pl

from src.data.config_data import DataConfigBase
from src.data.config_data_raw import RAW_DICT, RAW_LIST, RawConfig
from src.data.config_participant import PARTICIPANT_LIST, ParticipantConfig
from src.data.make_dataset import load_dataset
from src.features.quality_checks import check_sample_rate
from src.features.transformations import (
    add_timedelta_column,
    interpolate,
    map_participant_datasets,
    map_trials,
    merge_dfs,
    scale_min_max,
    scale_standard,
)
from src.log_config import configure_logging
from src.visualization.plot_data import (
    plot_data_panel,
    plot_trial_matplotlib,
    plot_trial_plotly,
)

configure_logging(
    stream_level=logging.DEBUG,
    ignore_libs=["matplotlib", "Comm", "bokeh", "tornado", "param"],
)

hv.extension("plotly")

pl.Config.set_tbl_rows(7)  # don't print too many rows in the book
plt.rcParams["figure.figsize"] = [15, 5]  # default is [6, 4]

In [62]:
participant_number = 0
modality = "pupillometry"
data_config = RAW_DICT[modality]
sampling_rate = data_config.sampling_rate

pupillometry_raw = load_dataset(
    PARTICIPANT_LIST[participant_number], RAW_DICT[modality]
).dataset
stimulus = load_dataset(
    PARTICIPANT_LIST[participant_number], RAW_DICT["stimulus"]
).dataset


# pupillometry_raw = pupillometry_raw.unique('Timestamp').sort('Timestamp') # actually slightly faster than maintain_order=True
# logging.warning("Working with unique timestamps.")

check_sample_rate(pupillometry_raw)
check_sample_rate(pupillometry_raw, unique_timestamp=True)
check_sample_rate(stimulus, unique_timestamp=False)

15:15:15 | [36mDEBUG   [0m| make_dataset | Dataset 'pupillometry' for participant 0 loaded from data/raw/0/0_pupillometry.csv
15:15:15 | [36mDEBUG   [0m| make_dataset | Dataset 'stimulus' for participant 0 loaded from data/raw/0/0_stimulus.csv
15:15:15 | [36mDEBUG   [0m| quality_checks | Sample rate per trial: [59.91 59.94 59.93 59.93 59.94 59.93 59.92 59.93 59.93 59.93 59.92 59.93]
15:15:15 | [92mINFO    [0m| quality_checks | The mean sample rate is 59.93.
15:15:15 | [92mINFO    [0m| quality_checks | Checking sample rate for unique timestamps.
15:15:15 | [36mDEBUG   [0m| quality_checks | Sample rate per trial: [59.91 59.94 59.93 59.93 59.94 59.93 59.92 59.93 59.93 59.93 59.92 59.93]
15:15:15 | [92mINFO    [0m| quality_checks | The mean sample rate is 59.93.
15:15:15 | [36mDEBUG   [0m| quality_checks | Sample rate per trial: [10. 10. 10. 10. 10. 10. 10. 10. 10. 10. 10. 10.]
15:15:15 | [92mINFO    [0m| quality_checks | The mean sample rate is 10.00.


In [63]:
pupillometry_raw

Timestamp,Pupillometry_L,Pupillometry_R,Pupillometry_L_Distance,Pupillometry_R_Distance,Trial
f64,f64,f64,f64,f64,f64
214220.8889,3.738112,3.884729,611.875427,617.908997,0.0
214240.9759,3.733135,3.88679,611.932007,617.877991,0.0
214257.6095,3.735349,3.880681,612.07782,617.924622,0.0
214274.3519,3.738612,3.882757,611.967957,617.791199,0.0
…,…,…,…,…,…
3.3631e6,2.987041,2.960266,638.944824,640.338257,11.0
3.3631e6,2.981889,2.961481,638.932922,640.33728,11.0
3.3631e6,2.984752,2.963059,638.932922,640.33728,11.0


In [64]:
features = ["Pupillometry_L", "Pupillometry_R"]
pupillometry_raw.hvplot(
    x="Timestamp", y=features, groupby="Trial", kind="line", width=800, height=400
)

BokehModel(combine_events=True, render_bundle={'docs_json': {'877e6034-a890-4176-b41a-7765ccac0f00': {'version…

In [65]:
pupillometry_raw_trials = [
    group.select("Pupillometry_L", "Pupillometry_R").to_numpy().flatten()
    for _, group in pupillometry_raw.group_by(["Trial"])
]
pupillometry_raw_trial = pupillometry_raw_trials[0]

"""Same as:
# Get each trial as a separate np.array
groups = pupillometry_raw.group_by(["Trial"], maintain_order=True)
pupillometry_raw_trials = []
for _, group in groups:
    pupillometry_raw_trials.append(group.select('pupillometry_RAW').to_numpy().flatten())
""";

In [66]:
plot_trial_plotly(
    pupillometry_raw.drop("Pupillometry_L_Distance", "Pupillometry_R_Distance"), trial=1
)

In [67]:
pupillometry = pupillometry_raw.with_columns(
    [
        pl.when(pl.col("Pupillometry_L") == -1)
        .then(None)
        .otherwise(pl.col("Pupillometry_L"))
        .alias("Pupillometry_L"),
        pl.when(pl.col("Pupillometry_R") == -1)
        .then(None)
        .otherwise(pl.col("Pupillometry_R"))
        .alias("Pupillometry_R"),
    ]
)

In [68]:
pupillometry.drop("Pupillometry_L_Distance", "Pupillometry_R_Distance").plot(
    x="Timestamp", y=["Pupillometry_L", "Pupillometry_R"]
)

In [69]:
plot_trial_plotly(
    pupillometry.drop("Pupillometry_L_Distance", "Pupillometry_R_Distance"), 4
)

## Plot

In [70]:
merged = merge_dfs(pupillometry_raw, stimulus)
merged = scale_min_max(merged)
merged = interpolate(merged)
merged

AttributeError: 'Series' object has no attribute 'join'

In [None]:
plot_data_panel(merged.drop("Pupillometry_L_Distance", "Pupillometry_R_Distance"))

INFO:bokeh.server.server:Starting Bokeh server version 3.3.3 (running on Tornado 6.3.3)
INFO:bokeh.server.tornado:User authentication hooks NOT provided (default user enabled)
DEBUG:bokeh.server.tornado:These host origins can connect to the websocket: ['localhost:15550']
DEBUG:bokeh.server.tornado:Patterns are:
DEBUG:bokeh.server.tornado:  [('/favicon.ico',
DEBUG:bokeh.server.tornado:    <class 'bokeh.server.views.ico_handler.IcoHandler'>,
DEBUG:bokeh.server.tornado:    {'app': <bokeh.server.tornado.BokehTornado object at 0x141bf1950>}),
DEBUG:bokeh.server.tornado:   ('/?',
DEBUG:bokeh.server.tornado:    <class 'panel.io.server.DocHandler'>,
DEBUG:bokeh.server.tornado:    {'application_context': <bokeh.server.contexts.ApplicationContext object at 0x132de8810>,
DEBUG:bokeh.server.tornado:     'bokeh_websocket_path': '/ws'}),
DEBUG:bokeh.server.tornado:   ('/ws',
DEBUG:bokeh.server.tornado:    <class 'bokeh.server.views.ws.WSHandler'>,
DEBUG:bokeh.server.tornado:    {'application_context

Launching server at http://localhost:15550


INFO:tornado.access:200 GET / (127.0.0.1) 194.35ms
INFO:tornado.access:200 GET /static/extensions/panel/bundled/jquery/jquery.slim.min.js (127.0.0.1) 3.09ms
INFO:tornado.access:200 GET /static/extensions/panel/bundled/plotlyplot/plotly-2.18.0.min.js (127.0.0.1) 9.71ms
INFO:tornado.access:200 GET /static/js/bokeh.min.js?v=f43c49e86dc38c1a13b9f41aad15fb57c3b2f70844817e5559b32d9e0a177c319416281f7bac18181198884ceb3998420b37b2b0199e0d0dc6485e34fc0a28dc (127.0.0.1) 10.42ms
INFO:tornado.access:200 GET /static/js/bokeh-gl.min.js?v=bf37f0b457d54fefb6ca8423c37db6ae69479153907d223a22f57d090b957998e75abda056bf5b0916a24f99930fa6df3b242a1a3a0986b549fbc966c1e04416 (127.0.0.1) 10.63ms
INFO:tornado.access:200 GET /static/js/bokeh-widgets.min.js?v=3c2dbaf226dc96c10bf3dfbcde30557363d2c16ec86bf2a10fb615e53d3971cbcf801e5051aa500292ec49f54812deae2aec9aaad0d97331534c89fe18ede89a (127.0.0.1) 11.52ms
INFO:tornado.access:200 GET /static/js/bokeh-tables.min.js?v=7849f2320ea741465a49857765873105e961ae71f15b481c5c