In [None]:
import os, sys, datetime, re, pickle, warnings
from os.path import join
from glob import glob
import numpy as np
from scipy.signal import find_peaks, savgol_filter
from scipy.stats import gaussian_kde
from scipy.stats import shapiro, linregress, ttest_1samp, f_oneway, kruskal, wilcoxon
from statsmodels.stats.multicomp import pairwise_tukeyhsd
from statsmodels.stats.multitest import multipletests
# from pingouin.pairwise import pairwise_gameshowell
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, CenteredNorm, SymLogNorm
from matplotlib.patches import Rectangle
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
import itertools as itt
from tqdm import tqdm
from numba import jit, njit

sys.path.append('trilab-tracker-0.1.2')
from trilabtracker import preprocess_trial

print(os.getcwd())

plt.rcParams['figure.figsize'] = 9,6
plt.rcParams['figure.facecolor'] = 'w'
# plt.rcParams['figure.dpi'] = 150
dpi = 150

# Prepare trial data.

Prepare a dictionary of trials to analyze with basic info for each (path to trial file, population, age, number of individuals, etc).

In [None]:
tank_diameter_vs_age = { 7:9.6, 14:10.4, 21:12.8, 28:17.7, 42:24.2, 56:33.8, 70:33.8, 84:33.8 }

# Load name correction atlas.
naming_atlas = pd.read_csv('naming_atlas.csv', index_col='video-name')
# naming_atlas[naming_atlas]

# Extract trial metadata from the trial's filename.
def parse_trial_file(trial_file):
    trial_dir  = os.path.dirname(trial_file)
    trial_name = os.path.basename(trial_dir)
    trial_name = naming_atlas.loc[trial_name,'folder-name']
    try:
        pop,day,age,group,n_ind = trial_name.split('_')[:5]
        pop        = {'sf':'SF', 'pa':'Pa', 'rc':'RC'}[pop.lower()]
        age        = int(age[:-3])
        age = 42 if age==43 else (70 if age==71 else age)
        n_ind      = int(re.findall('\d+',n_ind)[0])
        R_cm       = tank_diameter_vs_age[age]/2
        if n_ind not in [1,2,5,10] or pop not in ['SF','Pa','RC']:
            raise Exception
        trial      = { k:v for k,v in locals().items() if k in ['trial_file', 'trial_dir', 
                                     'trial_name', 'pop', 'age', 'group', 'n_ind', 'R_cm'] }
        return trial
    except:
        print(trial_dir)
        return {}

def apply_smoothing(trial, n=5, method='savgol'):
    n = n-n%2+1 # n must be odd
    for k in ['data','pos','vel','acc','v','d_wall']:
        if k in trial.keys():
            if method=='savgol':
                trial[k] = savgol_filter(trial[k], window_length=n, 
                                         polyorder=min(2,n-1), axis=0)
            else:
                logging.warning('Unknown smoothing method:', method)
    return trial

def load_trial(trial_file, **args):
    trial = parse_trial_file(trial_file)
    trial = preprocess_trial(trial, **args)
    trial = apply_smoothing(trial, n=9)
    return trial

# Select a set of trials to analyze.
trial_files = sorted(glob('tracking-data/*/trial.pik'))
# print(trial_files)

# Count trials of each type.
trials = [ parse_trial_file(f) for f in trial_files]
trials = pd.DataFrame(trials,index=trial_files)
trials = trials.dropna()
grouped_trials = trials.groupby(['pop','age','n_ind'])
count  = pd.DataFrame(grouped_trials['trial_dir'].count().rename('count'))
count = count.unstack(1)
count.columns = count.columns.droplevel()
count[pd.isna(count)] = 0
count = count.astype(int)
# print('Number of trials of each type:')
display(count)

def matching_trials(pop=None, age=None, n_ind=None, df=trials):
    I = pd.Series(data=True, index=df.index)
    if not pop is None:
        I = I & (df['pop']==pop)
    if not age is None:
        I = I & (df['age']==age)
    if not n_ind is None:
        I = I & (df['n_ind']==n_ind)
    return df[I].index.tolist()

# Compute polar force maps

In [None]:
binsθ = np.linspace(0, 2*np.pi, 9)
binsD = np.concatenate([np.arange(0,0.2,0.05), np.arange(0.2,2.1,0.1)])
print(binsD)

nθ    = len(binsθ)
nD    = len(binsD)

groups = [ g for g in grouped_trials.groups if g[2]==5 and g[1]!=56 ]
# groups = [('SF',70,5)]

