In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
import pathlib
import figurefirst
import style
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
import analysisStaySwitchDecoding
import analysisTunings
from scipy.spatial.distance import squareform, pdist
from collections import defaultdict
from utils import readSessions, fancyViz
import cmocean
from itertools import product
from matplotlib.gridspec import GridSpec
from matplotlib.backends.backend_pdf import PdfPages
plt.ioff()
style.set_context()

In [2]:
endoDataPath = pathlib.Path("data") / "endoData_2019.hdf"
alignmentDataPath = pathlib.Path("data") / "alignment_190227.hdf"
outputFolder = pathlib.Path("svg")
cacheFolder =  pathlib.Path("cache")
templateFolder = pathlib.Path("templates")

if not outputFolder.is_dir():
    outputFolder.mkdir()
if not cacheFolder.is_dir():
    cacheFolder.mkdir()

In [3]:
cachedDataPath = cacheFolder / 'staySwitchAUC.pkl'
if cachedDataPath.is_file():
    staySwitchAUC = pd.read_pickle(cachedDataPath)
else:
    staySwitchAUC = analysisStaySwitchDecoding.getWStayLSwitchAUC(endoDataPath)
    staySwitchAUC.to_pickle(cachedDataPath)
    
cachedDataPaths = [cacheFolder / name for name in ['actionValues.pkl',
                                                   'logRegCoefficients.pkl',
                                                   'logRegDF.pkl']]
if np.all([path.is_file() for path in cachedDataPaths]):
    actionValues = pd.read_pickle(cachedDataPaths[0])
    logRegCoef = pd.read_pickle(cachedDataPaths[1])
    logRegDF = pd.read_pickle(cachedDataPaths[2])
else:
    actionValues, logRegCoef, logRegDF = analysisStaySwitchDecoding.getActionValues(endoDataPath)
    actionValues.to_pickle(cachedDataPaths[0])
    logRegCoef.to_pickle(cachedDataPaths[1])
    logRegDF.to_pickle(cachedDataPaths[2])

tuningData = analysisTunings.getTuningData(endoDataPath)
tuningData['signp'] = tuningData['pct'] > .995
tuningData['signn'] = tuningData['pct'] < .005

actionValues.set_index(['genotype','animal','date','actionNo'], inplace=True)
actionValues.sort_index(inplace=True)
staySwitchAUC.set_index(['genotype','animal','date'], inplace=True)
staySwitchAUC.sort_index(inplace=True)

In [4]:
 def getActionMeans():
    actionMeans = pd.DataFrame()
    for s in readSessions.findSessions(endoDataPath, task='2choice'):
        #print(str(s))
        lfa = s.labelFrameActions(reward='fullTrial', switch=True, splitCenter=True)
        deconv = s.readDeconvolvedTraces(rScore=True).reset_index(drop=True)

        if not len(lfa) == len(deconv):
            print(str(s)+': more labeled frames than signal!')
            continue

        means = pd.DataFrame(deconv.groupby([lfa['label'], lfa['actionNo']]).mean().stack(),
                             columns=['trialMean'])
        means.index.names = means.index.names[:2] + ['neuron']
        
        #means['duration'] = lfa.groupby(['label','actionNo']).actionDuration.first()
        means.reset_index(inplace=True)
        means['action'] = means.label.str.slice(0,4)

        for k,v in [('date',s.meta.date),('animal',s.meta.animal),('genotype',s.meta.genotype)]:
            means.insert(0,k,v)

        actionMeans = actionMeans.append(means, ignore_index=True)
    return(actionMeans)

