# Example session

Notebook to try stuff before creating functions

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# select session
import pickle
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import flexiznam as flz
from v1_depth_analysis.config import PROJECT, EYE_TRACKING_SESSIONS
import v1_depth_analysis as vda


sess_num = 0
mouse = EYE_TRACKING_SESSIONS[sess_num][0]
session = EYE_TRACKING_SESSIONS[sess_num][1]
camera = "right_eye"

# Get data

In [None]:
from v1_depth_analysis.eye_tracking import analyse_eye

data, sampling, dlc_ds, dlc_res, camera_ds = analyse_eye.get_eye_tracking_data(
    mouse, session, project=PROJECT, camera="right_eye", verbose=True
)

In [None]:
dlc_data = dlc_res.xs("likelihood", axis="columns", level=2)
dlc_data.columns = dlc_data.columns.droplevel("scorer")
fig = plt.figure(figsize=(7, 3))
ax = fig.add_subplot(111)
ax.hist(dlc_data["reflection"].iloc[:-3].values[data.valid], bins=1000)
ax.set_xlabel("DLC likelihood")
ax.semilogy()

In [None]:
# look for momement where the reflection is wrongly detected
fig = plt.figure(figsize=(10, 3))

# add a column in data with the distance to median reflection position
data["reflection_distance"] = np.sqrt(
    (data["reflection_x"] - data["reflection_x"].median()) ** 2
    + (data["reflection_y"] - data["reflection_y"].median()) ** 2
)
elli = data[data.valid]
for i, w in enumerate(
    ["reflection_likelihood", "reflection_x", "reflection_y", "reflection_distance"]
):
    ax = fig.add_subplot(1, 4, 1 + i)
    sc = ax.scatter(elli.pupil_x, elli.pupil_y, c=elli[w], s=1)
    cb = fig.colorbar(ax=ax, mappable=sc)
    cb.set_label(w)

In [None]:
np.rad2deg(1)

# Actual analysis

In [None]:
fig = plt.figure(figsize=(15, 4))
elli = data[data.valid]
angles = np.rad2deg(elli.angle)
anglim = np.quantile(angles, [0.001, 0.99])
ax = fig.add_subplot(1, 3, 1)
sc = ax.scatter(elli.pupil_x, elli.pupil_y, c=angles, vmax=anglim[1], vmin=anglim[0])
cb = fig.colorbar(ax=ax, mappable=sc)
cb.set_label("Ellipse angle (degrees)")
ax.set_xlabel("Ellipse pupil X (pixels)")
ax.set_ylabel("Ellipse pupil Y (pixels)")
ax.set_aspect("equal")
ax.invert_yaxis()
ax = fig.add_subplot(1, 3, 2)

count, bx, by = np.histogram2d(elli.pupil_x, elli.pupil_y, bins=(70, 70))
h, bx, by = np.histogram2d(elli.pupil_x, elli.pupil_y, weights=angles, bins=(bx, by))
h[count < 1] = np.nan
img = ax.imshow(
    (h / count).T, extent=(bx[0], bx[-1], by[0], by[-1]), vmax=anglim[1], vmin=anglim[0]
)
cb = fig.colorbar(mappable=img, ax=ax)
cb.set_label("Ellipse angle (degrees)")
ax.set_xlabel("Ellipse pupil X (pixels)")
ax.set_ylabel("Ellipse pupil Y (pixels)")

ax = fig.add_subplot(1, 3, 3)
ratio = elli.major_radius / elli.minor_radius
ralim = np.quantile(ratio, [0.01, 0.99])
h, bx, by = np.histogram2d(elli.pupil_x, elli.pupil_y, weights=ratio, bins=(bx, by))
h[count < 1] = np.nan
img = ax.imshow(
    (h / count).T, extent=(bx[0], bx[-1], by[0], by[-1]), vmin=ralim[0], vmax=ralim[1]
)
cb = fig.colorbar(mappable=img, ax=ax)
cb.set_label("Ellipse axes ratio")
ax.set_xlabel("Ellipse pupil X (pixels)")
ax.set_ylabel("Ellipse pupil Y (pixels)")

fig.subplots_adjust(wspace=0.4)

In [None]:
data.columns

In [None]:
# plot time course of eye
fig, axes = plt.subplots(4, 1)
fig.set_size_inches((10, 10))

