In [None]:
from importlib import reload
import platform, os, sys, datetime, re, itertools, warnings, pickle, bz2
import os.path as osp
from glob import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import defaultdict

from importlib import reload
import trilabtracker
reload(trilabtracker)
import trilabtracker.utils as utils

plt.rcParams['figure.figsize'] = 9,6
# plt.rcParams['figure.dpi'] = 150
dpi = 150

analysis_dir = './analysis'

# Load analysis results.

In [None]:
# List previously computed cuts.
cut_dirs = os.listdir(analysis_dir)
print(cut_dirs)

# Load a specific cut.
cut_label = cut_dirs[0]
f    = os.path.join(analysis_dir,cut_label,'trial_data')
data = pickle.load( open(f+'.pik','rb') )
globals().update(data)

In [None]:
tank_diameter_vs_age = { 7:9.6, 14:10.4, 21:12.8, 28:17.7, 42:24.2, 56:33.8, 70:33.8, 84:33.8 }

# Extract trial metadata from the trial's filename.
def parse_trial_file(trial_file, etho=False):
    trial_dir  = os.path.dirname(trial_file)
    trial_name = os.path.basename(trial_dir)
    if etho:
        print(trial_name)
    pop,day,age,group,n_ind = trial_name.split('_')[:5]
    pop        = {'sf':'SF', 'pa':'Pa', 'rc':'RC'}[pop.lower()]
    age        = int(age[:-3])
    age = 42 if age==43 else (70 if age==71 else age)
    n_ind      = int(re.findall('\d+',n_ind)[0])
    R_cm       = tank_diameter_vs_age[age]
    trial      = { k:v for k,v in locals().items() if k in ['trial_dir', 
                   'trial_name', 'pop', 'age', 'group', 'n_ind'] }
    trial['R_cm'] = tank_diameter_vs_age.get(age,None)/2
    return trial


# Select a set of trials to analyze.
trial_files = list(trial_data.keys())


# Count trials of each type.
trials = [ parse_trial_file(f) for f in trial_files]
trials = pd.DataFrame(trials,index=trial_files)
grouped_trials = trials.groupby(['pop','age','n_ind'])
count  = pd.DataFrame(grouped_trials['trial_dir'].count().rename('count'))
count = count.unstack(1)
count.columns = count.columns.droplevel()
count[pd.isna(count)] = 0
count = count.astype(int)
# print('Number of trials of each type:')
count

# Group by population and age

### Area and aspect ratio

In [None]:
output_dir = os.path.join(analysis_dir,cut_label,'aggregated','hist_areaAspect')
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

for pop,files1 in trials.groupby('pop').groups.items():
    fig,ax = plt.subplots(1,2,figsize=(12,3))
    fig.suptitle(pop)
    for age,files2 in trials.loc[files1].groupby('age').groups.items():
        for i,name in enumerate(['hist_area','hist_aspect']):
            b = trial_data[files2[0]][name][1]
            h = np.mean([trial_data[f][name][0] for f in files2],axis=0)
            ax[i].plot(b[:-1],h,label=f'{age}')
        
    for i in range(2):
        ax[i].set_title(['Fish Area','Fish Aspect Ratio'][i])
        ax[i].legend(title='Age (dpf)')
    fig.savefig(os.path.join(output_dir,f'hist_areaAspect_{pop}.png'), dpi=dpi)
#     plt.show()
    plt.close()

### Valid fraction

In [None]:
output_dir = os.path.join(analysis_dir,cut_label,'aggregated','validFraction')
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

grouped_valid_fractions = {}
for g,files in grouped_trials.groups.items():
    grouped_valid_fractions[g] = defaultdict(list)
    for i,f in enumerate(files):
        for cut,vf in trial_data[f]['valid_fraction'].items():
            grouped_valid_fractions[g][cut].append(vf)

''' Boxplot of the fraction of valid frames for each type of invalidity. 
    One boxplot for each type of trial. '''
# for k,valid_fractions in grouped_valid_fractions.items():
#     plt.boxplot(valid_fractions.values(),labels=valid_fractions.keys())
# #     plt.violinplot([valid_fractions[c] for c in cut_names],showextrema=False)
#     plt.xlabel('Cut')
#     plt.ylabel('Valid fraction')
#     plt.ylim(0,1.05)
#     pop,age,n_ind = k
#     name = f'{pop}_{age}dpf_n{n_ind}'
#     plt.title(name)
#     plt.show()

''' Boxplot of the fraction of valid frames for each type of trial. 
    One boxplot for each type of invalidity. '''
# for cut_type in next(iter(grouped_valid_fractions.values())).keys():
#     vf_summary = { f'{k[0]}_{k[1]}dpf_n{k[2]}':vf[cut_type] for k,vf in grouped_valid_fractions.items() }
#     plt.figure(figsize=(12,4))
#     plt.boxplot(vf_summary.values(),labels=vf_summary.keys())
#     plt.ylabel('Valid fraction')
#     plt.ylim(0,1)
#     plt.xticks(rotation=90)
#     plt.title(cut_type)
#     plt.show()

