In [None]:
# %matplotlib notebook
%matplotlib inline
%run /media/turritopsis/katie/grooming/t1-grooming/grooming_functions.ipynb

import re
import random
import os.path
import fnmatch
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import umap
import csv

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from scipy import signal, stats
from matplotlib.patches import Patch
import matplotlib.cm as cm
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from collections import defaultdict
from sklearn.preprocessing import StandardScaler

warnings.filterwarnings('ignore')

In [None]:
# figure styling
sns.set(style='ticks', palette='colorblind')
figure_path = '/media/turritopsis/katie/grooming/t1-grooming/revisited/figures/umap'
# plt.rcParams['figure.figsize'] = (6,3)
# plt.rcParamsfigure.dpi'] = 200

In [None]:
def some_contains(v, L):
    for name in L:
        if name in v:
            return True
    return False

def get_url(flyid, filename):
    url_prefix = 'http://128.95.10.233:5000'
    session, _, folder = flyid.partition('_')
    url = '{}/#{}/Fly {}/{}'.format(
        url_prefix, session, folder, filename)
    url = url.replace(' ', '%20')
    return url

def get_bout_url(data):
    bout_url_dict = dict()
    bout_numbers = np.unique(data.behavior_bout.astype(int))
    for j in range(len(bout_numbers)):
        bout = data[data.behavior_bout == bout_numbers[j]]
        bout_url_dict[bout_numbers[j]] = get_url(data.flyid[0], data.filename[0])
    return bout_url_dict

def get_kde_vals(arr, xvals=None):
    
    x = arr[~np.isnan(arr)]
    kde = stats.gaussian_kde(x, bw_method='scott')
    
    if xvals is None:
        xvals = np.linspace(np.min(x)-1, np.max(x)+1, num=150)
    yvals = kde.evaluate(xvals)
    
    return xvals, yvals

# parse fly summary spreadsheet
def parse_date(df): 
    dates = list(df.Date.str.split('.'))
    dates_new = []
    for j in range(len(dates)):
        date = dates[j]
        month = date[0]
        day = date[1]
        if len(day) == 1:
            day = '0' + str(day)
        year = '20' + date[2]
        dates_new.append(str(month) + str(day) + str(year))
    df['date'] = dates_new
    return df

def clean_summary(prefix, f_in, f_out):
    fname_in = os.path.join(prefix, f_in)
    fname_out = os.path.join(prefix, f_out)
    with open(fname_in, 'r', encoding='utf-8', errors='ignore') as infile, open(fname_out, 'w') as outfile:
        inputs = csv.reader(infile)
        output = csv.writer(outfile)

        for index, row in enumerate(inputs):
            output.writerow(row)

# get the angle names to analyze (use _BC instead of _abduct for now)
def get_angle_names(angles, angle_types, only_t1):
    angle_names = np.array([])
    for ang in angle_types:
        if only_t1:
            angle_names = np.append(angle_names, [s for s in list(angles.columns) if '1' in s and ang in s])
        else:
            angle_names = np.append(angle_names, [s for s in list(angles.columns) if ang in s])
    # angle_names = angle_names + ['fictrac_speed', 'fictrac_rot']
    angle_names = [x for x in angle_names if 'freqs' not in x]
    return angle_names

# removes grooming bouts from dataset that are less than a specified 
# number of frames (too short to analyze)
def remove_short_bouts(data, min_frames):    
    bout_numbers = np.unique(data.behavior_bout)
    bout_lengths = np.zeros(bout_numbers.shape)
    for j in range(len(bout_numbers)):
        bout_lengths[j] = len(data[data.behavior_bout == bout_numbers[j]])
        
    saved_bouts = bout_lengths[bout_lengths >= min_frames]
    data_new = data[data.behavior_bout.isin(saved_bouts)]
    
    return data_new

# returns dictionary that maps bout number to fly id
def get_fly_id(angles, bout_numbers):
    fly_id = dict()
    for j in range(len(bout_numbers)):
        bout_df = angles[angles.behavior_bout == bout_numbers[j]]
        fly_id[bout_numbers[j]] = bout_df.iloc[0].flyid
    return fly_id

def norm(X):
    X_norm = X.reshape(X.shape[0], -1)
    X_norm = X - np.mean(X, axis = 0)
    X_norm = X_norm / np.std(X, axis = 0)
    X_norm = X_norm.reshape(X.shape)
    return X_norm

# determine which flies we have the most data for, then sort by flies with the most data
def data_per_fly(data):
    bout_numbers = np.unique(data.behavior_bout)
    fly_ids = get_fly_id(data, bout_numbers)
    fly_data = dict()
    for j in range(len(bout_numbers)):
        fly = fly_ids[bout_numbers[j]]
        bout_length = len(data[data.behavior_bout == bout_numbers[j]])
        if fly not in fly_data:
            fly_data[fly] = 0
        fly_data[fly] += bout_length       
    fly_names_sorted = sorted(fly_data, key=fly_data.get, reverse=True) 
    return fly_data, fly_names_sorted

