In [1]:
%matplotlib notebook
from copy import deepcopy
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats
from sklearn import linear_model

from disp import set_font_size


DATA_DIR = 'data/CL_360_LOWFLOW_ACV'

# Single trial examples

## Print trial list

In [2]:
TRIAL_PATHS = {}

for fly in os.listdir(DATA_DIR):
    for trial in os.listdir(os.path.join(DATA_DIR, fly)):
        
        trial_path = os.path.join(fly, trial)
        TRIAL_PATHS[trial] = trial_path
        
        print('\'{}\''.format(trial))
        
    print('')
    
del trial_path

'20180119.Fly9.2'

'20180201.Fly3.2'
'20180201.Fly3.4'

'20180131.Fly1.3'
'20180131.Fly1.2'

'20180120.Fly1.2'
'20180120.Fly1.1'
'20180120.Fly1.3'

'20170826.Fly4.1'

'20170828.Fly2.2'
'20170828.Fly2.6'

'20170822.Fly1.1'

'20170828.Fly4.3'

'20180119.Fly10.1'

'20180125.Fly6.1'

'20170826.Fly7.2'

'20180131.Fly9.3'

'20180124.Fly4.2'
'20180124.Fly4.3'

'20180131.Fly8.1'
'20180131.Fly8.2'
'20180131.Fly8.4'
'20180131.Fly8.3'

'20180119.Fly7.2'

'20180131.Fly7.2'
'20180131.Fly7.3'

'20170821.Fly1.1'

'20180131.Fly4.3'
'20180131.Fly4.2'

'20180125.Fly3.3'
'20180125.Fly3.4'



## List time series columns

In [3]:
COLUMNS = []

data = pd.read_csv(os.path.join(DATA_DIR, list(TRIAL_PATHS.values())[0], 'clean.csv'))

for column in data.columns:
    COLUMNS.append(column)
    print(column)
    
del column
del data

Time
xpos
ypos
ForVel
AngVel
LatVel
Motion
Heading
ForVel_Conv
AngVel_Conv
LatVel_Conv
Motion_Conv
G2_Red_R
G3_Red_R
G4_Red_R
G5_Red_R
G2_Red_L
G3_Red_L
G4_Red_L
G5_Red_L
G2_Green_R
G3_Green_R
G4_Green
G5_Green_R
G2_Green_L
G3_Green_L
G4_Green_L
G5_Green_L
G2_R
G3_R
G4_R
G5_R
G2_L
G3_L
G4_L
G5_L
G2_avg
G3_avg
G4_avg
G5_avg


## Plot example multivariate time series

### Define plot fn

In [4]:
def plot_vars(trial, t_lim, *vs, **kwargs):
    """Plot variables for single trial."""
    
    cols = [v[0] for v in vs]  # columns
    cs = [v[1] for v in vs]  # colors
    
    for col in cols:
        assert col in COLUMNS

    # load data
    data = pd.read_csv(os.path.join(DATA_DIR, TRIAL_PATHS[trial], 'clean.csv'))
    t = data['Time']
    
    # load odor
    odor = pd.read_csv(os.path.join(DATA_DIR, TRIAL_PATHS[trial], 'odor_times.csv'))
    
    fig_size = (9, 2*len(vs))
    fig, axs = plt.subplots(len(vs), 1, figsize=fig_size, tight_layout=True, sharex=True, squeeze=False)
    
    title = '{} - {}'.format(kwargs['title'], trial) if 'title' in kwargs else trial
    axs[0, 0].set_title(title)
    axs[-1, 0].set_xlabel('Time (s)')
    
    for col, c, ax in zip(cols, cs, axs[:, 0]):
        
        # time-series var
        ax.plot(t, data[col], color=c, lw=1)
        
        # odor
        for p_ctr in range(len(odor)):
            
            pulse = odor.iloc[p_ctr]
            
            p_start = pulse['Odor_On']
            p_end = pulse['Odor_Off']
            
            p_c = 'r' if pulse['Include'] else 'k'
            
            ax.axvspan(p_start, p_end, color=p_c, alpha=0.2)
            
        ax.set_xlim(t_lim)
        
        ax.set_ylabel(col)
        
        ax.grid()
        
        set_font_size(ax, 16)
        
    return fig, axs

In [5]:
plot_vars(
    '20180124.Fly4.3',
    (0, 300),
    ('ForVel', 'k'),
    ('Heading', 'r'),
    ('G4_avg', 'g'),
    title='example tracking',
);

<IPython.core.display.Javascript object>

In [6]:
plot_vars(
    '20180131.Fly8.2',
    (85, 115),
    ('ForVel', 'k'),
    ('Heading', 'r'),
    title='example tracking',
);

<IPython.core.display.Javascript object>

# Odor-triggered metrics

## Define metrics