''' Boxplot of the fraction of valid frames for each type of trial. '''
cut_types = next(iter(grouped_valid_fractions.values())).keys()
for cut_type in cut_types:
# for cut_type in ['final']:
    vf_summary = { f'{k[0]}_{k[1]}dpf_n{k[2]}':vf[cut_type] for k,vf in grouped_valid_fractions.items() }
    plt.figure(figsize=(12,4))
    plt.boxplot(vf_summary.values(),labels=vf_summary.keys())
    plt.ylabel('Valid fraction')
    plt.ylim(0,1.02)
    plt.xticks(rotation=90)
    plt.title(f'{cut_type} cut')
    plt.savefig(os.path.join(output_dir,f'validFraction_{cut_type}-cut.png'), dpi=dpi)
#     plt.show()
    plt.close()

In [None]:
''' Identify trials with a higher fraction of invalid frames. '''

for f in trials.index:
    vf = trial_data[f]['valid_fraction']
    if vf['final']<0.8:
        print(trials.loc[f]['trial_name'])
        print(', '.join([ f'{k}:{v:.2g}' for k,v in vf.items() ]))
        print()

### Wall distance

In [None]:
output_dir = os.path.join(analysis_dir,cut_label,'aggregated','hist_dWall')
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

pops = trials['pop'].unique()
ages = np.sort(trials['age'].unique())
colors = dict(zip( ages, plt.cm.viridis(np.linspace(0,1,len(ages))) ))

for pop in pops:
    fig,axs = plt.subplots(1,3,figsize=(15,4))
    for age in ages:
        for i,n_ind in enumerate([1,2,5]):
            ax = axs[i]
            files = grouped_trials.get_group((pop,age,n_ind)).index
            h,b = trial_data[files[0]]['hist_dWall']
            b   = b/(tank_diameter_vs_age[age]/2)
            x   = (b[1:]+b[:-1])/2
            H = np.array([ trial_data[f]['hist_dWall'][0] for f in files ])
            h = np.nansum(H,axis=0)
            h = h/np.sum(h)
            h = h/(2*np.pi*(1-x)*(b[1:]-b[:-1]))
            m = ax.plot(x, h, color=colors[age], label=f'{age}dpf')
    for i,n_ind in enumerate([1,2,5]):
        ax = axs[i]
#         ax.set_ylim(1e1,None)
        ax.set_xlabel('distance to the wall (tank radii)')
        ax.set_ylabel('density per unit area (arbitrary units)')
        ax.set_yscale('log')
        ax.set_title(f'{pop}_n{n_ind}')
        ax.legend(ncol=2)
    plt.savefig(os.path.join(output_dir,f'hist_dWall_{pop}.png'), dpi=dpi)
#     plt.show()
    plt.close()

### Speed distribution

In [None]:
output_dir = os.path.join(analysis_dir,cut_label,'aggregated','hist_speed')
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

pops = trials['pop'].unique()
ages = np.sort(trials['age'].unique())
colors = dict(zip( ages, plt.cm.viridis(np.linspace(0,1,len(ages))) ))

for pop in pops:
    fig,axs = plt.subplots(1,3,figsize=(15,4))
    for age in ages:
        for i,n_ind in enumerate([1,2,5]):
            ax = axs[i]
            files = grouped_trials.get_group((pop,age,n_ind)).index
            h,b = trial_data[files[0]]['hist_v']
            b   = b/(tank_diameter_vs_age[age]/2)
            x   = (b[1:]+b[:-1])/2
            H = np.array([ trial_data[f]['hist_v'][0] for f in files ])
            h = np.nansum(H,axis=0)
            h = h/np.sum(h)
            m = ax.plot(x, h, color=colors[age], label=f'{age}dpf')
    for i,n_ind in enumerate([1,2,5]):
        ax = axs[i]
#         ax.set_xlim(0,3)
        ax.set_xlabel('speed (tank radii/s)')
        ax.set_ylabel('frequency')
        ax.set_yscale('log')
        ax.set_title(f'{pop}_n{n_ind}')
        ax.legend()
    fig.savefig(os.path.join(output_dir,f'hist_speed_{pop}.png'), dpi=dpi)
#     plt.show()
    plt.close()

### Angular speed distribution

In [None]:
output_dir = os.path.join(analysis_dir,cut_label,'aggregated','hist_vAng')
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

pops = trials['pop'].unique()
ages = np.sort(trials['age'].unique())
colors = dict(zip( ages, plt.cm.viridis(np.linspace(0,1,len(ages))) ))

