In [None]:
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt
import os
import seaborn as sns
import umap
import csv
import warnings
from collections import OrderedDict
from collections import defaultdict
from scipy import signal, stats
from scipy.linalg import norm
from mpl_toolkits import mplot3d
from mpl_toolkits.mplot3d import axes3d
from matplotlib.patches import Patch
from sklearn.preprocessing import StandardScaler
from pydmd import MrDMD, DMD, OptDMD

plt.rcParams.update({'figure.max_open_warning': 0})
warnings.simplefilter("ignore")

%run /media/turritopsis/katie/grooming/t1-grooming/grooming_functions.ipynb
%matplotlib inline

sns.set()
sns.set_style('ticks')

In [None]:
# delay embedding
def time_embed(points, stacks=8, stride=1, ixs=None):
    if ixs is None:
        ixs = np.arange(len(points)-stacks*stride-1)
    if len(points.shape) == 1:
        ps = np.vstack([points[ixs+i*stride] for i in range(stacks)]).T
    else:
        ps = np.hstack([points[ixs+i*stride] for i in range(stacks)])
    return ps

def rmse(predictions, targets):
    error = np.sqrt(np.mean((predictions-targets)**2))
    return error

def reconstruction_error(dmd, angle_vars):
    dmd.original_time['dt'] = 1
    X = dmd.snapshots.T
    X_recon = dmd.reconstructed_data.T
    rmse_dict = dict()
    for j in range(len(angle_vars)):
        error = rmse(X_recon[:, j], X[:, j])
        rmse_dict[angle_vars[j]] = error
    return rmse_dict

def get_psd(data, angle_vars, bout_number, plot = False):

    freq_dict = dict()
    bout = data[data.behavior_bout == bout_number]

    for j in range(len(angle_vars)):    

        t1 = bout.iloc[0:][angle_vars[j]]
        t1 = t1[np.isfinite(t1)]
        f, pxx = signal.welch(t1, fs=300, nperseg=2048)  
        pxx = pxx - np.mean(pxx)
        peak_freq = f[np.argmax(pxx)]
        freq_dict[angle_vars[j]] = peak_freq 

        if plot:
            fig = plt.figure(figsize = (8,4))
            plt.title('power spectrum of ' + angle_vars[j] + ' angles (bout ' + str(int(bout_number)) + ')', fontsize = 14)           
            plt.xlabel('frequency (Hz)', fontsize = 14)
            plt.ylabel('power spectral density', fontsize = 14)                                                                 
            plt.plot(f, pxx, label = 'bout ' + str(int(bout_number)), c = 'k')
            sns.despine()
            plt.show()    
    
    return freq_dict

# plotting
def plot_dynamics(dynamics, together = True, ylim = None, xlim = None):

    if not together: 
        for j in range(dynamics.shape[0]):
            fig = plt.figure()
            plt.title('n = ' + str(j+1), fontsize = 14)
            plt.xlabel('time', fontsize = 14)
            plt.ylabel('amplitude', fontsize = 14)
            plt.plot(range(len(dynamics.T[:, j])), dynamics.T[:, j])
            plt.show()
            
    fig = plt.figure(figsize = (12, 3))
    plt.title('time evolution of dynamics', fontsize = 14)
    plt.xlabel('time (frames)', fontsize = 14)
    plt.ylabel('amplitude', fontsize = 14)
    if ylim is not None:
        plt.ylim([-1*ylim, ylim])
    if xlim is not None:
        plt.xlim([0, xlim])
    plt.plot(range(len(dynamics.T)), dynamics.T)
    plt.show()
            
