In [None]:
import os, sys, datetime, re, pickle, warnings
from os.path import join
from glob import glob
import numpy as np
from scipy.signal import find_peaks, savgol_filter
from scipy.stats import shapiro, ttest_1samp, wilcoxon
from scipy.stats import norm as normal
from statsmodels.api import qqplot
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, CenteredNorm, SymLogNorm
from matplotlib.patches import Rectangle
from matplotlib.lines import Line2D
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
import itertools as itt
from tqdm import tqdm
from numba import jit, njit

sys.path.append('trilab-tracker-0.1.2')
from trilabtracker import preprocess_trial

if not os.path.exists('figure-assets'):
    os.mkdir('figure-assets')
if not os.path.exists('figures'):
    os.mkdir('figures')

plt.rcParams['figure.figsize'] = 9,6
plt.rcParams['figure.facecolor'] = 'w'
dpi = 150

# Prepare trial data.

Prepare a dictionary of trials to analyze with basic info for each (path to trial file, population, age, number of individuals, etc).

In [None]:
tank_diameter_vs_age = { 7:9.6, 14:10.4, 21:12.8, 28:17.7, 42:24.2, 56:33.8, 70:33.8, 84:33.8 }

# Load name correction atlas.
naming_atlas = pd.read_csv('data/naming_atlas.csv', index_col='video-name')

# Load list of valid trials.
valid = pd.read_excel('data/tracking_progress.xlsx', usecols = ['video','qual_check']).set_index('video')
valid = valid['qual_check'].to_dict()

# Extract trial metadata from the trial's filename.
def parse_trial_file(trial_file):
    trial_dir  = os.path.dirname(trial_file)
    trial_name = os.path.basename(trial_dir)
    trial_name = naming_atlas.loc[trial_name,'folder-name']
    try:
        pop,day,age,group,n_ind = trial_name.split('_')[:5]
        pop        = {'sf':'SF', 'pa':'Pa', 'rc':'RC'}[pop.lower()]
        age        = int(age[:-3])
        age = 42 if age==43 else (70 if age==71 else age)
        n_ind      = int(re.findall('\d+',n_ind)[0])
        R_cm       = tank_diameter_vs_age[age]/2
        # Return parsed data as a dictionary.
        trial      = { k:v for k,v in locals().items() if k in ['trial_file', 'trial_dir', 
                                     'trial_name', 'pop', 'age', 'group', 'n_ind', 'R_cm'] }
        return trial
    except:
#         print(trial_dir)
        return {}

def apply_smoothing(trial, n=5, method='savgol'):
    n = n-n%2+1 # n must be odd
    for k in ['data','pos','vel','acc','v','d_wall']:
        if k in trial.keys():
            if method=='savgol':
                # data contains NaN in the area and aspect ratio columns which crash
                # savgol_filter in Windows, though for some reason not in Ubuntu.
                # Workaround: exclude area and aspect ratio from the smoothing.
                I = (slice(None),slice(None),slice(3)) if k=='data' else slice(None)
                trial[k][I] = savgol_filter(trial[k][I], window_length=n, 
                                            polyorder=min(2,n-1), axis=0)
            else:
                logging.warning('Unknown smoothing method:', method)
    return trial

def load_trial(trial_file, n_smooth=9, **args):
    trial = parse_trial_file(trial_file)
    trial = preprocess_trial(trial, **args)
    trial = apply_smoothing(trial, n=n_smooth)
    return trial

# Select a set of trials to analyze.
trial_files = sorted(glob('data/tracks/*/trial.pik'))
# print(trial_files)
print()

# Create a dataframe with basic trial info.
# Restrict to relevant trials.
# Count trials of each type.
trials = { f:parse_trial_file(f) for f in trial_files }
trials = { f:t for f,t in trials.items() if t!={} } # Remove unparsable trials.
trials = { f:t for f,t in trials.items() if 
           valid[os.path.basename(os.path.split(f)[0])]=='yes' } # Remove mistracked trials.
trials = { f:t for f,t in trials.items() if t['n_ind']==5 } # Restrict to groups of 5.
trials = { f:t for f,t in trials.items() if t['age'] in [7,28,42,70] } # Restrict to relevant ages.
trials = pd.DataFrame(trials).T #, index=trial_files)
trials = trials.dropna()
grouped_trials = trials.groupby(['pop','age','n_ind'])
count  = pd.DataFrame(grouped_trials['trial_dir'].count().rename('count'))
count = count.unstack(1)
count.columns = count.columns.droplevel()
count[pd.isna(count)] = 0
count = count.astype(int)
# print('Number of trials of each type:')
display(count)

def matching_trials(pop=None, age=None, n_ind=None, df=trials):
    I = pd.Series(data=True, index=df.index)
    if not pop is None:
        I = I & (df['pop']==pop)
    if not age is None:
        I = I & (df['age']==age)
    if not n_ind is None:
        I = I & (df['n_ind']==n_ind)
    return df[I].index.tolist()

# Compute polar force maps

In [None]:
# Angle and distance bins.
binsθ = np.linspace(0, 2*np.pi, 9)
binsD = np.concatenate([np.arange(0,0.2,0.05), np.arange(0.2,2.1,0.1)])
nθ    = len(binsθ)
nD    = len(binsD)
print(binsD)

