### Reqirements
* #### You need to install module future, manual importing from \_\_future\_\_ is at your convenience
* #### For hdf data import you need pytables too which is not default installed with Anaconda

In [None]:
#from future.utils import PY3
import future
from __future__ import (absolute_import, division,
                        print_function, unicode_literals)
import pandas as pd
import numpy as np
import os
import pprint
#from IPython.display import display
import IPython.display as disp
display = disp.display
import matplotlib.pyplot as plt
import scipy.stats as stats
zscore, describe = stats.mstats.zscore, stats.describe
import warnings
import datetime
dt, td = datetime.datetime, datetime.timedelta
import imp

%matplotlib inline

In [None]:
import ca_lib as la
imp.reload(la)

## Load files

In [None]:
basedir = '../_share/Losonczi/'

# Display database folders
display(os.listdir(basedir))

# Select animal
#animal = 'msa1215_1'; FPS = 30
#animal = 'msa0216_4'; FPS = 8
#animal = 'msa0316_1'; FPS = 8
#animal = 'msa0316_3'; FPS = 8
animal = 'msa0316ag_1'; FPS = 8

# List dir
mydir = os.path.join(basedir,animal)
os.listdir(mydir)

In [None]:
# Available trials and ROIs
data = la.load_files(mydir)
print (data.raw.shape, '\n', data.trials, '\n', data.rois)

In [None]:
# Post-Learning may repeat session_num therefore an additional index,
# day_num is created. See msa0316_1.
# It seems though that Pre-Learning and Learning treats session_num as documented.
display(data.experiment_traits.head())
display(data.experiment_traits[data.experiment_traits['day_leap']])

## Experiment protocol configurations

In [None]:
et = data.experiment_traits.copy()
et = la.df_epoch(et.groupby(la.display_learning).size().to_frame(name='count'))
#et.to_clipboard()
disp.display(disp.HTML('<font color="red">ATTENTION, </font>for later conformity we store columns in a <b>different order</b>: %s !!!'%la.sort_learning))
display(la.df_epoch(et))

et = data.experiment_traits.copy()
etc= et.groupby(la.sort_learning[1:]).size().to_frame(name='count')
et = la.df_epoch(et.groupby(la.sort_learning).size().to_frame(name='count'))

## Prepare data

In [None]:
df_data = data.filtered
df_raw = data.raw

In [None]:
# See how many ROIs are available for which frames

avail_sum = (~data.filtered.isnull()).sum() / len(data.trials)
plt.plot(avail_sum)
plt.xlabel('Camera frame within experiment')
plt.ylabel('Available ROIs on average')

In [None]:
# See which ROI is available in which trial and for how many frames

avail = ((~data.filtered.isnull()).sum(axis=1)).to_frame('nFrames').unstack()

print(avail.shape)
display(avail.head())
display(avail.tail())

In [None]:
# Create boolean DataFrame which ROI is spiking in which camera frame

# create empty structure for cumsum
df_template = pd.DataFrame(data=0,index=data.mirow,columns=data.icol)
df_spike = df_template.copy()

# select spike data
spikes = data.transients.loc[data.transients['in_motion_period']==False,['start_frame','stop_frame']]
spikes['count']=1

# fill in spike start and stop points (rename column to keep columns.name in df_spike)
sp = spikes[['start_frame','count']].rename(columns={'start_frame':'frame'}).pivot(columns='frame').fillna(0)
df_spike = df_spike.add(sp['count'], fill_value=0)
sp = spikes[['stop_frame','count']].rename(columns={'stop_frame':'frame'}).pivot(columns='frame').fillna(0)
df_spike = df_spike.add(-sp['count'], fill_value=0)

# cumulate, conversion to int is not adviced if using NaNs
df_spike = df_spike.cumsum(axis=1).astype(int)
df_spike = df_spike + data.time_roi_mask

print('table shape', df_spike.shape, 'active frames*ROIs', df_spike.sum().sum())
display(df_spike.head(25))
display(df_spike.tail())

In [None]:
# Create boolean DataFrame whether licking happens in camera frame

# Check for valid data and calculate their frame
print('All entries', data.behavior.shape)
df_lick = data.behavior[data.behavior.loc[:,'stop_time']>data.behavior.loc[:,'start_time']].copy()
print('Valid licks', df_lick.shape)
df_lick['frame'] = (FPS*(df_lick['start_time']+df_lick['stop_time'])/2).apply(np.round).astype(int)
display(df_lick.head())
display(df_lick.tail())
# Convert to a DataFrame like df_data or df_raw
df_lick = df_lick[['lick_idx','frame']].reset_index()
df_lick = df_lick.groupby(['time','frame']).count().unstack(fill_value=0)
display(df_lick.head())
df_lick = df_lick['lick_idx'].reindex(index=data.mirow.levels[0],columns=data.icol,fill_value=0)
display(df_lick.head())
# Number of remaining licks
print('Remaining licks',df_lick.sum().sum())
# Smoothen
from scipy.ndimage.filters import gaussian_filter
df_lick = df_lick.apply(lambda x: gaussian_filter(x.astype(float)*data.FPS, sigma=0.25*data.FPS), axis=1, raw=True)
display(df_lick.head())

## z-scoring

In [None]:
z_spike = la.pd_zscore_by_roi(df_spike, FPS, -2*FPS, axis=1)
z_data = la.pd_zscore_by_roi(df_data, FPS, -2*FPS, axis=1)
z_raw = la.pd_zscore_by_roi(df_raw, FPS, -2*FPS, axis=1)
z_lick = la.pd_zscore_clip(df_lick, FPS, -2*FPS, axis=1)

z_data = z_data.sort_index()
z_raw = z_raw.sort_index()

### Triggers

In [None]:
def trigger(data, threshold, rising=True, hold_off=None):
    data = np.array(data)
    if hold_off:
        raise ValueError('Hold off period not implemented yet.')
    if rising:
        trig = (data[1:]>threshold) & (data[:-1]<=threshold)
    else:
        trig = (data[1:]<threshold) & (data[:-1]>=threshold)
    trig = np.append([False],trig)
    return trig

def trigger_horizonal_pd(df, threshold, axis=1, hold_off=None):
    triggers_rise = df.apply(lambda x: trigger(x,threshold, True), axis=axis)
    triggers_rise[triggers_rise==0]=np.nan
    triggers_fall = df.apply(lambda x: trigger(x,threshold, False), axis=axis)
    triggers_fall[triggers_fall==0]=np.nan
    
    if axis==1:
        triggers_rise = triggers_rise.stack()
        triggers_fall = triggers_fall.stack()
    elif axis==0:
        triggers_rise = triggers_rise.T.stack().T
        triggers_fall = triggers_fall.T.stack().T
    else:
        warnings.warn('Axis reduction not implemented for axis.')
    triggers_rise.name='weight'
    triggers_fall.name='weight'
    return triggers_rise, triggers_fall