def plot_modes(modes, together = True, angles = None):
    
    if not together:
        for j in range(modes.shape[1]):
            fig = plt.figure(figsize = (12, 3))
            plt.title('DMD modes (n = ' + str(j+1) + ')', fontsize = 14)
            plt.xlabel('feature', fontsize = 14)
            plt.ylabel('amplitude', fontsize = 14)
            plt.plot(range(len(modes[:, j])), modes[:, j])
            plt.show()
    
    if angles is not None:
        angle_modes = modes[:len(angles), :]
        fig = plt.figure(figsize = (12, 3))
        plt.title('DMD modes', fontsize = 14)
        plt.xlabel('feature', fontsize = 14)
        plt.xticks(np.arange(len(modes)), labels = angles, rotation = 90)
        plt.ylabel('amplitude', fontsize = 14)
        plt.plot(range(len(angle_modes)), angle_modes)        
        plt.show()   
    
    fig = plt.figure(figsize = (12, 3))
    plt.title('DMD modes', fontsize = 14)
    plt.xlabel('feature', fontsize = 14)
    plt.xticks(np.arange(len(modes)), labels = [])
    plt.ylabel('amplitude', fontsize = 14)
    plt.plot(range(len(modes)), modes)        
    plt.show()

def plot_eigenvalues(dmd, show_unit_circle=True, figsize=(5, 5), title='eigenvalues of $\widetilde{A}$'):

    sns.set()
    sns.set_style('ticks')
    plt.figure(figsize=figsize)
    plt.title(title, fontsize = 14)
    plt.gcf()
    ax = plt.gca()

    points = []
    for j in range(len(dmd.eigs)):
        p = ax.plot(dmd.eigs.real[j], dmd.eigs.imag[j], 'o', markerfacecolor = 'w', markeredgecolor ='k', markersize = dmd.amplitudes[j], label='eigenvalues')
        points.append(p)

    # set limits for axis
    limit = np.max(np.ceil(np.absolute(dmd.eigs)))
    ax.set_xlim((-limit - 0.2, limit + 0.2))
    ax.set_ylim((-limit - 0.2, limit + 0.2))

    plt.ylabel('imaginary part', fontsize = 14)
    plt.xlabel('real part', fontsize = 14)

    if show_unit_circle:
        unit_circle = plt.Circle(
            (0., 0.),
            1.,
            color='green',
            fill=False,
            label='unit circle',
            linestyle='--')
        ax.add_artist(unit_circle)

    gridlines = ax.get_xgridlines() + ax.get_ygridlines()
    for line in gridlines:
        line.set_linestyle('-.')
    ax.grid(True)

    ax.set_aspect('equal')            
    plt.axhline(y = 0, xmin = -1, xmax = 1, linestyle = ':', color = 'k')
    plt.axvline(x = 0, ymin = -1, ymax = 1, linestyle = ':', color = 'k')

    if show_unit_circle:
        ax.add_artist(plt.legend([points, unit_circle], ['eigenvalues', 'unit circle'], loc=1))
    else:
        ax.add_artist(plt.legend([points], ['eigenvalues'], loc=1))

    sns.despine()
    plt.show()

In [None]:
# load data
def adjust_rot_angles(angles, angle_names):
    conds = ['2', '3', 'L1A', 'L1B', 'L1C', 'R1A', 'R1B', 'R1C']
    offsets = np.array([-50, -20, 20, -70, 10, 20, 70, -30])
    for j in range(len(conds)):
        rot_angs = [r for r in angle_names if '_rot' in r and conds[j] in r]
        for ang in rot_angs:
            r = np.array(angles[ang])
            r[r > offsets[j]] = r[r > offsets[j]] - 360
            angles[ang] = r
        
    abduct_angs = [r for r in angle_names if '_abduct' in r or 'A_flex' in r]
    for ang in abduct_angs:
        r = np.array(angles[ang])
        r[r > 50] = r[r > 50] - 360
        angles[ang] = r
        
    return angles

