In [None]:
import os, sys, re, warnings, logging
from os.path import join
from glob import glob
import numpy as np
from scipy.signal import find_peaks
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import defaultdict
sys.path.append('..')
import trilabtracker as tt

from importlib import reload
for m in tt.__all__:
    eval(f'reload(tt.{m})')
reload(tt)

plt.rcParams['figure.figsize'] = 6,4
dpi = 150

# analysis_dir = './analysis'

os.getcwd()

## Prepare trial data.

Prepare a dictionary of trials to analyze with basic info for each (path to trial file, population, age, number of individuals, etc).

In [None]:
tank_diameter_vs_age = { 7:9.6, 14:10.4, 21:12.8, 28:17.7, 42:24.2, 56:33.8, 70:33.8, 84:33.8 }

# Extract trial metadata from the trial's filename.
def parse_trial_file(trial_file, etho=False):
    if etho:
        trial_dir  = None
        trial_name = os.path.basename(trial_file)
        trial_name = trial_name.split('-')[1]
    else:
        trial_dir  = os.path.dirname(trial_file)
        trial_name = os.path.basename(trial_dir)
    pop,day,age,group,n_ind = trial_name.split('_')[:5]
    pop        = {'sf':'SF', 'pa':'Pa', 'rc':'RC'}[pop.lower()]
    age        = int(age[:-3])
    age = 42 if age==43 else (70 if age==71 else age)
    n_ind      = int(re.findall('\d+',n_ind)[0])
    R_cm       = tank_diameter_vs_age[age]/2
    trial      = { k:v for k,v in locals().items() if k in ['trial_file', 'trial_dir', 
                                 'trial_name', 'pop', 'age', 'group', 'n_ind', 'R_cm'] }
    return trial

# trilabtracker only for now
def load_trial(trial_file, **args):
    trial = parse_trial_file(trial_file)
    return tt.preprocess_trial(trial, **args)

# Select a set of trials to analyze.
trial_files = sorted(glob('../dataset1/tracking/21-07-12_full/*/trial.pik'))
# trial_files = sorted(glob('../dataset2/tracking/*/trial.pik'))
# print(trial_files)

# Count trials of each type.
trials = [ parse_trial_file(f) for f in trial_files ]
trials = pd.DataFrame(trials, index=trial_files)
grouped_trials = trials.groupby(['pop','age','n_ind'])
count  = pd.DataFrame(grouped_trials['trial_dir'].count().rename('count'))
count = count.unstack(1)
count.columns = count.columns.droplevel()
count[pd.isna(count)] = 0
count = count.astype(int)
display(count)

def matching_trials(pop=None, age=None, n_ind=None, df=trials):
    I = pd.Series(data=True, index=df.index)
    if not pop is None:
        I = I & (df['pop']==pop)
    if not age is None:
        I = I & (df['age']==age)
    if not n_ind is None:
        I = I & (df['n_ind']==n_ind)
    return df[I].index.tolist()

# Test analysis functions.

### [in progress] Detect discrete turns.

In [None]:
# for i,trial_file in enumerate(trials.index):
#     trial = parse_trial_file(trial_file)
#     trial = tt.preprocess_trial(trial)
#     locals().update(trial)
#     print('\r'+' '*200+'\r'+f'{i}/{len(trial_files)}  {trial_name}',end='')
#     # Plot angle.
#     a = pos[:,0,2]
#     plt.plot(time,a,label='actual angle')
#     # Extract discrete turns.
#     w = int(0.1*fps)
#     step = np.concatenate([np.ones(w)/w,np.zeros(1),-np.ones(w)/w])
#     conv = np.convolve(a,step,mode='same')
#     P = find_peaks(np.absolute(conv),distance=w,height=np.pi/25)
#     # Reconstruct angle vs time using only the discrete turns.
#     da = np.zeros_like(a)
#     da[P[0]] = conv[P[0]]
#     a2 = a[0]+np.cumsum(da)
#     plt.plot(time,a2,label='reconstructed angle')
    
#     plt.plot(time,a-a2,label='actual angle - reconstructed angle')
    
# #     plt.xlim(0,1000)
# #     plt.ylim(2,15)
#     plt.legend()
    
#     break

In [None]:
# fig,ax = plt.subplots(1,2,figsize=(12,4))
# I = P[0]
# dt = time[I[1:]]-time[I[:-1]]
# ax[0].hist(dt,bins=50)
# ax[0].set_xlabel('time between two discrete turns')
# ax[0].set_ylabel('frequency')
# ax[0].set_yscale('log')
# ax[1].hist(np.absolute(conv[I]),bins=50)
# ax[1].set_xlabel('angle turned')
# ax[1].set_yscale('log')
# plt.suptitle('Discrete turns')
# plt.show()


