In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# %%
import flexiznam as flz
from v1_depth_analysis.v1_manuscript_2023 import get_session_list
from v1_depth_analysis.eye_tracking.analysis import get_data
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

# %%
PROJECT = "hey2_3d-vision_foodres_20220101"
valid_mice = ['PZAH6.4b', 'PZAG3.4f']
example_session = "PZAH6.4b_S20220419"

In [None]:
flexilims_session = flz.get_flexilims_session(project_id=PROJECT)
sessions = get_session_list.get_sessions(
    flexilims_session=flexilims_session,
    exclude_sessions=(),
    exclude_openloop=True,
    exclude_pure_closedloop=False,
    v1_only=True,
)
print(f"Found {len(sessions)} sessions for closed loop only")


In [None]:

sessions = [s for s in sessions if s.split('_')[0] in valid_mice]
print(f"Found {len(sessions)} sessions for valid mice")


In [None]:
reload_from_disk = True

target_folder = flz.get_data_root(which='processed', project=PROJECT) / PROJECT / 'Analysis' / 'eye_tracking'
target_folder.mkdir(exist_ok=True, parents=True)
if reload_from_disk:
    all_data = pd.read_pickle(target_folder / 'all_data.pkl')
    example_dlc_res = pd.read_pickle(target_folder / 'example_dlc_res.pkl')
else:
    project_recordings = flz.get_entities(datatype='recording', flexilims_session=flexilims_session)
    sessions_df = pd.DataFrame([flz.get_entity(name=s, flexilims_session=flexilims_session) for s in sessions])
    all_data = {}
    example_dlc_res = None
    problematic_sessions = []
    for sess, sess_df in sessions_df.iterrows():
        print(sess_df['name'])
        recording = project_recordings[project_recordings.origin_id==sess_df['id']]
        recording = recording[recording.protocol=='SpheresPermTubeReward'].iloc[0]
        try:
            gaze_data, dlc_res = get_data(
                project=PROJECT,
                mouse=sess_df.genealogy[0],
                session=sess_df.genealogy[-1],
                recording=recording.genealogy[-1],
                filt_window=3,
                verbose=False,)
            gaze_data['session']  = sess_df['name']
            all_data[sess_df['name']] = gaze_data
        except Exception as e:
            print(f"Problem with {sess_df['name']}: {e}")
            problematic_sessions.append(sess_df['name'])
        if sess == example_session:
            example_dlc_res = dlc_res
    all_data = pd.concat(all_data, names=['session'], ignore_index=True)
    all_data.to_pickle(target_folder / 'all_data.pkl')
    example_dlc_res.to_pickle(target_folder / 'example_dlc_res.pkl')



In [None]:
from v1_depth_analysis.eye_tracking.analysis import get_saccades
saccades_by_sess = {}
filter_window = 5
threshold = 70
for sess_name, gaze_data in all_data.groupby('session'):
    saccades_by_sess[sess_name] = get_saccades(gaze_data, threshold=threshold, filter_window=filter_window)


In [None]:
# Make a dataframe of number of saccade per trial
sacc_by_trials = []
for sess_name, sess_df in all_data.groupby('session'):
    sacc_df = saccades_by_sess[sess_name]
    trials = sess_df['trial'].dropna().unique()
    for trial in trials:
        trial_df = sess_df[sess_df['trial']==trial]
        tdict = dict(trial=trial, session =sess_name, depth=int(trial_df.depth.iloc[1] * 100),
                     trial_start=trial_df.harptime.iloc[0], trial_end=trial_df.harptime.iloc[-1])
        tdict['nsaccades'] = sacc_df[(sacc_df['start_time']>=tdict['trial_start']) & (sacc_df['start_time']<=tdict['trial_end'])].shape[0]
        sacc_by_trials.append(tdict)
sacc_by_trials = pd.DataFrame(sacc_by_trials)
sacc_by_trials['trial_duration'] = sacc_by_trials['trial_end'] - sacc_by_trials['trial_start']
sacc_by_trials['saccade_rate'] = sacc_by_trials['nsaccades'] / sacc_by_trials['trial_duration']



In [None]:
all_data.dtypes

In [None]:
bytrialbydepth = all_data.groupby(['session', 'trial', 'depth'])[is_num].aggregate('median').reset_index()
bytrialbydepth

In [None]:
bytrialbydepth = all_data.groupby(['session', 'trial', 'depth'])[is_num].aggregate('median')
bytrialbydepth

In [None]:
bytrialbydepth.reset_index().shape

In [None]:
bysess.shape