prefix = r'/media/turritopsis/pierre/gdrive/latest/behavior/T1_grooming'
data_path = os.path.join(prefix, 'T1_grooming_all.csv.gz')
data = pd.read_csv(data_path, compression = 'gzip')
data['behavior_bout'] = data['T1_grooming_bout_number']
data['flyid'] = data['fly'].astype(str) + ' ' + data['date'].astype(str)
data = data[~data.behavior_bout.isnull()]
data = adjust_bout_numbers(data)
data = remove_short_bouts(data, 60)

angle_names = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot'])
             and v[:2] in ['L1', 'R1']]
data = correct_angles(data, angle_names)
data = adjust_rot_angles(data, angle_names)

In [None]:
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_rot', '_x', '_y', '_z'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])
              and v[:2] == 'L1']
features = ['L1B_rot_avg_range', 'L1A_flex_avg_range', 'L1E_z_avg_range', 'L1D_z', 'L1E_z']
flip = [False, False, False, True, True]
data = compute_grooming_scores(data, angle_vars, features, flip = flip, dist=20, norm=False)
data = data[data.grooming_score < 8.25]
data = data[data.grooming_score > 1.6]

bout_numbers = np.unique(np.array(data.behavior_bout))
bout_length_dict = get_bout_lengths(data)
fly_dict = get_fly_id(data, bout_numbers)
videos = get_videos(bout_numbers, data)
fly_videos = fly_to_video(data)
dif_flies = np.unique(list(fly_dict.values()))
fly_data_sorted, fly_names_sorted = data_per_fly(data)

In [None]:
[v for v in data.columns
              if some_contains(v, ['_flex', '_rot', '_x', '_y', '_z'])
              and not some_contains(v, ['_d1', '_d2', '_freq', '_range'])
              and v[:2] == 'L1']

###### dmd on a single grooming bout 

In [None]:
stacks = 10
stride = 7
rank = 40

In [None]:
np.unique(data['behavior_bout'])
bout_length_dict

In [None]:
bout_num = 299
check = data['behavior_bout'] == bout_num
d = data.loc[check]
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot'])
             and v[:2] in ['L1', 'R1']]
X = np.array(d.loc[:, angle_vars])
X_scaled = StandardScaler().fit_transform(X)
X_scaled = time_embed(X_scaled, stacks=stacks, stride=stride)
dmd = DMD(svd_rank=rank, tlsq_rank=0, exact=True, opt=True)
# dmd = MrDMD(svd_rank=3, tlsq_rank=0, max_level=8, max_cycles=2)
dmd.fit(X_scaled.T)

In [None]:
cols = X.shape[1]
cmap = plt.get_cmap('tab20_r', len(angle_vars))

fig = plt.figure(figsize = (8,4))
plt.title('input angles (raw)', fontsize = 14)
plt.xlabel('time (frames)', fontsize = 14)
plt.ylabel('angle (deg)', fontsize = 14)
ax = plt.gca()
for j in range(cols):
    ax.plot(X[:, j], label = angle_vars[j], color = cmap(j))

handles, labels = ax.get_legend_handles_labels()
sns.despine()
plt.show()

fig = plt.figure()
plt.legend(handles = handles, labels = labels, ncol=2)
plt.axis('off')
plt.show()

fig = plt.figure(figsize = (8,4))
plt.title('input angles (normalized)', fontsize = 14)
plt.xlabel('time (frames)', fontsize = 14)
plt.ylabel('amplitude (norm.)', fontsize = 14)
for j in range(cols):
    plt.plot(X_scaled[:, j], color = cmap(j%16))
sns.despine()
plt.show()

###### power spectra of angles

In [None]:
# plot power spectral density for each bout and find max freq
bout_number = 299
freq_dict = get_psd(data, angle_vars, bout_number, plot = True)
freq_dict

###### eigendecomposition

In [None]:
fps = 300.0
dmd.original_time['dt'] = 1/fps
frequencies = np.abs(dmd.frequency)
log_eigvals = np.log(dmd.eigs)

figure = plt.figure(figsize = (5,5))
sns.set()
sns.set_style('ticks')