real_data_file = 'figure-assets/fig345_real.pkl'
mock_data_file = 'figure-assets/fig345_mock.pkl'

## Actual groups

In [None]:
if not os.path.exists(real_data_file):

    @njit
    def compute_3d_histogram(d, θ, φ, fx, fy, fr, ω, dω, binsθ, binsD, N, Fx, Fy, Fr, Px, Ω, dΩ):
        for i in range(len(d)):
            if np.isfinite(fx[i]) and np.isfinite(fy[i]) and np.isfinite(fr[i]) and np.isfinite(dω[i]):
                kθ = np.digitize([θ[i]], binsθ).item() - 1
                kφ = np.digitize([φ[i]], binsθ).item() - 1
                kd = np.digitize([d[i]], binsD).item() - 1
                if kθ<0 or kφ<0 or kd<0 or kθ>=N.shape[0] or kφ>=N.shape[1] or kd>=N.shape[2]:
                    print('Histogram index out of bounds', N.shape, kθ, kφ, kd)
                    return
                N[kθ,kφ,kd]  += 1
                Fy[kθ,kφ,kd] += fy[i]
                Fx[kθ,kφ,kd] += fx[i]
                Fr[kθ,kφ,kd] += fr[i]
                Ω[kθ,kφ,kd]  += ω[i]
                dΩ[kθ,kφ,kd] += dω[i]
                if fx[i]>0:
                    Px[kθ,kφ,kd] += 1  
        return

    polar_maps = {}
    for pop,age,n_ind in tqdm(groups):
        trial_files = grouped_trials.groups[pop,age,n_ind]
        polar_maps[pop,age,n_ind] = dict(N=[], Fx=[], Fy=[], Fr=[], Px=[], Ω=[], dΩ=[])
        for trial_file in trial_files:
            trial = load_trial(trial_file)
            globals().update(trial)

            # Pairs of fish. Order matters: fish 1 is focal fish; fish 2 is neighbor.
            J     = np.indices((n_ind,n_ind)).reshape(2,-1)
            J1,J2 = J[:,J[0]!=J[1]]
            # Compute kinematic quantities for every pair.
            u     = np.cos(pos[:,J1,2]),np.sin(pos[:,J1,2]) # Director of fish 1.
            xy    = pos[:,J2,:2] - pos[:,J1,:2]             # Position of fish 2 relative to fish 1.
            u,xy  = np.reshape(u,(2,-1)), np.reshape(xy,(-1,2)).T # Merge time and fish axes.
            d     = np.hypot(*xy)
            u_r   = xy/d[None,:]
            θ     = np.arctan2(xy[1],xy[0]) - pos[:,J1,2].flatten() # Polar angle of fish 2 in Frenet basis of fish 1.
            θ     = θ % (2*np.pi)
            φ     = (pos[:,J2,2] - pos[:,J1,2]).flatten()   # Relative orientation.
            φ     = φ % (2*np.pi)
            ω     = vel[:,J1,2].flatten()
            dω    = acc[:,J1,2].flatten()
            f     = acc[:,J1,:2].reshape((-1,2)).T          # Acceleration.
            # Normalize by tank size.
            xy   /= R_cm
            d    /= R_cm
            # Acceleration in coordinate system based on director of focal fish.
            fy    = u[0]*f[0] + u[1]*f[1]  # Speeding force (acceleration of focal fish along itself).
            fx    = u[1]*f[0] - u[0]*f[1]  # Turning force (acceleration of focal fish perpendicular to itself).
            # Acceration in coordinate system based position of 2nd fish relative to focal fish.
            fr    = u_r[0]*f[0] + u_r[1]*f[1] # Attraction force (acc. of focal fish along axis towards 2nd fish).

            N  = np.zeros((nθ,nθ,nD)) # Density.
            Fx = np.zeros((nθ,nθ,nD)) # Turning force.
            Fy = np.zeros((nθ,nθ,nD)) # Speeding force.
            Fr = np.zeros((nθ,nθ,nD)) # Radial force (avoidance/repulsion).
            Px = np.zeros((nθ,nθ,nD)) # Probability of turning right.
            Ω  = np.zeros((nθ,nθ,nD)) # Angular velocity.
            dΩ = np.zeros((nθ,nθ,nD)) # Angular acceleration.
            compute_3d_histogram(d, θ, φ, fx, fy, fr, ω, dω, binsθ, binsD, N, Fx, Fy, Fr, Px, Ω, dΩ)
            tmp = dict( N=N, Fx=Fx, Fy=Fy, Fr=Fr, Px=Px, Ω=Ω, dΩ=dΩ)
            for k in tmp:
                polar_maps[pop,age,n_ind][k].append(tmp[k].copy())
        polar_maps[pop,age,n_ind].update(binsD=binsD, binsθ=binsθ, R=R_cm)
    #     break

    with open(real_data_file, 'wb') as f:
        pickle.dump(polar_maps, f)

