In [None]:
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt
import os
import seaborn as sns
import umap
import csv
import warnings
from collections import OrderedDict
from collections import defaultdict
from scipy import signal, stats
from mpl_toolkits import mplot3d
from mpl_toolkits.mplot3d import axes3d
from matplotlib.patches import Patch

plt.rcParams.update({'figure.max_open_warning': 0})
warnings.simplefilter("ignore")

%run grooming_functions.ipynb
%matplotlib inline

# for figure styling and saving
sns.set()
sns.set_style('ticks')
out_path = '/media/turritopsis/katie/grooming/t1-grooming'

#cmap = plt.get_cmap('Spectral')
#colors = [cmap(i/(len(fly_names)-1.9999)) for i in range(len(fly_names))]
#colors[4] = 'y'
#colors[5] = adjust_color('y', 1.7)
#colors[6] = 'g'
#colors[7] = adjust_color('g', 1.3)
#colors[-1] = adjust_color('#5636a7', 1.3)

In [None]:
# load data
behavior = 't1_grooming'
prefix = '/media/turritopsis/katie/grooming/summaries'
data_path = os.path.join(prefix, 'lines-' + behavior + '_processed.parquet')
data = pd.read_parquet(os.path.join(data_path), engine='fastparquet')

In [None]:
fps = 300.0 # know this for this dataset

# get the joints to analyze
bodyparts = np.array(['L1A', 'L1B', 'L1C', 'L1D', 'L1E', 
                      'L2A', 'L2B', 'L2C', 'L2D', 'L2E', 
                      'L3A', 'L3B', 'L3C', 'L3D', 'L3E', 
                      'R1A', 'R1B', 'R1C', 'R1D', 'R1E', 
                      'R2A', 'R2B', 'R2C', 'R2D', 'R2E',
                      'R3A', 'R3B', 'R3C', 'R3D', 'R3E'])

bodyparts = [x.replace('_error', '') for x in data.columns if '_error' in x]
bodyparts_xyz = [bp + '_' + x for bp in bodyparts for x in ['x', 'y', 'z']]

angle_types = np.array(['_BC', '_flex', '_rot', '_abduct'])
angle_names_t1 = get_angle_names(data, angle_types, only_t1 = True)
angle_names = get_angle_names(data, angle_types, only_t1 = False)
bout_length_dict = get_bout_lengths(data)

In [None]:
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_rot', '_x', '_y', '_z'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])
              and v[:2] == 'L1']

bout_numbers = np.unique(np.array(data.behavior_bout))
fly_dict = get_fly_id(data, bout_numbers)
videos = get_videos(bout_numbers, data)
fly_videos = fly_to_video(data)
dif_flies = np.unique(list(fly_dict.values()))
fly_data, fly_names_sorted = data_per_fly(data)

###### deciding the thresholds for grooming type

In [None]:
# plot all grooming angles across time (subplots) 
legs = ['L1', 'R1']
angle_types = ['_abduct', '_flex', '_rot']
angle_titles = ['abduction', 'flexion', 'rotation']
data_sorted = data.sort_values('grooming_score')
bout_numbers_sorted = np.array(data_sorted.drop_duplicates('behavior_bout').behavior_bout.astype(int))

for k in range(len(bout_numbers_sorted)):

    bout = data[data.behavior_bout == bout_numbers_sorted[k]]
    fig, axs = plt.subplots(1, 3, sharex=True, sharey=True, figsize = (18, 3))
    ax = axs.T.flatten()
    print(bout.filename.iloc[0])  
    print(bout.grooming_score.iloc[0])
    handles = []
    labels = []
    
    for i in range(len(angle_types)): 

        angle_vars = [v for v in data.columns if some_contains(v, [angle_types[i]])
                     and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])
                     and v[:2] in legs] 

        for j in range(len(angle_vars)):
            
            angle = np.array(bout.iloc[0:][angle_vars[j]])
            t = np.array(range(len(angle))) / fps
            ax[i].plot(t, angle, label = angle_vars[j])
            ax[i].set_xlabel('time (seconds)', fontsize = 14)
            ax[i].set_ylabel('{} angle (deg)'.format(angle_titles[i]), fontsize = 14)
            ax[i].set_title('{}, bout {}'.format(bout.flyid.iloc[0] , str(int(bout_numbers_sorted[k]))), fontsize = 14)
            ax[i].tick_params(labelleft = True)
            
        hs, ls = ax[i].get_legend_handles_labels()
        handles.append(hs)
        labels.append(hs)     
            
    sns.despine()
    plt.show() 