plt.title('', fontsize = 14)
plt.xlabel('real(log($\lambda$))', fontsize = 14)
plt.ylabel('frequency (hz)', fontsize = 14)
plt.axvline(x = 0, ymin = 0, ymax = np.max(frequencies), linestyle = ':', color = 'k')
for j in range(len(dmd.eigs)):
    plt.plot(log_eigvals[j], frequencies[j], 'o', markerfacecolor = 'w', markeredgecolor ='k', markersize = dmd.amplitudes[j], label='Eigenvalues')
        
sns.despine()
plt.show()

In [None]:
fps = 300.0
dmd.original_time['dt'] = 1/fps
frequencies = np.abs(dmd.frequency)
amplitudes = np.abs(dmd.amplitudes)

figure = plt.figure(figsize = (5,5))
sns.set()
sns.set_style('ticks')

plt.title('DMD power spectrum (bout {})'.format(bout_num), fontsize = 14)
plt.ylabel('mode amplitude', fontsize = 14)
plt.xlabel('frequency (hz)', fontsize = 14)
for j in range(len(frequencies)):
    plt.plot(frequencies[j], amplitudes[j], 'o', markerfacecolor = 'w', markeredgecolor ='k', markersize = 12)
    plt.plot([frequencies[j]]*100, np.linspace(0, amplitudes[j], 100), linestyle = '-', linewidth = 1, color = 'k')

sns.despine()
plt.show()

In [None]:
plot_eigenvalues(dmd)

dmd.amplitudes
fps = 300.0
dmd.original_time['dt'] = 1/fps
frequencies = np.abs(dmd.frequency)

In [None]:
dmd.original_time['dt'] = 1
plot_dynamics(dmd.dynamics[:, :])
plot_modes(dmd.modes, angles = angle_vars)

###### signal reconstruction

In [None]:
dmd.original_time['dt'] = 1
X_recon = dmd.reconstructed_data.T

for j in range(X.shape[1]):
    fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize = (14, 3))
    ax = ax.T.flatten()
    ax[0].plot(X_scaled[:, j], 'xkcd:indigo blue')
    ax[0].set_xlabel('time (frames)', fontsize = 14)
    ax[0].set_ylabel('amplitude (norm.)', fontsize = 14)
    ax[0].set_title('{} angles (original)'.format(angle_vars[j]), fontsize = 14)
    
    ax[1].plot(X_recon[:, j], 'xkcd:wine red')
    ax[1].set_xlabel('time (frames)', fontsize = 14)
    ax[1].set_title('{} angles (reconstruction)'.format(angle_vars[j]), fontsize = 14)
    ax[1].set_ylabel('amplitude (norm.)', fontsize = 14)
    ax[1].tick_params(labelleft = True)
    
    plt.subplots_adjust(wspace = 0.2)
    plt.show()

In [None]:
dmd.original_time['dt'] = 1
X_recon = dmd.reconstructed_data.T
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot'])
             and v[:2] in ['L1', 'R1']]

for j in range(len(angle_vars)):
    fig, ax = plt.subplots(1, 1, figsize = (8, 3))
    ax.plot(X_scaled[:, j], 'xkcd:indigo blue', label = 'original')
    ax.plot(X_recon[:, j], 'xkcd:wine red', label = 'reconstruction')
    ax.set_xlabel('time (frames)', fontsize = 14)
    ax.set_ylabel('amplitude (norm.)', fontsize = 14)
    ax.set_title('{} angles'.format(angle_vars[j]), fontsize = 14)    
    plt.subplots_adjust(wspace = 0.2)
    plt.legend(fontsize = 12)
    sns.despine()
    plt.show()

In [None]:
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot'])
             and v[:2] in ['L1', 'R1']]
rmse_dict = reconstruction_error(dmd, angle_vars)

