# Example session

Notebook to try stuff before creating functions

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# select session
import matplotlib as mpl
import seaborn as sns
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
import flexiznam as flz
from v1_depth_analysis.config import PROJECT
import v1_depth_analysis as vda
from cottage_analysis.eye_tracking import analysis as analeyesis
from cottage_analysis.eye_tracking import eye_model_fitting as emf


In [None]:
raw_path = Path(flz.PARAMETERS["data_root"]["raw"])
processed_path = Path(flz.PARAMETERS["data_root"]["processed"])
flm_sess = flz.get_flexilims_session(project_id=PROJECT)

recordings = vda.get_recordings(protocol="SpheresPermTubeReward", flm_sess=flm_sess)
datasets = vda.get_datasets(
    recordings, dataset_type="camera", dataset_name_contains="_eye", flm_sess=flm_sess
)


In [None]:
camera_full_name = "PZAG3.4f_S20220422_R130302_SpheresPermTubeReward_right_eye_camera"
# camera_full_name = 'PZAH6.4b_S20220419_R145152_SpheresPermTubeReward_left_eye_camera'
# camera_full_name = 'PZAH6.4b_S20220419_R145152_SpheresPermTubeReward_right_eye_camera'
camera = [ds for ds in datasets if ds.full_name == camera_full_name]
camera = camera[0]
print(f"Analysing {' from '.join(camera.genealogy[::-1])}")


# Get data

In [None]:
# get the data
dlc_res, ellipse = analeyesis.get_data(
    camera,
    flexilims_session=flm_sess,
    likelihood_threshold=0.88,
    rsquare_threshold=0.99,
    error_threshold=3,
)
data, sampling = analeyesis.add_behaviour(
    camera, dlc_res, ellipse, speed_threshold=0.01, log_speeds=False
)
assert "valid" in data.columns
data.head()


In [None]:
# Plot movie with ellipse fit
camera_save_folder = processed_path / camera.path / camera.dataset_name
target_file = camera_save_folder / "eye_tracking_ellipse_overlay.mp4"
video_file = camera.path_full / camera.extra_attributes["video_file"]
dlc_ds_name = "_".join(
    list(camera.genealogy[:-1]) + ["dlc_tracking", camera.dataset_name, "data", "0"]
)
dlc_ds = flz.Dataset.from_flexilims(name=dlc_ds_name, flexilims_session=flm_sess)
cropping = dlc_ds.extra_attributes["cropping"]

start_frame = 19286
vmin = 0
vmax = 150
crop_single_image = True

if False:
    analeyesis.plot_movie(
        camera,
        target_file,
        start_frame=start_frame,
        duration=2,
        dlc_res=dlc_res,
        ellipse=ellipse,
        vmax=vmax,
        vmin=vmin,
        playback_speed=4,
    )
else:
    cam_data = cv2.VideoCapture(str(video_file))
    cam_data.set(cv2.CAP_PROP_POS_FRAMES, start_frame - 1)
    ret, frame = cam_data.read()
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    cam_data.release()
    crop_shift = np.array([0, 0])
    if crop_single_image:
        assert cropping is not None
        gray = gray[cropping[2] : cropping[3], cropping[0] : cropping[1]]
    elif cropping is not None:
        crop_shift = np.array([cropping[0], cropping[2]])

    plt.imshow(gray, cmap="gray", vmin=vmin, vmax=vmax)
    plt.colorbar()
    track = dlc_res.loc[start_frame]
    track.index = track.index.droplevel(["scorer"])
    xdata = track.loc[[(f"eye_{i}", "x") for i in np.arange(1, 13)]]
    ydata = track.loc[[(f"eye_{i}", "y") for i in np.arange(1, 13)]]
    plt.scatter(xdata + crop_shift[0], ydata + crop_shift[1])
    plt.scatter(
        track.loc[(f"reflection", "x")] + crop_shift[0],
        track.loc[(f"reflection", "y")] + crop_shift[1],
    )

print(crop_shift)


In [None]:
dlc_data = dlc_res.xs("likelihood", axis="columns", level=2)
dlc_data.columns = dlc_data.columns.droplevel("scorer")
ax = sns.displot(dlc_data["reflection"].values)
plt.gca().set_xlabel("DLC likelihood")
plt.gcf().set_size_inches(5, 5)
plt.gca().semilogy()


In [None]:
# plot of fit quality

# DLC likelihood
dlc_data = dlc_res.xs("likelihood", axis="columns", level=2)
dlc_data.columns = dlc_data.columns.droplevel("scorer")
ax = sns.displot(
    dlc_data.drop(
        axis="columns",
        labels=[
            "reflection",
            "left_eye_corner",
            "right_eye_corner",
            "top_eye_lid",
            "bottom_eye_lid",
        ],
    )
)
plt.gca().set_xlabel("DLC likelihood")
plt.gcf().set_size_inches(5, 5)
likelihood_threshold = 0.88
plt.axvline(likelihood_threshold, color="k")
plt.xlim(0.8, 1)
sns.jointplot(data=ellipse[ellipse.valid], x="error", y="rsquare")
sns.jointplot(data=ellipse[ellipse.valid], x="error", y="dlc_avg_likelihood")


