In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# %%
import flexiznam as flz
from cottage_analysis.analysis import common_utils
from cottage_analysis.plotting import basic_vis_plots
from v1_depth_map.figure_utils import get_session_list
from v1_depth_map.batch_analysis.eye_tracking.analysis import get_data, get_saccades
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import cm
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42 # for pdfs

# %%
PROJECT = "hey2_3d-vision_foodres_20220101"
valid_mice = ['PZAH6.4b', 'PZAG3.4f']
example_session = "PZAG3.4f_S20220421"

sess_to_exclude = {"PZAH6.4b_S20220516": "Mouse squint too much, cropping issue",
                   "PZAH6.4b_S20220429": "Two eye position. Maybe reflection ill detected",
                    }

In [None]:
flexilims_session = flz.get_flexilims_session(project_id=PROJECT)
sessions = get_session_list.get_sessions(
    flexilims_session=flexilims_session,
    exclude_sessions=(),
    exclude_openloop=True,
    exclude_pure_closedloop=False,
    v1_only=True,
)
sessions = [s for s in sessions if s not in sess_to_exclude]
print(f"Found {len(sessions)} sessions for closed loop only")
sessions = [s for s in sessions if s.split('_')[0] in valid_mice]
print(f"Found {len(sessions)} sessions for valid mice")

In [None]:
reload_from_disk = True

target_folder = flz.get_data_root(which='processed', project=PROJECT) / PROJECT / 'Analysis' / 'eye_tracking'
target_folder.mkdir(exist_ok=True, parents=True)
if reload_from_disk:
    all_data = pd.read_pickle(target_folder / 'all_data.pkl')
    example_dlc_res = pd.read_pickle(target_folder / 'example_dlc_res.pkl')
else:
    project_recordings = flz.get_entities(datatype='recording', flexilims_session=flexilims_session)
    sessions_df = pd.DataFrame([flz.get_entity(name=s, flexilims_session=flexilims_session) for s in sessions])
    all_data = {}
    example_dlc_res = None
    problematic_sessions = []
    for sess, sess_df in sessions_df.iterrows():
        print(sess_df['name'])
        recording = project_recordings[project_recordings.origin_id==sess_df['id']]
        recording = recording[recording.protocol=='SpheresPermTubeReward'].iloc[0]
        try:
            gaze_data, dlc_res, dlc_ds = get_data(
                project=PROJECT,
                mouse=sess_df.genealogy[0],
                session=sess_df.genealogy[-1],
                recording=recording.genealogy[-1],
                filt_window=3,
                verbose=False,)
            gaze_data['session']  = sess_df['name']
            all_data[sess_df['name']] = gaze_data
        except Exception as e:
            print(f"Problem with {sess_df['name']}: {e}")
            problematic_sessions.append(sess_df['name'])
        if sess == example_session:
            example_dlc_res = dlc_res
    all_data = pd.concat(all_data, names=['session'], ignore_index=True)
    all_data.to_pickle(target_folder / 'all_data.pkl')
    example_dlc_res.to_pickle(target_folder / 'example_dlc_res.pkl')

# remove problematic sessions
# they could still be here if they have been read from disk

all_data = all_data[~all_data.session.isin(sess_to_exclude)]

In [None]:
saccades_by_sess = {}
filter_window = 5
threshold = 70
for sess_name, gaze_data in all_data.groupby('session'):
    saccades_by_sess[sess_name] = get_saccades(gaze_data, threshold=threshold, filter_window=filter_window)

# Make a dataframe of number of saccade per trial
sacc_by_trials = []
for sess_name, sess_df in all_data.groupby('session'):
    sacc_df = saccades_by_sess[sess_name]
    trials = sess_df['trial'].dropna().unique()
    for trial in trials:
        trial_df = sess_df[sess_df['trial']==trial]
        tdict = dict(trial=trial, session =sess_name, depth=int(trial_df.depth.iloc[1] * 100),
                     trial_start=trial_df.harptime.iloc[0], trial_end=trial_df.harptime.iloc[-1])
        tdict['nsaccades'] = sacc_df[(sacc_df['start_time']>=tdict['trial_start']) & (sacc_df['start_time']<=tdict['trial_end'])].shape[0]
        sacc_by_trials.append(tdict)
sacc_by_trials = pd.DataFrame(sacc_by_trials)
sacc_by_trials['trial_duration'] = sacc_by_trials['trial_end'] - sacc_by_trials['trial_start']
sacc_by_trials['saccade_rate'] = sacc_by_trials['nsaccades'] / sacc_by_trials['trial_duration']

# make a "position relative to median" column
dtypes = all_data.dtypes 
is_num = dtypes[dtypes== 'float64'].index.copy()
is_num = is_num.drop('trial')
is_num = is_num.drop('depth')

