In [None]:
import os, sys, re, warnings, logging, pickle, bz2
from os.path import join
from glob import glob
import numpy as np
from scipy.signal import find_peaks
import pandas as pd
from pandas.api.types import is_list_like
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import defaultdict
sys.path.append('..')
import trilabtracker as tt

from importlib import reload
for m in tt.__all__:
    eval(f'reload(tt.{m})')
reload(tt)

os.getcwd()

## Prepare trial data.

Define notebook-wide quantities (tank radius, cut ranges). Prepare a dictionary of trials to analyze with basic info about each (path to trial file, population, number of individuals, etc).

In [None]:
# Load list of trial names to use in the analysis.
valid_trials = open('../settings/adult-trials-to-analyze.txt').readlines()
valid_trials = [ fn.strip() for fn in valid_trials ]

# Tank radius.
R_cm        = 55.5

# Extract trial metadata from the trial's filename.
def parse_trial_file(trial_file, etho=False):
#     if etho:
#         trial_dir  = None
#         trial_name = os.path.basename(trial_file)
#         trial_name = trial_name.split('-')[1]
#     else:
    trial_dir  = os.path.dirname(trial_file)
    trial_name = os.path.basename(trial_dir)
    _,pop,n    = trial_name.lower().split('_')[:3]
    n_ind      = int(n[1:])
    if 'dark' in trial_name.lower():
        pop = pop+'-dark'
        # Skip dark trials until I implement one-sided background subtraction.
#         return {}
    if not trial_name in valid_trials:
        return {}
    trial      = { k:v for k,v in locals().items() if k in ['trial_file', 
                              'trial_dir', 'trial_name', 'pop', 'n_ind'] }
    trial['R_cm'] = R_cm
    return trial

# trilabtracker only for now
def load_trial(trial_file, **args):
    trial = parse_trial_file(trial_file)
    return tt.preprocess_trial(trial, load_timestamps=False, orientation='body', 
                               wall_distance=True, n_smooth=3, buffer_frames=0, 
                               **args)
    ''' !!!! '''
    # [done] Replace orientations (data[:,:,2]) with displacement-based ones.
    # [!] Handle overlaps.
    ''' !!!! '''    

# Select a set of trials to analyze.
trial_files = sorted(glob('../tracking_output/*/trial.pik'))
# print(trial_files)

# Count trials of each type.
trials = [ parse_trial_file(f) for f in trial_files ]
trials = pd.DataFrame(trials, index=trial_files)
trials = trials.dropna()
grouped_trials = trials.groupby(['pop','n_ind'])
count  = pd.DataFrame(grouped_trials['trial_dir'].count().rename('count'))
count = count.unstack(1)
count.columns = count.columns.droplevel().astype(int)
count[pd.isna(count)] = 0
count = count.astype(int)
display(count)

# Return the list of trials matching some condition.
# For column name from the trial dataframe is a valid argument.
# Provide a value or list of values to match.
# Example: matching_trials(pop=['mo','ti'], n_ind=5)
def matching_trials(df=trials, **args):
    I = pd.Series(data=True, index=df.index)
    for k,v in args.items():
        if not v is None:
            if is_list_like(v):
                I = I & df[k].isin(v)
            else:
                I = I & (df[k]==v)
    return df[I].index.tolist()

# Need two distinct cuts: one with inactive fish and one without.
v_inactive  = 3
cut_ranges0 = { 'v':[0,100], 'v_ang':[-30,30], 't':[600,1800] }
cut_ranges1 = { 'v':[v_inactive,100], 'v_ang':[-30,30], 't':[600,1800] }

# Inactivity filter.
def is_active(v, v_inactive=v_inactive, buffer_frames=0):
    I = v>v_inactive
    for j in range(buffer_frames):
        # Grow inactive region by one frame in each direction.
        I[1:-1] = I[0:-2] & I[1:-1] & I[2:]
    return I

# Figure 1

Distance to the wall and nematic order parameter w.r.t. the wall.