In [7]:
def calc_metrics(trial, data, pulse):
    """Calculate metrics for each odor pulse."""
    
    p_on = pulse['Odor_On']
    p_off = pulse['Odor_Off']
    
    rslt = {}
    
    rslt['Trial'] = trial
    rslt['On'] = p_on
    rslt['Off'] = p_off
    rslt['Include'] = 'Include' if pulse['Include'] else 'Exclude'
    
    # Preliminary
    t = np.array(data['Time'])
    
    ###### METRIC ######
    # dF/F G4_avg t-avg'd from 1 to 4 s after odor onset
    t_mask_0 = (p_on-15 <= t) & (t < p_on)
    f_0 = np.nanmean(np.array(data['G4_avg'][t_mask_0]))
    
    t_mask = (p_on+1 <= t) & (t < p_on+4)
    rslt['d_g4_1_4'] = (np.nanmean(np.array(data['G4_avg'])[t_mask]) - f_0) / f_0
    
    del t_mask_0, f_0, t_mask
    
    ###### METRIC ######
    # abs heading t-avg'd from 5 to 0 s before odor pulse
    t_mask = (p_on-5 <= t) & (t < p_on)
    rslt['h_neg5_0'] = np.nanmean(np.abs(data['Heading'][t_mask]))
    
    del t_mask
    
    ###### METRIC ######
    # change in abs heading from h_neg5_0 t-avgd from 1 to 4 s after odor pulse
    t_mask = (p_on+1 <= t) & (t < p_on+4)
    rslt['d_h_1_4'] = np.nanmean(np.abs(data['Heading'][t_mask])) - rslt['h_neg5_0']
    
    del t_mask
    
    ###### METRIC ######
    # abs heading t-avgd from 6 to 9 s after odor pulse
    t_mask = (p_on+6 <= t) & (t < p_on+9)
    rslt['d_h_6_9'] = np.nanmean(np.abs(data['Heading'][t_mask])) - rslt['h_neg5_0']
    
    del t_mask
    
    return rslt

# Nested model analysis

## Compute all metrics

In [8]:
metrics_dicts = []

for trial, trial_path in TRIAL_PATHS.items():
    
    data = pd.read_csv(os.path.join(DATA_DIR, TRIAL_PATHS[trial], 'clean.csv'))
    odor = pd.read_csv(os.path.join(DATA_DIR, TRIAL_PATHS[trial], 'odor_times.csv'))
    
    # loop over odor pulses
    for p_ctr in range(len(odor)):
        pulse = odor.iloc[p_ctr]
        
        metrics_dicts.append(calc_metrics(trial, data, pulse))

metrics = pd.DataFrame.from_records(metrics_dicts)
metrics.index.name = 'Pulse'

# re-order columns
cols = list(metrics.columns)
cols = [cols.pop(cols.index(col)) for col in ['Trial', 'On', 'Off', 'Include']] + cols
metrics = metrics[cols]

# add control vars to test stats analysis
x_c_1, x_c_2, x_c_3, x_c_4 = np.random.normal(0, 1, (4, len(metrics)))
y_c = x_c_1 + x_c_2 + x_c_3 + np.random.normal(0, 1, len(metrics))

metrics['x_c_1'] = x_c_1
metrics['x_c_2'] = x_c_2
metrics['x_c_3'] = x_c_3
metrics['x_c_4'] = x_c_4
metrics['y_c'] = y_c

metrics



Unnamed: 0_level_0,Trial,On,Off,Include,d_g4_1_4,d_h_1_4,d_h_6_9,h_neg5_0,x_c_1,x_c_2,x_c_3,x_c_4,y_c
Pulse,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
0,20180119.Fly9.2,60,70,Include,0.469526,-1.046928,0.414178,1.500978,-1.844741,-0.858629,0.356305,0.509369,-1.894292
1,20180119.Fly9.2,100,110,Include,0.071209,0.345387,0.256369,1.833332,0.133960,1.150264,-2.918795,1.218657,-0.652574
2,20180119.Fly9.2,140,150,Exclude,0.087130,0.666131,-0.422483,1.690513,0.961047,0.847334,0.870729,1.081745,3.094463
3,20180119.Fly9.2,180,190,Exclude,0.349216,-0.522086,0.377496,1.345519,-0.749247,-0.464221,-1.390984,-0.455758,-2.121499
4,20180119.Fly9.2,220,230,Include,0.119535,-0.189434,0.054875,1.534802,0.133537,-0.518112,-1.658388,-1.486008,-2.267963
5,20180201.Fly3.2,60,70,Exclude,0.642802,0.146048,-1.569773,2.573375,-0.204353,-0.776904,1.168909,0.502296,1.119294
6,20180201.Fly3.2,100,110,Include,0.645862,-0.753570,-0.588931,1.694850,1.000808,-0.421768,0.296028,-1.382557,0.893828
7,20180201.Fly3.2,140,150,Exclude,0.522623,0.335778,0.409868,1.117447,1.717207,-0.189610,-0.184568,-0.981201,1.995097
8,20180201.Fly3.2,180,190,Include,0.034480,-0.604111,-1.483817,2.252089,-0.462758,-0.064607,-1.075220,1.561349,-1.425166
9,20180201.Fly3.2,220,230,Include,0.333569,1.316419,0.086489,0.734761,-0.103148,0.798635,-0.192488,0.257536,-0.125392