# Where to save/load figure data.
real_data_file = 'figure-assets/fig345_real.pkl'
mock_data_file = 'figure-assets/fig345_mock.pkl'

def load_polar_maps(real_data_file=real_data_file, mock_data_file=mock_data_file):
    with open(real_data_file, 'rb') as f:
        polar_maps = pickle.load(f)
    with open(mock_data_file, 'rb') as f:
        polar_maps_mock = pickle.load(f)
    return polar_maps,polar_maps_mock

# Distance ranges for attractive forces and aligning torque, as slices for 
# the distance bins.
Id1   = range(3,6) # attraction zone
Id2   = range(1,6) # alignment zone

## Actual groups

In [None]:
if not os.path.exists(real_data_file):
    
    @njit
    def compute_3d_histogram(d, θ, φ, fx, fy, dω, binsθ, binsD, N, Fx, Fy, dΩ):
        for i in range(len(d)):
            if np.isfinite(fx[i]) and np.isfinite(fy[i]) and np.isfinite(dω[i]):
                kθ = np.digitize([θ[i]], binsθ).item() - 1
                kφ = np.digitize([φ[i]], binsθ).item() - 1
                kd = np.digitize([d[i]], binsD).item() - 1
                if kθ<0 or kφ<0 or kd<0 or kθ>=N.shape[0] or kφ>=N.shape[1] or kd>=N.shape[2]:
                    print('Histogram index out of bounds', N.shape, kθ, kφ, kd)
                    return
                N[kθ,kφ,kd]  += 1
                Fy[kθ,kφ,kd] += fy[i]
                Fx[kθ,kφ,kd] += fx[i]
                dΩ[kθ,kφ,kd] += dω[i]
        return

    polar_maps = {}
    for pop,age,n_ind in tqdm(grouped_trials.groups):
        trial_files = grouped_trials.groups[pop,age,n_ind]
        polar_maps[pop,age,n_ind] = dict(N=[], Fx=[], Fy=[], dΩ=[], Fx_avg=[], Fy_avg=[], dΩ_avg=[])
        v0 = [0,0]
        for trial_file in trial_files:
            trial = load_trial(trial_file)
            globals().update(trial)
            
            # Prepare to compute mean speed over entire age and population group.
            v0[0] += np.sum(np.isfinite(v))
            v0[1] += np.nansum(v)
            
            # Pairs of fish. Order matters: fish 1 is focal fish; fish 2 is neighbor.
            J     = np.indices((n_ind,n_ind)).reshape(2,-1)
            J1,J2 = J[:,J[0]!=J[1]]
            # Compute kinematic quantities for every pair.
            u     = np.cos(pos[:,J1,2]),np.sin(pos[:,J1,2]) # Director of fish 1.
            xy    = pos[:,J2,:2] - pos[:,J1,:2]             # Position of fish 2 relative to fish 1.
            u,xy  = np.reshape(u,(2,-1)), np.reshape(xy,(-1,2)).T # Merge time and fish axes.
            d     = np.hypot(*xy)
            u_r   = xy/d[None,:] # Unit vector pointing from fish 1 to fish 2.
            θ     = np.arctan2(xy[1],xy[0]) - pos[:,J1,2].flatten() # Polar angle of fish 2 in Frenet basis of fish 1.
            θ     = θ % (2*np.pi)
            φ     = (pos[:,J2,2] - pos[:,J1,2]).flatten()   # Relative orientation.
            φ     = φ % (2*np.pi)
            ω     = vel[:,J1,2].flatten()
            dω    = acc[:,J1,2].flatten()
            f     = acc[:,J1,:2].reshape((-1,2)).T          # Acceleration.
            # Normalize by tank size.
            xy   /= R_cm
            d    /= R_cm
            # Acceleration in coordinate system based on director of focal fish.
            fy    = u[0]*f[0] + u[1]*f[1]  # Speeding force (acceleration of focal fish along itself).
            fx    = u[1]*f[0] - u[0]*f[1]  # Turning force (acceleration of focal fish perpendicular to itself).
            
            # Bin force and torque data.
            N  = np.zeros((nθ,nθ,nD)) # Density.
            Fx = np.zeros((nθ,nθ,nD)) # Turning force.
            Fy = np.zeros((nθ,nθ,nD)) # Speeding force.
            dΩ = np.zeros((nθ,nθ,nD)) # Angular acceleration.
            compute_3d_histogram(d, θ, φ, fx, fy, dω, binsθ, binsD, N, Fx, Fy, dΩ)
            
            # Compute average attractive forces and aligning torque.
            # Integrate over irrelevant angle (φ for Fx and Fy, θ for dΩ), then average within each bin.
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=RuntimeWarning)
                Fx_avg = np.nansum(Fx, axis=1)[:-1,:-1]/np.nansum(N, axis=1)[:-1,:-1]
                Fy_avg = np.nansum(Fy, axis=1)[:-1,:-1]/np.nansum(N, axis=1)[:-1,:-1]
                dΩ_avg = np.nansum(dΩ, axis=0)[:-1,:-1]/np.nansum(N, axis=0)[:-1,:-1]
            # Flip signs to capture attraction or alignment.
            Fx_avg[:4]  *= -1 # Invert left side => positive means attraction.
            Fy_avg[2:6] *= -1 # Invert rear side => positive means attraction.
            dΩ_avg[4:]  *= -1 # Invert right side => positive means alignment.
            # Average over attraction or alignment zone.
            # The average is weigthed by area in (θ,d) or (φ,d) space, i.e., all configurations in
            # the relevant zone are treated equally, independently of their likeliness.
            A   = (binsθ[1:,None]-binsθ[:-1,None])*(binsD[None,1:]**2-binsD[None,:-1]**2)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=RuntimeWarning)
                Fx_avg = np.nansum((Fx_avg*A)[:,Id1])/np.sum(A[:,Id1])
                Fy_avg = np.nansum((Fy_avg*A)[:,Id1])/np.sum(A[:,Id1])
                dΩ_avg = np.nansum((dΩ_avg*A)[:,Id2])/np.sum(A[:,Id2])
            tmp = dict( N=N, Fx=Fx, Fy=Fy, dΩ=dΩ, Fx_avg=Fx_avg, Fy_avg=Fy_avg, dΩ_avg=dΩ_avg )
            for k in tmp:
                polar_maps[pop,age,n_ind][k].append(tmp[k].copy())

        polar_maps[pop,age,n_ind].update(binsD=binsD, binsθ=binsθ, R=R_cm, v0=v0[1]/v0[0])
