In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# %%
import flexiznam as flz
from v1_depth_analysis.v1_manuscript_2023 import get_session_list
from v1_depth_analysis.eye_tracking.analysis import get_data
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

# %%
PROJECT = "hey2_3d-vision_foodres_20220101"
valid_mice = ['PZAH6.4b', 'PZAG3.4f']
example_session = "PZAG3.4f_S20220421"

In [None]:
sess_to_exclude = {"PZAH6.4b_S20220516": "Mouse squint too much, cropping issue",
                   "PZAH6.4b_S20220429": "Two eye position. Maybe reflection ill detected",
                    }

In [None]:
flexilims_session = flz.get_flexilims_session(project_id=PROJECT)
sessions = get_session_list.get_sessions(
    flexilims_session=flexilims_session,
    exclude_sessions=(),
    exclude_openloop=True,
    exclude_pure_closedloop=False,
    v1_only=True,
)
sessions = [s for s in sessions if s not in sess_to_exclude]
print(f"Found {len(sessions)} sessions for closed loop only")
sessions = [s for s in sessions if s.split('_')[0] in valid_mice]
print(f"Found {len(sessions)} sessions for valid mice")

In [None]:
reload_from_disk = False

target_folder = flz.get_data_root(which='processed', project=PROJECT) / PROJECT / 'Analysis' / 'eye_tracking'
target_folder.mkdir(exist_ok=True, parents=True)
if reload_from_disk:
    all_data = pd.read_pickle(target_folder / 'all_data.pkl')
    example_dlc_res = pd.read_pickle(target_folder / 'example_dlc_res.pkl')
else:
    project_recordings = flz.get_entities(datatype='recording', flexilims_session=flexilims_session)
    sessions_df = pd.DataFrame([flz.get_entity(name=s, flexilims_session=flexilims_session) for s in sessions])
    all_data = {}
    example_dlc_res = None
    problematic_sessions = []
    for sess, sess_df in sessions_df.iterrows():
        print(sess_df['name'])
        recording = project_recordings[project_recordings.origin_id==sess_df['id']]
        recording = recording[recording.protocol=='SpheresPermTubeReward'].iloc[0]
        try:
            gaze_data, dlc_res, dlc_ds = get_data(
                project=PROJECT,
                mouse=sess_df.genealogy[0],
                session=sess_df.genealogy[-1],
                recording=recording.genealogy[-1],
                filt_window=3,
                verbose=False,)
            gaze_data['session']  = sess_df['name']
            all_data[sess_df['name']] = gaze_data
        except Exception as e:
            print(f"Problem with {sess_df['name']}: {e}")
            problematic_sessions.append(sess_df['name'])
        if sess == example_session:
            example_dlc_res = dlc_res
    all_data = pd.concat(all_data, names=['session'], ignore_index=True)
    all_data.to_pickle(target_folder / 'all_data.pkl')
    example_dlc_res.to_pickle(target_folder / 'example_dlc_res.pkl')



In [None]:
# remove problematic sessions
# they could still be here if they have been read from disk

all_data = all_data[~all_data.session.isin(sess_to_exclude)]

In [None]:
from v1_depth_analysis.eye_tracking.analysis import get_saccades
saccades_by_sess = {}
filter_window = 5
threshold = 70
for sess_name, gaze_data in all_data.groupby('session'):
    saccades_by_sess[sess_name] = get_saccades(gaze_data, threshold=threshold, filter_window=filter_window)


In [None]:
# Make a dataframe of number of saccade per trial
sacc_by_trials = []
for sess_name, sess_df in all_data.groupby('session'):
    sacc_df = saccades_by_sess[sess_name]
    trials = sess_df['trial'].dropna().unique()
    for trial in trials:
        trial_df = sess_df[sess_df['trial']==trial]
        tdict = dict(trial=trial, session =sess_name, depth=int(trial_df.depth.iloc[1] * 100),
                     trial_start=trial_df.harptime.iloc[0], trial_end=trial_df.harptime.iloc[-1])
        tdict['nsaccades'] = sacc_df[(sacc_df['start_time']>=tdict['trial_start']) & (sacc_df['start_time']<=tdict['trial_end'])].shape[0]
        sacc_by_trials.append(tdict)
