# Example session

Notebook to try stuff before creating functions

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# select session
import pickle
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
from skimage.measure import EllipseModel
import flexiznam as flz
from v1_depth_analysis.config import PROJECT
import v1_depth_analysis as vda
from cottage_analysis.eye_tracking import analysis as analeyesis

In [None]:
raw_path = Path(flz.PARAMETERS["data_root"]["raw"])
processed_path = Path(flz.PARAMETERS["data_root"]["processed"])
flm_sess = flz.get_flexilims_session(project_id=PROJECT)

recordings = vda.get_recordings(protocol="SpheresPermTubeReward", flm_sess=flm_sess)
datasets = vda.get_datasets(
    recordings, dataset_type="camera", dataset_name_contains="_eye", flm_sess=flm_sess
)

In [None]:
camera = [ds for ds in  datasets if ds.genealogy[1] in ["S20220419", "S20220421"]]
camera = [ds for ds in camera if "right" in ds.dataset_name][3]
print(f"Analysing {' from '.join(camera.genealogy[::-1])}")

In [None]:
dlc_res, ellipse = analeyesis.get_data(
    camera,
    flexilims_session=flm_sess,
    likelihood_threshold=0.88,
    rsquare_threshold=0.99,
    error_threshold=3,
)
ellipse.head()


In [None]:
# Plot movie with ellipse fit
camera_save_folder = camera.path_full / camera.dataset_name
target_file =  camera_save_folder / "eye_tracking_ellipse_overlay.mp4"
video_file = camera.path_full / camera.extra_attributes['video_file']
if False:
    analeyesis.plot_movie(camera,
    target_file,
    start_frame=0,
    duration=6,
    dlc_res=None,
    ellipse=None,
    vmax=None,
    vmin=None,
    playback_speed=4,
)
else:
    cam_data = cv2.VideoCapture(str(video_file))
    ret, frame = cam_data.read()
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    cam_data.release()
    plt.imshow(gray, cmap='gray')


In [None]:
# plot of fit quality

import seaborn as sns
# DLC likelihood
data=dlc_res.xs('likelihood', axis='columns', level=2)
data.columns = data.columns.droplevel('scorer')
ax = sns.displot(
    data.drop(
        axis="columns",
        labels=[
            "reflection",
            "left_eye_corner",
            "right_eye_corner",
            "top_eye_lid",
            "bottom_eye_lid",
        ],
    )
)
plt.gca().set_xlabel("DLC likelihood")
plt.gcf().set_size_inches(5,5)
likelihood_threshold = 0.88
plt.axvline(likelihood_threshold, color='k')
plt.xlim(0.8, 1)
sns.jointplot(data=ellipse[ellipse.valid], x='error', y='rsquare')
sns.jointplot(data=ellipse[ellipse.valid], x='error', y='dlc_avg_likelihood')

In [None]:
fig = plt.figure(figsize=(15, 4))
elli = ellipse[ellipse.valid]
ax = fig.add_subplot(1, 3, 1)
sc = ax.scatter(elli.centre_x, elli.centre_y, c=np.rad2deg(elli.angle), vmax=70, vmin=50)
cb = fig.colorbar(ax=ax, mappable=sc)
cb.set_label('Ellipse angle (degrees)')
ax.set_xlabel('Ellipse centre X (pixels)')
ax.set_ylabel('Ellipse centre Y (pixels)')
ax.set_aspect('equal')
ax.invert_yaxis()
ax = fig.add_subplot(1, 3, 2)

count, bx, by = np.histogram2d(elli.centre_x, elli.centre_y, bins=(70, 70))
h, bx, by = np.histogram2d(elli.centre_x, elli.centre_y, weights=np.rad2deg(elli.angle), bins=(bx, by))
h[count < 1] = np.nan
img = ax.imshow((h/count).T, extent=(bx[0], bx[-1], by[0], by[-1]), vmin=55, vmax=65)
cb = fig.colorbar(mappable=img, ax=ax)
cb.set_label('Ellipse angle (degrees)')
ax.set_xlabel('Ellipse centre X (pixels)')
ax.set_ylabel('Ellipse centre Y (pixels)')