#         break

    with open(real_data_file, 'wb') as f:
        pickle.dump(polar_maps, f)

## Proof that mock pairs are equivalent to mock groups

In [None]:
# n_ind = 5 # number of fish per group
# n_trials = 9 # typical number of trials of a given age and population

# # Construct the set of all mock pairs, i.e., pairs of fish from two different trials.
# pairs = []
# # Loop over all possible pairs of distinct trials. Order matters because the focal fish will 
# # taken from the first trial whereas the neighbor fish will be taken from the second trial.
# for i1,i2 in itt.permutations(range(n_trials),2):
#     # Loop over pairs of fish made up of a fish from trial i1 and a fish from trial i2.
#     for j1,j2 in itt.product(range(n_ind),range(n_ind)):
#         pairs.append(((i1,j1),(i2,j2)))
# print('Does the list contain duplicates?', len(pairs_set)<len(pairs))
# print('Number of unique fish pairs:', len(pairs))

In [None]:
# # Make a list of all possible mock groups, i.e., groups of 5 fish taken 
# # from 5 different trials.
# groups = []
# for I in itt.combinations(range(n_g),n_ind):
#     for J in itt.product(*[range(n_ind)]*n_ind):
#         groups.append(tuple(zip(I,J)))

# # Make a list of all possible fish pairs by looping over mock groups and 
# # listing all possible fish pairs withint that mock group.
# pairs2 = []
# for group in groups:
#     for k1,k2 in itt.permutations(group,2):
#         pairs2.append((k1,k2))

# # print('Number of unique fish pairs:', len(set(pairs2)))
# print('Are these the same pairs as before?', set(pairs2)==set(pairs))

# # This time there are lots of duplicates. To make sure averages over the mock pairs 
# # are equivalent to averages over the pairs from mock groups, we need to make sure 
# # every pair has the same number of duplicates.
# from collections import Counter
# c = Counter(pairs2) # number of instances of each pair
# n = set(c.values()) # distinct numbers of instances
# print('Number of duplicates of each pair:', n)
# # n has only one element, therefore every pair has the same number of duplicates.

## Mock groups