# fig,ax = plt.subplots(1,2,figsize=(12,4))
# da = np.diff(a-a2)/np.sqrt(time[1:]-time[:-1])
# ax[0].hist(da,bins=100)
# ax[0].set_xlabel('da/sqrt(dt)')
# ax[0].set_ylabel('frequency')
# # dt = time[I[1:]]-time[I[:-1]]
# # ax[0].hist(dt,bins=20)
# # ax[1].hist(np.absolute(conv[I]),bins=50)
# # plt.xlim(-5,5)
# ax[0].set_yscale('log')
# ax[1].set_visible(False)
# plt.suptitle('Continuous turns')
# plt.show()

# Test loading trilab-tracker and ethovision data

### trilab-tracker

In [None]:
# trial_files = sorted(glob('../tracking/full_21-01-22/*/trial.pik'))
# f = trial_files[1]
f = '../dataset1/tracking/21-07-12_full/Pa_Sun_21dpf_GroupC_n2a_20200621_1510/trial.pik'

# trial = parse_trial_file(f)
# trial = tt.preprocess_trial(trial)
trial = load_trial(f)
globals().update(trial)

for i in range(n_ind):
    plt.plot(*pos[:,i,:2].T)
plt.axis('equal')
plt.show()

# for i in range(n_ind):
#     plt.plot(pos[:200,i,2],'-')
# plt.show()

### ethovision

In [None]:
etho_files = glob('../dataset1/ethovision/Raw_Data/*.xlsx')
f = etho_files[1]
print(f)

trial = parse_trial_file(f, etho=True)
trial.update(tt.load_trial_ethovision(f))

In [None]:
globals().update(trial)

for i in range(n_ind):
    plt.plot(*data[:,i,:2].T, lw=0.5)
plt.axis('equal')
plt.show()

for i in range(n_ind):
    plt.plot(data[:200,i,2], '-', lw=1)
plt.show()

# Compute or load analyzed data.

The first subsection iterates over the list of trials, loads the full tracking output, compute kinematic quantities, performs cuts, computes statistical properties (distribution of speed, angular speed, pairwise distance-angle, etc), and saves them.

The second subsection loads precomputed statistical properties made with the first cell.

### Define cuts.

In [None]:
# Speed cut: 0.2*30 = travel 0.2 tank radius between 2 frames at 30fps.
# Angular speed cut: pi/2*30 = quarter turn between 2 frames at 30 fps.
cut_ranges    = { 'd_wall':[-np.inf,np.inf], 'v':[0,0.2*30], 
                  'v_ang':[-np.pi/2*30,np.pi/2*30] }
# This is used to name the output directory and make figure titles.
# The "=" should be "<=" however ntfs doesn't allow "<" in filenames.
cut_label     = ', '.join([ f'{v[0]:g}={k}={v[1]:g}' for k,v in cut_ranges.items() ])

cut_dir = os.path.join(analysis_dir,cut_label)
if not os.path.exists(cut_dir):
    os.mkdir(cut_dir)

### Analyze trials.

In [None]:
# bins_area     = np.linspace(0,600,100)
# bins_aspect   = np.linspace(1,15,100)
# prebins_dWall = np.linspace(-0.1,1.1,100) # Multiply by R_cm before using.
# prebins_v     = np.linspace(*cut_ranges['v'],100) # Multiply by R_cm before using.
# bins_vAng     = np.linspace(*cut_ranges['v_ang'],100)
# prebins_pairDist = np.linspace(0,2,60) # Multiply by R_cm before using.
# bins_pairAng  = np.linspace(0,np.pi,30)

# trial_data    = {}

# for i,trial_file in enumerate(trials.index):
#     print('\r'+' '*200+'\r'+f'{i+1}/{len(trials)}',end='')
    
#     trial  = load_trial(trial_file, load_data=True)
#     ranges = { 'd_wall': [ x*trial['R_cm'] for x in cut_ranges['d_wall'] ], 
#                'v':      [ x*trial['R_cm'] for x in cut_ranges['v'] ], 
#                'v_ang':  cut_ranges['v_ang'] }
#     trial  = compute_kinematics(trial)
#     with warnings.catch_warnings():
#         warnings.simplefilter("ignore", category=RuntimeWarning)
#         trial = compute_cuts(trial, ranges)
#     globals().update(trial)
    