def trigger_event_pd(df, start, stop):
    mi = pd.MultiIndex.from_product((df.index.values, [start]), names=['time', 'frame'])
    triggers_start = pd.Series(1.0, index=mi, name='weight')
    mi = pd.MultiIndex.from_product((df.index.values, [stop]), names=['time', 'frame'])
    triggers_stop = pd.Series(1.0, index=mi, name='weight')
    mi = pd.MultiIndex.from_product((df.index.values, list(range(start,stop))), names=['time', 'frame'])
    triggers_allow = pd.Series(1.0, index=mi, name='weight')

    return triggers_start, triggers_stop, triggers_allow


In [None]:
z_spike_threshold = 5.0/np.sqrt(len(data.rois))

max_lik_rate = 20
c,b = np.histogram(df_lick.values.ravel(),range=(0,max_lik_rate),bins=max_lik_rate)
lick_threshold = (np.argmax(c[1:])+1.5)/2
plt.hist(df_lick.values.ravel(),log=True,range=(0,max_lik_rate),bins=max_lik_rate)
plt.plot(lick_threshold,2,'y*',ms=15)
print(lick_threshold)

In [None]:
# The histogram shape justifies putting the threshold at the half maximum
lick_triggers_rise, lick_triggers_fall = trigger_horizonal_pd(df_lick, lick_threshold)
print (lick_triggers_rise.shape,lick_triggers_fall.shape)
print ('Port was present in %d trials.'%data.experiment_traits[data.experiment_traits['port']=='W+'].shape[0])

In [None]:
# Define the boundaryof a p<0.005 set 
spike_triggers_rise, spike_triggers_fall = trigger_horizonal_pd(z_spike.mean(level=0), z_spike_threshold)
print (spike_triggers_rise.shape,spike_triggers_fall.shape)

In [None]:
csp_triggers_rise, csp_triggers_fall, csp_triggers_allow = trigger_event_pd(
    data.experiment_traits[data.experiment_traits['port']=='W+'],
    la.events[1]*data.FPS, la.events[2]*data.FPS)
csm_triggers_rise, csm_triggers_fall, csm_triggers_allow = trigger_event_pd(
    data.experiment_traits[data.experiment_traits['port']=='W-'],
    la.events[1]*data.FPS, la.events[2]*data.FPS)

us_triggers_rise, us_triggers_fall, us_triggers_allow = trigger_event_pd(
    data.experiment_traits[data.experiment_traits['puffed']=='A+'],
    la.events[3]*data.FPS, la.events[4]*data.FPS)

tra_triggers_rise, tra_triggers_fall, tra_triggers_allow = trigger_event_pd(
    data.experiment_traits[data.experiment_traits['context']=='CS+'],
    la.events[2]*data.FPS, la.events[3]*data.FPS)

# Plot

In [None]:
from matplotlib.backends.backend_pdf import PdfPages

class helpmultipage(object):
    def __init__(self, filename):
        self.filename = filename
        self.isopen = False
        self.open()
        
    def __del__(self):
        self.close()
        
    def savefig(self, dpi=None):
        if self.isopen:
            self.pp.savefig(dpi=dpi)

    def open(self):
        if (~self.isopen) and len(self.filename):
            self.pp = PdfPages(self.filename)
            self.isopen = True
        
    def close(self):
        if self.isopen:
            self.pp.close()
        self.isopen = False

#### Explanatory figure

In [None]:
pp = helpmultipage('explanatory.pdf')

In [None]:
import matplotlib.patches as mpatches
from matplotlib.collections import PatchCollection
center = FPS * (la.events[:-1]+la.events[1:]) /2
left = FPS * la.events
width = FPS * (la.events[1:]-la.events[:-1])
vcenter = 0.0
vstart = -0.5

def label90(x,y,text):
    ax.text(x, y, text, ha="center", va="center", family='sans-serif', size=14, rotation=90)

fig, (empty, ax) = plt.subplots(2,1,figsize=(6,8))
fig.tight_layout(pad=3)
empty.axis('off')
#ax = fig.gca()
fig.suptitle('Explanatory figure')
ax.set_xlabel('Camera frame')
ax.set_ylabel('z-scored activity')
ax.set_ylim(vstart,vstart+1)
ax.plot(z_spike.mean(axis=0)+0.00, label="(CategoryA, True): #trials", c=(1,1,0))
ax.plot(z_spike.mean(axis=0)+0.02, label="(CategoryB, True): #trials", c=(.5,1,.5))
ax.plot(-z_spike.mean(axis=0)+0.00, label="(CategoryA, False): #trials", c=(1,.8,1))
ax.plot(-z_spike.mean(axis=0)+0.02, label="(CategoryB, False): #trials", c=(.5,1,1))
patches = []
# mark delay
label90(center[0], vcenter, 'excitation by\nshowing water')
# mark CS
rect = mpatches.Rectangle((left[1],vstart), width[1], 1, ec="none")
patches.append(rect)
label90(center[1], vcenter, 'CS± if tone\n"Baseline" otherwise')
# mark delay
label90(center[2], vcenter, 'delay')
# mark UC
rect = mpatches.Rectangle((left[3],vstart), width[3], 1, ec="none")
patches.append(rect)
label90(center[3], vcenter, 'UC if any')
# mark water
ax.text((left[0]+left[3])/2, vstart, "water source present\niff allowed to lick",
        ha="center", va="bottom", family='sans-serif', size=14, bbox=dict(boxstyle="DArrow", pad=0.0, fc='c'))

for i in range(0,len(la.events)):
    ax.axvline(x=la.events[i]*FPS, ymin=0.0, ymax = 1.0, linewidth=1, color='k')
colors = np.linspace(0, 1, len(patches))
collection = PatchCollection(patches, cmap=plt.cm.hsv, alpha=0.1)
collection.set_array(np.array(colors))
ax.add_collection(collection)

leg = ax.legend(loc='lower center', title="Category name, Condition name",
               bbox_to_anchor=(0.5, 1.1))
leg.get_title().set_fontsize('large')
leg.get_title().set_fontweight('bold')
with warnings.catch_warnings():
    warnings.simplefilter('ignore', UserWarning)
    fig.show()
pp.savefig()

In [None]:
pp.close()

In [None]:
pp = helpmultipage(animal+'_pop.pdf')

In [None]:
la.plot_data(df_spike, df_data, df_lick, data.experiment_traits, FPS)
pp.savefig()

## Z-scored spiking
Spiking is "True" in the [intervals) given in transients_data.hc5

In [None]:
bsections = np.arange(0,60,5)*FPS
bcenters = (bsections[1:]+bsections[:-1])/2
#mybfun = lambda x: la.func_over_intervals(np.nanmean, bsections, np.array(x))
mybfun = pd.DataFrame.mean

zb_spike = la.pd_aggr_col(z_spike, mybfun, bsections, bcenters.astype(str))
zb_data = la.pd_aggr_col(z_data, mybfun, bsections, bcenters.astype(str))
zb_raw = la.pd_aggr_col(z_raw, mybfun, bsections, bcenters.astype(str))
zb_lick = la.pd_aggr_col(z_lick, mybfun, bsections, bcenters.astype(str))
b_lick = la.pd_aggr_col(df_lick, mybfun, bsections, bcenters.astype(str))

