In [11]:
%load_ext autoreload
%autoreload 2
from matplotlib import animation
import matplotlib
matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import os
import base
import anz
import viz
import plot_tools
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams.update({'font.size': 16})

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
from pathlib import Path
path = Path(os.getcwd())
project_path = path.parent.absolute()
figure_path = os.path.join(project_path, '_FIGURES')

In [13]:
BPS_TRACK_LOCATION = ['r2_in', 'r3_in']
BPS_TRACK_PAW = ['r1_in', 'r1_out',
                 'r2_in', 'r2_out',
                 'r3_in', 'r3_out',
                 'r4_in', 'r4_out']
BPS_TRACK_PELLET = ['pellet']

mouse = 'M9'
date = '2021.03.11'
analyzed_dir = os.path.join(r'C:\Users\Peter\Desktop\ANALYZED', mouse, date)
scheme = [['r1_in', 'r1_out'],
          ['r2_in', 'r2_out'],
          ['r3_in', 'r3_out'],
          ['r4_in', 'r4_out'],
          ['r1_in', 'r2_in', 'r3_in', 'r4_in'],
          ['pellet'],
          ['insured pellet']]
save_dir = os.path.join(os.getcwd(), '_FIGURES')

# load data
bps_to_include = np.unique([x for y in scheme for x in y])
bp_dict = {bp: i for i, bp in enumerate(bps_to_include)}
marker_xys_per_video = anz._get_markers(
    os.path.join(analyzed_dir, 'POSE_2D'),
    bps_to_include,
    base.CAMERA_NAMES)
label_regions_per_video, label_names = anz._get_labels(
    os.path.join(analyzed_dir, 'LABELS'))

In [14]:
pellet_ix = bp_dict['pellet']
xs = [marker_xys[0][pellet_ix, :, 0] for marker_xys in marker_xys_per_video]
ys = [marker_xys[1][pellet_ix, :, 0] for marker_xys in marker_xys_per_video]
zs = [marker_xys[1][pellet_ix, :, 1] for marker_xys in marker_xys_per_video]
ss = [marker_xys[0][pellet_ix, :, 2] for marker_xys in marker_xys_per_video]
pellet_x = anz.get_pellet_location(xs, ss, plot=False).mean()
pellet_y = anz.get_pellet_location(ys, ss, plot=False).mean()
pellet_z = anz.get_pellet_location(zs, ss, plot=False).mean()

In [15]:
from scipy.ndimage import median_filter
grab_bps_ix = [bp_dict[bp] for bp in BPS_TRACK_LOCATION]
x_grab_offset = 0
y_grab_offset = 0
z_grab_offset = 0

trajs = []
grab_ixs = []
i = 0
upsample = 5
for label_regions, marker_xys in zip(label_regions_per_video,
                                     marker_xys_per_video):
    dropped_regions = label_regions['dropped']
    grabbed_regions = label_regions['grab']
    chew_regions = label_regions['chew']
    # chew region has to be greater than 50 consecutive
    chew_regions = [x for x in chew_regions if (x[1]-x[0]) > 50]
    # some insurance pellet drops at the start of trials, filter that out
    dropped_regions = [x for x in dropped_regions if x[0] > 30]
    outcome = anz.outcome_truth_table(dropped_regions,
                                      chew_regions,
                                      grabbed_regions)
    grab_outcomes = anz.grab_truth_table(outcome,
                                        grabbed_regions,
                                        chew_regions,
                                        dropped_regions)

    grabs, extends, both = anz.anneal_labels(label_regions['extend'],
                                             label_regions['grab'],
                                             window=10)

    scores_x = marker_xys[0][grab_bps_ix, :, 2]
    scores_y = marker_xys[1][grab_bps_ix, :, 2]
    xp = marker_xys[0][grab_bps_ix, :, 0] + x_grab_offset
    yp = marker_xys[1][grab_bps_ix, :, 0] + y_grab_offset
    zp = marker_xys[1][grab_bps_ix, :, 1] + z_grab_offset

    for grab, extend, grab_extend, outcome in zip(grabs, extends, both,
                                            grab_outcomes):
        if outcome != base.GRABTYPES.SNATCHED:
            continue

        xtraj = np.mean(xp, axis=0)
        ytraj = np.mean(yp, axis=0)
        ztraj = np.mean(zp, axis=0)

        ixs_x = anz._marker_label_intersects(scores_x,
                                             grab_extend,
                                             base.SCORE_THRESHOLD,
                                             base.CRITERIA_CONTIGUOUS_FRAMES)
        ixs_yz = anz._marker_label_intersects(scores_y,
                                              grab_extend,
                                              base.SCORE_THRESHOLD,
                                              base.CRITERIA_CONTIGUOUS_FRAMES)
        # linear interpolate
        xtraj = anz.contiguous_interp(ixs_x, xtraj)
        ytraj = anz.contiguous_interp(ixs_yz, ytraj)
        ztraj = anz.contiguous_interp(ixs_yz, ztraj)

        # median filter
        xtraj = median_filter(xtraj, size=5)
        ytraj = median_filter(ytraj, size=5)
        ztraj = median_filter(ztraj, size=5)

        # subset
        s = np.max([ixs_x[0], ixs_yz[0]])
        e = np.min([ixs_x[-1], ixs_yz[-1]])
        xtraj, ytraj, ztraj = xtraj[s:e], ytraj[s:e], ztraj[s:e]
        offset = s - extend[0]
        assert offset >= 0, print(s, extend[0])

        # spline
        xtraj, ytraj, ztraj = anz.interpolate_polyline(
            np.array([xtraj, ytraj, ztraj]).T,
            len(xtraj) * upsample,
            s=10)

        trajs.append([xtraj, ytraj, ztraj, i])
        grab_ixs.append(upsample * (4 + grab[0]-extend[0]-offset))
    i+=1