# assign a unique bout number to each bout (previously had duplicates due
# to running experiments on different days)
def adjust_bout_numbers(data):
    
    dates = np.unique(data.date)
    cumulative_bouts = 1
    data_new = pd.DataFrame()
    
    for i in range(len(dates)):
        
        subset = data[data['date'] == dates[i]]
        bout_numbers = np.unique(subset.behavior_bout)
        bout_numbers_new = np.arange(cumulative_bouts, cumulative_bouts + len(bout_numbers), 1)
        cumulative_bouts = cumulative_bouts + len(bout_numbers)
        
        for j in range(len(bout_numbers)):
            subset['behavior_bout'].replace({bout_numbers[j]:bout_numbers_new[j]}, inplace=True)
            
        data_new = pd.concat([data_new, subset])
    
    return data_new


In [None]:
# define dimensionality-reduction specific functions
def get_data_range(d):
    cols = d.columns
    df_max = d.max()
    df_min = d.min()
    df_range = df_max.copy()
    for j in range(len(cols)):
        if d.dtypes[j] == float and (df_max[cols[j]] is not None and df_min[cols[j]] is not None):
            df_range[cols[j]] = abs(df_max[[cols[j]]].sub(df_min[[cols[j]]]))
        else:
            df_range[cols[j]] = d[cols[j]]
    return df_range