In [None]:
asections = np.append(la.events,[60])*FPS
acenters = (asections[1:]+asections[:-1])/2
#myafun = lambda x: la.func_over_intervals(np.nanmean, asections, np.array(x))
myafun = pd.DataFrame.mean

za_spike = la.pd_aggr_col(z_spike, myafun, asections, acenters.astype(str))
za_data = la.pd_aggr_col(z_data, myafun, asections, acenters.astype(str))
za_raw = la.pd_aggr_col(z_raw, myafun, asections, acenters.astype(str))
za_lick = la.pd_aggr_col(z_lick, myafun, asections, acenters.astype(str))
a_lick = la.pd_aggr_col(df_lick, myafun, asections, acenters.astype(str))

### Single criterion
* comments

In [None]:
grp = [['context'],['learning_epoch'],['port'],['puffed']]
la.plot_data(z_spike, z_data, df_lick, data.experiment_traits, FPS, grp, title='Population activity')
pp.savefig()
la.plot_data(zb_spike, zb_data, b_lick, data.experiment_traits, FPS, grp, title='Population activity binned', div=bcenters)
pp.savefig()
la.plot_data(za_spike, za_data, a_lick, data.experiment_traits, FPS, grp, title='Population activity averaged over events', div=acenters)
pp.savefig()

### Two criteria
* comments

### Three criteria
* comments

### All criteria
* There is no increased population activity for CS+ without puffing. (For mouse 0216_4 the 1 trial with port displays increase during the trace period - why?)
* During learning mouse 0216_4 shows incresed activity during the UC phase for CS-

In [None]:
grp = [['context','port','puffed']]
la.plot_epochs(z_spike, z_data, df_lick, data.experiment_traits, etc, FPS, grp, title='Population activity')
pp.savefig()
la.plot_epochs(zb_spike, zb_data, b_lick, data.experiment_traits, etc, FPS, grp, title='Population activity binned', div=bcenters)
pp.savefig()
la.plot_epochs(za_spike, za_data, a_lick, data.experiment_traits, etc, FPS, grp, title='Population activity averaged over events', div=acenters)
pp.savefig()

### Activities conditional on epoch

In [None]:
def plot_by_epoch(epoch):
    experiment_c = data.experiment_traits[data.experiment_traits.loc[:,'learning_epoch']==epoch]
    spike_c = z_spike.reindex(experiment_c.index, level='time')
    data_c = z_data.reindex(experiment_c.index, level='time')
    raw_c = z_raw.reindex(experiment_c.index, level='time')
    lick_c = df_lick.reindex(experiment_c.index)
    print (experiment_c.shape, z_spike.shape)
    spike_ca = la.pd_aggr_col(spike_c, myafun, asections, acenters.astype(str))
    data_ca = la.pd_aggr_col(data_c, myafun, asections, acenters.astype(str))
    raw_ca = la.pd_aggr_col(raw_c, myafun, asections, acenters.astype(str))
    lick_ca = la.pd_aggr_col(lick_c, myafun, asections, acenters.astype(str))
    print (spike_c.shape, spike_ca.shape)

    grp = [['context','port'],['context','puffed'],['port','puffed']]
    la.plot_data(spike_c, data_c, lick_c, data.experiment_traits, FPS, grp, title=epoch)
    pp.savefig()
    la.plot_data(spike_ca, data_ca, lick_ca, data.experiment_traits, FPS, grp, title=epoch+' averaged over events', div=acenters)
    pp.savefig()

#### Pre-learning

In [None]:
plot_by_epoch('Pre-Learning')

#### Learning

In [None]:
plot_by_epoch('Learning')

#### Post-Learning

In [None]:
plot_by_epoch('Post-Learning')

In [None]:
pp.close()

In [None]:
pp = helpmultipage(animal+'_phases.pdf')

etmp = data.experiment_traits.reset_index(drop=True).set_index(la.sort_learning)

for p,aggr in enumerate(za_data.columns):
    nplot = len(et.index)
    ncol = 12
    nrow = int(np.ceil(len(et.index)/float(ncol)))
    fig, ax = plt.subplots(nrow,ncol,figsize=(2*ncol,1+10*nrow),squeeze=False,sharey=True)
    fig.tight_layout(pad=3, h_pad=3, rect=[0,0,1,0.8])
    fig.suptitle('Phase: %s'%la.phases[p],fontsize=16)
    for i, cond in enumerate(et.index):
        col = i%ncol
        row = int((i-col)/ncol)
        sel = etmp.loc[cond,'timestr']
        tmp = za_data.loc[sel.tolist(),aggr].unstack('time')
        ax[row,col].matshow(tmp.values,origin='lower')
        ax[row,col].xaxis.set_ticks_position('bottom')
        ax[row,col].set_title('\n'.join(cond))
        ax[row,col].set_ylabel('Unit ID')
        ax[row,col].set_xlabel('Trial')
    pp.savefig()    
pp.close()

### An example of spiking
The first 1 second of the recording seems missing

In [None]:
def plot_transients(ax, transients, experiment_id, rois=None):
    '''Plot transients with colored line and put a tic at the maxima'''
    import itertools, matplotlib
    ncolors = 10
    color=itertools.cycle(plt.cm.rainbow(np.linspace(0,1,ncolors)))
    # Plot all neural units in this experiment
    if rois is None:
        rois = transients.loc[experiment_id].index.unique()
    colors=np.array(list(itertools.islice(color,len(rois))))
    try:
        firing = transients.loc[experiment_id,['start_frame', 'stop_frame', 'max_frame']].join(roi_df.set_index(['roi_id']), how='left')
        # Reshape things so that we have a sequence of:
        # [[(x0,y0),(x1,y1)],[(x0,y0),(x1,y1)],...]
        # based on http://stackoverflow.com/questions/17240694/python-how-to-plot-one-line-in-different-colors
        segments = firing[['start_frame', 'idx', 'stop_frame', 'idx']].values.reshape(-1,2,2)
        coll = matplotlib.collections.LineCollection(segments, cmap=plt.cm.rainbow)
        coll.set_array(firing['idx']%ncolors)
        if len(firing):
            #ax.plot(firing[['start_frame', 'stop_frame']].T,firing[['idx', 'idx']].T,c=colors[firing['idx']])
            ax.add_collection(coll)
            ax.autoscale_view()
        if len(firing):
            ax.plot(firing['max_frame'].T,firing['idx'].T,'|',ms=5,c='k')
            #xlim, ylim = ax.get_xlim(), ax.get_ylim()
            #ax.scatter(firing['max_frame'].T,firing['idx'].T,s=5,c='k',marker='|')
            #ax.set_xlim(xlim), ax.set_ylim(ylim)
    except:
        pass
    for i in range(0,len(la.events)):
        ax.axvline(x=la.events[i]*FPS, ymin=0.0, ymax = 1.0, linewidth=1, color='k')
    ax.set_title('Transient peaks and durations ExID: '+experiment_id)
    ax.set_xlabel('Camera frame')
    ax.set_ylabel('Unit ID')