sacc_by_trials = pd.DataFrame(sacc_by_trials)
sacc_by_trials['trial_duration'] = sacc_by_trials['trial_end'] - sacc_by_trials['trial_start']
sacc_by_trials['saccade_rate'] = sacc_by_trials['nsaccades'] / sacc_by_trials['trial_duration']



In [None]:
# make a "position relative to median" column
dtypes = all_data.dtypes 
is_num = dtypes[dtypes== 'float64'].index.copy()
is_num = is_num.drop('trial')
is_num = is_num.drop('depth')

all_data['az_rel2med'] = np.nan
all_data['el_rel2med'] = np.nan
for sess, sess_df in all_data.groupby('session'):
    valid_part = sess_df[(sess_df.valid)&(~np.isnan(sess_df.depth))]
    med_pos = np.nanmedian(valid_part[['azimuth_filt', 'elevation_filt']].values, axis=0)
    all_data.loc[sess_df.index, 'az_rel2med'] = sess_df['azimuth_filt'].values - med_pos[0]
    all_data.loc[sess_df.index, 'el_rel2med'] = sess_df['elevation_filt'].values - med_pos[1]

# add binned version of running speed
bins = np.hstack([-np.inf, np.arange(0, 1, 0.2), np.inf])
all_data['rs_bin'] = pd.cut(all_data['RS'], bins=bins, labels=False)

In [None]:
# get example session data
# Get example frame
import cv2
from wayla import eye_io

start_frame = 45344
project_recordings = flz.get_entities(datatype='recording', flexilims_session=flexilims_session)
sess_df = flz.get_entity(name=example_session, datatype='session', flexilims_session=flexilims_session)
recording = project_recordings[project_recordings.origin_id==sess_df['id']]
recording = recording[recording.protocol=='SpheresPermTubeReward'].iloc[0]

camera = flz.Dataset.from_flexilims(name=f"{recording.name}_right_eye_camera", flexilims_session=flexilims_session)
gaze_data, dlc_res, dlc_ds = get_data(
    project=PROJECT,
    mouse=sess_df.genealogy[0],
    session=sess_df.genealogy[-1],
    recording=recording.genealogy[-1],
    filt_window=3,
    verbose=False,)
eye_params = eye_io.get_eye_parameters(camera, flexilims_session)
# remove the scorer multiindex column
dlc_res.columns = dlc_res.columns.droplevel('scorer')

video_file = camera.path_full / camera.extra_attributes["video_file"]

cropping = dlc_ds.extra_attributes["cropping"]
cam_data = cv2.VideoCapture(str(video_file))
cam_data.set(cv2.CAP_PROP_POS_FRAMES, start_frame - 1)
ret, frame = cam_data.read()
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
cam_data.release()
gray = gray[cropping[2] : cropping[3], cropping[0] : cropping[1]]

In [None]:
# use plotly to plot azimuth_filt 
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.02)
fig.add_trace(go.Scatter(x=np.arange(len(all_data[all_data.session==example_session])), y=all_data[all_data.session==example_session].azimuth_filt, mode='lines', name='Azimuth'), row=1, col=1)
fig.add_trace(go.Scatter(x=np.arange(len(all_data[all_data.session==example_session])), y=all_data[all_data.session==example_session].elevation_filt, mode='lines', name='Elevation'), row=2, col=1)
fig.update_layout(title_text="Azimuth and Elevation", showlegend=False)
fig.show()

In [None]:
palette = sns.color_palette('cool_r', 5)
palette
palette[1]

In [None]:
depths

In [None]:
import seaborn as sns
from wayla.diagnostics import plot_ellipse_on_frame
bysess = sacc_by_trials.groupby(['depth', 'session']).aggregate('mean')
bysess.reset_index(inplace=True)
fig = plt.figure(figsize=(6, 3))
palette = sns.color_palette('cool_r', 5)
ax = fig.add_subplot(2,2,3)
#bytrialbydepth = all_data.groupby(['session', 'trial', 'depth'])[is_num].aggregate('mean').reset_index()
#bytrialbydepth = bytrialbydepth.groupby(['session', 'depth']).aggregate('median').reset_index()
bytrialbydepth = all_data.groupby(['session', 'depth'])[is_num].aggregate('mean').reset_index()
bytrialbydepth['depth'] = (bytrialbydepth['depth'] * 100).astype(int)
sns.boxplot(ax=ax,data=bytrialbydepth, x='depth', y='velocity', hue='depth', palette=palette, saturation=1, legend=False,
showfliers=False)
sns.stripplot(ax=ax, data=bytrialbydepth, x='depth', y='velocity', hue='depth', palette=palette,  
              legend=False, edgecolor='k', linewidth=1, alpha=0.5)