## Proof that mock pairs are equivalent to mock groups

In [None]:
# n_ind = 5 # number of fish per group
# n_trials = 9 # typical number of trials of a given age and population

# # Construct the set of all mock pairs, i.e., pairs of fish from two different trials.
# pairs = []
# # Loop over all possible pairs of distinct trials. Order matters because the focal fish will 
# # taken from the first trial whereas the neighbor fish will be taken from the second trial.
# for i1,i2 in itt.permutations(range(n_trials),2):
#     # Loop over pairs of fish made up of a fish from trial i1 and a fish from trial i2.
#     for j1,j2 in itt.product(range(n_ind),range(n_ind)):
#         pairs.append(((i1,j1),(i2,j2)))
# print('Does the list contain duplicates?', len(pairs_set)<len(pairs))
# print('Number of unique fish pairs:', len(pairs))

In [None]:
# # Make a list of all possible mock groups, i.e., groups of 5 fish taken 
# # from 5 different trials.
# groups = []
# for I in itt.combinations(range(n_g),n_ind):
#     for J in itt.product(*[range(n_ind)]*n_ind):
#         groups.append(tuple(zip(I,J)))

# # Make a list of all possible fish pairs by looping over mock groups and 
# # listing all possible fish pairs withint that mock group.
# pairs2 = []
# for group in groups:
#     for k1,k2 in itt.permutations(group,2):
#         pairs2.append((k1,k2))

# # print('Number of unique fish pairs:', len(set(pairs2)))
# print('Are these the same pairs as before?', set(pairs2)==set(pairs))

# # This time there are lots of duplicates. To make sure averages over the mock pairs 
# # are equivalent to averages over the pairs from mock groups, we need to make sure 
# # every pair has the same number of duplicates.
# from collections import Counter
# c = Counter(pairs2) # number of instances of each pair
# n = set(c.values()) # distinct numbers of instances
# print('Number of duplicates of each pair:', n)
# # n has only one element, therefore every pair has the same number of duplicates.

## Mock groups