ax = fig.add_subplot(1, 3, 3)
h, bx, by = np.histogram2d(elli.centre_x, elli.centre_y, weights=elli.major_radius/elli.minor_radius, bins=(bx, by))
h[count < 1] = np.nan
img = ax.imshow((h/count).T, extent=(bx[0], bx[-1], by[0], by[-1]), vmin=1.1, vmax=1.4)
cb = fig.colorbar(mappable=img, ax=ax)
cb.set_label('Ellipse axes ratio')
ax.set_xlabel('Ellipse centre X (pixels)')
ax.set_ylabel('Ellipse centre Y (pixels)')

fig.subplots_adjust(wspace=0.5)

In [None]:
# get data 
def get_stim_info(rec):
    sess_children = flz.get_children(parent_id=rec.origin_id, flexilims_session=flm_sess, children_datatype='recording')
    rec_closeloop = sess_children[sess_children.protocol=="SpheresPermTubeReward"]
    assert len(rec_closeloop) == 1
    rec_closeloop = rec_closeloop.iloc[0]
    rec_playback =  sess_children[sess_children.protocol=="SpheresPermTubeRewardPlayback"]
    if len(rec_playback):
        rec_playback = rec_playback.iloc[0]
        print(f"Analysing {rec_closeloop.name}\n     and {rec_playback.name}", flush=True)
    else:
        rec_playback = 'NO PLAYBACK'
        print(f"Analysing {rec_closeloop.name} (no closed loop)", flush=True)
    
    sess_ds =  flz.get_children(parent_id=rec_closeloop.origin_id, flexilims_session=flm_sess, children_datatype='dataset')
    suite_2p = sess_ds[sess_ds.dataset_type=='suite2p_rois']
    assert len(suite_2p) == 1
    suite_2p = flz.Dataset.from_flexilims(data_series=suite_2p.iloc[0], flexilims_session=flm_sess)
    
    ops = np.load(suite_2p.path_full/'suite2p'/'plane0'/'ops.npy', allow_pickle=True).item()
    processed = Path(flz.PARAMETERS['data_root']['processed'])
    out = dict(ops=ops)
    with open(processed / rec_closeloop.path / 'img_VS.pickle', 'rb') as handle:
        out['closedloop'] = pickle.load(handle)
    with open(processed / rec_closeloop.path / 'stim_dict.pickle', 'rb') as handle:
        out['stim_dict_closedloop'] = pickle.load(handle)
    if rec_playback != 'NO PLAYBACK':
        with open(processed / rec_playback.path /'img_VS.pickle', 'rb') as handle:
            out['playback'] = pickle.load(handle)
        with open(processed / rec_playback.path /'stim_dict.pickle', 'rb') as handle:
            out['stim_dict_playback'] = pickle.load(handle)
    return out

recording = flz.get_entity(id=camera.origin_id, flexilims_session=flm_sess)
stim_params = get_stim_info(recording)
sampling = stim_params['ops']['fs']


In [None]:
# plot time course of eye
fig, axes = plt.subplots(4, 1)
fig.set_size_inches((10, 10))

valid = ellipse.valid
time = np.arange(len(dlc_res)) / sampling
reflection = dlc_res.xs(axis='columns', level=1, key='reflection')
reflection.columns = reflection.columns.droplevel('scorer')
for iax, ax in enumerate(axes):
    for w in ["x", "y"]:
        if iax > 1:
            ax.set_ylabel('Relative to reflection')
            d = (ellipse[f'centre_{w}'] - reflection[w])[valid]
        else:
            ax.set_ylabel('Raw')
            d = ellipse[f'centre_{w}'][valid]
        ax.plot(time[valid], d-np.nanmedian(d), label=fr"$\Delta${w.split('_')[0]}", lw=1)
axes[0].set_xlim(time[0], time[-1])
axes[2].set_xlim(time[0], time[-1])
for i in [1, 3]:
    axes[i].set_xlim(3000, 3000 + 60 * 2)
    axes[i].set_ylim(-15, 15)
ax.legend(loc='upper right')
ax.set_xlabel('Time (s)')



In [None]:
data = analeyesis.add_behaviour(camera, dlc_res, ellipse, speed_threshold=0.01, log_speeds=False)
data.head()

