# Outline

What happens in this notebook:

- Read in tracking results (or angles directly)
- Transform tracking results to fish-centric angles
- Apply PCA to angle data
- Inspect features of PCA (e.g. eigenmodes, variance explained, mode distributions)
- Animate fish representation from mode values using matplotlib widgets
- Import and inspect different data

# Imports

In [None]:
import numpy as np
np.set_printoptions(suppress=True)
import h5py
import pandas as pd
import math
import os
import cv2
import base64
import warnings
import joypy
from scipy.interpolate import interp1d
import pickle
import glob
from joblib import dump, load
from pathlib import Path
import cv2
import pykalman
import sklearn
import scipy.io as sio
from multiprocessing import Pool, RawArray
from sklearn.linear_model import LinearRegression
from scipy.optimize import linear_sum_assignment
from scipy.signal import savgol_filter as sgf
import scipy.ndimage
from scipy.interpolate import interp1d
from scipy import interpolate
import time
%load_ext autoreload
import os
%autoreload 2
import sys

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import seaborn as sns
#plt.style.use('seaborn-white')
%matplotlib widget
from mpl_toolkits.mplot3d import Axes3D
from IPython.display import display
from ipywidgets import Video, Image, VBox, Text
from sidecar import Sidecar
from IPython.display import display
from ipywidgets import interact, interactive, fixed, interact_manual, FloatSlider, Layout, AppLayout
import ipywidgets as widgets
import sklearn
from pathlib import Path

import csv
import scipy.stats as stats

# Functions to transfrom raw data into angles

In [None]:
def calc_angle_xpos(vector):
    return math.atan2(vector[1], vector[0]) * 180 / math.pi

def positions_to_angles_spine(positions):
    all_angles = []
    for i in range(len(positions[0])):
        x = positions[0][i]
        y = positions[1][i]
        angles = [0]
        for j in range(1, len(x)-1):
            vector_1 = [(x[j-1] - x[j]), (- y[j-1] + y[j])]
            vector_2 = [(x[j+1] - x[j]), (- y[j+1] + y[j])]
            
            angle_1 = calc_angle_xpos(vector_1)
            angle_2 = calc_angle_xpos(vector_2)

            diff = angle_2 - angle_1
            if diff < 0:
                angle = 180 + diff
            else:
                angle = diff - 180
            angles.append(angle)
        all_angles.append(np.cumsum(angles))
        
    return all_angles

def remove_average(angles_all):
    return np.array([[i - np.mean(angles) for i in angles] for angles in angles_all])


# functions for 3d
def angles_spherical(v):
    theta = math.atan(np.sqrt(v[0]*v[0] + v[1]*v[1])/(v[2])) * 180 / math.pi
    phi = calc_angle_xpos(v)
    return [phi, theta]

def spherical_coords(data):
    angles1 = [0]
    angles2 = [0]
    for j in range(1, len(data)-1):
        
        vector_1 = [(data[j-1][0] - data[j][0]), (- data[j-1][1] + data[j][1]), (data[j-1][2] - data[j][2])]
        vector_2 = [(data[j+1][0] - data[j][0]), (- data[j+1][1] + data[j][1]), (data[j+1][2] - data[j][2])]
        
        angles_2 = angles_spherical(vector_2)
        angles_1 = angles_spherical(vector_1)
        
        diff = angles_2[0] - angles_1[0]
        if diff < 0:
            angle = 180 + diff
        else:
            angle = diff - 180
        angles1.append(angle)
        
        if angles_2[1] > 0:
            start_angle = 90 - angles_2[1]
        else:
            start_angle = - (90 + angles_2[1])
            
        if angles_1[1] > 0:
            diff2 = start_angle + 90 - angles_1[1]
        else:
            diff2 = start_angle - (90 + angles_1[1])
            
        angles2.append(diff2)
        
    return [np.cumsum(angles1) - np.nanmean(np.cumsum(angles1)), np.cumsum(angles2) - np.nanmean(np.cumsum(angles2))]

def positions_to_angles_spine_3d(positions):
    all_angles = []
    for i in positions:
        angles = spherical_coords(i[0])
        all_angles.append([item for sublist in angles for item in sublist])
    return np.array(all_angles)

# PCA class