def get_window_features(d, index, window_feature, window_size):
    if window_feature == 'min':
        row = dict(d.min())
    elif window_feature == 'max':
        row = dict(d.max())
    elif window_feature == 'range':
        row = dict(get_data_range(d))
    else: 
        row = dict(d.iloc[index+window_size//2])
    return row

# vector type can be time, fft, or both
def get_embedding(data, angle_vars, attr = 'flyid', vector_type='time', window_feature = None, window_size=32, noverlap=8):

    bout_numbers = np.unique(data['behavior_bout'].astype(int))
    hanning = signal.windows.hann(window_size)

    X_list = []
    X_list_stft = []
    rows = []

    for j in range(len(bout_numbers)):
        d = data.loc[data['behavior_bout'] == bout_numbers[j]]
        X_in = np.array(d.loc[:, angle_vars])
        for i in range(0, X_in.shape[0]-window_size+1, noverlap):
            row = get_window_features(d, i, window_feature, window_size) # takes value from middle of the window by default
            row['bout_len'] = len(X_in)
            vals = X_in[i:i+window_size].T
            fft = np.log(np.abs(np.fft.rfft(vals * hanning)))
            X_list.append(vals)
            X_list_stft.append(fft)
            rows.append(row)

    X_times = np.array(X_list)
    X_stft = np.array(X_list_stft)
    X_times_flat = X_times.reshape(X_times.shape[0], -1)
    X_stft_flat = X_stft.reshape(X_stft.shape[0], -1)

    if vector_type == 'time':
        X_out = X_times_flat
    elif vector_type == 'fft':
        X_out = X_stft_flat
    else:
        X_out = np.hstack([X_times_flat, X_stft_flat])

    attributes = pd.DataFrame.from_dict(rows)
    feature = np.array(attributes[attr])
    feature_unique = np.unique(feature)
    feature_map = dict(zip(feature_unique, range(len(feature_unique))))
    attributes[attr] = np.array([feature_map[f] for f in feature])

    return X_out, attributes, feature_unique
                   
def run_umap(X_in, **params):
    X_scaled = StandardScaler().fit_transform(X_in)
    model = umap.UMAP(**params)
    X_map = model.fit_transform(X_scaled)
    return X_map

def run_pca(X_in, **params):
    X_scaled = StandardScaler().fit_transform(X_in)
    pca = PCA(**params)
    X_map = pca.fit_transform(X_scaled)
    return X_map

def run_tsne(X_in, **params):
    X_scaled = StandardScaler().fit_transform(X_in)
    tsne = TSNE(**params) 
    X_map = tsne.fit_transform(X_scaled)
    return X_map

def plot_embedding(X_map, attributes, flyid_unique, method, val='flyid', 
                   cmap='tab20', norm=False, legend=True, cbar=False, figsize=(8,4), title='', background = 'white'):
    
    values = np.array(attributes[val])
    if norm:
        a, b = np.percentile(values[~np.isnan(values)], [5, 95])
        normalizer = plt.Normalize(vmin=a, vmax=b)
        values = normalizer(values)

    cmap = plt.get_cmap(cmap)
    colors = cmap(values)

    plt.figure(figsize=figsize)
    ax = plt.gca()
    ax.set_facecolor('xkcd:' + background)
    plt.title(title, fontsize = 14)
    ax.scatter(X_map[:,0], X_map[:, 1], s=2, c=colors, alpha = 0.5)
    plt.xlabel(method + '1', fontsize = 14)
    plt.ylabel(method + '2', fontsize = 14)
    
    if cbar: 
        cb = plt.colorbar(cm.ScalarMappable(cmap=cmap))
        cb.set_clim(np.min(values), np.max(values))
        cb.set_label(val)

    if X_map.shape[1] > 2:
        fig = plt.figure(figsize=figsize)
        ax = plt.gca()
        ax.set_facecolor('xkcd:' + background)
        plt.title(title, fontsize = 14)
        ax.scatter(X_map[:,2], X_map[:, 1], s=2, c=colors, alpha = 0.5)
        plt.xlabel(method + '3', fontsize = 14)
        plt.ylabel(method + '2', fontsize = 14)
        
        if cbar: 
            cb = plt.colorbar(cm.ScalarMappable(cmap=cmap))
            cb.set_clim(np.min(values), np.max(values))
            cb.set_label(val)
            
        plt.figure(figsize=figsize)
        plt.title(title, fontsize = 14)
        ax = plt.gca()
        ax.set_facecolor('xkcd:' + background)
        ax.scatter(X_map[:,0], X_map[:, 2], s=2, c=colors, alpha = 0.5)
        plt.xlabel(method + '1', fontsize = 14)
        plt.ylabel(method + '3', fontsize = 14)
        
        if cbar: 
            cb = plt.colorbar(cm.ScalarMappable(cmap=cmap))
            cb.set_clim(np.min(values), np.max(values))
            cb.set_label(val)

    if legend:
        handles = [Patch(facecolor=cmap(i), label=flyid_unique[i]) for i in range(len(flyid_unique))]
        handles = [Patch(color=cmap(i/(len(flyid_unique)-0.9999)), label = x) for i, x in enumerate(flyid_unique)]
        plt.figure(figsize=(3, 3))
        plt.legend(handles=handles, ncol=2, loc='center')
        plt.axis('off')
        
def plot_embedding_3d(X_map, attributes, flyid_unique, method, val='flyid', 
                   cmap='tab20', norm=False, legend=True, cbar=False, figsize=(8,4), title='', background = 'white'):
    
    a = 0.6
    values = np.array(attributes[val])
    if norm:
        a, b = np.percentile(values[~np.isnan(values)], [5, 95])
        normalizer = plt.Normalize(vmin=a, vmax=b)
        values = normalizer(values)

    cmap = plt.get_cmap(cmap)
    colors = cmap(values)

    %matplotlib notebook
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, projection='3d')
    ax.set_title(title, fontsize = 14)
    ax.scatter(X_map[:,0], X_map[:, 1], X_map[:, 2], c=colors, alpha = a)
    ax.set_xlabel(method + '1', fontsize = 14)
    ax.set_ylabel(method + '2', fontsize = 14)

    if legend:
        handles = [Patch(facecolor=cmap(i), label=flyid_unique[i]) for i in range(len(flyid_unique))]
        handles = [Patch(color=cmap(i/(len(flyid_unique)-0.9999)), label = x) for i, x in enumerate(flyid_unique)]
        plt.figure(figsize=(3, 3))
        plt.legend(handles=handles, ncol=2, loc='center')
        plt.axis('off')
        
    if cbar: 
        cb = plt.colorbar(cm.ScalarMappable(cmap=cmap))
        cb.set_clim(np.min(values), np.max(values))
        cb.set_label(val)
        
#ax = fig.add_subplot(111, projection='3d')
#ax.scatter(X_map[:, 0], X_map[:, 1], X_map[:, 2], c=cmap(labels_ids), s=7)
#ax.scatter(X_map[:, 0], X_map[:, 1], X_map[:, 2], c=labels_ids, cmap = cmap, s=7)

In [None]:
def get_embedding_range(data, angle_vars, val, attr = 'flyid', vector_type='time', window_feature = None, window_size=32, noverlap=8):

    bout_numbers = np.unique(data['behavior_bout'].astype(int))
    hanning = signal.windows.hann(window_size)

    X_list = []
    X_list_stft = []
    rows = []

    for j in range(len(bout_numbers)):
        d = data.loc[data['behavior_bout'] == bout_numbers[j]]
        X_in = np.array(d.loc[:, angle_vars])
        for i in range(0, X_in.shape[0]-window_size+1, noverlap):
            row = dict(d.iloc[i+window_size//2]) # takes value from middle of the window by default
            row[val] = abs(np.max(d[val]) - np.min(d[val]))
            row['bout_len'] = len(X_in)
            vals = X_in[i:i+window_size].T
            fft = np.log(np.abs(np.fft.rfft(vals * hanning)))
            X_list.append(vals)
            X_list_stft.append(fft)
            rows.append(row)

    X_times = np.array(X_list)
    X_stft = np.array(X_list_stft)
    X_times_flat = X_times.reshape(X_times.shape[0], -1)
    X_stft_flat = X_stft.reshape(X_stft.shape[0], -1)

    if vector_type == 'time':
        X_out = X_times_flat
    elif vector_type == 'fft':
        X_out = X_stft_flat
    else:
        X_out = np.hstack([X_times_flat, X_stft_flat])

    attributes = pd.DataFrame.from_dict(rows)
    feature = np.array(attributes[attr])
    feature_unique = np.unique(feature)
    feature_map = dict(zip(feature_unique, range(len(feature_unique))))
    attributes[attr] = np.array([feature_map[f] for f in feature])

    return X_out, attributes, feature_unique

In [None]:
behavior = 't1_grooming'
prefix = '/media/turritopsis/katie/grooming/summaries/v3-b2'
data_path = os.path.join(prefix, 't1_grooming_subset_curated.parquet')
data = pd.read_parquet(data_path, engine='fastparquet')
# data_path = os.path.join(prefix, 'lines-' + behavior + '_onball_processed_all_gs.parquet')
# df = pd.read_parquet(data_path, engine='fastparquet')
#df2 = df[df.grooming_score < 8.25]
#data = df2[df2.grooming_score > 1.6]

fps = 300.0 # know this for this dataset
bout_numbers = np.unique(np.array(data.behavior_bout))
angle_vars = np.unique([v for v in data.columns
              if some_contains(v, ['_BC', '_flex', '_rot', '_abduct'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range', 'fictrac'])])

In [None]:
plt.plot(data.grooming_score, 'o')

In [None]:
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_rot', '_x', '_y', '_z'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range', 'fictrac'])
              and v[:2] == 'L1']

fly_dict = get_fly_id(data, bout_numbers)
videos = get_videos(bout_numbers, data)
fly_videos = fly_to_video(data)
dif_flies = np.unique(list(fly_dict.values()))
fly_data, fly_names_sorted = data_per_fly(data)
bout_length_dict = get_bout_lengths(data)

In [None]:
# how much data do we have for each fly?  
# fly_data, flyids_sorted = data_per_fly(data)
#data = data[data.flyid.isin(flyids_sorted[:10])]
# bout_numbers = np.unique(data['behavior_bout'].astype(int))

In [None]:
# how much data per fly? 
bout_numbers = np.unique(data['behavior_bout'].astype(int))
flies = np.unique(data.flyid)
for n, fly in enumerate(flies):
    print(n)
    print(fly)
    print(len(data[data['flyid'] == fly]))

In [None]:
# add velocity and acceleration columns to data
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range', 'fictrac'])]

dt = 1/fps
s = 1.0/dt
s2 = 1.0 / (dt * dt)

for j in range(len(bout_numbers)):
    mask = data.behavior_bout == bout_numbers[j]
    bout_df = data.loc[mask]
    for ang in angle_vars:
        bout = np.array(bout_df[ang])
        data.loc[mask, ang + '_d1'] = signal.savgol_filter(bout, 5, 3, deriv=1) * s
        data.loc[mask, ang + '_d2'] = signal.savgol_filter(bout, 5, 3, deriv=2) * s2

In [None]:
# determine grooming frequency for each bout (using the argmax of the psd)
fly_ids = []
bout_freq_dict = dict()
angle_vars = np.unique([v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])])
cols = [x + '_freq' for x in angle_vars]
cols = ['behavior_bout'] + cols

for k in range(len(bout_numbers)):
    
    bout = data[data.behavior_bout == bout_numbers[k]]
    fly_ids.extend(np.unique(bout.flyid))
    freqs = [bout_numbers[k]]
    
    for j in range(len(angle_vars)):    
    
        t1 = bout.iloc[0:][angle_vars[j]]
        t1 = t1[np.isfinite(t1)]
        f, pxx = signal.welch(t1, fs=300, nperseg=1024)  
        pxx = pxx - np.mean(pxx)
        freqs.append(f[np.argmax(pxx)])
        
    bout_freq_dict[bout_numbers[k]] = freqs
                   
bout_freq_df = pd.DataFrame(bout_freq_dict, index = cols).T
bout_freq_df['flyid'] = fly_ids
path = '/media/turritopsis/katie/grooming/t1-grooming'
out_fname = os.path.join(path, 'grooming_bout_freqs.csv')
bout_freq_df.to_csv(out_fname)

# append bout frequency to data
data = data.merge(bout_freq_df.loc[:, bout_freq_df.columns != 'flyid'], on='behavior_bout', how='left')

###### look at fly videos

In [None]:
flyid = '4_0 5272019'
ang = 'L1C_rot'
fly_videos = fly_to_video(data)
videos = fly_videos[flyid]
for j in range(len(videos)):
    bout = data[data.filename == videos[j]]
    print(videos[j] + ' (' + str(len(bout)) + ')')
    print(np.ptp(bout[ang]))

In [None]:
# umap for each fly to check for outliars
feature = 'L1A_abduct'
legs = ['L1']
cmap = 'plasma'
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_freq', '_range'])
              and v[:2] in legs]

flies = np.unique(data.flyid)
for fly in flies: 
    fly_data = data[data.flyid == fly]
    X_in, attributes, flyids = get_embedding(fly_data, angle_vars, attr = 'flyid', vector_type = 'both', window_feature = None)
    X_map = run_umap(X_in, n_components=2, min_dist=0.0, n_neighbors=30, random_state=5)
    title = '{} embedding ({})'.format(legs[0], fly)
    plot_embedding(X_map, attributes, flyids, val=feature, cmap=cmap, norm=True, legend=True, cbar=True, method = 'UMAP', title=title, background='black')
    print(max(fly_data.grooming_score))
    print(min(fly_data.grooming_score))

###### embed angles by fly id 

In [None]:
feature = 'L1B_flex'
feature = 'line'
legs = ['L1']
cmap = 'Spectral'
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_freq', '_range'])
              and v[:2] in legs]
