In [None]:
import sleap

In [None]:
metrics = sleap.load_metrics("../data/models/centered_instance", split="val")
skeleton = sleap.skeleton.Skeleton.load_json("../data/skeleton.json")
px_to_mm = 28.25
# print("\n".join(metrics.keys()))

In [None]:
import pandas as pd 
dists = metrics['dist.dists']
res = dists * (1/px_to_mm)
res = pd.DataFrame(res, columns = skeleton.node_names)
res = pd.melt(res, value_vars=skeleton.node_names,var_name="Landmark",value_name="Error (mm)")


In [None]:

import seaborn as sns
import matplotlib.pyplot as plt

g = sns.FacetGrid(res, row="Landmark", hue="Landmark", aspect=7, height=0.6, palette="husl")
g.map(sns.kdeplot, "Error (mm)", clip=(0,0.2), bw_adjust=0.1, fill=True, alpha=0.9, linewidth=0)
g.map(plt.axhline, y=0, lw=1.0, c="k")

def label(x, color, label):
    ax = plt.gca()
    ax.patch.set_alpha(0)
    ax.text(1, 0.3, label, fontweight="bold", color=color, ha="right", va="center", transform=ax.transAxes)
g.map(label, "Landmark")

g.set_titles("")
g.set(yticks=[])
g.despine(bottom=True, left=True)
g.fig.subplots_adjust(hspace=-0.5)
plt.xlabel("Error (mm)");

In [None]:
errors_90th = res.groupby('Landmark').quantile(.90)
errors_90th_px = errors_90th*28.25

errors_80th = res.groupby('Landmark').quantile(.80)
errors_80th_px = errors_80th*28.25

In [None]:
filename = "../data/cam1_20220217_0through190_cam1_20220217_0through190_1-tracked.analysis.h5"

In [None]:
import h5py
import numpy as np

with h5py.File(filename, "r") as f:
    dset_names = list(f.keys())
    locations = f["tracks"][:].T
    node_names = [n.decode() for n in f["node_names"][:]]

print("===filename===")
print(filename)
print()

print("===HDF5 datasets===")
print(dset_names)
print()

print("===locations data shape===")
print(locations.shape)
print()

print("===nodes===")
for i, name in enumerate(node_names):
    print(f"{i}: {name}")
print()

In [None]:
import cv2

# pose, img, window, nodes, percs, perc_errs = data["pose"], data["img"], data["window"].squeeze(), data["nodes"], data["percs"], data["perc_errs"]

pose = locations[93932,:,:,0]
videofile = "/Genomics/ayroleslab2/scott/long-timescale-behavior/data/organized_videos/20220217-lts-cam1/20220217-lts-cam1-0000.mp4"

cap = cv2.VideoCapture(videofile)
cap.set(1,93932);
success, img = cap.read()


In [None]:
%matplotlib inline

import pandas as pd
import numpy as np
from scipy.io import loadmat

import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.style
import matplotlib as mpl
from matplotlib.patches import Circle
from matplotlib import patches

mpl.rcParams["figure.facecolor"] = "w"
mpl.rcParams["figure.dpi"] = 150
mpl.rcParams["savefig.dpi"] = 600
mpl.rcParams["savefig.transparent"] = True
mpl.rcParams["font.size"] = 15
mpl.rcParams["font.family"] = "sans-serif"
mpl.rcParams["font.sans-serif"] = ["Arial", "DejaVu Sans"]
mpl.rcParams["axes.titlesize"] = "xx-large"  # medium, large, x-large, xx-large

mpl.style.use("seaborn-deep")

In [None]:
px_per_mm = 28.25
ctr = np.nanmean(pose, axis=0)
window = [-50,50]
nodes = node_names
tmp = [errors_90th_px[errors_90th_px.index == node] for node in node_names]
tmp  = pd.concat(tmp)
perc_errs = tmp.to_numpy()
percs=np.array(["90"])

fig = plt.figure(figsize=(6, 6))
ax = fig.add_axes([0, 0, 1, 1], frameon=False)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.autoscale(tight=True)
plt.imshow(img, cmap="gray")
plt.xticks([]), plt.yticks([]);
plt.ylim(window[::-1] + ctr[1])
plt.xlim(window + ctr[0]);



cmap = sns.color_palette("husl", n_colors=len(nodes))
for j in range(len(nodes)):
    ax.add_patch(Circle(xy=pose[j], radius=1, fill=True, lw=0, alpha=0.7, fc=cmap[j]))
    ax.add_patch(Circle(xy=pose[j], radius=perc_errs[j], fill=False, lw=1.5, alpha=0.7, ec=cmap[j], label=f"{percs[j]}%" if j == 0 else None))
    
tmp = [errors_80th_px[errors_80th_px.index == node] for node in node_names]
tmp  = pd.concat(tmp)
perc_errs = tmp.to_numpy()
percs=np.array(["80"])

for j in range(len(nodes)):
    ax.add_patch(Circle(xy=pose[j], radius=1, fill=True, lw=0, alpha=0.7, fc=cmap[j]))
    ax.add_patch(Circle(xy=pose[j], radius=perc_errs[j], fill=False, lw=1.5, alpha=0.7, ec=cmap[j], label=f"{percs[j]}%" if j == 0 else None))

x, y = ctr[0] + 30, ctr[1] + 40
plt.plot([x, x + px_per_mm * 0.5], [y, y], "w-", color="black", lw=10)
plt.text(x + px_per_mm * 0.25, y + 5, "0.5 mm", fontweight="bold", color="black", ha="center", va="top", fontsize=18);


In [None]:
import sys
sys.path.append('../')
import analysis.utils.trx_utils as trx_utils

In [None]:
import importlib
importlib.reload(trx_utils)
filtered_locations = trx_utils.fill_missing_np(locations)
filtered_locations = trx_utils.smooth_median(filtered_locations)
filtered_locations = trx_utils.smooth_gaussian(filtered_locations)


In [None]:
# start = int(2058*99.96)
start = int(2091*99.96)
end = start + 150
trx_utils.plot_trx(filtered_locations, videofile,frame_start=start,frame_end=end,trail_length=100)