### Reqirements
* #### You need to install module future, manual importing from \_\_future\_\_ is at your convenience
* #### For hdf data import you need pytables too which is not default installed with Anaconda

### Batch execution
* #### ```batch_animal=msaxxyy_z jupyter nbconvert Stat.ipynb --to=html --execute --ExecutePreprocessor.timeout=-1 --output=xxyy_z_report.html```

In [None]:
#from future.utils import PY3
import future
from __future__ import (absolute_import, division,
                        print_function, unicode_literals)
import pandas as pd
import numpy as np
import time, os, warnings, imp, itertools
import IPython.display as disp
display = disp.display
import matplotlib as mpl, matplotlib.pyplot as plt
import scipy.stats as stats
zscore, describe = stats.mstats.zscore, stats.describe
import datetime
dt, td = datetime.datetime, datetime.timedelta

%matplotlib inline

In [None]:
import ca_lib as la
imp.reload(la)

In [None]:
from os import environ
batch_animal = environ.get('batch_animal', None)

## Load files

In [None]:
basedir = '../_share/Losonczi/'

# Display database folders
display(os.listdir(basedir))

# Select animal
if batch_animal is None:
    animal = 'msa1215_1'; FPS = 30
    #animal = 'msa0216_4'; FPS = 8
    #animal = 'msa0316_1'; FPS = 8
    #animal = 'msa0316_3'; FPS = 8
    #animal = 'msa0316ag_1'; FPS = 8
else:
    FPS = None
    animal = batch_animal

print ('selecting',animal)

# List dir
mydir = os.path.join(basedir,animal)
os.listdir(mydir)

In [None]:
# Available trials and ROIs
data = la.load_files(mydir)
if (FPS is not None) and (data.FPS != FPS):
    warnings.warn('FPS indication might be wrong.')
print (data.raw.shape, '\n', data.trials, '\n', data.rois)

In [None]:
# Post-Learning may repeat session_num therefore an additional index,
# day_num is created. See msa0316_1.
# It seems though that Pre-Learning and Learning treats session_num as documented.
display(data.experiment_traits.head())
display(data.experiment_traits[data.experiment_traits['day_leap']])

## Experiment protocol configurations

In [None]:
et = data.experiment_traits.copy()
et = la.df_epoch(et.groupby(la.display_learning).size().to_frame(name='count'))
#et.to_clipboard()
disp.display(disp.HTML('<font color="red">ATTENTION, </font>for later conformity we store columns in a <b>different order</b>: %s !!!'%la.sort_learning))
display(la.df_epoch(et))

et = data.experiment_traits.copy()
etc= et.groupby(la.sort_learning[1:]).size().to_frame(name='count')
et = la.df_epoch(et.groupby(la.sort_learning).size().to_frame(name='count'))

## Prepare data

In [None]:
df_data = data.filtered
df_raw = data.raw

In [None]:
# See how many ROIs are available for which frames

avail_sum = (~data.filtered.isnull()).sum() / len(data.trials)
plt.plot(avail_sum)
plt.xlabel('Camera frame within experiment')
plt.ylabel('Available ROIs on average')

In [None]:
# See which ROI is available in which trial and for how many frames

avail = ((~data.filtered.isnull()).sum(axis=1)).to_frame('nFrames').unstack()

print(avail.shape)
display(avail.head())
display(avail.tail())

In [None]:
# Create boolean DataFrame which ROI is spiking in which camera frame

# create empty structure for cumsum
df_template = pd.DataFrame(data=0,index=data.mirow,columns=data.icol)
df_spike = df_template.copy()

# select spike data
spikes = data.transients.loc[data.transients['in_motion_period']==False,['start_frame','stop_frame']]
spikes['count']=1

# fill in spike start and stop points (rename column to keep columns.name in df_spike)
sp = spikes[['start_frame','count']].rename(columns={'start_frame':'frame'}).pivot(columns='frame').fillna(0)
df_spike = df_spike.add(sp['count'], fill_value=0)
sp = spikes[['stop_frame','count']].rename(columns={'stop_frame':'frame'}).pivot(columns='frame').fillna(0)
df_spike = df_spike.add(-sp['count'], fill_value=0)

# cumulate, conversion to int is not adviced if using NaNs
df_spike = df_spike.cumsum(axis=1).astype(int)
df_spike = df_spike + data.time_roi_mask

print('table shape', df_spike.shape, 'active frames*ROIs', df_spike.sum().sum())
display(df_spike.head(25))
display(df_spike.tail())

In [None]:
# Create boolean DataFrame whether licking happens in camera frame

# Check for valid data and calculate their frame
print('All entries', data.behavior.shape)
df_lick = data.behavior[data.behavior.loc[:,'stop_time']>data.behavior.loc[:,'start_time']].copy()
print('Valid licks', df_lick.shape)
df_lick['frame'] = (data.FPS*(df_lick['start_time']+df_lick['stop_time'])/2).apply(np.round).astype(int)
display(df_lick.head())
display(df_lick.tail())
# Convert to a DataFrame like df_data or df_raw
df_lick = df_lick[['lick_idx','frame']].reset_index()
df_lick = df_lick.groupby(['time','frame']).count().unstack(fill_value=0)
display(df_lick.head())
df_lick = df_lick['lick_idx'].reindex(index=data.mirow.levels[0],columns=data.icol,fill_value=0)
display(df_lick.head())
# Number of remaining licks
print('Remaining licks',df_lick.sum().sum())
# Smoothen
from scipy.ndimage.filters import gaussian_filter
df_lick = df_lick.apply(lambda x: gaussian_filter(x.astype(float)*data.FPS, sigma=0.25*data.FPS), axis=1, raw=True)
display(df_lick.head())

## z-scoring

In [None]:
z_spike = la.pd_zscore_by_roi(df_spike, data.FPS, -2*data.FPS, axis=1)
z_data = la.pd_zscore_by_roi(df_data, data.FPS, -2*data.FPS, axis=1)
z_raw = la.pd_zscore_by_roi(df_raw, data.FPS, -2*data.FPS, axis=1)
z_lick = la.pd_zscore_clip(df_lick, data.FPS, -2*data.FPS, axis=1)

z_data = z_data.sort_index()
z_raw = z_raw.sort_index()

### Triggers

In [None]:
def trigger(data, threshold, rising=True, hold_off=None):
    '''Find threshold crossings along first axis'''
    data = np.array(data)
    trig = np.full(data.shape,False,dtype=bool)
    if hold_off:
        raise ValueError('Hold off period not implemented yet.')
    if rising:
        trig[1:] = (data[1:]>threshold) & (data[:-1]<=threshold)
    else:
        trig[1:] = (data[1:]<threshold) & (data[:-1]>=threshold)
    return trig

def trigger_find_pd(df, threshold, axis=1, hold_off=None):
    '''Find threshold crossings in both directions in a DataFrame'''
    triggers_rise = df.apply(lambda x: trigger(x,threshold, True), axis=axis)
    triggers_rise[triggers_rise==0]=np.nan
    triggers_fall = df.apply(lambda x: trigger(x,threshold, False), axis=axis)
    triggers_fall[triggers_fall==0]=np.nan
    
    if axis==1:
        triggers_rise = triggers_rise.stack()
        triggers_fall = triggers_fall.stack()
    elif axis==0:
        triggers_rise = triggers_rise.T.stack().T
        triggers_fall = triggers_fall.T.stack().T
    else:
        warnings.warn('Axis reduction not implemented for axis.')
    triggers_rise.name='weight'
    triggers_fall.name='weight'
    return triggers_rise, triggers_fall

def trigger_enable_pd(df, start, stop):
    '''Create trigger enabled array based on a pair of switch on and off events'''
    mi = pd.MultiIndex.from_product((df.index.values, [start]), names=['time', 'frame'])
    triggers_start = pd.Series(1.0, index=mi, name='weight')
    mi = pd.MultiIndex.from_product((df.index.values, [stop]), names=['time', 'frame'])
    triggers_stop = pd.Series(1.0, index=mi, name='weight')
    mi = pd.MultiIndex.from_product((df.index.values, list(range(start,stop))), names=['time', 'frame'])
    triggers_allow = pd.Series(1.0, index=mi, name='weight')

    return triggers_start, triggers_stop, triggers_allow