X_in, attributes, flyids = get_embedding(data, angle_vars, attr = feature, vector_type = 'both', window_size = 16, window_feature = None)
X_map = run_umap(X_in, n_components=3, min_dist=0.0, n_neighbors=30, random_state=5)


In [None]:
# grooming possibly paramatrized by L1A_abduct, L1B_rot, L1B_flex, L1C_flex 
cmap = 'Spectral'
feature = 'L1A_abduct'
feature = 'line'
title = '{} embedding ({})'.format(legs[0], feature)
plot_embedding(X_map, attributes, flyids, val=feature, cmap=cmap, norm=True, legend=True, cbar=True, method = 'UMAP', title=title, background='black')

In [None]:
val = 'flyid'
# val = feature
cmap = 'Spectral'
method = 'UMAP'
norm = True
values = np.array(attributes[val])
title =  'T1 embedding'
if norm:
    a, b = np.percentile(values[~np.isnan(values)], [5, 95])
    normalizer = plt.Normalize(vmin=a, vmax=b)
    values = normalizer(values)

cmap = plt.get_cmap(cmap)
colors = cmap(values)

plt.figure(figsize=(9,4))
plt.title(title, fontsize = 14)
plt.scatter(X_map[:,0], X_map[:, 1], s=1, c=colors)
plt.xlabel(method + '1', fontsize = 14)
plt.ylabel(method + '2', fontsize = 14)
ax = plt.gca()
# ax.set_facecolor('xkcd:black') 
cb = plt.colorbar(cm.ScalarMappable(cmap=cmap))
cb.set_clim(np.min(values), np.max(values))
cb.set_label(val)

