In [None]:
import os, sys, datetime, re, pickle, warnings
from os.path import join
from glob import glob
import numpy as np
from scipy.signal import find_peaks, savgol_filter
from scipy.stats import shapiro, linregress, ttest_1samp, f_oneway, kruskal, \
                        wilcoxon, ttest_ind, mannwhitneyu, kstest, spearmanr
from statsmodels.stats.multicomp import pairwise_tukeyhsd
from statsmodels.stats.multitest import multipletests
# from pingouin.pairwise import pairwise_gameshowell
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, CenteredNorm, SymLogNorm, LinearSegmentedColormap
from matplotlib.patches import Rectangle
from matplotlib.lines import Line2D
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
import itertools as itt
from tqdm import tqdm
from numba import jit, njit
from copy import deepcopy

sys.path.append('trilab-tracker-0.1.2')
from trilabtracker import preprocess_trial

if not os.path.exists('figures'):
    os.mkdir('figures')

# print(os.getcwd())

plt.rcParams['figure.figsize'] = 9,6
plt.rcParams['figure.facecolor'] = 'w'
# plt.rcParams['figure.dpi'] = 150
dpi = 150

# Set the random number generator seed so the notebook is fully reproducible.
np.random.seed(1)

# Prepare trial data.

Prepare a dictionary of trials to analyze with basic info for each (path to trial file, population, age, number of individuals, etc).

In [None]:
tank_diameter_vs_age = { 7:9.6, 14:10.4, 21:12.8, 28:17.7, 42:24.2, 56:33.8, 70:33.8, 84:33.8 }

# Load name correction atlas.
naming_atlas = pd.read_csv('naming_atlas.csv', index_col='video-name')
# naming_atlas[naming_atlas]

# Load list of valid trials.
valid = pd.read_excel('tracking_progress.xlsx', usecols = ['video','qual_check']).set_index('video')
# print(valid['qual_check'].unique())
# display(valid)
valid = valid['qual_check'].to_dict()

# Extract trial metadata from the trial's filename.
def parse_trial_file(trial_file):
    trial_dir  = os.path.dirname(trial_file)
    trial_name = os.path.basename(trial_dir)
    trial_name = naming_atlas.loc[trial_name,'folder-name']
    try:
        pop,day,age,group,n_ind = trial_name.split('_')[:5]
        pop        = {'sf':'SF', 'pa':'Pa', 'rc':'RC'}[pop.lower()]
        age        = int(age[:-3])
        age = 42 if age==43 else (70 if age==71 else age)
        n_ind      = int(re.findall('\d+',n_ind)[0])
        R_cm       = tank_diameter_vs_age[age]/2
        # Return parsed data as a dictionary.
        trial      = { k:v for k,v in locals().items() if k in ['trial_file', 'trial_dir', 
                                     'trial_name', 'pop', 'age', 'group', 'n_ind', 'R_cm'] }
        return trial
    except:
#         print(trial_dir)
        return {}

def apply_smoothing(trial, n=5, method='savgol'):
    n = n-n%2+1 # n must be odd
    for k in ['data','pos','vel','acc','v','d_wall']:
        if k in trial.keys():
            if method=='savgol':
                # data contains NaN in the area and aspect ratio columns which crash
                # savgol_filter in Windows, though for some reason not in Ubuntu.
                # Workaround: exclude area and aspect ratio from the smoothing.
                I = (slice(None),slice(None),slice(3)) if k=='data' else slice(None)
                trial[k][I] = savgol_filter(trial[k][I], window_length=n, 
                                            polyorder=min(2,n-1), axis=0)
            else:
                logging.warning('Unknown smoothing method:', method)
    return trial

def load_trial(trial_file, n_smooth=9, **args):
#     args['load_timestamps'] = False
    trial = parse_trial_file(trial_file)
    trial = preprocess_trial(trial, **args)
    trial = apply_smoothing(trial, n=n_smooth)
    return trial