In [None]:
titles = {
    'A_flex': 'coxa flexion',
    'A_abduct': 'body-coxa abduction',
    'A_rot': 'coxa rotation',
    'B_flex': 'coxa-femur flexion',
    'B_rot': 'femur rotation',
    'C_flex': 'femur-tibia flexion',
    'C_rot': 'tibia rotation',
    'D_flex': 'tibia-tarsus flexion',
    '_BC': 'body-coxa abduction'
}

In [None]:
# define functions 
def get_url(flyid, filename):
    url_prefix = 'http://128.95.10.233:5000'
    session, _, folder = flyid.partition('_')
    url = '{}/#{}/Fly {}/{}'.format(url_prefix, session, folder, filename)
    url = url.replace(' ', '%20')
    return url

def get_cycles(data, angle_vars, align_to = None, get_all_cycles = False, min_length = 0, dist = 20, height = None, npeaks = 1):
    
    cycle_dict = dict()
    length_dict = dict()
    bout_numbers = np.unique(data.behavior_bout.astype(int))

    for i in range(len(angle_vars)): 

        all_cycles = []
        all_lengths = []

        for j in range(len(bout_numbers)):

            bout = data[data.behavior_bout == bout_numbers[j]]
            if align_to is None:
                bout_angs = np.array(bout[angle_vars[i]])
            else: 
                bout_angs = np.array(bout[align_to])
                
            if len(bout_angs) < min_length:
                continue
                
            idxs, props = signal.find_peaks(bout_angs, distance = dist, height = height)
            peaks = bout_angs[idxs]
            
            if len(peaks) < 4:
                continue

            cycles = []
            lengths = []
            bout_angs = np.array(bout[angle_vars[i]])
            for k in range(len(peaks)-npeaks):
                cycle = np.zeros(600)
                cycle[:] = np.nan
                period = bout_angs[idxs[k]:idxs[k+npeaks]]
                cycle[:len(period)] = period
                cycles.append(cycle)
                lengths.append(len(period))
            
            if not get_all_cycles:
                all_cycles.append(cycles[len(cycles)//2])
                all_lengths.append(lengths[len(lengths)//2])
            else:
                all_cycles.extend(cycles)
                all_lengths.extend(lengths)

        cycle_dict[angle_vars[i]] = all_cycles
        length_dict[angle_vars[i]] = all_lengths
       
    return cycle_dict, length_dict


def plot_cycles(cycle_dict, length_dict, angle_vars, align_start = True, title_extension = '', fps = 300.0, legend = False):
    
    for j in range(len(angle_vars)): 

        all_cycles = cycle_dict[angle_vars[j]]
        all_lengths = length_dict[angle_vars[j]]
        
        if len(all_lengths) == 0: 
            continue

        fig = plt.figure(figsize = (8,4))
        plt.title('{} {} angles {}'.format(angle_vars[j].split('_')[0][:2], titles[angle_vars[j][2:]], title_extension), fontsize = 14)     
        plt.xlabel('time (seconds)', fontsize = 14)
        plt.ylabel('angle (deg)', fontsize = 14)         

        for k in range(len(all_cycles)):
            cycle = all_cycles[k][:all_lengths[k]]
            if align_start:
                cycle = cycle - cycle[0]
                cycle_centered = np.zeros(600)
                cycle_centered[:] = np.nan
                cycle_centered[:all_lengths[k]] = cycle
                all_cycles[k] = cycle_centered
            plt.plot(np.array(range(len(cycle))) / fps, cycle, linewidth=1)

        avg_cycle = np.nanmean(all_cycles, axis = 0)
        avg_cycle = avg_cycle[:max(all_lengths)]
        plt.plot(np.array(range(len(avg_cycle))) / fps, avg_cycle, label = 'average', color = 'k')
        plt.xlim([0, np.percentile(all_lengths, 85)/fps])

        if legend:
            plt.legend(fontsize = 12, loc=(1.02,0.2)) 
            
        sns.despine()
        plt.show() 

###### aligning cycles by peaks

In [None]:
# overlay grooming cycles from individual bouts, then determine an average grooming
# cycle for each angle (all flies, uses one oscillation per bout)
dist = 20
min_length = 0
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])
              and v[:2] in ['L1', 'R1']]
angle_titles = ['flexion', 'abduction', 'rotation', 'body-coxa']
cycle_dict, length_dict = get_cycles(data, angle_vars, get_all_cycles = False, min_length = min_length, dist = dist)
plot_cycles(cycle_dict, length_dict, angle_vars, align_start = True)

In [None]:
# overlay grooming cycles from individual bouts, then determine an average grooming
# cycle for each angle (all flies, uses all oscillations in all bouts)
dist = 20
min_length = 0
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])
              and v[:2] in ['L1', 'R1']]