fig = plt.figure()
plt.title('RMSE of the reconstructed grooming signal (bout {})'.format(str(int(bout_number))), fontsize = 14)
plt.xlabel('joints', fontsize = 14)
plt.xticks(np.arange(len(angle_vars)), labels = angle_vars, rotation = 90)
plt.ylabel('RMSE', fontsize = 14)
for j in range(len(angle_vars)):
    error = rmse_dict[angle_vars[j]]
    plt.plot(j, error, 'ko')
    plt.plot([j]*100, np.linspace(0, error, 100), linestyle = ':', linewidth = 1, color = 'k')

sns.despine()
plt.show()

###### rmse as a function of rank

In [None]:
stacks = 10
stride = 7
bout_num = 299
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot'])
             and v[:2] in ['L1', 'R1']]

check = data['behavior_bout'] == bout_num
d = data.loc[check]
X = np.array(d.loc[:, angle_vars])
X_scaled = StandardScaler().fit_transform(X)
X_scaled = time_embed(X_scaled, stacks=stacks, stride=stride)

ranks = np.arange(1, X_scaled.shape[1], 1)
errors = np.zeros([len(ranks), len(angle_vars)])
for j in range(len(ranks)):
    dmd = DMD(svd_rank=int(ranks[j]), tlsq_rank=0, exact=True, opt=True)
    dmd.fit(X_scaled.T)
    rmse_dict = reconstruction_error(dmd, angle_vars)
    errors[j, :] = np.array(list(rmse_dict.values()))

In [None]:
for j in range(len(angle_vars)):
    fig = plt.figure(figsize = (8, 3))
    plt.title('RMSE of reconstructed {} angles (bout {})'.format(angle_vars[j], str(int(bout_number))), fontsize = 14)
    plt.xlabel('svd rank', fontsize = 14)
    plt.ylabel('RMSE', fontsize = 14)
    plt.plot(range(len(errors[:, j])), errors[:, j], color='k', linewidth = 1)
    plt.scatter(range(len(errors[:, j])), errors[:, j], s = 10, facecolor = 'w', edgecolor = 'k')
    sns.despine()
    plt.show()

###### dmd on multiple grooming bouts

In [None]:
def get_all_bouts(data, angle_vars, delay_embed = False, stacks = None, stride = None):
    
    all_X = []
    all_Y = []
    bout_numbers = np.unique(data.behavior_bout)
    for j in range(len(bout_numbers)): 
        bout = data[data.behavior_bout == bout_numbers[j]]
        X = np.array(bout.loc[:, angle_vars])
        X_scaled = StandardScaler().fit_transform(X)
        if delay_embed:
            X_scaled = time_embed(X_scaled, stacks=stacks, stride=stride)
        all_X.extend(X_scaled[:-1])
        all_Y.extend(X_scaled[1:])

    all_X = np.array(all_X).T
    all_Y = np.array(all_Y).T

    return all_X, all_Y

def plot_prediction_error(train_error, test_error, title_extension = ''):
    sns.set()
    sns.set_style('ticks')
    fig = plt.figure(figsize = (8, 4))
    plt.title('DMD prediction error {}'.format(title_extension), fontsize = 14)
    plt.xlabel('svd rank', fontsize = 14)
    plt.ylabel('normalized error', fontsize = 14)
    plt.plot(range(len(train_error)), train_error, color = 'xkcd:indigo blue', linewidth = 1, label = 'train')
    plt.plot(range(len(test_error)), test_error, color = 'xkcd:wine red', linewidth = 1, label = 'test')
    sns.despine()
    plt.legend(fontsize = 12)
    plt.show()
    
def get_bout_ensemble(data, angle_vars, bout_numbers, stacks, stride):
    all_X = []
    all_Y = []
    for j in range(len(bout_numbers)): 
        bout = data[data.behavior_bout == bout_numbers[j]]
        X = np.array(bout.loc[:, angle_vars])
        X_scaled = StandardScaler().fit_transform(X)
        X_scaled = time_embed(X_scaled, stacks=stacks, stride=stride)
        all_X.extend(X_scaled[:, :-1])
        all_Y.extend(X_scaled[:, 1:])
    all_X = np.array(all_X).T
    all_Y = np.array(all_Y).T
    
    return all_X, all_Y