valid = data.valid
time = np.arange(len(data)) / sampling
reflection = data[[f"reflection_{w}" for w in ["x", "y"]]]
reflection.columns = ["x", "y"]
for iax, ax in enumerate(axes):
    for w in ["x", "y"]:
        if iax > 1:
            ax.set_ylabel("Relative to reflection")
            d = (data[f"centre_{w}"] - reflection[w])[valid]
        else:
            ax.set_ylabel("Raw")
            d = data[f"centre_{w}"][valid]
        ax.plot(
            time[valid], d - np.nanmedian(d), label=rf"$\Delta${w.split('_')[0]}", lw=1
        )
axes[0].set_xlim(time[0], time[-1])
axes[2].set_xlim(time[0], time[-1])
for i in [1, 3]:
    axes[i].set_xlim(3000, 3000 + 60 * 2)
    axes[i].set_ylim(-15, 15)
ax.legend(loc="upper right")
ax.set_xlabel("Time (s)")

In [None]:
from cottage_analysis.eye_tracking import diagnostics
from cottage_analysis.eye_tracking import eye_model_fitting as emf
from cottage_analysis.eye_tracking import eye_io

flm_sess = flz.get_flexilims_session(project_id=PROJECT)

min_frame_cutoff = 10
binned_ellipses, bedg_x, bedg_y = emf.bin_ellipse_by_position(
    data[data.valid], nbins=(25, 25)
)
enough_frames = binned_ellipses[binned_ellipses.n_frames_in_bin > min_frame_cutoff]
dlc_tracks = eye_io.get_tracking_datasets(camera_ds, flexilims_session=flm_sess)
cropping = dlc_ds.extra_attributes["cropping"]
processed = flz.get_data_root("processed", project=PROJECT)
raw = flz.get_data_root("raw", project=PROJECT)
save_folder = processed / camera_ds.path_full.parent.relative_to(raw)

fig = diagnostics.plot_binned_ellipse_params(
    binned_ellipses,
    binned_ellipses["n_frames_in_bin"],
    save_folder=save_folder,
    min_frame_cutoff=min_frame_cutoff,
    fig_title=camera_ds.full_name,
    camera_ds=camera_ds,
    cropping=cropping,
    bin_edges_y=bedg_y,
    bin_edges_x=bedg_x,
    col2plot=["angle", "azimuth", "elevation"],
)

In [None]:
data.columns

In [None]:
# add a measure of angular velocity to the dataframe
data["angular_velocity"] = np.nan
eye_gaze = data[["azimuth", "elevation"]].values
gaze_change = np.diff(eye_gaze, axis=0)
data.loc[1:, "angular_velocity"] = np.sqrt(np.sum(gaze_change ** 2, axis=1)) * sampling
_ = plt.hist(data.angular_velocity[data.valid], bins=100)
plt.ylim(0, 100)

In [None]:
# Make a small figure for one session:
# histogram of azimuth and elevation distribution
# Median position of the eye by depth
# Average eye speed by depth
from cottage_analysis.plotting.basic_vis_plots import get_depth_color

fig, axes  = plt.subplots(2, 2, figsize=(10, 10))
fig.suptitle(f"{mouse} {session} {camera}")
# Azimuth and elevation distribution
for i, w in enumerate(["azimuth", "elevation"]):
    ax = axes[0, i]
    ax.hist(data[w], bins=100, color=[0.1, 0.1, 0.1])
    ax.set_xlabel(w)
    ax.set_ylabel("Count")
    ax.set_title(f"{w} distribution")

# Median position of the eye by depth
depths = data.depth.dropna().unique()
avg_by_depth = data.groupby("depth").median()
std_by_depth = data.groupby("depth").std()
n_by_depth = data.groupby("depth").size()

ax = axes[1, 0]
for i, d in enumerate(sorted(depths)):
    az = avg_by_depth.loc[d, "azimuth"]
    el = avg_by_depth.loc[d, "elevation"]
    azerr = std_by_depth.loc[d, "azimuth"] 
    elerr = std_by_depth.loc[d, "elevation"]
    axes[0, 0].axvline(az, color=get_depth_color(d, depths), linestyle="--", alpha=0.7, lw=0.5)
    axes[0, 1].axvline(el, color=get_depth_color(d, depths), linestyle="--",  alpha=0.7, lw=0.5)

    ax.errorbar(az, el, xerr=azerr/ np.sqrt(n_by_depth.loc[d]), yerr=elerr/ np.sqrt(n_by_depth.loc[d]), color=get_depth_color(d, depths),
    marker="o", label=f"{d:.2f}m", lw=2)