# Actual analysis

In [None]:
fig = plt.figure(figsize=(15, 4))
elli = data[data.valid]
angles = np.rad2deg(elli.angle)
anglim = np.quantile(angles, [0.01, 0.99])
ax = fig.add_subplot(1, 3, 1)
sc = ax.scatter(elli.pupil_x, elli.pupil_y, c=angles, vmax=anglim[1], vmin=anglim[0])
cb = fig.colorbar(ax=ax, mappable=sc)
cb.set_label("Ellipse angle (degrees)")
ax.set_xlabel("Ellipse pupil X (pixels)")
ax.set_ylabel("Ellipse pupil Y (pixels)")
ax.set_aspect("equal")
ax.invert_yaxis()
ax = fig.add_subplot(1, 3, 2)

count, bx, by = np.histogram2d(elli.pupil_x, elli.pupil_y, bins=(70, 70))
h, bx, by = np.histogram2d(elli.pupil_x, elli.pupil_y, weights=angles, bins=(bx, by))
h[count < 1] = np.nan
img = ax.imshow(
    (h / count).T, extent=(bx[0], bx[-1], by[0], by[-1]), vmax=anglim[1], vmin=anglim[0]
)
cb = fig.colorbar(mappable=img, ax=ax)
cb.set_label("Ellipse angle (degrees)")
ax.set_xlabel("Ellipse pupil X (pixels)")
ax.set_ylabel("Ellipse pupil Y (pixels)")

ax = fig.add_subplot(1, 3, 3)
ratio = elli.major_radius / elli.minor_radius
ralim = np.quantile(ratio, [0.01, 0.99])
h, bx, by = np.histogram2d(elli.pupil_x, elli.pupil_y, weights=ratio, bins=(bx, by))
h[count < 1] = np.nan
img = ax.imshow(
    (h / count).T, extent=(bx[0], bx[-1], by[0], by[-1]), vmin=ralim[0], vmax=ralim[1]
)
cb = fig.colorbar(mappable=img, ax=ax)
cb.set_label("Ellipse axes ratio")
ax.set_xlabel("Ellipse pupil X (pixels)")
ax.set_ylabel("Ellipse pupil Y (pixels)")

fig.subplots_adjust(wspace=0.4)


## Check ellipse orientation

To know what is 0 degrees and direction of rotation

In [None]:
from skimage.measure import EllipseModel

ellipse_model = EllipseModel()

# params are xc, yc, a, b, theta
fig, ax = plt.subplots(1, 1)
fig.set_size_inches(3, 3)
colors = "rgbk"
for ia, angle in enumerate(range(0, 180, 45)):
    ellipse_model.params = (0, 0, 2, 1, np.deg2rad(angle))
    circ_coord = ellipse_model.predict_xy(np.arange(0, 2 * np.pi, 0.1))
    ax.plot(circ_coord[:, 0], circ_coord[:, 1], label=angle, color=colors[ia])
    origin = np.array([0, 0])
    shift = np.array([np.cos(np.deg2rad(angle)), np.sin(np.deg2rad(angle))])
    ax.plot(*[(origin[i], shift[i]) for i in range(2)], color=colors[ia])
    shift_pi = np.array(
        [np.cos(np.deg2rad(angle) + np.pi / 2), np.sin(np.deg2rad(angle) + np.pi / 2)]
    )
    ax.plot(*[(origin[i], shift_pi[i]) for i in range(2)], color=colors[ia], ls="--")
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
ax.set_aspect("equal")


# Kehr calib

As in Wallace et al.

## Smooth data

To make fit easier, don't work on individual frames but group frames where the pupil
centre is similar

In [None]:
# make bins of ellipse centre position
elli = pd.DataFrame(data[data.valid], copy=True)
count, bin_edges_x, bin_edges_y = np.histogram2d(
    elli.pupil_x, elli.pupil_y, bins=(25, 25)
)
elli["bin_id_x"] = bin_edges_x.searchsorted(elli.pupil_x.values)
elli["bin_id_y"] = bin_edges_y.searchsorted(elli.pupil_y.values)

binned_ellipses = elli.groupby(["bin_id_x", "bin_id_y"])
ns = binned_ellipses.valid.aggregate(len)
binned_ellipses = binned_ellipses.aggregate(np.nanmedian)
enough_frames = binned_ellipses[ns > 10]


In [None]:
mat = np.zeros((len(ns.index.levels[0]), len(ns.index.levels[1]))) + np.nan
fig = plt.figure(figsize=(15, 5))
for ip, p in enumerate(["angle", "minor_radius", "major_radius"]):
    mat[
        enough_frames.index.get_level_values(0), enough_frames.index.get_level_values(1)
    ] = enough_frames[p]
    lim = np.nanquantile(mat, [0.01, 0.99])
    plt.subplot(1, 3, ip + 1)
    plt.imshow(mat, vmin=lim[0], vmax=lim[1])
    plt.colorbar()
    plt.title(p)


## Find eye centre