In [None]:
import seaborn as sns
bysess = sacc_by_trials.groupby(['depth', 'session']).aggregate('mean')
bysess.reset_index(inplace=True)
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(2,2,4)
dtypes = all_data.dtypes 
is_num = dtypes[dtypes== 'float64'].index.copy()
is_num = is_num.drop('trial')
is_num = is_num.drop('depth')
bytrialbydepth = all_data.groupby(['session', 'trial', 'depth'])[is_num].aggregate('median').reset_index()
bytrialbydepth = bytrialbydepth.groupby(['session', 'depth']).aggregate('mean').reset_index()
bytrialbydepth['depth'] = (bytrialbydepth['depth'] * 100).astype(int)
sns.boxplot(ax=ax,data=bytrialbydepth, x='depth', y='velocity', hue='depth', palette=sns.color_palette('cool_r', 5), saturation=1, legend=False,
showfliers=False)
sns.stripplot(ax=ax, data=bytrialbydepth, x='depth', y='velocity', hue='depth', palette=sns.color_palette('cool_r', 5), 
              legend=False, edgecolor='k', linewidth=1, alpha=0.5)
ax.set_ylabel('Velocity (degrees/s)')
ax.set_xlabel('Depth (cm)')
ax = fig.add_subplot(2,2,3)
sns.boxplot(ax=ax,data=bysess, x='depth', y='saccade_rate', hue='depth', palette=sns.color_palette('cool_r', 5), saturation=1)
sns.stripplot(ax=ax, data=bysess, x='depth', y='saccade_rate', hue='depth', palette=sns.color_palette('cool_r', 5), 
              legend=False, edgecolor='k', linewidth=1, alpha=0.5)
ax.set_ylabel('Saccade Rate (Hz)')
ax.set_xlabel('Depth (cm)')
ax.get_legend().remove()
ax.legend(loc='upper left', ncol=5, columnspacing=0.2, bbox_to_anchor=(-0.2,1.2,0,0))
fig.subplots_adjust(wspace=0.3)


In [None]:
trial_df.iloc[0]

In [None]:
sessions = list(sorted(all_data['session'].unique()))
rates = np.zeros(len(sessions))
for isess, sess_name in enumerate(sessions):
    gaze_data = all_data[all_data['session']==sess_name]
    frame_rate = 1/np.nanmedian(np.diff(gaze_data['harptime']))
    rec_length = gaze_data['harptime'].max() - gaze_data['harptime'].min()
    saccade_rate = len(saccades[sess_name]) / rec_length
    print(f"{sess_name}: {frame_rate:.2f} Hz, {rec_length:.2f} s, {saccade_rate:.2f} saccades/s")
    rates[isess] = frame_rate


In [None]:

gaze_data = all_data[all_data['session']=='PZAG3.4f_S20220421']
t0 = gaze_data['harptime'].min()
vel = gaze_data['velocity'].interpolate()
azi = gaze_data['azimuth'].interpolate()
filt_azi = azi.rolling(window=filter_window, center=True).median().interpolate()
ele = gaze_data['elevation']
filt_ele = ele.rolling(window=filter_window, center=True).median().interpolate()

displacement = np.sqrt(filt_azi.diff()**2 + filt_ele.diff()**2)
velocity = pd.Series(displacement) / gaze_data['harptime'].diff()

saccades = get_saccades(gaze_data, threshold=threshold, filter_window = filter_window)


In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(rows=2, cols=1, shared_xaxes=True)

if False:
    fig.add_trace(go.Scatter(
        x=gaze_data['harptime']-t0,
        y=gaze_data['velocity'],
    ), row=1, col=1)
    fig.add_trace(go.Scatter(
        x=gaze_data['harptime']-t0,
    y=velocity,
    ), row=1, col=1)

fig.add_trace(go.Scatter(
    x=gaze_data['harptime']-t0,
    y=gaze_data['azimuth']-filt_azi.mean(),
), row=2, col=1)
fig.add_trace(go.Scatter(
    x=gaze_data['harptime']-t0,
    y=filt_azi-filt_azi.mean(),
), row=2, col=1)

fig.add_trace(go.Scatter(
    x=gaze_data['harptime']-t0,
    y=displacement,
), row=2, col=1)
fig.add_trace(go.Scatter(
    x=gaze_data['harptime']-t0,
    y=velocity,
    name='velocity',
), row=1, col=1)
fig.add_trace(go.Scatter(
    x=gaze_data['harptime']-t0,
    y=np.ones_like(displacement) * threshold,
    line=dict(color='black', width=1)
), row=1, col=1)
fig.add_scatter(x=saccades.start_time - t0,
                y=saccades.peak_velocity,
                mode='markers',
                marker=dict(size=10, color='red'),
                name='saccades',
                row=1, col=1)
fig.add_scatter(x=saccades.start_time - t0,
                y=np.ones(len(saccades)) * 10,
                mode='markers',
                marker=dict(size=10, color='red'),
                name='saccades',
                row=2, col=1)

fig.update_layout(height=600, width=1200)
# fig.update_yaxes(range=[0, 240], row=1, col=1)
fig.update_yaxes(range=[-20, 20], row=2, col=1)
fig.update_yaxes(range=[0, threshold*3], row=1, col=1)
fig.update_xaxes(range=[900, 1100], row=2, col=1)
fig.show()

In [None]:
saccades