if X_map.shape[1] > 2:
    plt.figure(figsize=(9,4))
    plt.title(title, fontsize = 14)
    plt.scatter(X_map[:,2], X_map[:, 1], s=1, c=colors)
    plt.xlabel('UMAP3', fontsize = 14)
    plt.ylabel('UMAP2', fontsize = 14)
    ax = plt.gca()
    ax.set_facecolor('xkcd:black')
    cb = plt.colorbar(cm.ScalarMappable(cmap=cmap))
    cb.set_clim(np.min(values), np.max(values))
    cb.set_label(val)
    
if X_map.shape[1] > 2:
    plt.figure(figsize=(9,4))
    plt.title(title, fontsize = 14)
    plt.scatter(X_map[:,0], X_map[:, 2], s=1, c=colors)
    plt.xlabel('UMAP1', fontsize = 14)
    plt.ylabel('UMAP3', fontsize = 14)
    ax = plt.gca()
    ax.set_facecolor('xkcd:black')
    cb = plt.colorbar(cm.ScalarMappable(cmap=cmap))
    cb.set_clim(np.min(values), np.max(values))
    cb.set_label(val)

###### embedding angles for one leg, coloring by a single joint

In [None]:
legs = ['L1']
feature = 'L1C_flex'
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])
              and v[:2] in legs]