Defined as the intersection of all minor axis. Leastsquare solution

### Using all data

In [None]:

if False:
    d = data[data.valid]
    p = np.vstack([d[f"pupil_{a}"].values for a in "xy"])
    n = np.vstack([np.cos(d.angle), np.sin(d.angle)])
    intercept_minor = emf.pts_intersection(p, n)

    n = np.vstack([np.cos(d.angle + np.pi / 2), np.sin(d.angle + np.pi / 2)])
    intercept_major = emf.pts_intersection(p, n)
    axes_ratio = d.minor_radius.values / d.major_radius.values
    eye_centre_all = intercept_minor.flatten()

    delta_pts = np.vstack([d.pupil_x, d.pupil_y]) - eye_centre_all[:, np.newaxis]
    sum_sqrt_ratio = np.sum(
        np.sqrt((1 - axes_ratio**2)) * np.linalg.norm(delta_pts, axis=0)
    )
    sum_sq_ratio = np.sum(1 - axes_ratio**2)
    f_z0_all = sum_sqrt_ratio / sum_sq_ratio
    print(rf"Eye centre: {eye_centre_all}. f/z0: {f_z0_all}")

    # plot it
    cam_data = cv2.VideoCapture(str(video_file))
    cam_data.set(cv2.CAP_PROP_POS_FRAMES, start_frame - 1)
    ret, frame = cam_data.read()
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    gray = gray[cropping[2] : cropping[3], cropping[0] : cropping[1]]
    cam_data.release()
    assert np.sum(crop_shift) == 0
    fig, ax = plt.subplots(1, 1)
    fig.set_size_inches(10, 10)
    img = ax.imshow(gray, cmap="gray", vmin=vmin, vmax=vmax)
    fig.colorbar(img, ax=ax)
    track = dlc_res.loc[start_frame]
    track.index = track.index.droplevel(["scorer"])


    for i, series in data[data.valid].iloc[::100].iterrows():
        origin = np.array([series.pupil_x, series.pupil_y])
        ref = np.array([series.reflection_x, series.reflection_y])
        n_v = np.array([np.cos(series.angle + np.pi / 2), np.sin(series.angle + np.pi / 2)])
        rng = np.array([-200, 200])
        ax.plot(
            *[(origin[a] + ref[a] + n_v[a] * rng) for a in range(2)],
            color="purple",
            alpha=0.05,
            lw=1,
        )

    ref = data.loc[start_frame, ["reflection_x", "reflection_y"]].values
    ax.scatter(*ref, color="blue", marker="+")
    ax.plot(*(eye_centre_all + ref), color="g", marker="o")
    import matplotlib as mpl

    eye_all = mpl.patches.Circle(
        xy=eye_centre_all + ref, radius=f_z0_all, facecolor="none", edgecolor="g"
    )
    ax.add_artist(eye_all)
    ax.set_xlim(0, gray.shape[1])
    _ = ax.set_ylim(gray.shape[0], 0)


### Using smoothed data


In [None]:
p = np.vstack([enough_frames[f"pupil_{a}"].values for a in "xy"])
n = np.vstack([np.cos(enough_frames.angle.values), np.sin(enough_frames.angle.values)])
intercept_minor = emf.pts_intersection(p, n)
n = np.vstack(
    [np.cos(enough_frames.angle + np.pi / 2), np.sin(enough_frames.angle + np.pi / 2)]
)
axes_ratio = enough_frames.minor_radius.values / enough_frames.major_radius.values
eye_centre_binned = intercept_minor.flatten()

delta_pts = (
    np.vstack([enough_frames.pupil_x, enough_frames.pupil_y])
    - eye_centre_binned[:, np.newaxis]
)
sum_sqrt_ratio = np.sum(
    np.sqrt(1 - axes_ratio**2) * np.linalg.norm(delta_pts, axis=0)
)
sum_sq_ratio = np.sum(1 - axes_ratio**2)
f_z0_binned = sum_sqrt_ratio / sum_sq_ratio
print(rf"Eye centre: {eye_centre_binned}. f/z0: {f_z0_binned}")

# plot it
cam_data = cv2.VideoCapture(str(video_file))
cam_data.set(cv2.CAP_PROP_POS_FRAMES, start_frame - 1)
ret, frame = cam_data.read()
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
gray = gray[cropping[2] : cropping[3], cropping[0] : cropping[1]]
cam_data.release()
assert np.sum(crop_shift) == 0
fig, ax = plt.subplots(1, 1)
fig.set_size_inches(10, 10)
img = ax.imshow(gray, cmap="gray", vmin=vmin, vmax=vmax)
fig.colorbar(img, ax=ax)
track = dlc_res.loc[start_frame]
track.index = track.index.droplevel(["scorer"])


for i, series in enough_frames.iterrows():
    origin = np.array([series.pupil_x, series.pupil_y])
    ref = np.array([series.reflection_x, series.reflection_y])
    n_v = np.array([np.cos(series.angle + np.pi / 2), np.sin(series.angle + np.pi / 2)])
    rng = np.array([-200, 200])
    ax.plot(
        *[(origin[a] + ref[a] + n_v[a] * rng) for a in range(2)],
        color="purple",
        alpha=0.1,
        lw=1,
    )