for pop in pops:
    fig,axs = plt.subplots(1,3,figsize=(15,4))
    for age in ages:
        for i,n_ind in enumerate([1,2,5]):
            ax = axs[i]
            files = grouped_trials.get_group((pop,age,n_ind)).index
            h,b = trial_data[files[0]]['hist_vAng']
            x   = (b[1:]+b[:-1])/2
            H = np.array([ trial_data[f]['hist_vAng'][0] for f in files ])
            h = np.nansum(H,axis=0)
            h = h/np.sum(h)
            m = ax.plot(x, h, color=colors[age], label=f'{age}dpf')
    for i,n_ind in enumerate([1,2,5]):
        ax = axs[i]
        ax.set_xlabel('speed (rad/s)')
        ax.set_ylabel('frequency')
        ax.set_yscale('log')
        ax.set_title(f'{pop}_n{n_ind}')
        ax.legend()
    fig.savefig(os.path.join(output_dir,f'hist_vAng_{pop}.png'), dpi=dpi)
#     plt.show()
    plt.close()

### Joint pair distance-pair angle distribution

In [None]:
# for (pop,age,n_ind),files in grouped_trials.groups.items():
#     if n_ind==1:
#         continue
#     h,b1,b2 = trial_data[files[0]]['hist_distAng']
#     H = np.array([ trial_data[f]['hist_distAng'][0] for f in files ])
#     h = np.nanmean(H,axis=0)
#     plt.pcolormesh(bins_d, bins_a*180/np.pi, h.T, cmap='Oranges')
#     plt.xlabel('pair distance (cm)')
#     plt.ylabel('pair angle (deg)')
#     plt.colorbar()
#     plt.title(cut_label)
#     plt.suptitle(f'{pop}_{age}dpf_n{n_ind}')
#     plt.show()
# #     break

In [None]:
pops = trials['pop'].unique()
ages = np.sort(trials['age'].unique())

output_dir = os.path.join(analysis_dir,cut_label,'aggregated','hist_distAng')
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

for pop in pops:
# for pop in ['RC']:
    for age in ages:
        fig,axs = plt.subplots(1,2,figsize=(12,4))
        for i,n_ind in enumerate([2,5]):
            ax = axs[i]
            files = grouped_trials.get_group((pop,age,n_ind)).index
            h,b1,b2 = trial_data[files[0]]['hist_distAng']
            H = np.array([ trial_data[f]['hist_distAng'][0] for f in files ])
            h = np.nanmean(H,axis=0)
            h = h/np.sum(h)
            vmax = 0.002 if pop=='Pa' else 0.006
            m = ax.pcolormesh( b1, b2*180/np.pi, h.T, cmap='Oranges', 
                               vmin=0, vmax=vmax )
            ax.set_xlabel('pair distance (cm)')
            ax.set_ylabel('pair angle (deg)')
            fig.colorbar(m,ax=ax)
            ax.set_title(f'{pop}_{age}dpf_n{n_ind}')
        fig.savefig(os.path.join(output_dir,f'hist_distAng_{pop}_{age}dpf.png'), dpi=dpi)
#         plt.show()
        plt.close()
#         break
#     break

plt.close('all')

In [None]:
pops   = trials['pop'].unique()
ages   = np.sort(trials['age'].unique())
colors = dict(zip( ages, plt.cm.viridis(np.linspace(0,1,len(ages))) ))

output_dir = os.path.join(analysis_dir,cut_label,'aggregated','polar')
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

# n_ind  = 5
for pop in pops:
#     fig,axs = plt.subplots(1,2,figsize=(15,4))
    fig = plt.figure(figsize=(15,4))
    axs = [ fig.add_axes([i/2+0.05,0.12,0.3,0.8]) for i in range(2) ]
    for age in ages:
        for i,n_ind in enumerate([2,5]):
#         print(pop,age)
            files = grouped_trials.get_group((pop,age,n_ind)).index
            p,b = trial_data[files[0]]['polar']
            x   = (b[1:]+b[:-1])/2
            P   = np.array([ trial_data[f]['polar'][0] for f in files ])
            with warnings.catch_warnings():
                p = np.mean(P,axis=0)
            axs[i].plot( x/(tank_diameter_vs_age[age]/2), p, lw=2, color=colors[age], 
                      label=f'{pop}_{age}dpf_n{n_ind}' )
            axs[i].set_xlabel('pair distance/tank radius')
            axs[i].set_ylabel('mean cosine of pair angle')
            axs[i].set_ylim(-1,1)
            axs[i].set_title(cut_label)
        #     plt.suptitle(f'{pop}_{age}dpf_n{n_ind}')
            axs[i].legend(loc='center left', bbox_to_anchor=(1.01,0.5))
    fig.savefig(os.path.join(output_dir,f'polar_{pop}.png'), dpi=dpi)
#     plt.show()
    plt.close()
