In this notebook we explore the high variability in predictive accuracy that we see across timepoints/trials. Specifically, we ask if there are other DAN/behavioral features that correlate with such predictive accuracy. For example, maybe when G4S predicts A it doesn't predict B, and when it predicts B it doesn't predict A.

# G4S prediction of speed

We first investigate what factors determine how well G4S predicts the fly's current walking speed.

In [1]:
import matplotlib.pyplot as plt
import numpy as np

from data import DataLoader
from db import make_session, d_models
import lin_fit
import os
from plot import shade, set_font_size

import CONFIG as C
import LOCAL as L
import PARAMS as P


SAVE_DIR = 'data_snapshots'
SAVE_FILE = 'g4s_vs_speed.png'
FIG_SIZE = (60, 15)

WINDOWS_SPEED = {'g4s': (-1, 2)}
WINDOWS_G4S = {'g2s': (-1, 2), 'g3s': (-1, 2), 'g5s': (-1, 2)}

In [2]:
# get all trials with labeled states
FILT = [d_models.Trial.walking_threshold != None]

session = make_session()
trials = session.query(d_models.Trial).filter(*FILT).all()
session.close()

### Make plots of true vs. predicted speed and true vs. predicted G4S

In [3]:
# loop through trials
for trial in trials:
    print('Loading data from trial "{}"'.format(trial.name))
    
    trial.dl = DataLoader(trial, 0, None)
    valid = trial.dl.state == 'W'
    
    # make plots
    fig, axs = plt.subplots(4, 1, figsize=FIG_SIZE, tight_layout=True)
    
    # true speed vs. speed predicted from G4S
    rslt_speed = lin_fit.regress(
        trial, targ='speed', preds=WINDOWS_SPEED.keys(),
        windows=WINDOWS_SPEED, valid=valid)
    
    axs[0].plot(trial.dl.t, rslt_speed.ys, color='k', lw=2)
    axs[0].plot(trial.dl.t, rslt_speed.ys_pred, color='b', lw=2)
    axs[0].legend(['True', 'Predicted'])
    
    speed_invalid = trial.dl.speed.copy()
    speed_invalid[valid] = np.nan
    axs[0].plot(trial.dl.t, speed_invalid, color='gray', alpha=0.3, lw=2)
    
    axs[0].set_ylabel('speed')
    axs[0].set_title('True speed vs. G4S-predicted speed')
    
    # true G4S vs. G4S predicted from other DANs
    rslt_g4s = lin_fit.regress(
        trial, targ='g4s', preds=WINDOWS_G4S.keys(),
        windows=WINDOWS_G4S, valid=valid)
    
    axs[1].plot(trial.dl.t, rslt_g4s.ys, color='k', lw=2)
    axs[1].plot(trial.dl.t, rslt_g4s.ys_pred, color='b', lw=2)
    axs[1].legend(['True', 'Predicted'])
    
    g4s_invalid = trial.dl.g4s.copy()
    g4s_invalid[valid] = np.nan
    axs[1].plot(trial.dl.t, g4s_invalid, color='gray', alpha=0.3, lw=2)
    
    axs[1].set_ylabel('g4s')
    axs[1].set_title('True G4S vs. G4s predicted from other DANs')
    
    # true speed minus predicted speed
    axs[2].plot(trial.dl.t, rslt_speed.ys-rslt_speed.ys_pred, color='k', lw=2)
    axs[2].axhline(0, color='gray', alpha=0.5, ls='--')
    axs[2].set_ylabel('diff')
    axs[2].set_title('True speed minus G4S-predicted speed')
    
    # true G4S minus predicted G4S
    axs[3].plot(trial.dl.t, rslt_g4s.ys-rslt_g4s.ys_pred, color='k', lw=2)
    axs[3].axhline(0, color='gray', alpha=0.5, ls='--')
    axs[3].set_ylabel('diff')
    axs[3].set_title('True G4S minus G4S predicted from other DANs')
    
    for ax in axs:
        ax.set_xlim(trial.dl.t[0], trial.dl.t[-1])
        ax.grid()
        
        ax.set_xlabel('Time (s)')
        set_font_size(ax, 16)
        
    # save figure
    save_dir = os.path.join(SAVE_DIR, trial.fly, '{} ({})'.format(trial.name, trial.expt))
    save_path = os.path.join(save_dir, SAVE_FILE)
    
    fig.savefig(save_path)
    plt.close(fig)