In [None]:
def plot_levels(ax, data, experiment_id, rois=None, zoom=0.5, dist=1.0):
    '''Plot transients with colored line and put a tic at the maxima'''
    import itertools, matplotlib
    ncolors = 10
    color=itertools.cycle(plt.cm.rainbow(np.linspace(0,1,ncolors)))
    # Plot all neural units in this experiment
    if rois is None:
        rois = data.loc[experiment_id].index.unique()
    colors=np.array(list(itertools.islice(color,len(rois))))
    try:
        firing = (zoom*data.loc[experiment_id,:]).add(dist*roi_df.set_index(['roi_id']).loc[:,'idx'],axis=0)
        if len(firing):
            ax.plot(firing.T)#,c=colors)
    except:
        pass
    for i in range(0,len(la.events)):
        ax.axvline(x=la.events[i]*FPS, ymin=0.0, ymax = 1.0, linewidth=1, color='k')
    ax.set_title('Transient peaks and durations ExID: '+experiment_id)
    ax.set_xlabel('Camera frame')
    ax.set_ylabel('Unit ID')

In [None]:
def plot_spiking_nan(ax, spiking, experiment_id, rois=None):
    '''Mark unavailable data with a gray dot'''
    import itertools, matplotlib
    ncolors = 10
    color=itertools.cycle(plt.cm.rainbow(np.linspace(0,1,ncolors)))
    # Plot all neural units in this experiment
    if rois is None:
        rois = transients.loc[experiment_id].index.unique()
    try:
        firing = spiking.loc[experiment_id,:].join(roi_df.set_index(['roi_id']), how='left')
        firing.columns.name='frame'
        firing = firing.set_index('idx').stack(dropna=False)
        firing = firing[firing.isnull()]
        firing = firing.reset_index()
        if len(firing):
            #ax.scatter(firing.loc[:,'frame'],firing.loc[:,'idx'],s=1,c='lightgray',marker='.')
            ax.plot(firing.loc[:,'frame'],firing.loc[:,'idx'],'.',ms=1,c='lightgray')
    except:
        pass

In [None]:
def plot_triggers(ax, triggers, experiment_id, pos=0, ls='x', c='b', ms=8):
    '''Plot trigger events'''
    try:
        if type(triggers) is not list:
            triggers=[triggers]
        for i, trig in enumerate(triggers):
            x = trig.loc[experiment_id].index.values
            x = x.reshape((1,-1))
            if x.shape[1]>0:
                ls1 = ls[i] if type(ls) is list else ls
                c1 = c[i] if type(c) is list else c
                ms1 = ms[i] if type(ms) is list else ms
                ax.plot(x, pos, ls1, c=c1, ms=ms1)
    except:
        pass
    
def plot_behavior(ax, licks, experiment_id):
    '''Plot individual licks'''
    try:
        i=-5
        licking = np.array(licks.loc[experiment_id,['start_time', 'stop_time']])*FPS
        if len(licking):
            ax.plot(licking.T,i*np.ones_like(licking.T),c='b')
        licking = np.array(licks.loc[experiment_id,['start_time', 'stop_time']].mean(axis=1))*FPS
        if len(licking):
            ax.plot(licking,i*np.ones_like(licking),'o',ms=5,c='k')
    except:
        pass

def plot_licking(ax, licking, experiment_id, pos=-20, zoom=1.0, c='b', threshold=None):
    '''Plot licking rate'''
    try:
        ax.axhline(y=pos, xmin=0.0, xmax = 1.0, linewidth=1, color='k')
        if threshold is not None:
            ax.axhline(y=threshold*zoom+pos,c='lightgray')
        licking = licking.loc[experiment_id,:].values
        if len(licking):
            ax.plot(licking*zoom+pos,c=c)
    except:
        pass

def plot_population(ax, data, experiment_id, pos=-20, zoom=10.0, c='r', threshold=None):
    '''Plot population activity'''
    try:
        ax.axhline(y=pos, xmin=0.0, xmax = 1.0, linewidth=1, color='k')
        if threshold is not None:
            ax.axhline(y=threshold*zoom+pos,c='lightgray')
        data = data.loc[experiment_id,:].mean(axis=0)
        if len(data):
            ax.plot(data*zoom+pos,c=c)
    except:
        pass

In [None]:
def plot_conditions(ax, conditions, experiment_id, height=20):
    '''Draw a table and write experimental conditions into it'''
    import matplotlib
    a = conditions.loc[[experiment_id],['learning_epoch','context','port','puffed','session_num','day_num']]
    cw = np.concatenate((la.durations[1:]*data.FPS,np.array([0.5,0.5])*(ax.get_xlim()[1]-la.events[-1]*data.FPS)))

    c = a.copy()
    c.loc[:,:]='lightblue' if any(a['port'].isin(['W+',True])) else 'white'
    replacement = [('context', 'CS-', 'lightgreen'), ('context', 'CS+', 'lightcoral'),
                   ('context', 'Baseline', 'lightblue'),
                   ('port', 'W+', 'lightblue'), ('puffed', 'A+', 'yellow'),
                   ('port', True, 'lightblue'), ('puffed', True, 'yellow')]
    for label, value, color in replacement:
        c.loc[a[label]==value,label]=color
    
    ylim = ax.get_ylim()
    #tab = pd.tools.plotting.table(ax, a, loc='lower center', fontsize=24, colWidths=cw/np.sum(cw))
    tab = matplotlib.table.table(ax, cellText=a.values,
                                   #rowLabels=rowLabels, colLabels=colLabels,
                           loc='lower center', fontsize=24, colWidths=cw/np.sum(cw), bbox=[0,0,1,height/(ylim[1]-ylim[0])], cellLoc='center', cellColours=c.values)
    # fontsize keyword is accepted but seems ineffective
    tab.set_fontsize(24)
    for key, cell in tab.get_celld().items():
        cell.set_linewidth(0)

In [None]:
# Order experiments by settings (deprecated)
et3 = data.experiment_traits.copy().reset_index(drop=True)
#et3.loc[:,'session_num'] = et3.loc[:,'session_num'].astype(int)
et3 = et3.sort_values(['learning_epoch','context','port','puffed','session_num']).set_index(['learning_epoch','context','port','puffed'])

In [None]:
# Enumerate ROIs
roi_df = pd.DataFrame(data.rois, columns=['roi_id']).reset_index().rename(columns={'index':'idx'})

In [None]:
# Triggers
trig_list_data = [lick_triggers_rise, lick_triggers_fall, spike_triggers_rise, spike_triggers_fall]
trig_list_sign = ['o', 's', '^', 'v']
trig_list_color = ['b', 'y', 'r', 'g']

In [None]:
# Show an example
idx = data.experiment_traits.index[9]
fig, ax = plt.subplots(1,1,figsize=(16,10))
plot_levels(ax, z_data, idx, data.rois.values)
ax.set_ylim(ymin=-60,ymax=120)
#plot_spiking_nan(ax, df_spike, idx, data.rois.values)
#plot_behavior(ax, data.behavior, idx)
plot_population(ax, z_data, idx, pos=-20, c='y')
plot_population(ax, z_spike, idx, pos=-20, threshold=z_spike_threshold)
plot_licking(ax, df_lick, idx, pos=-40, threshold=lick_threshold)
plot_triggers(ax, trig_list_data, idx, -5, trig_list_sign, c=trig_list_color)
plot_conditions(ax, data.experiment_traits, idx, height=20)