In [None]:
z_spike_threshold = 5.0/np.sqrt(len(data.rois))

max_lik_rate = 20
c,b = np.histogram(df_lick.values.ravel(),range=(0,max_lik_rate),bins=max_lik_rate)
lick_threshold = (np.argmax(c[1:])+1.5)/2
plt.hist(df_lick.values.ravel(),log=True,range=(0,max_lik_rate),bins=max_lik_rate)
plt.plot(lick_threshold,2,'y*',ms=15)
print(lick_threshold)

In [None]:
# The histogram shape justifies putting the threshold at the half maximum
lick_triggers_rise, lick_triggers_fall = trigger_find_pd(df_lick, lick_threshold)
print (lick_triggers_rise.shape,lick_triggers_fall.shape)
print ('Port was present in %d trials.'%data.experiment_traits[data.experiment_traits['port']=='W+'].shape[0])

In [None]:
# Define the boundaryof a p<0.005 set 
spike_triggers_rise, spike_triggers_fall = trigger_find_pd(z_spike.mean(level=0), z_spike_threshold)
print (spike_triggers_rise.shape,spike_triggers_fall.shape)

In [None]:
csp_triggers_rise, csp_triggers_fall, csp_triggers_allow = trigger_enable_pd(
    data.experiment_traits[data.experiment_traits['port']=='W+'],
    la.events[1]*data.FPS, la.events[2]*data.FPS)
csm_triggers_rise, csm_triggers_fall, csm_triggers_allow = trigger_enable_pd(
    data.experiment_traits[data.experiment_traits['port']=='W-'],
    la.events[1]*data.FPS, la.events[2]*data.FPS)

us_triggers_rise, us_triggers_fall, us_triggers_allow = trigger_enable_pd(
    data.experiment_traits[data.experiment_traits['puffed']=='A+'],
    la.events[3]*data.FPS, la.events[4]*data.FPS)

tra_triggers_rise, tra_triggers_fall, tra_triggers_allow = trigger_enable_pd(
    data.experiment_traits[data.experiment_traits['context']=='CS+'],
    la.events[2]*data.FPS, la.events[3]*data.FPS)

# Plot

In [None]:
from matplotlib.backends.backend_pdf import PdfPages

class helpmultipage(object):
    def __init__(self, filename):
        self.filename = filename
        self.isopen = False
        self.open()
        
    def __del__(self):
        self.close()
        
    def savefig(self, dpi=None):
        if self.isopen:
            self.pp.savefig(dpi=dpi)

    def open(self):
        if (~self.isopen) and len(self.filename):
            self.pp = PdfPages(self.filename)
            self.isopen = True
        
    def close(self):
        if self.isopen:
            self.pp.close()
        self.isopen = False

#### Explanatory figure

In [None]:
pp = helpmultipage(animal+'_explanatory.pdf')

In [None]:
import matplotlib.patches as mpatches
from matplotlib.collections import PatchCollection
center = data.FPS * (la.events[:-1]+la.events[1:]) /2
left = data.FPS * la.events
width = data.FPS * (la.events[1:]-la.events[:-1])
vcenter = 0.0
vstart = -0.5

def label90(x,y,text):
    ax.text(x, y, text, ha="center", va="center", family='sans-serif', size=14, rotation=90)

fig, (empty, ax) = plt.subplots(2,1,figsize=(6,8))
fig.tight_layout(pad=3)
empty.axis('off')
#ax = fig.gca()
fig.suptitle('Explanatory figure',fontsize=16)
ax.set_xlabel('Camera frame')
ax.set_ylabel('z-scored activity')
ax.set_ylim(vstart,vstart+1)
ax.plot(z_spike.mean(axis=0)+0.00, label="(CategoryA, True): #trials", c=(1,1,0))
ax.plot(z_spike.mean(axis=0)+0.02, label="(CategoryB, True): #trials", c=(.5,1,.5))
ax.plot(-z_spike.mean(axis=0)+0.00, label="(CategoryA, False): #trials", c=(1,.8,1))
ax.plot(-z_spike.mean(axis=0)+0.02, label="(CategoryB, False): #trials", c=(.5,1,1))
patches = []
# mark delay
label90(center[0], vcenter, 'excitation by\nshowing water')
# mark CS
rect = mpatches.Rectangle((left[1],vstart), width[1], 1, ec="none")
patches.append(rect)
label90(center[1], vcenter, 'CS± if tone\n"Baseline" otherwise')
# mark delay
label90(center[2], vcenter, 'delay')
# mark UC
rect = mpatches.Rectangle((left[3],vstart), width[3], 1, ec="none")
patches.append(rect)
label90(center[3], vcenter, 'UC if any')
# mark water
ax.text((left[0]+left[3])/2, vstart, "water source present\niff allowed to lick",
        ha="center", va="bottom", family='sans-serif', size=14, bbox=dict(boxstyle="DArrow", pad=0.0, fc='c'))

for i in range(0,len(la.events)):
    ax.axvline(x=la.events[i]*data.FPS, ymin=0.0, ymax = 1.0, linewidth=1, color='k')
colors = np.linspace(0, 1, len(patches))
collection = PatchCollection(patches, cmap=plt.cm.hsv, alpha=0.1)
collection.set_array(np.array(colors))
ax.add_collection(collection)

leg = ax.legend(loc='lower center', title="Category name, Condition name",
               bbox_to_anchor=(0.5, 1.1))
leg.get_title().set_fontsize('large')
leg.get_title().set_fontweight('bold')
with warnings.catch_warnings():
    warnings.simplefilter('ignore', UserWarning)
    fig.show()
pp.savefig()

In [None]:
pp.close()

## Z-scored spiking
Spiking is "True" in the [intervals) given in transients_data.hc5

In [None]:
mymean = pd.DataFrame.mean
mystd = pd.DataFrame.std

### Averaging in 5" bins

In [None]:
bsections = np.arange(0,60,5)*data.FPS
bcenters = (bsections[1:]+bsections[:-1])/2

In [None]:
zb_spike = la.pd_aggr_col(z_spike, mymean, bsections, bcenters.astype(str))
zb_data = la.pd_aggr_col(z_data, mymean, bsections, bcenters.astype(str))
zb_raw = la.pd_aggr_col(z_raw, mymean, bsections, bcenters.astype(str))
zb_lick = la.pd_aggr_col(z_lick, mymean, bsections, bcenters.astype(str))

zb_data = zb_data.sort_index()
zb_raw = zb_raw.sort_index()

In [None]:
b_spike = la.pd_aggr_col(df_spike, mymean, bsections, bcenters.astype(str))
b_data = la.pd_aggr_col(df_data, mymean, bsections, bcenters.astype(str))
b_raw = la.pd_aggr_col(df_raw, mymean, bsections, bcenters.astype(str))
b_lick = la.pd_aggr_col(df_lick, mymean, bsections, bcenters.astype(str))

b_data = b_data.sort_index()
b_raw = b_raw.sort_index()

### Averaging within phases

In [None]:
asections = np.append(la.events,[60])*data.FPS
acenters = (asections[1:]+asections[:-1])/2

In [None]:
za_spike = la.pd_aggr_col(z_spike, mymean, asections, acenters.astype(str))
za_data = la.pd_aggr_col(z_data, mymean, asections, acenters.astype(str))
za_raw = la.pd_aggr_col(z_raw, mymean, asections, acenters.astype(str))
za_lick = la.pd_aggr_col(z_lick, mymean, asections, acenters.astype(str))

za_data = za_data.sort_index()
za_raw = za_raw.sort_index()

In [None]:
a_spike = la.pd_aggr_col(df_spike, mymean, asections, acenters.astype(str))
a_data = la.pd_aggr_col(df_data, mymean, asections, acenters.astype(str))
a_raw = la.pd_aggr_col(df_raw, mymean, asections, acenters.astype(str))
a_lick = la.pd_aggr_col(df_lick, mymean, asections, acenters.astype(str))

a_data = a_data.sort_index()
a_raw = a_raw.sort_index()

## Licking statistics

In [None]:
lick_rate_mean = la.pd_aggr_col(df_lick, mymean, asections, acenters.astype(str))
lick_rate_std = la.pd_aggr_col(df_lick, mystd, asections, acenters.astype(str))
lick_time_mean = la.pd_aggr_col((df_lick>lick_threshold).astype(float), mymean,
                                asections, acenters.astype(str))
