In [None]:
import os
import re
import json
import fnmatch
import warnings
import numpy as np
import pandas as pd 
import seaborn as sns
import matplotlib.pyplot as plt

from scipy import signal, stats
from collections import defaultdict

plt.rcParams.update({'figure.max_open_warning': 0})
warnings.simplefilter("ignore")

%run /media/turritopsis/katie/grooming/t1-grooming/grooming_functions.ipynb
%matplotlib inline

sns.set()
sns.set_style('ticks')

In [None]:
titles = {
    'A_flex': 'coxa flexion',
    'A_abduct': 'body-coxa abduction',
    'A_rot': 'coxa rotation',
    'B_flex': 'coxa-femur flexion',
    'B_rot': 'femur rotation',
    'C_flex': 'femur-tibia flexion',
    'C_rot': 'tibia rotation',
    'D_flex': 'tibia-tarsus flexion',
    '_BC': 'body-coxa abduction'
}

temp_legend = {'0': '70-79',
               '1': '80-84',
               '2': '85-89'}

In [None]:
data_path = '/media/turritopsis/katie/apviz/classifiers/2021_04_26/temperature/training_data.parquet'

In [None]:
d = pd.read_parquet(data_path, engine='fastparquet')
d['date'] = d.flyid.str.partition(' ')[0]
data = d[d.behavior == 't1_grooming'] 
data['temp'] = data.flyid.str.partition('_')[2]

In [None]:
# add velocity and acceleration columns to data
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range', 'fictrac'])]

fps = 300.0 
dt = 1/fps
s = 1.0/dt
s2 = 1.0 / (dt * dt)
bout_numbers = np.unique(np.array(data.behavior_bout))

for j in range(len(bout_numbers)):
    mask = data.behavior_bout == bout_numbers[j]
    bout_df = data.loc[mask]
    for ang in angle_vars:
        bout = np.array(bout_df[ang])
        data.loc[mask, ang + '_d1'] = signal.savgol_filter(bout, 5, 3, deriv=1) * s
        data.loc[mask, ang + '_d2'] = signal.savgol_filter(bout, 5, 3, deriv=2) * s2

In [None]:
nframes = 0
for bout in bouts:
    fig = plt.figure()
    b = data[data.behavior_bout == bout]
    plt.plot(np.arange(len(b['L1C_flex'])), b['L1C_flex'])
    plt.show()
    print(len(b))
    nframes = nframes + len(b)

In [None]:
# plot angle distributions for each temperature treatment
sns.set_style('ticks')
temps = np.unique(data.temp)
legs = ['L1', 'R1']
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range', 'fictrac'])
              and v[:2] in legs]

for i in range(len(angle_vars)): 

    fig = plt.figure(figsize = (9,2))
    plt.title(angle_vars[i][:2] + ' ' + titles[angle_vars[i][2:]])
    plt.xlabel('angle (deg)')
    colors = sns.color_palette('gist_heat_r', len(temps))

    for j in range(len(temps)):

        temp_data = data[data['temp'] == temps[j]]

        t1 = temp_data.iloc[0:][angle_vars[i]]
        t1 = t1[np.isfinite(t1)] # ignores nans 
        if len(t1) <= 1: 
            continue        

        kernel_t1 = stats.gaussian_kde(t1)    
        t1 = np.linspace(np.percentile(t1, 2), np.percentile(t1, 98), 500)
        height_t1 = kernel_t1.pdf(t1)                       
        plt.plot(t1, height_t1, label = temp_legend[temps[j]], color = colors[j])
    
    ax = plt.gca()
    hs, ls = ax.get_legend_handles_labels()
    plt.subplots_adjust(wspace = 0.2, hspace = 0.3)
    sns.despine()
    plt.show()

fig = plt.figure(figsize = (3,3))
plt.legend(handles = hs, labels = ls, loc = 'center', fontsize = 12)
plt.axis('off')
plt.tight_layout()
plt.show()   

In [None]:
# plot angle distributions for each temperature treatment
sns.set_style('ticks')
temps = np.unique(data.temp)
legs = ['L1', 'R1']
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and some_contains(v, ['_d1'])
              and v[:2] in legs]

for i in range(len(angle_vars)): 

    fig = plt.figure(figsize = (9,2))
    plt.title(angle_vars[i][:2] + ' ' + titles[angle_vars[i][2:-3]])
    plt.xlabel('velocity (deg/$s$)')
    colors = sns.color_palette('gist_heat_r', len(temps))

    for j in range(len(temps)):

        temp_data = data[data['temp'] == temps[j]]

        t1 = temp_data.iloc[0:][angle_vars[i]]
        t1 = t1[np.isfinite(t1)] # ignores nans 
        if len(t1) <= 1: 
            continue        

        kernel_t1 = stats.gaussian_kde(t1)    
        t1 = np.linspace(np.percentile(t1, 1), np.percentile(t1, 99), 500)
        height_t1 = kernel_t1.pdf(t1)                       
        plt.plot(t1, height_t1, label = temp_legend[temps[j]], color = colors[j])
    
    ax = plt.gca()
    hs, ls = ax.get_legend_handles_labels()
    plt.subplots_adjust(wspace = 0.2, hspace = 0.3)
    sns.despine()
    plt.show()

fig = plt.figure(figsize = (3,3))
plt.legend(handles = hs, labels = ls, loc = 'center', fontsize = 12)
plt.axis('off')
plt.tight_layout()
plt.show()   