def get_prediction_error(X_train, Y_train, X_test, Y_test):
    
    ranks = np.arange(1, X_train.shape[0], 1)
    train_error = []
    test_error = []
    for j in range(len(ranks)):

        dmd = OptDMD(svd_rank=int(ranks[j]), factorization = 'svd', opt = True)
        dmd.fit(X_train, Y_train)

        y_predict = dmd.predict(X_train)
        train_error.append(norm(y_predict-Y_train)/norm(Y_train))

        y_predict = dmd.predict(X_test)
        test_error.append(norm(y_predict-Y_test)/norm(Y_test))
        
    return train_error, test_error

def split_data(all_X, all_Y, training_fraction):
    idx = int(len(all_X)*training_fraction)
    X_train = all_X[:, :idx]
    Y_train = all_Y[:, :idx]
    X_test = all_X[:, idx:]
    Y_test = all_Y[:, idx:]
    return X_train, Y_train, X_test, Y_test 

def plot_error_matrix(matrix, error_type = None):
    
    fig = plt.figure()
    ax = plt.gca()
    if error_type is None:
        ax.set_title('DMD prediction error difference', fontsize = 14)
        error_type = 'test error - train'
    else:
        ax.set_title('DMD prediction error ({}ing dataset)'.format(error_type), fontsize = 14)
    ax.set_ylabel('number of stacks', fontsize = 14)
    ax.set_xlabel('number of strides', fontsize = 14)
    stacks = np.arange(0, matrix.shape[0], 1)
    strides = np.arange(0, matrix.shape[1], 1)
    ax.set_xticks(stacks)
    ax.set_xticklabels(stacks+1)
    ax.set_yticks(strides)
    ax.set_yticklabels(strides+1)
    im = ax.imshow(matrix)
    cbar = fig.colorbar(im, orientation = 'vertical')
    cbar.set_label('{} error (norm.)'.format(error_type), fontsize = 14)
    cbar.set_clim(vmin = 0, vmax = np.amax(matrix))
    cbar.ax.tick_params(length = 3)
    plt.show()

###### prediction on a grooming bout (from a single long bout)

In [None]:
training_bout = 105
testing_bout = 106
stacks = 1
stride = 2
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot'])
             and v[:2] in ['L1', 'R1']]

# get training data
train_mask = data['behavior_bout'] == training_bout
d_train = data.loc[train_mask]
X = np.array(d_train.loc[:, angle_vars])
X_train = StandardScaler().fit_transform(X)
X_train = time_embed(X_train, stacks=stacks, stride=stride)
x_train = X_train.T[:, :-1]
y_train = X_train.T[:, 1:]

# get testing data 
test_mask = data['behavior_bout'] == testing_bout
d_test = data.loc[test_mask]
X = np.array(d_test.loc[:, angle_vars])
X_test = StandardScaler().fit_transform(X)
X_test = time_embed(X_test, stacks=stacks, stride=stride)
x_test = X_test.T[:, :-1]
y_test = X_test.T[:, 1:]

ranks = np.arange(1, X_test.shape[1], 1)
train_error = []
test_error = []
for j in range(len(ranks)):
    
    dmd = DMD(svd_rank=int(ranks[j]), opt = True)
    dmd.fit(X_train.T)
    
    y_predict = dmd.predict(x_train)
    train_error.append(norm(y_predict-y_train)/norm(y_train))
    
    y_predict = dmd.predict(x_test)
    test_error.append(norm(y_predict-y_test)/norm(y_test))


In [None]:
plot_prediction_error(train_error, test_error, title_extension = '(bout ' + str(testing_bout) + ')')