In [None]:
time = np.arange(len(data)) / sampling
xl = [3000, 3030]
xi = time.searchsorted(xl)
p = slice(*xi)
plt.subplot(2,1,1)
plt.plot(time[p], data.mvt[p])
plt.ylabel('Eye movement')
plt.subplot(2,1,2)
plt.plot(time[p], data.mvt[p])
plt.ylabel('Eye movement')
plt.ylim(0, 5)
plt.xlabel('Time (s)')

In [None]:
d = data[~np.isnan(data.mvt)].mvt
sns.displot(d[d>0], log_scale=True)
ax = plt.gca()
ax.set_xlabel('Eye movement (pixels between 2 frames)')
ax.axvline(3, color='k')

In [None]:
import matplotlib as mpl
from matplotlib import cm

depth_list = np.unique(data.depth)
cmap = cm.cool.reversed()
line_colors = []
norm = mpl.colors.Normalize(vmin=np.log(min(depth_list)), vmax=np.log(max(depth_list)))
col_dict = dict()
for depth in depth_list:
    rgba_color = cmap(norm(np.log(depth)),bytes=True)
    rgba_color = tuple(it/255 for it in rgba_color)
    line_colors.append(rgba_color)
    col_dict[depth] = rgba_color

fig, axes = plt.subplots(3, 1)
fig.set_size_inches(6, 10)
labels = ['X position', 'Y position', 'Distance to median position']
d = data[(~np.isnan(data.dx))& (~np.isnan(data.depth))]
for iw, w in enumerate(['dx', "dy", "d_med"]):
    
    sns.violinplot(data=d, x='depth', y=w, palette=line_colors, ax=axes[iw])
    axes[iw].set_ylabel(labels[iw])



In [None]:
import matplotlib as mpl
from matplotlib import cm

depth_list = np.unique(data.depth)
cmap = cm.cool.reversed()
line_colors = []
norm = mpl.colors.Normalize(vmin=np.log(min(depth_list)), vmax=np.log(max(depth_list)))
col_dict = dict()
for depth in depth_list:
    rgba_color = cmap(norm(np.log(depth)),bytes=True)
    rgba_color = tuple(it/255 for it in rgba_color)
    line_colors.append(rgba_color)
    col_dict[depth] = rgba_color

fig, axes = plt.subplots(3, 1)
fig.set_size_inches(5, 7)
labels = ['Motion', 'Small movements', 'Saccades']
data['saccade'] = data.mvt > 5
d = data[(~np.isnan(data.dx))& (~np.isnan(data.depth))]
for iw in range(2):
    if not iw:
        sns.violinplot(data=d, x='depth', y='mvt', palette=line_colors, ax=axes[iw])
    else:
        sns.violinplot(data=d[d.mvt<2], x='depth', y='mvt', palette=line_colors, ax=axes[iw])
    axes[iw].set_ylabel(labels[iw])

sac_per_depth = d.groupby('depth').saccade.aggregate(np.nansum)
sample_per_depth = d.groupby('depth').saccade.aggregate(len)

axes[2].bar(x=np.arange(len(sac_per_depth)), height=sac_per_depth/sample_per_depth * sampling, color=line_colors)
axes[2].set_xticks(np.arange(len(sac_per_depth)))
axes[2].set_xticklabels(sac_per_depth.index)
axes[2].set_ylabel('Saccades per second')
    



In [None]:
import matplotlib as mpl
from matplotlib import cm
lr = np.log10(data.rs)
data['running_bin'] = np.round(data.rs/10) * 10

running_bins = np.unique(data[data.valid]['running_bin'])
cmap = cm.viridis
rs_colors = []
norm = mpl.colors.Normalize(vmin=min(running_bins), vmax=max(running_bins))
for rb in running_bins:
    rgba_color = cmap(norm(rb),bytes=True)
    rgba_color = tuple(it/255 for it in rgba_color)
    rs_colors.append(rgba_color)


fig, axes = plt.subplots(3, 1)
fig.set_size_inches(5, 7)
labels = ['Motion', 'Small movements', 'Saccades']
data['saccade'] = data.mvt > 5
d = data[(~np.isnan(data.dx))& (~np.isnan(data.depth))]
for iw in range(2):
    if not iw:
        sns.violinplot(data=d, x='running_bin', y='mvt', palette='viridis', ax=axes[iw])
    else:
        sns.violinplot(data=d[d.mvt<2], x='running_bin', y='mvt', palette='viridis', ax=axes[iw])
    axes[iw].set_ylabel(labels[iw])