In [None]:
if not os.path.exists(mock_data_file):

    @njit
    def compute_3d_histogram(d, θ, φ, fx, fy, fr, ω, dω, binsθ, binsD, N, Fx, Fy, Fr, Px, Ω, dΩ):
        for i in range(len(d)):
            if np.isfinite(fx[i]) and np.isfinite(fy[i]) and np.isfinite(fr[i]) and np.isfinite(dω[i]):
                kθ = np.digitize([θ[i]], binsθ).item() - 1
                kφ = np.digitize([φ[i]], binsθ).item() - 1
                kd = np.digitize([d[i]], binsD).item() - 1
                if kθ<0 or kφ<0 or kd<0 or kθ>=N.shape[0] or kφ>=N.shape[1] or kd>=N.shape[2]:
                    print('Histogram index out of bounds', N.shape, kθ, kφ, kd)
                    return
                N[kθ,kφ,kd]  += 1
                Fy[kθ,kφ,kd] += fy[i]
                Fx[kθ,kφ,kd] += fx[i]
                Fr[kθ,kφ,kd] += fr[i]
                Ω[kθ,kφ,kd]  += ω[i]
                dΩ[kθ,kφ,kd] += dω[i]
                if fx[i]>0:
                    Px[kθ,kφ,kd] += 1  
        return

    polar_maps = {}
    for pop,age,n_ind in tqdm(groups):
        trials = []
        v0     = [0,0]
        for trial_file in grouped_trials.groups[pop,age,n_ind]:
            trial = load_trial(trial_file)
            trials.append(trial)
            v0[0] += np.sum(np.isfinite(trial['v']))
            v0[1] += np.nansum(trial['v'])
        v0     = v0[1]/v0[0]

        N      = np.zeros((nθ,nθ,nD)) # Density.
        Fx     = np.zeros((nθ,nθ,nD)) # Turning force.
        Fy     = np.zeros((nθ,nθ,nD)) # Speeding force.
        Fr     = np.zeros((nθ,nθ,nD)) # Radial force (avoidance/repulsion).
        Px     = np.zeros((nθ,nθ,nD)) # Probability of turning right.
        Ω      = np.zeros((nθ,nθ,nD)) # Angular velocity.
        dΩ     = np.zeros((nθ,nθ,nD)) # Angular acceleration.
        R_cm   = trials[0]['R_cm']

        n      = len(trials)
        for i1,i2 in itt.permutations(range(n),2):
            for j1,j2 in itt.product(range(n_ind),range(n_ind)):
                pos   = trials[i1]['pos'][:,j1,:3]
                vel   = trials[i1]['vel'][:,j1,:3]
                acc   = trials[i1]['acc'][:,j1,:3]
                pos2  = trials[i2]['pos'][:,j2,:3]
                # Ignore frames beyond end of shorter trial.
                m     = min(pos.shape[0],pos2.shape[0])
                pos,pos2 = pos[:m],pos2[:m]
                vel,acc  = vel[:m],acc[:m]
                # Compute kinematic quantities.
                xy    = (pos2[:,:2] - pos[:,:2]).T
                u     = np.cos(pos[:,2]),np.sin(pos[:,2]) # Director of fish 1.
                d     = np.hypot(*xy)
                u_r   = xy/d[None,:]
                θ     = np.arctan2(xy[1],xy[0]) - pos[:,2]
                θ     = θ % (2*np.pi)
                φ     = pos2[:,2] - pos[:,2]              # Relative orientation.
                φ     = φ % (2*np.pi)
                ω     = vel[:,2]
                dω    = acc[:,2]
                f     = acc[:,:2].T      # Acceleration.
                # Normalize by tank size.
                xy   /= R_cm
                d    /= R_cm
                # Acceleration in coordinate system based on director of focal fish.
                fy    = u[0]*f[0] + u[1]*f[1]  # Speeding force (acceleration of focal fish along itself).
                fx    = u[1]*f[0] - u[0]*f[1]  # Turning force (acceleration of focal fish perpendicular to itself).
                # Acceration in coordinate system base position of 2nd fish relative to focal fish.
                fr    = u_r[0]*f[0] + u_r[1]*f[1] # Attraction force (acc. of focal fish along axis towards 2nd fish).
                compute_3d_histogram(d, θ, φ, fx, fy, fr, ω, dω, binsθ, binsD, N, Fx, Fy, Fr, Px, Ω, dΩ)

        polar_maps[pop,age,n_ind] = dict(N=N, Fx=Fx, Fy=Fy, Fr=Fr, Px=Px, Ω=Ω, dΩ=dΩ, 
                                         binsD=binsD, binsθ=binsθ, R=R_cm, v0=v0)

    with open(mock_data_file, 'wb') as f:
        pickle.dump(polar_maps, f)

# Make figures

In [None]:
quantity_name = { 'N':'Density', 'Fr':'Attraction Force', 'Fy':'Speeding Force',
                  'Fx':'Turning Force', 'Px':'Right Turn Probability',
                  'Ω':'Angular Velocity', 'dΩ':'Angular Acceleration'}
population_name = { 'SF':'Surface', 'Pa':'Cave' }

def plot_polar_heatmap(H_, ax=None, n=10, binsθ=binsθ, binsD=binsD, 
                       cmap=plt.get_cmap('seismic'), norm=None):
    if ax is None:
        ax = plt.gca()
    binsθ_ = np.linspace(binsθ[0],binsθ[-1],n*(len(binsθ)-1)+1)
    θ,D    = np.meshgrid(binsθ_, binsD, indexing='ij')
    H      = np.stack([H_]*n, axis=1).reshape((-1,H_.shape[1]))
    ax.grid(False)
    ax.set_theta_zero_location('N')
    mesh   = ax.pcolormesh(θ, D, H, cmap=cmap, norm=norm)
    ax.set_xticks([])
    ax.set_yticks([])
    return mesh

pop_colors = {'SF': {'face':'pink', 'edge':'palevioletred', 'lines':'mediumvioletred'},
              'Pa': {'face':'thistle', 'edge':'mediumorchid', 'lines':'purple'}}