ax.plot(*(eye_centre_binned + ref), color="g", marker="o")
eye_binned = mpl.patches.Circle(
    xy=(eye_centre_binned + ref),
    radius=f_z0_binned,
    facecolor="none",
    edgecolor="g",
)
ax.add_artist(eye_binned)
ax.set_xlim(0, gray.shape[1])
_ = ax.set_ylim(gray.shape[0], 0)


# Reprojection

Find where is the ellipse given $\theta$ and $\phi$

In [None]:
# fit median eye position with fine grid

ellipse_params_med = enough_frames.loc[
    :, ["pupil_x", "pupil_y", "major_radius", "minor_radius", "angle"]
].median(axis=0)

p0 = (0, 0, 1)
params_med, i, e = emf.minimise_reprojection_error(
    ellipse_params_med,
    p0,
    eye_centre_binned,
    f_z0_binned,
    p_range=(np.pi, np.pi, 0.5),
    grid_size=20,
    niter=5,
    reduction_factor=5,
    verbose=True,
)
phi, theta, radius = params_med

# Plot fit of median position

cam_data = cv2.VideoCapture(str(video_file))
cam_data.set(cv2.CAP_PROP_POS_FRAMES, start_frame - 1)
ret, frame = cam_data.read()
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
gray = gray[cropping[2] : cropping[3], cropping[0] : cropping[1]]
cam_data.release()

ax = plt.subplot(1, 1, 1)
ax.axis("off")
ax.imshow(gray, cmap="gray")

source_model = EllipseModel()
source_model.params = ellipse_params_med
circ_coord = source_model.predict_xy(np.arange(0, 2 * np.pi, 0.1)) + ref.reshape(1, 2)
ax.plot(circ_coord[:, 0], circ_coord[:, 1], label="DLC fit", color="lightblue")
fitted_model = emf.reproj_ellipse(
    phi=phi, theta=theta, r=radius, eye_centre=eye_centre_binned, f_z0=f_z0_binned
)
circ_coord = fitted_model.predict_xy(np.arange(0, 2 * np.pi, 0.1)) + ref.reshape(1, 2)
ax.plot(circ_coord[:, 0], circ_coord[:, 1], label="DLC fit", color="purple")
plt.tight_layout()


## Initial gaze estimate

We fit all ellipses of the binned matrix.

In [None]:
# optimise for all binned positions
eye_rotation_initial = np.zeros((len(enough_frames), 3))
grid_angles = np.deg2rad(np.arange(0, 360, 5))
grid_radius = np.arange(0.8, 1.2, 0.1)
for i_pos, (pos, s) in enumerate(enough_frames.iterrows()):
    ellipse_params = s[["pupil_x", "pupil_y", "major_radius", "minor_radius", "angle"]]
    p, i, e = emf.minimise_reprojection_error(
        ellipse_params,
        p0=params_med,
        eye_centre=eye_centre_binned,
        f_z0=f_z0_binned,
        p_range=(np.pi / 4, np.pi / 4, 0.5),
        grid_size=10,
        niter=5,
        reduction_factor=5,
        verbose=False,
    )
    eye_rotation_initial[i_pos] = p


In [None]:
mat = np.zeros((len(binned_ellipses.index.levels[0]), len(binned_ellipses.index.levels[1]), 3)) + np.nan
for i_pos, (pos, _) in enumerate(enough_frames.iterrows()):
    mat[pos[0], pos[1]] = eye_rotation_initial[i_pos]
fig = plt.figure(figsize=(15, 4))
labels = ['phi', 'theta', 'radius']
for i in range(3):
    plt.subplot(1, 3, 1 + i)
    if i < 2:
        d = np.rad2deg(mat[...,i])
    else:
        d =mat[...,i]
    plt.imshow(d)
    plt.title(labels[i])
    plt.colorbar()

## Optimize eye parameters

In [None]:
# Now optimise eye_centre and f_z0
# use a 3rd of the frames
source_ellipses = enough_frames.loc[::20,["pupil_x", "pupil_y", "major_radius", "minor_radius", "angle"]].values
gazes = eye_rotation_initial[::10]
(x, y, f_z0), ind, err = emf.optimise_eye_parameters(ellipses=source_ellipses, gazes=gazes,
    p0=(*eye_centre_binned, f_z0_binned),
    p_range=(70, 70, 50),
    grid_size=7,
    niter=3,
    reduction_factor=3,
    verbose=True,
)
eye_centre = np.array([x,y])


In [None]:
# replot median eye posisiton with better eye
cam_data = cv2.VideoCapture(str(video_file))
cam_data.set(cv2.CAP_PROP_POS_FRAMES, start_frame - 1)
ret, frame = cam_data.read()
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
gray = gray[cropping[2] : cropping[3], cropping[0] : cropping[1]]
cam_data.release()

ax = plt.subplot(1, 1, 1)
ax.axis("off")
ax.imshow(gray, cmap="gray")