sac_per_depth = d.groupby('running_bin').saccade.aggregate(np.nansum)
sample_per_depth = d.groupby('running_bin').saccade.aggregate(len)

axes[2].bar(x=np.arange(len(sac_per_depth)), height=sac_per_depth/sample_per_depth * sampling, color=rs_colors)
axes[2].set_xticks(np.arange(len(sac_per_depth)))
axes[2].set_xticklabels(np.array(sac_per_depth.index, dtype=int))
axes[2].set_ylabel('Saccades per second')
    



In [None]:
from scipy.stats import mannwhitneyu
depth_list = np.unique(d.depth)
props = ['dx', "dy", "d_med"]
pval_mat = np.zeros([len(props)] + [len(depth_list)]*2)

for ix, dx in enumerate(depth_list):
    xdf = d[d.depth == dx]
    for iy, dy in enumerate(depth_list):
        for ip, p in enumerate(props):
            ydf = d[d.depth == dy]
            if ix == iy:
                pval_mat[ip, ix, iy] = 0
            else:
                w = mannwhitneyu(xdf[p].values, ydf[p].values)
                pval_mat[ip, ix, iy] = w.pvalue

fig, axes = plt.subplots(1, 3)
for ip, p in enumerate(props):
    axes[ip].imshow(pval_mat[ip] - 0.05, cmap='RdBu', origin='lower')

In [None]:
fig, axes = plt.subplots(1, 2)
for d, ddf in data.groupby('depth'):
    axes[0].errorbar(x=np.nanmedian(ddf.dx), y=np.nanmedian(ddf.dy), xerr=np.nanstd(ddf.dx), yerr=np.nanstd(ddf.dy), label=int(d), marker='o', color=col_dict[d])
    axes[1].errorbar(x=np.nanmean(ddf.dx), y=np.nanmean(ddf.dy), 
                xerr=np.nanstd(ddf.dx)/np.sqrt(np.sum(~np.isnan(ddf.dx))), 
                yerr=np.nanstd(ddf.dy)/np.sqrt(np.sum(~np.isnan(ddf.dy))), label=int(d), marker='.', lw=3,
                color=col_dict[d])

for ax in axes:
    ax.set_aspect('equal')
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
axes[0].set_title('Eye position (median +/- std)')
axes[1].set_title('Eye position (mean +/- std)')

In [None]:

img_VS = pd.merge_asof(
    img_VS,
    mousez_logger,
    on="HarpTime",
    allow_exact_matches=True,
    direction="backward",
)

img_VS.EyeZ = img_VS.EyeZ / 100  # Convert cm to m
img_VS.MouseZ = img_VS.MouseZ / 100  # Convert cm to m
img_VS.Depth = img_VS.Depth / 100  # Convert cm to m
img_VS.Z0 = img_VS.Z0 / 100  # Convert cm to m

depth_list = img_VS["Depth"].unique()
depth_list = np.round(depth_list, 2)
depth_list = depth_list[~np.isnan(depth_list)].tolist()
depth_list.remove(-99.99)
depth_list.sort()


In [None]:
import pickle
# print(img_VS[:20], flush=True)
# Save img_VS
with open(protocol_folder / "img_VS.pickle", "wb") as handle:
    pickle.dump(img_VS, handle, protocol=pickle.HIGHEST_PROTOCOL)
print("Timestamps aligned and saved.", flush=True)
print("---STEP 3 FINISHED.---", "\n", flush=True)

# -----STEP4: Get the visual stimulation structure and Save (find the imaging frames for visual stimulation)-----
print("---START STEP 4---", "\n", "Get vis-stim structure...", flush=True)
with open(protocol_folder / "img_VS.pickle", "rb") as handle:
    img_VS = pickle.load(handle)
from cottage_analysis.stimulus_structure import sphere_structure as vis_stim_structure
stim_dict = vis_stim_structure.create_stim_dict(
    depth_list=depth_list, img_VS=img_VS, choose_trials=None
)


In [None]:
img_VS.head()

In [None]:
img_VS.shape

In [None]:
dlc_res.shape