In [None]:
import os, sys, re, warnings, logging, pickle, bz2
from os.path import join
from glob import glob
import numpy as np
from scipy.signal import find_peaks
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import defaultdict
sys.path.append('..')
import trilabtracker as tt

from importlib import reload
for m in tt.__all__:
    eval(f'reload(tt.{m})')
reload(tt)

plt.rcParams['figure.figsize'] = 6,4
dpi = 150

analysis_dir = './analysis'
R_cm         = 55.5

os.getcwd()

# Prepare trial data.

Prepare a dictionary of trials to analyze with basic info for each (path to trial file, population, age, number of individuals, etc).

In [None]:
# Load list of trial names to use in the analysis.
valid_trials = open('settings/adult-trials-to-analyze.txt').readlines()
valid_trials = [ fn.strip() for fn in valid_trials ]

# Extract trial metadata from the trial's filename.
def parse_trial_file(trial_file, etho=False):
#     if etho:
#         trial_dir  = None
#         trial_name = os.path.basename(trial_file)
#         trial_name = trial_name.split('-')[1]
#     else:
    trial_dir  = os.path.dirname(trial_file)
    trial_name = os.path.basename(trial_dir)
    _,pop,n    = trial_name.lower().split('_')[:3]
    n_ind      = int(n[1:])
    if 'dark' in trial_name.lower():
        pop = pop+'-dark'
        # Skip dark trials until I implement one-sided background subtraction.
#         return None
    
    if not trial_name in valid_trials:
        return None
    trial      = { k:v for k,v in locals().items() if k in ['trial_file', 
                              'trial_dir', 'trial_name', 'pop', 'n_ind'] }
    trial['R_cm'] = R_cm
    return trial

# trilabtracker only for now
def load_trial(trial_file, **args):
    trial = parse_trial_file(trial_file)
    return tt.preprocess_trial(trial, **args)

# Select a set of trials to analyze.
trial_files = sorted(glob('tracking_output/*/trial.pik'))
# print(trial_files)

# Count trials of each type.
trials = [ parse_trial_file(f) for f in trial_files ]
trials = pd.DataFrame(trials, index=trial_files)
grouped_trials = trials.groupby(['pop','n_ind'])
count  = pd.DataFrame(grouped_trials['trial_dir'].count().rename('count'))
count = count.unstack(1)
count.columns = count.columns.droplevel()
count[pd.isna(count)] = 0
count = count.astype(int)
display(count)

# def matching_trials(pop=None, age=None, n_ind=None, df=trials):
#     I = pd.Series(data=True, index=df.index)
#     if not pop is None:
#         I = I & (df['pop']==pop)
#     if not age is None:
#         I = I & (df['age']==age)
#     if not n_ind is None:
#         I = I & (df['n_ind']==n_ind)
#     return df[I].index.tolist()

# Compute or load analyzed data.

The first subsection iterates over the list of trials, loads the full tracking output, compute kinematic quantities, performs cuts, computes statistical properties (distribution of speed, angular speed, pairwise distance-angle, etc), and saves them.

The second subsection loads precomputed statistical properties made with the first cell.

### Define cuts.

In [None]:
cut_ranges    = { 'v':[0,100], 'v_ang':[-30,30], 't':[600,1800] }
# This is used to name the output directory and make figure titles.
# The "=" should be "<=" however ntfs doesn't allow "<" in filenames.
cut_label     = ', '.join([ f'{v[0]:g}={k}={v[1]:g}' for k,v in cut_ranges.items() ])

cut_dir = os.path.join(analysis_dir, cut_label)
if not os.path.exists(cut_dir):
    os.mkdir(cut_dir)

### Analyze trials.

In [None]:
from importlib import reload
for m in tt.__all__:
    eval(f'reload(tt.{m})')
reload(tt)

traj_data = {}
stat_data = dict(
    bins_dWall    = np.linspace(0,R_cm,100), 
    bins_v        = np.linspace(*cut_ranges['v'],100), 
    bins_vAng     = np.linspace(*cut_ranges['v_ang'],100), 
    bins_pairDist = np.linspace(0,2*R_cm,60), 
    bins_pairAng  = np.linspace(0,np.pi,30), 
    )
globals().update(stat_data)
results = {}