# plt.style.use('dark_background')
azimuth = -60
elevation = -170
fig = plt.figure(figsize=(10, 10))
ax = plt.axes(projection='3d')
# viz.black_theme(fig, ax)
viz.threedimstyle(fig, ax, elevation=-170, azimuth=-30)

i, j = 0, 40
ax.plot3D(pellet_x, pellet_y, pellet_z, 'o', alpha=0.2, markersize=100,
          color='cyan')
for ix, traj in zip(grab_ixs[i:j], trajs[i:j]):
    # print(traj[3])
    lines = ax.plot3D(traj[0][:ix], traj[1][:ix], traj[2][:ix], c='r',
                  label=traj[3])
    ax.plot3D(traj[0][ix-1:], traj[1][ix-1:], traj[2][ix-1:], c='g')

def animate(f):
    ax.view_init(elevation, azimuth - f)
    return lines

anim = animation.FuncAnimation(fig,
                               animate,
                               frames=360)
sp = os.path.join(figure_path, 'reach_dynamics', 'still_animation.mp4')
anim.save(sp,
          fps= 20,
          extra_args=['-vcodec', 'libx264'])

In [9]:
i, j = 0, 40
maxlen = np.max([x[0].shape[0] for x in trajs[i:j]])
trailing = 6 * upsample
plot_points_per_line = []
for traj in trajs[i:j]:
    points = np.array([traj[0], traj[1], traj[2]]).T
    pad_len = maxlen - points.shape[0]
    points = np.pad(points, ((0, pad_len), (0, 0)), 'edge')
    plot_points = []
    for ix in range(points.shape[0]-trailing):
        plot_points.append(points[ix:ix+trailing])
    plot_points_per_line.append(plot_points)

In [10]:
azimuth = -60
elevation = -170

plt.style.use('dark_background')
fig = plt.figure(figsize=(10, 10))
ax = plt.axes(projection='3d')
viz.black_theme(fig, ax)
viz.threedimstyle(fig, ax, elevation=elevation, azimuth=azimuth)
ax.set_xlim(160, 230)
ax.set_ylim(140, 250)
ax.set_zlim(110, 150)
frame_rate = 200

ax.plot3D(pellet_x, pellet_y, pellet_z-10, 'o', alpha=0.2, markersize=100,
          color='cyan')
lines = []
for plot_points in plot_points_per_line:
    little_lines = []
    for i in np.arange(0, plot_points[0].shape[0]-1):
        line, = ax.plot3D(plot_points[0][i:i+2,0],
                          plot_points[0][i:i+2,1],
                          plot_points[0][i:i+2,2],
                          alpha=i/plot_points[0].shape[0],
                          c='r')
        little_lines.append(line)
    lines.append(little_lines)

def animate(f):
    if f % 10 == 0:
        print(f)
    for plot_points, grab_ix, little_lines in zip(plot_points_per_line, grab_ixs,
                                          lines):
        for i, line in enumerate(little_lines):
            line.set_data(plot_points[f][i:i+2,0],
                          plot_points[f][i:i+2,1])
            line.set_3d_properties(plot_points[f][i:i+2,2])
            if f + i >= grab_ix:
                line.set_color('g')
        if f + trailing == grab_ix:
            ax.scatter(plot_points[f][-1, 0],
                        plot_points[f][-1, 1],
                        plot_points[f][-1, 2],
                        s=5,
                        facecolors='none',
                        edgecolors='g')
    ax.set_title(f'{1000/frame_rate * f/upsample} ms')
    return [x for y in lines for x in y]

anim = animation.FuncAnimation(fig,
                               animate,
                               frames=maxlen-trailing,
                               blit=True)
sp = os.path.join(figure_path, 'reach_dynamics', 'animation.mp4')
anim.save(sp,
          fps= 5 * upsample,
          extra_args=['-vcodec', 'libx264'])

print('done')

0
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
done