def plot_violin(ax, q, polar_maps, polar_maps_mock, pop, n_ind, ages, 
                average_wrapper, mode, quantity_name=quantity_name):
    avg_ = lambda pm,q,age: average_wrapper(pm[pop,age,n_ind], q)
    if mode=='difference':
        dQ   = {age:np.array(avg_(polar_maps,q,age))-avg_(polar_maps_mock,q,age)[0] for age in ages}
    elif mode=='real':
        dQ   = {age:np.array(avg_(polar_maps,q,age)) for age in ages}
    elif mode=='mock':
        dQ   = {age:np.array(avg_(polar_maps_mock,q,age)) for age in ages}
    df   = pd.DataFrame([(f'{age}dpf',dq) for age in dQ for dq in dQ[age]], columns=['Age',q])
    data = dict(x='Age', y=q, data=df, ax=ax)
    violin = sns.violinplot(inner=None, color=pop_colors[pop]['face'], linewidth=1, **data)
    for c in violin.axes.collections:
        c.set_edgecolor(pop_colors[pop]['edge'])
    sns.stripplot(size=3, jitter=0.1, color=pop_colors[pop]['lines'], **data)
    ax.axhline(0, color='k', lw=0.5)
    ax.set_xlabel(None)
#     ax.set_ylabel(None)
    ax.set_ylabel(quantity_name[q], fontsize=fontsize)
#     ax.text(-0.3, 0.5, quantity_name[q], transform=ax.transAxes, 
#             rotation=90, va='center', ha='center', fontsize=fontsize)
    ax.locator_params(axis='y', nbins=4)
    return dQ

def average(Q, N, q, v0, binsD, binsθ, Id, Iθ):
    N,Q = N.copy(),Q.copy()
    if q in ['Fy']:
        Q[2:6]  *= -1 # Invert rear side => positive means attraction.
    if q in ['Fx']:
        Q[:4]  *= -1 # Invert left side => positive means attraction.
    if q in ['Px']:
        Px[:4]  = 1-Px[:4] # Invert left side => positive means attraction.
    if q in ['Fy','Fx','Fr']:
        Q /= v0
    if q in ['Ω','dΩ']:
        Q[4:]  *= -1 # Invert rear side => positive means attraction.
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)
#         # Option 1 (bad!): Flat average.
#         Q_ = np.nanmean(Q[Iθ,Id]/N[Iθ,Id])
#         # Option 2: Average weighted by fish probability density (likelier θ & d values weigh more). 
#         Q_ = np.sum(Q[Iθ,Id])/np.sum(N[Iθ,Id])
        # Option 3: Average weigthed by area in (θ,d) space. 
        A  = (binsθ[1:,None]-binsθ[:-1,None])*(binsD[None,1:]**2-binsD[None,:-1]**2)
        Q_ = np.nansum((Q*A/N)[Iθ,Id])/np.sum(A[Iθ,Id])
    return Q_

def load_polar_maps():
    with open(real_data_file, 'rb') as f:
        polar_maps = pickle.load(f)
    with open(mock_data_file, 'rb') as f:
        polar_maps_mock = pickle.load(f)
    return polar_maps,polar_maps_mock

def significant_symbol(p):
    s0 = '*' # '$\star\!$'
    if p<0.001: return s0*3 
    elif p<0.01: return s0*2
    elif p<0.05: return s0
    else: return ''

## (d,θ) maps

