In [58]:
%matplotlib inline
%run /media/turritopsis/katie/grooming/t1-grooming/grooming_functions.ipynb

import os 
import pandas as pd 
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import gc
from scipy import signal, stats

warnings.filterwarnings('ignore')

In [59]:
behavior = 't1_grooming'
prefix = '/media/turritopsis/katie/apviz/classifiers/2021_02_19'
# prefix = '/media/turritopsis/katie/apviz/classifiers/2021_01_05'
prefix_out = '/media/turritopsis/katie/grooming/summaries/v3-b2'

In [60]:
# load training data from labels (angles, pose3d)
data = pd.read_parquet(os.path.join(prefix, 'training_data.parquet'), engine='fastparquet')

In [61]:
data = data[data.behavior == 't1_grooming']
fly_decomp = data['flyid'].str.split(' ', n = 2, expand = True)
data['Date'] = fly_decomp[0]
data['Fly #'] = fly_decomp[2]
data['flyid'] = data['Fly #'].astype(str) + ' ' + data['Date'].astype(str)

In [62]:
p = '/media/turritopsis/pierre/gdrive/Tuthill Lab Shared/Pierre/summaries/v3-b2/lines-t1_grooming'
(root, dirs, files) = next(os.walk(p))
files = sorted(files)

In [63]:
wt_flies = [file for file in files if 'Berlin-WT' in file and ('offball' not in file and 'headless' not in file)]
hg_flies = [file for file in files if '39A11-gal4xUAS-10x-ChrimsonR' in file]
ag_flies = [file for file in files if 'AntennalGrooming' in file]
eg_flies = [file for file in files if 'EyeGrooming' in file]
files = [*wt_flies, *hg_flies, *ag_flies, *eg_flies]

In [64]:
line_dict = {'evyn--Berlin-WT.pq':'berlin wt',
             'sarah--rv1-Berlin-WT.pq':'berlin wt',
             'sarah--rv3-Berlin-WT.pq':'berlin wt',
             'sarah--rv4-Berlin-WT.pq':'berlin wt',
             'sarah--rv2-39A11-gal4xUAS-10x-ChrimsonR.pq':'39A11 (head)',
             'sarah--rv3-AntennalGrooming-w;25F11-AD;27H08-DBDxUAS-10x-ChrimsonR.pq':'25F11AD;27H08DBD (antennal)',
             'sarah--rv3-AntennalGrooming-w;VT005525-AD(100C03);27H08-DBDxUAS-10x-ChrimsonR.pq':'VT005525AD;27H08DBD (antennal)',
             'sarah--rv3-EyeGrooming-w;VT017251-LexA(3012796)xLexAop-Chrimson-tdTomato.pq':'VT017251 (eye)'}

In [65]:
#(root, dirs, files) = next(os.walk(prefix))
# files = sorted(files)
bout_num = 1
thresh = 50
session_datas = []

for file in files:
        
    print(file)
    path = os.path.join(p, file)
    file_data = pd.read_parquet(path, engine='fastparquet')
    cols_good = np.unique([v for v in file_data.columns
              if not some_contains(v, ['_range', '_score', '_error', '_ncams',
                                       '_prob', '_class', '_bout_number'])])
    cols_good = np.append(cols_good, ['behavior_bout', 'line'])
    if len(file_data) == 0:
        continue
        
    sessions = np.unique(file_data.Date)
    for session in sessions:
        
        session_data = file_data[file_data['Date'] == session]
        dsub = session_data[session_data[behavior + '_class']]
        d = dsub[~dsub['t1_grooming_bout_number'].isna()]
        bout_numbers = np.unique(d['t1_grooming_bout_number'])

        for j in range(len(bout_numbers)): 
            bout = dsub[dsub['t1_grooming_bout_number'] == bout_numbers[j]]
            bout['behavior_bout'] = bout_num
            bout['line'] = line_dict[file]
            if len(bout) >= thresh:
                session_datas.append(bout[cols_good])
                bout_num += 1
        
df = pd.concat(session_datas)
df['flyid'] = df['Fly #'].astype(str) + ' ' + df['Date'].astype(str)

evyn--Berlin-WT.pq
sarah--rv1-Berlin-WT.pq
sarah--rv3-Berlin-WT.pq
sarah--rv4-Berlin-WT.pq
sarah--rv2-39A11-gal4xUAS-10x-ChrimsonR.pq
sarah--rv3-AntennalGrooming-w;25F11-AD;27H08-DBDxUAS-10x-ChrimsonR.pq
sarah--rv3-AntennalGrooming-w;VT005525-AD(100C03);27H08-DBDxUAS-10x-ChrimsonR.pq
sarah--rv3-EyeGrooming-w;VT017251-LexA(3012796)xLexAop-Chrimson-tdTomato.pq


In [66]:
fly_ids = np.unique(data['flyid'])
print(fly_ids)
print(np.unique(data['flyid']))