In [None]:
# Show an example
idx = data.experiment_traits.index[9]
fig, ax = plt.subplots(1,1,figsize=(16,10))
plot_transients(ax, data.transients, idx, data.rois.values)
ax.set_ylim(ymin=-60,ymax=len(data.rois)+1)
plot_spiking_nan(ax, df_spike, idx, data.rois.values)
#plot_behavior(ax, data.behavior, idx)
plot_population(ax, z_data, idx, pos=-20, c='y')
plot_population(ax, z_spike, idx, pos=-20, threshold=z_spike_threshold)
plot_licking(ax, df_lick, idx, pos=-40, threshold=lick_threshold)
plot_triggers(ax, trig_list_data, idx, -5, trig_list_sign, c=trig_list_color)
plot_conditions(ax, data.experiment_traits, idx, height=20)

In [None]:
pp = helpmultipage(animal+'_firing.pdf')

xmax = data.transients.loc[:,['stop_frame']].max().values

for idx, val in data.experiment_traits.iterrows(): #et3.iterrows():
    fig, ax = plt.subplots(1,1,figsize=(16,10))
    ax.set_xlim(xmax=xmax)
    experiment_id = val['timestr']
    #print (experiment_id)
    fig.suptitle('learning_epoch, context, port, puffed: #context in epoch, #day\n'+
        '%s: session %s, day %s'%(idx,val['session_num'],val['day_num']))
    plot_transients(ax, data.transients, idx, data.rois.values)
    ax.set_ylim(ymin=-60,ymax=len(data.rois)+1)
    #plot_spiking_nan(ax, df_spike, idx, data.rois.values)
    plot_population(ax, z_data, idx, pos=-20, c='y')
    plot_population(ax, z_spike, idx, pos=-20, threshold=z_spike_threshold)
    plot_licking(ax, df_lick, idx, pos=-40, threshold=lick_threshold)
    plot_triggers(ax, trig_list_data, idx, -5, trig_list_sign, c=trig_list_color)
    plot_conditions(ax, data.experiment_traits, experiment_id, height=20)
    pp.savefig()
    plt.close(fig)
    
pp.close()

### Peri-event averages

In [None]:
def peri_event_avg(data, triggers, diameter=(-3*FPS, 3*FPS), allow=None):
    window = np.arange(diameter[0],diameter[1])
    count=0
    ret = []
    for idx, weight in triggers.iteritems():
        experiment_id, frame = idx
        if (experiment_id in data.index) and ((allow is None) or (idx in allow)):
            tmp = data.loc[experiment_id,:].reindex(columns=frame+window)
            tmp.columns = pd.MultiIndex.from_product([count,window],names=['id','frame'])
            ret.append(tmp)
            count += 1
    if len(ret):
        ret = pd.concat(ret,axis=1)
        return ret, count
    else:
        return None, count

In [None]:
def get_decay(time_range, rate):
    rate = float(rate)
    if type(time_range) is int:
        time_points = np.arange(0,time_range)-0.5*time_range
    elif len(time_range)==2:
        time_points = np.arange(time_range[0],time_range[-1])
    else:
        time_points = time_range
    decay = np.exp(-rate*np.power(time_points,2))
    return decay/np.sum(decay)

In [None]:
def rev_align(data, shape):
    '''Align for broadcast to shape matching axes from the beginning (opposed to numpy convention)'''
    data_dim = data.ndim
    req_dim = len(shape)
    new_axes = np.arange(data_dim,req_dim)
    # TODO: using np.reshape is more efficient
    ret = data
    for axis in new_axes:
        ret = np.expand_dims(data, axis=axis)
    return ret

In [None]:
def rev_broadcast(data, shape):
    '''Broadcast to shape matching axes from the beginning (opposed to numpy convention)'''
    ret = np.broadcast_to(rev_align(data,shape), shape)
    return ret

In [None]:
def match_pattern(data, pattern, std, decay, noise_level=0.1, detailed=False):
    '''Match pattern with decaying strength along time axis (rows). Use any number of columns.
    Pattern and std may be one (time points given, all columns equal) or two dimensional (matrix given).
    Noise_level can be 0 to 2 dimensional, if it is one dimensional then it is understood
    on the category axis (all rows equal), noise_level must be positive if there is any std==0.'''
    diff = data-rev_align(pattern, data.shape)
    scale = rev_align(decay, data.shape)/(noise_level+rev_align(std, data.shape))
    ret = np.nanmean(diff*scale, axis=(0 if detailed else None))
    return ret

In [None]:
def rolling2D(df, func, window, min_periods=None, center=True):
    '''Slice a DataFrame along index (rows) to apply 2D function'''
    # This was a missing feature in pandas: one could previously correlate a single pattern
    # along a selected axis of a 2D DataFrame.
    window = int(window)
    if window<1:
        raise ValueError('window needs positive length')
    if min_periods is None:
        min_periods = window
    else:
        min_periods = int(min_periods)
    if min_periods<1:
        raise ValueError('min_periods needs to be positive')
    start = min_periods-window # first point of first window is start, available points evaluate to [0, min_periods)
    end = len(df)-min_periods # first point of last window is end, available points evaluate to [len-min_periods, len)
    if center:
        shift = int(np.floor(window/2))
    else:
        shift = 0
    first = start
    data = df.iloc[first:first+window,:]
    tmp = func(data)
    try:
        if len(tmp)==len(df.columns):
            ret = pd.DataFrame([], index=df.index, columns=df.columns)
        else:
            ret = pd.DataFrame([], index=df.index, columns=pd.Index(np.arange(0,len(tmp))))
    except:
        ret = pd.DataFrame([], index=df.index, columns=pd.Index([0]))
    ret.iloc[first+shift,:]=tmp
    for first in range(start+1,end+1):
        data = df.iloc[first:first+window,:]
        tmp = func(data)
        ret.iloc[first+shift,:]=tmp
    return ret

### NOTE: it is interesting to compare performance
1. ROIs are inspected one-by-one using 'apply' on all trials then the results aggregated
2. ROIs are inspected one at a time using 'rolling2D' with scalar pattern then the results aggregated
3. Trials are inspecetd one at a time using 'rolling2D' with matrix pattern

In [None]:
import time, cProfile
time.clock()

In [None]:
# This seems to be extremly slow to to the massive maount of function calls
def method1():
    ret1 = []
    diam = (-3*FPS,3*FPS)
    window = diam[1]-diam[0]
    decay = get_decay(window,3.0/FPS)
    dd, c = peri_event_avg(df_spike, lick_triggers_rise, diameter=diam)
    p1 = dd.mean(axis=1, level=1)
    s1 = dd.std(axis=1, level=1)
    t1 = time.clock()
    df_spike1 = df_spike.reset_index().set_index(['roi_id','time']).sort_index() # not much faster than .loc[(slice(None),roi),:]
    func = (lambda x: match_pattern(x,p1.loc[roi].values,s1.loc[roi].values,decay=decay))
    for roi in data.rois: #[0:3]:
        tmp = df_spike1.loc[roi,:].rolling(window,window,center=True,axis=1).apply(func)
        ret1.append(tmp)
    ret1 = pd.concat(ret1).astype(float)
    ret1 = ret1.mean(axis=0, level=0)
    t1 = time.clock()-t1
    print(ret1.shape,t1)
    return ret1