ax.set_xlabel("Azimuth")
ax.set_ylabel("Elevation")
ax.set_aspect("equal")
ax.set_title(f"Median eye position by depth")

# Average eye speed by depth
ax = axes[1, 1]
for i, d in enumerate(sorted(depths)):
    speed = data[data.depth == d].angular_velocity
    ax.hist(speed, bins=np.arange(-5, 201, 15), color=get_depth_color(d, depths), alpha=0.5, label=f"{d:.0f}m",
    histtype="step", lw=2, density=True)
ax.set_xlabel("Angular velocity (deg/s)")
ax.set_ylabel("Density")
ax.legend()

In [None]:
save_folder

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

fig = plt.figure(figsize=(17.7165, 7.87402))

ax = plt.subplot2grid((5, 5), (4, 4))

# Eye tracking video example
ax = plt.subplot2grid((5, 6), (0, 0), rowspan=2, colspan=2)
labels = ["azimuth", "elevation"]
v = binned_ellipses[["pupil_x", "pupil_y"]].values
lims = np.vstack([np.nanmin(v, axis=0), np.nanmax(v, axis=0)]) + ref
circ_coord = fitted_model.predict_xy(np.arange(0, 2 * np.pi, 0.1)) + ref.reshape(1, 2)
mat = (
    np.zeros(
        (len(binned_ellipses.index.levels[0]), len(binned_ellipses.index.levels[1]), 2)
    )
    + np.nan
)
for i_pos, (pos, _) in enumerate(binned_ellipses.iterrows()):
    mat[pos[0], pos[1]] = [azimuth[i_pos], elevation[i_pos]]

ax.imshow(gray, cmap="gray", vmin=20, vmax=100, zorder=-1)
img = ax.imshow(
    np.rad2deg(mat[..., 0]),
    extent=np.hstack([lims[:, 0], lims[::-1, 1]]),
    cmap="RdBu_r",
    vmin=-20,
    vmax=20,
    zorder=10,
)
ax.plot(
    circ_coord[:, 0],
    circ_coord[:, 1],
    label="Reprojection",
    color="lightblue",
    alpha=0.5,
    zorder=5,
)
pupil_c = np.array(fitted_model.params[:2])
ax.plot(*(eye_centre + ref), marker="o", ms=5, mfc="k", mec="none", zorder=1)
ax.plot(*(pupil_c + ref), marker="o", ms=5, mfc="none", color="lightblue", alpha=0.5)
ax.plot(
    *[(np.array([eye_centre[i], pupil_c[i]]) + ref[i]) for i in range(2)],
    color="lightblue",
    zorder=2,
    alpha=0.5
)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cb = fig.colorbar(img, cax=cax)
cb.set_label("Azimuth (degrees)")
ax.set_xlim([gray.shape[1], 0])
ax.set_ylim([gray.shape[0] - 80, 50])
ax.text(x=0.55, y=0.8, s="Dorsal", color="white", transform=ax.transAxes)
ax.text(x=0.65, y=0.1, s="Nasal", color="white", transform=ax.transAxes)
ax.axis("off")

ax_timecourse = plt.subplot2grid((5, 6), (0, 2), rowspan=1, colspan=2)
n = 6000
b = 16000
time = np.arange(n) / sampling
colors = ["purple", "orange"]
ax_timecourse.plot(
    time, data.pupil_x.iloc[b : b + n] - np.nanmedian(data.pupil_x), color=colors[0]
)
ax_timecourse.plot(
    time, data.pupil_y.iloc[b : b + n] - np.nanmedian(data.pupil_y), color=colors[1]
)
ax_timecourse.set_xlabel("Time (s)")
ax_timecourse.set_ylabel(r"$\Delta$angle (degrees)")
ax_timecourse.set_xlim(0, time.max())
ax_timecourse.set_xticks(np.arange(0, time.max(), 60))
ax_timecourse.spines["top"].set_visible(False)
ax_timecourse.spines["right"].set_visible(False)
divider = make_axes_locatable(ax_timecourse)
ax = divider.append_axes("right", size="10%", pad=0.05)
bins = np.arange(-25, 25)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.hist(
    data.pupil_x - np.nanmedian(data.pupil_x),
    orientation="horizontal",
    density=False,
    bins=bins,
    color=colors[0],
    histtype="step",
    lw=2,
)
ax.hist(
    data.pupil_y - np.nanmedian(data.pupil_y),
    orientation="horizontal",
    density=False,
    bins=bins,
    color=colors[1],
    histtype="step",
    lw=2,
)
ax.set_xlabel("# frames")
ax.set_yticks([])

