In [1]:
# Jupyter settings and Imports

# %load_ext autoreload
# %autoreload 2
# %flow mode reactive

from datetime import date
import ipdb
from itertools import product
from pathlib import Path

from dotmap import DotMap
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
import seaborn as sns

import aeon.io.api as api
from aeon.io import reader
from aeon.schema.dataset import exp02, exp01
from aeon.analysis.utils import visits, distancetravelled

In [2]:
# Get sessions
pd.set_option('display.max_columns', 100)
pd.set_option('display.max_rows', 100)

roots = [Path("/ceph/aeon/aeon/data/raw/AEON3/presocial0.1"), Path("/ceph/aeon/aeon/data/raw/AEON2/presocial0.1")]
if not np.all([path.exists() for path in roots]):
    print("Cannot find root paths. Check path names or connection.")
    
subject_events = api.load(roots, exp02.ExperimentalMetadata.SubjectState)
sessions = visits(subject_events[subject_events.id.str.startswith("BAA-11030")])
sessions = sessions.loc[sessions["id"] != "BAA-1103048"]

In [3]:
# Prettify sessions

pd.options.mode.chained_assignment = None  # turn off "SettingWithCopy" warning for this cell

sessions = sessions[sessions.enter.dt.date >= date(2023, 3, 10)]
sessions = sessions[sessions.enter.dt.date != date(2023, 3, 29)]
sessions = sessions[sessions.duration > pd.Timedelta("1 hour")]
sessions.loc[:, ("weight_enter")] = sessions["weight_enter"].astype(float).round(1)
sessions.loc[:, ("weight_exit")] = sessions["weight_exit"].astype(float).round(1)
sessions.loc[:, ("enter")] = sessions["enter"].dt.floor("1s")
sessions.loc[:, ("exit")] = sessions["exit"].dt.ceil("1s")
sessions.loc[:, ("duration")] = sessions["duration"].round("1s")
sessions = sessions[["id", "enter", "exit", "duration", "weight_enter", "weight_exit"]]
sessions = sessions.sort_values(by="enter")
sessions = sessions.reset_index()
sessions = sessions.drop(columns=["index"])
pd.options.mode.chained_assignment = "warn"
display(sessions)

Unnamed: 0,id,enter,exit,duration,weight_enter,weight_exit
0,BAA-1103050,2023-03-10 09:41:48,2023-03-10 12:55:19,0 days 03:13:30,23.2,23.9
1,BAA-1103045,2023-03-10 12:12:45,2023-03-10 15:22:14,0 days 03:09:28,23.0,23.7
2,BAA-1103047,2023-03-10 15:27:05,2023-03-10 19:10:44,0 days 03:43:38,19.8,21.0
3,BAA-1103049,2023-03-10 16:22:29,2023-03-10 19:21:50,0 days 02:59:20,20.9,22.6
4,BAA-1103044,2023-03-17 14:44:00,2023-03-17 19:15:43,0 days 04:31:42,25.0,23.0
5,BAA-1103045,2023-03-23 10:16:38,2023-03-23 13:19:43,0 days 03:03:03,23.6,25.1
6,BAA-1103049,2023-03-23 11:15:29,2023-03-23 14:23:41,0 days 03:08:11,22.0,24.4
7,BAA-1103047,2023-03-23 13:29:36,2023-03-23 16:32:51,0 days 03:03:14,23.2,22.4
8,BAA-1103050,2023-03-23 14:30:26,2023-03-23 17:29:18,0 days 02:58:51,23.5,24.9
9,BAA-1103050,2023-03-24 09:11:06,2023-03-24 10:52:58,0 days 01:41:51,23.9,25.4


In [4]:
# Get bad sessions based on Get 'DispenserBroken' and 'Annotation' messages
message_log_aeon3 = api.load(str(roots[0]), exp02.ExperimentalMetadata.MessageLog)
print(f"Aeon3 messages:\n")
display(message_log_aeon3[np.logical_or(message_log_aeon3.type == "DispenserBroken", message_log_aeon3.type == "Annotation")])
print(f"\n\n")
message_log_aeon2 = api.load(str(roots[0]), exp02.ExperimentalMetadata.MessageLog)
print(f"Aeon2 messages:\n")
display(message_log_aeon2[np.logical_or(message_log_aeon2.type == "DispenserBroken", message_log_aeon2.type == "Annotation")])

Aeon3 messages:



Unnamed: 0_level_0,priority,type,message
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2023-02-27 12:37:32.836480141,Alert,DispenserBroken,Patch2
2023-02-27 12:38:31.825503826,Alert,DispenserBroken,Patch2
2023-02-27 12:43:57.997504234,Alert,DispenserBroken,Patch2
2023-02-27 12:51:15.807487965,Alert,DispenserBroken,Patch2
2023-02-27 12:55:49.355487823,Alert,DispenserBroken,Patch1
...,...,...,...
2023-06-21 14:40:11.161503792,Alert,DispenserBroken,Patch1
2023-06-21 14:42:05.131487845,Alert,DispenserBroken,Patch1
2023-06-21 14:47:10.425504208,Alert,DispenserBroken,Patch1
2023-06-21 14:56:13.303487778,Alert,DispenserBroken,Patch1





Aeon2 messages:



Unnamed: 0_level_0,priority,type,message
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2023-02-27 12:37:32.836480141,Alert,DispenserBroken,Patch2
2023-02-27 12:38:31.825503826,Alert,DispenserBroken,Patch2
2023-02-27 12:43:57.997504234,Alert,DispenserBroken,Patch2
2023-02-27 12:51:15.807487965,Alert,DispenserBroken,Patch2
2023-02-27 12:55:49.355487823,Alert,DispenserBroken,Patch1
...,...,...,...
2023-06-21 14:40:11.161503792,Alert,DispenserBroken,Patch1
2023-06-21 14:42:05.131487845,Alert,DispenserBroken,Patch1
2023-06-21 14:47:10.425504208,Alert,DispenserBroken,Patch1
2023-06-21 14:56:13.303487778,Alert,DispenserBroken,Patch1


In [5]:
# Based on above, manually decide which are bad sessions, and drop these from `sessions`

bad_sessions = DotMap()
bad_sessions.ids = (
    "BAA-1103048", 
    "BAA-1103044",
    "BAA-1103050",
    "BAA-1103048",
    "BAA-1103050",
    "BAA-1103049",
    "BAA-1103049",
    "BAA-1103045",
    "BAA-1103050",
    "BAA-1103049",
    "BAA-1103050",
    "BAA-1103050",
)
bad_sessions.dates = (
    date(2023, 3, 15),  # bugs in workflow
    date(2023, 3, 17),  # rfid session
    date(2023, 3, 24),  # only stayed on one patch from beginning
    date(2023, 3, 24),  # poop stuck on wheel
    date(2023, 6, 6),   # stayed on one patch (patch2) the entire time
    date(2023, 6, 7),   # stayed on one patch (patch1) the entire time
    date(2023, 6, 12),  # stayed on one patch (patch1) the entire time
    date(2023, 6, 12),  # stayed on one patch (patch1) the entire time
    date(2023, 6, 12),  # bonsai crashed
    date(2023, 6, 15),  # bonsai crashed
    date(2023, 6, 20),  # bonsai crashed
    date(2023, 6, 21),  # bonsai crashed
)

for i in range(len(bad_sessions.ids)):
    i_bad_sesh = np.where(np.logical_and(
        sessions.id == bad_sessions.ids[i], sessions.enter.dt.date == bad_sessions.dates[i]))[0]
    sessions.drop(index=sessions.iloc[i_bad_sesh].index, inplace=True)
sessions = sessions.sort_values(by="enter")
sessions = sessions.reset_index()
sessions = sessions.drop(columns=["index"])

In [6]:
# Declare some set-up variables to help with analysis

# Specify which animals in which room
in_b2_210 = ("48", "49", "50")
in_465 = ("45", "47")

# Columns to add to table
new_cols = (
    "post_thresh_dur", 
    "post_thresh_both_p_sampled_dur",
    "pre_sampling_both_p_dur", 
    "easy_patch", 
    "hard_patch", 
    "post_easy_rate", 
    "post_hard_rate", 
    "pre_easy_n_pel", 
    "pre_hard_n_pel", 
    "post_easy_n_pel", 
    "post_hard_n_pel", 
    "pre_easy_wheel_dist", 
    "pre_hard_wheel_dist",
    "post_easy_wheel_dist", 
    "post_hard_wheel_dist", 
    "pre_easy_pref", 
    "post_easy_pref",
    "pre_hard_pref", 
    "post_hard_pref", 
    "post_pre_easy_pref", 
    "post_easy_pel_thresh", 
    "post_easy_pel_thresh_idx", 
    "post_hard_pel_thresh", 
    "post_hard_pel_thresh_idx", 
    "init_pref_by_pel_ct", 
    "epoch_thresh_change_idx", 
    "easy_pref_epoch_cum", 
    "easy_pref_epoch",
    "cont_patch_pref"
)
for col in new_cols: sessions[col] = np.nan