cProfile.run('ret1 = method1()')
plt.plot(ret1.T)
ret1.head()

In [None]:
# To be seen why this one is slow
def method2():
    ret2 = []
    diam = (-3*FPS,3*FPS)
    window = diam[1]-diam[0]
    decay = get_decay(window,3.0/FPS)
    dd, c = peri_event_avg(df_spike, lick_triggers_rise, diameter=diam)
    p1 = dd.mean(axis=1, level=1)
    s1 = dd.std(axis=1, level=1)
    t2 = time.clock()
    df_spike2 = df_spike.reset_index().set_index(['roi_id','time']).sort_index() # not much faster than .loc[(slice(None),roi),:]
    func = (lambda x: match_pattern_testing(x.values,p1.loc[roi].values,s1.loc[roi].values,detailed=True,decay=decay))
    for roi in data.rois: #[0:3]:
        tmp = rolling2D(df_spike2.loc[roi,:].T,func,window,center=True).T
        ret2.append(tmp)
    # Need to cast explicitly (seem not liking numpy's output and inferring 'object')
    ret2 = pd.concat(ret2).astype(float)
    ret2 = ret2.mean(axis=0, level=0)
    t2 = time.clock()-t2
    print(ret2.shape,t2)
    return ret2
cProfile.run('ret2 = method2()')
plt.plot(ret2.T)
ret2.head()

In [None]:
# How can this be so fast even with largest memory requirement?
def method3():
    ret3 = []
    diam = (-3*FPS,3*FPS)
    window = diam[1]-diam[0]
    decay = get_decay(window,3.0/FPS)
    dd, c = peri_event_avg(df_spike, lick_triggers_rise, diameter=diam)
    p1 = dd.mean(axis=1, level=1)
    s1 = dd.std(axis=1, level=1)
    t3 = time.clock()
    func = (lambda x: match_pattern_testing(x.values,p1.T.values,s1.T.values,decay=decay))
    for trial in data.trials: #[0:3]:
        tmp = rolling2D(df_spike.loc[trial,:].T,func,window,center=True).T
        tmp.index=[trial]
        ret3.append(tmp)
    ret3 = pd.concat(ret3)
    t3 = time.clock()-t3
    print(ret3.shape,t3)
    return ret3
cProfile.run('ret3 = method3()')
plt.plot(ret3.T)
ret3.head()

In [None]:
def plot_peri_event(ax, df, title=None, pos=-15, zoom=10.0, vmin=None, vmax=None):
    extent = np.min(df.columns.values)-0.5, np.max(df.columns.values)+0.5, -0.5, len(df)+0.5
    ax.set_xlim(extent[0:2])
    ax.set_ylim((pos-zoom,extent[3]))
    ret = ax.matshow(df, origin='lower', aspect='auto', extent=extent, vmin=vmin, vmax=vmax)
    ax.axhline(y=pos,xmin=0.0,xmax=1.0,c='gray')
    ax.plot(zoom*df.mean(axis=0)+pos)
    ax.xaxis.set_ticks_position('bottom')
    ax.set_xlabel('$\Delta$frame')
    ax.set_ylabel('Unit ID')
    ax.set_xlim(extent[0:2])
    ax.set_ylim((pos-zoom,extent[3]))
    if title is not None:
        ax.set_title(title)
    return ret

In [None]:
def plot_peri_us(df_spike, title='Spiking'):
    fig, ax = plt.subplots(1, 12, figsize=(24,12), sharey=True)
    fig.tight_layout(rect=(0,0,1,0.9), w_pad=2)
    fig.suptitle(title)
    dd, c = peri_event_avg(df_spike, lick_triggers_rise)
    if c:
        plot_peri_event(ax[0], dd.mean(axis=1, level=1), 'Lick rise: %d'%c, vmin=-1, vmax=1)
        plot_peri_event(ax[1], dd.std(axis=1, level=1), '(sigma)', vmin=0, vmax=2)
    dd, c = peri_event_avg(df_spike, lick_triggers_fall)
    if c:
        plot_peri_event(ax[2], dd.mean(axis=1, level=1), 'Lick fall: %d'%c, vmin=-1, vmax=1)
        plot_peri_event(ax[3], dd.std(axis=1, level=1), '(sigma)', vmin=0, vmax=2)
    dd, c = peri_event_avg(df_spike, lick_triggers_rise, allow=us_triggers_allow)
    if c:
        plot_peri_event(ax[4], dd.mean(axis=1, level=1), 'Lick rise US: %d'%c, vmin=-1, vmax=1)
        plot_peri_event(ax[5], dd.std(axis=1, level=1), '(sigma)', vmin=0, vmax=2)
    dd, c = peri_event_avg(df_spike, lick_triggers_fall, allow=us_triggers_allow)
    if c:
        plot_peri_event(ax[6], dd.mean(axis=1, level=1), 'Lick fall US: %d'%c, vmin=-1, vmax=1)
        plot_peri_event(ax[7], dd.std(axis=1, level=1), '(sigma)', vmin=0, vmax=2)
    dd, c = peri_event_avg(df_spike, us_triggers_rise)
    if c:
        plot_peri_event(ax[8], dd.mean(axis=1, level=1), 'US start: %d'%c, vmin=-1, vmax=1)
        plot_peri_event(ax[9], dd.std(axis=1, level=1), '(sigma)', vmin=0, vmax=2)
    dd, c = peri_event_avg(df_spike, us_triggers_fall)
    if c:
        plot_peri_event(ax[10], dd.mean(axis=1, level=1), 'US end: %d'%c, vmin=-1, vmax=1)
        plot_peri_event(ax[11], dd.std(axis=1, level=1), '(sigma)', vmin=0, vmax=2)
    return fig