for trial_file in tqdm(trials.index):
    
    trial = load_trial(trial_file, load_timestamps=False, cut_ranges=cut_ranges)
    globals().update(trial)
    
    # [?] Replace orientations (data[:,:,2]) with displacement-based ones.
    
    # Save trajectories.
    traj_data[trial_name] = pos[:,:,:2]
    
    # [!] Handle overlaps.
    
    # Distribution of distance to the wall.
    bins, vals = bins_dWall, d_wall.flatten()
    hist_dWall = np.histogram(vals[~np.isnan(vals)], bins=bins)
    
    # Speed distribution.
    bins, vals = bins_v, v.flatten()
    hist_v = np.histogram(vals[~np.isnan(vals)], bins=bins)
    
    # Angular speed distribution.
    bins, vals = bins_vAng, vel[:,:,2].flatten()
    hist_vAng = np.histogram(vals[~np.isnan(vals)], bins=bins)
    
    # Joint distribution of pair distance and pair angle,
    if n_ind>1:
        bins_d  = bins_pairDist
        bins_a  = bins_pairAng
        J1,J2 = np.triu_indices(n_ind,1)
        d     = np.hypot(pos[:,J1,0]-pos[:,J2,0],pos[:,J1,1]-pos[:,J2,1]).flatten()
        a     = (pos[:,J1,2]-pos[:,J2,2]).flatten()
        a     = a - 2*np.pi*np.rint(a/(2*np.pi))
        I     = np.logical_not(np.logical_or(np.isnan(d),np.isnan(a)))
        d     = d[I]
        a     = np.absolute(a[I])
        # 2D histogram of pairwise distance and angle.
        hist_distAng = np.histogram2d(d, a, bins=(bins_d,bins_a), density=True)
        # Pairwise polar alignment parameter vs pair distance.
        K = np.digitize(d,bins_d)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=RuntimeWarning)
            p = np.array([ np.nanmean(np.cos(a[K==i])) for i in range(len(bins_d)+1) ])
        polar = p[1:-1],bins_d
    else:
        hist_distAng,polar = None,None

    # Save output.
    results[trial_file] = { k:v for k,v in locals().items() if k in 
                               [ 'valid_fraction', 'hist_area', 'hist_aspect', 'hist_dWall', 
                                 'hist_v', 'hist_vAng', 'hist_distAng', 'polar' ] }
#     break

stat_data['results'] = results
pickle.dump(stat_data, open(join(cut_dir,'stats.pik'), 'wb'))
pickle.dump(traj_data, open(join(cut_dir,'trajectories.pik'), 'wb'))

In [None]:
traj_data = pickle.load(open(join(cut_dir,'trajectories.pik'), 'rb'))

for trial_name,pos in list(traj_data.items())[:1]:
    plt.figure(figsize=(6,)*2)
    plt.axis('off')
    labels = [f'fish {i+1}' for i in range(pos.shape[1])]
    plt.plot(pos[:,:,0], pos[:,:,1], lw=0.2, label=labels)
    plt.title(trial_name)
    plt.legend(loc=(1,0.1))
    plt.show()

### Load precomputed analysis results.

In [None]:
# List previously computed cuts.
cut_dirs  = os.listdir(analysis_dir)
print(cut_dirs)

# Pick one and load data.
cut_label = cut_dirs[0]
cut_dir   = os.path.join(analysis_dir, cut_label)
stat_data = pickle.load(open(join(cut_dir,'stats.pik'), 'rb'))
globals().update(stat_data)

# Plot individual trial data

Distributions and other statistical quantities were precomputed in the previous section. Time series and trajectories plots require to reload each trial's trial file, one at a time.

### Trajectories and time series

In [None]:
# fig_dir = { k:os.path.join(analysis_dir,cut_label,k) for k in 
#                ['trajectories', 'angle-vs-time'] }
# for d in fig_dir.values():
#     if not os.path.exists(d):
#         os.mkdir(d)

# # Creating a new figure for each trial creates a memory leak
# # I haven't been able to plug, so I'm creating one figure for
# # trajectories and one for angles and reusing them.
# fig_traj = plt.figure(figsize=(9,)*2)
# fig_ang  = plt.figure(figsize=(12,6))

# for i,trial_file in enumerate(trials.index):
#     print('\r'+' '*200+'\r'+f'{i+1}/{len(trials)}',end='')
    
#     trial = load_trial(trial_file, load_data=True)
#     trial = compute_kinematics(trial)
#     globals().update(trial)
    