# Select a set of trials to analyze.
trial_files = sorted(glob('tracks/*/trial.pik'))
# print(trial_files)
print()

# Create a dataframe with basic trial info.
# Restrict to relevant trials.
# Count trials of each type.
trials = { f:parse_trial_file(f) for f in trial_files }
trials = { f:t for f,t in trials.items() if t!={} } # Remove unparsable trials.
trials = { f:t for f,t in trials.items() if 
           valid[os.path.basename(os.path.split(f)[0])]=='yes' } # Remove mistracked trials.
trials = { f:t for f,t in trials.items() if t['n_ind']==5 } # Restrict to groups of 5.
trials = { f:t for f,t in trials.items() if t['age'] in [7,28,42,70] } # Restrict to relevant ages.
# trials = { f:t for f,t in trials.items() if t['pop'] in ['SF','Pa'] } # Restrict to relevant populations.
trials = pd.DataFrame(trials).T #, index=trial_files)
trials = trials.dropna()
grouped_trials = trials.groupby(['pop','age','n_ind'])
count  = pd.DataFrame(grouped_trials['trial_dir'].count().rename('count'))
count = count.unstack(1)
count.columns = count.columns.droplevel()
count[pd.isna(count)] = 0
count = count.astype(int)
# print('Number of trials of each type:')
display(count)

def matching_trials(pop=None, age=None, n_ind=None, df=trials):
    I = pd.Series(data=True, index=df.index)
    if not pop is None:
        I = I & (df['pop']==pop)
    if not age is None:
        I = I & (df['age']==age)
    if not n_ind is None:
        I = I & (df['n_ind']==n_ind)
    return df[I].index.tolist()

In [None]:
BL = pd.read_excel('dataset2_bodylengths.xlsx', usecols=range(6), 
                   header=None, skiprows=[0], index_col=0)

body_length = {}
body_length_by_trial = {}
# body_length_by_fish = {}
for pop,age,n_ind in tqdm(grouped_trials.groups):
    trial_files = grouped_trials.groups[pop,age,n_ind]
    trial_files = [f.split(os.path.sep)[-2] for f in trial_files]
    body_length[pop,age] = 0.1*np.nanmean(BL.loc[trial_files])
    body_length_by_trial[pop,age] = 0.1*BL.loc[trial_files].mean(axis=1).values
#     body_length_by_fish[pop,age] = 0.1*BL.loc[trial_files].values.flatten()

In [None]:
pop_colors = {'SF':   {'face':'pink', 
                       'edge':'palevioletred', 
                       'lines':'mediumvioletred'},
              'Pa':   {'face':'thistle', 
                       'edge':'mediumorchid', 
                       'lines':'purple'},
              'Mock': {'face':'lightgray',
                       'edge':'k',
                       'lines':'k'}}

color_maps = { 'Pa':[171,93,171], 'SF':[217,103,175] }
color_maps = { k:np.array(v)/255 for k,v in color_maps.items() }
color_maps = { k:LinearSegmentedColormap.from_list(name=k, colors=['w',v])
               for k,v in color_maps.items() }

binsD = np.linspace(0, 1, 200)
binsθ = np.linspace(0, 180, 200)

def significant_symbol(p):
    s0 = '*' # '$\star\!$'
    if p<0.001: return s0*3 
    elif p<0.01: return s0*2
    elif p<0.05: return s0
    else: return ''

def savefig(fn):
    plt.savefig(f'figures/{fn}.png', dpi=200)

## Figure 1: Pair Orientation/Distance Joint Distribution Plots

In [None]:
fn = 'figure-assets/fig12_joint.pkl'