In [None]:
def plot_peri_csp(df_spike, title='Spiking'):
    fig, ax = plt.subplots(1, 12, figsize=(24,12), sharey=True)
    fig.tight_layout(rect=(0,0,1,0.9), w_pad=2)
    fig.suptitle(title)
    dd, c = peri_event_avg(df_spike, lick_triggers_rise)
    if c:
        plot_peri_event(ax[0], dd.mean(axis=1, level=1), 'Lick rise: %d'%c, vmin=-1, vmax=1)
        plot_peri_event(ax[1], dd.std(axis=1, level=1), '(sigma)', vmin=0, vmax=2)
    dd, c = peri_event_avg(df_spike, lick_triggers_fall)
    if c:
        plot_peri_event(ax[2], dd.mean(axis=1, level=1), 'Lick fall: %d'%c, vmin=-1, vmax=1)
        plot_peri_event(ax[3], dd.std(axis=1, level=1), '(sigma)', vmin=0, vmax=2)
    dd, c = peri_event_avg(df_spike, lick_triggers_rise, allow=csp_triggers_allow)
    if c:
        plot_peri_event(ax[4], dd.mean(axis=1, level=1), 'Lick rise CS+: %d'%c, vmin=-1, vmax=1)
        plot_peri_event(ax[5], dd.std(axis=1, level=1), '(sigma)', vmin=0, vmax=2)
    dd, c = peri_event_avg(df_spike, lick_triggers_fall, allow=csp_triggers_allow)
    if c:
        plot_peri_event(ax[6], dd.mean(axis=1, level=1), 'Lick fall CS+: %d'%c, vmin=-1, vmax=1)
        plot_peri_event(ax[7], dd.std(axis=1, level=1), '(sigma)', vmin=0, vmax=2)
    dd, c = peri_event_avg(df_spike, csp_triggers_rise)
    if c:
        plot_peri_event(ax[8], dd.mean(axis=1, level=1), 'CS+ start: %d'%c, vmin=-1, vmax=1)
        plot_peri_event(ax[9], dd.std(axis=1, level=1), '(sigma)', vmin=0, vmax=2)
    dd, c = peri_event_avg(df_spike, csp_triggers_fall)
    if c:
        plot_peri_event(ax[10], dd.mean(axis=1, level=1), 'CS+ end: %d'%c, vmin=-1, vmax=1)
        plot_peri_event(ax[11], dd.std(axis=1, level=1), '(sigma)', vmin=0, vmax=2)
    return fig

In [None]:
pp = helpmultipage(animal+'_peri.pdf')
for epoch in la.epochs.values:
    experiment_c = data.experiment_traits[data.experiment_traits.loc[:,'learning_epoch']==epoch]
    spike_c = df_spike.reindex(experiment_c.index, level='time')
    data_c = z_data.reindex(experiment_c.index, level='time')
    fig=plot_peri_us(spike_c,'%s Spiking on US'%epoch)
    pp.savefig()
    plt.close(fig)
    fig=plot_peri_csp(spike_c,'%s Spiking on CS+'%epoch)
    pp.savefig()
    plt.close(fig)
for epoch in la.epochs.values:
    experiment_c = data.experiment_traits[data.experiment_traits.loc[:,'learning_epoch']==epoch]
    spike_c = df_spike.reindex(experiment_c.index, level='time')
    data_c = z_data.reindex(experiment_c.index, level='time')
    fig=plot_peri_us(data_c,'%s Ca-level on US'%epoch)
    pp.savefig()
    plt.close(fig)
    fig=plot_peri_csp(data_c,'%s Ca-level on CS+'%epoch)
    pp.savefig()
    plt.close(fig)
pp.close()

In [None]:
fig=plot_peri_us(df_spike,'Spiking')
fig=plot_peri_us(z_data,'Ca-level')
fig=plot_peri_csp(df_spike,'Spiking')
fig=plot_peri_csp(z_data,'Ca-level')

## Individual ROIs
* since there are many of them, save figure to pdf
* THIS WILL <font color="red">TAKE A WHILE</font>, consider testing with a small range

In [None]:
def plot_roi(df_spike, df_data, filaname, grp, title_template, by_epoch=False, div=None):
    pp = PdfPages(filaname)
    for i in range(0,len(data.rois)):
        spike_c = df_spike.loc[(slice(None),data.rois[i]),:]
        data_c = df_data.loc[(slice(None),data.rois[i]),:]
        #raw_c = df_raw.loc[(slice(None),data.rois[i]),:]
        if by_epoch:
            fig = la.plot_epochs(spike_c, data_c, None, data.experiment_traits, etc, FPS, grp, title=title_template%(i,data.rois[i]), div=div)
        else:
            fig = la.plot_data(spike_c, data_c, None, data.experiment_traits, FPS, grp, title=title_template%(i,data.rois[i]), div=div)
        pp.savefig()
        plt.close(fig)
    pp.close()

In [None]:
raise ValueError("You don't want to run this automatically")

In [None]:
import ca_lib as la
imp.reload(la)

In [None]:
plot_roi(df_spike, df_data, animal+'_roi1crit.pdf',[['context'],['learning_epoch'],['port'],['puffed']],'ROI %d:\n%s')

In [None]:
plot_roi(df_spike, df_data, animal+'_roiAcrit.pdf',[['context','port','puffed']],'ROI %d:\n%s',True)

### Averaging over intervals

#### Intervals aligned to events

In [None]:
a_spike = la.pd_aggr_col(df_spike, myafun, asections, acenters.astype(str))
a_data = la.pd_aggr_col(df_data, myafun, asections, acenters.astype(str))
a_raw = la.pd_aggr_col(df_raw, myafun, asections, acenters.astype(str))
a_lick = la.pd_aggr_col(df_lick, myafun, asections, acenters.astype(str))

In [None]:
a_data = a_data.sort_index()
a_raw = a_raw.sort_index()

In [None]:
plot_roi(a_spike, a_data, animal+'_avg1crit.pdf',[['context'],['learning_epoch'],['port'],['puffed']],'ROI %d:\n%s',div=acenters)

In [None]:
plot_roi(a_spike, a_data, animal+'_avgAcrit.pdf',[['context','port','puffed']],'ROI %d:\n%s',True,div=acenters)

#### Averaging over bins

In [None]:
b_spike = la.pd_aggr_col(df_spike, mybfun, bsections, bcenters.astype(str))
b_data = la.pd_aggr_col(df_data, mybfun, bsections, bcenters.astype(str))
b_raw = la.pd_aggr_col(df_raw, mybfun, bsections, bcenters.astype(str))
b_lick = la.pd_aggr_col(df_lick, mybfun, bsections, bcenters.astype(str))

In [None]:
b_data = b_data.sort_index()
b_raw = b_raw.sort_index()

In [None]:
plot_roi(b_spike, b_data, animal+'_bin1crit.pdf',[['context'],['learning_epoch'],['port'],['puffed']],'ROI %d:\n%s',div=bcenters)

In [None]:
plot_roi(a_spike, a_data, animal+'_binAcrit.pdf',[['context','port','puffed']],'ROI %d:\n%s',True,div=acenters)

## Correlations

In [None]:
# Convert to ordinal, here we use indices that way
et1 = experiment_traits.copy().drop('time', axis=1)
et1[['session_num', 'day_num']] = et1[['session_num', 'day_num']].astype(int)
ord1 = z_data.reindex(df_template.index, df_template.columns)
ord1.columns = pd.Index(ord1.columns.values.astype(int), name=ord1.columns.name)

# Combine information
ord1 = ord1.join(et1, how='inner').reset_index().drop('time', axis=1).set_index(['roi_id','learning_epoch','context','port','puffed','session_num']).sort_index()
ord1.columns.name='Spike'
print(ord1.shape)
display(ord1.head())