lick_time_std = la.pd_aggr_col((df_lick>lick_threshold).astype(float), mystd,
                               asections, acenters.astype(str))

### Learning progress

In [None]:
pp = helpmultipage(animal+'_protocol.pdf')
fig, ax = plt.subplots(len(data.trials),1,figsize=(10,0.6*len(data.trials)), sharex=True, sharey=True)
fig.tight_layout(h_pad=0.1)
ind = np.arange(0,5)
width = 1
height = 1.2
spacing = 10
i = 0
label_df = data.experiment_traits.replace('Baseline','B.L.')
for i in range(0,len(data.trials)):
    trial = data.trials[i]
    sc = 2.0*lick_threshold
    rects1 = ax[i].bar(ind+2*spacing, lick_rate_mean.loc[trial]/sc, width, color='r', yerr=lick_rate_std.loc[trial]/sc)
    rects2 = ax[i].bar(ind+3*spacing, lick_time_mean.loc[trial], width, color='b') #, yerr=lick_time_std.loc[trial])
    #label = label_df.loc[trial,la.sort_learning]
    #ax[i].text(1, 0, ', '.join(label.values.astype(str)),
    #    horizontalalignment='left',verticalalignment='bottom')
    ax[i].set_xlim(xmin=0)
    ax[i].set_ylim(ymin=0, ymax=height)
    ax[i].set_yticks([0,0.5,1])
    la.draw_conditions(ax[i],label_df,trial,data.FPS,loc='lower left',screen_width=0.5, height=height, cw=[0.25, 0.15, 0.15, 0.15, 0.15, 0.15],fontsize=12)
ax[i].set_xticks([spacing, 2.2*spacing, 3.2*spacing])
ax[i].set_xticklabels(['Conditions', 'Licking rate', 'Licking time'])
pp.savefig()
pp.close()

## Population averages

In [None]:
pp = helpmultipage(animal+'_pop.pdf')

In [None]:
la.plot_data(df_spike, df_data, df_lick, data.experiment_traits, data.FPS)
pp.savefig()

### Single criterion
* comments

In [None]:
grp = [['context'],['learning_epoch'],['port'],['puffed']]
la.plot_data(z_spike, z_data, df_lick, data.experiment_traits, data.FPS, grp, title='Population activity')
pp.savefig()
la.plot_data(zb_spike, zb_data, b_lick, data.experiment_traits, data.FPS, grp, title='Population activity binned', div=bcenters)
pp.savefig()
la.plot_data(za_spike, za_data, a_lick, data.experiment_traits, data.FPS, grp, title='Population activity averaged over events', div=acenters)
pp.savefig()

### Two criteria
* comments

### Three criteria
* comments

### All criteria
* There is no increased population activity for CS+ without puffing. (For mouse 0216_4 the 1 trial with port displays increase during the trace period - why?)
* During learning mouse 0216_4 shows incresed activity during the UC phase for CS-

In [None]:
grp = [['context','port','puffed']]
la.plot_epochs(z_spike, z_data, df_lick, data.experiment_traits, etc, data.FPS, grp, title='Population activity')
pp.savefig()
la.plot_epochs(zb_spike, zb_data, b_lick, data.experiment_traits, etc, data.FPS, grp, title='Population activity binned', div=bcenters)
pp.savefig()
la.plot_epochs(za_spike, za_data, a_lick, data.experiment_traits, etc, data.FPS, grp, title='Population activity averaged over events', div=acenters)
pp.savefig()

### Activities conditional on epoch

In [None]:
def plot_by_epoch(epoch):
    experiment_c = data.experiment_traits[data.experiment_traits.loc[:,'learning_epoch']==epoch]
    spike_c = z_spike.reindex(experiment_c.index, level='time')
    data_c = z_data.reindex(experiment_c.index, level='time')
    raw_c = z_raw.reindex(experiment_c.index, level='time')
    lick_c = df_lick.reindex(experiment_c.index)
    print (experiment_c.shape, z_spike.shape)
    spike_ca = la.pd_aggr_col(spike_c, mymean, asections, acenters.astype(str))
    data_ca = la.pd_aggr_col(data_c, mymean, asections, acenters.astype(str))
    raw_ca = la.pd_aggr_col(raw_c, mymean, asections, acenters.astype(str))
    lick_ca = la.pd_aggr_col(lick_c, mymean, asections, acenters.astype(str))
    print (spike_c.shape, spike_ca.shape)

    grp = [['context','port'],['context','puffed'],['port','puffed']]
    la.plot_data(spike_c, data_c, lick_c, data.experiment_traits, data.FPS, grp, title=epoch)
    pp.savefig()
    la.plot_data(spike_ca, data_ca, lick_ca, data.experiment_traits, data.FPS, grp, title=epoch+' averaged over events', div=acenters)
    pp.savefig()

#### Pre-learning

In [None]:
plot_by_epoch('Pre-Learning')

#### Learning

In [None]:
plot_by_epoch('Learning')

#### Post-Learning

In [None]:
plot_by_epoch('Post-Learning')

In [None]:
pp.close()

## Activity vector by phases

In [None]:
pp = helpmultipage(animal+'_phases_sp.pdf')

etmp = data.experiment_traits.reset_index(drop=True).set_index(la.sort_learning)

for p,aggr in enumerate(za_data.columns):
    nplot = len(et.index)
    ncol = 14
    nrow = int(np.ceil(len(et.index)/float(ncol)))
    fig, ax = plt.subplots(nrow,ncol,figsize=(2*ncol,1+10*nrow),squeeze=False,sharey=True)
    fig.tight_layout(pad=3, h_pad=3, rect=[0,0,1,0.8])
    fig.suptitle('Spikes, Phase: %s'%la.phases[p],fontsize=16)
    cax = None
    for i, cond in enumerate(et.index):
        col = i%ncol
        row = int((i-col)/ncol)
        sel = etmp.loc[cond,'timestr']
        tmp = a_spike.loc[sel.tolist(),aggr].unstack('time')
        img = ax[row,col].matshow(tmp.values,origin='lower',vmin=0,vmax=1)
        ax[row,col].xaxis.set_ticks_position('bottom')
        ax[row,col].set_title('\n'.join(cond))
        ax[row,col].set_ylabel('Unit ID')
        ax[row,col].set_xlabel('Trial')
    #cax,kw = mpl.colorbar.make_axes([axis for axis in ax.flat])
    cax = ax[-1,-1]
    plt.colorbar(img,ax=cax)#ax=cax,**kw)
    pp.savefig()    
pp.close()

In [None]:
pp = helpmultipage(animal+'_phases_ca.pdf')

etmp = data.experiment_traits.reset_index(drop=True).set_index(la.sort_learning)

for p,aggr in enumerate(za_data.columns):
    nplot = len(et.index)
    ncol = 14
    nrow = int(np.ceil(len(et.index)/float(ncol)))
    fig, ax = plt.subplots(nrow,ncol,figsize=(2*ncol,1+10*nrow),squeeze=False,sharey=True)
    fig.tight_layout(pad=3, h_pad=3, rect=[0,0,1,0.8])
    fig.suptitle('Ca-Signal, Phase: %s'%la.phases[p],fontsize=16)
    cax = None
    for i, cond in enumerate(et.index):
        col = i%ncol
        row = int((i-col)/ncol)
        sel = etmp.loc[cond,'timestr']
        tmp = za_data.loc[sel.tolist(),aggr].unstack('time')
        img = ax[row,col].matshow(tmp.values,origin='lower',vmin=-3,vmax=3)
        ax[row,col].xaxis.set_ticks_position('bottom')
        ax[row,col].set_title('\n'.join(cond))
        ax[row,col].set_ylabel('Unit ID')
        ax[row,col].set_xlabel('Trial')
    #cax,kw = mpl.colorbar.make_axes([axis for axis in ax.flat])
    cax = ax[-1,-1]
    plt.colorbar(img,ax=cax)#ax=cax,**kw)
    pp.savefig()    
pp.close()

### An example of spiking
The first 1 second of the recording seems missing

In [None]:
# Order experiments by settings (deprecated)
et3 = data.experiment_traits.copy().reset_index(drop=True)
et3 = et3.sort_values(la.sort_learning+[str('session_num')]).set_index(la.sort_learning)