In [None]:
def getActionWindows(win_size=(10, 9)):
    windows = pd.DataFrame()
    for s in readSessions.findSessions(endoDataPath, task='2choice'):
        print(str(s))
        lfa = s.labelFrameActions(reward='fullTrial', switch=True, splitCenter=True)
        deconv = s.readDeconvolvedTraces(rScore=True).reset_index(drop=True)
        
        if not len(lfa) == len(deconv):
            print(str(s)+': more labeled frames than signal!')
            continue
    
        lfa['index'] = deconv.index
        deconv = deconv.set_index([lfa.label,lfa.actionNo], append=True)
        deconv.index.names = ['index','label','actionNo']
        deconv.columns.name = 'neuron'
        
        actions_idx = lfa.groupby('actionNo')[['index','actionNo','label']].first().values
    
        _windows = []
        neurons = deconv.columns
        for idx, actionNo, label in actions_idx:
            win = deconv.loc[idx-win_size[0]:idx+win_size[1]].reset_index()
            win.loc[win.actionNo > actionNo, neurons] = np.nan
            win.loc[win.actionNo < actionNo-1, neurons] = np.nan
            win['frameNo'] = np.arange(len(win))
            win['label'] = label
            win['actionNo'] = actionNo
            win = win.set_index(['actionNo','label','frameNo'])[neurons]
            win = win.unstack('frameNo').stack('neuron')
            win.columns = pd.MultiIndex.from_product([['frameNo'], win.columns])
            _windows.append(win.reset_index())
        _windows = pd.concat(_windows, ignore_index=True)
        
        for k,v in [('date',s.meta.date),('animal',s.meta.animal),('genotype',s.meta.genotype)]:
            _windows.insert(0,k,v)
            
        windows = windows.append(_windows, ignore_index=True)
    
    windows['action'] = windows.label.str.slice(0,4)
    return windows

In [151]:
actionMeans = getActionMeans()

a2a_3244_180410: more labeled frames than signal!


In [26]:
actionMeans.to_pickle('actionMeans.pkl')

In [4]:
actionMeans = pd.read_pickle('actionMeans.pkl')
actionWindows = pd.read_pickle('actionWindows.pkl')

In [5]:
actionMeans = actionMeans.set_index(['genotype','animal','date','neuron','actionNo']).sort_index()
actionWindows = actionWindows.set_index(['genotype','animal','date','neuron','actionNo']).sort_index()

In [6]:
def getPhaseLabels(phase):
    actions = [phase.replace('S', s) for s in 'LR']
    inclLabels = [actions[0]+tt for tt in ['r.','o.','o!']] + [actions[1]+tt for tt in ['o!','o.','r.']]
    return inclLabels

def getPlotData(genotype, animal, date, neuron):
    av = actionValues.loc[(genotype, animal, date)].copy()
    wins = actionWindows.loc[(genotype, animal, date, neuron)].copy()
    wins['value'] = av.value*-1
    means = actionMeans.loc[(genotype, animal, date, neuron)].copy()
    means['value'] = av.value*-1
    return (av, wins, means)

def avRegPlot(genotype, animal, date, neuron, phase='pC2S', ax=None):
    inclLabels = getPhaseLabels(phase)
    av, wins, means = getPlotData(genotype, animal, date, neuron)

    if not ax:
        ax = plt.gca()
    for l, ldata in means.loc[means.label.isin(inclLabels)].groupby('label'):
        sns.regplot('value', 'trialMean', data=ldata, ax=ax, fit_reg=True, ci=None,
                    color=style.getColor(l[-2:]), marker='<' if 'L' in l else '>',
                    scatter_kws={'alpha':.5, 's':10}, truncate=True, line_kws={}, x_bins=4)
        ax.errorbar(ldata.value.mean(), ldata.trialMean.mean(),
                    xerr=ldata.value.sem(), yerr=ldata.trialMean.sem(),
                    color=style.getColor(l[-2:]), marker='<' if 'L' in l else '>',
                    ms=5)
    sns.regplot('value', 'trialMean', data=means.loc[means.label.isin(inclLabels)],
                ax=ax, ci=None, scatter=False, color='k', truncate=True, order=2,
                line_kws={'alpha':1, 'zorder':-99, 'lw':.8})
    