In [None]:
if not os.path.exists(mock_data_file):

    @njit
    def compute_3d_histogram(d, θ, φ, fx, fy, dω, binsθ, binsD, N, Fx, Fy, dΩ):
        for i in range(len(d)):
            if np.isfinite(fx[i]) and np.isfinite(fy[i]) and np.isfinite(dω[i]):
                kθ = np.digitize([θ[i]], binsθ).item() - 1
                kφ = np.digitize([φ[i]], binsθ).item() - 1
                kd = np.digitize([d[i]], binsD).item() - 1
                if kθ<0 or kφ<0 or kd<0 or kθ>=N.shape[0] or kφ>=N.shape[1] or kd>=N.shape[2]:
                    print('Histogram index out of bounds', N.shape, kθ, kφ, kd)
                    return
                N[kθ,kφ,kd]  += 1
                Fy[kθ,kφ,kd] += fy[i]
                Fx[kθ,kφ,kd] += fx[i]
                dΩ[kθ,kφ,kd] += dω[i]
        return

    polar_maps = {}
    for pop,age,n_ind in tqdm(grouped_trials.groups):
        trials = []
        v0     = [0,0]
        for trial_file in grouped_trials.groups[pop,age,n_ind]:
            trial = load_trial(trial_file)
            trials.append(trial)
            v0[0] += np.sum(np.isfinite(trial['v']))
            v0[1] += np.nansum(trial['v'])
        v0     = v0[1]/v0[0]
        
        N      = np.zeros((nθ,nθ,nD)) # Density.
        Fx     = np.zeros((nθ,nθ,nD)) # Turning force.
        Fy     = np.zeros((nθ,nθ,nD)) # Speeding force.
        dΩ     = np.zeros((nθ,nθ,nD)) # Angular acceleration.
        Fx_avg = []
        Fy_avg = []
        dΩ_avg = []
        R_cm   = trials[0]['R_cm']
        polar_maps[pop,age,n_ind] = dict(binsD=binsD, binsθ=binsθ, R=R_cm, v0=v0,
                                         Fx_avg=[], Fy_avg=[], dΩ_avg=[])

        n      = len(trials)
        for i1,i2 in itt.permutations(range(n),2):
            N_    = np.zeros((nθ,nθ,nD)) # Density.
            Fx_   = np.zeros((nθ,nθ,nD)) # Turning force.
            Fy_   = np.zeros((nθ,nθ,nD)) # Speeding force.
            dΩ_   = np.zeros((nθ,nθ,nD)) # Angular acceleration.
            for j1,j2 in itt.product(range(n_ind),range(n_ind)):
                pos   = trials[i1]['pos'][:,j1,:3]
                vel   = trials[i1]['vel'][:,j1,:3]
                acc   = trials[i1]['acc'][:,j1,:3]
                pos2  = trials[i2]['pos'][:,j2,:3]
                # Ignore frames beyond end of shorter trial.
                m     = min(pos.shape[0],pos2.shape[0])
                pos,pos2 = pos[:m],pos2[:m]
                vel,acc  = vel[:m],acc[:m]
                # Compute kinematic quantities.
                xy    = (pos2[:,:2] - pos[:,:2]).T
                u     = np.cos(pos[:,2]),np.sin(pos[:,2]) # Director of fish 1.
                d     = np.hypot(*xy)
                u_r   = xy/d[None,:]
                θ     = np.arctan2(xy[1],xy[0]) - pos[:,2]
                θ     = θ % (2*np.pi)
                φ     = pos2[:,2] - pos[:,2]              # Relative orientation.
                φ     = φ % (2*np.pi)
                ω     = vel[:,2]
                dω    = acc[:,2]
                f     = acc[:,:2].T      # Acceleration.
                # Normalize by tank size.
                xy   /= R_cm
                d    /= R_cm
                # Acceleration in coordinate system based on director of focal fish.
                fy    = u[0]*f[0] + u[1]*f[1]  # Speeding force (acceleration of focal fish along itself).
                fx    = u[1]*f[0] - u[0]*f[1]  # Turning force (acceleration of focal fish perpendicular to itself).
                
                compute_3d_histogram(d, θ, φ, fx, fy, dω, binsθ, binsD, N_, Fx_, Fy_, dΩ_)
                N    += N_
                Fx   += Fx_
                Fy   += Fy_
                dΩ   += dΩ_

            # Compute average attractive forces and aligning torque.
            # Integrate over irrelevant angle (φ for Fx and Fy, θ for dΩ), then average within each bin.
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=RuntimeWarning)
                Fx_  = np.nansum(Fx, axis=1)[:-1,:-1]/np.nansum(N, axis=1)[:-1,:-1]
                Fy_  = np.nansum(Fy, axis=1)[:-1,:-1]/np.nansum(N, axis=1)[:-1,:-1]
                dΩ_  = np.nansum(dΩ, axis=0)[:-1,:-1]/np.nansum(N, axis=0)[:-1,:-1]
            # Flip signs to capture attraction or alignment.
            Fx_[:4]  *= -1 # Invert left side => positive means attraction.
            Fy_[2:6] *= -1 # Invert rear side => positive means attraction.
            dΩ_[4:]  *= -1 # Invert right side => positive means alignment.
            # Average over attraction or alignment zone.
            # The average is weigthed by area in (θ,d) or (φ,d) space, i.e., all configurations in
            # the relevant zone are treated equally, independently of their likeliness.
            A   = (binsθ[1:,None]-binsθ[:-1,None])*(binsD[None,1:]**2-binsD[None,:-1]**2)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=RuntimeWarning)
                Fx_ = np.nansum((Fx_*A)[:,Id1])/np.sum(A[:,Id1])
                Fy_ = np.nansum((Fy_*A)[:,Id1])/np.sum(A[:,Id1])
                dΩ_ = np.nansum((dΩ_*A)[:,Id2])/np.sum(A[:,Id2])
            polar_maps[pop,age,n_ind]['Fx_avg'].append(Fx_)
            polar_maps[pop,age,n_ind]['Fy_avg'].append(Fy_)
            polar_maps[pop,age,n_ind]['dΩ_avg'].append(dΩ_)
                
        polar_maps[pop,age,n_ind].update(N=N, Fx=Fx, Fy=Fy, dΩ=dΩ)
        
    with open(mock_data_file, 'wb') as f:
        pickle.dump(polar_maps, f)

# Make figures

In [None]:
quantity_name = { 'N':'Density', 'Fr':'Attraction Force', 'Fy':'Speeding Force',
                  'Fx':'Turning Force', 'Px':'Right Turn Probability',
                  'Ω':'Angular Velocity', 'dΩ':'Angular Acceleration'}
population_name = { 'SF':'Surface', 'Pa':'Cave' }

def plot_polar_heatmap(H_, ax=None, n=10, binsθ=binsθ, binsD=binsD, 
                       cmap=plt.get_cmap('seismic'), norm=None):
    if ax is None:
        ax = plt.gca()
    binsθ_ = np.linspace(binsθ[0],binsθ[-1],n*(len(binsθ)-1)+1)
    θ,D    = np.meshgrid(binsθ_, binsD, indexing='ij')
    H      = np.stack([H_]*n, axis=1).reshape((-1,H_.shape[1]))
    ax.grid(False)
    ax.set_theta_zero_location('N')
    mesh   = ax.pcolormesh(θ, D, H, cmap=cmap, norm=norm)
    ax.set_xticks([])
    ax.set_yticks([])
    return mesh