sessions["post_easy_pel_thresh"] = sessions["post_easy_pel_thresh"].astype(object)
sessions["post_hard_pel_thresh"] = sessions["post_hard_pel_thresh"].astype(object)
sessions["post_easy_pel_thresh_idx"] = sessions["post_easy_pel_thresh_idx"].astype(object)
sessions["post_hard_pel_thresh_idx"] = sessions["post_hard_pel_thresh_idx"].astype(object)
sessions["init_pref_by_pel_ct"] = sessions["init_pref_by_pel_ct"].astype(object)
sessions["easy_pref_epoch_cum"] = sessions["easy_pref_epoch_cum"].astype(object)
sessions["easy_pref_epoch"] = sessions["easy_pref_epoch"].astype(object)
sessions["cont_patch_pref"] = sessions["cont_patch_pref"].astype(object)
display(sessions)

Unnamed: 0,id,enter,exit,duration,weight_enter,weight_exit,post_thresh_dur,post_thresh_both_p_sampled_dur,pre_sampling_both_p_dur,easy_patch,hard_patch,post_easy_rate,post_hard_rate,pre_easy_n_pel,pre_hard_n_pel,post_easy_n_pel,post_hard_n_pel,pre_easy_wheel_dist,pre_hard_wheel_dist,post_easy_wheel_dist,post_hard_wheel_dist,pre_easy_pref,post_easy_pref,pre_hard_pref,post_hard_pref,post_pre_easy_pref,post_easy_pel_thresh,post_easy_pel_thresh_idx,post_hard_pel_thresh,post_hard_pel_thresh_idx,init_pref_by_pel_ct,epoch_thresh_change_idx,easy_pref_epoch_cum,easy_pref_epoch,cont_patch_pref
0,BAA-1103050,2023-03-10 09:41:48,2023-03-10 12:55:19,0 days 03:13:30,23.2,23.9,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
1,BAA-1103045,2023-03-10 12:12:45,2023-03-10 15:22:14,0 days 03:09:28,23.0,23.7,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
2,BAA-1103047,2023-03-10 15:27:05,2023-03-10 19:10:44,0 days 03:43:38,19.8,21.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
3,BAA-1103049,2023-03-10 16:22:29,2023-03-10 19:21:50,0 days 02:59:20,20.9,22.6,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
4,BAA-1103045,2023-03-23 10:16:38,2023-03-23 13:19:43,0 days 03:03:03,23.6,25.1,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
5,BAA-1103049,2023-03-23 11:15:29,2023-03-23 14:23:41,0 days 03:08:11,22.0,24.4,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
6,BAA-1103047,2023-03-23 13:29:36,2023-03-23 16:32:51,0 days 03:03:14,23.2,22.4,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
7,BAA-1103050,2023-03-23 14:30:26,2023-03-23 17:29:18,0 days 02:58:51,23.5,24.9,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
8,BAA-1103047,2023-03-24 12:08:58,2023-03-24 15:12:50,0 days 03:03:51,22.2,22.3,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
9,BAA-1103049,2023-03-24 14:22:48,2023-03-24 18:31:16,0 days 04:08:27,21.7,22.3,,,,,,,,,,,,,,,,,,,,,,,,,,,,,