###### prediction on a grooming bout (from an ensemble of bouts)

In [None]:
stacks = 10
stride = 3
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot'])
              and not some_contains(v, ['_range'])
              and v[:2] in ['L1', 'R1']]
fly_name = '4_0 5222019'
fly_name = '5_0 5272019'
fly_data = data#[data.flyid == fly_name]
bout_numbers = np.unique(fly_data.behavior_bout)
testing_bouts = [bout_numbers[-1], bout_numbers[-2]]
training_bouts = bout_numbers[~np.isin(bout_numbers, testing_bouts)]
testing_bouts = np.array([106])
training_bouts = np.array([105])

all_X_train, all_Y_train = get_bout_ensemble(fly_data, angle_vars, training_bouts, stacks, stride)
all_X_test, all_Y_test = get_bout_ensemble(fly_data, angle_vars, testing_bouts, stacks, stride)
#all_X, all_Y = get_bout_ensemble(fly_data, angle_vars, bout_numbers, stacks, stride)
#all_X_train, all_Y_train, all_X_test, all_Y_test = split_data(all_X, all_Y, 0.75)
train_error, test_error = get_prediction_error(all_X_train, all_Y_train, all_X_test, all_Y_test) 
plot_prediction_error(train_error, test_error)

In [None]:
# vary stacks and stride
fly_name = '4_0 5222019'
fly_data = data # [data.flyid == fly_name]
bout_numbers = np.unique(fly_data.behavior_bout)
training_fraction = 0.75
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot'])
              and not some_contains(v, ['_range'])
             and v[:2] in ['L1', 'R1']]

stacks = np.arange(1, 10, 1)
strides = np.arange(1, 10, 1)
test_error_matrix = np.zeros([len(stacks), len(strides)])
train_error_matrix = np.zeros([len(stacks), len(strides)])

testing_bouts = np.array([106])
training_bouts = np.array([105])
idx = int(training_fraction * len(bout_numbers))
training_bouts = bout_numbers[:idx]
testing_bouts = bout_numbers[idx:]

for i in range(len(stacks)):
    for j in range(len(strides)): 
        
        X_train, Y_train = get_bout_ensemble(fly_data, angle_vars, training_bouts, stacks[i], strides[j])
        X_test, Y_test = get_bout_ensemble(fly_data, angle_vars, testing_bouts, stacks[i], strides[j])
        # all_X, all_Y = get_bout_ensemble(fly_data, angle_vars, bout_numbers, stacks[i], strides[j])
        # X_train, Y_train, X_test, Y_test = split_data(all_X, all_Y, training_fraction)
        train_error, test_error = get_prediction_error(X_train, Y_train, X_test, Y_test) 
        test_error_matrix[i, j] = np.mean(test_error)
        train_error_matrix[i, j] = np.mean(train_error)
        #plot_prediction_error(train_error, test_error, title_extension = 'on a single bout')
        print('stacks = ' + str(stacks[i]) + ', strides = ' + str(strides[j]))

plot_error_matrix(train_error_matrix, error_type='train')  
plot_error_matrix(test_error_matrix, error_type='test')
plot_error_matrix(test_error_matrix-train_error_matrix, error_type=None)

In [None]:
plot_error_matrix(train_error_matrix[:-1,:], error_type='train')  
plot_error_matrix(test_error_matrix[:-1,:], error_type='test')
plot_error_matrix(test_error_matrix[:-1,:]-train_error_matrix[-1,:], error_type=None)

###### prediction on a grooming bout (trained on 3/4 of the entire dataset)

In [None]:
stacks = 8
stride = 1
training_fraction = 0.75
bout_numbers = np.unique(data.behavior_bout)
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot'])
              and not some_contains(v, ['_range'])
              and v[:2] in ['L1', 'R1']]