# Search for days that contain experiments with same traits and session_num
# These entries would jeopardize unstacking
et2 = et1.reset_index(drop=True).set_index(['learning_epoch','context','port','puffed','session_num']).sort_index()
second_occur = et2.index.duplicated()
set1 = et2.loc[second_occur,'day_num'].unique()
all_occur = et2.index.get_duplicates()
set_all = et2.loc[all_occur,'day_num'].unique()
set2 = np.array(list(set(set_all)-set(set1)))
print(set1,set2)

# Filter out second occurrences stored in set2
if len(set2):
    ord1 = ord1[ord1.loc[:,'day_num'].apply(lambda x: x not in set2)]
print(ord1.shape)

# Reshape for correlation analysis
# some value get converted to float to be able to hold nan-s
comp = ord1['day_num'].unstack().sort_index(axis=1)
ord1 = ord1.drop(['day_num'], axis=1).unstack()
ord1 = ord1.reset_index().set_index(['learning_epoch','context','port','puffed','roi_id']).sort_index()
display(comp.head())
display(ord1.head(10))

In [None]:
# Find the pre-learning structure, without airpuff
key_ref = ('Pre-Learning','CS+',True,False)
time_ref = np.array([15, 40])
col_ref = slice(int(time_ref[0]*FPS),int(time_ref[1]*FPS))
sel = ord1.loc[key_ref+(slice(None),),col_ref]
print(sel.shape)

# Correlate
corr_df = sel.T.corr()
corr_np = corr_df.fillna(0).values

# Discard invalid series
keep = (np.diag(corr_np) == 1.0)
corr_np = corr_np[keep,:][:,keep]

# Show
fig, ax = plt.subplots(1,2, figsize=(10,5))
ax[0].matshow(corr_df.values)
ax[1].matshow(corr_np)

In [None]:
pp = helpmultipage(animal+'_correl.pdf')

In [None]:
# Define an ordering
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import squareform, pdist
sq_dist = squareform(1.0-corr_np)
corr_link = linkage(sq_dist, 'average')
fig, ax = plt.subplots(1,2, figsize=(18,8))
fig.suptitle('Reference is presented here: '+(', '.join(np.array(key_ref)))+
            ' and time '+('..'.join(time_ref.astype(str)))+'s')
labels = sel.index.get_level_values(4).to_series().reset_index(drop=True)[keep]
dendo = dendrogram(corr_link, ax=ax[1], labels=labels.values, leaf_font_size=2.5, orientation='left')
ax[1].set_title('Distance of firing patterns')
corr_order = dendo['leaves']
# Show reordered
cax = ax[0].matshow(corr_np[corr_order,:][:,corr_order], origin='lower', vmin=-0.8, vmax=1)
ax[0].xaxis.set_ticks_position('bottom')
ax[0].set_title('Ordered correlation matrix', y=1.0)
fig.colorbar(cax)
pp.savefig(dpi=600)

In [None]:
num_plots = len(et.index)
num_rows = int(np.ceil(num_plots/3.0))
fig, ax = plt.subplots(num_rows,3, figsize=(12,4.5*num_rows))
fig.suptitle('Correlation structure under different conditions: learning_epoch, context, port, puffed\n'+
             '(small number of trials might lead to larger percieved correlation)')
ax = np.ravel(ax)
mx = {}
ds = pd.DataFrame(columns=et.index)

for idx in range(0,num_plots):
    # Find the pre-learning structure
    key = et.index[idx]
    sel = ord1.loc[key+(slice(None),),col_ref]
    print(key,ord1.shape,sel.shape)
    
    # Correlate
    corr_tmp = sel.T.corr()
    corr_tmp = corr_tmp.fillna(0).values

    # Discard invalid series
    #if len(corr_tmp):
    corr_tmp = corr_tmp[keep,:][:,keep][corr_order,:][:,corr_order]
    cax = ax[idx].matshow(corr_tmp, origin='lower', vmin=-0.8, vmax=1)
    ax[idx].xaxis.set_ticks_position('bottom')
        
    mx[et.index[idx]] = corr_tmp
    ds[et.index[idx]] = np.ravel(corr_tmp+np.diag(np.nan*np.diag(corr_tmp)))
    ax[idx].set_title('%s: %d'%(et.index[idx],et.ix[idx]))
pp.savefig(dpi=600)

In [None]:
fig, ax = plt.subplots(num_rows,3, figsize=(12,4.5*num_rows))
fig.suptitle('Distribution of the above correlation coefficients\n(diagonals excluded)')
ax = np.ravel(ax)

for idx in range(0,num_plots):
    corr_tmp = mx[et.index[idx]]
    corr_tmp = corr_tmp+np.diag(np.nan*np.diag(corr_tmp))
    ax[idx].hist(np.ravel(corr_tmp),range=(-1,1),bins=20)
    ax[idx].set_yscale('log')
    ax[idx].set_title('%s: %d'%(et.index[idx],et.ix[idx]))
    

In [None]:
#help(pd.tools.plotting.table)
# FIXME index column toooo wide
fig, ax = plt.subplots(1,1)
fig.suptitle('Statistics on the correlation coefficients')
ax.axis('off')
ax.set_position([.5, 0.2, 0.5, 0.6])
a = df_epoch(np.round(ds.describe(),4).T)
cw = np.ones((len(a.columns),))
t = pd.tools.plotting.table(ax, a, loc='upper right', fontsize=12, colWidths=cw/np.sum(cw))
pp.savefig()
plt.close(fig)
a

In [None]:
#help(pd.tools.plotting.table)
# FIXME index column toooo wide
fig, ax = plt.subplots(1,1)
fig.suptitle('Statistics on the absolute value of correlation coefficients')
ax.axis('off')
ax.set_position([.5, 0.2, 0.5, 0.6])
a = df_epoch(np.round(np.abs(ds).describe(),4).T)
cw = np.ones((len(a.columns),))
t = pd.tools.plotting.table(ax, a, loc='upper right', fontsize=12, colWidths=cw/np.sum(cw))
pp.savefig()
plt.close(fig)
a

In [None]:
change = np.zeros((num_plots,num_plots))
for idx1 in range(0,num_plots):
    for idx2 in range(0,num_plots):
        change[idx1,idx2] = np.linalg.norm(np.ravel(mx[et.index[idx1]]-mx[et.index[idx2]])/np.size(mx[et.index[idx2]]),2)


fig, ax = plt.subplots(1,1, figsize=(6,6))
fig.tight_layout(rect=[0.4,0,0.95,0.55])
#fig = plt.figure()
#ax = fig.gca()
cax = ax.matshow(change+np.diag(np.nan*np.diag(change)), cmap=plt.get_cmap('rainbow'))

fig.suptitle('Difference between test cases (RMS distance)\n'+', '.join(et.index.names))
ax.set_xticks(np.array(range(0,len(et.index))))
ax.set_xticklabels(et.index.values.tolist(),rotation=90)
ax.set_yticks(np.array(range(0,len(et.index))))
ax.set_yticklabels(et.index.values.tolist())

# without set_yticks
# ax.set_yticklabels([tuple()]+et.index.values.tolist())
fig.colorbar(cax)
pp.savefig()

In [None]:
pp.close()