all_data['az_rel2med'] = np.nan
all_data['el_rel2med'] = np.nan
for sess, sess_df in all_data.groupby('session'):
    valid_part = sess_df[(sess_df.valid)&(~np.isnan(sess_df.depth))]
    med_pos = np.nanmedian(valid_part[['azimuth_filt', 'elevation_filt']].values, axis=0)
    all_data.loc[sess_df.index, 'az_rel2med'] = sess_df['azimuth_filt'].values - med_pos[0]
    all_data.loc[sess_df.index, 'el_rel2med'] = sess_df['elevation_filt'].values - med_pos[1]

# add binned version of running speed
bins = np.hstack([-np.inf, np.arange(0, 1, 0.2), np.inf])
all_data['rs_bin'] = pd.cut(all_data['RS'], bins=bins, labels=False)

In [None]:
# get example session data
# Get example frame
import cv2
from wayla import eye_io

start_frame = 45344
project_recordings = flz.get_entities(datatype='recording', flexilims_session=flexilims_session)
sess_df = flz.get_entity(name=example_session, datatype='session', flexilims_session=flexilims_session)
recording = project_recordings[project_recordings.origin_id==sess_df['id']]
recording = recording[recording.protocol=='SpheresPermTubeReward'].iloc[0]

camera = flz.Dataset.from_flexilims(name=f"{recording.name}_right_eye_camera", flexilims_session=flexilims_session)
gaze_data, dlc_res, dlc_ds = get_data(
    project=PROJECT,
    mouse=sess_df.genealogy[0],
    session=sess_df.genealogy[-1],
    recording=recording.genealogy[-1],
    filt_window=3,
    verbose=False,)
eye_params = eye_io.get_eye_parameters(camera, flexilims_session)
# remove the scorer multiindex column
dlc_res.columns = dlc_res.columns.droplevel('scorer')

video_file = camera.path_full / camera.extra_attributes["video_file"]

cropping = dlc_ds.extra_attributes["cropping"]
cam_data = cv2.VideoCapture(str(video_file))
cam_data.set(cv2.CAP_PROP_POS_FRAMES, start_frame - 1)
ret, frame = cam_data.read()
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
cam_data.release()
gray = gray[cropping[2] : cropping[3], cropping[0] : cropping[1]]

In [None]:
# use plotly to plot azimuth_filt 
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.02)
fig.add_trace(go.Scatter(x=np.arange(len(all_data[all_data.session==example_session])), y=all_data[all_data.session==example_session].azimuth_filt, mode='lines', name='Azimuth'), row=1, col=1)
fig.add_trace(go.Scatter(x=np.arange(len(all_data[all_data.session==example_session])), y=all_data[all_data.session==example_session].elevation_filt, mode='lines', name='Elevation'), row=2, col=1)
fig.update_layout(title_text="Azimuth and Elevation", showlegend=False)
fig.show()

In [None]:
import seaborn as sns
from wayla.diagnostics import plot_ellipse_on_frame
fontsize_dict = {"title": 5, "label": 7, "tick": 5, "legend": 5}
bysess = sacc_by_trials.groupby(['depth', 'session']).aggregate('mean')
bysess.reset_index(inplace=True)
fig = plt.figure(figsize=(18/2.54, 18/2.54))
palette = sns.color_palette('cool_r', 5)
bytrialbydepth = all_data.groupby(['session', 'depth'])[is_num].aggregate('mean').reset_index()
bytrialbydepth['depth'] = (bytrialbydepth['depth'] * 100).astype(int)
jitter = 0.2
scatter_alpha = 0.3
scatter_markersize=3
capsize=5
capthick=1.5
linewidth=1.5
elinewidth=1.5

for icol, (data,col,ylabel, ylim) in enumerate(zip([bytrialbydepth, bysess],
                                      ['velocity', 'saccade_rate'],
                                      ['Velocity (degrees/s)', 'Saccade Rate (Hz)'],
                                      [(0, 20), (-0.01, 0.2)])):
    ax = fig.add_axes([0.05+0.5*icol, 0.1, 0.3, 0.15])
    depth_list = np.sort(bytrialbydepth.depth.unique()).astype("float")
    for idepth, depth in enumerate(depth_list):
        color = basic_vis_plots.get_depth_color(
            depth_list[idepth], depth_list, cmap=cm.cool.reversed()
        )
        velocity = data[bytrialbydepth['depth']==depth][col].values
        CI_low, CI_high = common_utils.get_bootstrap_ci(velocity.T, sig_level=0.05)
        mean_velocity = np.nanmean(velocity)

        sns.stripplot(
            x=np.ones(velocity.shape)*idepth,
            y=velocity,
            jitter=0.2,
            edgecolor="white",
            color=color,
            alpha=scatter_alpha,
            size=scatter_markersize,
        )
        plt.plot(
            [idepth - 0.3, idepth + 0.3],
            [mean_velocity, mean_velocity],
            linewidth=linewidth,
            color=color,
        )
        plt.errorbar(
            x=idepth,
            y=mean_velocity,
            yerr = np.array([mean_velocity-CI_low, CI_high-mean_velocity]).reshape(2,1),
            capsize=capsize,
            elinewidth=elinewidth,
            ecolor=color,
            capthick=capthick,
        )
        ax.set_ylabel(ylabel, fontsize=fontsize_dict["label"])
        ax.set_xlabel("Depth (cm)", fontsize=fontsize_dict["label"])
        ax.set_xticks(np.arange(len(depth_list)))
        ax.set_xticklabels((depth_list).astype("int"), fontsize=fontsize_dict["tick"])
        ax.tick_params(axis="both", which="major", labelsize=fontsize_dict["tick"])
        ax.set_ylim(ylim)
        ax.set_yticks(np.linspace(0, ylim[1], 5))
        sns.despine(ax=ax)
        