In [None]:
for mode in ['real', 'mock', 'difference']:
    
    polar_maps,polar_maps_mock = load_polar_maps()
    
    quantities  = ['N', 'Fy','Fx'] # ,'Fr' ,'Ω' ,'dΩ' ,'Px'
    populations = ['SF','Pa']
    ages        = [7,28,42,70]
    n_ind       = 5

    # Sum over real trials. Sum over the relative orientation angle φ.
    for k in polar_maps.keys():
        for q in quantities:
            polar_maps[k][q] = np.nansum(polar_maps[k][q], axis=(0,2))
            polar_maps_mock[k][q] = np.nansum(polar_maps_mock[k][q], axis=1)
            polar_maps[k][q] = polar_maps[k][q][:-1,:-1]
            polar_maps_mock[k][q] = polar_maps_mock[k][q][:-1,:-1]

    norm = { 'N' : CenteredNorm(halfrange=1), 
             'Fy': CenteredNorm(halfrange=3), 
             'Fx': CenteredNorm(halfrange=3), 
             'Fr': CenteredNorm(halfrange=5),
             'Px': CenteredNorm(halfrange=0.25),
             'Ω' : CenteredNorm(halfrange=0.25),
             'dΩ': CenteredNorm(halfrange=2) }
    if mode in ['real','mock']:
        norm['Px'] = CenteredNorm(vcenter=0.5, halfrange=0.25)

    nx,ny   = len(ages),len(quantities)*len(populations)
    gsk     = {'bottom':0.00, 'top':0.97, 'left':0.15, 'right':0.9,
               'hspace':0.05, 'wspace':0.05}
    fig,axs = plt.subplots(ny, nx, figsize=(2*nx,2*ny*0.85), 
                           subplot_kw={'projection':'polar'},
                           gridspec_kw=gsk, squeeze=False)
    for ax in axs[::2].flatten():
        bounds = list(ax.get_position().bounds)
        bounds[1] -= 0.09/ny
        ax.set_position(bounds)

    fontsize = 18

    for i_q,q in enumerate(quantities):
        for i_p,pop in enumerate(populations):
            i_y = i_q*len(populations)+i_p
            for i_a,age in enumerate(ages):
                pm  = polar_maps[pop,age,n_ind]
                pmm = polar_maps_mock[pop,age,n_ind]
                binsθ,binsD = pm['binsθ'],pm['binsD']
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", category=RuntimeWarning)
                    if q=='N':
                        H1,H2 = pm['N']/np.sum(pm['N']), pmm['N']/np.sum(pmm['N'])
                    else:
                        H1,H2 = pm[q]/pm['N'], pmm[q]/pmm['N']
                    H  = {'real':H1, 'mock':H2, 'difference':H1-H2}[mode]
                    if q in ['Fy','Fx','Fr']:
                        H /= pmm['v0']
                    if q=='N':
                        H /= (binsθ[1:,None]-binsθ[:-1,None])*(binsD[None,1:]**2-binsD[None,:-1]**2)

                ax   = axs[i_y,i_a]
                mesh = plot_polar_heatmap(H, ax=ax, n=10, norm=norm[q])
                ax.set_ylim(0,0.5)

                if i_y==0:
                    ax.set_title(f'{age} dpf', fontsize=fontsize, y=1.1)

            if i_p==0:
                x,y1,dx,dy = axs[i_q*len(populations),-1].get_position().bounds
                x,y2,dx,dy = axs[i_q*len(populations)+1,-1].get_position().bounds
                cax  = fig.add_axes([x+1.15*dx,(y1+dy+y2)/2-0.8*dy,0.1*dx,1.6*dy])
                fig.colorbar(mesh, cax=cax)
                cax.locator_params(axis='y', nbins=3)

            axs[i_y,0].set_ylabel(population_name[pop], fontsize=fontsize, rotation=90, labelpad=20, va='center')
    #         axs[i_y,0].text(-0.2, 0.5, pop, transform=axs[i_y,0].transAxes, fontsize=14)

        ax = axs[i_q*len(populations)+1,0]
        ax.text(-0.6, 1, quantity_name[q], transform=ax.transAxes, fontsize=fontsize, rotation=90, va='center')


    # plt.savefig(f'figures/d-θ-map_{"-".join(quantities)}_{mode}.png', dpi=200)
    plt.savefig(f'figures/heatmaps_{mode}.png', dpi=200)
    plt.show()

### Polar grid showing focal fish

In [None]:
fig,ax = plt.subplots(1, 1, figsize=(3,3), subplot_kw={'projection':'polar'}, 
                      gridspec_kw={'bottom':0.15, 'top':0.85, 'left':0.15, 'right':0.85})

ax.set_theta_zero_location('N')
ticks  = binsD[1:]
labels = [f'{d:.1g}' if np.abs(d%0.1)<1e-6 else '' for d in ticks]
ax.set_yticks(ticks, labels, x=-0.4, ha='center', va='center')
ax.set_ylim(0,0.5)

img = plt.imread('figure-assets/fish.png')
img = np.ma.masked_where(img==1, img)
# ax2 = fig.add_axes(ax.get_position().bounds)
ax2 = fig.add_axes([0.43,0.39,0.15,0.15])
ax2.imshow(img, alpha=0.7, cmap=plt.get_cmap('gray'), vmin=0, vmax=1)
ax2.axis('off')

plt.savefig(f'figures/heatmap-grid.png', dpi=200)
plt.show()

### Scalar metric with statistical analysis

In [None]:
polar_maps,polar_maps_mock = load_polar_maps()

mode        = ['real', 'mock', 'difference'][2]
quantities  = ['Fx','Fy'] # ,'Fr' ,'N'] # ,'Ω' ,'dΩ' ,'Px'
populations = ['SF','Pa']
ages        = [7,28,42,70]
n_ind       = 5