In [7]:
w_chunk_t = 500         # time chunks in samples (1 sample = 2 ms)
w_chunk_dist = 500      # distance chunks in cm
pref_window = 3000      # distance in cm
pref_thresh = 0.80      # preference threshold to check over `pref_window`
n_distris = 100         # number of diff distributions to generate
#for s in sessions.itertuples():
for s in list(sessions.itertuples())[64:]:
    root = str(roots[0]) if np.any([s.id.endswith(sid) for sid in in_b2_210]) else str(roots[1])  # get root for current session
    root = str(roots[1]) if s.enter.date() == date(2023, 6, 22) else root
    harp_reader = reader.Harp(pattern="Patch1_35*", columns=["TriggerPellet"])
    
    
    new_pellet_trig_bitmask = api.load(root, harp_reader, start=s.enter, end=s.exit).iloc[0, 0]
    new_pellet_trig_reader_p1 = reader.BitmaskEvent("Patch1_35*", new_pellet_trig_bitmask, "TriggerPellet")
    new_pellet_trig_reader_p2 = reader.BitmaskEvent("Patch2_35*", new_pellet_trig_bitmask, "TriggerPellet")
    
    
    p1 = api.load(root, new_pellet_trig_reader_p1, start=s.enter, end=s.exit)
    p2 = api.load(root, new_pellet_trig_reader_p2, start=s.enter, end=s.exit)
    
    
    pstate1 = api.load(root, exp02.Patch1.DepletionState, start=s.enter, end=s.exit)
    pstate2 = api.load(root, exp02.Patch2.DepletionState, start=s.enter, end=s.exit)
    
    
    encoder1 = api.load(root, exp02.Patch1.Encoder, start=s.enter, end=s.exit)
    w1 = -distancetravelled(encoder1.angle)
    
    encoder2 = api.load(root, exp02.Patch2.Encoder, start=s.enter, end=s.exit)
    w2 = -distancetravelled(encoder2.angle)
    
    
    # PelletTrig cleaning: remove repeated deliveries (events <1.5 s apart) and manual deliveries (201)
    p1 = p1.drop(p1.index[np.where(np.diff(p1.index).astype("float64") < 1.5e9)[0]])
    p2 = p2.drop(p2.index[np.where(np.diff(p2.index).astype("float64") < 1.5e9)[0]])
    harp_reader = reader.Harp(pattern="Patch1_201", columns=["ExperimenterDeliveries"])
    user_p1 = api.load(root, harp_reader, start=s.enter, end=s.exit)
    harp_reader = reader.Harp(pattern="Patch2_201", columns=["ExperimenterDeliveries"])
    user_p2 = api.load(root, harp_reader, start=s.enter, end=s.exit)
    if not user_p1.empty:
        user_p1_idxs = np.abs(np.subtract.outer(user_p1.index, p1.index)).argmin(axis=1)
        p1.drop(p1.index[user_p1_idxs])
    if not user_p2.empty:
        user_p2_idxs = np.abs(np.subtract.outer(user_p2.index, p2.index)).argmin(axis=1)
        p2.drop(p2.index[user_p2_idxs])
    both_pellet_data = pd.concat([p1, p2]).sort_index()
    # PatchState cleaning: remove NaNs; remove updates <1.5s apart (bug updates)
    pstate1.dropna(inplace=True)
    good_indxs = np.concatenate((np.diff(pstate1.index).astype("float64") > 1.5e9, [True]))
    pstate1 = pstate1[good_indxs]
    pstate2.dropna(inplace=True)
    good_indxs = np.concatenate((np.diff(pstate2.index).astype("float64") > 1.5e9, [True]))
    pstate2 = pstate2[good_indxs]
    # Clean known issues in particular sessions
    if s.enter == pd.Timestamp("2023-03-24 14:22:48"):  # last threshold update of 75 for some reason
        pstate1 = pstate1.drop(pstate1.index[-1])
        pstate2 = pstate2.drop(pstate2.index[-1])
    if s.enter == pd.Timestamp("2023-03-10 13:08:24"):  # TriggerPellet at very end of session for some reason
        p2 = p2.drop(p2.index[-1])
    if s.enter == pd.Timestamp("2023-06-15 13:33:30"):  # @todo check this
        p2 = p2.drop(p2.index[-4:])
    if s.enter == pd.Timestamp("2023-06-21 09:49:02"):  # @todo check this
        pstate1 = pstate1.drop(pstate1.index[-2:])
        
    # Check lengths of PelletTrigger and PatchState events
    if ((len(pstate1) - len(p1)) not in (0, 1, 2)) or ((len(pstate2) - len(p2)) not in (0, 1, 2)):
        raise Exception(
            f"PelletTrigger-PatchState mismatch: \n"
            f"len(p1) = {len(p1)} \n"
            f"len(p2) = {len(p2)} \n"
            f"len(pstate1) = {len(pstate1)} \n"
            f"len(pstate2) = {len(pstate2)} \n"
        )
    both_state_data = pd.concat([pstate1, pstate2]).sort_index()
    
    if len(w1) > len(w2):  # ensure same num samples for each wheel
        w1 = w1[:len(w2)]
    else:
        w2 = w2[:len(w1)]
    if len(w1) % w_chunk_t == 0:
        w1 = w1[:-1]
        w2 = w2[:-1]
    if w1.index[-1] != w2.index[-1]:
        print(f"WARNING: sync issues seen in wheel data for {s.id} {s.enter}. Automatically reindexing to continue")
        if w1.index.is_monotonic_increasing and not np.any(w1.index.duplicated()):
            w2.index = w1.index
        else:
            w1.index = w2.index
    wboth = w1 + w2
    
    # Find threshold-change ts
    thresh_change_idx = np.where(np.abs(np.diff(both_state_data.threshold)) > 1)[0][0] + 1
    safe_change_ts = change_ts = both_state_data.index[thresh_change_idx]
    sessions.loc[s.Index, "post_thresh_dur"] = post_thresh_dur = (s.exit - change_ts).round("1s")
    # if (len(p2[p2.index > change_ts]) > 0) and (len(p1[p1.index > change_ts]) > 0):
    #     safe_change_ts = pd.Series((p1[p1.index > change_ts].index[0], p2[p2.index > change_ts].index[0])).max()
    #     sessions.loc[s.Index, "post_thresh_both_p_sampled_dur"] = post_thresh_both_p_sampled_dur = (s.exit - safe_change_ts).round("1s")
    # else:
    #     safe_change_ts = change_ts
    
    # Find both-patches-sampled ts
    both_patches_sampled_ts = pd.Series((p1.index[0], p2.index[0])).max()
    sessions.loc[s.Index, "pre_sampling_both_p_dur"] = pre_sampling_b_patches_dur = (both_patches_sampled_ts - s.enter).round("1s")
    
    if (np.any(p1.index > safe_change_ts) and np.any(p2.index > safe_change_ts)):
        both_patches_sampled_ts_post = (
            pd.Series((p1.index[p1.index > safe_change_ts][0], 
                        p2.index[p2.index > safe_change_ts][0])).max()
        )
        sessions.loc[s.Index, "post_thresh_both_p_sampled_dur"] = (
            (both_patches_sampled_ts_post - safe_change_ts).round("1s")
        )
        
    if s.enter == pd.Timestamp("2023-06-05 14:30:00"):
        pstate1["delta"][-1] = 0.002
        pstate2["delta"][-1] = 0.0033
    sessions.loc[s.Index, "hard_patch"] = hard_patch = 1 if (pstate1["delta"][-1] < pstate2["delta"][-1]) else 2
    sessions.loc[s.Index, "easy_patch"] = easy_patch = 1 if (hard_patch == 2) else 2
    sessions.loc[s.Index, "post_hard_rate"] = post_hard_rate = pstate1["delta"][-1] if (hard_patch == 1) else pstate2["delta"][-1]
    sessions.loc[s.Index, "post_easy_rate"] = post_easy_rate = pstate1["delta"][-1] if (hard_patch == 2) else pstate2["delta"][-1]
    
    whard = w1 if (hard_patch == 1) else w2
    weasy = w1 if (easy_patch == 1) else w2
    p1_pre_n_pel = len(p1[p1.index <= (safe_change_ts + pd.Timedelta("1s"))])  # ensure we don't count last pellet in pre as first pellet in post
    p1_post_n_pel = len(p1[p1.index > (safe_change_ts + pd.Timedelta("1s"))])
    p2_pre_n_pel = len(p2[p2.index <= (safe_change_ts + pd.Timedelta("1s"))])
    p2_post_n_pel = len(p2[p2.index > (safe_change_ts + pd.Timedelta("1s"))])
    sessions.loc[s.Index, "pre_easy_n_pel"] = pre_easy_n_pel = p1_pre_n_pel if (easy_patch == 1) else p2_pre_n_pel
    sessions.loc[s.Index, "pre_hard_n_pel"] = pre_hard_n_pel = p1_pre_n_pel if (hard_patch == 1) else p2_pre_n_pel
    
    p1_pre_wheel_dist = w1[w1.index > safe_change_ts][0] - w1[0]
    p2_pre_wheel_dist = w2[w2.index > safe_change_ts][0] - w2[0]
    sessions.loc[s.Index, "pre_easy_wheel_dist"] = pre_easy_wheel_dist = p1_pre_wheel_dist if (easy_patch == 1) else p2_pre_wheel_dist
    sessions.loc[s.Index, "pre_hard_wheel_dist"] = pre_hard_wheel_dist = p1_pre_wheel_dist if (hard_patch == 1) else p2_pre_wheel_dist
    sessions.loc[s.Index, "post_easy_n_pel"] = post_easy_n_pel = p1_post_n_pel if (easy_patch == 1) else p2_post_n_pel
    sessions.loc[s.Index, "post_hard_n_pel"] = post_hard_n_pel = p1_post_n_pel if (hard_patch == 1) else p2_post_n_pel
    
    p1_post_wheel_dist = w1[-1] - p1_pre_wheel_dist
    p2_post_wheel_dist = w2[-1] - p2_pre_wheel_dist
    sessions.loc[s.Index, "post_easy_wheel_dist"] = post_easy_wheel_dist = p1_post_wheel_dist if (easy_patch == 1) else p2_post_wheel_dist
    sessions.loc[s.Index, "post_hard_wheel_dist"] = post_hard_wheel_dist = p1_post_wheel_dist if (hard_patch == 1) else p2_post_wheel_dist
    
    sessions.loc[s.Index, "pre_easy_pref"] = pre_easy_pref = pre_easy_wheel_dist / (pre_easy_wheel_dist + pre_hard_wheel_dist)
    sessions.loc[s.Index, "post_easy_pref"] = post_easy_pref = post_easy_wheel_dist / (post_easy_wheel_dist + post_hard_wheel_dist)
    sessions.loc[s.Index, "pre_hard_pref"] = post_hard_pref = 1 - pre_easy_pref
    sessions.loc[s.Index, "post_hard_pref"] = post_hard_pref = 1 - post_easy_pref
    sessions.loc[s.Index, "post_pre_easy_pref"] = post_pre_easy_pref = post_easy_pref / pre_easy_pref
    
    # Find each pstate update prior to each pellet threshold crossing
    p1_post_pel_thresh = pstate1[pstate1.index >= safe_change_ts].threshold[:-1]
    #p1_post_pel_thresh = np.nan if p1_post_pel_thresh.empty else p1_post_pel_thresh
    p2_post_pel_thresh = pstate2[pstate2.index >= safe_change_ts].threshold[:-1]
    #p2_post_pel_thresh = np.nan if p2_post_pel_thresh.empty else p2_post_pel_thresh
    post_easy_pel_thresh = p1_post_pel_thresh if (easy_patch == 1) else p2_post_pel_thresh
    post_hard_pel_thresh = p1_post_pel_thresh if (hard_patch == 1) else p2_post_pel_thresh
    sessions.at[s.Index, "post_easy_pel_thresh"] = post_easy_pel_thresh.values.round(3)
    sessions.at[s.Index, "post_hard_pel_thresh"] = post_hard_pel_thresh.values.round(3)
    sessions.at[s.Index, "post_easy_pel_thresh_idx"] = np.array(post_easy_pel_thresh.index.round("1s"))
    sessions.at[s.Index, "post_hard_pel_thresh_idx"] = np.array(post_hard_pel_thresh.index.round("1s"))
    
    init_pref_by_pel_ct = np.ones((10,)) * np.nan
    for i, pel_ct in enumerate(range(8,18)):
        cur_pel_ct_ts = both_pellet_data.index[pel_ct]
        if cur_pel_ct_ts > (safe_change_ts + pd.Timedelta("1s")):
            break
        cur_whard_dist = whard[whard.index > cur_pel_ct_ts][0] - whard[0]
        cur_weasy_dist = weasy[weasy.index > cur_pel_ct_ts][0] - weasy[0] 
        init_pref_by_pel_ct[i] = cur_whard_dist / (cur_whard_dist + cur_weasy_dist)
    sessions.at[s.Index, "init_pref_by_pel_ct"] = init_pref_by_pel_ct
    
    wboth_quantized = np.linspace(0, wboth[-1], 11)
    easy_pref_epoch_cum = np.zeros((10,))
    easy_pref_epoch = np.zeros((10,))
    epoch_thresh_change_idx = 0
    epoch_ts_pre = wboth.index[0]
    for i in range(1, 10):
        epoch_ts_post = wboth[wboth > wboth_quantized[i]].index[0] - pd.Timedelta("1s")
        if (epoch_ts_post > safe_change_ts) and not epoch_thresh_change_idx:
            epoch_thresh_change_idx = i
        weasy_post = weasy[weasy.index > epoch_ts_post][0]
        whard_post = whard[whard.index > epoch_ts_post][0]
        weasy_pre = weasy[weasy.index > epoch_ts_pre][0]
        whard_pre = whard[whard.index > epoch_ts_pre][0]
        weasy_diff = weasy_post - weasy_pre
        whard_diff = whard_post - whard_pre
        easy_pref_epoch_cum[i] = weasy_post / (weasy_post + whard_post)
        easy_pref_epoch[i] = weasy_diff / (weasy_diff + whard_diff)
        epoch_ts_pre = epoch_ts_post
    sessions.at[s.Index, "easy_pref_epoch_cum"] = easy_pref_epoch_cum
    sessions.at[s.Index, "easy_pref_epoch"] = easy_pref_epoch
    sessions.loc[s.Index, "epoch_thresh_change_idx"] = epoch_thresh_change_idx
    
    # <s Get chunked patch pref compared to synthetic data
    # <ss Chunk (downsample) wheel data
    weasy_chnkd = np.abs((weasy[(w_chunk_t - 1)::w_chunk_t]).values - (weasy[::w_chunk_t][:-1]).values)
    #weasy_chnkd = np.abs((weasy[(w_chunk_t - 1)::w_chunk_t]).values - (weasy[::w_chunk_t]).values)
    weasy_chnkd_cumsum = weasy_chnkd.cumsum()
    whard_chnkd = np.abs((whard[(w_chunk_t - 1)::w_chunk_t]).values - (whard[::w_chunk_t][:-1]).values)
    #whard_chnkd = np.abs((whard[(w_chunk_t - 1)::w_chunk_t]).values - (whard[::w_chunk_t]).values)
    whard_chnkd_cumsum = whard_chnkd.cumsum()
    w_all_chnkd_cumsum = weasy_chnkd_cumsum + whard_chnkd_cumsum
    n_samples = len(weasy_chnkd)
    pref_first_idx = np.where(w_all_chnkd_cumsum > w_chunk_dist)[0][0]
    end_idxs = np.arange(pref_first_idx, n_samples, 1).astype(int)
    start_idxs = np.zeros((len(end_idxs),)).astype(int)
    for i, idx in enumerate(end_idxs):
        start_idxs[i] = np.where((w_all_chnkd_cumsum[0:idx] + w_chunk_dist) > w_all_chnkd_cumsum[idx])[0][0]
    # /ss>
    # <ss Get true chunked patch pref
    weasy_diff = weasy_chnkd_cumsum[end_idxs] - weasy_chnkd_cumsum[start_idxs]
    whard_diff = whard_chnkd_cumsum[end_idxs] - whard_chnkd_cumsum[start_idxs]
    weasy_pref = weasy_diff / (weasy_diff + whard_diff)
    # /ss>
    # <ss Generate individual wheel null distributions
    w_all_chnkd = np.concatenate((weasy_chnkd, whard_chnkd))
    syn_chunk_pref_dists = np.zeros((n_distris, len(weasy_pref)))
    for distri_n in range(n_distris):
        # Create synthetic distributions
        weasy_chnkd_gen = np.random.choice(w_all_chnkd, size=n_samples, replace=False)
        whard_chnkd_gen = np.random.choice(w_all_chnkd, size=n_samples, replace=False)
        impossible_idxs = np.where(np.logical_and(weasy_chnkd_gen > 0.1, whard_chnkd_gen > 0.1))[0]
        for ii in impossible_idxs:
            if weasy_chnkd_gen[ii] > whard_chnkd_gen[ii]:
                whard_chnkd_gen[ii] = 0
            else:
                weasy_chnkd_gen[ii] = 0
        weasy_chnkd_gen_cumsum = weasy_chnkd_gen.cumsum()
        whard_chnkd_gen_cumsum = whard_chnkd_gen.cumsum()
        w_all_chnkd_gen_cumsum = weasy_chnkd_gen_cumsum + whard_chnkd_gen_cumsum
        # Get synthetic patch pref
        end_idxs = np.arange(pref_first_idx, n_samples, 1).astype(int)
        start_idxs = np.zeros((len(end_idxs),)).astype(int)
        for i, idx in enumerate(end_idxs):
            start_idxs[i] = np.where(
                (w_all_chnkd_gen_cumsum[0:idx] + w_chunk_dist) 
                > w_all_chnkd_gen_cumsum[idx]
            )[0][0]
        weasy_diff_gen = weasy_chnkd_gen_cumsum[end_idxs] - weasy_chnkd_gen_cumsum[start_idxs]
        whard_diff_gen = whard_chnkd_gen_cumsum[end_idxs] - whard_chnkd_gen_cumsum[start_idxs]
        weasy_pref_gen = weasy_diff_gen / (weasy_diff_gen + whard_diff_gen)
        syn_chunk_pref_dists[distri_n, :] = weasy_pref_gen
    # /ss>
    # <ss Get the 2.5th and 97.5th percentiles of the null distributions
    syn_chunk_pref_dists = np.sort(syn_chunk_pref_dists, axis=0)
    low_bound = syn_chunk_pref_dists[3, :]
    high_bound = syn_chunk_pref_dists[96, :]
    # /ss>
    # <ss Check if learning criteria is met
    learned_start_idx = None
    learned_end_idx = None
    pref_idxs = np.where(weasy_pref > high_bound)[0]
    # For each pref_idx, find the first earlier idx with `pref_window` less
    # cum distance, then see if pref over this window is > `pref_thresh`
    for pref_start_idx in pref_idxs:
        pref_end_idx = np.where(
            w_all_chnkd_cumsum[pref_start_idx:] 
            > (w_all_chnkd_cumsum[pref_start_idx] + pref_window)
        )[0]
        if pref_end_idx.size > 0:
            pref_end_idx = pref_end_idx[0] + pref_start_idx
            pref_p = np.sum(
                weasy_pref[pref_start_idx : pref_end_idx] 
                > high_bound[pref_start_idx : pref_end_idx]
            ) / (pref_end_idx - pref_start_idx)
            if pref_p > pref_thresh:
                learned_start_idx = pref_start_idx
                learned_end_idx = pref_end_idx
                break
    # /ss>
    cont_patch_pref = DotMap(
        w_all_chnkd_cumsum=w_all_chnkd_cumsum.astype('float32'),
        weasy_pref=weasy_pref.astype('float32'),
        low_bound=low_bound.astype('float32'),
        high_bound=high_bound.astype('float32'),
        learned_start_idx=learned_start_idx,
        learned_end_idx=learned_end_idx,
        thresh_change_idx=(safe_change_ts - s.enter).seconds
    )
    sessions.at[s.Index, "cont_patch_pref"] = cont_patch_pref
    
    if learned_start_idx:
        print(f"Learned: {s.id} {s.enter} ... {post_easy_rate} {post_hard_rate}")
    # /s>    

    

