# Take snapshots of time-series for each file and save them to disk

In [None]:
%reset -f
import os
import matplotlib.pyplot as plt
import numpy as np

from data import DataLoader, find_segs
from db import make_session, d_models
from plot import shade, set_font_size

import CONFIG as C


SAVE_DIR = 'data_snapshots'
LABELS = ('behav', 'gcamp')
BEHAV_VARIABLES = [
    ['speed'], ['ball'], ['v_ang'], ['air']
]
GCAMP_VARIABLES = [
    ['g2s'], ['g3s'], ['g4s'], ['g5s'], ['g2d'], ['g3d'], ['g4d'], ['g5d'],
]

In [None]:
session = make_session()
trials = session.query(d_models.Trial).all()
session.close()

for ctr, trial in enumerate(trials):
    print('Loading data for trial {}/{} ({})'.format(ctr+1, len(trials), trial.name))
    trials[ctr].dl = DataLoader(trial, sfx='0', vel_filt=None)

### Define function that saves data time-series to PNG

In [None]:
def trial_to_png(
    save_file, trial, variables, width=30, ax_height=3,
    lw=2, colors=None, y_lims=None, y_tick_spacings=None,
    shading=None):
    """
    Save the complete time-series data for several quantities for
    a trial to a single PNG image for efficient viewing.
    
    :param shading: shading to apply on top of time-series traces; dict
        where keys are rgbas and vals are segments to shade
    """
    if colors is None:
        colors = {}
    if y_lims is None:
        y_lims = {}
    if y_tick_spacings is None:
        y_tick_spacings = {}
    if shading is None:
        shading == {}
    
    # make sure each element of variables is a list of variables
    variables = [vs if hasattr(vs, '__iter__') else [vs] for vs in variables]
    
    # make sure each element of variables is at most 2 long
    if any([len(vs) > 2 for vs in variables]):
        raise ValueError('At most 2 variables are allowed per axis.')
        
    # make sure save directory exists
    save_dir = os.path.dirname(save_file)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        
    # get quantities from trial
    ts = trial.dl.t
    unique = list(set(sum(variables, [])))
    for v in unique:
        if v not in colors:
            colors[v] = 'k'
            
    data = {v: getattr(trial.dl, v) for v in unique}
    
    # arrange figure
    fig_size = (width, ax_height*len(variables))
    fig, axs = plt.subplots(
        len(variables), 1, figsize=fig_size, tight_layout=True, squeeze=False)
    
    for ax, vs in zip(axs[:, 0], variables):
        
        ax.set_xlabel('t (s)')
        
        ax.plot(ts, data[vs[0]], color=colors[vs[0]], lw=lw)
        ax.set_ylabel(vs[0], color=colors[vs[0]])
        
        ax.set_xlim(ts[0], ts[-1])
        ax.grid()
        
        if vs[0] in y_lims:
            ax.set_ylim(y_lims[vs[0]])
        
        if vs[0] in y_tick_spacings:
            ax.set_yticks(np.arange(
                *ax.get_ylim(), y_tick_spacings[vs[0]]))
            
        if shading is not None:
            for rgba, segs in shading.items():
                shade(ax, segs, rgba)
            
        set_font_size(ax, 16)
        
        if len(vs) == 1:
            continue
            
        ax_twin = ax.twinx()
        ax_twin.plot(ts, data[vs[1]], color=colors[vs[1]], lw=lw)
        ax_twin.set_ylabel(vs[1], color=colors[vs[1]])
    
        if vs[1] in y_lims:
            ax_twin.set_ylim(y_lims[vs[1]])
        
        if vs[1] in y_tick_spacings:
            ax_twin.set_yticks(np.arange(
                *ax_twin.get_ylim(), y_tick_spacings[vs[1]]))
            
        ax_twin.set_xlim(ts[0], ts[-1])
            
        set_font_size(ax_twin, 16)
        
    # save png
    fig.savefig(save_file)
    plt.close(fig)
    
    return save_file

### Basic behavioral and gcamp variables