In [None]:
# Triggers
trig_list_data = [lick_triggers_rise, lick_triggers_fall, spike_triggers_rise, spike_triggers_fall]
trig_list_sign = ['o', 's', '^', 'v']
trig_list_color = ['b', 'y', 'r', 'g']

In [None]:
# Show an example
idx = data.experiment_traits.index[9]
fig, ax = plt.subplots(1,1,figsize=(16,10))
la.draw_levels(ax, z_data, idx, data.FPS, data.roi_df)
ax.set_ylim(ymin=-60,ymax=120)
#plot_spiking_nan(ax, df_spike, idx, data.rois.values)
#la.draw_behavior(ax, data.behavior, idx, data.FPS)
la.draw_population(ax, z_data, idx, pos=-20, c='y', label='population Ca-signal')
la.draw_population(ax, z_spike, idx, pos=-20, threshold=z_spike_threshold, label='population z-spike count')
la.draw_licking(ax, df_lick, idx, pos=-40, threshold=lick_threshold, label='licking')
la.draw_triggers(ax, trig_list_data, idx, -5, trig_list_sign, c=trig_list_color)
la.draw_conditions(ax, data.experiment_traits, idx, data.FPS, height=20)

In [None]:
# Show an example
idx = data.experiment_traits.index[9]
fig, ax = plt.subplots(1,1,figsize=(16,10))
la.draw_transients(ax, data.transients, idx, data.FPS, data.roi_df)
ax.set_ylim(ymin=-60,ymax=len(data.rois)+1)
la.draw_spiking_nan(ax, df_spike, idx, data.roi_df)
#la.draw_behavior(ax, data.behavior, idx, dta.data.FPS)
la.draw_population(ax, z_data, idx, pos=-20, c='y', label='population Ca-signal')
la.draw_population(ax, z_spike, idx, pos=-20, threshold=z_spike_threshold, label='population z-spike count')
la.draw_licking(ax, df_lick, idx, pos=-40, threshold=lick_threshold, label='licking')
la.draw_triggers(ax, trig_list_data, idx, -5, trig_list_sign, c=trig_list_color)
la.draw_conditions(ax, data.experiment_traits, idx, data.FPS, height=20)
ax.legend()

In [None]:
pp = helpmultipage(animal+'_firing.pdf')

xmax = data.transients.loc[:,['stop_frame']].max().values

for idx, val in data.experiment_traits.iterrows(): #et3.iterrows():
    fig, ax = plt.subplots(1,1,figsize=(16,10))
    ax.set_xlim(xmax=xmax)
    experiment_id = val['timestr']
    #print (experiment_id)
    fig.suptitle('learning_epoch, context, port, puffed: #context in epoch, #day\n'+
        '%s: session %s, day %s'%(idx,val['session_num'],val['day_num']),fontsize=16)
    la.draw_transients(ax, data.transients, idx, data.FPS, data.roi_df)
    ax.set_ylim(ymin=-60,ymax=len(data.rois)+1)
    #plot_spiking_nan(ax, df_spike, idx, data.rois.values)
    la.draw_population(ax, z_data, idx, pos=-20, c='y')
    la.draw_population(ax, z_spike, idx, pos=-20, threshold=z_spike_threshold)
    la.draw_licking(ax, df_lick, idx, pos=-40, threshold=lick_threshold)
    la.draw_triggers(ax, trig_list_data, idx, -5, trig_list_sign, c=trig_list_color)
    la.draw_conditions(ax, data.experiment_traits, experiment_id, data.FPS, height=20)
    pp.savefig()
    plt.close(fig)
    
pp.close()

### Peri-event averages

In [None]:
def peri_event_avg(data, triggers, diameter=(-3*data.FPS, 3*data.FPS), allow=None, disable=None):
    '''Collect data in windows arond events in an event list'''
    window = np.arange(diameter[0],diameter[1])
    count=0
    ret = []
    for idx, weight in triggers.iteritems():
        experiment_id, frame = idx
        if (experiment_id in data.index) and ((allow is None) or (idx in allow)) and ((disable is None) or (idx not in disable)):
            tmp = data.loc[experiment_id,:].reindex(columns=frame+window)
            tmp.columns = pd.MultiIndex.from_product([count,window],names=['id','frame'])
            ret.append(tmp)
            count += 1
    if len(ret):
        ret = pd.concat(ret,axis=1)
        return ret, count
    else:
        return None, count

In [None]:
def get_gauss_window(time_range, rate):
    '''Gaussian window'''
    rate = float(rate)
    if type(time_range) is int:
        time_points = np.arange(0,time_range)-0.5*time_range
    elif len(time_range)==2:
        time_points = np.arange(time_range[0],time_range[-1])
    else:
        time_points = time_range
    decay = np.exp(-np.power(rate*time_points,2)/2.0)
    return decay/np.sum(decay)

In [None]:
def get_decay(time_range, rate):
    '''Exponentially decaying series'''
    rate = float(rate)
    if type(time_range) is int:
        time_points = np.arange(0,time_range)-0.5*time_range
    elif len(time_range)==2:
        time_points = np.arange(time_range[0],time_range[-1])
    else:
        time_points = time_range
    decay = np.exp(-np.abs(rate*time_points))
    return decay/np.sum(decay)

In [None]:
def rev_align(data, shape):
    '''Align for broadcast to shape matching axes from the beginning (opposed to numpy convention)'''
    data_dim = data.ndim
    req_dim = len(shape)
    new_axes = np.arange(data_dim,req_dim)
    # TODO: using np.reshape is more efficient
    ret = data
    for axis in new_axes:
        ret = np.expand_dims(data, axis=axis)
    return ret

In [None]:
def rev_broadcast(data, shape):
    '''Broadcast to shape matching axes from the beginning (opposed to numpy convention)'''
    ret = np.broadcast_to(rev_align(data,shape), shape)
    return ret

### Pattern matching

In [None]:
def match_pattern(data, pattern, std, decay, noise_level=0.01, detailed=False):
    '''Match pattern with decaying strength along time axis (rows). Use any number of columns.
    Pattern and std may be one (time points given, all columns equal) or two dimensional (matrix given).
    Noise_level can be 0 to 2 dimensional, if it is one dimensional then it is understood
    on the category axis (all rows equal), noise_level must be positive if there is any std==0.'''
    diff = data-rev_align(pattern, data.shape)
    scale = rev_align(decay, data.shape)/(noise_level+rev_align(std, data.shape))
    ret = np.nanmean(np.abs(diff*scale), axis=(0 if detailed else None))
    return -ret

In [None]:
def correlate_pattern(data, pattern, std, decay, noise_level=0.01, detailed=False):
    '''Multiply pattern with decaying strength along time axis (rows). Use any number of columns.
    Pattern and std may be one (time points given, all columns equal) or two dimensional (matrix given).
    Noise_level can be 0 to 2 dimensional, if it is one dimensional then it is understood
    on the category axis (all rows equal), noise_level must be positive if there is any std==0.'''
    # Note: this is not real correlation unless input is normalized
    corr = data*rev_align(pattern, data.shape)
    scale = rev_align(decay, data.shape)/(noise_level+rev_align(std, data.shape))
    ret = np.nanmean((corr*scale), axis=(0 if detailed else None))
    return ret

In [None]:
def rolling2D(df, func, window, min_periods=None, center=True):
    '''Slice a DataFrame along index (rows) to apply 2D function'''
    # This was a missing feature in pandas: one could previously correlate a single pattern
    # along a selected axis of a 2D DataFrame.
    window = int(window)
    if window<1:
        raise ValueError('window needs positive length')
    if min_periods is None:
        min_periods = window
    else:
        min_periods = int(min_periods)
    if min_periods<1:
        raise ValueError('min_periods needs to be positive')
    start = min_periods-window # first point of first window is start, available points evaluate to [0, min_periods)
    end = len(df)-min_periods # first point of last window is end, available points evaluate to [len-min_periods, len)
    if center:
        shift = int(np.floor(window/2))
    else:
        shift = 0
    first = start
    data = df.iloc[first:first+window,:]
    tmp = func(data)
    try:
        if len(tmp)==len(df.columns):
            ret = pd.DataFrame([], index=df.index, columns=df.columns)
        else:
            ret = pd.DataFrame([], index=df.index, columns=pd.Index(np.arange(0,len(tmp))))
    except:
        ret = pd.DataFrame([], index=df.index, columns=pd.Index([0]))
    ret.iloc[first+shift,:]=tmp
    for first in range(start+1,end+1):
        data = df.iloc[first:first+window,:]
        tmp = func(data)
        ret.iloc[first+shift,:]=tmp
    return ret