#     # Fish area and Fish aspect ratio histograms.
#     # Useful to tune the tracker's contour filters.
#     hist_area   = np.histogram(data[:,:n_ind,3], bins=bins_area )
#     hist_aspect = np.histogram(data[:,:n_ind,4], bins=bins_aspect )
    
#     # Distribution of distance to the wall.
#     bins = prebins_dWall * R_cm
#     vals = d_wall.flatten()
#     hist_dWall = np.histogram(vals[~np.isnan(vals)],bins=bins)
    
#     # Speed distribution.
#     bins = prebins_v * R_cm
#     vals = v.flatten()
#     hist_v = np.histogram(vals[~np.isnan(vals)],bins=bins)
    
#     # Angular speed distribution.
#     bins = bins_vAng
#     vals = vel[:,:,2].flatten()
#     hist_vAng = np.histogram(vals[~np.isnan(vals)],bins=bins)
    
#     # Joint distribution of pair distance and pair angle,
#     if n_ind>1:
#         bins_d  = prebins_pairDist * R_cm
#         bins_a  = bins_pairAng
#         J1,J2 = np.triu_indices(n_ind,1)
#         d     = np.hypot(pos[:,J1,0]-pos[:,J2,0],pos[:,J1,1]-pos[:,J2,1]).flatten()
#         a     = (pos[:,J1,2]-pos[:,J2,2]).flatten()
#         a     = a - 2*np.pi*np.rint(a/(2*np.pi))
#         I     = np.logical_not(np.logical_or(np.isnan(d),np.isnan(a)))
#         d     = d[I]
#         a     = np.absolute(a[I])
#         # 2D histogram of pairwise distance and angle.
#         hist_distAng = np.histogram2d(d, a, bins=(bins_d,bins_a), density=True)
#         # Pairwise polar alignment parameter vs pair distance.
#         K = np.digitize(d,bins_d)
#         with warnings.catch_warnings():
#             warnings.simplefilter("ignore", category=RuntimeWarning)
#             p = np.array([ np.nanmean(np.cos(a[K==i])) for i in range(len(bins_d)+1) ])
#         polar = p[1:-1],bins_d
#     else:
#         hist_distAng,polar = None,None

#     # Save output.
#     trial_data[trial_file] = { k:v for k,v in locals().items() if k in 
#                                [ 'valid_fraction', 'hist_area', 'hist_aspect', 'hist_dWall', 
#                                  'hist_v', 'hist_vAng', 'hist_distAng', 'polar' ] }
# #     break


# f = os.path.join(cut_dir,'trial_data')
# # pickle.dump( {'trial_data':trial_data}, bz2.BZ2File(f+'.bz2','w') )
# pickle.dump( {'trial_data':trial_data}, open(f+'.pik','wb') )

### Load precomputed analysis results.

In [None]:
# List previously computed cuts.
cut_dirs = os.listdir(analysis_dir)
print(cut_dirs)
cut_label = cut_dirs[0]

f    = os.path.join(analysis_dir,cut_label,'trial_data')
data = pickle.load( open(f+'.pik','rb') )
globals().update(data)

# Plot individual trial data

Distributions and other statistical quantities were precomputed in the previous section. Time series and trajectories plots require to reload each trial's trial file, one at a time.

### Trajectories and time series

In [None]:
# fig_dir = { k:os.path.join(analysis_dir,cut_label,k) for k in 
#                ['trajectories', 'angle-vs-time'] }
# for d in fig_dir.values():
#     if not os.path.exists(d):
#         os.mkdir(d)

# # Creating a new figure for each trial creates a memory leak
# # I haven't been able to plug, so I'm creating one figure for
# # trajectories and one for angles and reusing them.
# fig_traj = plt.figure(figsize=(9,)*2)
# fig_ang  = plt.figure(figsize=(12,6))

# for i,trial_file in enumerate(trials.index):
#     print('\r'+' '*200+'\r'+f'{i+1}/{len(trials)}',end='')
    
#     trial = load_trial(trial_file, load_data=True)
#     trial = compute_kinematics(trial)
#     globals().update(trial)
    
#     # Trajectories.
#     fig,ax = fig_traj,fig_traj.gca()
#     ax.add_patch(plt.Circle(center, R_px, facecolor='None', 
#                                    edgecolor='k', lw=0.5))
#     ax.plot(*np.moveaxis(data[::5,:n_ind,:2],2,0),lw=0.5)
#     ax.axis('equal')
#     ax.yaxis.set_inverted(True)
#     fig.suptitle(trial_name)
#     fig.savefig(os.path.join(fig_dir['trajectories'],trial_name+'.png'))
#     fig.clf()
    