source_model = EllipseModel()
source_model.params = ellipse_params_med
circ_coord = source_model.predict_xy(np.arange(0, 2 * np.pi, 0.1)) + ref.reshape(1, 2)
ax.scatter(*(eye_centre + ref), color='purple')
ax.scatter(*(eye_centre_binned + ref), color='orange')
ax.plot(circ_coord[:, 0], circ_coord[:, 1], label="DLC fit", color="lightblue")
fitted_model = emf.reproj_ellipse(
    phi=phi, theta=theta, r=radius, eye_centre=eye_centre_binned, f_z0=f_z0_binned
)
circ_coord = fitted_model.predict_xy(np.arange(0, 2 * np.pi, 0.1)) + ref.reshape(1, 2)
ax.plot(circ_coord[:, 0], circ_coord[:, 1], label="DLC fit", color="orange")

fitted_model = emf.reproj_ellipse(
    phi=phi, theta=theta, r=radius, eye_centre=eye_centre, f_z0=f_z0
)
circ_coord = fitted_model.predict_xy(np.arange(0, 2 * np.pi, 0.1)) + ref.reshape(1, 2)
ax.plot(circ_coord[:, 0], circ_coord[:, 1], label="DLC fit", color="purple")
plt.tight_layout()


# Fit gaze

for all frames

In [None]:
# optimise for all frames

eye_rotation = np.zeros((len(data), 3))
eye_rotation[~data.valid] += np.nan
grid_angles = np.deg2rad(np.arange(0, 360, 5))
grid_radius = np.arange(0.8, 1.2, 0.1)

for i_pos, series in data.iterrows():
    if np.mod(i_pos, 1000)==0:
        print(i_pos)
    if not series.valid:
        continue
    ellipse_params = series[["pupil_x", "pupil_y", "major_radius", "minor_radius", "angle"]]
    pa, i, e = emf.minimise_reprojection_error(
        ellipse_params,
        p0=params_med,
        eye_centre=eye_centre,
        f_z0=f_z0,
        p_range=(np.pi / 2, np.pi / 2, 0.5),
        grid_size=10,
        niter=3,
        reduction_factor=5,
        verbose=False,
    )
    eye_rotation[i_pos] = pa


In [None]:
plt.hist(eye_rotation[:,0])
plt.hist(eye_rotation[:,1])

# Transform to world coordinates

In [None]:
# get calibration data
calibration_folder = processed_path / PROJECT / 'Calibrations' 

calib_data = dict()
for cam_name in ['RightEyeCam', 'LeftEyeCam']:
    calib_data[cam_name.lower()] = dict()
    folder = calibration_folder / cam_name 
    folder = list(folder.glob('*xtrinsics_flat'))[0]  # case is inconsistent
    folder = folder / '20220818' / 'aruco5_5mm'
    assert folder.exists()
    for trial in folder.glob('trial*'):
        fname = str(trial / 'camera_extrinsics_flat.yml')
        s = cv2.FileStorage()
        s.open(fname, cv2.FileStorage_READ)
        rvec= s.getNode('rvec').mat()
        tvec = s.getNode('tvec').mat()
        calib_data[cam_name.lower()][trial.name] = dict(rvec=rvec, tvec=tvec)

In [None]:
# take median across trials
extrinsics = dict()
for cam, trials in calib_data.items():
    extrinsics[cam] = dict()
    for w in ['rvec', 'tvec']:
        extrinsics[cam][w] = np.median(np.vstack([d[w].flatten() for d in trials.values()]), axis=0)
extrinsics[cam]


In [None]:
# get the camera we need for this acq and build tform matrix
extrin = extrinsics[camera.dataset_name.replace('_', '')[:-3]]
tform = np.zeros((4, 4))
tform[3, 3] = 1
rmat, jac = cv2.Rodrigues(extrin['rvec'])
tform[:3, :3] = rmat
tform[:3, 3] = extrin['tvec']
print(np.round(tform, 2))


In [None]:
# get gaze vectors and rotate them
gaze_vec = np.vstack([emf.get_gaze_vector(p[0], p[1]) for p in eye_rotation])
rotated_gaze = (rmat @ gaze_vec.T).T

plt.figure(figsize=(10, 4))
plt.subplot(1,2,1)
plt.hist(np.rad2deg(gaze_vec), histtype='step')
plt.subplot(1,2,2)
_ = plt.hist(np.rad2deg(rotated_gaze), histtype='step')


In [None]:
azimuth = np.arctan2(rotated_gaze[:, 1], rotated_gaze[:, 0])
elevation = np.arctan2(rotated_gaze[:, 2], np.sum(rotated_gaze[:, :2]**2, axis=1))
# zero the median pos
azimuth -= np.nanmedian(azimuth)
elevation -= np.nanmedian(elevation)

# put back in -pi pi
azimuth = np.mod(azimuth + np.pi, 2*np.pi) - np.pi
elevation = np.mod(elevation + np.pi, 2*np.pi) - np.pi


plt.hist(azimuth, bins=np.arange(-np.pi, np.pi, np.pi/12))
plt.hist(elevation, bins=np.arange(-np.pi, np.pi, np.pi/12))

