In [None]:
import os, sys, datetime, re, pickle, bz2
from os.path import join
from glob import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import itertools as itt
from tqdm import tqdm
sys.path.append('..')
import trilabtracker as tt
from trilabtracker.gui_classes import Track

# import trilabtracker
# from importlib import reload
# reload(trilabtracker)
# for m in trilabtracker.__all__:
#     eval(f'reload(trilabtracker.{m})')
# import trilabtracker as tt
# from trilabtracker.gui_classes import Track

# Prepare trial data.

Prepare a dictionary of trials to analyze with basic info for each (path to trial file, population, age, number of individuals, etc).

In [None]:
tank_diameter_vs_age = { 7:9.6, 14:10.4, 21:12.8, 28:17.7, 42:33.8,
                         56:33.8, 70:33.8, 84:3.8 }

# Extract trial metadata from the trial's filename.
def parse_trial_file(trial_file, etho=False):
    if etho:
        trial_dir  = None
        trial_name = os.path.basename(trial_file)
        trial_name = trial_name.split('-')[1]
    else:
        trial_dir  = os.path.dirname(trial_file)
        trial_name = os.path.basename(trial_dir)
    pop,day,age,group,n_ind = trial_name.split('_')[:5]
    pop        = {'sf':'SF', 'pa':'Pa', 'rc':'RC'}[pop.lower()]
    age        = int(age[:-3])
    age = 42 if age==43 else (70 if age==71 else age)
    n_ind      = int(re.findall('\d+',n_ind)[0])
    R_cm       = tank_diameter_vs_age[age]/2
    trial      = { k:v for k,v in locals().items() if k in ['trial_dir', 
                   'trial_name', 'pop', 'age', 'group', 'n_ind', 'R_cm'] }
    return trial


# Select a set of trials to analyze.
trial_files = sorted(glob('../../dataset2/tracking/*/trial.pik'))
# print(trial_files)

# Count trials of each type.
trials = [ parse_trial_file(f) for f in trial_files]
trials = pd.DataFrame(trials,index=trial_files)
grouped_trials = trials.groupby(['pop','age','n_ind'])
count  = pd.DataFrame(grouped_trials['trial_dir'].count().rename('count'))
count = count.unstack(1)
count.columns = count.columns.droplevel()
count[pd.isna(count)] = 0
count = count.astype(int)
# print('Number of trials of each type:')
count

# Define analysis functions.

In [None]:
def compute_kinematics(trial,wall_distance=False):
    center      = np.array([trial['tank']['xc'],trial['tank']['yc']])
    px2cm       = trial['R_cm']/trial['tank']['R']
    n           = trial['n_ind']
    pos         = trial['data'][:,:n,:3].copy() # discard extra objects
    pos[:,:,:2] = (pos[:,:,:2]-center[None,None,:])*px2cm # convert to centimeters
    pos[:,:,1]  = -pos[:,:,1] # flip y axis
    for j in range(n): # unwrap orientations
        I          = ~np.isnan(pos[:,j,2])
        pos[I,j,2] = np.unwrap(pos[I,j,2])
    time        = trial['frame_list']/trial['fps']
    vel         = np.gradient(pos,time,axis=0)
    acc         = np.gradient(vel,time,axis=0)
    v           = np.hypot(vel[:,:,0],vel[:,:,1])
    
    d_wall      = trial['R_cm'] - np.hypot(pos[:,:,0],pos[:,:,1])
# #     dist    = lambda xy: dist2ellipse(*ellipse[1],xy)
#     dist    = lambda xy: cv2.pointPolygonTest(trial['tank']['contour'],tuple(xy),True)
#     d_wall  = px2cm * np.apply_along_axis(dist,2,trial['data'][:,:,:2])
    
    trial.update({ k:v for k,v in locals().items() if k in 
                   ['time', 'pos', 'vel', 'acc', 'd_wall', 'v'] })
    return trial


def compute_cuts(trial,ranges):
    # Distances in ut ranges
    globals().update(trial)
    
    # valid array: axis 0 = time, axis 1 = [nan_xy,nan_any,d_wall,v,v_ang,final]
    valid  = np.full(pos.shape[:2]+(7,),np.True_,dtype=np.bool_)
    valid[:,:,0] = np.logical_not(np.any(np.isnan(pos),axis=2))
    valid[:,:,1] = np.logical_not(np.any(np.isnan(vel),axis=2))
    valid[:,:,2] = np.logical_not(np.any(np.isnan(acc),axis=2))
    valid[:,:,3] = np.logical_and(d_wall>=ranges['d_wall'][0],d_wall<=ranges['d_wall'][1])
    valid[:,:,4] = np.logical_and(v>=ranges['v'][0],v<=ranges['v'][1])
    valid[:,:,5] = np.logical_and(vel[:,:,2]>=ranges['v_ang'][0],vel[:,:,2]<=ranges['v_ang'][1])
    valid[:,:,6] = np.all(valid[:,:,:6],axis=2)
    
    n_total = valid.shape[0]*valid.shape[1]
    n_valid = np.count_nonzero(valid,axis=(0,1))
    valid_fraction = { 'nan_pos' : n_valid[0]/n_total, 
                       'nan_vel' : n_valid[1]/n_total, 
                       'nan_acc' : n_valid[2]/n_total, 
                       'd_wall'  : n_valid[3]/n_valid[0], 
                       'v'       : n_valid[4]/n_valid[1], 
                       'v_ang'   : n_valid[5]/n_valid[1], 
                       'final'   : n_valid[6]/n_total     }
    
    trial.update({'valid':valid, 'valid_fraction':valid_fraction}) #, 'cut_label':cut_label})
    return trial