In [None]:
fig_data    = { k:defaultdict(list) for k in ['density', 'nematic', 'median'] }
cut_ranges  = cut_ranges0
bins        = R_cm - np.sqrt(np.linspace(R_cm**2,0,20))
for trial_file in tqdm(matching_trials(n_ind=1)):
    trial = load_trial(trial_file, cut_ranges=cut_ranges)
    globals().update(trial)
    # Area density and nematic OP vs distance to the wall.
    vals       = d_wall.flatten()
    median     = np.nanmedian(vals)
    density,_  = np.histogram(vals, bins=bins)
    area       = np.pi*(R_cm-bins[:-1])**2 - np.pi*(R_cm-bins[1:])**2
    density    = density/np.sum(density)/area
    nematic    = []
    for i in range(len(bins)-1):
        I = (vals>=bins[i])&(vals<bins[i+1])
        thetaP = np.arctan2(pos[I,0,1],pos[I,0,0]) # position angle (polar angle)
#         thetaO = np.arctan2(vel[I,0,1],vel[I,0,0]) # velocity angle
        thetaO = pos[I,0,2] # body orientation angle
        thetaW = thetaO-thetaP # velocity angle with respect to closest wall
        thetaW = thetaW - 2*np.pi*np.rint(thetaW/(2*np.pi)) # between -pi and pi
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=RuntimeWarning)
            nematic.append(np.nanmean(np.cos(2*thetaW)))
    fig_data['density'][pop].append(density)
    fig_data['nematic'][pop].append(nematic)
    fig_data['median'][pop].append(median)
#     break
for k1,v1 in fig_data.items():
    for k2,v2 in v1.items():
        fig_data[k1][k2] = np.array(v2)
fig_data['bin_centers'] = (bins[1:]+bins[:-1])/2

pickle.dump( fig_data, open('data/figure1-data.pik','wb') )

# Figure 2

Fraction of the time active.

In [None]:
fig_data    = { k:defaultdict(list) for k in ['speed_distribution', 'inactive_fraction'] }
cut_ranges  = cut_ranges0
bins        = np.linspace(*cut_ranges['v'],101)
for trial_file in tqdm(trials.index):
    trial = load_trial(trial_file, cut_ranges=cut_ranges)
    globals().update(trial)
    if pop=='sf':
        h,_  = np.histogram(v.flatten(), bins=bins, density=True)
        fig_data['speed_distribution'][pop,n_ind].append(h)
    f = np.count_nonzero(is_active(v))/np.count_nonzero(np.isfinite(v))
    fig_data['inactive_fraction'][pop,n_ind].append(f)
for k,v in fig_data['speed_distribution'].items():
    fig_data['speed_distribution'][k] = np.array(v)
fig_data['speed_distribution']['bin_centers'] = (bins[1:]+bins[:-1])/2

pickle.dump( fig_data, open('data/figure2-data.pik','wb') )

# Figures 3 & S3

Distribution of speed when active.

In [None]:
from importlib import reload
for m in tt.__all__:
    eval(f'reload(tt.{m})')
reload(tt)

fig_data    = { k:defaultdict(list) for k in ['speed_distribution', 'mean_speed'] }
cut_ranges  = cut_ranges1
bins        = np.linspace(*cut_ranges['v'],101)
for trial_file in tqdm(trials.index):
    trial = load_trial(trial_file, cut_ranges=cut_ranges)
    globals().update(trial)
#     h,_   = np.histogram(v.flatten(), bins=bins, density=True)
    h,_   = np.histogram(v[is_active(v)].flatten(), bins=bins, density=True)
    fig_data['speed_distribution'][pop,n_ind].append(h)
    fig_data['mean_speed'][pop,n_ind].append(np.nanmean(v))
for k,v in fig_data['speed_distribution'].items():
    fig_data['speed_distribution'][k] = np.array(v)
fig_data['speed_distribution']['bin_centers'] = (bins[1:]+bins[:-1])/2

pickle.dump( fig_data, open('data/figure3-data.pik','wb') )

# Figures 4 & S4

Distribution of angular speed when active.