In [None]:
for trial in trials:
    print('Saving data snapshot for trial "{}"...'.format(trial.name))
    save_dir = os.path.join(
        SAVE_DIR, trial.fly,
        '{} ({})'.format(trial.name, trial.expt))
    
    for label, vs in zip(LABELS, [BEHAV_VARIABLES, GCAMP_VARIABLES]):
        
        save_path = os.path.join(save_dir, '{}.png'.format(label))
        trial_to_png(save_path, trial, vs)
        
        print('Snapshot saved at "{}".'.format(save_path))

### Close up of walking speed

In [None]:
session = make_session()
trials = session.query(d_models.Trial).all()
session.close()

for trial in trials:
    trial.dl = DataLoader(trial, sfx='0', vel_filt=None)
    
    if trial.walking_threshold is not None:
        states = trial.dl.state
        ts_ = np.concatenate([
            trial.dl.t,
            [trial.dl.t[-1] + C.DT]
        ])
        rgbas = [(0, 0, 0, 0.2), (1, 0, 0, 0.2), (0, 1, 0, 0.2)]
        labels = ['A', 'P', 'W']
        
        shading = {}
        for rgba, label in zip(rgbas, labels):
            segs_idx = find_segs(states == label)
            segs = np.nan * np.ones(segs_idx.shape)
            segs[:, 0] = ts_[segs_idx[:, 0]]
            segs[:, 1] = ts_[segs_idx[:, 1]]
            
            shading[rgba] = segs
    else:
        shading=None
    
    print('Saving data snapshot for trial "{}"...'.format(trial.name))
    save_dir = os.path.join(
        SAVE_DIR, trial.fly,
        '{} ({})'.format(trial.name, trial.expt))
    
    save_path = os.path.join(save_dir, 'speed.png')
    trial_to_png(
        save_path, trial, [['speed']], width=60, ax_height=8,
        y_lims={'speed': [0, None]}, y_tick_spacings={'speed': 0.005},
        shading=shading)

    print('Snapshot saved at "{}".'.format(save_path))

### Speed overlaid with GS

In [None]:
vs = [
    ['speed', 'G2S'],
    ['speed', 'G3S'],
    ['speed', 'G4S'],
    ['speed', 'G5S'],
]
colors = {'speed': 'k', 'G2S': 'm', 'G3S': 'm', 'G4S': 'm', 'G5S': 'm'}

for trial in trials:
    print('Saving data snapshot for trial "{}"...'.format(trial.name))
    save_dir = os.path.join(
        SAVE_DIR, trial.fly,
        '{} ({})'.format(trial.name, trial.expt))
    
    save_path = os.path.join(save_dir, 'speed_vs_GS.png')
    trial_to_png(save_path, trial, variables=vs, colors=colors)
        
    print('Snapshot saved at "{}".'.format(save_path))

### Write js file listing all trials and image names

In [None]:
TRIAL_FILE = 'js_viewer/trials.js'

# get all files in save directory
sub_dirs = [
    os.path.join(SAVE_DIR, d) for d in os.listdir(SAVE_DIR)
    if os.path.isdir(os.path.join(SAVE_DIR, d))
    and not d.startswith('.')]

trial_keys = []
trial_paths = []

for sub_dir in sub_dirs:
    trial_keys_ = [
        d for d in os.listdir(sub_dir)
        if os.path.isdir(os.path.join(sub_dir, d))
        and not d.startswith('.')
    ]
    
    trial_paths_ = [os.path.join(sub_dir, tk) for tk in trial_keys_]
    
    trial_keys.extend(trial_keys_)
    trial_paths.extend(trial_paths_)

In [None]:
# open file
with open(TRIAL_FILE, 'w') as f:
    f.write('var trials = {};\n')
    
    for tk, tp in zip(trial_keys, trial_paths):
        f.write('trials["{}"] = "{}";\n'.format(tk, tp))
        
    f.write('\n')
    f.write('var plot_types = [\n')
    
    for plot_type in ('behav', 'gcamp', 'speed_vs_GS', 'speed_vs_GD'):
        f.write('  "{}",\n'.format(plot_type))
    f.write(']\n')