Learned: BAA-1103047 2023-06-21 13:24:35 ... 0.0033 0.0014
Learned: BAA-1103049 2023-06-22 09:28:03 ... 0.0033 0.0014
Learned: BAA-1103050 2023-06-22 12:41:54 ... 0.0033 0.0014


In [8]:
s

Pandas(Index=66, id='BAA-1103050', enter=Timestamp('2023-06-22 12:41:54'), exit=Timestamp('2023-06-22 15:42:40'), duration=Timedelta('0 days 03:00:45'), weight_enter=25.8, weight_exit=26.5, post_thresh_dur=nan, post_thresh_both_p_sampled_dur=nan, pre_sampling_both_p_dur=nan, easy_patch=nan, hard_patch=nan, post_easy_rate=nan, post_hard_rate=nan, pre_easy_n_pel=nan, pre_hard_n_pel=nan, post_easy_n_pel=nan, post_hard_n_pel=nan, pre_easy_wheel_dist=nan, pre_hard_wheel_dist=nan, post_easy_wheel_dist=nan, post_hard_wheel_dist=nan, pre_easy_pref=nan, post_easy_pref=nan, pre_hard_pref=nan, post_hard_pref=nan, post_pre_easy_pref=nan, post_easy_pel_thresh=nan, post_easy_pel_thresh_idx=nan, post_hard_pel_thresh=nan, post_hard_pel_thresh_idx=nan, init_pref_by_pel_ct=nan, epoch_thresh_change_idx=nan, easy_pref_epoch_cum=nan, easy_pref_epoch=nan, cont_patch_pref=nan)

In [9]:
root

'/ceph/aeon/aeon/data/raw/AEON2/presocial0.1'

In [10]:
api.load(root, harp_reader, start=s.enter, end=s.exit)

Unnamed: 0_level_0,ExperimenterDeliveries
time,Unnamed: 1_level_1


In [11]:
cols_to_round = [
    "pre_easy_wheel_dist",
    "pre_hard_wheel_dist",
    "post_easy_wheel_dist",
    "post_hard_wheel_dist",
    "pre_easy_pref",
    "post_easy_pref",
    "pre_hard_pref",
    "post_hard_pref",
]
for col in cols_to_round:
    sessions[col] = sessions[col].round(3)

In [12]:
sessions.loc[49, "post_easy_rate"] = 0.0033
sessions.loc[49, "post_hard_rate"] = 0.0025
sessions.loc[62, "post_easy_rate"] = 0.0033

In [13]:
sessions.to_pickle(Path(
    "/nfs/nhome/live/jbhagat/ProjectAeon/aeon_analysis/aeon_analysis/presocial/data"
    "/presocial_data.pkl"
    )
)

PermissionError: [Errno 13] Permission denied: '/nfs/nhome/live/jbhagat/ProjectAeon/aeon_analysis/aeon_analysis/presocial/data/presocial_data.pkl'

In [None]:
df = pd.read_pickle(
    Path(
        "/nfs/nhome/live/jbhagat/ProjectAeon/aeon_analysis/aeon_analysis/presocial/data"
        "/presocial_data.pkl"
    )
)

In [None]:
display(df)