In [None]:
# Fine-tuned version of method3
def search_pattern(df, triggers, trials, FPS, diam = (-3,3), decay_time=0.1, trigger_allow=None, trigger_disable=None, method='correlate'):
    ret = []
    diam = int(FPS*diam[0]),int(FPS*diam[1])
    window = diam[1]-diam[0]
    decay = get_decay(window,1.0/(decay_time*FPS))
    dd, c = peri_event_avg(df, triggers, diameter=diam, allow=trigger_allow, disable=trigger_disable)
    p1 = dd.mean(axis=1, level=1).T.values
    s1 = dd.std(axis=1, level=1).T.values
    if method=='match':
        func = (lambda x: match_pattern(x.values,p1,s1,decay=decay))
    elif method=='correlate':
        p1 = p1 - np.nanmean(p1)
        func = (lambda x: correlate_pattern(x.values-x.mean().mean(),p1,s1,decay=decay))
    else:
        raise ValueError('Unaccepted method')
    for trial in trials:
        tmp = rolling2D(df.loc[trial,:].T,func,window,center=True).T
        tmp.index=[trial]
        ret.append(tmp)
    ret = pd.concat(ret)
    return ret.astype(float)

In [None]:
prog_update = 1467367471
print ("%.0f"%time.time())

In [None]:
pattdb_file = 'pattdb_'+animal+'.h5'
if 'pattdb' in locals():
    pattdb.close()
    del pattdb
if (not la.test_hdf(pattdb_file)) or (os.path.getmtime(pattdb_file)<prog_update):
    with pd.HDFStore(pattdb_file, mode='w') as pattdb:
        for method,sel in itertools.product(['match','correlate'],['sp','ca']):
            print(method,sel)
            df = df_spike if sel == 'sp' else z_data.reindex(data.mirow)
            key = '/'.join((method,sel,'lick_rise_csp'))
            pattdb[key] = search_pattern(df, lick_triggers_rise, data.trials,
                                        data.FPS, trigger_allow=csp_triggers_allow)
            key = '/'.join((method,sel,'lick_fall_csp'))
            pattdb[key] = search_pattern(df,lick_triggers_fall, data.trials,
                                       data.FPS, trigger_allow=csp_triggers_allow)
            key = '/'.join((method,sel,'csp_rise'))
            pattdb[key] = search_pattern(df,csp_triggers_rise, data.trials,
                                      data.FPS, trigger_allow=csp_triggers_allow)
            key = '/'.join((method,sel,'us_rise'))
            pattdb[key] = search_pattern(df,us_triggers_rise, data.trials, data.FPS)
pattdb = pd.HDFStore(pattdb_file, mode='r')

In [None]:
z_patt = {}
for key in pattdb.keys():
    if key[0] == '/':
        key = key[1:]
    z_patt[key] = la.nan_zscore(pattdb[key])
pattdb

In [None]:
# Show an example
method = 'match' # 'match', 'correlate'
sel = 'ca' # 'ca', 'sp'
zoom = 3
idx = data.experiment_traits.index[13]
fig, ax = plt.subplots(1,1,figsize=(16,10))
la.draw_transients(ax, data.transients, idx, data.FPS, data.roi_df)
ax.set_ylim(ymin=-80,ymax=len(data.rois)+1)
la.draw_spiking_nan(ax, df_spike, idx, data.rois.values)
#la.draw_behavior(ax, data.behavior, idx, data.FPS)
la.draw_licking(ax, z_patt['/'.join((method,sel,'lick_rise_csp'))], idx, pos=-20, c='g', threshold=[-3, 3], zoom=zoom, label='%s: CS+ lick start'%sel)
la.draw_licking(ax, z_patt['/'.join((method,sel,'lick_fall_csp'))], idx, pos=-20, c='c', threshold=None, zoom=zoom, label='%s: CS+ lick end'%sel)
la.draw_licking(ax, z_patt['/'.join((method,sel,'csp_rise'))], idx, pos=-20, c='orange', threshold=None, zoom=zoom, label='%s: CS+ start'%sel)
la.draw_licking(ax, z_patt['/'.join((method,sel,'us_rise'))], idx, pos=-20, c='r', threshold=None, zoom=zoom, label='%s: US start'%sel)
la.draw_population(ax, z_data, idx, pos=-40, c='y', label='pop. Ca-signal')
la.draw_population(ax, z_spike, idx, pos=-40, threshold=z_spike_threshold, label='pop. z-spike count')
la.draw_licking(ax, df_lick, idx, pos=-60, threshold=lick_threshold, label='licking')
la.draw_triggers(ax, trig_list_data, idx, -5, trig_list_sign, c=trig_list_color)
la.draw_conditions(ax, data.experiment_traits, idx, data.FPS, height=20)
ax.legend()

In [None]:
for method,sel in itertools.product(['match','correlate'], ['sp','ca']):
    print (method,sel)

    pp = helpmultipage(animal+'_triggers_%s_%s.pdf'%(method,sel))
    zoom = 4

    xmax = data.transients.loc[:,['stop_frame']].max().values

    for idx, val in data.experiment_traits.iterrows(): #et3.iterrows():
        fig, ax = plt.subplots(1,1,figsize=(16,10))
        ax.set_xlim(xmax=xmax)
        experiment_id = val['timestr']
        #print (experiment_id)
        fig.suptitle('learning_epoch, context, port, puffed: #context in epoch, #day\n'+
            '%s: session %s, day %s'%(idx,val['session_num'],val['day_num']),fontsize=16)
        la.draw_transients(ax, data.transients, idx, data.FPS, data.roi_df)
        ax.set_ylim(ymin=-80,ymax=len(data.rois)+1)
        #la.draw_spiking_nan(ax, df_spike, idx, data.rois.values)
        la.draw_licking(ax, z_patt['/'.join((method,sel,'lick_rise_csp'))], idx, pos=-20, c='g',
                        threshold=[-3, 3], zoom=zoom, label='%s: CS+ lick start'%sel)
        la.draw_licking(ax, z_patt['/'.join((method,sel,'lick_fall_csp'))], idx, pos=-20, c='c',
                        threshold=None, zoom=zoom, label='%s: CS+ lick end'%sel)
        la.draw_licking(ax, z_patt['/'.join((method,sel,'csp_rise'))], idx, pos=-20, c='orange',
                        threshold=None, zoom=zoom, label='%s: CS+ start'%sel)
        la.draw_licking(ax, z_patt['/'.join((method,sel,'us_rise'))], idx, pos=-20, c='r',
                        threshold=None, zoom=zoom, label='%s: US start'%sel)
        la.draw_population(ax, z_data, idx, pos=-40, c='y', label='population Ca-signal')
        la.draw_population(ax, z_spike, idx, pos=-40, threshold=z_spike_threshold, label='population z-spike count')
        la.draw_licking(ax, df_lick, idx, pos=-60, threshold=lick_threshold, label='licking')
        la.draw_triggers(ax, trig_list_data, idx, -5, trig_list_sign, c=trig_list_color)
        la.draw_conditions(ax, data.experiment_traits, idx, data.FPS, height=20)
        ax.legend()
        pp.savefig()
        plt.close(fig)

    pp.close()

## Peri-event plots

In [None]:
import matlab_tools as mt

In [None]:
def plot_peri_event1(ax, df, title=None, pos=-15, zoom=10.0, vmin=None, vmax=None):
    '''Plot df using matshow'''
    extent = np.min(df.columns.values)-0.5, np.max(df.columns.values)+0.5, -0.5, len(df)+0.5
    ax.set_xlim(extent[0:2])
    ax.set_ylim((pos-zoom,extent[3]))
    ret = ax.matshow(df, origin='lower', aspect='auto', extent=extent, vmin=vmin, vmax=vmax)
    ax.axhline(y=pos,xmin=0.0,xmax=1.0,c='gray')
    ax.axvline(x=0,ymin=0.0,ymax=1.0,c='gray')
    ax.plot(zoom*df.mean(axis=0)+pos)
    ax.xaxis.set_ticks_position('bottom')
    ax.set_xlabel('$\Delta$frame')
    ax.set_ylabel('Unit ID')
    ax.set_xlim(extent[0:2])
    ax.set_ylim((pos-zoom,extent[3]))
    if title is not None:
        ax.set_title(title)
    return ret

