In [1]:
%load_ext autoreload

In [2]:
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
import mpld3

import pandas as pd
from pathlib import Path
from utilities import *

import seaborn as sns
sns.set(style="ticks", palette="deep")
cols = sns.color_palette()
plt.style.use('v_arial')

Master folder with Stytra example data:

In [3]:
sample_data_path = Path("J:\_Shared\stytra\manuscript\datasets")

# Load all experiments:

Note that in these datasets tail angles have been removed for compression

In [4]:
master_path = Path(sample_data_path / "replication_portugues2011")
paths = list(master_path.glob('*f*'))
exps = [Experiment(path) for path in paths]

In [5]:
def get_exp_stats(exp):
    tail_log_df = exp.behavior_log.set_index("t")  # DataFrame with the tail trace
    stim_log_df = exp.stimulus_param_log.set_index("t")
    
    tail_dt = np.diff(tail_log_df.index).mean()  # time step size in tail trace df
    
    # Recalculate vigor from tail trace:
    vigor = tail_log_df["tail_sum"].rolling(int(0.05/tail_dt), center=True).std().as_matrix()
    
    # extract bouts
    bouts_idxs, _ = extract_segments_above_thresh(vigor, 0.6, min_duration=int(0.1/tail_dt), 
                                             pad_before=5, pad_after=5, skip_nan=True)
    # Calculate bouts start and end times:
    bout_starts = np.array([tail_log_df["tail_sum"].index[b[0]] for b in bouts_idxs])
    bout_ends = np.array([tail_log_df["tail_sum"].index[b[1]] for b in bouts_idxs])

    
    # Exclude trailing trial of gain 1:
    trial_s = stim_log_df.index[np.ediff1d(stim_log_df["closed loop 1D_gain"], to_begin=0) > 0][1:]
    trial_e = stim_log_df.index[np.ediff1d(stim_log_df["closed loop 1D_gain"], to_begin=0) < 0][1:]
    trial_g = stim_log_df["closed loop 1D_gain"][np.ediff1d(stim_log_df["closed loop 1D_gain"], to_begin=0) > 0].as_matrix()[1:]                    
    trial_df = pd.DataFrame(dict(start=trial_s, end=trial_e, gain=trial_g, inter_bout_t=np.nan, 
                                 trial_idx = [i for i in range(6) for _ in range(18)],
                                 bout_n=np.nan, bout_duration=np.nan, first_bout_latency=np.nan), 
                            index=np.arange(len(trial_s)))
    for i in range(len(trial_df)):
        bout_idxs = np.argwhere((bout_starts > trial_df.loc[i, "start"]) 
                                & (bout_ends < trial_df.loc[i, "end"]))[:,0]
        if len(bout_idxs) > 0:
            trial_df.loc[i, "bout_n"] = len(bout_idxs)
            trial_df.loc[i, "first_bout_latency"] = bout_starts[bout_idxs[0]] - trial_df.loc[i, "start"]
            trial_df.loc[i, "bout_duration"] = (bout_ends[bout_idxs] - bout_starts[bout_idxs]).mean()

            if len(bout_idxs) > 3:
                trial_df.loc[i, "inter_bout_t"] = (bout_starts[bout_idxs[1:]] - bout_ends[bout_idxs[:-1]]).mean()
                
    return trial_df
                
def cum_stat(stat_list, key):
    stat = np.array([np.nanmean(stat[key].as_matrix().reshape(6, 18), 0) for stat in stat_list]).T
    stat = stat/np.nanmean(stat[:3,:], 0)
    return stat

In [6]:
trial_stats = [get_exp_stats(exp) for exp in exps]

# Make figure

In [7]:
def back_color():
    plt.axvspan(-0.5, 2.5, color=(0.9,)*3)
    plt.axvspan(8.5, 11.5, color=(0.9,)*3)
    plt.axvspan(5.5, 8.5, color=(0.7,)*3)
    plt.axvspan(11.5, 14.5, color=(0.7,)*3)

In [12]:
labels = ["fish {}".format(i) for i in range(stat.shape[1])]
labels

['fish 0',
 'fish 1',
 'fish 2',
 'fish 3',
 'fish 4',
 'fish 5',
 'fish 6',
 'fish 7',
 'fish 8',
 'fish 9',
 'fish 10',
 'fish 11',
 'fish 12',
 'fish 13',
 'fish 14',
 'fish 15',
 'fish 16',
 'fish 17',
 'fish 18',
 'fish 19',
 'fish 20',
 'fish 21',
 'fish 22',
 'fish 23']

In [19]:
from mpld3 import plugins, utils

class HighlightLines(plugins.PluginBase):
    """A plugin to highlight lines on hover"""

    JAVASCRIPT = """
    mpld3.register_plugin("linehighlight", LineHighlightPlugin);
    LineHighlightPlugin.prototype = Object.create(mpld3.Plugin.prototype);
    LineHighlightPlugin.prototype.constructor = LineHighlightPlugin;
    LineHighlightPlugin.prototype.requiredProps = ["line_ids"];
    LineHighlightPlugin.prototype.defaultProps = {alpha_bg:0.3, alpha_fg:1.0}
    function LineHighlightPlugin(fig, props){
        mpld3.Plugin.call(this, fig, props);
    };

    LineHighlightPlugin.prototype.draw = function(){
      for(var i=0; i<this.props.line_ids.length; i++){
         var obj = mpld3.get_element(this.props.line_ids[i], this.fig),
             alpha_fg = this.props.alpha_fg;
             alpha_bg = this.props.alpha_bg;
         obj.elements()
             .on("mouseover", function(d, i){
                            d3.select(this).transition().duration(50)
                              .style("stroke-opacity", alpha_fg); })
             .on("mouseout", function(d, i){
                            d3.select(this).transition().duration(200)
                              .style("stroke-opacity", alpha_bg); });
      }
    };
    """

    def __init__(self, lines):
        self.lines = lines
        self.dict_ = {"type": "linehighlight",
                      "line_ids": [utils.get_id(line) for line in lines],
                      "alpha_bg": lines[0].get_alpha(),
                      "alpha_fg": 1.0}

In [47]:
fig = plt.figure(figsize=(7,6))
k = 'inter_bout_t'
stat = cum_stat(trial_stats, k)
plt.errorbar(np.arange(stat.shape[0]), np.nanmean(stat, 1), np.nanstd(stat, 1)/np.sqrt(stat.shape[1]-1), linewidth = 2)
plot = plt.plot(stat, linewidth = 1.5, c='k', alpha=0.2)
# back_color()
# plt.ylim(0.5,1.5)
# plt.xlim((-0.5, 17.5))
plt.xticks(np.arange(1, 17, 3), np.arange(2, 18, 3), fontsize=14)
plt.yticks(np.arange(0.4, 2.2, 0.4), fontsize=14)
plt.ylabel('Relative inter-bout time', fontsize=18)
plt.xlabel('trial number', fontsize=18)

labels = ["fish {}".format(i) for i in range(stat.shape[1])]
tooltip = HighlightLines(plot)
mpld3.plugins.connect(fig, tooltip)
mpld3.save_html(fig, r"C:\Users\lpetrucco\code\stytra\docs\figures\portugues2011_replication.html")
mpld3.display()