# example session plot
ax = fig.add_axes([0.05, 0.5, 0.2, 0.2])
ax.imshow(gray, cmap='gray', vmin=0, vmax=150)
eye_pos = dlc_res.iloc[start_frame][[f'eye_{i}' for i in range(1, 13)]]
ref = gaze_data.iloc[start_frame][['reflection_x', 'reflection_y']].values
eye_centre = eye_params['eye_centre'] + ref
pupil_center = gaze_data.iloc[start_frame][['centre_x', 'centre_y']].values
ax.plot([eye_centre[0], pupil_center[0]], [eye_centre[1], pupil_center[1]], c='dodgerblue')
ax.scatter(*eye_centre, s=5, c='k', zorder=5)
ax.scatter(eye_pos.xs('x', level=1), eye_pos.xs('y', level=1), s=1, c='indianred', zorder=11)
plot_ellipse_on_frame(
                ax,
                start_frame,
                gaze_data,
                origin='uncropped',
                dlc_res=dlc_res,
                reflection_fit=None,
                color='dodgerblue',
                alpha=1,
                lw=0.5,
                zorder=12,
            )
ax.set_ylim(450, 60)
ax.set_xlim(10, 330)
ax.axis('off')

ax = fig.add_axes([0.35, 0.5, 0.4, 0.15])
b = start_frame - 1000
e = start_frame + 1000
gd = gaze_data.iloc[b:e]
t0 = gaze_data.harptime.iloc[b]
med_pos = np.nanmedian(gaze_data[['azimuth_filt', 'elevation_filt']].values, axis=0)


ax.plot(gd.harptime.values - t0, gd.elevation_filt.values - np.array(med_pos)[1], c='grey', label='Elevation', zorder=9)
ax.plot(gd.harptime.values - t0, gd.azimuth_filt.values - np.array(med_pos)[0], c='k', label='Azimuth', zorder=10)
depths = gd.depth.unique()
depths = sorted(np.round(depths[~np.isnan(depths)] * 100).astype(int))
for t, tdf in gd.groupby('trial'):
    ax.axvline(tdf.harptime.iloc[0] - t0, c='gray', lw=0.5)
    ax.axvline(tdf.harptime.iloc[-1] - t0, c='gray', lw=0.5)
    d = np.round(tdf.depth.iloc[0] * 100).astype(int)
    depth_index = list(depths).index(d)
    ax.axvspan(tdf.harptime.iloc[0] - t0, tdf.harptime.iloc[-1] - t0, alpha=0.5, 
               color=basic_vis_plots.get_depth_color(
            depth_list[depth_index], depth_list, cmap=cm.cool.reversed()
        ))
ax.set_ylabel('Eye position (degrees)', fontsize=fontsize_dict['label'])
ax.set_xlabel('Time (s)', fontsize=fontsize_dict['label'])
ax.tick_params(axis='both', which='major', labelsize=fontsize_dict['tick'])
ax.set_xlim(0, 120)
ax.set_ylim(-10, 10)
leg = ax.legend(loc='upper right', fontsize=fontsize_dict['legend'], ncol=2, bbox_to_anchor=(1.05,1.3,0,0), columnspacing=0.2, frameon=False)
# ax.plot(gaze_data.harptime.iloc[b:e] - t0, gaze_data.elevation_filt.iloc[b:e] - med_pos[1], c='indianred', label='Elevation')

for x in fig.axes:
    x.spines['top'].set_visible(False)
    x.spines['right'].set_visible(False)
fig.subplots_adjust(hspace=1, wspace=0.5)

from pathlib import Path
VERSION = 10
SAVE_ROOT = flz.get_data_root("processed", flexilims_session=flexilims_session) / "v1_manuscript_figures"/f"ver{VERSION}"
fig.savefig(SAVE_ROOT/ 'figsupp_eye.pdf', bbox_inches='tight')