In [None]:
def plot_peri_event2(ax, df_mean, df_std, title=None, pos=-15, zoom=10.0, vmin=None, vmax=None):
    extent = np.min(df_mean.columns.values)-0.5, np.max(df_mean.columns.values)+0.5, -0.5, len(df_mean)+0.5
    '''Plot mean and std using color and lightness-encoding'''
    ax.set_xlim(extent[0:2])
    ax.set_ylim((pos-zoom,extent[3]))
    #img = mt.hls_matrix(mt.crop_series(0.4-0.5*df_mean.T.values,(0,0.8)),mt.crop_series(0.5-0.5*df_std.T.values,(0,1)),0.6)
    img = mt.hls_matrix(mt.crop_series(0.2-0.25*df_mean.T.values,(0,0.8)),mt.crop_series(0.5-0.25*df_std.T.values,(0,1)),0.6)
    ret = ax.imshow(img,interpolation='none',origin='lower',aspect='auto', extent=extent)
    ax.axhline(y=pos,xmin=0.0,xmax=1.0,c='gray')
    ax.axvline(x=0,ymin=0.0,ymax=1.0,c='gray')
    ax.plot(zoom*df_mean.mean(axis=0)+pos)
    ax.xaxis.set_ticks_position('bottom')
    ax.set_xlabel('$\Delta$frame')
    ax.set_ylabel('Unit ID')
    ax.set_xlim(extent[0:2])
    ax.set_ylim((pos-zoom,extent[3]))
    if title is not None:
        ax.set_title(title)
    return ret

In [None]:
def plot_peri_collection(collection, title=None, combine=True):
    max_cols = 10
    num_plots = len(collection) * (1 if combine else 2)
    num_rows = int(np.ceil(num_plots/float(max_cols)))
    num_cols = max_cols if num_rows>1 else num_plots
    fig, ax = plt.subplots(num_rows, num_cols, figsize=(2*num_cols+2,12*num_rows), sharex=True, sharey=True, squeeze=False)
    ax = np.ravel(ax)
    fig.tight_layout(rect=(0,0,0.9,0.9), w_pad=2, h_pad=8)
    gradient = np.linspace(-1, 1, 256)
    gradient = pd.DataFrame(np.vstack((gradient, gradient)).T, columns=[2,3])
    cax1 = fig.add_axes([0.92, 0.1, 0.02, 0.8])
    cax2 = fig.add_axes([0.96, 0.1, 0.02, 0.8])
    if combine:
        plot_peri_event2(cax1, gradient, gradient*0, 'Mean', vmin=-1, vmax=1)
        plot_peri_event2(cax2, gradient*0, gradient+1, 'Std', vmin=-1, vmax=1)
    else:
        plot_peri_event(cax1, gradient, 'Mean', vmin=-1, vmax=1)
        plot_peri_event(cax2, gradient+1, 'Stdev', vmin=0, vmax=2)
    cax1.set_ylabel(''); cax1.set_yticks([0,128,256]); cax1.set_yticklabels([-1,0,1])
    cax2.set_ylabel(''); cax2.set_yticks([0,128,256]); cax2.set_yticklabels([0,1,2])
    cax1.set_xlabel(''); cax1.set_xticks([])
    cax2.set_xlabel(''); cax2.set_xticks([])
    fig.suptitle(title,fontsize=16)
    
    for i, (df, trig, allow, disable, title) in enumerate(collection):
        dd, c = peri_event_avg(df, trig, allow=allow, disable=disable)
        if c:
            dd = dd.reindex(data.rois)
            if combine:
                plot_peri_event2(ax[i], dd.mean(axis=1, level=1), dd.std(axis=1, level=1), '%s: %d'%(title,c), vmin=-1, vmax=1)
            else:
                plot_peri_event1(ax[2*i], dd.mean(axis=1, level=1), '%s: %d'%(title,c), vmin=-1, vmax=1)
                plot_peri_event1(ax[2*i+1], dd.std(axis=1, level=1), 'Stdev', vmin=0, vmax=2)

    return fig

In [None]:
def list_peri_3a(df, title=None):
    '''Plot collection: CS+ US'''
    ret = [] # df, trig, allow, disable, title
    ret.append([df, lick_triggers_rise, None, None, 'Lick rise'])
    ret.append([df, lick_triggers_fall, None, None, 'Lick fall'])
    ret.append([df, lick_triggers_rise, csp_triggers_allow, None, 'Lick rise CS+'])
    ret.append([df, lick_triggers_fall, csp_triggers_allow, None, 'Lick fall CS+'])
    ret.append([df, csp_triggers_rise, None, None, 'CS+ start'])
    ret.append([df, csp_triggers_fall, None, None, 'CS+ end'])
    ret.append([df, lick_triggers_rise, us_triggers_allow, None, 'Lick rise US'])
    ret.append([df, lick_triggers_fall, us_triggers_allow, None, 'Lick fall US'])
    ret.append([df, us_triggers_rise, None, None, 'US start'])
    ret.append([df, us_triggers_fall, None, None, 'US end'])
    return ret

In [None]:
def list_peri_3b(df, title=None):
    '''Plot collection: CS+ US'''
    ret = [] # df, trig, allow, disable, title
    ret.append([df, lick_triggers_rise, None, None, 'Lick rise'])
    ret.append([df, lick_triggers_fall, None, None, 'Lick fall'])
    ret.append([df, lick_triggers_rise, csp_triggers_allow, None, 'Lick rise CS+'])
    ret.append([df, lick_triggers_fall, csp_triggers_allow, None, 'Lick fall CS+'])
    ret.append([df, csp_triggers_rise, None, None, 'CS+ start'])
    ret.append([df, csp_triggers_fall, None, None, 'CS+ end'])
    ret.append([df, lick_triggers_rise, csm_triggers_allow, None, 'Lick rise CS-'])
    ret.append([df, lick_triggers_fall, csm_triggers_allow, None, 'Lick fall CS-'])
    ret.append([df, csm_triggers_rise, None, None, 'CS- start'])
    ret.append([df, csm_triggers_fall, None, None, 'CS- end'])
    return ret

In [None]:
pp = helpmultipage(animal+'_peri1.pdf')
for epoch in la.epochs.values:
    experiment_c = data.experiment_traits[data.experiment_traits.loc[:,'learning_epoch']==epoch]
    spike_c = df_spike.reindex(experiment_c.index, level='time')
    data_c = z_data.reindex(experiment_c.index, level='time')
    fig=plot_peri_collection(list_peri_3a(spike_c),'%s Spiking on US'%epoch,combine=False)
    pp.savefig()
    plt.close(fig)
    fig=plot_peri_collection(list_peri_3b(spike_c),'%s Spiking on CS+/-'%epoch,combine=False)
    pp.savefig()
    plt.close(fig)
for epoch in la.epochs.values:
    experiment_c = data.experiment_traits[data.experiment_traits.loc[:,'learning_epoch']==epoch]
    spike_c = df_spike.reindex(experiment_c.index, level='time')
    data_c = z_data.reindex(experiment_c.index, level='time')
    fig=plot_peri_collection(list_peri_3a(data_c),'%s Ca-level on US'%epoch,combine=False)
    pp.savefig()
    plt.close(fig)
    fig=plot_peri_collection(list_peri_3b(data_c),'%s Ca-level on CS+/-'%epoch,combine=False)
    pp.savefig()
    plt.close(fig)
pp.close()

In [None]:
pp = helpmultipage(animal+'_peri2.pdf')
for epoch in la.epochs.values:
    experiment_c = data.experiment_traits[data.experiment_traits.loc[:,'learning_epoch']==epoch]
    spike_c = df_spike.reindex(experiment_c.index, level='time')
    data_c = z_data.reindex(experiment_c.index, level='time')
    fig=plot_peri_collection(list_peri_3a(spike_c),'%s Spiking on US'%epoch)
    pp.savefig()
    plt.close(fig)
    fig=plot_peri_collection(list_peri_3b(spike_c),'%s Spiking on CS+/-'%epoch)
    pp.savefig()
    plt.close(fig)