ax.set_ylabel('Velocity (degrees/s)')
ax.set_xlabel('Depth (cm)')
ax = fig.add_subplot(2,2,4)
sns.boxplot(ax=ax,data=bysess, x='depth', y='saccade_rate', hue='depth',palette=palette,  saturation=1,
showfliers=False)
sns.stripplot(ax=ax, data=bysess, x='depth', y='saccade_rate', hue='depth', palette=palette, 
              legend=False, edgecolor='k', linewidth=1, alpha=0.5)
ax.set_ylabel('Saccade Rate (Hz)')
ax.set_xlabel('Depth (cm)')
ax.get_legend().remove()
ax.legend(loc='upper left', ncol=5, columnspacing=0.2, bbox_to_anchor=(-0.2,1.5,0,0),
          fontsize=8)

# example session plot


ax = fig.add_axes([0, 0.5, 0.3, 0.4])
ax.imshow(gray, cmap='gray', vmin=0, vmax=150)
eye_pos = dlc_res.iloc[start_frame][[f'eye_{i}' for i in range(1, 13)]]
ref = gaze_data.iloc[start_frame][['reflection_x', 'reflection_y']].values
eye_centre = eye_params['eye_centre'] + ref
pupil_center = gaze_data.iloc[start_frame][['centre_x', 'centre_y']].values
ax.plot([eye_centre[0], pupil_center[0]], [eye_centre[1], pupil_center[1]], c='dodgerblue')
ax.scatter(*eye_centre, s=5, c='k', zorder=5)
ax.scatter(eye_pos.xs('x', level=1), eye_pos.xs('y', level=1), s=1, c='indianred', zorder=11)
plot_ellipse_on_frame(
                ax,
                start_frame,
                gaze_data,
                origin='uncropped',
                dlc_res=dlc_res,
                reflection_fit=None,
                color='dodgerblue',
                alpha=1,
                lw=0.5,
                zorder=12,
            )
ax.set_ylim(450, 60)
ax.set_xlim(10, 330)
ax.axis('off')
# print the ax position on the figure
print(ax.get_position())

ax = plt.subplot2grid(fig=fig, shape=(2,4), loc=(0,1), colspan=3)
b = start_frame - 1000
e = start_frame + 1000
gd = gaze_data.iloc[b:e]
t0 = gaze_data.harptime.iloc[b]
med_pos = np.nanmedian(gaze_data[['azimuth_filt', 'elevation_filt']].values, axis=0)

ax.plot(gd.harptime - t0, gd.azimuth_filt - med_pos[0], c='k', label='Azimuth')
depths = gd.depth.unique()
depths = sorted(np.round(depths[~np.isnan(depths)] * 100).astype(int))
for t, tdf in gd.groupby('trial'):
    ax.axvline(tdf.harptime.iloc[0] - t0, c='gray', lw=0.5)
    ax.axvline(tdf.harptime.iloc[-1] - t0, c='gray', lw=0.5)
    d = np.round(tdf.depth.iloc[0] * 100).astype(int)
    depth_index = list(depths).index(d)
    ax.axvspan(tdf.harptime.iloc[0] - t0, tdf.harptime.iloc[-1] - t0, alpha=0.5, color=palette[depth_index])
ax.set_ylabel('Azimuth (degrees)')
ax.set_xlabel('Time (s)')
ax.set_xlim(0, 120)
ax.set_ylim(-10, 10)
# ax.plot(gaze_data.harptime.iloc[b:e] - t0, gaze_data.elevation_filt.iloc[b:e] - med_pos[1], c='indianred', label='Elevation')

for x in fig.axes:
    x.spines['top'].set_visible(False)
    x.spines['right'].set_visible(False)
fig.subplots_adjust(hspace=1, wspace=0.5)