def apply_cuts(trial):
    globals().update(trial)
    for k in 'data','vel','acc','v':
        trial[k][~valid[:,:,6]] = np.nan
    return trial

default_cut_ranges = dict( d_wall=[-np.inf,np.inf], 
                           v=[0,np.inf], 
                           v_ang=[-np.inf,np.inf] )
def preprocess_trial(trial_file, cut_ranges=None):
    trial = parse_trial_file(trial_file)
    trial.update(tt.load_trial(trial_file))
    trial = compute_kinematics(trial)
    if not cut_ranges is None:
        ranges = copy(default_cut_ranges)
        ranges.update(cut_ranges)
        trial = compute_cuts(trial, ranges)
        trial = apply_cuts(trial)
    return trial
    
def matching_trials(pop=None, age=None, n_ind=None, df=trials):
    I = pd.Series(data=True, index=df.index)
    if not pop is None:
        I = I & (df['pop']==pop)
    if not age is None:
        I = I & (df['age']==age)
    if not n_ind is None:
        I = I & (df['n_ind']==n_ind)
    return df[I].index.tolist()
    
# t0    = datetime.datetime.now()
# trial = preprocess_trial(trial_files[0])
# print(datetime.datetime.now()-t0)

In [None]:
# Show raw and fixed trajectory around a fix.

trial_file = matching_trials(age=28)[0]
print(trial_file)
trial = tt.load_trial(trial_file)
globals().update(trial)

data0 = data.copy()
print(data0.shape)
from trilabtracker.gui_classes import Track
track = Track(trial_dir)
fixes = tt.load_pik(join(trial_dir,'gui_fixes.pik'))['fixes']
for fix in fixes:
    track.fix(*fix, recompute_bad_frames=False)
data = track.tracks[:,:n_ind,:]

plt.figure(figsize=(6,6))
I,J = slice(6050,6150), [0,2,4]
for j in J:
    c = f'C{j}'
    plt.plot(*data0[I,j,:2].T, lw=8, alpha=0.2, color=c, zorder=-5)
    plt.plot(*data[I,j,:2].T, lw=1, color=c)
# plt.axis('equal')
plt.show()

In [None]:
# trial_file = matching_trials(age=28)[0]
for trial_file in tqdm(trials.index):
    trial = tt.load_trial(trial_file)
    globals().update(trial)
    trial_name = os.path.split(trial['trial_dir'])[1]
    
    fx = f'../excel/{trial_name}.xlsx'
    if os.path.exists(fx):
        continue
    
    try:
        fn    = join(trial_dir,'gui_fixes.pik')
        if os.path.exists(fn):
            track = Track(trial_dir)
            fixes = tt.load_pik(fn)['fixes']
            for fix in fixes:
                track.fix(*fix, recompute_bad_frames=False)
            data = track.tracks[:,:n_ind,:]

        fn = open(join(trial_dir, 'raw.txt')).read()
        fn = os.path.splitext(fn)[0]+'.txt'
        fn = join(os.path.relpath(trial_dir, os.getcwd()), fn)
        fn = os.path.normpath(fn)
        os.path.exists(fn), fn
        time = pd.read_csv(fn, index_col=0)
        time.columns = ['Time']
        I = np.nonzero((time.diff()<0).values)[0]
        for i in I:
            time.iloc[i:] += 128
        time = time[time.index<len(data)]
        time -= time.iloc[0]

        writer = pd.ExcelWriter(fx)
        for i in range(n_ind):
            df = pd.DataFrame(data[:,i,:3], columns=['X (cm)', 'Y (cm)', 'Direction (rad)'])
            df = pd.concat([time,df], axis=1)
            df.to_excel(writer, sheet_name=f'fish {i+1}', index=None)
        for sheet in writer.sheets.values():
            for col in sheet.columns:
                for cell in col:
                    if isinstance(cell.value,float):
                        cell.number_format = '0.000'
                sheet.column_dimensions[col[0].column_letter].width = 15
        writer.save()
        
    except:
        print(trial_name)