for epoch in la.epochs.values:
    experiment_c = data.experiment_traits[data.experiment_traits.loc[:,'learning_epoch']==epoch]
    spike_c = df_spike.reindex(experiment_c.index, level='time')
    data_c = z_data.reindex(experiment_c.index, level='time')
    fig=plot_peri_collection(list_peri_3a(data_c),'%s Ca-level on US'%epoch)
    pp.savefig()
    plt.close(fig)
    fig=plot_peri_collection(list_peri_3b(data_c),'%s Ca-level on CS+/-'%epoch)
    pp.savefig()
    plt.close(fig)
pp.close()

In [None]:
fig=plot_peri_collection(list_peri_3a(df_spike),'Spiking',combine=False)
fig=plot_peri_collection(list_peri_3a(df_spike),'Spiking')

## Individual ROIs
* since there are many of them, save figure to pdf
* THIS WILL <font color="red">TAKE A WHILE</font>, consider testing with a small range

In [None]:
def plot_roi(df_spike, df_data, filaname, grp, title_template, by_epoch=False, div=None, fill=None):
    pp = PdfPages(filaname)
    for i in range(0,len(data.rois)):
        spike_c = df_spike.loc[(slice(None),data.rois[i]),:]
        data_c = df_data.loc[(slice(None),data.rois[i]),:]
        #raw_c = df_raw.loc[(slice(None),data.rois[i]),:]
        if by_epoch:
            fig = la.plot_epochs(spike_c, data_c, None, data.experiment_traits, etc, data.FPS, grp, title=title_template%(i,data.rois[i]), div=div, fill=fill)
        else:
            fig = la.plot_data(spike_c, data_c, None, data.experiment_traits, data.FPS, grp, title=title_template%(i,data.rois[i]), div=div, fill=fill)
        pp.savefig()
        plt.close(fig)
    pp.close()

In [None]:
if batch_animal is None:
    raise ValueError("You don't want to run this automatically")

#### Raw data

In [None]:
plot_roi(df_spike, df_data, animal+'_roi1crit.pdf',[['context'],['learning_epoch'],['port'],['puffed']],'ROI %d:\n%s')

In [None]:
plot_roi(df_spike, df_data, animal+'_roiAcrit.pdf',[['context','port','puffed']],'ROI %d:\n%s',True)

### Averaging over intervals

#### Intervals aligned to events

In [None]:
plot_roi(a_spike, a_data, animal+'_avg1crit.pdf',[['context'],['learning_epoch'],['port'],['puffed']],'ROI %d:\n%s',div=acenters)

In [None]:
plot_roi(a_spike, a_data, animal+'_avgAcrit.pdf',[['context','port','puffed']],'ROI %d:\n%s',True,div=acenters, fill='err')

#### Averaging over bins

In [None]:
plot_roi(b_spike, b_data, animal+'_bin1crit.pdf',[['context'],['learning_epoch'],['port'],['puffed']],'ROI %d:\n%s',div=bcenters)

In [None]:
plot_roi(a_spike, a_data, animal+'_binAcrit.pdf',[['context','port','puffed']],'ROI %d:\n%s',True,div=acenters, fill='err')

## Correlations

In [None]:
# Combine information
ord1 = z_data.reindex(data.mirow, data.icol)
et1 = data.experiment_traits.copy().loc[:,la.sort_learning+['day_num','session_num']]
ord1 = ord1.join(et1, how='inner').reset_index().drop('time', axis=1).set_index(la.sort_learning+['roi_id', 'session_num']).sort_index()
ord1.columns.name='Spike'
print(ord1.shape)
display(ord1.head())

# Search for days that contain experiments with same traits and session_num
# These entries would jeopardize unstacking
et2 = et1.reset_index(drop=True).set_index(la.sort_learning+['session_num']).sort_index()
second_occur = et2.index.duplicated()
set1 = et2.loc[second_occur,'day_num'].unique()
all_occur = et2.index.get_duplicates()
set_all = et2.loc[all_occur,'day_num'].unique()
set2 = np.array(list(set(set_all)-set(set1)))
print(set1,set2)

# Filter out second occurrences stored in set2
if len(set2):
    ord1 = ord1[ord1.loc[:,'day_num'].apply(lambda x: x not in set2)]
print(ord1.shape)

# Reshape for correlation analysis
# integer values get converted to float if needed to hold NaN-s
calendar = ord1['day_num'].unstack(fill_value=0)
ord1 = ord1.drop(['day_num'], axis=1).unstack()
display(calendar.head())
display(ord1.head(10))

In [None]:
# Find the pre-learning structure, without airpuff
key_ref = ('Post-Learning','CS+','W+','A+')
time_ref = np.array([15, 40])
col_ref = slice(int(time_ref[0]*data.FPS),int(time_ref[1]*data.FPS))
sel = ord1.loc[key_ref+(slice(None),),col_ref]
print(key_ref,time_ref,col_ref,sel.shape)

# Correlate
corr_df = sel.T.corr()
corr_np = corr_df.fillna(0).values

# Discard invalid series
keep = (np.diag(corr_np) == 1.0)
corr_np = corr_np[keep,:][:,keep]

# Show
fig, ax = plt.subplots(1,2, figsize=(12,4))
img = ax[0].matshow(corr_df.values)
img = ax[1].matshow(corr_np)
fig.colorbar(img, ax=ax[1])

In [None]:
pp = helpmultipage(animal+'_correl.pdf')

In [None]:
# Define an ordering
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import squareform, pdist
sq_dist = squareform(1.0-corr_np)
corr_link = linkage(sq_dist, 'average')
fig, ax = plt.subplots(1,2, figsize=(18,8))
fig.suptitle('Reference is presented here: '+(', '.join(np.array(key_ref)))+
            ' and time '+('..'.join(time_ref.astype(str)))+'s',fontsize=16)
labels = sel.index.get_level_values(4).to_series().reset_index(drop=True)[keep]
dendo = dendrogram(corr_link, ax=ax[1], labels=labels.values, leaf_font_size=2.5, orientation='left')
ax[1].set_title('Distance of firing patterns')
corr_order = dendo['leaves']
# Show reordered
img = ax[0].matshow(corr_np[corr_order,:][:,corr_order], origin='lower', vmin=-0.8, vmax=1)
ax[0].xaxis.set_ticks_position('bottom')
ax[0].set_title('Ordered correlation matrix', y=1.0)
fig.colorbar(img)
pp.savefig(dpi=600)

In [None]:
phase_start = data.event_frames+data.FPS
phase_end = data.event_frames[1:]-data.FPS

num_phases = 3
num_rows = len(et.index)
num_cols = num_phases
fig, ax = plt.subplots(num_rows,num_cols, figsize=(5*num_cols,5*num_rows))
fig.suptitle('Correlation structure under different conditions: learning_epoch, context, port, puffed\n'+
             '(small number of trials might lead to larger percieved correlation)\n'+
             '(in phases Ready, CS, Trace the conditions A+ and A- should be very similar)',fontsize=16)
#ax = np.ravel(ax)
mx = {}
mi = pd.DataFrame([], index=pd.Index(la.phases[0:num_phases],name='phase'), columns = et.index).unstack().index
ds = pd.DataFrame(columns=mi)

for row,col in itertools.product(range(0,num_rows),range(0,num_cols)):
    # Find the pre-learning structure
    key = et.index[row]
    phase = la.phases[col]
    count = et.ix[row]
    col_sel = slice(int(phase_start[col]),int(phase_end[col]))
    sel = ord1.loc[key+(slice(None),),col_sel]
    print(key,phase,ord1.shape,sel.shape)
    
    # Correlate
    corr_tmp = sel.T.corr()
    corr_tmp = corr_tmp.fillna(0).values

    # Discard invalid series
    #if len(corr_tmp):
    corr_tmp = corr_tmp[keep,:][:,keep][corr_order,:][:,corr_order]
    img = ax[row,col].matshow(corr_tmp, origin='lower', vmin=-0.8, vmax=1)
    ax[row,col].xaxis.set_ticks_position('bottom')
        
    mx[key+(phase,)] = corr_tmp
    ds[key+(phase,)] = np.ravel(corr_tmp+np.diag(np.nan*np.diag(corr_tmp)))
    ax[row,col].set_title('%s, %s: %d'%(key,phase,count))
