In this notebook we examine whether the G3 compartment represents only the walking/paused state of the fly, or whether it contains information about the fine-scale walking speed of the fly as well.

# Classify walking/paused states using G3

Here we fit a logistic regression classifier to predict the walking ('W') or paused ('P') states for each trial with labeled states.

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LogisticRegression

from aux import make_extended_predictor_matrix
from data import DataLoader


WINDOW = (-2, 2)

def classify_states(trial, dans, windows):
    """
    Fit a classifier predicting a trial's paused vs. walking states
    using activity from a set of neural compartments.
    
    :param trial: trial object
    :param dans: list of DAN compartments to use for prediction
    :param windows: dict of windows to use for each predictive dan
        (keys are DAN names, vals are tuples of (start, end) time
            points relative to time point of prediction)
            
    :return: ClassifierResult
    """
    
    try:
        dl = trial.dl
    except:
        dl = DataLoader(trial, vel_filt=None)
        
    states = dl.states
    vs = {dan: getattr(dl, dan) for dan in dans}
    
    # make extended dan predictor matrix
    vs_extd = make_extended_predictor_matrix(vs, windows, order=dans)
    
    # make valid mask (not nan predictors and not ambiguous state)
    valid = np.all(~np.isnan(vs_extd), axis=1) & (states != 'A')
    
    if len(set(states[valid])) == 2:
        # at least two states
        
        # fit classifier
        clf = LogisticRegression(n_jobs=-1)
        clf.fit(vs_extd[valid], states[valid])
        
        # make state predictions
        states_pred = np.repeat('', len(states))
        states_pred[valid] = clf.predict(vs_extd[valid])
        
        # get coefficients
        window_lens = [windows[d][1] - windows[d][0] for d in dans]
        splits = np.split(clf.coef_[0], np.cumsum(window_lens)[:-1])
        
        coefs = {d: s for d, s in zip(dans, splits)}
        
    else:
        # fewer than two states
        clf = None
        states_pred = None
        coefs = None
        
    result = ClassifierResult(
        trial_id=trial.id, dans=dans, windows=windows, valid=valid,
        states_pred=states_pred, states=states, coefs=coefs, clf=clf)
    
    return result


class ClassifierResult(object):
    
    def __init__(self, trial_id, dans, windows, valid, states_pred, states, coefs, clf):
        self.trial_id = trial_id,
        self.dans = dans
        self.window = windows
        self.valid = valid
        self.states_pred = states_pred
        self.states = states
        self.coefs = coefs
        self.clf = clf
        
        # get number of valid time points
        self.n_valid = valid.sum()
        
        # get fraction of valid time points in walking state
        self.walk_frac = np.sum(states[valid] == 'W') / self.n_valid
        
        # get accuracy (fraction of correctly labeled valid states)
        if states_pred is not None:
            self.acc = np.sum(states_pred[valid] == states[valid]) / self.n_valid
        else:
            self.acc = None

# Predict walking speed during walking states using G3