pop_colors = {'SF': {'face':'pink', 'edge':'palevioletred', 'lines':'mediumvioletred'},
              'Pa': {'face':'thistle', 'edge':'mediumorchid', 'lines':'purple'}}


def plot_violin(ax, q, polar_maps, polar_maps_mock, pop, n_ind, ages, 
                average_wrapper, mode, quantity_name=quantity_name):
    avg_ = lambda pm,q,age: average_wrapper(pm[pop,age,n_ind], q)
    if mode=='difference':
        dQ   = {age:np.array(avg_(polar_maps,q,age))-avg_(polar_maps_mock,q,age)[0] for age in ages}
    elif mode=='real':
        dQ   = {age:np.array(avg_(polar_maps,q,age)) for age in ages}
    elif mode=='mock':
        dQ   = {age:np.array(avg_(polar_maps_mock,q,age)) for age in ages}
    df   = pd.DataFrame([(f'{age}dpf',dq) for age in dQ for dq in dQ[age]], columns=['Age',q])
    data = dict(x='Age', y=q, data=df, ax=ax)
    violin = sns.violinplot(inner=None, color=pop_colors[pop]['face'], linewidth=1, **data)
    for c in violin.axes.collections:
        c.set_edgecolor(pop_colors[pop]['edge'])
    sns.stripplot(size=3, jitter=0.1, color=pop_colors[pop]['lines'], **data)
    ax.axhline(0, color='k', lw=0.5)
    ax.set_xlabel(None)
    ax.set_ylabel(quantity_name[q], fontsize=fontsize)
    ax.locator_params(axis='y', nbins=4)
    return dQ

def average(Q, N, q, v0, binsD, binsθ, Id, Iθ):
    N,Q = N.copy(),Q.copy()
    if q in ['Fy']:
        Q[2:6]  *= -1 # Invert rear side => positive means attraction.
    if q in ['Fx']:
        Q[:4]  *= -1 # Invert left side => positive means attraction.
    if q in ['Px']:
        Px[:4]  = 1-Px[:4] # Invert left side => positive means attraction.
    if q in ['Fy','Fx','Fr']:
        Q /= v0
    if q in ['Ω','dΩ']:
        Q[4:]  *= -1 # Invert rear side => positive means attraction.
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)
        A  = (binsθ[1:,None]-binsθ[:-1,None])*(binsD[None,1:]**2-binsD[None,:-1]**2)
        Q_ = np.nansum((Q*A/N)[Iθ,Id])/np.sum(A[Iθ,Id])
    return Q_

def significant_symbol(p):
    s0 = '*' # '$\star\!$'
    if p<0.001: return s0*3 
    elif p<0.01: return s0*2
    elif p<0.05: return s0
    else: return ''

def savefig(fn):
    plt.savefig(f'figures/fig345_{fn}.png', dpi=200)

## (d,θ) maps

In [None]:
for mode in ['real', 'mock', 'difference']:
    
    polar_maps,polar_maps_mock = load_polar_maps()
    
    quantities  = ['N', 'Fy', 'Fx'] # , 'dΩ'
    populations = ['SF','Pa']
    ages        = [7,28,42,70]
    n_ind       = 5

    # Sum over real trials. Sum over the relative orientation angle φ.
    for k in polar_maps.keys():
        for q in quantities:
            polar_maps[k][q] = np.nansum(polar_maps[k][q], axis=(0,2))
            polar_maps_mock[k][q] = np.nansum(polar_maps_mock[k][q], axis=1)
            polar_maps[k][q] = polar_maps[k][q][:-1,:-1]
            polar_maps_mock[k][q] = polar_maps_mock[k][q][:-1,:-1]

    norm = { 'N' : CenteredNorm(halfrange=1), 
             'Fy': CenteredNorm(halfrange=3), 
             'Fx': CenteredNorm(halfrange=3), 
             'Ω' : CenteredNorm(halfrange=0.25),
             'dΩ': CenteredNorm(halfrange=2) }
    if mode in ['real','mock']:
        norm['Px'] = CenteredNorm(vcenter=0.5, halfrange=0.25)

    nx,ny   = len(ages),len(quantities)*len(populations)
    gsk     = {'bottom':0.00, 'top':0.97, 'left':0.15, 'right':0.9,
               'hspace':0.05, 'wspace':0.05}
    fig,axs = plt.subplots(ny, nx, figsize=(2*nx,2*ny*0.85), 
                           subplot_kw={'projection':'polar'},
                           gridspec_kw=gsk, squeeze=False)
    for ax in axs[::2].flatten():
        bounds = list(ax.get_position().bounds)
        bounds[1] -= 0.09/ny
        ax.set_position(bounds)

    fontsize = 18

    for i_q,q in enumerate(quantities):
        for i_p,pop in enumerate(populations):
            i_y = i_q*len(populations)+i_p
            for i_a,age in enumerate(ages):
                pm  = polar_maps[pop,age,n_ind]
                pmm = polar_maps_mock[pop,age,n_ind]
                binsθ,binsD = pm['binsθ'],pm['binsD']
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", category=RuntimeWarning)
                    if q=='N':
                        H1,H2 = pm['N']/np.sum(pm['N']), pmm['N']/np.sum(pmm['N'])
                    else:
                        H1,H2 = pm[q]/pm['N'], pmm[q]/pmm['N']
                    H  = {'real':H1, 'mock':H2, 'difference':H1-H2}[mode]
                    if q in ['Fy','Fx','Fr']:
                        H /= pmm['v0']
                    if q=='N':
                        H /= (binsθ[1:,None]-binsθ[:-1,None])*(binsD[None,1:]**2-binsD[None,:-1]**2)

                ax   = axs[i_y,i_a]
                mesh = plot_polar_heatmap(H, ax=ax, n=10, norm=norm[q])
                ax.set_ylim(0,0.5)

                if i_y==0:
                    ax.set_title(f'{age} dpf', fontsize=fontsize, y=1.1)

            if i_p==0:
                x,y1,dx,dy = axs[i_q*len(populations),-1].get_position().bounds
                x,y2,dx,dy = axs[i_q*len(populations)+1,-1].get_position().bounds
                cax  = fig.add_axes([x+1.15*dx,(y1+dy+y2)/2-0.8*dy,0.1*dx,1.6*dy])
                fig.colorbar(mesh, cax=cax)
                cax.locator_params(axis='y', nbins=3)

            axs[i_y,0].set_ylabel(population_name[pop], fontsize=fontsize, rotation=90, labelpad=20, va='center')

        ax = axs[i_q*len(populations)+1,0]
        ax.text(-0.6, 1, quantity_name[q], transform=ax.transAxes, fontsize=fontsize, rotation=90, va='center')


    savefig(f'heatmaps_{mode}')
    plt.show()