In [None]:
class PCA_angles(object):

    def __init__(self, angles_series, dimension=2):
        
        # parse inputs
        self.dimension = dimension
        self.numProjectionModes = len(angles_series)
        self.angles_series = angles_series
        self.data_order = range(0,len(angles_series))
        self.covmat = self._compute_covariance_matrix_of_rep()
        self.mean_rep = self._compute_mean_rep()
        self.eigvals, self.eigvecs = self._eigendecompose_covariance_matrix()
        self.cum_var_explained = self._compute_variance_explained()
        self.pca_tseries = self._project_rep_onto_numProjectionModes()


    def convert_mode_weights_to_rep(self, mode_weights):
        ''' Convert these mode weight to a representation using the eigenvectors
        '''
        return mode_weights.dot(self.eigvecs[:, :len(mode_weights)].T) + self.mean_rep
    
    def convert_rep_to_mode_weights(self, single_frame_rep):
        ''' Project this rep onto the modes.
        '''
        
        numModes=self.numProjectionModes
        mode_weights = np.dot( (single_frame_rep-self.mean_rep), self.eigvecs[:, :numModes])
        return mode_weights
    
    def convert_rep_to_plotable_positions(self, single_frame_rep):
        
        if self.dimension == 2:
        
            point = [0, 0]
            frame_positions = [point]
            length = 2

            for angle in single_frame_rep:
                point = [point[0] + length * np.cos(angle/180*np.pi),point[1] + length * np.sin(angle/180*np.pi)]
                frame_positions.append(point)

            return frame_positions
        
        if self.dimension == 3:
            point = [0, 0, 0]
            frame_positions = [point]
            length = 2

            for i in range(len(single_frame_rep)//2):

                angle1 = single_frame_rep[i]
                angle2 = single_frame_rep[i + len(single_frame_rep)//2]

                angle = [angle1, angle2]

                point1 = point[0] + length * np.cos(angle[0]/180*np.pi) * np.sin((angle[1]+90)/180*np.pi)
                point2 = point[1] + length * np.sin(angle[0]/180*np.pi) * np.sin((angle[1]+90)/180*np.pi)
                point3 = point[2] + length * np.cos((90-angle[1])/180*np.pi)

                point = [point1, point2, point3]

                frame_positions.append(point)

            return np.vstack(frame_positions)

    def _compute_covariance_matrix_of_rep(self):
        ''' Compute the cov matrix of the timeseries data
        '''
        covmat = np.ma.cov(np.ma.masked_invalid(self.angles_series), rowvar=False)
        return covmat


    def _compute_mean_rep(self):
        ''' Compute the mean configuration
        '''
        return np.nanmean(self.angles_series, axis=0)


    def _eigendecompose_covariance_matrix(self):
        ''' Compute the eigenspectrum, with eigenvectors sorted by
            eigvals highest to lowest
        '''
        # get the sorted eigenspectrum
        eig_vals, eig_vecs = np.linalg.eig(self.covmat)
        idxs = np.argsort(eig_vals)
        idxs = idxs[::-1]
        sorted_eig_vals = eig_vals[idxs]
        sorted_eig_vecs = eig_vecs[:, idxs]
        sorted_eig_vecs = sorted_eig_vecs.data
        return sorted_eig_vals, sorted_eig_vecs

    def _compute_variance_explained(self):
        ''' Return a timeseries of the variance explained by successive modes
        '''
        variance_explained = self.eigvals/ np.sum(self.eigvals)
        cum_variance_explained = np.cumsum(variance_explained)
        return cum_variance_explained

    def _project_rep_onto_numProjectionModes(self):
        ''' Project rep onto pca space, using the number of modes
            set by self.numProjectionModes
        '''
        projection_eig_vecs = np.copy(self.eigvecs[:, :self.numProjectionModes])
        mean_subtraced_tseries = self.angles_series - self.mean_rep
        pca_tseries = np.dot(mean_subtraced_tseries, projection_eig_vecs)
        return pca_tseries

# Functions to show info etc.

In [None]:
def variance_explained(handlers, pc_max):
    '''
    creates variance explained histogram for 1 or more pca results
    prints cum_var_explained for each input
    
    handlers = (list of multiple) PCA_angles object(s)
    pc_max = amount of pc's on x-axis to be shown
    '''
    
    # make code work for single and multiple (list) inputs
    if not isinstance(handlers, list):
        handlers = [handlers]
        
    # get the dimension of the timeseries
    dim = handlers[0].covmat.shape[0]

    # get the xvalues for plotting (non zero indexed)
    x_vals = np.arange(1,dim+1)
    x_text_vals = [' ']
    for i in range(1, pc_max):
        x_text_vals.append("PC" + str(i))

    # make the plot
    fig, ax = plt.subplots(1,1, figsize=(8,4))

    if len(handlers) == 1:
        ax.plot(x_vals, handlers[0].cum_var_explained, color='red',linewidth=2)
        
    for i,j in enumerate(handlers):
        print(j.cum_var_explained)
        w = 0.9/len(handlers)
        offset = (i - (len(handlers)-1)/2) * w
        ax.bar(x_vals+offset, j.cum_var_explained, width=w, label=i)

    ax.set_ylabel('Variance Explained')
    ax.set_xlabel('Principal Components')
    ax.set_xlim(0,len(x_text_vals))
    ax.set_xticks(np.arange(len(x_text_vals)))
    ax.set_xticklabels(x_text_vals, rotation='vertical', fontsize=10)
    ax.tick_params(length=10,width=1)

    ax.hlines(1, 0, 19, colors='black',linestyles='dashed',linewidth=2)
    ax.legend(framealpha=1)
    fig.show()
    fig.tight_layout()
    
    return


def plot_eigenfish_2d(handlers, num_modes):
    '''
    plots eigenfish for 1 or more pca results
    
    handlers = (list of multiple) PCA_angles object(s)
    num_modes = amount of modes to be plotted
    '''
    
    # make code work for single and multiple (list) inputs
    if not isinstance(handlers, list):
        handlers = [handlers]
    
    # create figure 
    if num_modes > 1:
        colcount = 2
    else:
        colcount = 1
    rowcount = round(num_modes / colcount)
    fig, ax = plt.subplots(nrows=rowcount, ncols=colcount)
    
    for i in range(num_modes):
        for j,k in enumerate(handlers):
            if rowcount > 1:
                obj = ax[i//2,i%2]
            else:
                obj = ax[i%2]
            obj.plot(np.round(k.eigvecs,3).swapaxes(0,1)[i], linewidth=2,label=j)
            obj.set_title("Eigenfish " + str(i+1))
            obj.set_xlabel('S')
            obj.set_ylabel('Angle [degrees]')
            obj.set_ylim(-1, 1)
            obj.legend()
    
    fig.tight_layout()
    fig.show()
    
    return


def plot_mode_dist(handler, num_modes):
    '''
    plots mode distribution for 1 pca result
    
    handler = PCA_angles object
    num_modes = amount of modes to be plotted
    '''
    
    # create figure 
    if num_modes > 1:
        colcount = 2
    else:
        colcount = 1
    rowcount = round(num_modes / colcount)
    fig, ax = plt.subplots(nrows=rowcount, ncols=colcount)
    
    pca_tseries = np.copy(handler.pca_tseries)

    for i in range(num_modes):
        
        if rowcount > 1:
            obj = ax[i//2,i%2]
        else:
            obj = ax[i%2]
        
        pc = pca_tseries[:,i][~np.isnan(pca_tseries[:,i])]
        obj.hist(data, bins=np.arange(np.min(pc)-0.1, np.max(pc)+0.1, 0.5), density=True)
        obj.set_title("mode " + str(i+1))
    
    fig.tight_layout()
    fig.show()
    
    return


def data_stuff(handler):
    '''
    Shows a data for one pca
    
    handler: PCA_angles object
    
    data shown:
    - eigenvectors matrix
    - mean configuration
    '''
    
    # The eigenvectors
    eigvec_df = pd.DataFrame(data=np.round(handler.eigvecs,3),
                             index=handler.data_order,
                             columns=['mode {0}'.format(i+1) for i in range(handler.covmat.shape[0])])

    print(eigvec_df)
    
    # the mean configuration
    mean_rep_df = pd.DataFrame(data=np.round(handler.mean_rep,3),
                               index=handler.data_order,
                               columns=['mean representation'])

    print(mean_rep_df)
    
    return


def pca_histdata(coord_handler, data_range, all_data=False):
    pca_tseries = np.copy(coord_handler.pca_tseries)
    
    if all_data:
        pc1 = pca_tseries[data_range[0]:data_range[1],0]
        pc2 = pca_tseries[data_range[0]:data_range[1],1]
#         pc3 = pca_tseries[data_range[0]:data_range[1],2]
    else:
        pc1 = pca_tseries[data_range[0]:data_range[1],0][~np.isnan(pca_tseries[data_range[0]:data_range[1],0])]
        pc2 = pca_tseries[data_range[0]:data_range[1],1][~np.isnan(pca_tseries[data_range[0]:data_range[1],1])]
#         pc3 = pca_tseries[data_range[0]:data_range[1],2][~np.isnan(pca_tseries[data_range[0]:data_range[1],2])]

    return [pc1, pc2]#, pc3]

# Load In data

### angles from h5

#### 2d

In [None]:
with h5py.File('./singlefish_angles20.h5', 'r') as f:
    singlefish_angles = np.copy(f['angles20'])

In [None]:
with h5py.File('./twofish_angles20.h5', 'r') as f:
    fish1_angles = np.copy(f['fish1'])
    fish2_angles = np.copy(f['fish2'])

#### 3d

In [None]:
file_path2 = './angles3d.h5'

with h5py.File(file_path2, 'r') as f:
    angles_3D = np.copy(f["angles"])

### import spine data and convert to angular

#### 2d

In [None]:
exp_path = './Project/LargeCrop/Code/two_fish_skeleton_data/spine_output/'
    
file_paths = sorted([i for i in os.listdir(exp_path) if ".h5" in i])
t0 = time.time()

fish1_angles = []
fish2_angles = []

for i in file_paths:
    with h5py.File(exp_path + i, 'r') as f:
        pos_fish1 = [np.copy(f['x1']), np.copy(f['y1'])]
        pos_fish2 = [np.copy(f['x2']), np.copy(f['y2'])]
        fish1_angles = [*fish1_angles, *positions_to_angles_spine(pos_fish1)]
        fish2_angles = [*fish2_angles, *positions_to_angles_spine(pos_fish2)]
        
fish1_angles = remove_average(fish1_angles)
fish2_angles = remove_average(fish2_angles)

print(time.time() - t0)

#### 3d

In [None]:
file_path = './spinedata3D_new.h5'

with h5py.File(file_path, 'r') as f:
    spinedata3D = np.copy(f['spinedata3D'])
    
angles_3D = positions_to_angles_spine_3d(data_3D[0])

### Example of saving angle data as h5 file

In [None]:
with h5py.File('./twofish_angles20.h5', 'w') as f:
    f.create_dataset('fish1', data=fish1_angles)
    f.create_dataset('fish2', data=fish2_angles)

# Create PCA objects

In [None]:
# Shows how it is possible to easily take subsets when desired

coord_1 = PCA_angles(fish1_angles)
coord_1_fight = PCA_angles(fish1_angles[24000:234000])
coord_1_nofight = PCA_angles([*fish1_angles[:24000], *fish1_angles[234000:]])
coord_2 = PCA_angles(fish2_angles)
coord_2_fight = PCA_angles(fish2_angles[24000:234000])
coord_2_nofight = PCA_angles([*fish2_angles[:24000], *fish2_angles[234000:]])
coord_single = PCA_angles(singlefish_angles)

In [None]:
coord_3d = PCA_angles(angles_3D, 3) # 3 is for 3d

# The playground!

In [None]:
variance_explained([coord_single, coord_1, coord_2], 5)

In [None]:
plot_eigenfish_2d([coord_single, coord_1, coord_2], 4)

In [None]:
plot_mode_dist(coord_single, 4)

In [None]:
data_stuff(coord_single)

# Some code i used to investigate modes
## heatmap/scatter/moving hist

In [None]:
names = ['fish1','fish2','singlefish','fish1_fight','fish2_fight', 'fish1_nofight', 'fish2_nofight']
handlers = [coord_1, coord_2, coord_single, coord_1_fight, coord_2_fight, coord_1_nofight, coord_2_nofight]
choices = [0,1,2,3,4,5,6]

In [None]:
f1 = pca_histdata(handlers[3], [0, -1])
f2 = pca_histdata(handlers[4], [0, -1])

heatmap, xedges, yedges = np.histogram2d(f1[0], f1[1], range=[[-400, 400], [-300, 300]], bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
plt.figure()
plt.imshow((heatmap).T, extent=extent, cmap='coolwarm')
plt.colorbar()
plt.show()

In [None]:
colors = ['red', 'blue', 'green']

plt.figure()

for i in [0,1]:
    mode_series = pca_histdata(handlers[i+3], [0, -1])
    plt.scatter(mode_series[0], mode_series[1], c=colors[i], alpha=0.5, label=i)

plt.axvline(x=0, c='black', linewidth=0.5)
plt.axhline(y=0, c='black', linewidth=0.5)
plt.legend()
plt.show()

In [None]:
#histogram movie

Writer = animation.writers['ffmpeg']
writer = Writer(fps=10, metadata=dict(artist='LOS'), bitrate=3500)
fig = plt.figure()

colors = ['blue', 'red']

with writer.saving(fig, './2fish_images/histogram_timechange_mode2_single_speed.mp4', 100):
    
    for t in np.arange(0, 500000, 1000):
        fig.clf()
        
        plt.title('Mode 2 [{}:{}], speed = {}'.format(t,t+30000, np.nanmean(absolutespeed[t:t+30000])))
#         for i in [0,1]:
        i = 2
        data = pca_histdata(handlers[i], [t, t+30000])
        n, x, _ = plt.hist(data[1], bins=200, histtype=u'step', density=True, label=names[i], linewidth=2) #, color=colors[i])
        plt.axvline(x=0, c='black')
        plt.xlim(-100,100)
        plt.ylim(0, 0.05)
        plt.legend()

        # add frame to video
        writer.grab_frame()
        

# Widget stuff - creates fish representation
### (sorry for it being messy)

In [None]:
# choose one PCA_angles object to use here
coord_handler = coord_3d

In [None]:
# check if we're dealing with 2d or 3d
if coord_handler.dimension == 2:

    ### Prepare the list of widgets ###

    prmds = coord_handler.covmat.shape[0]
    # layout parameters
    number_layout = widgets.Layout(width='70px')
    text_layout = widgets.Layout(width='120px')

    # preallocate list structures to hold the numbers
    pca_widget_list = [ [] for i in range(prmds)]
    rep_widget_list = [ [] for i in range(prmds)]

    # get the plottable positions for the frame
    plottable_3D_widget_list = coord_handler.convert_rep_to_plotable_positions(coord_handler.mean_rep)

    # grab the pca values
    mean_rep_modes = coord_handler.convert_rep_to_mode_weights(coord_handler.mean_rep)
    for pcIdx in range(prmds):
        wid_val = np.round(np.copy(mean_rep_modes[pcIdx]), 2)
        val_widge_3D = widgets.FloatText(value=wid_val,layout=number_layout)
        pca_widget_list[pcIdx] = val_widge_3D

    # grab the rep values
    for pcIdx in range(prmds):
        wid_val = np.round(np.copy(coord_handler.mean_rep[pcIdx]), 2)
        val_widge_3D = widgets.FloatText(value=wid_val,layout=number_layout)
        rep_widget_list[pcIdx] = val_widge_3D


    ### Make widgets to vizualize the plottable positions

    # layout parameters
    number_layout = widgets.Layout(width='80px')
    text_layout = widgets.Layout(width='120px')

    # widget lists
    f_bp_XYZ_boxes = [None]
    f_3D_wid_titles = [None]
    f_3D_wids = [None]

    f_bp_box_list = []
    for bpIdx in range(prmds+1):
        f_bp_XYZ_boxes = widgets.HBox([widgets.FloatText(value=np.round(i,2),layout=number_layout) for i in plottable_3D_widget_list[bpIdx]], layout=widgets.Layout(padding=('0px 30px 0 0')))
        f_bp_box_list.append(f_bp_XYZ_boxes)
    f_3D_wid_titles = widgets.Text('Fish FC bps', layout=text_layout)
    f_3D_wids = widgets.VBox([f_3D_wid_titles]+f_bp_box_list)

    # Final widget
    plottable_frame_widget = f_3D_wids


    ### Make widgets to vizualize the pca_data

    # layout parameters
    number_layout = widgets.Layout(width='80px')
    text_layout = widgets.Layout(width='120px')
    wid_title = widgets.Text('mode weights', layout=text_layout)

    # get the values of the widget in a hbox
    mode_weight_vals_wid = widgets.HBox(pca_widget_list, layout=widgets.Layout(padding=('0px 0px 00px 000px')))

    # make the final widget with title
    mode_frame_widget = widgets.VBox([wid_title, mode_weight_vals_wid], layout=widgets.Layout(padding=('0px 0px 30px 200px')))


    ### Make widgets to vizualize the rep_data

    # layout parameters
    number_layout = widgets.Layout(width='80px')
    text_layout = widgets.Layout(width='120px')
    wid_title = widgets.Text('14D rep', layout=text_layout)

    # get the values of the widget in a hbox
    rep_vals_wid = widgets.HBox(rep_widget_list, layout=widgets.Layout(padding=('0px 0px 00px 000px')))

    # make the final widget with title
    rep_frame_widget = widgets.VBox([wid_title, rep_vals_wid], layout=widgets.Layout(padding=('0px 0px 30px 200px')))

    # final widget


    data_widget = widgets.VBox([plottable_frame_widget, mode_frame_widget, rep_frame_widget],
                                        layout=widgets.Layout(display="flex-start"))

    #data_widget
else:
    ### Prepare the list of widgets ###
    PRMODES = coord_handler.covmat.shape[0]
    # layout parameters
    number_layout = widgets.Layout(width='70px')
    text_layout = widgets.Layout(width='120px')

    # preallocate list structures to hold the numbers
    plottable_3D_widget_list = [[[] for _ in range(3)] for _ in range(PRMODES//2+1)]
    pca_widget_list = [ [] for i in range(PRMODES)]
    rep_widget_list = [ [] for i in range(PRMODES)]

    # get the plottable positions for the frame
    # plottable_3D_widget_list = coord_handler.convert_rep_to_plotable_positions(coord_handler.mean_rep)
    plottable_positions = coord_handler.convert_rep_to_plotable_positions(coord_handler.mean_rep)

    # grab the plottable values
    for bpIdx in range(PRMODES//2+1):
        for dimIdx in range(3):
            # plottable data
            wid_val = np.round(np.copy(plottable_positions[bpIdx][dimIdx]), 2)
            val_widge_3D = widgets.FloatText(value=wid_val,layout=number_layout)
            plottable_3D_widget_list[bpIdx][dimIdx] = val_widge_3D

    # grab the pca values
    mean_rep_modes = coord_handler.convert_rep_to_mode_weights(coord_handler.mean_rep)
    for pcIdx in range(PRMODES):
        wid_val = np.round(np.copy(mean_rep_modes[pcIdx]), 2)
        val_widge_3D = widgets.FloatText(value=wid_val,layout=number_layout)
        pca_widget_list[pcIdx] = val_widge_3D

    # grab the rep values
    for pcIdx in range(PRMODES):
        wid_val = np.round(np.copy(coord_handler.mean_rep[pcIdx]), 2)
        val_widge_3D = widgets.FloatText(value=wid_val,layout=number_layout)
        rep_widget_list[pcIdx] = val_widge_3D
        
    ### Make widgets to vizualize the plottable positions

    # layout parameters
    number_layout = widgets.Layout(width='80px')
    text_layout = widgets.Layout(width='120px')

    # widget lists
    f_bp_XYZ_boxes = [None]
    f_3D_wid_titles = [None]
    f_3D_wids = [None]

    f_bp_box_list = []
    for bpIdx in range(PRMODES//2+1):
        f_bp_XYZ_boxes = widgets.HBox(plottable_3D_widget_list[bpIdx][:], layout=widgets.Layout(padding=('0px 30px 0 0')))
        f_bp_box_list.append(f_bp_XYZ_boxes)
    f_3D_wid_titles = widgets.Text('Fish FC bps', layout=text_layout)
    f_3D_wids = widgets.VBox([f_3D_wid_titles]+f_bp_box_list)

    # Final widget
    plottable_frame_widget = f_3D_wids
    
    ### Make widgets to vizualize the pca_data

    # layout parameters
    number_layout = widgets.Layout(width='80px')
    text_layout = widgets.Layout(width='120px')
    wid_title = widgets.Text('mode weights', layout=text_layout)

    # get the values of the widget in a hbox
    mode_weight_vals_wid = widgets.HBox(pca_widget_list, layout=widgets.Layout(padding=('0px 0px 00px 000px')))

    # make the final widget with title
    mode_frame_widget = widgets.VBox([wid_title, mode_weight_vals_wid], layout=widgets.Layout(padding=('0px 0px 30px 200px')))


    ### Make widgets to vizualize the rep_data

    # layout parameters
    number_layout = widgets.Layout(width='80px')
    text_layout = widgets.Layout(width='120px')
    wid_title = widgets.Text('14D rep', layout=text_layout)

    # get the values of the widget in a hbox
    rep_vals_wid = widgets.HBox(rep_widget_list, layout=widgets.Layout(padding=('0px 0px 00px 000px')))

    # make the final widget with title
    rep_frame_widget = widgets.VBox([wid_title, rep_vals_wid], layout=widgets.Layout(padding=('0px 0px 30px 200px')))
    
    # final widget


    data_widget = widgets.VBox([plottable_frame_widget, mode_frame_widget, rep_frame_widget],
                                        layout=widgets.Layout(display="flex-start"))

    #data_widget

In [None]:
# check if we're dealing with 2d or 3d
if coord_handler.dimension == 2:

    #  ----------------- make a 2D plot of the positions ----------------------- #

    plt.ioff()


    rep_data = np.copy(coord_handler.mean_rep)
    pca_data = coord_handler.convert_rep_to_mode_weights(rep_data)
    frame_positions = coord_handler.convert_rep_to_plotable_positions(rep_data)

    numFish, numBodyPoints = 1, prmds+1
    fish_colors = ['red']

    # Attaching 3D axis to the figure
    fig = plt.figure()
    fig.tight_layout()
    fig.canvas.header_visible = False
    #fig.canvas.layout.min_height = '400px'
    fig.canvas.layout.height = '100%'
    fig.canvas.layout.width = '100%'
    ax = fig.add_subplot(111)  
    fig.suptitle('Investigating PCA mode action about mean configuration')
    #ax = p3.Axes3D(fig, [0.,0.,1.,1.])

    # set the limits
    xmin, xmax = -10, 50
    ymin, ymax = -40, 40
    zmin, zmax = -5, 5

    # Setting the axes properties
    ax.set_xlim([xmin, xmax])
    # ax.set_xlabel('X')
    ax.xaxis.set_ticklabels([])
    ax.yaxis.set_ticklabels([])
    ax.set_ylim([ymin, ymax])
    # ax.set_ylabel('Y')

    # ax.set_zlim3d([zmin, zmax])
    # ax.set_zlabel('Z')

    # Initialize scatters (list over fish, list over bps)
    symbols = ['o', 's', 'x']

    # pixel-width values from straight fish threshold image
    sizes=[28.0, 30.0, 30.0, 30.0, 31.0, 33.0, 31.0, 27.0, 27.0, 26.0, 24.0, 21.0, 21.0, 21.0, 21.0, 15.0, 15.0, 14.0, 9.0, 10.0]
    sizes = [250,260,280,280,270,260,250,240,230,220,200,180,150,120,100,75,60,35,35,35]
    sizes = [i*3 for i in sizes]
    # Main scatters,
    scatters = []
    col = fish_colors[0]
    fish_scatters = []
    for bpIdx in range(numBodyPoints):
        fish_scatters.append(ax.scatter(frame_positions[bpIdx][0], 
                                        frame_positions[bpIdx][1], 
                                        c=col, s=sizes[bpIdx])) 
    scatters.append(fish_scatters)

    # Main lines
    lines = []
    col = fish_colors[0]
    line = ax.plot([i[0] for i in frame_positions], [i[1] for i in frame_positions], c='black')[0]
    lines.append(line)

    # scatterplot the origin
    ax.scatter(0, 0, c='black', s=14)

    sizes=5

    # make the widget
    layout = widgets.Layout(width='100%')
    plot_widget = VBox([fig.canvas], layout=layout)


    # make the 8 sliders
    numModeSliders = 8

    mode_sliders = []

    for ii in range(numModeSliders):

        slider_min = np.nanmin(pca_tseries[:,ii])
        slider_max = np.nanmax(pca_tseries[:,ii])
        step=0.01
        init_val = pca_widget_list[ii].value

        # make the slider
        mode_slider = widgets.FloatSlider(value=init_val, min=slider_min, max=slider_max, step=step, 
                                             continuous_update=True, description='m{0}: {1}'.format(ii+1, np.round(mean_rep_modes[ii],2)), 
                                              layout=Layout(width='100%', height='50px'))

        # link the slider to the widget display values
        mode_val_slider_link = widgets.link((mode_slider, 'value'), (pca_widget_list[ii], 'value'))

        # keep the mode slider
        mode_sliders.append(mode_slider)




    # ------- On changes to the mode sliders, update the plottable positions ------ #

    def update_plottable_from_mode_sliders(change):
        numCams=3
        numFish=1
        numBodyPoints=prmds+1

        # get the current values of the mode weights
        pca_data = []
        for ii in range(numModeSliders):
            pca_weight = pca_widget_list[ii].value
            pca_data.append(pca_weight)
        frame_pca_data = np.array(pca_data)

        # get the representation data from the pca weights
        frame_rep_data = coord_handler.convert_mode_weights_to_rep(frame_pca_data)

        # get the plottable positions
        frame_positions = coord_handler.convert_rep_to_plotable_positions(frame_rep_data)

        # update the rep data
        for ii in range(prmds):
            rep_widget_list[ii].value = np.round(frame_rep_data[ii],3)

        # update the plottable positions (lab can't be updated)
        for bpIdx in range(numBodyPoints):
            for dimIdx in range(2):
                plottable_3D_widget_list[bpIdx][dimIdx] = np.round(frame_positions[bpIdx][dimIdx],3)


    for ii in range(numModeSliders):
        mode_sliders[ii].observe(update_plottable_from_mode_sliders, names='value')


    # all sliders
    all_sliders = widgets.VBox([widgets.VBox(mode_sliders)])



    # ----- Update the plot with mode sliders ----- #


    def update_plot_mode_change(change):
        numFish=1
        numBodyPoints=prmds+1
        fig.suptitle('Investigating PCA mode action about mean configuration')

        #------ Gather the current frame positions from the widget --- #
        frame_positions = []

        for bpIdx in range(numBodyPoints):
            bodypoint_pos = []
            for dimIdx in range(2):
                val = plottable_3D_widget_list[bpIdx][dimIdx]
                bodypoint_pos.append(val)
            bodypoint_pos = np.array(bodypoint_pos)
            frame_positions.append(bodypoint_pos)
        frame_positions = np.stack(frame_positions)
        #-------------------------------------#


        #------- Update the real fish ---- #
        # update the scatters
        for bpIdx in range(numBodyPoints):
            ax = scatters[0][bpIdx]
            ax.set_offsets([frame_positions[bpIdx][0], frame_positions[bpIdx][1]])

        # update the lines
        for fishIdx,line in enumerate(lines):
                line.set_data(frame_positions[:, :2].swapaxes(0,1))
        #--------------------------------------#


        fig.canvas.draw()
        fig.canvas.flush_events()


    for ii in range(numModeSliders):
        ms = mode_sliders[ii]
        ms.observe(update_plot_mode_change)


    # display Vbox
    plt.ioff()
    sc = Sidecar(title='Sidecar Output')
    with sc:
        display(all_sliders)
        #display(mean_rep_modes)
    
else:
    #  ----------------- make a 3D plot of the positions ----------------------- #

    plt.ioff()


    rep_data = np.copy(coord_handler.mean_rep)

    # print(rep_data)
    pca_data = coord_handler.convert_rep_to_mode_weights(rep_data)
    frame_positions = coord_handler.convert_rep_to_plotable_positions(rep_data)

    # print(frame_positions)

    numFish, numBodyPoints = 1, PRMODES//2+1
    fish_colors = ['red']

    # Attaching 3D axis to the figure
    fig = plt.figure()
    fig.tight_layout()
    fig.canvas.header_visible = False
    # fig.canvas.layout.min_height = '400px'
    fig.canvas.layout.height = '100%'
    fig.canvas.layout.width = '100%'
    ax = fig.add_subplot(111, projection='3d')  
    fig.suptitle('Investigating PCA mode action about mean configuration')
    # ax = p3.Axes3D(fig, [0.,0.,1.,1.])

    # set the limits
    xmin, xmax = -60, 60
    ymin, ymax = -60, 60
    zmin, zmax = -60, 60

    # Setting the axes properties
    ax.set_xlim3d([xmin, xmax])
    ax.set_xlabel('X')

    ax.set_ylim3d([ymin, ymax])
    ax.set_ylabel('Y')

    ax.set_zlim3d([zmin, zmax])
    ax.set_zlabel('Z')

    # Initialize scatters (list over fish, list over bps)
    symbols = ['o', 's', 'x']
    sizes=12

    # Main scatters
    scatters = []
    col = fish_colors[0]
    fish_scatters = []
    for bpIdx in range(numBodyPoints):
        fish_scatters.append(ax.scatter(frame_positions[bpIdx][0], 
                                        frame_positions[bpIdx][1], 
                                        frame_positions[bpIdx][2],
                                        c=col, s=sizes)) 
    scatters.append(fish_scatters)

    # Main lines
    lines = []
    col = fish_colors[0]
    line = ax.plot([i[0] for i in frame_positions], [i[1] for i in frame_positions], [i[2] for i in frame_positions], c=col)[0]
    lines.append(line)

    # scatterplot the origin
    ax.scatter(0, 0, 0, c='black', s=14)

    sizes=5
    
    # make the widget
    layout = widgets.Layout(width='100%')
    plot_widget = VBox([fig.canvas], layout=layout)
    
    # make the 8 sliders
    numModeSliders = 8

    mode_sliders = []

    for ii in range(numModeSliders):

        slider_min = np.nanmin(pca_tseries[:,ii])
        slider_max = np.nanmax(pca_tseries[:,ii])
        step=0.01
        init_val = pca_widget_list[ii].value

        # make the slider
        mode_slider = widgets.FloatSlider(value=init_val, min=slider_min, max=slider_max, step=step, 
                                             continuous_update=True, description='m{0}: {1}'.format(ii+1, np.round(mean_rep_modes[ii],2)), 
                                              layout=Layout(width='100%', height='30px'))

        # link the slider to the widget display values
        mode_val_slider_link = widgets.link((mode_slider, 'value'), (pca_widget_list[ii], 'value'))

        # keep the mode slider
        mode_sliders.append(mode_slider)




    # ------- On changes to the mode sliders, update the plottable positions ------ #

    def update_plottable_from_mode_sliders(change):

        numCams=3
        numFish=1
        numBodyPoints=PRMODES//2+1

        # get the current values of the mode weights
        pca_data = []
        for ii in range(numModeSliders):
            pca_weight = pca_widget_list[ii].value
            pca_data.append(pca_weight)
        frame_pca_data = np.array(pca_data)

        # get the representation data from the pca weights
        frame_rep_data = coord_handler.convert_mode_weights_to_rep(frame_pca_data)

        # get the plottable positions
        frame_positions = coord_handler.convert_rep_to_plotable_positions(frame_rep_data)

        # update the rep data
        for ii in range(PRMODES):
            rep_widget_list[ii].value = np.round(frame_rep_data[ii],3)

        # update the plottable positions (lab can't be updated)
        for bpIdx in range(numBodyPoints):
            for dimIdx in range(3):
                plottable_3D_widget_list[bpIdx][dimIdx].value = np.round(frame_positions[bpIdx][dimIdx],3)



    for ii in range(numModeSliders):
        mode_sliders[ii].observe(update_plottable_from_mode_sliders, names='value')


    # all sliders
    all_sliders = widgets.VBox([widgets.VBox(mode_sliders)])



    # ----- Update the plot with mode sliders ----- #


    def update_plot_mode_change(change):
        numFish=1
        numBodyPoints=PRMODES//2+1
        fig.suptitle('Investigating PCA mode action about mean configuration')

        #------ Gather the current frame positions from the widget --- #
        frame_positions = []

        for bpIdx in range(numBodyPoints):
            bodypoint_pos = []
            for dimIdx in range(3):
                val = plottable_3D_widget_list[bpIdx][dimIdx].value
                bodypoint_pos.append(val)
            bodypoint_pos = np.array(bodypoint_pos)
            frame_positions.append(bodypoint_pos)
        frame_positions = np.stack(frame_positions)
        #-------------------------------------#


        #------- Update the real fish ---- #
        # update the scatters
        for bpIdx in range(numBodyPoints):
            ax = scatters[0][bpIdx]
    #         ax.set_offsets([frame_positions[bpIdx][0], frame_positions[bpIdx][1],frame_positions[bpIdx][2]])
            ax._offsets3d[0][0] = frame_positions[bpIdx][0]
            ax._offsets3d[1][0] = frame_positions[bpIdx][1]
            ax._offsets3d[2][0] = frame_positions[bpIdx][2]
        # update the lines
        for line in lines:
                line.set_data(frame_positions[:, :2].swapaxes(0,1))
                line.set_3d_properties(frame_positions[:, 2])
        #--------------------------------------#


        fig.canvas.draw()
        fig.canvas.flush_events()


    for ii in range(numModeSliders):
        ms = mode_sliders[ii]
        ms.observe(update_plot_mode_change)
        
    # display Vbox
    plt.ioff()
    sc = Sidecar(title='Sidecar Output')
    with sc:
        display(all_sliders)
    #     display(mean_rep_modes)

In [None]:
plot_widget

### play with mode ranges

In [None]:
range_turn = [114270, 114300]
for i in range(range_turn[0], range_turn[1]):
    for j in range(10):
        pca_widget_list[j].value = pca_tseries[:,j][i]
        
# reset values
for j in range(10):
    pca_widget_list[j].value = 0

In [None]:
#make a video

Writer = animation.writers['ffmpeg']
writer = Writer(fps=10, metadata=dict(artist='LOS'), bitrate=3500)
ranges = [114270, 114300]


with writer.saving(fig, './2fish_images/singlefish/singlefish_turn.mp4', 100):
    
    for t in np.arange(ranges[0],ranges[1],1):
        pca_widget_list[0].value = pca_tseries[:,0][t]
        pca_widget_list[1].value = pca_tseries[:,1][t]

        # add frame to video
        writer.grab_frame()
        
    pca_widget_list[1].value = 0
    pca_widget_list[0].value = 0

# Play with data fish

In [None]:
# needs data from fish analysis

In [None]:
# load in data
file_paths = ["./Project/LargeCrop/predictions/20201119/bottomup/pred_FishTank20200416_160648/results/FishTank20200416_160648/FishTank20200416_160648.h5",
              "./Project/LargeCrop/predictions/20201119/bottomup/pred_FishTank20200413_154621/results/FishTank20200413_154621/FishTank20200413_154621.h5",
              "./Project/LargeCrop/predictions/20201120/pred_FishTank20200414_154100/results/FishTank20200414_154100/FishTank20200414_154100.h5",
              "./Project/LargeCrop/predictions/20201120/pred_FishTank20200415_154234/results/FishTank20200415_154234/FishTank20200415_154234.h5",
              "./Project/LargeCrop/predictions/20201120/pred_FishTank20200417_154139/results/FishTank20200417_154139/FishTank20200417_154139.h5",
              "./Project/LargeCrop/predictions/20201120/pred_FishTank20200418_160144/results/FishTank20200418_160144/FishTank20200418_160144.h5",
              "./Project/LargeCrop/predictions/20201120/pred_FishTank20200419_151352/results/FishTank20200419_151352/FishTank20200419_151352.h5",
              "./Project/LargeCrop/predictions/20201120/pred_FishTank20200419_173651/results/FishTank20200419_173651/FishTank20200419_173651.h5",
              "./Project/LargeCrop/predictions/20201120/pred_FishTank20200420_153056/results/FishTank20200420_153056/FishTank20200420_153056.h5",
              "./Project/LargeCrop/predictions/20201120/pred_FishTank20200420_175237/results/FishTank20200420_175237/FishTank20200420_175237.h5",
             ]

# All data variables
speeds_frametoframe = []
data_3D_filtered = []
angles_upperbody_360 = []
angles_upperbody_180 = []
upperbody_angles_xy = []
upperbody_angles_z = []
tail_angles_xy = []
tail_angles_z = []
speed_xyz = []
speed_upperbody_angles_xy = []
speed_upperbody_angles_z = []
speed_tail_angles_xy = []
speed_tail_angles_z = []
data_3D = []
original_instances = []

for file_path in file_paths:
    with h5py.File(file_path, 'r') as f:
        original_instances.append(np.copy(f['original_instances']))
        data_3D.append(np.copy(f['tracks_3D']))
        speeds_frametoframe.append(np.copy(f['speed_frametoframe']))
        data_3D_filtered.append(np.copy(f['tracks_3D_no_outliers_60']))
        angles_upperbody_360.append(np.copy(f['upperbody_xy_angle_0to360']))
        angles_upperbody_180.append(np.copy(f['upperbody_xy_angle_180']))
        upperbody_angles_xy.append(np.copy(f['upperbody_xy_angle_additive']))
        upperbody_angles_z.append(np.copy(f['upperbody_z_angle']))
        tail_angles_xy.append(np.copy(f['tail_xy_angle']))
        tail_angles_z.append(np.copy(f['tail_z_angle']))
        speed_xyz.append(np.copy(f['speed_xyz']))
        speed_upperbody_angles_xy.append(np.copy(f['speed_upperbody_angles_xy']))
        speed_upperbody_angles_z.append(np.copy(f['speed_upperbody_angles_z']))
        speed_tail_angles_xy.append(np.copy(f['speed_tail_angles_xy']))
        speed_tail_angles_z.append(np.copy(f['speed_tail_angles_z']))
        
# relation to fishpair image tatsuo
fishpairs=["9_top",
           "7_down",
           "7_up",
           "9_down",
           "10_down",
           "10_up",
           "8_down",
           "6_down",
           "8_up",
           "6_up"
          ]


def abs_speed(data, bdyp):
    return [np.sqrt(i[0]**2+i[1]**2+i[2]**2) for i in data[:, bdyp]]

In [None]:
# just to plot a dataset

plt.figure()
plt.plot(upperbody_angles_xy[0])
plt.xlabel('Frame')
plt.ylabel('Additive angle [Degrees]')
plt.tight_layout()
plt.show()

In [None]:
def make_plots(handler, range_data, ranges):
    '''
    function that plots (angular)speed, angle, mode1, mode2 and mode1vsmode2
    
    handler: PCA_angles object
    range_data: frame range
    ranges: subranges within the frame range for the mode1vsmode2 plot
    '''
    
    fig, axes = plt.subplots(3, 2, figsize=(8,6))
    
    axes[0,0].set_title("Absolute speed [cm/s]")
    axes[0,0].plot(abs_speed(speed_xyz[0], 0)[range_data[0]:range_data[1]], label="head")
    axes[0,0].plot(abs_speed(speed_xyz[0], 1)[range_data[0]:range_data[1]], label="pec")
    axes[0,0].legend(prop={'size': 8})
    
    axes[1,0].set_title("Upperbody angular speed [degrees/s]")
    axes[1,0].plot(speed_upperbody_angles_xy[0][range_data[0]:range_data[1]])
    
    axes[2,0].set_title("Relative upperbody angle [degrees]")
    axes[2,0].plot(upperbody_angles_xy[0][range_data[0]:range_data[1]] - upperbody_angles_xy[0][range_data[0]])
    axes[2,0].set_xlabel('Frame')
    
    f_data = pca_histdata(handler, [range_data[0],range_data[1]], all_data=True)
    
    axes[0,1].set_title("Mode 1 Amplitude")
    axes[0,1].plot(f_data[0])
    
    axes[1,1].set_title("Mode 2 Amplitude")
    axes[1,1].plot(f_data[1])
    axes[1,1].set_xlabel('Frame')
    
    axes[2,1].set_title("Modespace")
    for k,i in enumerate(ranges):
        alphas = [(j+1)/(i[1]-i[0]) for j in range(i[1]-i[0])]
        rgba_colors = np.zeros((i[1]-i[0],4))
        
        rgba_colors[:,k%3] = 1.0
        if k > 2:
            rgba_colors[:,(k+2)%3] = 1.0
        
        # the fourth column needs to be your alphas
        rgba_colors[:, 3] = alphas
        
        if len(ranges) == 1:
            pltted = axes[2,1].scatter(f_data[0][i[0]:i[1]], f_data[1][i[0]:i[1]], color=rgba_colors, s=5)
            pltted.set_label("[{},{}]".format(i[0], i[1]))
        else:
            pltted = axes[2,1].plot(f_data[0][i[0]:i[1]], f_data[1][i[0]:i[1]], label="[{},{}]".format(i[0], i[1]))
            
    axes[2,1].set_xlabel('Mode 1')
    axes[2,1].set_ylabel('Mode 2')
    axes[2,1].axvline(x=0, c='black')
    axes[2,1].axhline(y=0, c='black')
    axes[2,1].set_xlim(-400,400)
    axes[2,1].set_ylim(-200,200)
    leg = axes[2,1].legend(prop={'size': 8})
    for lh in leg.legendHandles: 
        lh.set_alpha(1)
    fig.tight_layout()
    plt.show()
    
    return

In [None]:
plt.close('all')

range_swim = [432474, 432535]
range_turn = [114270, 114320]

make_plots(coord_single, range_swim, [[15,45]]) # one swim
make_plots(coord_single,range_turn, [[5,35]]) #turn1
make_plots(coord_single,[158855,158900], [[5,30]]) #turn2
make_plots(coord_single,[11700, 11800], [[20,80]]) # full turn
make_plots(coord_single,[157950,158250], [[0,35]])#, [65,95], [140,175], [220,250]]) # L(+swim)RRR
make_plots(coord_single,[12100,12300], [[20,50],[50,75],[80,120],[150,175],[175,200]]) # left, swim, swim, left, right
make_plots(coord_single,[13800,14000], [[20,50], [60,125], [160,200]]) # R, Lswim swim swim, R

In [None]:
def make_mode_plots(handler, range_data):
    '''
    plot timelines of modes for certain range
    '''
    
    pca_tseries = np.copy(handler.pca_tseries)
    
    fig, axes = plt.subplots(3, 2, figsize=(8,6))
    
    axes[0,0].set_title("Mode 1")
    axes[0,0].plot(pca_tseries[range_data[0]:range_data[1],0])
    
    axes[0,1].set_title("Mode 2")
    axes[0,1].plot(pca_tseries[range_data[0]:range_data[1],1])
    
    axes[1,0].set_title("Mode 3")
    axes[1,0].plot(pca_tseries[range_data[0]:range_data[1],2])
    
    axes[1,1].set_title("Mode 4")
    axes[1,1].plot(pca_tseries[range_data[0]:range_data[1],3])
    
    axes[2,0].set_title("Mode 5")
    axes[2,0].plot(pca_tseries[range_data[0]:range_data[1],4])
    
    axes[2,1].set_title("Mode 6")
    axes[2,1].plot(pca_tseries[range_data[0]:range_data[1],5])
    
    fig.tight_layout()
    plt.show()
    
    return

In [None]:
make_mode_plots(coord_3d, range_swim)

In [None]:
# makes a video of raw video (xy)

Writer = animation.writers['ffmpeg']
writer = Writer(fps=100, metadata=dict(artist='LOS'), bitrate=3500)
t0 = time.time()
fig = plt.figure()
range_vid = [13800,140000]

with writer.saving(fig, './2fish_images/singlefish/swim13800_14000.mp4', 100):
    
    video_file = '/home/thomasreus/Documents/zebrafish_labeling_GUI/OneFish_20200416/E_xy/splitdata{}.mp4'.format(str(range_vid[0]//6000 + 7).zfill(4))
    vidcap = cv2.VideoCapture(video_file)
    success,image = vidcap.read()
    count = 0

    while success and count < range_vid[1]%6000:
        
        if count > range_vid[0]%6000:
            #normal image
            plt.imshow(image[:,:,0],cmap='gray')

            # add frame to video
            writer.grab_frame()

            fig.clear(True)

        count += 1
        success,image = vidcap.read()
        
print(time.time()-t0)