Unnamed: 0,id,enter,exit,duration,weight_enter,weight_exit,post_thresh_dur,post_thresh_both_p_sampled_dur,pre_sampling_both_p_dur,easy_patch,hard_patch,post_easy_rate,post_hard_rate,pre_easy_n_pel,pre_hard_n_pel,post_easy_n_pel,post_hard_n_pel,pre_easy_wheel_dist,pre_hard_wheel_dist,post_easy_wheel_dist,post_hard_wheel_dist,pre_easy_pref,post_easy_pref,pre_hard_pref,post_hard_pref,post_pre_easy_pref,post_easy_pel_thresh,post_easy_pel_thresh_idx,post_hard_pel_thresh,post_hard_pel_thresh_idx,init_pref_by_pel_ct,epoch_thresh_change_idx,easy_pref_epoch_cum,easy_pref_epoch
0,BAA-1103050,2023-03-10 09:41:48,2023-03-10 12:55:19,0 days 03:13:30,23.2,23.9,0 days 02:22:53,0 days 00:12:53,0 days 00:17:40,2.0,1.0,0.01,0.0025,1.0,17.0,21.0,18.0,89.731,1275.491,3207.803,8580.92,0.066,0.272,0.934,0.728,4.140009,"[121.823, 239.646, 214.056, 109.924, 160.179, ...","[2023-03-10T10:32:26.000000000, 2023-03-10T10:...","[787.955, 453.691, 131.533, 306.643, 380.248, ...","[2023-03-10T10:32:26.000000000, 2023-03-10T10:...","[0.8815044790108325, 0.8833988643958598, 0.893...",2.0,"[0.0, 0.06909652367344479, 0.2668156136292335,...","[0.0, 0.06909424261898463, 0.4617004033497592,..."
1,BAA-1103045,2023-03-10 12:12:45,2023-03-10 15:22:14,0 days 03:09:28,23.0,23.7,0 days 02:15:22,0 days 00:32:38,0 days 00:07:38,2.0,1.0,0.01,0.0025,9.0,9.0,24.0,14.0,637.379,671.408,3490.806,7168.452,0.487,0.327,0.513,0.673,0.672465,"[129.145, 79.028, 191.036, 87.88, 123.836, 88....","[2023-03-10T13:06:52.000000000, 2023-03-10T13:...","[1549.34, 194.009, 117.725, 465.422, 329.508, ...","[2023-03-10T13:06:52.000000000, 2023-03-10T13:...","[0.40413871418162683, 0.3611177037993481, 0.32...",2.0,"[0.0, 0.5122037205671914, 0.268011025453626, 0...","[0.0, 0.5121998496158439, 0.02782743102960722,..."
2,BAA-1103048,2023-03-10 13:08:24,2023-03-10 16:16:31,0 days 03:08:06,22.5,24.6,0 days 01:28:08,0 days 00:01:32,0 days 01:39:59,2.0,1.0,0.01,0.0025,1.0,68.0,1.0,33.0,89.789,5115.58,271.892,13906.201,0.017,0.019,0.983,0.981,1.111744,[193.265],[2023-03-10T14:48:23.000000000],"[522.37, 274.443, 115.277, 253.721, 96.886, 95...","[2023-03-10T14:48:23.000000000, 2023-03-10T14:...","[0.9945817950119978, 0.9947954694754996, 0.995...",3.0,"[0.0, 0.009501453253348757, 0.0056203811282433...","[0.0, 0.00950388191824924, 0.00182721496743799..."
3,BAA-1103047,2023-03-10 15:27:05,2023-03-10 19:10:44,0 days 03:43:38,19.8,21.0,0 days 03:17:44,,0 days 00:06:35,1.0,2.0,0.01,0.0025,7.0,11.0,0.0,51.0,526.378,825.288,20.805,24885.277,0.389,0.001,0.611,0.999,0.002145,[],[],"[137.334, 514.599, 162.643, 254.303, 297.185, ...","[2023-03-10T15:53:00.000000000, 2023-03-10T15:...","[0.22612623407040333, 0.2995549748907851, 0.36...",1.0,"[0.0, 0.201008772572165, 0.10048828173618965, ...","[0.0, 0.20101029450617944, 1.7568496640593205e..."
4,BAA-1103049,2023-03-10 16:22:29,2023-03-10 19:21:50,0 days 02:59:20,20.9,22.6,0 days 02:34:42,0 days 02:17:23,0 days 00:03:20,1.0,2.0,0.01,0.0025,8.0,10.0,67.0,1.0,600.211,756.874,11424.823,2429.83,0.442,0.825,0.558,0.175,1.864476,"[93.197, 114.719, 123.426, 118.123, 163.986, 1...","[2023-03-10T16:47:08.000000000, 2023-03-10T16:...",[1578.329],[2023-03-10T16:47:08.000000000],"[0.6702724079895558, 0.6987034819223695, 0.722...",1.0,"[0.0, 0.40752862309470156, 0.45704937673077456...","[0.0, 0.40752980054290106, 0.5055215290913712,..."
5,BAA-1103048,2023-03-23 07:56:44,2023-03-23 11:08:34,0 days 03:11:49,24.8,24.3,0 days 02:01:38,,0 days 00:39:29,2.0,1.0,0.01,0.0033,1.0,12.0,0.0,35.0,217.358,2400.518,113.245,13620.375,0.083,0.008,0.917,0.992,0.099313,[],[],"[265.051, 436.068, 255.756, 334.965, 386.051, ...","[2023-03-23T09:06:56.000000000, 2023-03-23T09:...","[0.8857331583210366, 0.8955956240117702, 0.901...",2.0,"[0.0, 0.1266859884464668, 0.08837536432592485,...","[0.0, 0.12668434433999118, 0.04963030024361057..."
6,BAA-1103045,2023-03-23 10:16:38,2023-03-23 13:19:43,0 days 03:03:03,23.6,25.1,0 days 02:42:35,0 days 00:03:59,0 days 00:05:06,2.0,1.0,0.01,0.0033,6.0,7.0,67.0,4.0,1200.148,1406.659,12398.231,2206.132,0.46,0.849,0.54,0.151,1.843959,"[120.184, 128.872, 158.416, 225.21, 77.384, 10...","[2023-03-23T10:37:08.000000000, 2023-03-23T10:...","[350.323, 255.706, 440.952, 935.67]","[2023-03-23T10:37:08.000000000, 2023-03-23T10:...","[0.5467665117106025, 0.5915735648541067, 0.628...",2.0,"[0.0, 0.48718061979794414, 0.3546886074251283,...","[0.0, 0.48717738373193664, 0.2242810293976725,..."
7,BAA-1103049,2023-03-23 11:15:29,2023-03-23 14:23:41,0 days 03:08:11,22.0,24.4,0 days 02:46:11,0 days 02:30:49,0 days 00:12:57,2.0,1.0,0.01,0.0033,1.0,12.0,9.0,64.0,203.326,2400.495,1815.1,25128.523,0.078,0.067,0.922,0.933,0.862705,"[323.36, 89.956, 100.945, 281.542, 317.752, 26...","[2023-03-23T11:37:30.000000000, 2023-03-23T14:...","[287.998, 580.596, 724.668, 76.437, 272.179, 7...","[2023-03-23T11:37:30.000000000, 2023-03-23T11:...","[0.8874033724380087, 0.8986480071542763, 0.907...",1.0,"[0.0, 0.11194368517145097, 0.06307926689193422...","[0.0, 0.11194269347596233, 0.01479731539165762..."
8,BAA-1103047,2023-03-23 13:29:36,2023-03-23 16:32:51,0 days 03:03:14,23.2,22.4,0 days 01:46:30,0 days 00:09:22,0 days 00:22:44,1.0,2.0,0.01,0.0033,5.0,8.0,14.0,8.0,1006.42,1600.277,1808.106,3435.514,0.386,0.345,0.614,0.655,0.893108,"[80.512, 115.623, 92.681, 91.222, 124.364, 147...","[2023-03-23T14:46:21.000000000, 2023-03-23T14:...","[948.057, 450.369, 265.346, 354.936, 203.452, ...","[2023-03-23T14:46:21.000000000, 2023-03-23T15:...","[0.44317959750338587, 0.4987168650968184, 0.54...",4.0,"[0.0, 0.21524198573427852, 0.4900212583057259,...","[0.0, 0.21524795798719143, 0.761422088340779, ..."
9,BAA-1103050,2023-03-23 14:30:26,2023-03-23 17:29:18,0 days 02:58:51,23.5,24.9,0 days 02:42:59,0 days 00:00:24,0 days 00:08:34,2.0,1.0,0.01,0.0033,2.0,11.0,60.0,19.0,407.696,2200.467,10931.713,7902.336,0.156,0.58,0.844,0.42,3.713156,"[112.835, 109.153, 146.757, 106.563, 247.758, ...","[2023-03-23T14:46:19.000000000, 2023-03-23T14:...","[777.252, 470.784, 199.692, 117.888, 433.435, ...","[2023-03-23T14:46:19.000000000, 2023-03-23T14:...","[0.9999241752352911, 0.9003446268909796, 0.818...",2.0,"[0.0, 0.1463396112335383, 0.2863590759643062, ...","[0.0, 0.14633795309553957, 0.425921200220298, ..."


In [None]:
from pathlib import Path
from itertools import product

import dash
import dash_daq as daq
import ipdb
import json
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns

from dash import Dash, dash_table, dcc, html
from dash.dependencies import Input, Output, State, ClientsideFunction
from dash.development.base_component import ComponentRegistry
from dotmap import DotMap
from plotly.subplots import make_subplots

from aeon_analysis.presocial.presocial_dash import helpers

bg_col = "#050505"
txt_col = "#f2f2f2"
plt_bg_col = "#0d0d0d"
tab_bg_col = "#003399"
tab_txt_col = "#f2f2f2"
table_max_height = "400px"
table_min_width = "1200px"
mrkr_sz = 14
color_dict = {
    "BAA-1103045": "rgb(31, 119, 180)",
    "BAA-1103047": "rgb(214, 39, 40)",
    "BAA-1103048": "rgb(44, 160, 44)",
    "BAA-1103049": "rgb(148, 103, 189)",
    "BAA-1103050": "rgb(255, 127, 14)",
}

In [None]:
# Get unique ID-thresh sessions
df_uniq_id_thresh = df[["id", "post_easy_rate", "post_hard_rate"]].drop_duplicates()
uniq_id_thresh_tits = df_uniq_id_thresh.apply(
    lambda row: ' '.join(row.values.astype(str)), axis=1).tolist()
uniq_id_thresh_dropdown = [{'label': id_thresh, 'value': id_thresh} 
                           for id_thresh in uniq_id_thresh_tits]
sesh_pref_time_figs = DotMap()
sesh_pref_dist_figs = DotMap()
dtl_init = DotMap()  # distance-to-learned
dtl_avg = DotMap()
ttl_init = DotMap()  # time-to-learned
ttl_avg = DotMap()