### Polar grid showing focal fish. Attraction zone. Alignment zone.

In [None]:
fig,ax = plt.subplots(1, 1, figsize=(3,3), subplot_kw={'projection':'polar'}, 
                      gridspec_kw={'bottom':0.15, 'top':0.85, 'left':0.15, 'right':0.85})

ax.set_theta_zero_location('N')
ticks  = binsD[1:]
labels = [f'{d:.1g}' if np.abs(d%0.1)<1e-6 else '' for d in ticks]
ax.set_yticks(ticks, labels, x=-0.4, ha='center', va='center')
ax.set_ylim(0,0.5)

img = plt.imread('data/fish.png')
img = np.ma.masked_where(img==1, img)
ax2 = fig.add_axes([0.43,0.39,0.15,0.15])
ax2.imshow(img, alpha=0.7, cmap=plt.get_cmap('gray'), vmin=0, vmax=1)
ax2.axis('off')

savefig('heatmap-grid')
plt.show()

In [None]:
# Region of (θ,d) space used to compute the scalar metric.
d1,d2 = binsD[Id1[0]],binsD[Id1[-1]+1]
print(d1,d2)
fig,ax = plt.subplots(1, 1, figsize=(3,3), subplot_kw={'projection':'polar'}, 
                      gridspec_kw={'bottom':0.15, 'top':0.85, 'left':0.15, 'right':0.85})
ax.bar(0, 1).remove()
ax.set_theta_zero_location('N')
ticks  = binsD[1:]
labels = [f'{d:.1g}' if np.abs(d%0.1)<1e-6 else '' for d in ticks]
ax.set_yticks(ticks, labels, x=-0.4, ha='center', va='center')
ax.set_ylim(0,0.5)
ax.add_patch(Rectangle((0,d1), width=2*np.pi, height=d2-d1, color='C0'))
savefig(f'force-average')
plt.show()

In [None]:
# Region of (θ,d) space used to compute the scalar metric.
d1,d2 = binsD[Id2[0]],binsD[Id2[-1]+1]
print(d1,d2)
fig,ax = plt.subplots(1, 1, figsize=(3,3), subplot_kw={'projection':'polar'}, 
                      gridspec_kw={'bottom':0.15, 'top':0.85, 'left':0.15, 'right':0.85})
ax.bar(0, 1).remove()
ax.set_theta_zero_location('N')
ticks  = binsD[1:]
labels = [f'{d:.1g}' if np.abs(d%0.1)<1e-6 else '' for d in ticks]
ax.set_yticks(ticks, labels, x=-0.4, ha='center', va='center')
ax.set_ylim(0,0.5)
ax.add_patch(Rectangle((0,d1), width=2*np.pi, height=d2-d1, color='C0'))
savefig(f'angular-average')
plt.show()

### Scalar metric with statistical analysis

#### Violin plots for figures 4 and 5

In [None]:
polar_maps,polar_maps_mock = load_polar_maps()

quantities  = ['Fx','Fy','dΩ']
populations = ['SF','Pa']
ages        = [7,28,42,70]
n_ind       = 5

nx,ny   = len(populations),len(quantities)
gsk     = {'bottom':0.07, 'top':0.92, 'left':0.15, 'right':0.98,
           'hspace':0.4, 'wspace':0.4}
fig,axs = plt.subplots(ny, nx, figsize=(4*nx,2.5*ny), 
                       gridspec_kw=gsk, squeeze=False)
fontsize = 14
quantity_name_ = {'Fy':'Attractive\nSpeeding Force',
                  'Fx':'Attractive\nTurning Force',
                  'dΩ':'Aligning Angular\nAcceleration'}