# Depth analysis

In [None]:
import matplotlib as mpl
from matplotlib import cm

data["dx"] = data.pupil_x.diff()
data["dy"] = data.pupil_y.diff()
data["d_med"] = 0

depth_list = np.unique(data.depth)
cmap = cm.cool.reversed()
line_colors = []
norm = mpl.colors.Normalize(vmin=np.log(min(depth_list)), vmax=np.log(max(depth_list)))
col_dict = dict()
for depth in depth_list:
    rgba_color = cmap(norm(np.log(depth)), bytes=True)
    rgba_color = tuple(it / 255 for it in rgba_color)
    line_colors.append(rgba_color)
    col_dict[depth] = rgba_color

fig, axes = plt.subplots(3, 1)
fig.set_size_inches(6, 10)
labels = ["X position", "Y position", "Distance to median position"]
d = data[(~np.isnan(data.dx)) & (~np.isnan(data.depth))]
for iw, w in enumerate(["dx", "dy", "d_med"]):
    sns.violinplot(data=d, x="depth", y=w, palette=line_colors, ax=axes[iw])
    axes[iw].set_ylabel(labels[iw])

In [None]:
import matplotlib as mpl
from matplotlib import cm

depth_list = np.unique(data.depth)
cmap = cm.cool.reversed()
line_colors = []
norm = mpl.colors.Normalize(vmin=np.log(min(depth_list)), vmax=np.log(max(depth_list)))
col_dict = dict()
for depth in depth_list:
    rgba_color = cmap(norm(np.log(depth)), bytes=True)
    rgba_color = tuple(it / 255 for it in rgba_color)
    line_colors.append(rgba_color)
    col_dict[depth] = rgba_color

fig, axes = plt.subplots(3, 1)
fig.set_size_inches(5, 7)
labels = ["Motion", "Small movements", "Saccades"]
data["saccade"] = data.mvt > 5
d = data[(~np.isnan(data.dx)) & (~np.isnan(data.depth))]
for iw in range(2):
    if not iw:
        sns.violinplot(data=d, x="depth", y="mvt", palette=line_colors, ax=axes[iw])
    else:
        sns.violinplot(
            data=d[d.mvt < 2], x="depth", y="mvt", palette=line_colors, ax=axes[iw]
        )
    axes[iw].set_ylabel(labels[iw])

sac_per_depth = d.groupby("depth").saccade.aggregate(np.nansum)
sample_per_depth = d.groupby("depth").saccade.aggregate(len)

axes[2].bar(
    x=np.arange(len(sac_per_depth)),
    height=sac_per_depth / sample_per_depth * sampling,
    color=line_colors,
)
axes[2].set_xticks(np.arange(len(sac_per_depth)))
axes[2].set_xticklabels(sac_per_depth.index)
axes[2].set_ylabel("Saccades per second")

In [None]:
import matplotlib as mpl
from matplotlib import cm

lr = np.log10(data.rs)
data["running_bin"] = np.round(data.rs / 10) * 10

running_bins = np.unique(data[data.valid]["running_bin"])
cmap = cm.viridis
rs_colors = []
norm = mpl.colors.Normalize(vmin=min(running_bins), vmax=max(running_bins))
for rb in running_bins:
    rgba_color = cmap(norm(rb), bytes=True)
    rgba_color = tuple(it / 255 for it in rgba_color)
    rs_colors.append(rgba_color)


fig, axes = plt.subplots(3, 1)
fig.set_size_inches(5, 7)
labels = ["Motion", "Small movements", "Saccades"]
data["saccade"] = data.mvt > 5
d = data[(~np.isnan(data.dx)) & (~np.isnan(data.depth))]
for iw in range(2):
    if not iw:
        sns.violinplot(data=d, x="running_bin", y="mvt", palette="viridis", ax=axes[iw])
    else:
        sns.violinplot(
            data=d[d.mvt < 2], x="running_bin", y="mvt", palette="viridis", ax=axes[iw]
        )
    axes[iw].set_ylabel(labels[iw])