#     # Trajectories.
#     fig,ax = fig_traj,fig_traj.gca()
#     ax.add_patch(plt.Circle(center, R_px, facecolor='None', 
#                                    edgecolor='k', lw=0.5))
#     ax.plot(*np.moveaxis(data[::5,:n_ind,:2],2,0),lw=0.5)
#     ax.axis('equal')
#     ax.yaxis.set_inverted(True)
#     fig.suptitle(trial_name)
#     fig.savefig(os.path.join(fig_dir['trajectories'],trial_name+'.png'))
#     fig.clf()
    
#     # Angle vs time.
#     fig,ax = fig_traj,fig_traj.gca()
#     ax.plot(time[:,None],pos[:,:,2])
#     ax.set_xlabel('Time (s)')
#     ax.set_ylabel('Angle (rad)')
#     fig.suptitle(trial_name)
#     fig.savefig(os.path.join(fig_dir['angle-vs-time'],trial_name+'.png'))
#     fig.clf()
    
# plt.close('all')

### Statistics

In [None]:
fig_dir = { k:join(cut_dir,k) for k in [ 'valid_fraction', 
            'hist_dWall', 'hist_v', 'hist_vAng', 'hist_distAng', 'polar' ] }
for d in fig_dir.values():
    if not os.path.exists(d):
        os.mkdir(d)


fig = plt.figure()
# ax  = fig.gca()

for trial_file in tqdm(trials.index):
    
    globals().update(trials.loc[trial_file].to_dict())
    globals().update(trial_data[trial_file])
    
    # Valid fraction.
    bp = plt.bar(*zip(*valid_fraction.items()))
    for bar in bp:
        h,x,w = bar.get_height(),bar.get_x(),bar.get_width()
        plt.annotate(f'{h:.2f}', xy=(x+w/2,1.01), ha='center', va='bottom')
    plt.ylim(0,1.1)
    plt.ylabel('Valid fraction')
    plt.title(cut_label)
    plt.suptitle(trial_name)
    plt.savefig(os.path.join(fig_dir['valid_fraction'],trial_name+'.png'))
    plt.close()
    
#     # Distributions of distance to the wall, speed, and angular speed.
#     H = dict( hist_dWall='d_wall (cm)', hist_v='v (cm/s)', hist_vAng='v_ang (rad/s)' )
#     for name,label in H.items():
#         ax = fig.gca()
#         h,b = locals()[name]
#         ax.bar(b[:-1],h,width=b[1:]-b[:-1])
#         ax.set_yscale('log')
#         ax.set_xlabel(label)
#         ax.set_ylabel('frequency')
#         ax.set_title(cut_label)
#         fig.suptitle(trial_name)
#         fig.savefig(os.path.join(fig_dir[name],trial_name+'.png'))
#         fig.clf()
    
#     # Pairwise distance-angle distribution.
#     if not hist_distAng is None:
#         ax = fig.gca()
#         h,b1,b2 = hist_distAng
#         m = ax.pcolormesh(b1, b2*180/np.pi, h.T, cmap='Oranges')
#         ax.set_xlabel('pair distance (cm)')
#         ax.set_ylabel('pair angle (deg)')
#         fig.colorbar(m)
#         ax.set_title(cut_label)
#         fig.suptitle(trial_name)
#         fig.savefig(os.path.join(fig_dir['hist_distAng'],trial_name+'.png'))
#         fig.clf()
    
#     # Pairwise polar order parameter vs distance.
#     if not polar is None:
#         ax = fig.gca()
#         p,b = polar
#         ax.plot((b[1:]+b[:-1])/2,p,marker='o',mfc='None',ms=4)
#         ax.set_xlabel('pair distance (cm)')
#         ax.set_ylabel('mean cosine of pair angle')
#         ax.set_ylim(-1,1)
#         ax.set_title(cut_label)
#         fig.suptitle(trial_name)
#         fig.savefig(os.path.join(fig_dir['polar'],trial_name+'.png'))
#         fig.clf()
    
#     break

plt.close('all')

In [None]:
# ''' Analyze instances of unusually high velocity. '''

# # At 30 fps, |v_ang|=30 (about where the rare peaks start) 
# # corresponds to about pi/3 in one frame.
# print('v_ang for pi/3 in (1/30) second:',np.pi/3*fps)

# print('Instances of unusually high v_ang:')
# for f in fish_list:
#     ang_diff  = df[f,'ang'].diff()
#     I = np.nonzero(np.absolute(ang_diff.values)>1)[0]
#     for i in I[:5]:
#         display(df[f,'ang'].iloc[i-1:i+2])