X_in, attributes, flyids = get_embedding(data, angle_vars, attr = 'flyid', vector_type = 'both', window_feature = None)
# X_in, attributes, flyids = get_embedding_range(data, angle_vars, feature, attr = 'flyid', vector_type = 'both', window_feature = 'range')
X_map = run_umap(X_in, n_components=2, min_dist=0.0, n_neighbors=30, random_state=3)
#title = '{} embedding ({})'.format(legs[0] + ' and ' + legs[1], feature)
title = '{} embedding ({})'.format(legs[0], feature)
plot_embedding(X_map, attributes, flyids, val=feature, cmap='plasma', norm=True, legend=False, cbar=True, method = 'UMAP', title=title, figsize = (9,4))

###### color embeddings by velocity

In [None]:
legs = ['L1']
for leg in legs:
    feature = leg + 'B_flex_d1'
    angle_vars = [v for v in data.columns if some_contains(v, ['_d1']) and v[:2] == leg]
    X_in, attributes, flyids = get_embedding(data, angle_vars, attr = 'flyid', vector_type = 'both', window_feature = None)
    X_map = run_umap(X_in, n_components=2, min_dist=0.0, n_neighbors=30)
    title = '{} embedding ({})'.format(leg, feature)
    plot_embedding(X_map, attributes, flyids, val=feature, cmap='plasma', norm=True, legend=False, cbar=True, method = 'UMAP', title = title)

###### color embeddings by acceleration

In [None]:
legs = ['L1', 'R1']
for leg in legs:
    feature = leg + 'C_flex_d2'
    angle_vars = [v for v in data.columns if some_contains(v, ['_d2']) and v[:2] == leg]
    #X_in, attributes, flyids = get_embedding(data, angle_vars, attr = 'flyid', vector_type = 'time', window_feature = None)
    #X_map = run_umap(X_in, n_components=2, min_dist=0.0, n_neighbors=30)
    title = '{} embedding ({})'.format(leg, feature)
    plot_embedding(X_map, attributes, flyids, val=feature, cmap='plasma', norm=True, legend=False, cbar=True, method = 'UMAP', title = title)

###### check embeddings of all legs: position

In [None]:
legs = ['L1', 'L2', 'L3', 'R1', 'R2', 'R3']
legs = ['L1', 'R1']
leg_embeddings = []
for leg in legs:
    angle_vars = [v for v in data.columns
                  if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
                  and not some_contains(v, ['_d1', '_d2', '_freq', '_range']) and v[:2] == leg]

    X_in, attributes, flyids = get_embedding(data, angle_vars, attr = 'flyid', vector_type='both', window_feature = None)
    X_map = run_umap(X_in, n_components=2, min_dist=0.0, n_neighbors=30)
    # X_map = run_tsne(X_in, n_components=2)
    leg_embeddings.append((X_map, X_in, attributes, flyids))

In [None]:
feature = 'flyid'
for embeddings, leg in zip(leg_embeddings, legs):
    X_map, X_in, attributes, flyids = embeddings
    title = '{} embedding ({}, {})'.format(leg, feature, 'time')
    plot_embedding(X_map, attributes, flyids, val=feature, cmap='Spectral', norm=True, legend=False, cbar=False, figsize=(7, 4), method = 'UMAP', title=title)
    ax = plt.gca()
    ax.set_facecolor('xkcd:black')
    # plt.tight_layout()
    # plt.show()
    # plt.savefig(os.path.join(figure_path, 'position_flyid_time', 'UMAP ' + title))


In [None]:
angle_vars = np.unique([v for v in data.columns.str[2:]
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])])

for ang in angle_vars:
    for embeddings, leg in zip(leg_embeddings, legs):
        feature = leg + ang
        X_map, X_in, attributes, flyids = embeddings
        title = '{} embedding ({}, {})'.format(leg, feature, 'both')
        plot_embedding(X_map, attributes, flyids, val=feature, cmap='plasma', norm=True, legend=False, cbar=True, figsize=(7, 4), method = 'UMAP', title=title)
        plt.tight_layout()
        plt.savefig(os.path.join(figure_path, 'position_both', 'UMAP ' + title))
        plt.show()

In [None]:
# ANGLE RANGE
legs = ['L1', 'R1']
leg_embeddings = []

for leg in legs:
    angle_vars = [v for v in data.columns
                  if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
                  and not some_contains(v, ['_d1', '_d2', '_freq', '_range']) and v[:2] == leg]
    
    for ang in angle_vars:    
        X_in, attributes, flyids = get_embedding_range(data, angle_vars, ang, attr = 'flyid', vector_type='both', window_feature = 'range')
        X_map = run_umap(X_in, n_components=2, min_dist=0.0, n_neighbors=30)
        leg_embeddings.append((X_map, X_in, attributes, flyids))

        X_map, X_in, attributes, flyids = embeddings
        title = '{} embedding ({}, {})'.format(leg, ang, 'both')
        plot_embedding(X_map, attributes, flyids, val=ang, cmap='plasma', norm=True, legend=False, cbar=True, figsize=(7, 4), method = 'UMAP', title=title)
        plt.tight_layout()
        plt.savefig(os.path.join(figure_path, 'position_both_range', 'UMAP ' + title))
        plt.show()