In [None]:
from importlib import reload
for m in tt.__all__:
    eval(f'reload(tt.{m})')
reload(tt)

fig_data    = { k:defaultdict(list) for k in ['angSpeed_distribution'] }
cut_ranges  = cut_ranges1
# bins        = np.linspace(0,cut_ranges['v_ang'][1],61)
# bins        = np.concatenate([np.linspace(0,2,20)[:-1],
#                               np.linspace(2,cut_ranges['v_ang'][1],40)])
bins        = np.linspace(0,np.sqrt(cut_ranges['v_ang'][1]),61)**2
for trial_file in tqdm(trials.index):
    trial = load_trial(trial_file, cut_ranges=cut_ranges)
    globals().update(trial)
#     vals  = np.absolute(vel[:,:,2])
    vals  = np.absolute(vel[is_active(v),2])
    h,_   = np.histogram(vals.flatten(), bins=bins, density=True)
    fig_data['angSpeed_distribution'][pop,n_ind].append(h)
for k,v in fig_data['angSpeed_distribution'].items():
    fig_data['angSpeed_distribution'][k] = np.array(v)
fig_data['angSpeed_distribution']['bin_centers'] = (bins[1:]+bins[:-1])/2

pickle.dump( fig_data, open('data/figure4-data.pik','wb') )

# Figures 5 & S5

Joint distribution of pair distance and pair angle.

In [None]:
from importlib import reload
for m in tt.__all__:
    eval(f'reload(tt.{m})')
reload(tt)

# fig_data    = { k:defaultdict(list) for k in ['distAng_heatmap','cos_mean'] }
fig_data    = defaultdict(lambda: defaultdict(list))
cut_ranges  = cut_ranges0
bins_d      = np.arange(0,2*R_cm+1,1)
bins_a      = np.linspace(0,180,112)
d0          = 10 # distance threshold for close-range
for trial_file in tqdm(trials.index):
    if trials.loc[trial_file,'n_ind'] == 1:
        continue
    trial = load_trial(trial_file, cut_ranges=cut_ranges)
    globals().update(trial)
    
    J1,J2 = np.triu_indices(n_ind,1)
    d     = np.hypot(pos[:,J1,0]-pos[:,J2,0],pos[:,J1,1]-pos[:,J2,1]).flatten()
    a     = (pos[:,J1,2]-pos[:,J2,2]).flatten()
    a     = a - 2*np.pi*np.rint(a/(2*np.pi))
    I     = np.isfinite(d) & np.isfinite(a) & (d>1e-6)
    d     = d[I]
    a     = np.absolute(a[I])*180/np.pi
    # 2D histogram of pairwise distance and angle.
    distAng_heatmap = np.histogram2d(d, a, bins=(bins_d,bins_a), density=True)[0]
    # Distribution and mean cosine of angle at close range.
    a_    = np.absolute(a[d<d0])
    ang_distribution = np.histogram(a_, bins=bins_a, density=True)[0]
    polar = np.nanmean(np.cos(a_*np.pi/180))
    
#     # Pairwise polar alignment parameter vs pair distance.
#     K = np.digitize(d,bins_d)
#     with warnings.catch_warnings():
#         warnings.simplefilter("ignore", category=RuntimeWarning)
#         p = np.array([ np.nanmean(np.cos(a[K==i])) for i in range(len(bins_d)+1) ])
#     polar = p[1:-1]
    
    fig_data['distAng_heatmap'][pop,n_ind].append(distAng_heatmap)
    fig_data['ang_distribution'][pop,n_ind].append(ang_distribution)
    fig_data['polar'][pop,n_ind].append(polar)

for k,v in fig_data['distAng_heatmap'].items():
    fig_data['distAng_heatmap'][k] = np.mean(v, axis=0)
for q in ['ang_distribution','polar']:
    for k,v in fig_data[q].items():
        fig_data[q][k] = np.array(v)
fig_data['bins_d'] = bins_d
fig_data['bins_a'] = bins_a
fig_data           = dict(fig_data)

pickle.dump( fig_data, open('data/figure5-data.pik','wb') )