angle_titles = ['flexion', 'abduction', 'rotation', 'body-coxa']
cycle_dict, length_dict = get_cycles(data, angle_vars, get_all_cycles = True, min_length = min_length, dist = dist)
plot_cycles(cycle_dict, length_dict, angle_vars, align_start = True)

In [None]:
# overlay grooming cycles from individual bouts, then determine an average grooming
# cycle for each angle (individual flies, uses one oscillation per bout)
dist = 20
min_length = 0
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])
              and v[:2] in ['L1', 'R1']]

fly_data, fly_names = data_per_fly(data)
for j in range(len(fly_names)):
    fly_data = data[data['flyid'] == fly_names[j]]
    cycle_dict, length_dict = get_cycles(fly_data, angle_vars, get_all_cycles = False, min_length = min_length, dist = dist)
    plot_cycles(cycle_dict, length_dict, angle_vars, align_start = True, title_extension = '(fly ' + fly_names[j] + ')')

In [None]:
# overlay grooming cycles from individual bouts, then determine an average grooming
# cycle for each angle (individual bouts, uses all oscillations in all bouts)
dist = 20
min_length = 0
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])
              and v[:2] in ['L1']]

bout_numbers = np.unique(data.behavior_bout.astype(int))
for j in range(len(bout_numbers)):
    bout_data = data[data['behavior_bout'] == bout_numbers[j]]
    fly_id = bout_data.flyid.iloc[0]
    cycle_dict, length_dict = get_cycles(bout_data, angle_vars, get_all_cycles = True, min_length = min_length, dist = dist)
    plot_cycles(cycle_dict, length_dict, angle_vars, align_start = True, title_extension = '(fly ' + fly_names[j] + ', bout ' + str(bout_numbers[j]) + ')')

###### aligning cycles in time

In [None]:
# overlay grooming cycles from individual bouts, then determine an average grooming
# cycle for each angle (all flies, uses one oscillation per bout)
dist = 20
min_length = 0
align_to = 'L1B_flex'
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])
              and v[:2] in ['L1']]
angle_titles = ['flexion', 'abduction', 'rotation', 'body-coxa']
cycle_dict, length_dict = get_cycles(data, angle_vars, align_to = align_to, get_all_cycles = False, min_length = min_length, dist = dist)
plot_cycles(cycle_dict, length_dict, angle_vars, align_start = True)

In [None]:
# overlay grooming cycles from individual bouts, then determine an average grooming
# cycle for each angle (individual flies, uses all oscillations)
dist = 20
min_length = 0
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])
              and v[:2] in ['L1']]

fly_data, fly_names = data_per_fly(data)
for j in range(len(fly_names)):
    fly_data = data[data['flyid'] == fly_names[j]]
    cycle_dict, length_dict = get_cycles(fly_data, angle_vars, get_all_cycles = True, min_length = min_length, dist = dist)
    plot_cycles(cycle_dict, length_dict, angle_vars, align_start = True, title_extension = '(fly ' + fly_names[j] + ')')