## Specify and fit models

In [11]:
COLS_X_RED = ['h_neg5_0', 'd_h_1_4']
COLS_X_FULL = ['h_neg5_0', 'd_h_1_4', 'd_h_6_9']
#COLS_X_RED = ['h_neg5_0']
#COLS_X_FULL = ['h_neg5_0', 'd_h_6_9']
COL_Y = 'd_g4_1_4'

# UNCOMMENT TO RUN CONTROL ANALYSIS
#COLS_X_RED = ['x_c_1', 'x_c_2']
#COLS_X_FULL = ['x_c_1', 'x_c_2', 'x_c_3']
#COLS_X_FULL = ['x_c_1', 'x_c_2', 'x_c_4']
#COL_Y = 'y_c'

COLS_ALL = list(set(COLS_X_RED + COLS_X_FULL + [COL_Y]))

for col in COLS_ALL:
    assert col in metrics
assert COL_Y not in COLS_X_RED
assert COL_Y not in COLS_X_FULL

# helper function
def fit_line(x, y):
    slp, icpt = stats.linregress(x, y)[:2]
    x_line = np.array([x.min(), x.max()])
    y_line = slp*x_line + icpt
    
    return x_line, y_line

# filter data
metrics_filt = metrics[COLS_ALL][metrics['Include'] == 'Include'].dropna()

# fit nested models and plot predictions
fig, axs = plt.subplots(1, 2, figsize=(9, 4.5), sharex=True, sharey=True, tight_layout=True)

## targ
y = metrics_filt[COL_Y]

## reduced

### data
x_red = metrics_filt[COLS_X_RED]

### fit
m_red = linear_model.LinearRegression().fit(x_red, y)

### prediction and R2
y_hat_red = m_red.predict(x_red)
r2_red = m_red.score(x_red, y)

### plot
axs[0].scatter(y_hat_red, y)
axs[0].plot(*fit_line(y_hat_red, y), c='r', lw=2)
axs[0].set_title('Reduced (R = {0:.4f})'.format(r2_red**.5))
axs[0].set_xlabel('Pred: {}'.format(COL_Y))
axs[0].set_ylabel(COL_Y)

## full

### data
x_full = metrics_filt[COLS_X_FULL]

### fit

### prediction and R2
m_full = linear_model.LinearRegression().fit(x_full, y)

y_hat_full = m_full.predict(x_full)
r2_full = m_full.score(x_full, y)

### plot
axs[1].scatter(y_hat_full, y)
axs[1].plot(*fit_line(y_hat_full, y), c='r', lw=2)
axs[1].set_title('Full (R = {0:.4f})'.format(r2_full**.5))
axs[1].set_xlabel('Pred. {}'.format(COL_Y))
axs[1].set_ylabel(COL_Y)

for ax in axs:
    set_font_size(ax, 14)
    
# plot model coefs
fig, axs = plt.subplots(1, 2, figsize=(9, 6), sharey=True, tight_layout=True)

x_bar = np.arange(1 + len(m_full.coef_))

bar_red = [m_red.intercept_] + list(m_red.coef_) + [0]
axs[0].bar(x_bar, bar_red, align='center', width=0.8)
axs[0].set_xticks(x_bar)
axs[0].set_xticklabels(['icpt'] + COLS_X_RED + [''], rotation=90)
axs[0].set_title('Coefs (Reduced)\nTarg: {}'.format(COL_Y))

bar_full = [m_full.intercept_] + list(m_full.coef_)
axs[1].bar(x_bar, bar_full, align='center', width=0.8)
axs[1].set_xticks(x_bar)
axs[1].set_xticklabels(['icpt'] + COLS_X_FULL, rotation=90)
axs[1].set_title('Coefs (Full)\nTarg: {}'.format(COL_Y))

for ax in axs:
    ax.grid()
    set_font_size(ax, 14)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## Compute statistics

### Calculate F statistic

In [12]:
# RSS
rss_red = np.sum((y - m_red.predict(x_red))**2)
rss_full = np.sum((y - m_full.predict(x_full))**2)

# params
p_red = len(COLS_X_RED) + 1
p_full = len(COLS_X_FULL) + 1

# num samps
n = len(y)

# degs freedom
dfn = p_full - p_red
dfd = n - p_full

# calc F
f = ((rss_red - rss_full) / (p_full - p_red)) / (rss_full / (n - p_full))

# calc p-val
p_val = stats.f.sf(f, dfn, dfd)

print('F = {}'.format(f))
print('P = {}'.format(p_val))

F = 4.913234312863378
P = 0.02932217482316981