for i_p,pop in enumerate(populations):
    for i_q,q in enumerate(quantities):
        ax   = axs[i_q,i_p]
        
        # Prepare data.
        Q1   = { age:polar_maps[pop,age,n_ind][q+'_avg'] for age in ages }
        Q2   = { age:polar_maps_mock[pop,age,n_ind][q+'_avg'] for age in ages }
        Q = { age:Q1[age]-np.mean(Q2[age]) for age in ages }
        if q in ['Fx','Fy']:
            Q = { age:Q[age]/polar_maps_mock[pop,age,n_ind]['v0'] for age in ages }
        
        # Make boxplots.
        x = np.arange(len(ages))
        ax.boxplot(Q.values(), positions=x, 
                   showmeans = True, meanline = True, 
                   meanprops = dict(color = pop_colors[pop]['lines'], 
                                    linewidth = 1.2, linestyle = '-'), 
                   zorder = 1, patch_artist = True, showfliers = False, 
                   showbox = False, showcaps = False,  whis = 0, 
                   medianprops = dict(color = pop_colors[pop]['lines'], 
                                      linewidth = 1.2, linestyle = '--'))
        
        # Make violin plots.
        p = ax.violinplot(Q.values(), positions=x, showextrema=False)
        for b in p['bodies']:
            b.set_edgecolor(pop_colors[pop]['lines'])
            b.set_facecolor(pop_colors[pop]['lines'])
        
        # Make scatter plots.
        xy = [(x[i_a],y) for i_a,age in enumerate(ages) for y in Q[age]]
        ax.plot(*zip(*xy), lw=0, marker='o', ms=2, c=pop_colors[pop]['lines'], label='Actual')
        
        
        # Polish plot.
        ax.axhline(0, color='k', lw=0.5, ls='--', zorder=-10)
        ax.set_xlabel(None)
        ax.set_xticks(range(0,4,1), [f'{a}dpf' for a in ages])
        ax.tick_params('x', length=0, pad=10)
        ax.set_ylabel(quantity_name_[q], fontsize=fontsize)
        ax.locator_params(axis='y', nbins=4)
        ax.set_xlim(-0.5,3.5)
        ylim = [-0.5,2.5] if q=='dΩ' else [-0.2,0.8]
        ax.set_ylim(*ylim)
        if q=='dΩ': ax.yaxis.labelpad += 10
        
        # Statistical analysis.
        for i_a,age in enumerate(Q):
            x = Q[age]
            print(f'{pop}, {quantity_name[q]}, {age}dpf') #, end=' ')
            p = shapiro(x).pvalue
            print(f'    [Shapiro (real): n={len(x)}, median={np.median(x):.3g}, p={p:.3g}]') #, end=' ')
            t1 = ttest_1samp(x, 0)
            t2 = wilcoxon(x)
            print(f'    [1 sample T-test: stat={t1.statistic:.3g}, p={t1.pvalue:.3g}]')
            print(f'    [Wilcoxon: stat={t2.statistic:.3g}, p={t2.pvalue:.3g}]')
            t  = t2 if p<0.05 else t1
            l1  = ax.get_ylim()
            l2  = max(Q[age])
            y   = l2 + 0.05*(l1[1]-l1[0])
            ax.text(i_a, y, significant_symbol(t.pvalue), ha='center', fontsize=fontsize)

for pop,ax in zip(populations,axs[0,:]):
    ax.set_title(population_name[pop], fontsize=fontsize, pad=10) #, weight='bold')

savefig(f'force-violins')
plt.show()

#### Figure S6

In [None]:
polar_maps,polar_maps_mock = load_polar_maps()

quantities  = ['Fx','Fy','dΩ']
populations = ['SF','Pa']
ages        = [7,28,42,70]
n_ind       = 5

nx,ny   = len(populations),len(quantities)
gsk     = {'bottom':0.07, 'top':0.92, 'left':0.15, 'right':0.98,
           'hspace':0.4, 'wspace':0.4}
fig,axs = plt.subplots(ny, nx, figsize=(4*nx,2.5*ny), 
                       gridspec_kw=gsk, squeeze=False)
fontsize = 14
quantity_name_ = {'Fy':'Attractive\nSpeeding Force',
                  'Fx':'Attractive\nTurning Force',
                  'dΩ':'Aligning Angular\nAcceleration'}