#     # Angle vs time.
#     fig,ax = fig_traj,fig_traj.gca()
#     ax.plot(time[:,None],pos[:,:,2])
#     ax.set_xlabel('Time (s)')
#     ax.set_ylabel('Angle (rad)')
#     fig.suptitle(trial_name)
#     fig.savefig(os.path.join(fig_dir['angle-vs-time'],trial_name+'.png'))
#     fig.clf()
    
# plt.close('all')

### Statistics

In [None]:
# f    = os.path.join(analysis_dir,cut_label,'trial_data')
# data = pickle.load( open(f+'.pik','rb') )
# globals().update(data)

# fig_dir = { k:os.path.join(analysis_dir,cut_label,k) for k in [ #'valid_fraction', 
#             'hist_dWall', 'hist_v', 'hist_vAng', 'hist_distAng', 'polar' ] }
# for d in fig_dir.values():
#     if not os.path.exists(d):
#         os.mkdir(d)


# fig = plt.figure()
# ax  = fig.gca()

# for i,trial_file in enumerate(trials.index):
#     print('\r'+' '*200+'\r'+f'{i+1}/{len(trials)}',end='')
    
#     globals().update(trials.loc[trial_file].to_dict())
#     globals().update(trial_data[trial_file])
    
# #     # Valid fraction.
# #     bp = plt.bar(*zip(*valid_fraction.items()))
# #     for bar in bp:
# #         h,x,w = bar.get_height(),bar.get_x(),bar.get_width()
# #         plt.annotate(f'{h:.2f}', xy=(x+w/2,1.01), ha='center', va='bottom')
# #     plt.ylim(0,1.1)
# #     plt.ylabel('Valid fraction')
# #     plt.title(cut_label)
# #     plt.suptitle(trial_name)
# #     plt.savefig(os.path.join(fig_dir['valid_fraction'],trial_name+'.png'))
# #     plt.close()
    
#     # Distributions of distance to the wall, speed, and angular speed.
#     H = dict( hist_dWall='d_wall (cm)', hist_v='v (cm/s)', hist_vAng='v_ang (rad/s)' )
#     for name,label in H.items():
#         ax = fig.gca()
#         h,b = locals()[name]
#         ax.bar(b[:-1],h,width=b[1:]-b[:-1])
#         ax.set_yscale('log')
#         ax.set_xlabel(label)
#         ax.set_ylabel('frequency')
#         ax.set_title(cut_label)
#         fig.suptitle(trial_name)
#         fig.savefig(os.path.join(fig_dir[name],trial_name+'.png'))
#         fig.clf()
    
#     # Pairwise distance-angle distribution.
#     if not hist_distAng is None:
#         ax = fig.gca()
#         h,b1,b2 = hist_distAng
#         m = ax.pcolormesh(b1, b2*180/np.pi, h.T, cmap='Oranges')
#         ax.set_xlabel('pair distance (cm)')
#         ax.set_ylabel('pair angle (deg)')
#         fig.colorbar(m)
#         ax.set_title(cut_label)
#         fig.suptitle(trial_name)
#         fig.savefig(os.path.join(fig_dir['hist_distAng'],trial_name+'.png'))
#         fig.clf()
    
#     # Pairwise polar order parameter vs distance.
#     if not polar is None:
#         ax = fig.gca()
#         p,b = polar
#         ax.plot((b[1:]+b[:-1])/2,p,marker='o',mfc='None',ms=4)
#         ax.set_xlabel('pair distance (cm)')
#         ax.set_ylabel('mean cosine of pair angle')
#         ax.set_ylim(-1,1)
#         ax.set_title(cut_label)
#         fig.suptitle(trial_name)
#         fig.savefig(os.path.join(fig_dir['polar'],trial_name+'.png'))
#         fig.clf()
    
# #     if i==10:
# #         break

# plt.close('all')

In [None]:
# ''' Analyze instances of unusually high velocity. '''

# # At 30 fps, |v_ang|=30 (about where the rare peaks start) 
# # corresponds to about pi/3 in one frame.
# print('v_ang for pi/3 in (1/30) second:',np.pi/3*fps)

# print('Instances of unusually high v_ang:')
# for f in fish_list:
#     ang_diff  = df[f,'ang'].diff()
#     I = np.nonzero(np.absolute(ang_diff.values)>1)[0]
#     for i in I[:5]:
#         display(df[f,'ang'].iloc[i-1:i+2])