In [None]:
fig = plt.figure(figsize=(10, 10))
for iax, vec in enumerate([gaze_vec, rotated_gaze]):
    for ia in range(2):
        ax= fig.add_subplot(2,2, 1+iax*2 +ia, projection='3d')
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        med_gaze = np.nanmedian(vec, axis=0)
        for g in vec:
            ax.plot(*[(0, i) for i in g], color='k')
        c= azimuth if ia else elevation
        sc = ax.scatter(*vec.T, c=np.rad2deg(c))
        cb =fig.colorbar(sc, ax=ax)
        cb.set_label('Azimuth' if ia else 'Elevation')
        

In [None]:
mat = np.zeros((len(binned_ellipses.index.levels[0]), len(binned_ellipses.index.levels[1]), 2)) + np.nan
for i_pos, (pos, _) in enumerate(enough_frames.iterrows()):
    mat[pos[0], pos[1]] = [azimuth[i_pos], elevation[i_pos]]
fig = plt.figure(figsize=(15, 4))
labels = ['azimuth', 'elevation']
for i, v in enumerate([azimuth, elevation]):
    plt.subplot(1, 2, 1 + i)
    plt.imshow(np.rad2deg(mat[..., i]))
    plt.title(labels[i])
    plt.colorbar()

In [None]:
mat = np.zeros((len(binned_ellipses.index.levels[0]), len(binned_ellipses.index.levels[1]), 2)) + np.nan
for i_pos, (pos, _) in enumerate(enough_frames.iterrows()):
    mat[pos[0], pos[1]] = [azimuth[i_pos], elevation[i_pos]]
fig = plt.figure(figsize=(15, 4))
labels = ['aximuth', 'elevation']
for i, v in enumerate([azimuth, elevation]):
    plt.subplot(1, 2, 1 + i)
    plt.imshow(np.rad2deg(mat[..., i]))
    plt.title(labels[i])
    plt.colorbar()

In [None]:
# replot median eye posisiton with better eye
eye_centre = np.array([x,y])
f_z0 = f

cam_data = cv2.VideoCapture(str(video_file))
cam_data.set(cv2.CAP_PROP_POS_FRAMES, start_frame - 1)
ret, frame = cam_data.read()
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
gray = gray[cropping[2] : cropping[3], cropping[0] : cropping[1]]
cam_data.release()

ax = plt.subplot(1, 1, 1)
ax.axis("off")
ax.imshow(gray, cmap="gray")

source_model = EllipseModel()
source_model.params = ellipse_params_med
circ_coord = source_model.predict_xy(np.arange(0, 2 * np.pi, 0.1)) + ref.reshape(1, 2)
ax.scatter(*(eye_centre + ref), color='purple')
ax.scatter(*(eye_centre_binned + ref), color='g')
ax.plot(circ_coord[:, 0], circ_coord[:, 1], label="DLC fit", color="lightblue")
fitted_model = emf.reproj_ellipse(
    phi=phi, theta=theta, r=radius, eye_centre=eye_centre_binned, f_z0=f_z0_binned
)
circ_coord = fitted_model.predict_xy(np.arange(0, 2 * np.pi, 0.1)) + ref.reshape(1, 2)
ax.plot(circ_coord[:, 0], circ_coord[:, 1], label="DLC fit", color="g")

fitted_model = emf.reproj_ellipse(
    phi=phi, theta=theta, r=radius, eye_centre=eye_centre, f_z0=f_z0
)
circ_coord = fitted_model.predict_xy(np.arange(0, 2 * np.pi, 0.1)) + ref.reshape(1, 2)
ax.plot(circ_coord[:, 0], circ_coord[:, 1], label="DLC fit", color="purple")
plt.tight_layout()


In [None]:
from scipy import optimize
from cottage_analysis.eye_tracking import eye_model_fitting as emf


def fit_ellipse(source_ellipse, eye_centre, f_z0, p0=(0, 0, 1)):
    if not isinstance(source_ellipse, EllipseModel):
        model1 = EllipseModel()
        model1.params = source_ellipse
    else:
        model1 = source_ellipse

    def cost(p):
        el = emf.reproj_ellipse(*p, eye_centre=eye_centre, f_z0=f_z0)
        return emf.ellipse_distance(model1, el)

    res = optimize.minimize(cost, x0=p0, tol=1e-3)
    return res


def fit_all_ellipses(source_ellipses, ellipses_p0s, eye_centre_init, f_z0_init):
    p0 = (*eye_centre_init, f_z0_init)

    def cost(p):
        errs = [
            fit_ellipse(s, p[0], p[1], p0=ep0).fun
            for s, ep0 in zip(source_ellipses, ellipses_p0s)
        ]
        return np.sum(errs)

    res = optimize.minimize(cost, x0=p0, tol=1e-3)
    return res