sac_per_depth = d.groupby("running_bin").saccade.aggregate(np.nansum)
sample_per_depth = d.groupby("running_bin").saccade.aggregate(len)

axes[2].bar(
    x=np.arange(len(sac_per_depth)),
    height=sac_per_depth / sample_per_depth * sampling,
    color=rs_colors,
)
axes[2].set_xticks(np.arange(len(sac_per_depth)))
axes[2].set_xticklabels(np.array(sac_per_depth.index, dtype=int))
axes[2].set_ylabel("Saccades per second")

In [None]:
from scipy.stats import mannwhitneyu

depth_list = np.unique(d.depth)
props = ["dx", "dy", "d_med"]
pval_mat = np.zeros([len(props)] + [len(depth_list)] * 2)

for ix, dx in enumerate(depth_list):
    xdf = d[d.depth == dx]
    for iy, dy in enumerate(depth_list):
        for ip, p in enumerate(props):
            ydf = d[d.depth == dy]
            if ix == iy:
                pval_mat[ip, ix, iy] = 0
            else:
                w = mannwhitneyu(xdf[p].values, ydf[p].values)
                pval_mat[ip, ix, iy] = w.pvalue

fig, axes = plt.subplots(1, 3)
for ip, p in enumerate(props):
    axes[ip].imshow(pval_mat[ip] - 0.05, cmap="RdBu", origin="lower")

In [None]:
fig, axes = plt.subplots(1, 2)
for d, ddf in data.groupby("depth"):
    axes[0].errorbar(
        x=np.nanmedian(ddf.dx),
        y=np.nanmedian(ddf.dy),
        xerr=np.nanstd(ddf.dx),
        yerr=np.nanstd(ddf.dy),
        label=int(d),
        marker="o",
        color=col_dict[d],
    )
    axes[1].errorbar(
        x=np.nanmean(ddf.dx),
        y=np.nanmean(ddf.dy),
        xerr=np.nanstd(ddf.dx) / np.sqrt(np.sum(~np.isnan(ddf.dx))),
        yerr=np.nanstd(ddf.dy) / np.sqrt(np.sum(~np.isnan(ddf.dy))),
        label=int(d),
        marker=".",
        lw=3,
        color=col_dict[d],
    )

for ax in axes:
    ax.set_aspect("equal")
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
axes[0].set_title("Eye position (median +/- std)")
axes[1].set_title("Eye position (mean +/- std)")

In [None]:
img_VS = pd.merge_asof(
    img_VS,
    mousez_logger,
    on="HarpTime",
    allow_exact_matches=True,
    direction="backward",
)

img_VS.EyeZ = img_VS.EyeZ / 100  # Convert cm to m
img_VS.MouseZ = img_VS.MouseZ / 100  # Convert cm to m
img_VS.Depth = img_VS.Depth / 100  # Convert cm to m
img_VS.Z0 = img_VS.Z0 / 100  # Convert cm to m

depth_list = img_VS["Depth"].unique()
depth_list = np.round(depth_list, 2)
depth_list = depth_list[~np.isnan(depth_list)].tolist()
depth_list.remove(-99.99)
depth_list.sort()

In [None]:
import pickle

# print(img_VS[:20], flush=True)
# Save img_VS
with open(protocol_folder / "img_VS.pickle", "wb") as handle:
    pickle.dump(img_VS, handle, protocol=pickle.HIGHEST_PROTOCOL)
print("Timestamps aligned and saved.", flush=True)
print("---STEP 3 FINISHED.---", "\n", flush=True)

# -----STEP4: Get the visual stimulation structure and Save (find the imaging frames for visual stimulation)-----
print("---START STEP 4---", "\n", "Get vis-stim structure...", flush=True)
with open(protocol_folder / "img_VS.pickle", "rb") as handle:
    img_VS = pickle.load(handle)
from cottage_analysis.stimulus_structure import sphere_structure as vis_stim_structure

stim_dict = vis_stim_structure.create_stim_dict(
    depth_list=depth_list, img_VS=img_VS, choose_trials=None
)

In [None]:
img_VS.head()

In [None]:
img_VS.shape

In [None]:
dlc_res.shape