if not os.path.exists(fn):
# if True: 
    joint = {}
    for pop,age,n_ind in tqdm(grouped_trials.groups):
        trial_files = grouped_trials.groups[pop,age,n_ind]
        H = np.zeros((len(binsD)-1, len(binsθ)-1))
        for trial_file in trial_files:
            trial = load_trial(trial_file)
            globals().update(trial)
            # Make list of unique fish pairs.
            J1,J2 = np.triu_indices(n_ind,1)
            # Compute distance between the fish in each pair in each frame.
            d     = np.hypot(pos[:,J1,0]-pos[:,J2,0],pos[:,J1,1]-pos[:,J2,1]).flatten()
            d    /= tank_diameter_vs_age[age]
            # Compute angle between the fish in each pair in each frame.
            θ     = (pos[:,J1,2]-pos[:,J2,2]).flatten()
            θ     = θ - 2*np.pi*np.rint(θ/(2*np.pi))
            θ     = np.absolute(θ)*180/np.pi
            # Compute 2D histogram of pairwise distance and angle.
            H    += np.histogram2d(d, θ, bins=(binsD, binsθ), density=False)[0]
        joint[pop,age] = dict(binsD=binsD, binsθ=binsθ, H=H.T)

    with open(fn, 'wb') as f:
        pickle.dump(joint, f)
else:
    with open(fn, 'rb') as f:
        joint = pickle.load(f)

In [None]:
with open('figure-assets/fig12_joint.pkl', 'rb') as f:
    joint = pickle.load(f)

gsk = dict(bottom=0.1, top=0.9, left=0.06, right=0.99,
           wspace=0.3, hspace=0.3)