###### check embeddings of all legs: velocity

In [None]:
legs = ['L1', 'L2', 'L3', 'R1', 'R2', 'R3']
legs = ['L1', 'R1']
leg_embeddings_v = []
for leg in legs:
    angle_vars = [v for v in data.columns
                  if some_contains(v, ['_flex_d1', '_abduct_d1', '_rot_d1', '_BC_d1'])
                  and v[:2] == leg]

    X_in_v, attributes, flyids = get_embedding(data, angle_vars, attr = 'flyid', vector_type='time', window_feature = None)
    X_map_v = run_umap(X_in_v, n_components=2, min_dist=0.0, n_neighbors=30)
    # X_map_v = run_tsne(X_in_v, n_components=2)
    leg_embeddings_v.append((X_map_v, X_in_v, attributes, flyids))

In [None]:
angle_vars = np.unique([v for v in data.columns.str[2:]
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])])

for ang in angle_vars:
    for embeddings, leg in zip(leg_embeddings_v, legs):
        feature = leg + ang + '_d1'
        X_map, X_in, attributes, flyids = embeddings
        title = '{} embedding ({}, {})'.format(leg, feature, 'time')
        plot_embedding(X_map, attributes, flyids, val=feature, cmap='plasma', norm=True, legend=False, cbar=True, figsize=(7, 4), method = 'UMAP', title=title)
        plt.tight_layout()
        # plt.savefig(os.path.join(figure_path, 'velocity_time', 'UMAP ' + title))
        plt.show()

###### check embeddings of all legs: acceleration

In [None]:
legs = ['L1', 'L2', 'L3', 'R1', 'R2', 'R3']
legs = ['L1', 'R1']
leg_embeddings_a = []
for leg in legs:
    angle_vars = [v for v in data.columns
                  if some_contains(v, ['_flex_d1', '_abduct_d1', '_rot_d1', '_BC_d1'])
                  and v[:2] == leg]

    X_in_a, attributes, flyids = get_embedding(data, angle_vars, attr = 'flyid', vector_type='fft', window_feature = None)
    X_map_a = run_umap(X_in_a, n_components=2, min_dist=0.0, n_neighbors=30)
    leg_embeddings_a.append((X_map_a, X_in_a, attributes, flyids))

In [None]:
angle_vars = np.unique([v for v in data.columns.str[2:]
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])])

for ang in angle_vars:
    for embeddings, leg in zip(leg_embeddings_a, legs):
        feature = leg + ang + '_d2'
        X_map_a, X_in_a, attributes, flyids = embeddings
        title = '{} embedding ({}, {})'.format(leg, feature, 'fft')
        plot_embedding(X_map_a, attributes, flyids, val=feature, cmap='plasma', norm=True, legend=False, cbar=True, figsize=(7, 4), method = 'UMAP', title=title)
        plt.tight_layout()
        plt.savefig(os.path.join(figure_path, 'acceleration_fft', 'UMAP ' + title))
        plt.show()

###### check embeddings of all legs: velocity and acceleration

In [None]:
legs = ['L1']
leg_embeddings_va = []
for leg in legs:
    angle_vars = [v for v in data.columns
                  if some_contains(v, ['_d1', '_d2']) and v[:2] in legs]

    X_in_va, attributes, flyids = get_embedding(data, angle_vars, attr = 'flyid', vector_type='both', window_feature = None)
    X_map_va = run_umap(X_in_va, n_components=3, min_dist=0.0, n_neighbors=30)
    leg_embeddings_va.append((X_map_va, X_in_va, attributes, flyids))

In [None]:
angle_vars = np.unique([v for v in data.columns.str[2:]
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])
              and v[:2] in legs])

for ang in angle_vars:
    for embeddings, leg in zip(leg_embeddings_va, legs):
        feature = leg + ang + '_freq'
        feature = 'L1B_flex_freq'
        X_map_va, X_in_va, attributes, flyids = embeddings
        title = '{} embedding ({}, {})'.format(leg, feature, 'both')
        plot_embedding(X_map_va, attributes, flyids, val=feature, cmap='plasma', norm=True, legend=False, cbar=True, figsize=(8, 4), method = 'UMAP', title=title)
        # plot_embedding_3d(X_map_va, attributes, flyids, val=feature, cmap='plasma', norm=True, legend=False, cbar=True, figsize=(7, 4), method = 'UMAP', title=title)
        # plt.tight_layout()
        # plt.savefig(os.path.join(figure_path, 'derivatives_freq_both', 'UMAP ' + title))
        plt.show()