Loading data from trial "20170310.Fly2.6"
Loading clean data from file "clean_0.csv"...


  if valid == 'all':
  elif valid == 'none':


Loading data from trial "20170310.Fly3.1"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170310.Fly3.2"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170313.Fly1.1"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170313.Fly1.2"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170313.Fly1.3"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170313.Fly1.4"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170313.Fly1.8"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170313.Fly3.1"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170313.Fly3.2"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170313.Fly3.3"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170330.Fly1.1"
Loading clean data from file "clean_0.csv"...
Loading data from trial "2017033

### Make plots of true vs. predicted speed, and absolute air tube angle

In [7]:
# loop through trials
for trial in trials:
    print('Loading data from trial "{}"'.format(trial.name))
    
    trial.dl = DataLoader(trial, 0, None)
    valid = trial.dl.state == 'W'
    
    # make plots
    fig, axs = plt.subplots(3, 1, figsize=(60, 12), tight_layout=True)
    
    # true speed vs. speed predicted from G4S
    rslt_speed = lin_fit.regress(
        trial, targ='speed', preds=WINDOWS_SPEED.keys(),
        windows=WINDOWS_SPEED, valid=valid)
    
    axs[0].plot(trial.dl.t, rslt_speed.ys, color='k', lw=2)
    axs[0].plot(trial.dl.t, rslt_speed.ys_pred, color='b', lw=2)
    axs[0].legend(['True', 'Predicted'])
    
    speed_invalid = trial.dl.speed.copy()
    speed_invalid[valid] = np.nan
    axs[0].plot(trial.dl.t, speed_invalid, color='gray', alpha=0.3, lw=2)
    
    axs[0].set_ylabel('speed')
    axs[0].set_title('True speed vs. G4S-predicted speed')
    
    # true speed minus predicted speed
    axs[1].plot(trial.dl.t, rslt_speed.ys-rslt_speed.ys_pred, color='k', lw=2)
    axs[1].axhline(0, color='gray', alpha=0.5, ls='--')
    axs[1].set_ylabel('diff')
    axs[1].set_title('True speed minus G4S-predicted speed')
    
    # air tube
    air_valid = trial.dl.air.copy()
    air_valid[~valid] = np.nan
    
    air_invalid = trial.dl.air.copy()
    air_invalid[valid] = np.nan
    
    axs[2].plot(trial.dl.t, np.abs(air_valid), color='k', lw=2)
    axs[2].plot(trial.dl.t, np.abs(air_invalid), color='gray', alpha=0.3, lw=2)
    axs[2].set_ylabel('angle (deg)')
    axs[2].set_title('Absolute air tube angle')
    
    for ax in axs:
        ax.set_xlim(trial.dl.t[0], trial.dl.t[-1])
        ax.grid()
        
        ax.set_xlabel('Time (s)')
        set_font_size(ax, 16)
        
    # save figure
    save_dir = os.path.join(SAVE_DIR, trial.fly, '{} ({})'.format(trial.name, trial.expt))
    save_path = os.path.join(save_dir, 'g4s_vs_speed_1.png')
    
    fig.savefig(save_path)
    plt.close(fig)

Loading data from trial "20170310.Fly2.6"
Loading clean data from file "clean_0.csv"...


  if valid == 'all':
  elif valid == 'none':


Loading data from trial "20170310.Fly3.1"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170310.Fly3.2"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170313.Fly1.1"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170313.Fly1.2"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170313.Fly1.3"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170313.Fly1.4"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170313.Fly1.8"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170313.Fly3.1"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170313.Fly3.2"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170313.Fly3.3"
Loading clean data from file "clean_0.csv"...
Loading data from trial "20170330.Fly1.1"
Loading clean data from file "clean_0.csv"...
Loading data from trial "2017033