In [None]:
# plot angle distributions for each temperature treatment
sns.set_style('ticks')
temps = np.unique(data.temp)
legs = ['L1', 'R1']
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and some_contains(v, ['_d2'])
              and v[:2] in legs]

for i in range(len(angle_vars)): 

    fig = plt.figure(figsize = (9,2))
    plt.title(angle_vars[i][:2] + ' ' + titles[angle_vars[i][2:-3]])
    plt.xlabel('acceleration (deg/$s^2$)')
    colors = sns.color_palette('gist_heat_r', len(temps))

    for j in range(len(temps)):

        temp_data = data[data['temp'] == temps[j]]

        t1 = temp_data.iloc[0:][angle_vars[i]]
        t1 = t1[np.isfinite(t1)] # ignores nans 
        if len(t1) <= 1: 
            continue        

        kernel_t1 = stats.gaussian_kde(t1)    
        t1 = np.linspace(np.percentile(t1, 1), np.percentile(t1, 99), 500)
        height_t1 = kernel_t1.pdf(t1)                       
        plt.plot(t1, height_t1, label = temp_legend[temps[j]], color = colors[j])
    
    ax = plt.gca()
    hs, ls = ax.get_legend_handles_labels()
    plt.subplots_adjust(wspace = 0.2, hspace = 0.3)
    sns.despine()
    plt.show()

fig = plt.figure(figsize = (3,3))
plt.legend(handles = hs, labels = ls, loc = 'center', fontsize = 12)
plt.axis('off')
plt.tight_layout()
plt.show()   

In [None]:
def get_freqs(data, fps, thresh = None, dist = None):
    
    data = data[np.isfinite(data)]
    idxs, props = signal.find_peaks(data, height = thresh, distance = dist)
    intervals = np.diff(idxs) / fps # in seconds
    if len(idxs) <= 1: 
        freqs = 1 / np.array([len(data)])
    else: 
        freqs = 1 / intervals
    
    freq_data = np.zeros(len(data))
    if len(intervals) == 0:
        freq_data = [freqs[0]]*len(data)
    else:
        freq_data[:idxs[0]] = freqs[0]
        freq_data[idxs[-1]:] = freqs[-1]
        if len(freqs) > 1:
            for i in range(len(freqs)):
                freq_data[idxs[i]:idxs[i+1]] = freqs[i]
    return freq_data

In [None]:
dist = 15
fps = 300
data = data.reset_index(drop = True)
bout_numbers = np.unique(data.behavior_bout)
angle_vars = np.unique([v for v in data.columns
              if some_contains(v, ['_BC', '_flex', '_rot', '_abduct'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range', 'fictrac'])])
frequency = np.zeros(len(data))

for angle in angle_vars:
    
    for j in range(len(bout_numbers)):  
        bout_data = data[data.behavior_bout == bout_numbers[j]]
        freq_data = get_freqs(bout_data[angle], fps, dist = dist)
        frequency[bout_data.index] = freq_data
    
    data[angle + '_freq'] = frequency
    

In [None]:
# plot frequency for each joint averaged across all bouts and scatter 
# points around them 
angle_types_names = ['abduction', 'flexion', 'rotation']
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and some_contains(v, ['_freq'])
              and v[:2] in legs]
sns.set_style('ticks')
scat = True
temps = np.unique(data.temps)
dists = np.array([15, 15, 15])

#cmap = plt.get_cmap('Spectral')
# colors = [cmap(i/(n-0.9999)) for i in range(8)] 

for i in range(len(angle_vars)):
    
    ang_names = [angle_vars[i] + ' (' + temp_legend[t] + ')' for t in temp_legend.keys()]
    fig = plt.figure(figsize = (8,4))
    plt.title('Average grooming frequency of ' + angle_vars[i] + ' angles', fontsize = 14)
    plt.xlabel('Joint', fontsize = 14)
    plt.ylabel('Frequency (Hz)', fontsize = 14) 
    colors = sns.color_palette('gist_heat_r', len(temps))
    ax = plt.gca()
    
    for j in range(len(temps)):
    
        temp_data = data[data.temp == temps[j]]
        mean = np.nanmean(temp_data[angle_vars[i]])
        stderr_freq = np.nanstd(temp_data[angle_vars[i]])

            
#             k = j + 1
#             sc = 0
#             if scat:
#                 sc = 0.5*np.random.rand(mean.shape[0]) - 0.75 
#             else: 
#                 sc = -0.5
#             plt.scatter(np.ones(mean.shape[0]) + j + sc, mean[:, j], s = 2, color = colors[j])
#             m = mean_freq[j]
#             x_b = j + 0.5
#             x_t = (j + 1)
#             ax.axhline(y = m, xmin = j/n + 0.02, xmax = k/n - 0.02, color = colors[j])
#             ax.errorbar(j + 0.5, mean_freq[j], yerr = stderr_freq[j], fmt = 'none', capsize = 5, color = colors[j])
#             plt.xlim([0, n])
#             plt.ylim([0, 26])
        # ax.errorbar(np.arange(0.5, n + 0.5, 1), mean_int, yerr = stderr_int, fmt = 'none', capsize = 5)

    plt.xticks(np.arange(0.5, (len(ang_names)), 1), [ang_names[n] for n in range(len(ang_names))])
    sns.despine()
    # plt.savefig(r'/media/turritopsis/katie/grooming/figures/grooming_freq/grooming_freq_' + angle_types_names[i] + '.png', bbox_inches = 'tight') 
    plt.show()

In [None]:
data.groupby('temp').mean()