pp.savefig(dpi=600)

## Statistics of the correlation coefficients

In [None]:
num_rows = len(et.index)
num_cols = num_phases

fig, ax = plt.subplots(num_rows,num_cols, figsize=(5*num_cols,5*num_rows))
fig.suptitle('Distribution of the above correlation coefficients\n(diagonals excluded)',fontsize=16)
#ax = np.ravel(ax)

for row,col in itertools.product(range(0,num_rows),range(0,num_cols)):
    key = et.index[row]
    phase = la.phases[col]
    count = et.ix[row]
    col_sel = slice(int(phase_start[col]),int(phase_end[col]))
    sel = ord1.loc[key+(slice(None),),col_sel]
    print(key,phase,ord1.shape,sel.shape)

    corr_tmp = mx[key+(phase,)]
    corr_tmp = corr_tmp+np.diag(np.nan*np.diag(corr_tmp))
    ax[row,col].hist(np.ravel(corr_tmp),range=(-1,1),bins=20)
    ax[row,col].set_yscale('log')
    ax[row,col].set_title('%s, %s: %d'%(key,phase,count))
    

### Real value

In [None]:
#help(pd.tools.plotting.table)
# FIXME index column toooo wide
fig, ax = plt.subplots(1,1,figsize=(12,16))
fig.suptitle('Statistics on the correlation coefficients',fontsize=16)
ax.axis('off')
#ax.set_position([.5, 0.2, 0.5, 0.6])
a = la.df_epoch(np.round(ds.describe(),4).T)
cw = np.ones((len(a.columns),))
#t = pd.tools.plotting.table(ax, a, loc='upper right', fontsize=12, colWidths=0.6*cw/np.sum(cw))
tab = mpl.table.table(ax, cellText=a.values,
                             rowLabels=[', '.join(x) for x in a.index.values], colLabels=a.columns.values.astype(str),
                             loc='upper right', fontsize=20, colWidths=0.6*cw/np.sum(cw), bbox=[0.3,0,0.7,1], cellLoc='center')
pp.savefig()
plt.close(fig)
a

In [None]:
num_rows = num_phases
num_cols = len(la.epochs)
fig, ax = plt.subplots(num_rows,num_cols, figsize=(5*num_cols,5*num_rows), sharex=True, sharey=True)
fig.suptitle('Distribution of the above correlation coefficients\n(diagonals excluded)',fontsize=16)
#mi = pd.MultiIndex.from_product([['CS+','CS-'],la.port,la.puff],names=['context','port','puffed'])
#mi = mi.insert(0,('Baseline','W+','A-'))
mi = pd.MultiIndex.from_tuples(la.legal_conditions,names=['context','port','puffed'])
#color = ['y','r','r','g','g','b','b','g','g']
color = ['y', 'magenta','purple','red','maroon','cyan','lime']
cat = len(mi)
b = a.reset_index().set_index(['phase']+la.sort_learning)
for (irow,row), (icol,col) in itertools.product(enumerate(la.phases),enumerate(la.epochs)):
    try:
        bar = b.loc[(row, col),:].reindex(mi)
        ax[irow,icol].set_title(col)
        ax[irow,icol].set_ylabel(row)
        low, high = [0]+bar['25%'].fillna(0).tolist(),[0]+bar['75%'].fillna(0).tolist()
        ax[irow,icol].fill_between(np.arange(0,cat+1), low, high, alpha=0.1, interpolate=False, color='grey', edgecolor=None, step='pre')
        ax[irow,icol].bar(range(0,cat),bar['mean'],1,yerr=bar['std'],color=color)
        ax[irow,icol].set_xticks(np.arange(0,cat)+0.5)
        ax[irow,icol].set_xticklabels(mi.values, rotation='vertical')
        
    except KeyError:
        pass
pp.savefig()

### Absolute value

In [None]:
#help(pd.tools.plotting.table)
# FIXME index column toooo wide
fig, ax = plt.subplots(1,1,figsize=(12,16))
fig.suptitle('Statistics on the absolute value of correlation coefficients',fontsize=16)
ax.axis('off')
#ax.set_position([.5, 0.2, 0.5, 0.6])
a = la.df_epoch(np.round(np.abs(ds).describe(),4).T)
cw = np.ones((len(a.columns),))
#t = pd.tools.plotting.table(ax, a, loc='upper right', fontsize=12, colWidths=cw/np.sum(cw))
tab = mpl.table.table(ax, cellText=a.values,
                             rowLabels=[', '.join(x) for x in a.index.values], colLabels=a.columns.values.astype(str),
                             loc='upper right', fontsize=20, colWidths=0.6*cw/np.sum(cw), bbox=[0.3,0,0.7,1], cellLoc='center')
pp.savefig()
plt.close(fig)
a

In [None]:
num_rows = num_phases
num_cols = len(la.epochs)
fig, ax = plt.subplots(num_rows,num_cols, figsize=(5*num_cols,5*num_rows), sharex=True, sharey=True)
fig.suptitle('Distribution of the above absolute value of the correlation coefficients\n(diagonals excluded)',fontsize=16)
#mi = pd.MultiIndex.from_product([['CS+','CS-'],la.port,la.puff],names=['context','port','puffed'])
#mi = mi.insert(0,('Baseline','W+','A-'))
mi = pd.MultiIndex.from_tuples(la.legal_conditions,names=['context','port','puffed'])
#color = ['y','r','r','g','g','b','b','g','g']
color = ['y', 'magenta','purple','red','maroon','cyan','lime']
cat = len(mi)
b = a.reset_index().set_index(['phase']+la.sort_learning)
for (irow,row), (icol,col) in itertools.product(enumerate(la.phases),enumerate(la.epochs)):
    try:
        bar = b.loc[(row, col),:].reindex(mi)
        ax[irow,icol].set_title(col)
        ax[irow,icol].set_ylabel(row)
        low, high = [0]+bar['25%'].fillna(0).tolist(),[0]+bar['75%'].fillna(0).tolist()
        ax[irow,icol].fill_between(np.arange(0,cat+1), low, high, alpha=0.1, interpolate=False, color='grey', edgecolor=None, step='pre')
        ax[irow,icol].bar(range(0,cat),bar['mean'],1,yerr=bar['std'],color=color)
        ax[irow,icol].set_xticks(np.arange(0,cat)+0.5)
        ax[irow,icol].set_xticklabels(mi.values, rotation='vertical')
        
    except KeyError:
        pass
pp.savefig()

## Similarity of correlation matrices

In [None]:
num_rows = len(et.index)
num_cols = num_phases

change = np.zeros((num_cols,num_rows,num_rows))
for col in range(0,num_cols):
    phase = la.phases[col]
    for row1 in range(0,num_rows):
        key1 = et.index[row1]
        count1 = et.ix[row1]
        for row2 in range(0,num_rows):
            key2 = et.index[row2]
            count2 = et.ix[row2]
            change[col,row1,row2] = np.linalg.norm(np.ravel(mx[key1+(phase,)]-mx[key2+(phase,)])/np.size(mx[key2+(phase,)]),2)

for col in range(0,num_cols):
    fig, ax = plt.subplots(1,1, figsize=(8,6))
    fig.tight_layout(rect=[0.4,0,0.95,0.55])
    #fig = plt.figure()
    #ax = fig.gca()
    img = ax.matshow(change[col]+np.diag(np.nan*np.diag(change[col])), cmap=plt.get_cmap('rainbow'))

    fig.suptitle('Difference between test cases (RMS distance)\n'+', '.join(et.index.names),fontsize=16)
    ax.set_xticks(np.array(range(0,len(et.index))))
    ax.set_xticklabels(et.index.values.tolist(),rotation=90)
    ax.set_yticks(np.array(range(0,len(et.index))))
    ax.set_yticklabels(et.index.values.tolist())

    # without set_yticks
    # ax.set_yticklabels([tuple()]+et.index.values.tolist())
    fig.colorbar(img)
    pp.savefig()

In [None]:
pp.close()