# Sum over the relative orientation angle φ.
for k in polar_maps.keys():
    polar_maps[k]['v0'] = polar_maps_mock[k]['v0']
    for q in quantities+['N']:
        polar_maps[k][q] = np.nansum(polar_maps[k][q], axis=2)
        polar_maps_mock[k][q] = np.nansum(polar_maps_mock[k][q], axis=1)
        polar_maps[k][q] = polar_maps[k][q][:,:-1,:-1]
        polar_maps_mock[k][q] = polar_maps_mock[k][q][:-1,:-1]

Iφ = slice(None)
# Id = range(0,2) # repulsion zone
Id = range(3,6) # attraction zone (when attraction exists)
def average_wrapper(pm, q, Id=Id, Iθ=Iφ):
    kw = dict(q=q, v0=pm['v0'], binsD=pm['binsD'], binsθ=pm['binsθ'], Id=Id, Iθ=Iθ)
    if len(pm['N'].shape)==2:
        return [average(pm[q], pm['N'], **kw)]
    elif len(pm['N'].shape)==3: # separate histogram for each trial
        return [average(Q, N, **kw) for Q,N in zip(pm[q],pm['N'])]


nx,ny   = len(populations),len(quantities)
gsk     = {'bottom':0.07, 'top':0.92, 'left':0.15, 'right':0.98,
           'hspace':0.4, 'wspace':0.5}
fig,axs = plt.subplots(ny, nx, figsize=(3*nx,2.5*ny), 
                       gridspec_kw=gsk, squeeze=False)
fontsize = 14
quantity_name_ = {'Fy':'Attractive\nSpeeding Force',
                  'Fx':'Attractive\nTurning Force'}

for i_p,pop in enumerate(populations):
    for i_q,q in enumerate(quantities):
        ax = axs[i_q,i_p]
#         Q,Qm,P0,Md = plot_violin(ax, q, polar_maps, polar_maps_mock, pop, n_ind, ages, average_wrapper)
        dQ = plot_violin(ax, q, polar_maps, polar_maps_mock, pop, n_ind, ages, 
                         average_wrapper, mode, quantity_name=quantity_name_)
        ax.set_ylim(-0.6,1) # for flat average
        ax.yaxis.labelpad = -5
        for i_a,age in enumerate(dQ):
            p0  = shapiro(dQ[age]).pvalue
            x,t = dQ[age],wilcoxon(dQ[age]) if p0<0.05 else ttest_1samp(dQ[age], 0)
            print(f'{pop}, {quantity_name[q]}, {age}dpf, median={np.median(x):.3g}, '
                  f'p_shapiro={p0:.3g}, stat={t.statistic:.3g}, pval={t.pvalue:.3g}')
            l1  = ax.get_ylim()
            l2  = ax.collections[i_a].get_datalim(ax.transData)
            y   = l2.y1 + 0.03*(l1[1]-l1[0])
            ax.text(i_a, y, significant_symbol(t.pvalue), ha='center', fontsize=fontsize)

for pop,ax in zip(populations,axs[0,:]):
    ax.set_title(population_name[pop], fontsize=fontsize, pad=10) #, weight='bold')
    
# plt.savefig(f'figures/{"-".join(quantities)}_{mode}_with-stats.png', dpi=200)
plt.savefig(f'figures/force-violins_{mode}.png', dpi=200)
plt.show()

#------------------------------------------------

# Region of (θ,d) space used to compute the scalar metric.
d1,d2 = binsD[Id[0]],binsD[Id[-1]+1]
print(d1,d2)
fig,ax = plt.subplots(1, 1, figsize=(3,3), subplot_kw={'projection':'polar'}, 
                      gridspec_kw={'bottom':0.15, 'top':0.85, 'left':0.15, 'right':0.85})
ax.bar(0, 1).remove()
ax.set_theta_zero_location('N')
ticks  = binsD[1:]
labels = [f'{d:.1g}' if np.abs(d%0.1)<1e-6 else '' for d in ticks]
ax.set_yticks(ticks, labels, x=-0.4, ha='center', va='center')
ax.set_ylim(0,0.5)
ax.add_patch(Rectangle((0,d1), width=2*np.pi, height=d2-d1, color='C0'))
plt.savefig(f'figures/force-average.png', dpi=200)
plt.show()

## (d,φ) maps

### Scalar metric with statistical analysis

In [None]:
polar_maps,polar_maps_mock = load_polar_maps()

