# Take snapshots of time-series for each file and save them to disk

In [1]:
%reset -f
import os
import matplotlib.pyplot as plt
import numpy as np

from data import DataLoader, find_segs
from db import make_session, d_models
from plot import shade, set_font_size

import CONFIG as C


SAVE_DIR = 'data_snapshots'
LABELS = ('behav', 'gcamp')
BEHAV_VARIABLES = [
    'speed', 'ball_speed', 'v_ang', 'air_tube'
]
GCAMP_VARIABLES = [
    'G2S', 'G3S', 'G4S', 'G5S', 'G2D', 'G3D', 'G4D', 'G5D'
]

### Define function that saves data time-series to PNG

In [2]:
def trial_to_png(
    save_file, trial, variables, width=30, ax_height=3,
    lw=2, colors=None, y_lims=None, y_tick_spacings=None,
    shading=None):
    """
    Save the complete time-series data for several quantities for
    a trial to a single PNG image for efficient viewing.
    
    :param shading: shading to apply on top of time-series traces; dict
        where keys are rgbas and vals are segments to shade
    """
    if colors is None:
        colors = {}
    if y_lims is None:
        y_lims = {}
    if y_tick_spacings is None:
        y_tick_spacings = {}
    if shading is None:
        shading == {}
        
    for v in variables:
        if v not in colors:
            colors[v] = 'k'
        
    # make sure save directory exists
    save_dir = os.path.dirname(save_file)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    try:
        dl = trial.dl
    except:
        dl = DataLoader(trial, vel_filt=None)
        
    # get quantities from trial
    ts = dl.timestamp_gcamp
    
    data = {}
    for v in variables:
        try:
            data[v] = getattr(dl, v)
        except:
            raise KeyError('Variable "{}" not found for trial.'.format(v))
    
    # arrange figure
    fig_size = (width, ax_height*len(variables))
    fig, axs = plt.subplots(
        len(variables), 1, figsize=fig_size, tight_layout=True, squeeze=False)
    
    for ax, v in zip(axs[:, 0], variables):

        color = colors[v]
        ax.plot(ts, data[v], color=color, lw=lw)
        ax.grid()
        
        ax.set_xlabel('t (s)')
        ax.set_ylabel(v)
        
        ax.set_xlim(ts[0], ts[-1])
        
        if v in y_lims:
            ax.set_ylim(y_lims[v])
            
        if v in y_tick_spacings:
            ax.set_yticks(np.arange(
                *ax.get_ylim(), y_tick_spacings[v]))
    
        # add shadings if desired
        if shading is not None:
            for rgba, segs in shading.items():
                shade(ax, segs, rgba)
            
        set_font_size(ax, 16)
        
    # save png
    fig.savefig(save_file)
    plt.close(fig)
    
    return save_file

### Basic behavioral variables

In [None]:
session = make_session()
trials = session.query(d_models.Trial).all()
session.close()

for trial in trials:
    print('Saving data snapshot for trial "{}"...'.format(trial.name))
    save_dir = os.path.join(
        SAVE_DIR, trial.fly,
        '{} ({})'.format(trial.name, trial.expt))
    
    for label, vs in zip(LABELS, [BEHAV_VARIABLES, GCAMP_VARIABLES]):
        
        save_path = os.path.join(save_dir, '{}.png'.format(label))
        trial_to_png(save_path, trial, vs)
        
        print('Snapshot saved at "{}".'.format(save_path))

### Close up of walking speed

In [None]:
session = make_session()
trials = session.query(d_models.Trial).all()
session.close()

for trial in trials:
    trial.dl = DataLoader(trial, vel_filt=None)
    
    if trial.walking_threshold is not None:
        states = trial.dl.states
        shading = {
            (0, 0, 0, 0.2): find_segs(states == 'A')*C.DT,
            (1, 0, 0, 0.2): find_segs(states == 'P')*C.DT,
            (0, 1, 0, 0.2): find_segs(states == 'W')*C.DT,
        }
    else:
        shading=None
    
    print('Saving data snapshot for trial "{}"...'.format(trial.name))
    save_dir = os.path.join(
        SAVE_DIR, trial.fly,
        '{} ({})'.format(trial.name, trial.expt))
    
    save_path = os.path.join(save_dir, 'speed.png')
    trial_to_png(
        save_path, trial, ['speed'], width=60, ax_height=8,
        y_lims={'speed': [0, None]}, y_tick_spacings={'speed': 0.005},
        shading=shading)

    print('Snapshot saved at "{}".'.format(save_path))

Saving data snapshot for trial "20160904.Fly4.2"...
Snapshot saved at "data_snapshots/20160904.Fly4/20160904.Fly4.2 (closed_loop)/speed.png".
Saving data snapshot for trial "20160904.Fly4.3"...
Snapshot saved at "data_snapshots/20160904.Fly4/20160904.Fly4.3 (closed_loop)/speed.png".
Saving data snapshot for trial "20160904.Fly4.6"...
Snapshot saved at "data_snapshots/20160904.Fly4/20160904.Fly4.6 (closed_loop)/speed.png".
Saving data snapshot for trial "20160904.Fly4.7"...
Snapshot saved at "data_snapshots/20160904.Fly4/20160904.Fly4.7 (closed_loop)/speed.png".
Saving data snapshot for trial "20160904.Fly5.3"...
Snapshot saved at "data_snapshots/20160904.Fly5/20160904.Fly5.3 (closed_loop)/speed.png".
Saving data snapshot for trial "20160904.Fly5.4"...
Snapshot saved at "data_snapshots/20160904.Fly5/20160904.Fly5.4 (closed_loop)/speed.png".
Saving data snapshot for trial "20160904.Fly6.1"...
Snapshot saved at "data_snapshots/20160904.Fly6/20160904.Fly6.1 (closed_loop)/speed.png".
Saving