In [None]:
# anova of bysess saccaade rate
import statsmodels.api as sm
import statsmodels.formula.api as smf

formula = 'saccade_rate ~ C(depth)'
model = smf.ols(formula, data=bysess)
results = model.fit()
aov_table = sm.stats.anova_lm(results, typ=2)
print(aov_table)

formula = 'velocity ~ C(depth)'
model = smf.ols(formula, data=bytrialbydepth)
results = model.fit()
aov_table = sm.stats.anova_lm(results, typ=2)
print(aov_table)

In [None]:
sns.displot(all_data, x="azimuth", row='session', kind="kde", fill=True, palette=sns.color_palette('cool_r', 5))

In [None]:

sns.displot(data=all_data, x='az_rel2med', hue='depth', row='session',col='rs_bin',kind='kde' ,fill=True, palette=sns.color_palette('cool_r', 5), common_norm=False)

In [None]:
sessions = list(sorted(all_data['session'].unique()))
rates = np.zeros(len(sessions))
for isess, sess_name in enumerate(sessions):
    gaze_data = all_data[all_data['session']==sess_name]
    frame_rate = 1/np.nanmedian(np.diff(gaze_data['harptime']))
    rec_length = gaze_data['harptime'].max() - gaze_data['harptime'].min()
    saccade_rate = len(saccades[sess_name]) / rec_length
    print(f"{sess_name}: {frame_rate:.2f} Hz, {rec_length:.2f} s, {saccade_rate:.2f} saccades/s")
    rates[isess] = frame_rate


In [None]:

gaze_data = all_data[all_data['session']=='PZAG3.4f_S20220421']
t0 = gaze_data['harptime'].min()
vel = gaze_data['velocity'].interpolate()
azi = gaze_data['azimuth'].interpolate()
filt_azi = azi.rolling(window=filter_window, center=True).median().interpolate()
ele = gaze_data['elevation']
filt_ele = ele.rolling(window=filter_window, center=True).median().interpolate()

displacement = np.sqrt(filt_azi.diff()**2 + filt_ele.diff()**2)
velocity = pd.Series(displacement) / gaze_data['harptime'].diff()

saccades = get_saccades(gaze_data, threshold=threshold, filter_window = filter_window)


In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(rows=2, cols=1, shared_xaxes=True)

if False:
    fig.add_trace(go.Scatter(
        x=gaze_data['harptime']-t0,
        y=gaze_data['velocity'],
    ), row=1, col=1)
    fig.add_trace(go.Scatter(
        x=gaze_data['harptime']-t0,
    y=velocity,
    ), row=1, col=1)

fig.add_trace(go.Scatter(
    x=gaze_data['harptime']-t0,
    y=gaze_data['azimuth']-filt_azi.mean(),
), row=2, col=1)
fig.add_trace(go.Scatter(
    x=gaze_data['harptime']-t0,
    y=filt_azi-filt_azi.mean(),
), row=2, col=1)

fig.add_trace(go.Scatter(
    x=gaze_data['harptime']-t0,
    y=displacement,
), row=2, col=1)
fig.add_trace(go.Scatter(
    x=gaze_data['harptime']-t0,
    y=velocity,
    name='velocity',
), row=1, col=1)
fig.add_trace(go.Scatter(
    x=gaze_data['harptime']-t0,
    y=np.ones_like(displacement) * threshold,
    line=dict(color='black', width=1)
), row=1, col=1)
fig.add_scatter(x=saccades.start_time - t0,
                y=saccades.peak_velocity,
                mode='markers',
                marker=dict(size=10, color='red'),
                name='saccades',
                row=1, col=1)
fig.add_scatter(x=saccades.start_time - t0,
                y=np.ones(len(saccades)) * 10,
                mode='markers',
                marker=dict(size=10, color='red'),
                name='saccades',
                row=2, col=1)

fig.update_layout(height=600, width=1200)
# fig.update_yaxes(range=[0, 240], row=1, col=1)
fig.update_yaxes(range=[-20, 20], row=2, col=1)
fig.update_yaxes(range=[0, threshold*3], row=1, col=1)
fig.update_xaxes(range=[900, 1100], row=2, col=1)
fig.show()

In [None]:
saccades