In [None]:
# overlay grooming cycles from individual bouts, then determine an average grooming
# cycle for each angle (individual flies, uses one oscillation per bout)
dist = 20
min_length = 0
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])
              and v[:2] in ['L1', 'R1']]

fly_data, fly_names = data_per_fly(data)
for j in range(len(fly_names)):
    fly_data = data[data['flyid'] == fly_names[j]]
    cycle_dict, length_dict = get_cycles(fly_data, angle_vars, get_all_cycles = False, min_length = min_length, dist = dist)
    plot_cycles(cycle_dict, length_dict, angle_vars, align_start = True, title_extension = '(fly ' + fly_names[j] + ')')

In [None]:
# overlay grooming cycles from individual bouts, then determine an average grooming
# cycle for each angle (individual bouts, uses all oscillations in all bouts)
dist = 20
min_length = 0
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])
              and v[:2] in ['L1']]

bout_numbers = np.unique(data.behavior_bout.astype(int))
for j in range(len(bout_numbers)):
    bout_data = data[data['behavior_bout'] == bout_numbers[j]]
    fly_id = bout_data.flyid.iloc[0]
    if fly_id == '5_0 5272019':
        cycle_dict, length_dict = get_cycles(bout_data, angle_vars, get_all_cycles = True, min_length = min_length, dist = dist)
        plot_cycles(cycle_dict, length_dict, angle_vars, align_start = True, title_extension = '(fly ' + fly_id + ', bout ' + str(bout_numbers[j]) + ')')

In [None]:
fly_videos['4_0 5222019']

###### find videos for a fly

In [None]:
flyid = '1_0 6182019'
videos = fly_videos[flyid]
for fname in videos:
    print(fname + ' ({})'.format(len(data[data.filename == fname])))
    #print(get_url(flyid, filename))

###### calculate grooming frequencies

In [None]:
# compute and save average grooming frequencies for each fly and session. 
# maybe include frequency from peak finding and using signal.welch.
# plot power spectral density for each bout and find max freq

# choose one angle to determine the grooming frequency from 
# (avg freq of this joint across all bouts associated with the fly)

angle = 'R1C_flex' # 'R1C_flex'
cols = ['date', 'fly', 'grooming_freq_welch', 'grooming_freq_period', 'grooming_cycles', 'grooming_bouts']
freq_df = pd.DataFrame(columns = cols)
fly_ids = np.unique(list(fly_dict.values()))
dist = 25

for i in range(len(fly_ids)):
    
    fly_data = data[data['flyid'] == fly_ids[i]]
    bout_nums = np.unique(fly_data.behavior_bout)
    fly_freqs_welch = []
    fly_freqs_period = []
    grooming_cycles = 0
    
    for j in range(len(bout_nums)):
        
        bout = data[data.behavior_bout == bout_nums[j]]
        t1 = bout.iloc[0:][angle]
        t1 = np.array(t1[np.isfinite(t1)])
        if len(t1) <= 1: 
            continue

        f, pxx = signal.welch(t1, fs=300, nperseg=1024)  
        pxx = pxx - np.mean(pxx)
        fly_freqs_welch.append(f[np.argmax(pxx)])
        
        peaks, props = signal.find_peaks(t1, height = None, distance = dist)
        mean_int, stderr_int, intervals = mean_peak_interval(t1, fps, dist = dist)
        fly_freqs_period.append(1/mean_int)
        grooming_cycles += len(peaks)
     
    avg_freq_welch = np.nanmean(fly_freqs_welch)
    avg_freq_period = np.nanmean(fly_freqs_period)
    
    row = dict()
    row[cols[0]] = fly_ids[i].split()[1]
    row[cols[1]] = fly_ids[i].split()[0]
    row[cols[2]] = avg_freq_welch
    row[cols[3]] = avg_freq_period
    row[cols[4]] = grooming_cycles
    row[cols[5]] = len(bout_numbers)
    freq_df = freq_df.append(row, ignore_index = True)

csv_name = os.path.join(out_path, 'grooming_freqs_' + angle + '.csv')
freq_df.to_csv(csv_name, index = False)