# Iterate over each unique ID-thresh sesh, get all corresponding sessions, create figure
# with corresponding number of axes, iterate over seshes, plot each ax.
for j, id_thresh in enumerate(df_uniq_id_thresh.itertuples()):
    cur_df = df[(
        df["id"] == id_thresh.id)
        & (df["post_easy_rate"] == id_thresh.post_easy_rate)
        & (df["post_hard_rate"] == id_thresh.post_hard_rate)
    ]
    ncols = len(cur_df) if len(cur_df) < 3 else 3
    nrows = int(np.ceil(len(cur_df) / ncols))
    fig_pref_time = make_subplots(rows=nrows, cols=ncols)
    fig_pref_dist = make_subplots(rows=nrows, cols=ncols)
    id_thresh_cum_dist = 0
    id_thresh_cum_time = 0
    cur_dtl_init = 0
    cur_ttl_init = 0
    learned_ctr = 0
    learned_flag = False
    for i, sesh in enumerate(cur_df.itertuples()):
        r, c = (i // ncols + 1, i % ncols + 1)
        cum_dist, lo, hi, tru, v_start, v_end, thresh_change = (
            sesh.cont_patch_pref.w_all_chnkd_cumsum,
            sesh.cont_patch_pref.low_bound,
            sesh.cont_patch_pref.high_bound,
            sesh.cont_patch_pref.weasy_pref,
            sesh.cont_patch_pref.learned_start_idx,
            sesh.cont_patch_pref.learned_end_idx,
            sesh.cont_patch_pref.thresh_change_idx
        )
        # fig_pref_time.add_trace(
        #     go.Scatter(
        #         y=lo, mode='lines', line=dict(color='darkslategray', dash='dash'), name='low_bound'
        #     ), 
        #     row=r, col=c
        # )
        # fig_pref_sesh.add_trace(
        #     go.Scatter(
        #         x=cum_dist, y=lo, mode='lines', line=dict(color='darkslategray', dash='dash'), name='low_bound'
        #     ), 
        #     row=r, col=c
        # )
        fig_pref_time.add_trace(
            go.Scatter(
                y=hi, mode='lines', line=dict(color='darkslategray', dash='dash'), name='high_bound'
            ), 
            row=r, col=c
        )
        fig_pref_dist.add_trace(
            go.Scatter(
                x=cum_dist, y=hi, mode='lines', line=dict(color='darkslategray', dash='dash'), name='high_bound'
            ),
            row=r, col=c
        )
        fig_pref_time.add_trace(
            go.Scatter(
                y=tru, mode='lines', line=dict(color=color_dict[sesh.id]), name='true'
            ),
            row=r, col=c
        )
        fig_pref_dist.add_trace(
            go.Scatter(
                x=cum_dist, y=tru, mode='lines', line=dict(color=color_dict[sesh.id]), name='true'
            ),
            row=r, col=c
        )
        fig_pref_time.add_shape(
                type="line", x0=thresh_change, x1=thresh_change, y0=0, y1=1.1,
                line=dict(
                    color="lightslategray",
                    dash="dash",
                    width=3,
                ),
                name="thresh_change",
                row=r, col=c
            )
        fig_pref_dist.add_shape(
                type="line", x0=cum_dist[thresh_change], x1=cum_dist[thresh_change], y0=0, y1=1.1,
                line=dict(
                    color="lightslategray",
                    dash="dash",
                    width=3,
                ),
                row=r, col=c
            )
        if v_start:
            id_thresh_cum_dist += cum_dist[v_start]
            id_thresh_cum_time += v_start
            learned_ctr += 1
            if not learned_flag:
                cur_dtl_init += cum_dist[v_start]
                cur_ttl_init += v_start
            learned_flag = True
            fig_pref_time.add_shape(
                type="line", x0=v_start, x1=v_start, y0=0, y1=1.1,
                line=dict(
                    color="deeppink",
                    dash="dash",
                    width=3,
                ),
                row=r, col=c
            )
            fig_pref_time.add_shape(
                type="line", x0=v_end, x1=v_end, y0=0, y1=1.1,
                line=dict(
                    color="deeppink",
                    dash="dash",
                    width=3,
                ),
                row=r, col=c
            )
            fig_pref_dist.add_shape(
                type="line", x0=cum_dist[v_start], x1=cum_dist[v_start], y0=0, y1=1.1,
                line=dict(
                    color="deeppink",
                    dash="dash",
                    width=3,
                ),
                row=r, col=c
            )
            fig_pref_dist.add_shape(
                type="line", x0=cum_dist[v_end], x1=cum_dist[v_end], y0=0, y1=1.1,
                line=dict(
                    color="deeppink",
                    dash="dash",
                    width=3,
                ),
                row=r, col=c
            )
        else:
            id_thresh_cum_dist += cum_dist[-1]
            id_thresh_cum_time += len(tru)
        fig_pref_time.update_layout(
            title_text=(f"{sesh.id}  easy_rate: {sesh.post_easy_rate}  hard_rate: {sesh.post_hard_rate}"),
            paper_bgcolor=bg_col,
            plot_bgcolor=plt_bg_col,
            font={"color": txt_col},
        )
        fig_pref_dist.update_layout(
            title_text=(f"{sesh.id}  easy_rate: {sesh.post_easy_rate}  hard_rate: {sesh.post_hard_rate}"),
            paper_bgcolor=bg_col,
            plot_bgcolor=plt_bg_col,
            font={"color": txt_col},
        )
        fig_pref_time.update_xaxes(title_text="Time (s)", row=r, col=c)
        fig_pref_time.update_yaxes(title_text=f"{str(sesh.enter.date())}", row=r, col=c)
        fig_pref_dist.update_xaxes(title_text="Distance (cm)", row=r, col=c)
        fig_pref_dist.update_yaxes(title_text=f"{str(sesh.enter.date())}", row=r, col=c)
        if not learned_flag:
            cur_dtl_init += cum_dist[-1]
            cur_ttl_init += len(tru)

    sesh_pref_time_figs[uniq_id_thresh_tits[j]] = fig_pref_time
    sesh_pref_dist_figs[uniq_id_thresh_tits[j]] = fig_pref_dist
    if learned_flag:
        dtl_avg[uniq_id_thresh_tits[j]] = id_thresh_cum_dist / learned_ctr
        ttl_avg[uniq_id_thresh_tits[j]] = id_thresh_cum_time / learned_ctr
        dtl_init[uniq_id_thresh_tits[j]] = cur_dtl_init
        ttl_init[uniq_id_thresh_tits[j]] = cur_ttl_init
    else:
        dtl_avg[uniq_id_thresh_tits[j]] = np.nan
        ttl_avg[uniq_id_thresh_tits[j]] = np.nan
        dtl_init[uniq_id_thresh_tits[j]] = np.nan
        ttl_init[uniq_id_thresh_tits[j]] = np.nan

In [None]:
dict(dtl_avg)

In [None]:
subj_pref_df = pd.DataFrame(
    {"dtl_avg": dtl_avg.toDict(), 
    "dtl_init": dtl_init.toDict(), 
    "ttl_avg": ttl_avg.toDict(), 
    "ttl_init": ttl_init.toDict()}
)
subj_pref_df = subj_pref_df.drop(subj_pref_df.tail(1).index)

In [None]:
group_indxs

In [None]:
# Subject pref plots
subj_pref_df = pd.DataFrame(
    {"pre_pref_avg_dist": dtl_avg.toDict(), 
    "pre_pref_init_dist": dtl_init.toDict(), 
    "pre_pref_avg_time": ttl_avg.toDict(), 
    "pre_pref_init_time": ttl_init.toDict()}
)
subj_pref_df = subj_pref_df.drop(subj_pref_df.tail(1).index)
subj_pref_df = subj_pref_df.sort_index()
groups = subj_pref_df.index.str.split(' ', n=1).str[-1].unique()

dist_cols = ["pre_pref_init_dist", "pre_pref_avg_dist"]
time_cols = ["pre_pref_init_time", "pre_pref_avg_time"]
dist_fig = make_subplots(rows=1, cols=2)
time_fig = make_subplots(rows=1, cols=2)
for i, col in enumerate(dist_cols, start=1):
    for group in groups:
        group_indxs = subj_pref_df.index.str.endswith(group)
        dist_fig.add_trace(
            go.Bar(
                x=subj_pref_df[group_indxs].index, 
                y=subj_pref_df.loc[group_indxs, col], 
                name=f'{col} {group}', 
                marker_color='bisque'
            ),
            row=1, 
            col=i
        )
    dist_fig.update_xaxes(title_text=col, row=1, col=i)
for i, col in enumerate(time_cols, start=1):
    for group in groups:
        group_indxs = subj_pref_df.index.str.endswith(group)
        time_fig.add_trace(
            go.Bar(
                x=subj_pref_df[group_indxs].index, 
                y=subj_pref_df.loc[group_indxs, col], 
                name=f'{col} {group}', 
                marker_color='bisque'
            ),
            row=1, 
            col=i
        )
    time_fig.update_xaxes(title_text=col, row=1, col=i)

# Add group boxplots
for i, col in enumerate(dist_cols, start=1):
    grp_name ='init' if i == 1 else 'avg'
    for group in groups:
        group_indices = subj_pref_df.index.str.endswith(group)
        dist_fig.add_trace(
            go.Box(y=subj_pref_df.loc[group_indices, col], name=f'{group} {grp_name} box',
            boxpoints='all', marker_color='gainsboro'),
            row=1, col=i
        )
for i, col in enumerate(time_cols, start=1):
    grp_name ='init' if i == 1 else 'avg'
    for group in groups:
        group_indices = subj_pref_df.index.str.endswith(group)
        time_fig.add_trace(
            go.Box(y=subj_pref_df.loc[group_indices, col], name=f'{group} {grp_name} box',
            boxpoints='all', marker_color='gainsboro'),
            row=1, col=i,
        )
dist_fig.update_layout(
    title='Pre Easy Preference Distance', 
    showlegend=True,
    paper_bgcolor=bg_col,
    plot_bgcolor=plt_bg_col,
    font={"color": txt_col},
)
time_fig.update_layout(
    title='Pre Easy Preference Time', 
    showlegend=True,
    paper_bgcolor=bg_col,
    plot_bgcolor=plt_bg_col,
    font={"color": txt_col},
)
dist_fig.show()
time_fig.show()

In [None]:
groups = subj_pref_df.index.str.split(' ', n=1).str[-1].unique()
groups

In [None]:
dtl_avg.pprint(), dtl_init.pprint(), ttl_avg.pprint(), ttl_init.pprint()

In [None]:
j, uniq_id_thresh_tits[j], id_thresh_cum_time

In [None]:
id_thresh_cum_dist

In [None]:
0 / 0

In [None]:
z = [i + 1 for i in range(10)]
z

In [None]:
uniq_id_thresh_tits

In [None]:
[{'label': id_thresh, 'value': id_thresh} for id_thresh in uniq_id_thresh_tits]

In [None]:
uniq_id_thresh_dropdown = [{'label': id_thresh, 'value': id_thresh} for id_thresh in uniq_id_thresh_tits]
uniq_id_thresh_dropdown

In [None]:
df_uniq_id_thresh = df[["id", "post_easy_rate", "post_hard_rate"]].drop_duplicates()
uniq_id_thresh_tits = df_uniq_id_thresh.apply(
    lambda row: ' '.join(row.values.astype(str)), axis=1).tolist()
uniq_id_thresh_tits

In [None]:
df

In [None]:
mrkr_sz = 14
color_dict = {
    "BAA-1103045": "rgb(31, 119, 180)",
    "BAA-1103047": "rgb(255, 127, 14)",
    "BAA-1103048": "rgb(44, 160, 44)",
    "BAA-1103049": "rgb(214, 39, 40)",
    "BAA-1103050": "rgb(148, 103, 189)",
}

In [None]:
patch_pref_epoch_session = go.Figure()
sesh_subj_counter = DotMap(
    {
        "BAA-1103045": 0,
        "BAA-1103047": 0,
        "BAA-1103048": 0,
        "BAA-1103049": 0,
        "BAA-1103050": 0,
    }
)
for i in df.index:
    uid = df["id"][i]
    y = df["easy_pref_epoch"][i]
    sesh_subj_counter[uid] += 1
    patch_pref_epoch_session.add_trace(
        go.Scatter(
            y=y,
            name=f"{uid}: {sesh_subj_counter[uid]}",
            mode="lines+markers",
            marker={"size": mrkr_sz},
            line=dict(color=color_dict[uid]),
        )
    )

In [None]:
sesh_subj_counter

In [None]:
patch_pref_epoch_session.show()

In [None]:
cum_patch_pref_epoch_session = go.Figure()
sesh_subj_counter = DotMap(
    {
        "BAA-1103045": 0,
        "BAA-1103047": 0,
        "BAA-1103048": 0,
        "BAA-1103049": 0,
        "BAA-1103050": 0,
    }
)
for i in df.index:
    uid = df["id"][i]
    y = df["easy_pref_epoch_cum"][i]
    sesh_subj_counter[uid] += 1
    cum_patch_pref_epoch_session.add_trace(
        go.Scatter(
            y=y,
            name=f"{uid}: {sesh_subj_counter[uid]}",
            mode="lines+markers",
            marker={"size": mrkr_sz},
            line=dict(color=color_dict[uid]),
        )
    )
    zidx = int(df["epoch_thresh_change_idx"][i])
    cum_patch_pref_epoch_session.add_trace(
        go.Scatter(
            x=np.array((zidx, zidx + 1)),
            y=y[zidx : zidx + 2],
            mode="lines+markers",
            marker={"size": mrkr_sz},
            line=dict(color="black"),
            name=f"{uid}: {sesh_subj_counter[uid]}: thresh change",
        )
    )
cum_patch_pref_epoch_session.update_layout(
    title="Cumulative Patch Preference by Quantile within Session",
    xaxis_title="Quantile",
    yaxis_title="Easy Patch Preference",
    legend_title="Session",
)
cum_patch_pref_epoch_session.show()

In [None]:
np.array((zidx, zidx + 1))

In [None]:
zidx

In [None]:
import dash
from dash import Dash, dash_table, dcc, html

bg_col = "#050505"
txt_col = "#f2f2f2"
plt_bg_col = "#0d0d0d"
tab_bg_col = "#003399"
tab_txt_col = "#f2f2f2"
table_max_height = "400px"
table_min_width = "1200px"
mrkr_sz = 14
color_dict = {
    "BAA-1103045": "rgb(31, 119, 180)",
    "BAA-1103047": "rgb(255, 127, 14)",
    "BAA-1103048": "rgb(44, 160, 44)",
    "BAA-1103049": "rgb(214, 39, 40)",
    "BAA-1103050": "rgb(148, 103, 189)",
}
# Set all relevant app.layout children names (for future color theme updates)
fig_names = [
    "weight_enter_session",
    "weight_diff_session",
    "weight_enter_subject",
    "weight_diff_subject",
    "duration_session",
    "post_thresh_dur_session",
    "pre_sampling_both_p_dur_session",
    "duration_subject",
    "post_thresh_dur_subject",
    "pre_sampling_both_p_dur_subject",
    "hard_patch_session",
    "hard_patch_subject",
    "wheel_session_abs",
    "wheel_session_norm",
    "wheel_subject_abs",
    "wheel_subject_norm",
    "pellet_session_abs",
    "pellet_session_norm",
    "pellet_subject_abs",
    "pellet_subject_norm",
    "prob_pels_session",
    "prob_pels_subject",
]
tab_names = []

In [None]:


data_table = dash_table.DataTable(
    id="data_table",
    data=df.to_dict("records"),
    columns=[{"name": i, "id": i} for i in df.columns],
    style_table={
        "overflowX": "auto",
        "overflowY": "auto",
        "maxHeight": table_max_height,
        "minWidth": table_min_width,
    },
    fixed_columns={"headers": True, "data": 2},
    fixed_rows={"headers": True},
    style_header={"fontWeight": "bold", "backgroundColor": plt_bg_col},
    style_cell={
        "backgroundColor": plt_bg_col,
        "color": txt_col,
        "textAlign": "left",
        "whiteSpace": "normal",
        "height": "auto",
        "minWidth": 60,
    },
)

In [None]:
display(data_table)

In [None]:
import dash

In [None]:
dash

In [None]:
dash.dash

In [None]:
from dash.dependencies import Input, Output, State, ClientsideFunction


In [None]:
import seaborn as sns


In [None]:
clear dash

In [None]:
s = list(sessions.itertuples())[-1]

In [None]:
s

In [None]:
    root = str(roots[0]) if np.any([s.id.endswith(sid) for sid in in_b2_210]) else str(roots[1])  # get root for current session
    harp_reader = reader.Harp(pattern="Patch1_35", columns=["TriggerPellet"])
    new_pellet_trig_bitmask = api.load(root, harp_reader, start=s.enter, end=s.exit).iloc[0, 0]
    new_pellet_trig_reader_p1 = reader.BitmaskEvent("Patch1_35", new_pellet_trig_bitmask, "TriggerPellet")
    new_pellet_trig_reader_p2 = reader.BitmaskEvent("Patch2_35", new_pellet_trig_bitmask, "TriggerPellet")
    p1 = api.load(root, new_pellet_trig_reader_p1, start=s.enter, end=s.exit)
    p2 = api.load(root, new_pellet_trig_reader_p2, start=s.enter, end=s.exit)

In [None]:
from pathlib import Path
from itertools import product

import dash
import dash_daq as daq
import ipdb
import json
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns

from dash import Dash, dash_table, dcc, html
from dash.dependencies import Input, Output, State, ClientsideFunction
from dash.development.base_component import ComponentRegistry
from dotmap import DotMap
from plotly.subplots import make_subplots

from aeon_analysis.presocial.presocial_dash import helpers

In [None]:
# Set some constants relating to the initialized colors/plotting
bg_col = "#050505"
txt_col = "#f2f2f2"
plt_bg_col = "#0d0d0d"
tab_bg_col = "#003399"
tab_txt_col = "#f2f2f2"
table_max_height = "400px"
table_min_width = "1200px"
mrkr_sz = 14
color_dict = {
    "BAA-1103045": "rgb(31, 119, 180)",
    "BAA-1103047": "rgb(214, 39, 40)",
    "BAA-1103048": "rgb(44, 160, 44)",
    "BAA-1103049": "rgb(148, 103, 189)",
    "BAA-1103050": "rgb(255, 127, 14)",
}
# Set all relevant app.layout children names (for future color theme updates)
fig_names = [
    "weight_enter_session",
    "weight_diff_session",
    "weight_enter_subject",
    "weight_diff_subject",
    "duration_session",
    "post_thresh_dur_session",
    "pre_sampling_both_p_dur_session",
    "duration_subject",
    "post_thresh_dur_subject",
    "pre_sampling_both_p_dur_subject",
    "hard_patch_session",
    "hard_patch_subject",
    "wheel_session_abs",
    "wheel_session_norm",
    "wheel_subject_abs",
    "wheel_subject_norm",
    "pellet_session_abs",
    "pellet_session_norm",
    "pellet_subject_abs",
    "pellet_subject_norm",
    "prob_pels_session",
    "prob_pels_subject",
    "patch_pref_epoch_session",
    "cum_patch_pref_epoch_session",
]

In [None]:
df_uniq_id_thresh = df[["id", "post_easy_rate", "post_hard_rate"]].drop_duplicates()
uniq_id_thresh_tits = df_uniq_id_thresh.apply(
    lambda row: ' '.join(row.values.astype(str)), axis=1).tolist()
uniq_id_thresh_figs = DotMap()

# Iterate over each unique ID-thresh sesh, get all corresponding sessions, create figure
# with corresponding number of axes, iterate over seshes, plot each ax.
for j, id_thresh in enumerate(df_uniq_id_thresh.itertuples()):
    cur_df = df[np.logical_and(
        df["id"] == id_thresh.id,
        df["post_easy_rate"] == id_thresh.post_easy_rate, 
        df["post_hard_rate"] == id_thresh.post_hard_rate
        )
    ]
    ncols = 3
    nrows = int(np.ceil(len(cur_df) / ncols))
    fig = make_subplots(rows=nrows, cols=ncols)
    for i, sesh in enumerate(cur_df.itertuples()):
        r, c = (i // ncols + 1, i % ncols + 1)
        lo, hi, tru, v_start, v_end = (
            sesh.cont_patch_pref.low_bound, 
            sesh.cont_patch_pref.high_bound, 
            sesh.cont_patch_pref.weasy_pref,
            sesh.cont_patch_pref.learned_start_idx,
            sesh.cont_patch_pref.learned_end_idx
        )
        fig.add_trace(
            go.Scatter(
                y=lo, mode='lines', line=dict(color='darkslategray', dash='dash'), name='low_bound'
            ), 
            row=r, col=c
        )
        fig.add_trace(
            go.Scatter(
                y=hi, mode='lines', line=dict(color='darkslategray', dash='dash'), name='high_bound'
            ), 
            row=r, col=c
        )
        fig.add_trace(
            go.Scatter(
                y=tru, mode='lines', line=dict(color=color_dict[sesh.id]), name='true'
            ),
            row=r, col=c
        )
        if v_start:
            fig.add_shape(
                type="line", x0=v_start, x1=v_start, y0=0, y1=1.1,
                line=dict(
                    color="deeppink",
                    dash="dash",
                    width=2,
                ),
                row=r, col=c
            )
            fig.add_shape(
                type="line", x0=v_end, x1=v_end, y0=0, y1=1.1,
                line=dict(
                    color="deeppink",
                    dash="dash",
                ),
                row=r, col=c
            )
        fig.update_layout(
            title_text=(f"{sesh.id}  easy_rate: {sesh.post_easy_rate}  hard_rate: {sesh.post_hard_rate}"),
        )
        fig.update_xaxes(title_text="Time (s)", row=r, col=c)
        fig.update_yaxes(title_text=f"{str(sesh.enter.date())}", row=r, col=c)
    uniq_id_thresh_figs[uniq_id_thresh_tits[j]] = fig

In [None]:
cur_df = df[np.logical_and(
        df["id"] == id_thresh.id,
        df["post_easy_rate"] == id_thresh.post_easy_rate, 
        df["post_hard_rate"] == id_thresh.post_hard_rate
        )
    ]

cur_df = df[(df["id"] == id_thresh.id) & (df["post_easy_rate"] == id_thresh.post_easy_rate) & (df["post_hard_rate"] == id_thresh.post_hard_rate)]

In [None]:
id_thresh.post_easy_rate, id_thresh.post_hard_rate

In [None]:
cur_df = df[(df["id"] == id_thresh.id) & (df["post_easy_rate"] == id_thresh.post_easy_rate) & (df["post_hard_rate"] == id_thresh.post_hard_rate)]
cur_df

In [None]:
cur_df = df[np.logical_and(
        df["id"] == id_thresh.id,
        df["post_hard_rate"] == id_thresh.post_hard_rate
        )
    ]
cur_df

In [None]:
cur_df.id, cur_df.enter, cur_df.post_easy_rate, cur_df.post_hard_rate

In [None]:
df_uniq_id_thresh

In [None]:
uniq_id_thresh_figs[uniq_id_thresh_tits[0]]

In [None]:
uniq_id_thresh_figs[uniq_id_thresh_tits[9]]

In [None]:
df_uniq_id_thresh

In [None]:
cur_df

In [None]:
i = 0
cur_df.cont_patch_pref.values[i].learned_start_idx
cur_df.cont_patch_pref.values[i].learned_end_idx

In [None]:
sesh.cont_patch_pref

In [None]:
sesh.cont_patch_pref.learned_start_idx

In [None]:
i = 0
sesh = cur_df.iloc[i]
fig = make_subplots(rows=nrows, cols=2)
r, c = (i // 2 + 1, i % 2 + 1)
lo, hi, tru, v_start, v_end = (
    sesh.cont_patch_pref.low_bound, 
    sesh.cont_patch_pref.high_bound, 
    sesh.cont_patch_pref.weasy_pref,
    sesh.cont_patch_pref.learned_start_idx,
    sesh.cont_patch_pref.learned_end_idx
)
fig.add_trace(
    go.Scatter(
        y=lo, mode='lines', line=dict(color='darkslategray', dash='dash'), name='low_bound'
    ), 
    row=r, col=c
)
fig.add_trace(
    go.Scatter(
        y=hi, mode='lines', line=dict(color='darkslategray', dash='dash'), name='high_bound'
    ), 
    row=r, col=c
)
fig.add_trace(
    go.Scatter(
        y=tru, mode='lines', line=dict(color=color_dict[sesh.id]), name='true'
    ),
    row=r, col=c
)
if v_start:
    fig.add_shape(
        type="line", x0=v_start, x1=v_start, y0=0, y1=1,
        line=dict(
            color="deeppink",
            dash="dash",
        ),
        row=r, col=c
    )
    fig.add_shape(
        type="line", x0=v_end, x1=v_end, y0=0, y1=1,
        line=dict(
            color="deeppink",
            dash="dash",
        ),
        row=r, col=c
    )
fig.update_layout(
    title_text=(f"{sesh.id}  easy_rate: {sesh.post_easy_rate}  hard_rate: {sesh.post_hard_rate}"),
)
fig.update_xaxes(title_text="Time (s)", row=r, col=c)
fig.update_yaxes(title_text=f"{str(sesh.enter.date())}", row=r, col=c)

In [None]:
v_end

In [None]:
cur_df.iloc[i]

In [None]:
df_uniq_id_thresh = df[["id", "post_easy_rate", "post_hard_rate"]].drop_duplicates()
uniq_id_thresh_tits = df_uniq_id_thresh.apply(
    lambda row: ' '.join(row.values.astype(str)), axis=1).tolist()
uniq_id_thresh_figs = DotMap()
uniq_id_thresh_tits

In [None]:
for j, id_thresh in enumerate(df_uniq_id_thresh.itertuples()):
    cur_df = df[np.logical_and(
        df["id"] == id_thresh.id, 
        df["post_easy_rate"] == id_thresh.post_easy_rate, 
        df["post_hard_rate"] == id_thresh.post_hard_rate
        )
    ]
    nrows = int(np.ceil(len(cur_df) / 2))
    fig = make_subplots(rows=nrows, cols=2)
    for i, sesh in enumerate(cur_df.itertuples()):
        r, c = (i // 2 + 1, i % 2 + 1)
        lo, hi, tru, v_start, v_end = (
            sesh.cont_patch_pref.low_bound, 
            sesh.cont_patch_pref.high_bound, 
            sesh.cont_patch_pref.weasy_pref,
            sesh.cont_patch_pref.learned_start_idx,
            sesh.cont_patch_pref.learned_end_idx
        )
        fig.add_trace(
            go.Scatter(
                y=lo, mode='lines', line=dict(color='darkslategray', dash='dash'), name='low_bound'
            ), 
            row=r, col=c
        )
        fig.add_trace(
            go.Scatter(
                y=hi, mode='lines', line=dict(color='darkslategray', dash='dash'), name='high_bound'
            ), 
            row=r, col=c
        )
        fig.add_trace(
            go.Scatter(
                y=tru, mode='lines', line=dict(color=color_dict[sesh.id]), name='true'
            ),
            row=r, col=c
        )
        if v_start:
            fig.add_shape(
                type="line", x0=v_start, x1=v_start, y0=0, y1=1.1,
                line=dict(
                    color="deeppink",
                    dash="dash",
                ),
                row=r, col=c
            )
            fig.add_shape(
                type="line", x0=v_end, x1=v_end, y0=0, y1=1.1,
                line=dict(
                    color="deeppink",
                    dash="dash",
                ),
                row=r, col=c
            )
        fig.update_layout(
            title_text=(f"{sesh.id}  easy_rate: {sesh.post_easy_rate}  hard_rate: {sesh.post_hard_rate}"),
        )
        fig.update_xaxes(title_text="Time (s)", row=r, col=c)
        fig.update_yaxes(title_text=f"{str(sesh.enter.date())}", row=r, col=c)
    uniq_id_thresh_figs[uniq_id_thresh_tits[j]] = fig

In [None]:
fig

In [None]:
eval("uniq_id_thresh_figs['BAA-1103050 0.01 0.0025']")

In [None]:
df.drop("cont_patch_pref", axis=1)