fig,axs = plt.subplots(2, 4, figsize=(12,6), gridspec_kw=gsk)
for i,(pop,age) in enumerate(sorted(joint.keys())):
    globals().update(joint[pop,age])
    ax = axs[1-i//4,i%4]
    if i//4==1:
        t = ax.set_title(f'{age}dpf', pad=15)
    ax.pcolormesh(binsD, binsθ, H, cmap=color_maps[pop])
    ax.set_xticks(np.linspace(0,1,6))
    ax.set_yticks(np.linspace(0,180,7))
    ax.set_xlim(0,0.98)
    ax.set_xlabel('Pair Distance (Arena Diameter)')
    ax.set_ylabel('Pair Angle (Degrees)')
    for k in 'right','top':
        ax.spines[k].set_visible(False)

savefig('fig1_joint')

## Figure 2: Real vs mock IID, NND, and orientation comparison

In [None]:
fn = 'figure-assets/fig12_comparisons.pkl'

def get_iid_nnd(pos):
    # pos = array of positions and angles of shape (n_frames,n_ind,3)
    n_ind = pos.shape[1]
    J1,J2 = np.triu_indices(n_ind,1)
    d     = np.hypot(pos[:,J1,0]-pos[:,J2,0],pos[:,J1,1]-pos[:,J2,1])
    d    /= body_length[pop,age]
    iid   = np.median(d)
    θ     = pos[:,J1,2]-pos[:,J2,2]
    θ     = θ - 2*np.pi*np.rint(θ/(2*np.pi))
    θ     = np.absolute(θ)*180/np.pi
    nnd   = []
    nna   = []
    for k in range(n_ind):
        # Indices of pairs involving fish k.
        K   = np.nonzero((J1==k)|(J2==k))[0]
        # Index of the pair containing fish k and its nearest neighbor in frame.
        K   = K[np.argmin(d[:,K], axis=1)]
        # Nearest neighbor distance.
        nnd.append(d[np.arange(d.shape[0]),K])
        # Nearest neighbor pair angle.
        nna.append(θ[np.arange(d.shape[0]),K])
    nnd   = np.median(np.concatenate(nnd))
    nna   = np.median(np.concatenate(nna))
    return iid,nnd,nna
    
if not os.path.exists(fn):
    IID,IID_mock = {},{} # inter individual distance
    NND,NND_mock = {},{} # nearest neighbor distance
    NNA,NNA_mock = {},{} # nearest neighbor angle
    for pop,age,n_ind in tqdm(grouped_trials.groups):
        trial_files   = grouped_trials.groups[pop,age,n_ind]
        trial_objects = [load_trial(trial_file) for trial_file in trial_files]
        # Real trials.
        IID[pop,age]  = []
        NND[pop,age]  = []
        NNA[pop,age]  = []
        for trial in trial_objects:
            iid,nnd,nna = get_iid_nnd(trial['pos'])
            IID[pop,age].append(iid)
            NND[pop,age].append(nnd)
            NNA[pop,age].append(nna)
        # Mock trials.
        IID_mock[pop,age] = []
        NND_mock[pop,age] = []
        NNA_mock[pop,age] = []
        for T in itt.combinations(trial_objects,5):
            n   = min([trial['pos'].shape[0] for trial in trial_objects])
            J   = np.random.randint(0, n_ind, size=n_ind)
            pos = np.empty((n,n_ind,3))
            for i,(trial,j) in enumerate(zip(T,J)):
                pos[:,i,:] = trial['pos'][:n,j,:]
            iid,nnd,nna = get_iid_nnd(pos)
            IID_mock[pop,age].append(iid)
            NND_mock[pop,age].append(nnd)
            NNA_mock[pop,age].append(nna)

    D = dict(IID=IID, NND=NND, NNA=NNA, IID_mock=IID_mock, 
             NND_mock=NND_mock, NNA_mock=NNA_mock)
    with open(fn, 'wb') as f:
        pickle.dump(D, f)
else:
    with open(fn, 'rb') as f:
        globals().update(pickle.load(f))

In [None]:
with open('figure-assets/fig12_comparisons.pkl', 'rb') as f:
    globals().update(pickle.load(f))

def boxplot(value_dict, ax, pop, Δx=0):
    x = np.arange(len(value_dict))+Δx
    color1 = color2 = pop_colors[pop]['lines']
    if pop=='Mock': color1,color2 = 'k','lightgray'
    ax.boxplot(value_dict.values(), 
               positions=x, 
               labels=value_dict.keys(), 
               showmeans = True, meanline = True, 
               meanprops = dict(color = color1, 
                                linewidth = 1.2, linestyle = '-'), 
               zorder = 1, patch_artist = True, showfliers = False, 
               showbox = False, showcaps = False,  whis = 0, 
               medianprops = dict(color = color1, 
                                  linewidth = 1.2, linestyle = '--'))
    p = ax.violinplot(value_dict.values(), 
                      positions=x,
                      showextrema=False)
    for b in p['bodies']:
        b.set_edgecolor(color1)
        b.set_facecolor(color2)
    for i,d in enumerate(value_dict.values()):
        lbl = None if i>0 else pop if pop=='Mock' else 'Real'
#         ax.scatter([x[i]]*len(d), d, s=4, color=color1, label=lbl)
        ax.plot([x[i]]*len(d), d, lw=0, marker='o',
                ms=1, color=color1, label=lbl)
    return ax


gsk     = dict(hspace=0.35, wspace=0.35,
               left=0.1 , right=0.99, 
               bottom=0.08, top=0.98)
fig,axs = plt.subplots(3, 2, figsize=(9,9), gridspec_kw=gsk)
ylims   = [180, 9, 20]
labels  = ['Nearest Neighbor Pair\nAngle (Degrees)',
           'Nearest Neighbor\nDistance (Body Lengths)',
           'Interindividual\nDistance (Body Lengths)' ]
for j,pop in enumerate(['SF','Pa']):
    for i,(real,mock,y,lbl) in \
        enumerate(zip([NNA,NND,IID], 
                      [NNA_mock,NND_mock,IID_mock], 
                      ylims, labels)):
        ax = axs[i,j]
        R  = { k[1]:v for k,v in real.items() if k[0]==pop }
        M  = { k[1]:v for k,v in mock.items() if k[0]==pop }
        boxplot(R, ax, pop, -0.2)
        boxplot(M, ax, 'Mock', 0.2)
        ax.set_xticks(range(4))
        ax.set_ylim(0, y)
        ax.set_xlabel('Age (dpf)')
#         ax.set_ylabel(lbl, x=-0.1)
        ax.text(-0.17, 0.5, lbl, transform=ax.transAxes, 
                ha='center', va='center', 
                rotation=90)
        ax.legend(loc='upper right')
        ax.get_legend_handles_labels()
        h,l  = ax.get_legend_handles_labels()
        for i,h_ in enumerate(h):
            h_ = Line2D([0],[0])
            h_.update_from(h[i])
            h_.set_markersize(5)
            h[i] = h_
        ax.legend(h, l, loc='upper right')
        
        for i,age in enumerate(R):
            x1,x2 = R[age],M[age]
            p1 = shapiro(x1).pvalue
            p2 = shapiro(x2).pvalue
            print(f'{pop}, {age}dpf')
            print(f'    [Real: n={len(x1)}, median={np.median(x1):.3g}, p_shapiro={p1:.3g}]')
            print(f'    [Mock: n={len(x2)}, median={np.median(x2):.3g}, p_shapiro={p2:.3g}]')
            if p1<0.05 or p2<0.05:
                t = mannwhitneyu(x1,x2)
                print(f'    [Mann Whitney U: stat={t.statistic:.3g}, p={t.pvalue:.3g}]')
            else:
                t = ttest_ind(x1,x2)
                print(f'    [T-test: stat={t.statistic:.3g}, p={t.pvalue:.3g}]')
            l1  = ax.get_ylim()
            y   = max(x1+x2) + 0.05*(l1[1]-l1[0])
            ax.text(i, y, significant_symbol(t.pvalue), transform=ax.transData, 
                    ha='center') #, fontsize=fontsize)
#         break
#     break
            
for j in range(2):
    axs[0,j].set_yticks(np.linspace(0,180,7))

savefig('fig2_comparison')
plt.show()

# Supplementary Figures

## Figure S4: Body length vs Arena diameter

In [None]:
gsk = dict(left=0.1, right=0.98, wspace=0.3, bottom=0.15, top=0.9)
fig,axs = plt.subplots(1, 2, figsize=(8,4), gridspec_kw=gsk)

for i,pop in enumerate(['SF','Pa']):
    x  = { age:tank_diameter_vs_age[age]/body_length_by_trial[pop,age] for age in [7,28,42,70]}
    c1 = pop_colors[pop]['lines']
    c2 = pop_colors[pop]['face']
    ax = axs[i]
    ax.boxplot(x.values(), labels=x.keys(),
               showmeans=True, meanline=True, patch_artist=True, 
               showfliers=False, showbox = True, showcaps=True, 
               meanprops=dict(color=c1, linestyle='-', linewidth=1.2),
               medianprops=dict(color=c1, linewidth=1.2, linestyle='--'), 
               boxprops=dict(edgecolor=c1, facecolor=c2), 
               whiskerprops=dict(color=c1), capprops=dict(color=c1),
               flierprops=dict(mec=c1)
               )
    for i,x_ in enumerate(x.values()):
        ax.scatter([i+1]*len(x_), x_, s=3, zorder=2, color='k')
    ax.axhline(22, color='gray', ls='--', lw=1)
    ax.set_ylim(0, 40)
    ax.set_xlabel('Age (dpf)') #, fontsize='x-large')
    ax.set_ylabel('Arena Diameter/Body Length') #, fontsize='x-large')

savefig('figS4_body-length')
plt.show()

## Figures S1 & S5: Proximity vs Speed & Body length

In [None]:
fn = 'figure-assets/fig123_body-length_speed_proximity.xlsx'

if not os.path.exists(fn):
    def get_iid_nnd(pos):
        # pos = array of positions and angles of shape (n_frames,n_ind,3)
        n_ind = pos.shape[1]
        J1,J2 = np.triu_indices(n_ind,1)
        d     = np.hypot(pos[:,J1,0]-pos[:,J2,0],pos[:,J1,1]-pos[:,J2,1])
    #     d    /= body_length[pop,age]
        iid   = np.median(d)
        θ     = pos[:,J1,2]-pos[:,J2,2]
        θ     = θ - 2*np.pi*np.rint(θ/(2*np.pi))
        θ     = np.absolute(θ)*180/np.pi
        nnd   = []
        nna   = []
        for k in range(n_ind):
            # Indices of pairs involving fish k.
            K   = np.nonzero((J1==k)|(J2==k))[0]
            # Index of the pair containing fish k and its nearest neighbor in frame.
            K   = K[np.argmin(d[:,K], axis=1)]
            # Nearest neighbor distance.
            nnd.append(d[np.arange(d.shape[0]),K])
            # Nearest neighbor pair angle.
            nna.append(θ[np.arange(d.shape[0]),K])
        nnd   = np.median(np.concatenate(nnd))
        nna   = np.median(np.concatenate(nna))
        return iid,nnd,nna

    # if not os.path.exists(fn):
    BL = pd.read_excel('dataset2_bodylengths.xlsx', usecols=range(6), 
                       header=None, skiprows=[0], index_col=0)
    df = []
    for pop,age,n_ind in tqdm(grouped_trials.groups):
        trial_files   = grouped_trials.groups[pop,age,n_ind]
        for trial_file in trial_files:
            trial = load_trial(trial_file)
            v     = np.median(trial['v'])
            bl    = 0.1*np.nanmean(BL.loc[trial_file.split(os.path.sep)[-2]])
            iid,nnd,nna = get_iid_nnd(trial['pos'])
            df.append((trial_file,pop,age,bl,v,iid,nnd,nna))

    df = pd.DataFrame(df, columns=['trial_file', 'pop', 'age', 'body length', 
                                   'speed', 'iid', 'nnd', 'nna'])
    df = df.set_index('trial_file')
    df['iid'] /= df['body length']
    df['nnd'] /= df['body length']
    df.to_excel(fn)

### Figure S5: Proximity vs Body length

In [None]:
fn  = 'figure-assets/fig123_body-length_speed_proximity.xlsx'
df  = pd.read_excel(fn, index_col=0)

gsk = dict(left=0.1, right=0.98, wspace=0.25, 
           bottom=0.1, top=0.95, hspace=0.25)
fig,axs = plt.subplots(2, 2, figsize=(8,6), gridspec_kw=gsk)
for i,pop in enumerate(['SF','Pa']):
    c   = pop_colors[pop]['lines']
    df_ = df[df['pop']==pop]
    df_.plot.scatter(x='body length', y='iid', ax=axs[i,0], color=c)
    df_.plot.scatter(x='body length', y='nnd', ax=axs[i,1], color=c)
    axs[i,0].set_ylabel('Interindividual Distance (BL)')
    axs[i,1].set_ylabel('Nearest Neighbor Distance (BL)')
for ax in axs.flatten():
    ax.set_ylim(0,20)
    ax.locator_params(axis='y', nbins=5)
    ax.set_xlabel('Body Length (cm)')
savefig('figS5_body-length_proximity')
plt.show()

### Figure S1: Proximity vs Speed

In [None]:
# '''
# This looks at the correlation between proximity metrics and speed 
# with one data point per trial. What Aly actually did is keep each 
# frame as a seperate data point (averaging over fish or pairs within 
# that frame).
# '''

# fn = 'figure-assets/fig123_body-length_speed_proximity.xlsx'
# df = pd.read_excel(fn, index_col=0)
# df = pd.melt(df, id_vars=['pop','age','speed'], 
#              value_vars=['iid','nnd','nna'])
# # display(df)

# gb    = df.groupby(['pop','age','variable'])
# gb    = gb[['speed','value']]
# stats = {}
# for (pop,age,var),df_ in gb:
#     print(pop,age,var)
#     display(df_)
#     display(df_.corr(method='pearson'))
#     display(df_.corr(method='spearman'))
#     df_.plot.scatter(x='speed', y='value', figsize=(4,3))
#     print(spearmanr(df_))
#     break

In [None]:
def get_iid_nnd(pos):
    # pos = array of positions and angles of shape (n_frames,n_ind,3)
    n_ind = pos.shape[1]
    J1,J2 = np.triu_indices(n_ind,1)
    d     = np.hypot(pos[:,J1,0]-pos[:,J2,0],pos[:,J1,1]-pos[:,J2,1])
    θ     = pos[:,J1,2]-pos[:,J2,2]
    θ     = θ - 2*np.pi*np.rint(θ/(2*np.pi))
    θ     = np.absolute(θ)*180/np.pi
    nnd   = []
    nna   = []
    for k in range(n_ind):
        # Indices of pairs involving fish k.
        K   = np.nonzero((J1==k)|(J2==k))[0]
        # Index of the pair containing fish k and its nearest neighbor in frame.
        K   = K[np.argmin(d[:,K], axis=1)]
        # Nearest neighbor distance.
        nnd.append(d[np.arange(d.shape[0]),K])
        # Nearest neighbor pair angle.
        nna.append(θ[np.arange(d.shape[0]),K])
    iid   = np.mean(d, axis=1)
    nnd   = np.mean(nnd, axis=0)
    nna   = np.mean(nna, axis=0)
    return iid,nnd,nna

df = []
for pop,age,n_ind in tqdm(grouped_trials.groups):
    trial_files   = grouped_trials.groups[pop,age,n_ind]
    V = []
    D = {'iid':[], 'nnd':[], 'nna':[]}
    for trial_file in trial_files:
        trial = load_trial(trial_file)
        v     = np.mean(trial['v'], axis=1)
        iid,nnd,nna = get_iid_nnd(trial['pos'])
        V.append(v)
        D['iid'].append(iid)
        D['nnd'].append(nnd)
        D['nna'].append(nna)
    V = np.concatenate(V)
    gsk = dict(wspace=0.35)
    fig,axs = plt.subplots(1, 3, figsize=(9,2.5), gridspec_kw=gsk)
    for i,q in enumerate(['nnd','iid','nna']):
        Q = np.concatenate(D[q])
        corr,p = spearmanr(V, Q)
        df.append(dict(pop=pop, age=age, var=q, corr=corr, p=p))
        ax = axs[i]
        ax.scatter(V, Q, s=1, color=pop_colors[pop]['lines'])
        tx = f'$\\rho={corr:.3f}$\n$p={p:.3f}$\n$R^2={corr**2:.3f}$'
        ax.text(0.99, 0.99, tx, transform=ax.transAxes, ha='right', va='top')
        ax.set_xlabel('Speed')
        ax.set_ylabel({'iid':'Interindividual Distance',
                       'nnd':'Nearest Neighbor Distance',
                       'nna':'Nearest Neighbor Pair Angle'}[q])
    plt.suptitle(f'{pop}, {age}dpf')
    plt.show()
df = pd.DataFrame(df)
df['R2'] = df['corr']**2
df.to_excel('figure-assets/fig123_speed_proximity_long.xlsx', index=False)

In [None]:
df = pd.read_excel('figure-assets/fig123_speed_proximity_long.xlsx')
df = df.round(3)
df = df.melt(id_vars=['pop','age','var'], value_vars=['corr','p','R2'], var_name='stat')
df = df.pivot_table(index=['pop','age','stat'], columns='var', values='value')
df = df.loc[(['SF','Pa'],[7,28,42,70],['corr','R2','p']),['nna','nnd','iid']]
display(df)
df.to_excel('figures/figS1_speed-proximity.xlsx')