['10_0 10.22.20' '1_0 10.22.20' '1_0 11.5.20' '1_0 11.6.20' '1_0 12.10.20'
 '1_0 12.16.20' '1_0 5.22.19' '1_0 5.24.19' '1_0 5.27.19' '1_0 6.10.20'
 '1_0 6.11.20' '1_0 6.15.20' '1_0 6.3.20' '1_0 6.4.20' '1_0 6.5.20'
 '1_0 7.10.19' '1_1 10.19.20' '1_1 10.6.20' '1_1 11.5.20' '1_1 11.6.20'
 '1_2 8.31.20' '1_2 9.1.20' '2_0 10.22.20' '2_0 11.5.20' '2_0 11.6.20'
 '2_0 12.10.20' '2_0 5.27.19' '2_0 6.10.20' '2_0 6.11.20' '2_0 6.15.20'
 '2_0 6.3.20' '2_0 6.4.20' '2_0 6.5.20' '2_0 8.31.20' '2_0 9.1.20'
 '2_1 11.5.20' '2_1 11.6.20' '2_1 12.16.20' '3_0 10.21.20' '3_0 10.22.20'
 '3_0 11.5.20' '3_0 5.22.19' '3_0 5.27.19' '3_0 6.10.20' '3_0 6.11.20'
 '3_0 6.3.20' '3_0 6.4.20' '3_0 6.5.20' '3_0 8.28.20' '3_0 9.1.20'
 '3_1 11.5.20' '3_1 8.28.20' '3_2 8.28.20' '3_3 8.28.20' '4_0 10.21.20'
 '4_0 10.22.20' '4_0 11.21.19' '4_0 12.11.20' '4_0 12.15.20' '4_0 5.22.19'
 '4_0 5.27.19' '4_0 6.10.20' '4_0 6.11.20' '4_0 6.15.20' '4_0 6.4.20'
 '4_0 6.5.20' '4_0 9.1.20' '5_0 10.21.20' '5_0 10.22.20' '5_0 5.22.19'
 '5

In [67]:
ndata = df[df['flyid'].isin(fly_ids)]
np.unique(ndata.Date)

array(['10.21.20', '10.22.20', '11.5.20', '11.6.20', '12.10.20',
       '12.11.20', '12.15.20', '12.16.20', '5.22.19', '5.24.19',
       '5.27.19', '6.10.20', '6.11.20', '6.15.20', '6.3.20', '6.4.20',
       '6.5.20', '8.28.20', '8.31.20', '9.1.20'], dtype=object)

In [68]:
# adjust data
def adjust_rot_angles(angles, angle_names):
    conds = ['2', '3', 'L1A', 'L1B', 'L1C', 'R1A', 'R1B', 'R1C']
    offsets = np.array([-50, -20, 20, -70, 10, 20, 70, -30])
    for j in range(len(conds)):
        rot_angs = [r for r in angle_names if '_rot' in r and conds[j] in r]
        for ang in rot_angs:
            r = np.array(angles[ang])
            r[r > offsets[j]] = r[r > offsets[j]] - 360
            angles[ang] = r
        
    abduct_angs = [r for r in angle_names if '_abduct' in r or 'A_flex' in r]
    for ang in abduct_angs:
        r = np.array(angles[ang])
        r[r > 50] = r[r > 50] - 360
        angles[ang] = r
        
    return angles

angle_vars = np.unique([v for v in data.columns
              if some_contains(v, ['_BC', '_flex', '_rot', '_abduct'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])])
ndata = correct_angles(ndata, angle_vars)
ndata = adjust_rot_angles(ndata, angle_vars)
ndata = normalize_data(ndata)

In [69]:
# save if all labels are manual 
path_out = os.path.join(prefix_out, 't1_grooming_subset_curated.parquet')
ndata.to_parquet(path_out, compression = 'gzip')

In [57]:
ndata.line

0              berlin_wt
1              berlin_wt
2              berlin_wt
3              berlin_wt
4              berlin_wt
               ...      
140018    VT017251 (eye)
140019    VT017251 (eye)
140020    VT017251 (eye)
140021    VT017251 (eye)
140022    VT017251 (eye)
Name: line, Length: 353739, dtype: object

In [67]:
# remove head_grooming from t1_grooming data (dont run if labels are all manual)
features = [v for v in data.columns
              if some_contains(v, ['_flex', '_rot', '_x', '_y', '_z'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])
              and v[:2] == 'L1']
feature_names= ['L1B_rot_avg_range', 'L1A_flex_avg_range', 'L1E_z_avg_range', 'L1D_z', 'L1E_z']
flip = [False, False, False, True, True]
data = compute_grooming_scores(data, features, feature_names, flip = flip, dist=20, norm=False)
df0 = data[data.grooming_score < 8.25]
df = df0[df0.grooming_score > 1.6]

In [68]:
# save with the grooming score before removing bouts with certain scores?
cols_good = np.unique([v for v in ndata.columns
              if not some_contains(v, ['_range', '_error', '_ncams', '_prob', '_class', '_bout_number'])])

# out = os.path.join(prefix_out, 'lines-' + behavior + '_onball_processed_all_gs.parquet')
out = os.path.join(prefix_out, 'subset_t1_grooming_all_gs.parquet')
ndata = ndata[cols_good]
ndata.to_parquet(out, compression = 'gzip')

In [69]:
cols_good = np.unique([v for v in df.columns
              if not some_contains(v, ['_range', '_error', '_ncams', '_prob', '_class', '_bout_number'])])
df = df[cols_good]


In [70]:
# path_out = os.path.join(prefix_out, 'lines-' + behavior + '_onball_processed.parquet')
path_out = os.path.join(prefix_out, 't1_grooming_subset_curated.parquet')
df.to_parquet(path_out, compression = 'gzip')

In [73]:
print(len(df))
print(len(np.unique(df.behavior_bout)))
print(len(data))
print(len(np.unique(data.behavior_bout)))

324631
964
359873
1120