for i_p,pop in enumerate(populations):
    for i_q,q in enumerate(quantities):
        ax   = axs[i_q,i_p]
        
        # Prepare data.
        Q1   = { age:polar_maps[pop,age,n_ind][q+'_avg'] for age in ages }
        Q2   = { age:polar_maps_mock[pop,age,n_ind][q+'_avg'] for age in ages }
        if q in ['Fx','Fy']:
            Q1 = { age:Q1[age]/polar_maps_mock[pop,age,n_ind]['v0'] for age in ages }
            Q2 = { age:Q2[age]/polar_maps_mock[pop,age,n_ind]['v0'] for age in ages }
        
        # Make boxplots.
        x = np.arange(len(ages))
        ax.boxplot(Q1.values(), positions=x, 
                   showmeans = True, meanline = True, 
                   meanprops = dict(color = pop_colors[pop]['lines'], 
                                    linewidth = 1.2, linestyle = '-'), 
                   zorder = 1, patch_artist = True, showfliers = False, 
                   showbox = False, showcaps = False,  whis = 0, 
                   medianprops = dict(color = pop_colors[pop]['lines'], 
                                      linewidth = 1.2, linestyle = '--'))
        ax.boxplot(Q2.values(), positions=x,
                   showmeans = True, meanline = True, 
                   meanprops = dict(color = 'black', 
                                    linestyle = '-', 
                                    linewidth = 1.2), 
                   zorder = 1, patch_artist = True, showfliers = False, 
                   showbox = False, showcaps = False,  whis = 0, 
                   medianprops = dict(lw=0))
        
        # Make violin plots.
        p = ax.violinplot(Q1.values(), positions=x, showextrema=False)
        for b in p['bodies']:
            b.set_edgecolor(pop_colors[pop]['lines'])
            b.set_facecolor(pop_colors[pop]['lines'])
        
        # Make scatter plots.
        xy = [(x[i_a],q1) for i_a,age in enumerate(ages) for q1 in Q1[age]]
        ax.plot(*zip(*xy), lw=0, marker='o', ms=2, c=pop_colors[pop]['lines'], label='Actual')
        
        # Polish plot.
        ax.axhline(0, color='k', lw=0.5, ls='--', zorder=-10)
        ax.set_xlabel(None)
        ax.set_xticks(range(0,4,1), [f'{a}dpf' for a in ages])
        ax.tick_params('x', length=0, pad=10)
        ax.set_ylabel(quantity_name_[q], fontsize=fontsize)
        ax.locator_params(axis='y', nbins=4)
        ax.set_xlim(-0.5,3.5)
        ylim = [-0.5,2.5] if q=='dΩ' else [-0.2,0.8]
        ax.set_ylim(*ylim)
        if q=='dΩ': ax.yaxis.labelpad += 10
        
        # Statistical analysis.
        for i_a,age in enumerate(Q1):
            x1,x2 = Q1[age],Q2[age]
            print(f'{pop}, {quantity_name[q]}, {age}dpf') #, end=' ')
            p1 = shapiro(x1).pvalue
            print(f'    [Real: n={len(x1)}, mean={np.mean(x1):.3g}, median={np.median(x1):.3g}, p_shapiro={p1:.3g}]') #, end=' ')
            print(f'    [Mock: n={len(x2)}, mean={np.mean(x2):.3g}]') #, end=' ')
            if p1<0.05:
                t = wilcoxon(x1-np.mean(x2))
                print(f'    [Wilcoxon: stat={t.statistic:.3g}, p={t.pvalue:.3g}]')
            else:
                t = ttest_1samp(x1, np.mean(x2))
                print(f'    1 sample T-test: stat={t.statistic:.3g}, p={t.pvalue:.3g}')
            l1  = ax.get_ylim()
            l2  = max(max(Q1[age]),max(Q2[age]))
            y   = l2 + 0.05*(l1[1]-l1[0])
            ax.text(i_a, y, significant_symbol(t.pvalue), ha='center', fontsize=fontsize)

for pop,ax in zip(populations,axs[0,:]):
    ax.set_title(population_name[pop], fontsize=fontsize, pad=10) #, weight='bold')

savefig(f'force-violins_supplementary')
plt.show()

#### QQplots

In [None]:
polar_maps,polar_maps_mock = load_polar_maps()

quantities  = ['Fx','Fy','dΩ']
populations = ['SF','Pa']
ages        = [7,28,42,70]
n_ind       = 5

nx,ny   = len(ages),len(populations)*len(quantities)
gsk     = {'bottom':0.07, 'top':0.92, 'left':0.15, 'right':0.98,
           'hspace':0.4, 'wspace':0.4}
fig,axs = plt.subplots(ny, nx, figsize=(4*nx,2.5*ny), 
                       gridspec_kw=gsk, squeeze=False)
fontsize = 'x-large'
quantity_name_ = {'Fy':'Attractive\nSpeeding Force',
                  'Fx':'Attractive\nTurning Force',
                  'dΩ':'Aligning Angular\nAcceleration'}

for i_p,pop in enumerate(populations):
    for i_q,q in enumerate(quantities):
        for i_a,(age,x) in enumerate(Q.items()):
            ax   = axs[i_q*len(populations)+i_p,i_a]

            # Prepare data.
            x1   = polar_maps[pop,age,n_ind][q+'_avg']
            x2   = polar_maps_mock[pop,age,n_ind][q+'_avg']
            x    = np.array(x1)-np.mean(x2)
            if q in ['Fx','Fy']:
                x = x/polar_maps_mock[pop,age,n_ind]['v0']
        
            # Make qqplot.
            c   = 'C0'
            qqplot(x, normal, fit=True, line='45', ax=ax, ms=3, 
                   markerfacecolor=c, markeredgecolor=c)
            
            p0 = shapiro(x).pvalue
            p1 = ttest_1samp(x, 0).pvalue
            p2 = wilcoxon(x).pvalue
            ax.text(0.02, 0.98, 
                    f'p_shapiro={p0:.3f}\np_t-test={p1:.3f}\np_wilcox={p2:.3f}',
                    transform=ax.transAxes, ha='left', va='top')
            
        
        ax = axs[i_q*len(populations)+i_p,0]
        ax.text(-0.5, 0.5, f'{quantity_name[q]}\n{population_name[pop]}', 
                transform=ax.transAxes, fontsize=fontsize, 
                rotation=90, va='center', ha='center')

for age,ax in zip(ages,axs[0,:]):
    ax.set_title(f'{age}dpf', fontsize=fontsize, pad=30)

savefig(f'force-qqplots')
plt.show()