###### color by grooming frequency

In [None]:
angle_vars = np.unique([v for v in data.columns.str[2:]
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])])

for ang in angle_vars:
    for embeddings, leg in zip(leg_embeddings, legs):
        feature = leg + ang + '_freq'
        X_map, X_in, attributes, flyids = embeddings
        title = '{} embedding ({}, {})'.format(leg, feature, 'both')
        plot_embedding(X_map, attributes, flyids, val=feature, cmap='plasma', norm=True, legend=False, cbar=True, figsize=(7, 4), method = 'UMAP', title=title)
        plt.tight_layout()
        plt.savefig(os.path.join(figure_path, 'position_freq_both', 'UMAP ' + title))
        plt.show()

###### color according to how close bouts occurred in time

In [None]:
leg = 'L1'
fly_ids = np.unique(data.flyid)
data['trial_num'] = (data['rep']-1)*data['condnum'] + data['condnum']
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range']) and v[:2] == leg]

fly_embeddings = []
for fly in fly_ids:
    fly_data = data[data.flyid == fly]
    X_in, attributes, flyids = get_embedding(fly_data, angle_vars, attr = 'flyid', vector_type='both', window_feature = None)
    X_map = run_umap(X_in, n_components=2, min_dist=0.0, n_neighbors=30)
    fly_embeddings.append((X_map, X_in, attributes, flyids))

In [None]:
angle_vars = np.unique([v for v in data.columns.str[2:]
              if some_contains(v, ['_flex', '_abduct', '_rot', '_BC'])
              and not some_contains(v, ['_d1', '_d2', '_freq'])])
feature = 'trial_num'

for embeddings, fly in zip(fly_embeddings, fly_ids):
    X_map, X_in, attributes, flyids = embeddings
    num_bouts = len(np.unique(attributes['behavior_bout']))
    title = '{} embedding ({}, {}, {} bouts)'.format(leg, feature, fly, num_bouts)
    plot_embedding(X_map, attributes, flyids, val=feature, cmap='plasma', norm=True, legend=False, cbar=True, figsize=(7, 4), method = 'UMAP', title=title)
    plt.tight_layout()

In [None]:
def adjust_rot_angles(angles, angle_names):
    conds = ['2', '3', 'L1A', 'L1B', 'L1C', 'R1A', 'R1B', 'R1C']
    offsets = np.array([-50, -20, 20, -70, 10, 20, 70, -30])
    for j in range(len(conds)):
        rot_angs = [r for r in angle_names if '_rot' in r and conds[j] in r]
        for ang in rot_angs:
            r = np.array(angles[ang])
            r[r > offsets[j]] = r[r > offsets[j]] - 360
            angles[ang] = r
        
    abduct_angs = [r for r in angle_names if '_abduct' in r or 'A_flex' in r]
    for ang in abduct_angs:
        r = np.array(angles[ang])
        r[r > 50] = r[r > 50] - 360
        angles[ang] = r
        
    return angles

# load data
prefix = r'/media/turritopsis/pierre/gdrive/latest/behavior/T1_grooming'
data_path = os.path.join(prefix, 'T1_grooming_all.csv.gz')
data = pd.read_csv(data_path, compression = 'gzip')
data['behavior_bout'] = data['T1_grooming_bout_number']
data['flyid'] = data['fly'].astype(str) + ' ' + data['date'].astype(str)
data = data[~data.behavior_bout.isnull()]
data = adjust_bout_numbers(data)
bout_length_dict = get_bout_lengths(data)
data = remove_short_bouts(data, 50)

fps = 300.0 # know this for this dataset
only_t1 = False
normalize = True

angle_vars = np.unique([v for v in data.columns
              if some_contains(v, ['_BC', '_flex', '_rot', '_abduct'])
              and not some_contains(v, ['_d1', '_d2', '_freq'])])
data = correct_angles(data, angle_vars)
data = adjust_rot_angles(data, angle_vars)
bout_numbers = np.unique(data.behavior_bout)


In [None]:
# plot root angles for each bout (same plot)
angle_vars = np.unique([v for v in data.columns
              if some_contains(v, ['_flex'])
              and not some_contains(v, ['_d1', '_d2', '_freq'])])

for ang in angle_vars:
    figure = plt.figure(figsize = (8,4))
    plt.title('{} angles'.format(ang), fontsize = 14)
    for j in range(len(bout_numbers)):
        bout = data[data.behavior_bout == bout_numbers[j]]
        plt.plot(range(len(bout[ang])), bout[ang])
        plt.xlabel('time (s)', fontsize = 14)
        plt.ylabel('{} angle (deg)'.format(ang), fontsize = 14)
    plt.show()