mode        = ['real', 'mock', 'difference'][2]
quantities  = ['dΩ']
populations = ['SF','Pa']
ages        = [7,28,42,70]
n_ind       = 5

# Sum over the relative orientation angle θ.
for k in polar_maps.keys():
    polar_maps[k]['v0'] = polar_maps_mock[k]['v0']
    for q in quantities+['N']:
        polar_maps[k][q] = np.nansum(polar_maps[k][q], axis=1)
        polar_maps_mock[k][q] = np.nansum(polar_maps_mock[k][q], axis=0)
        polar_maps[k][q] = polar_maps[k][q][:,:-1,:-1]
        polar_maps_mock[k][q] = polar_maps_mock[k][q][:-1,:-1]

# Iφ = [-2,-1,0,1]
Iφ = slice(None)
# Id = range(0,2) # repulsion zone
Id = range(1,6) # attraction zone (when attraction exists)
def average_wrapper(pm, q, Id=Id, Iθ=Iφ):
    kw = dict(q=q, v0=pm['v0'], binsD=pm['binsD'], binsθ=pm['binsθ'], Id=Id, Iθ=Iθ)
    if len(pm['N'].shape)==2:
        return [average(pm[q], pm['N'], **kw)]
    elif len(pm['N'].shape)==3: # separate histogram for each trial
        return [average(Q, N, **kw) for Q,N in zip(pm[q],pm['N'])]

nx,ny   = len(populations),len(quantities)
gsk     = {'bottom':0.1, 'top':0.87, 'left':0.15, 'right':0.98,
           'hspace':0.4, 'wspace':0.5}
fig,axs = plt.subplots(ny, nx, figsize=(3*nx,2.5*ny), 
                       gridspec_kw=gsk, squeeze=False)
fontsize = 14
quantity_name_ = {'dΩ':'Aligning Angular\nAcceleration'}

for i_p,pop in enumerate(populations):
    for i_q,q in enumerate(quantities):
        ax = axs[i_q,i_p]
        dQ = plot_violin(ax, q, polar_maps, polar_maps_mock, pop, n_ind, ages, 
                         average_wrapper, mode, quantity_name=quantity_name_)
        ax.set_ylim(-1.5,3.6) # flat average
        for i_a,age in enumerate(dQ):
            p0  = shapiro(dQ[age]).pvalue
            x,t = dQ[age],wilcoxon(dQ[age]) if p0<0.05 else ttest_1samp(dQ[age], 0)
            print(f'{pop}, {quantity_name[q]}, {age}dpf, median={np.median(x):.3g}, '
                  f'p_shapiro={p0:.3g}, stat={t.statistic:.3g}, pval={t.pvalue:.3g}')
            l1  = ax.get_ylim()
            l2  = ax.collections[i_a].get_datalim(ax.transData)
            y   = l2.y1 + 0.03*(l1[1]-l1[0])
            ax.text(i_a, y, significant_symbol(t.pvalue), ha='center', fontsize=fontsize)

for pop,ax in zip(populations,axs[0,:]):
    ax.set_title(population_name[pop], fontsize=fontsize, pad=10) #, weight='bold')
    
# plt.savefig(f'figures/{"-".join(quantities)}_{mode}_with-stats.png', dpi=200)
plt.savefig(f'figures/angular-violins_{mode}.png', dpi=200)
plt.show()

#------------------------------------------------

# Region of (φ,d) space used to compute the scalar metric.
φ1 = binsθ[Iφ][0]
dφ = (binsθ[Iφ][-1]-φ1)%(2*np.pi)
dφ = 2*np.pi if dφ==0 else dφ
d1,d2 = binsD[Id[0]],binsD[Id[-1]+1]
print(d1,d2)
fig,ax = plt.subplots(1, 1, figsize=(3,3), subplot_kw={'projection':'polar'}, 
                      gridspec_kw={'bottom':0.15, 'top':0.85, 'left':0.15, 'right':0.85})
ax.bar(0, 1).remove()
ax.set_theta_zero_location('N')
ticks  = binsD[1:]
labels = [f'{d:.1g}' if np.abs(d%0.1)<1e-6 else '' for d in ticks]
ax.set_yticks(ticks, labels, x=-0.4, ha='center', va='center')
ax.set_ylim(0,0.5)
ax.add_patch(Rectangle((φ1,d1), width=dφ, height=d2-d1, color='C0'))
plt.savefig(f'figures/angular-average.png', dpi=200)
plt.show()