all_X, all_Y = get_bout_ensemble(data, angle_vars, bout_numbers, stacks, stride)
X_train, Y_train, X_test, Y_test = split_data(all_X, all_Y, training_fraction)
train_error, test_error = get_prediction_error(X_train, Y_train, X_test, Y_test) 
plot_prediction_error(train_error, test_error)

###### prediction on a grooming bout (trained on 3/4 of a fly's data)

In [None]:
stacks = 8
stride = 1
training_fraction = 0.75
bout_numbers = np.unique(data.behavior_bout)
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot'])
              and not some_contains(v, ['_range'])
             and v[:2] in ['L1', 'R1']]

fly_names = np.array(list(set(data.flyid)))
for j in range(len(fly_names)):
    fly_data = data[data.flyid == fly_names[j]]
    bout_numbers = np.unique(fly_data.behavior_bout)
    if len(bout_numbers) <= 1:
        continue
    all_X, all_Y = get_bout_ensemble(data, angle_vars, bout_numbers, stacks, stride)
    X_train, Y_train, X_test, Y_test = split_data(all_X, all_Y, training_fraction)
    train_error, test_error = get_prediction_error(X_train, Y_train, X_test, Y_test) 
    plot_prediction_error(train_error, test_error, title_extension = '(fly {}, {} bouts)'.format(fly_names[j], str(len(bout_numbers))))

###### check test error and visualize predictions

In [None]:
# vary stacks and stride
fly_name = '4_0 5222019'
fly_data = data # [data.flyid == fly_name]
bout_numbers = np.unique(fly_data.behavior_bout)
training_fraction = 0.75
angle_vars = [v for v in data.columns
              if some_contains(v, ['_abduct', '_rot', '_flex'])
              and not some_contains(v, ['_range'])
             and v[:2] in ['L1', 'R1']]

stacks = 6
stride = 1

testing_bouts = np.array([106])
training_bouts = np.array([105])
idx = int(training_fraction * len(bout_numbers))
training_bouts = bout_numbers[:idx]
testing_bouts = bout_numbers[idx:]

X_train, Y_train = get_bout_ensemble(fly_data, angle_vars, training_bouts, stacks, stride)
X_test, Y_test = get_bout_ensemble(fly_data, angle_vars, testing_bouts, stacks, stride)

ranks = np.arange(1, X_train.shape[0], 1)
train_error = []
test_error = []

for j in range(len(ranks)):

    dmd = OptDMD(svd_rank=int(ranks[j]), factorization = 'svd', opt = True)
    dmd.fit(X_train, Y_train)

    y_predict_train = dmd.predict(X_train)
    train_error.append(norm(y_predict_train-Y_train)/norm(Y_train))

    y_predict_test = dmd.predict(X_test)
    test_error.append(norm(y_predict_test-Y_test)/norm(Y_test))
    
    #dmd.original_time['dt'] = 1
    # X_recon = dmd.reconstructed_data.T

    for k in range(len(angle_vars)):
        print('rank = ' + str(ranks[j]))
        fig, ax = plt.subplots(1, 1, figsize = (8, 3))
        ax.plot(Y_test[k, 2300:2700], 'xkcd:indigo blue', label = 'original')
        ax.plot(y_predict_test[k, 2300:2700], 'xkcd:wine red', label = 'prediction')
        ax.set_xlabel('time (frames)', fontsize = 14)
        ax.set_ylabel('amplitude (norm.)', fontsize = 14)
        ax.set_title('{} angles'.format(angle_vars[k]), fontsize = 14)    
        plt.subplots_adjust(wspace = 0.2)
        plt.legend(fontsize = 12)
        sns.despine()
        plt.show()
         
plot_prediction_error(train_error, test_error, title_extension = 'on a single bout')


In [None]:
angle_vars = [v for v in data.columns
              if some_contains(v, ['_flex', '_abduct', '_rot'])
             and v[:2] in ['L1', 'R1']]
print(angle_vars)
data.loc[:, angle_vars]