In [1]:
%load_ext autoreload

In [1]:
import matplotlib.pyplot as plt
import numpy as np
# %matplotlib inline
import plotly.graph_objs as go

import pandas as pd
from pathlib import Path
from utilities import *

import seaborn as sns
sns.set(style="ticks", palette="deep")
cols = sns.color_palette()
plt.style.use('v_arial')


Master folder with Stytra example data:

In [2]:
sample_data_path = Path("J:\_Shared\stytra\manuscript\datasets")

# Load all experiments:

Note that in these datasets tail angles have been removed for compression

In [3]:
master_path = Path(sample_data_path / "replication_portugues2011")
paths = list(master_path.glob('*f*'))
exps = [Experiment(path) for path in paths]

In [4]:
def get_exp_stats(exp):
    tail_log_df = exp.behavior_log.set_index("t")  # DataFrame with the tail trace
    stim_log_df = exp.stimulus_param_log.set_index("t")
    
    tail_dt = np.diff(tail_log_df.index).mean()  # time step size in tail trace df
    
    # Recalculate vigor from tail trace:
    vigor = tail_log_df["tail_sum"].rolling(int(0.05/tail_dt), center=True).std().as_matrix()
    
    # extract bouts
    bouts_idxs, _ = extract_segments_above_thresh(vigor, 0.6, min_duration=int(0.1/tail_dt), 
                                             pad_before=5, pad_after=5, skip_nan=True)
    # Calculate bouts start and end times:
    bout_starts = np.array([tail_log_df["tail_sum"].index[b[0]] for b in bouts_idxs])
    bout_ends = np.array([tail_log_df["tail_sum"].index[b[1]] for b in bouts_idxs])

    
    # Exclude trailing trial of gain 1:
    trial_s = stim_log_df.index[np.ediff1d(stim_log_df["closed loop 1D_gain"], to_begin=0) > 0][1:]
    trial_e = stim_log_df.index[np.ediff1d(stim_log_df["closed loop 1D_gain"], to_begin=0) < 0][1:]
    trial_g = stim_log_df["closed loop 1D_gain"][np.ediff1d(stim_log_df["closed loop 1D_gain"], to_begin=0) > 0].as_matrix()[1:]                    
    trial_df = pd.DataFrame(dict(start=trial_s, end=trial_e, gain=trial_g, inter_bout_t=np.nan, 
                                 trial_idx = [i for i in range(6) for _ in range(18)],
                                 bout_n=np.nan, bout_duration=np.nan, first_bout_latency=np.nan), 
                            index=np.arange(len(trial_s)))
    for i in range(len(trial_df)):
        bout_idxs = np.argwhere((bout_starts > trial_df.loc[i, "start"]) 
                                & (bout_ends < trial_df.loc[i, "end"]))[:,0]
        if len(bout_idxs) > 0:
            trial_df.loc[i, "bout_n"] = len(bout_idxs)
            trial_df.loc[i, "first_bout_latency"] = bout_starts[bout_idxs[0]] - trial_df.loc[i, "start"]
            trial_df.loc[i, "bout_duration"] = (bout_ends[bout_idxs] - bout_starts[bout_idxs]).mean()

            if len(bout_idxs) > 3:
                trial_df.loc[i, "inter_bout_t"] = (bout_starts[bout_idxs[1:]] - bout_ends[bout_idxs[:-1]]).mean()
                
    return trial_df
                
def cum_stat(stat_list, key):
    stat = np.array([np.nanmean(stat[key].as_matrix().reshape(6, 18), 0) for stat in stat_list]).T
    stat = stat/np.nanmean(stat[:3,:], 0)
    return stat

In [5]:
trial_stats = [get_exp_stats(exp) for exp in exps]

# Make figure

In [6]:
def back_color():
    plt.axvspan(-0.5, 2.5, color=(0.9,)*3)
    plt.axvspan(8.5, 11.5, color=(0.9,)*3)
    plt.axvspan(5.5, 8.5, color=(0.7,)*3)
    plt.axvspan(11.5, 14.5, color=(0.7,)*3)

In [7]:
k = 'inter_bout_t'
stat = cum_stat(trial_stats, k)
labels = ["fish {}".format(i) for i in range(stat.shape[1])]


Mean of empty slice



In [8]:
import plotly.graph_objs as go
from plotly import tools
import plotly.plotly as py
from plotly.graph_objs import Figure
from plotly.offline import init_notebook_mode, plot, iplot
init_notebook_mode(connected=True)

In [9]:

patches_x_cols = [(-1.5, 3.5, (229,)*3),
                  (9.5, 12.5, (229,)*3),
                  (6.5, 9.5, (178,)*3),
                  (12.5, 15.5,(178,)*3)]
shapes = []
for p in patches_x_cols:
    shapes.append(dict(x0=p[0], x1=p[1], fillcolor="rgb{}".format(p[2]),
                       xref="x", y0=0, y1=1, yref="paper", layer='below', line=dict(width=0)))
    
annotations = []
for i, g in enumerate([1, 0.5, 1.5, 1, 1.5, 0.5]):
    annotations.append(
        dict(x=i*3+2,
            y=0.95,
            xref='x',
            yref='paper',
            showarrow=False,
            text='gain {}'.format(g))
    )

In [26]:
data=[go.Scatter(x=np.arange(stat.shape[0])+1, y=stat[:, i], 
                 mode="lines", line=dict(color="rgb(100, 100, 100)", width=0.5),
                 hoverinfo='none'
                ) for i in range(stat.shape[1])]

data.append(go.Scatter(
        x=np.arange(stat.shape[0])+1,
        y= np.nanmean(stat, 1),
        error_y=dict(
            type='data',
            array=np.nanstd(stat, 1)/np.sqrt(stat.shape[1]-1),
            visible=True),
        line=dict(color="rgb(76, 114, 176)", width=2)))

updatemenus = list([
    dict(type="buttons",
         buttons=list([   
            dict(label = 'nogrid',
                 method = 'relayout',
                 args = ['xaxis1', dict(showgrid=True)]),
            dict(label = 'grid',
                 method = 'relayout',
                 args = ['xaxis1', dict(showgrid=True)])]))])


layout = dict(autosize=True, 
              font=dict(family="Droid Sans, sans-serif", 
                        size=14),
              height=600,
              width=800,
              annotations=annotations,
              shapes=shapes, 
              showlegend=False, 
              title="Inter-bout time with varying gains", 
              xaxis=dict(autorange=False, 
                         range=[0.5, 19.5], 
                         type="linear",
                         title="Trial n.",
                         tick0=2,
                         dtick=3,
                         showline=False, 
                         zeroline=False,
                         showgrid=False),
             yaxis=dict(autorange=False, 
                        range=[0.5, 1.5],
                        type="linear",
                        title="Relative inter-bout time",
                        showline=False, 
                        zeroline=False,
                        showgrid=False),
              paper_bgcolor='rgba(0,0,0,0)',
              plot_bgcolor='rgba(0,0,0,0)')

In [27]:
fig = Figure(data=data, layout=layout)
iplot(fig)

In [119]:
plot(fig, filename=r"C:\Users\lpetrucco\code\stytra\docs\figures\portugues2011_replication.html", include_plotlyjs=False)

'file://C:\\Users\\lpetrucco\\code\\stytra\\docs\\figures\\portugues2011_replication.html'