def avAvgTracePlot(genotype, animal, date, neuron, phase='pC2S', compression=40, ax=None):
    inclLabels = getPhaseLabels(phase)
    #inclLabels2 = getPhaseLabels({'pC2S':'mC2S','mS2C':'pC2S','mC2S':'dS2C'}[phase])
    av, wins, means = getPlotData(genotype, animal, date, neuron)

    if not ax:
        ax = plt.gca()
    trans = mpl.transforms.blended_transform_factory(ax.transData, ax.transAxes)
    for l,ldata in wins.loc[wins.label.isin(inclLabels)].groupby('label'):
        x = np.array(ldata['frameNo'].columns.values/compression + ldata['value'].mean(), dtype='float')
        x_offset = -(len(x)//2)/compression #-(len(x)+1)/compression
        y = ldata['frameNo'].mean().values
        y[ldata['frameNo'].notna().sum(axis=0) < 10] = np.nan
        sem = ldata['frameNo'].sem().values
        sem[ldata['frameNo'].notna().sum(axis=0) < 10] = np.nan
        ax.fill_between(x + x_offset, y-sem, y+sem,
                        color=style.getColor(l[-2:]), lw=0, alpha=.3)
        ax.plot(x + x_offset, y, color=style.getColor(l[-2:]))
        ax.axvline(ldata['value'].mean(), ls=':', color='k')
        #ax.axvline(ldata['value'].mean() - 1/compression, ls=':', color='k')
        #ax.fill_between([(len(x)//2)/compression + ldata['value'].mean() + x_offset, 
        #                 ldata['value'].mean() - 1/compression], 0, 1,
        #                color='lightgray', alpha=.8, zorder=-99, transform=trans)
    #for l,ldata in wins.loc[wins.label.isin(inclLabels2)].groupby('label'):
    #    x = np.array(ldata['frameNo'].columns.values/compression + ldata.value.mean(), dtype='float')
    #    x_offset = 2/compression
    #    y = ldata['frameNo'].mean()
    #    sem = ldata['frameNo'].sem()
    #    ax.fill_between(x + x_offset, y-sem, y+sem,
    #                    color=style.getColor(l[-2:]), lw=0, alpha=.3)
    #    ax.plot(x + x_offset, y, color=style.getColor(l[-2:]))
    #    ax.axvline(ldata['value'].mean() + 1/compression, ls=':', color='k')
    #    ax.fill_between([ldata['value'].mean() + 1/compression, 
    #                     (len(x)//2)/compression + ldata['value'].mean() + x_offset], 0, 1,
    #                    color='lightgray', alpha=.8, zorder=-99, transform=trans)

def wstLswFigEight(genotype, animal, date, neuron, saturation='auto', axs=None):
    if not axs:
        return -1
    s = next(readSessions.findSessions('data/endoData_2019.hdf', task='2choice',
                                       genotype=genotype, animal=animal, date=date))
    lfa = s.labelFrameActions(switch=True, reward='fullTrial', splitCenter=True)
    deconv = s.readDeconvolvedTraces(rScore=True).reset_index(drop=True)
    if saturation == 'auto':
        inclLabels = np.concatenate([getPhaseLabels(p) for p in ['pS2C','mS2C','pC2S','mC2S','dS2C']])
        saturation = deconv.groupby(lfa['label']).mean().loc[inclLabels, neuron].max()
        saturation = np.ceil(saturation*2) / 2
    fv = fancyViz.SchematicIntensityPlot(s, splitReturns=False,
                                         linewidth=mpl.rcParams['axes.linewidth'],
                                         smoothing=8, saturation=saturation)
    plt.sca(axs[0])
    fv.setMask(lfa.label.str.contains('r\.$'))
    fv.draw(deconv[neuron])
    plt.sca(axs[1])
    fv.setMask(lfa.label.str.contains('o\.$'))
    fv.draw(deconv[neuron])
    plt.sca(axs[2])
    fv.setMask(lfa.label.str.contains('o\!$'))
    fv.draw(deconv[neuron])
    
def durationPlot(genotype, animal, date, phase='pC2S', ax=None):
    if not ax:
        ax = plt.gca()
    s = next(readSessions.findSessions('data/endoData_2019.hdf', task='2choice',
                                       genotype=genotype, animal=animal, date=date))
    lfa = s.labelFrameActions(switch=True, reward='fullTrial', splitCenter=True)
    inclLabels = getPhaseLabels(phase)
    lfa = lfa.loc[lfa['label'].isin(inclLabels)].copy()
    lfa = lfa.groupby('actionNo')[['label','actionDuration']].first()
    lfa['actionDuration'] /= 20
    sns.boxplot('label', 'actionDuration', orient='vertical', order=inclLabels, data=lfa, ax=ax,
                palette={l:style.getColor(l[-2:]) for l in inclLabels},
                showcaps=False,  showfliers=False, boxprops={'alpha':0.5, 'lw':0, 'zorder':-99}, 
                width=.55, whiskerprops={'c':'k','zorder':99}, medianprops={'c':'k','zorder':99})

In [8]:
for switch, phase in (product([True,False],
                              ['pL2C','pR2C','mL2C','mR2C','pC2L','pC2R',
                               'mC2L','mC2R','dL2C','dR2C'])):
    auc = staySwitchAUC.reset_index().set_index(['action','auc']).sort_index(ascending=switch).copy()
    auc = auc.loc[(auc.pct < .005) | (auc.pct > .995)]
    no = 200

    with PdfPages("svg/top{}_{}_{}_sortAUC.pdf".format(no, phase, 'switch' if switch else 'stay')) as pdf:
        for g,a,d,n in auc.loc[phase, ['genotype','animal','date','neuron']].values[:no]:
            fig = plt.figure(figsize=(7.5,3))
            gs = GridSpec(6, 3, hspace=.2, wspace=0.2)
            regAx = fig.add_subplot(gs[:3,1])
            avgAx = fig.add_subplot(gs[3:,1])
            regAx.get_shared_x_axes().join(regAx, avgAx)
            wstAx = fig.add_subplot(gs[:2,0])
            lstAx = fig.add_subplot(gs[2:4,0])
            lswAx = fig.add_subplot(gs[4:,0])
            durAx = fig.add_subplot(gs[:,2])

            avRegPlot(g,a,d,n,phase=phase.replace('L','S').replace('R','S'),ax=regAx)
            avAvgTracePlot(g,a,d,n,phase=phase.replace('L','S').replace('R','S'),compression=20,ax=avgAx)
            regAx.set_title('{} {} {} #{}'.format(g,a,d,int(n)), pad=0)
            regAx.set_xticks(())
            regAx.set_xlabel('')
            regAx.set_ylabel('sd')
            avgAx.set_ylabel('sd')
            sns.despine(bottom=True, ax=regAx, trim=True)
            sns.despine(ax=avgAx)

            wstLswFigEight(g,a,d,n,axs=[wstAx,lstAx,lswAx], saturation='auto') #regAx.get_ylim()[1])

            durationPlot(g,a,d,phase=phase.replace('L','S').replace('R','S'),ax=durAx)
            durAx.set_xticklabels(durAx.get_xticklabels(), rotation=90, ha='center')
            durAx.set_ylabel('action duration (s)')
            sns.despine(ax=durAx)

            pdf.savefig(bbox_inches='tight', pad_inches=0)
            #plt.show()
            plt.close('all')

  im = np.clip(im / saturation, -1, 1)
  im = np.clip(im / saturation, -1, 1)
  xa[xa < 0] = -1
  im = np.clip(im / saturation, -1, 1)
  im = np.clip(im / saturation, -1, 1)
  xa[xa < 0] = -1
  im = np.clip(im / saturation, -1, 1)
  im = np.clip(im / saturation, -1, 1)
  xa[xa < 0] = -1
  im = np.clip(im / saturation, -1, 1)
  im = np.clip(im / saturation, -1, 1)
  xa[xa < 0] = -1
  im = np.clip(im / saturation, -1, 1)
  im = np.clip(im / saturation, -1, 1)
  xa[xa < 0] = -1
  im = np.clip(im / saturation, -1, 1)
  im = np.clip(im / saturation, -1, 1)
  xa[xa < 0] = -1
  im = np.clip(im / saturation, -1, 1)
  im = np.clip(im / saturation, -1, 1)
  xa[xa < 0] = -1
  im = np.clip(im / saturation, -1, 1)
  im = np.clip(im / saturation, -1, 1)
  xa[xa < 0] = -1
  im = np.clip(im / saturation, -1, 1)
  im = np.clip(im / saturation, -1, 1)
  xa[xa < 0] = -1
  im = np.clip(im / saturation, -1, 1)
  im = np.clip(im / saturation, -1, 1)
  xa[xa < 0] = -1
  im = np.clip(im / saturation, -1, 1)
 