In [None]:
# attempt with minize fails
if False:
    ellipses_to_fit = []
    for _, p in enough_frames.iterrows():
        ell = EllipseModel()
        ell.params = p[["pupil_x", "pupil_y", "major_radius", "minor_radius", "angle"]]
        ellipses_to_fit.append(ell)
    res = fit_all_ellipses(
        ellipses_to_fit,
        ellipses_p0s=params_all,
        eye_centre_init=eye_centre_binned,
        f_z0_init=f_z0_binned,
    )


In [None]:
res


In [None]:
# optimise all ellipse fits with current eye_centre and f_z0
ellipse_params_med = enough_frames.loc[
    :, ["pupil_x", "pupil_y", "major_radius", "minor_radius", "angle"]
].median(axis=0)
p0 = fit_ellipse(
    ellipse_params_med.values, eye_centre=eye_centre_binned, f_z0=f_z0_binned
).x
params_all = np.zeros((len(enough_frames), len(p0)))
for line, series in enough_frames.iterrows():
    res = fit_ellipse(
        series.loc[["pupil_x", "pupil_y", "major_radius", "minor_radius", "angle"]],
        eye_centre=eye_centre_binned,
        f_z0=f_z0_binned,
    )
    params_all[line, :] = res.x


In [None]:
# plot time course of eye
fig, axes = plt.subplots(4, 1)
fig.set_size_inches((10, 10))

valid = ellipse.valid
time = np.arange(len(dlc_res)) / sampling
reflection = dlc_res.xs(axis="columns", level=1, key="reflection")
reflection.columns = reflection.columns.droplevel("scorer")
for iax, ax in enumerate(axes):
    for w in ["x", "y"]:
        if iax > 1:
            ax.set_ylabel("Relative to reflection")
            d = (ellipse[f"centre_{w}"] - reflection[w])[valid]
        else:
            ax.set_ylabel("Raw")
            d = ellipse[f"centre_{w}"][valid]
        ax.plot(
            time[valid], d - np.nanmedian(d), label=rf"$\Delta${w.split('_')[0]}", lw=1
        )
axes[0].set_xlim(time[0], time[-1])
axes[2].set_xlim(time[0], time[-1])
for i in [1, 3]:
    axes[i].set_xlim(3000, 3000 + 60 * 2)
    axes[i].set_ylim(-15, 15)
ax.legend(loc="upper right")
ax.set_xlabel("Time (s)")


In [None]:
import matplotlib as mpl
from matplotlib import cm

depth_list = np.unique(data.depth)
cmap = cm.cool.reversed()
line_colors = []
norm = mpl.colors.Normalize(vmin=np.log(min(depth_list)), vmax=np.log(max(depth_list)))
col_dict = dict()
for depth in depth_list:
    rgba_color = cmap(norm(np.log(depth)), bytes=True)
    rgba_color = tuple(it / 255 for it in rgba_color)
    line_colors.append(rgba_color)
    col_dict[depth] = rgba_color

fig, axes = plt.subplots(3, 1)
fig.set_size_inches(6, 10)
labels = ["X position", "Y position", "Distance to median position"]
d = data[(~np.isnan(data.dx)) & (~np.isnan(data.depth))]
for iw, w in enumerate(["dx", "dy", "d_med"]):

    sns.violinplot(data=d, x="depth", y=w, palette=line_colors, ax=axes[iw])
    axes[iw].set_ylabel(labels[iw])


In [None]:
import matplotlib as mpl
from matplotlib import cm

depth_list = np.unique(data.depth)
cmap = cm.cool.reversed()
line_colors = []
norm = mpl.colors.Normalize(vmin=np.log(min(depth_list)), vmax=np.log(max(depth_list)))
col_dict = dict()
for depth in depth_list:
    rgba_color = cmap(norm(np.log(depth)), bytes=True)
    rgba_color = tuple(it / 255 for it in rgba_color)
    line_colors.append(rgba_color)
    col_dict[depth] = rgba_color

fig, axes = plt.subplots(3, 1)
fig.set_size_inches(5, 7)
labels = ["Motion", "Small movements", "Saccades"]
data["saccade"] = data.mvt > 5
d = data[(~np.isnan(data.dx)) & (~np.isnan(data.depth))]
for iw in range(2):
    if not iw:
        sns.violinplot(data=d, x="depth", y="mvt", palette=line_colors, ax=axes[iw])
    else:
        sns.violinplot(
            data=d[d.mvt < 2], x="depth", y="mvt", palette=line_colors, ax=axes[iw]
        )
    axes[iw].set_ylabel(labels[iw])

sac_per_depth = d.groupby("depth").saccade.aggregate(np.nansum)
sample_per_depth = d.groupby("depth").saccade.aggregate(len)

axes[2].bar(
    x=np.arange(len(sac_per_depth)),
    height=sac_per_depth / sample_per_depth * sampling,
    color=line_colors,
)
axes[2].set_xticks(np.arange(len(sac_per_depth)))
axes[2].set_xticklabels(sac_per_depth.index)
axes[2].set_ylabel("Saccades per second")


In [None]:
import matplotlib as mpl
from matplotlib import cm

lr = np.log10(data.rs)
data["running_bin"] = np.round(data.rs / 10) * 10

running_bins = np.unique(data[data.valid]["running_bin"])
cmap = cm.viridis
rs_colors = []
norm = mpl.colors.Normalize(vmin=min(running_bins), vmax=max(running_bins))
for rb in running_bins:
    rgba_color = cmap(norm(rb), bytes=True)
    rgba_color = tuple(it / 255 for it in rgba_color)
    rs_colors.append(rgba_color)


fig, axes = plt.subplots(3, 1)
fig.set_size_inches(5, 7)
labels = ["Motion", "Small movements", "Saccades"]
data["saccade"] = data.mvt > 5
d = data[(~np.isnan(data.dx)) & (~np.isnan(data.depth))]
for iw in range(2):
    if not iw:
        sns.violinplot(data=d, x="running_bin", y="mvt", palette="viridis", ax=axes[iw])
    else:
        sns.violinplot(
            data=d[d.mvt < 2], x="running_bin", y="mvt", palette="viridis", ax=axes[iw]
        )
    axes[iw].set_ylabel(labels[iw])

sac_per_depth = d.groupby("running_bin").saccade.aggregate(np.nansum)
sample_per_depth = d.groupby("running_bin").saccade.aggregate(len)

axes[2].bar(
    x=np.arange(len(sac_per_depth)),
    height=sac_per_depth / sample_per_depth * sampling,
    color=rs_colors,
)
axes[2].set_xticks(np.arange(len(sac_per_depth)))
axes[2].set_xticklabels(np.array(sac_per_depth.index, dtype=int))
axes[2].set_ylabel("Saccades per second")


In [None]:
from scipy.stats import mannwhitneyu

depth_list = np.unique(d.depth)
props = ["dx", "dy", "d_med"]
pval_mat = np.zeros([len(props)] + [len(depth_list)] * 2)

for ix, dx in enumerate(depth_list):
    xdf = d[d.depth == dx]
    for iy, dy in enumerate(depth_list):
        for ip, p in enumerate(props):
            ydf = d[d.depth == dy]
            if ix == iy:
                pval_mat[ip, ix, iy] = 0
            else:
                w = mannwhitneyu(xdf[p].values, ydf[p].values)
                pval_mat[ip, ix, iy] = w.pvalue

fig, axes = plt.subplots(1, 3)
for ip, p in enumerate(props):
    axes[ip].imshow(pval_mat[ip] - 0.05, cmap="RdBu", origin="lower")


In [None]:
fig, axes = plt.subplots(1, 2)
for d, ddf in data.groupby("depth"):
    axes[0].errorbar(
        x=np.nanmedian(ddf.dx),
        y=np.nanmedian(ddf.dy),
        xerr=np.nanstd(ddf.dx),
        yerr=np.nanstd(ddf.dy),
        label=int(d),
        marker="o",
        color=col_dict[d],
    )
    axes[1].errorbar(
        x=np.nanmean(ddf.dx),
        y=np.nanmean(ddf.dy),
        xerr=np.nanstd(ddf.dx) / np.sqrt(np.sum(~np.isnan(ddf.dx))),
        yerr=np.nanstd(ddf.dy) / np.sqrt(np.sum(~np.isnan(ddf.dy))),
        label=int(d),
        marker=".",
        lw=3,
        color=col_dict[d],
    )

for ax in axes:
    ax.set_aspect("equal")
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
axes[0].set_title("Eye position (median +/- std)")
axes[1].set_title("Eye position (mean +/- std)")


In [None]:
img_VS = pd.merge_asof(
    img_VS,
    mousez_logger,
    on="HarpTime",
    allow_exact_matches=True,
    direction="backward",
)

img_VS.EyeZ = img_VS.EyeZ / 100  # Convert cm to m
img_VS.MouseZ = img_VS.MouseZ / 100  # Convert cm to m
img_VS.Depth = img_VS.Depth / 100  # Convert cm to m
img_VS.Z0 = img_VS.Z0 / 100  # Convert cm to m

depth_list = img_VS["Depth"].unique()
depth_list = np.round(depth_list, 2)
depth_list = depth_list[~np.isnan(depth_list)].tolist()
depth_list.remove(-99.99)
depth_list.sort()


In [None]:
import pickle

# print(img_VS[:20], flush=True)
# Save img_VS
with open(protocol_folder / "img_VS.pickle", "wb") as handle:
    pickle.dump(img_VS, handle, protocol=pickle.HIGHEST_PROTOCOL)
print("Timestamps aligned and saved.", flush=True)
print("---STEP 3 FINISHED.---", "\n", flush=True)

# -----STEP4: Get the visual stimulation structure and Save (find the imaging frames for visual stimulation)-----
print("---START STEP 4---", "\n", "Get vis-stim structure...", flush=True)
with open(protocol_folder / "img_VS.pickle", "rb") as handle:
    img_VS = pickle.load(handle)
from cottage_analysis.stimulus_structure import sphere_structure as vis_stim_structure

stim_dict = vis_stim_structure.create_stim_dict(
    depth_list=depth_list, img_VS=img_VS, choose_trials=None
)


In [None]:
img_VS.head()


In [None]:
img_VS.shape


In [None]:
dlc_res.shape
