### Reqirements
* #### You need to install module future, manual importing from \_\_future\_\_ is at your convenience
* #### For hdf data import you need pytables too which is not default installed with Anaconda

### Batch execution
* #### ```batch_animal=msaxxyy_z jupyter nbconvert Stat.ipynb --to=html --execute --ExecutePreprocessor.timeout=-1 --output=xxyy_z_report.html```

In [None]:
#from future.utils import PY3
import future
from __future__ import (absolute_import, division,
                        print_function, unicode_literals)
import pandas as pd
import numpy as np
import time, os, warnings, imp, itertools
import IPython.display as disp
display = disp.display
import matplotlib as mpl, matplotlib.pyplot as plt
import scipy.stats as stats
zscore, describe = stats.mstats.zscore, stats.describe
import datetime
dt, td = datetime.datetime, datetime.timedelta

%matplotlib inline

In [None]:
import ca_lib as la
imp.reload(la)

In [None]:
from os import environ
batch_animal = environ.get('batch_animal', None)

## Load files

In [None]:
basedir = '../_share/Losonczi/'
animals = ['msa0216_4','msa0316_1','msa0316_3','msa0316ag_1',
           'msa0915_1','msa0915_2','msa1215_1']

# Display database folders
display(os.listdir(basedir))

In [None]:
# Load files
data = {}
bayes = {}
for animal in animals:
    mydir = os.path.join(basedir,animal)
    #data[animal] = la.read_from_hdf('anidb_'+animal+'.h5', la.Bunch())
    bayes[animal] = la.read_from_hdf('baydb_'+animal+'.h5', la.Bunch())
    #print (animal, data[animal].raw.shape)

## Experiment protocol configurations

In [None]:
def settings_summary(data):
    et_store, et_disp = [], []
    for animal, db in data.iteritems():
        et = db.et
        et_disp.append(et.reset_index().set_index(la.display_learning))
        et_store.append(et)
    et_disp = pd.concat(et_disp,axis=1,names='animal').fillna('-')
    et_store = pd.concat(et_store,axis=1,names='animal')
    disp.display(disp.HTML('<font color="red">ATTENTION, </font>for later conformity we store columns in a <b>different order</b>: %s !!!'%la.sort_learning))
    display(la.df_epoch(et_disp))
    return et_store

In [None]:
et = settings_summary(data)

In [None]:
def bayes_summary(data):
    bt_store = []
    for animal, db in data.iteritems():
        bt = db.constellations
        bt_store.append(bt)
    bt_store = pd.concat(bt_store,axis=1,names='animal')
    return bt_store

In [None]:
bt = bayes_summary(bayes).fillna(0).astype(int)

In [None]:
bt

In [None]:
import re
lab = pd.DataFrame(bt.index.tolist(), columns=bt.index.names)
for col in lab.columns:
    lab[col]=lab[col].apply(lambda x: col if x else '')
lab = lab.apply(lambda x: ' '.join(x).replace('  ',' '), axis=1).str.strip()
lab

In [None]:
bt2 = bt.copy()
bt2.index=lab.values
bt2

## Prepare data

### Averaging (integrating)
Spiking is "True" in the [intervals) given in transients_data.hc5

In [None]:
mymean = pd.DataFrame.mean
mystd = pd.DataFrame.std

# Plot

In [None]:
from matplotlib.backends.backend_pdf import PdfPages

class helpmultipage(object):
    def __init__(self, filename):
        self.filename = filename
        self.isopen = False
        self.open()
        
    def __del__(self):
        self.close()
        
    def savefig(self, dpi=None):
        if self.isopen:
            self.pp.savefig(dpi=dpi)

    def open(self):
        if (~self.isopen) and len(self.filename):
            self.pp = PdfPages(self.filename)
            self.isopen = True
        
    def close(self):
        if self.isopen:
            self.pp.close()
        self.isopen = False

#### Explanatory figure

In [None]:
def explain_figures(data):
    import matplotlib.patches as mpatches
    from matplotlib.collections import PatchCollection
    center = (data.event_frames[1:]+data.event_frames[:-1]) /2
    left = data.event_frames
    width = data.event_frames[1:]-data.event_frames[:-1]
    vcenter = 0.0
    vstart = -0.5

    def label90(x,y,text):
        ax.text(x, y, text, ha="center", va="center", family='sans-serif', size=14, rotation=90)

    fig, (empty, ax) = plt.subplots(2,1,figsize=(6,8))
    fig.suptitle('Explanatory figure',fontsize=16)
    fig.tight_layout(pad=3)
    empty.axis('off')
    
    ax.set_xlabel('Camera frame')
    ax.set_ylabel('z-scored activity')
    ax.set_ylim(vstart,vstart+1)
    ax.plot(data.z_spike.mean(axis=0)+0.00, label="(CategoryA, True): #trials", c=(1,1,0))
    ax.plot(data.z_spike.mean(axis=0)+0.02, label="(CategoryB, True): #trials", c=(.5,1,.5))
    ax.plot(-data.z_spike.mean(axis=0)+0.00, label="(CategoryA, False): #trials", c=(1,.8,1))
    ax.plot(-data.z_spike.mean(axis=0)+0.02, label="(CategoryB, False): #trials", c=(.5,1,1))
    patches = []
    # mark delay
    label90(center[0], vcenter, 'excitation by\nshowing water')
    # mark CS
    rect = mpatches.Rectangle((left[1],vstart), width[1], 1, ec="none")
    patches.append(rect)
    label90(center[1], vcenter, 'CS± if tone\n"Baseline" otherwise')
    # mark delay
    label90(center[2], vcenter, 'trace = delay')
    # mark UC
    rect = mpatches.Rectangle((left[3],vstart), width[3], 1, ec="none")
    patches.append(rect)
    label90(center[3], vcenter, 'UC if any')
    # mark water
    ax.text((left[0]+left[3])/2, vstart, "water port present\niff allowed to lick",
            ha="center", va="bottom", family='sans-serif', size=14, bbox=dict(boxstyle="DArrow", pad=0.0, fc='c'))

    # show event boundaries
    for sep in data.event_frames[:-1]:
        ax.axvline(x=sep, ymin=0.0, ymax = 1.0, linewidth=1, color='k')
    colors = np.linspace(0, 1, len(patches))
    collection = PatchCollection(patches, cmap=plt.cm.hsv, alpha=0.1)
    collection.set_array(np.array(colors))
    ax.add_collection(collection)

    # align legend
    leg = ax.legend(loc='lower center', title="Category name, Condition name",
                   bbox_to_anchor=(0.5, 1.1))
    leg.get_title().set_fontsize('large')
    leg.get_title().set_fontweight('bold')
    with warnings.catch_warnings():
        warnings.simplefilter('ignore', UserWarning)
        fig.show()
    return fig

### Learning progress

In [None]:
def learning_chart(data):
    fig, ax = plt.subplots(len(data.trials),1,figsize=(10,0.6*len(data.trials)), sharex=True, sharey=True)
    fig.tight_layout(h_pad=0.1)
    ind = np.arange(0,5)
    width, height, spacing = 1, 1.2, 10
    label_df = data.experiment_traits.replace('Baseline','B.L.')
    for i, trial in enumerate(data.trials):
        pos = ind+2*spacing
        unit = 2.0*data.lick_threshold
        # need to use .values because integer colummn indices cause confusion
        mea, err = data.lick_rate_mean.loc[trial].values/unit, data.lick_rate_std.loc[trial].values/unit
        rects1 = ax[i].bar(pos, mea, width, color='r', yerr=err)
        pos = ind+3*spacing
        mea = data.lick_time_mean.loc[trial].values
        rects2 = ax[i].bar(pos, mea, width, color='b')
        pos = ind[0:2]+2.65*spacing
        mea = 0.2 * np.array([len(data.lick_triggers_rise[trial]), len(data.spike_triggers_rise[trial])])
        rects3 = ax[i].bar(pos, mea, width, color='g')
        ax[i].set_xlim(xmin=0)
        ax[i].set_ylim(ymin=0, ymax=height)
        ax[i].set_yticks([0,0.5,1])
        la.draw_conditions(ax[i],label_df,trial,data.FPS,loc='lower left',screen_width=0.5, height=height, cw=[0.25, 0.15, 0.15, 0.15, 0.15, 0.15],fontsize=12)
    ax[-1].set_xticks([spacing, 2.2*spacing, 2.75*spacing, 3.2*spacing])
    ax[-1].set_xticklabels(['Conditions', 'Licking rate', 'Lick rise, Pop. rise', 'Licking time'])
    return fig

In [None]:
pp = helpmultipage('all_protocol.pdf')
for animal in animals:
    fig = learning_chart(data[animal])
    fig.suptitle(animal)
    pp.savefig()
    plt.close(fig)
pp.close()

## Population averages

### Single criterion
* comments

### Two criteria
* comments

### Three criteria
* comments

### All criteria
* There is no increased population activity for CS+ without puffing. (For mouse 0216_4 the 1 trial with port displays increase during the trace period - why?)
* During learning mouse 0216_4 shows incresed activity during the UC phase for CS-

### Activities conditional on epoch

In [None]:
def plot_by_epoch(epoch):
    experiment_c = data.experiment_traits[data.experiment_traits.loc[:,'learning_epoch']==epoch]
    spike_c = data.z_spike.reindex(experiment_c.index, level='time')
    data_c = data.z_filtered.reindex(experiment_c.index, level='time')
    raw_c = data.z_raw.reindex(experiment_c.index, level='time')
    lick_c = data.lick.reindex(experiment_c.index)
    print (experiment_c.shape, spike_c.shape)
    spike_ca = la.pd_aggr_col(spike_c, mymean, asections, acenters)
    data_ca = la.pd_aggr_col(data_c, mymean, asections, acenters)
    raw_ca = la.pd_aggr_col(raw_c, mymean, asections, acenters)
    lick_ca = la.pd_aggr_col(lick_c, mymean, asections, acenters)
    print (spike_c.shape, spike_ca.shape)

    grp = [['context','port'],['context','puffed'],['port','puffed']]
    la.plot_data(data, [spike_c, data_c, lick_c],
                 ['z-scored Spiking', 'z-scored Ca-levels', 'Licking'],
                 grp, title=epoch)
    pp.savefig()
    la.plot_data(data, [spike_ca, data_ca, lick_ca],
                 ['z-scored Spiking', 'z-scored Ca-levels', 'Licking'],
                 grp, title=epoch+' averaged over events', div=acenters)
    pp.savefig()

#### Pre-learning

#### Learning

#### Post-Learning

## Activity vector by phases

pp = helpmultipage(animal+'_activation_ca.pdf')

for p,aggr in enumerate(data.za_filtered.columns):
    title = 'Ca-Signal, Phase: %s'%la.phases[p]
    plot_activity_vectors(data, data.za_filtered, title, vmin=-3, vmax=3)
    pp.savefig()    
pp.close()

### An example of spiking
The first 1 second of the recording seems missing

In [None]:
def draw_firing(ax, data, idx, settings, seismic=False, show_nan=False, pos=-20):
    #experiment_id = settings['timestr']
    fig.suptitle('%s: session %s, day %s\n'%(idx,settings['session_num'],settings['day_num'])+
                 ', '.join(la.sort_learning)+': #context in epoch, #day',fontsize=16)
    if seismic:
        la.draw_levels(ax, data.z_filtered, idx, data.FPS, data.roi_df)
    else:
        la.draw_transients(ax, data.transients, idx, data.FPS, data.roi_df)
    if show_nan:
        la.draw_spiking_nan(ax, data.spike, idx, data.rois.values)

def draw_signals(ax, data, idx, settings, seismic=False, show_nan=False, pos=-20):
    experiment_id = settings['timestr']
    la.draw_population(ax, data.z_filtered, idx, pos=pos, c='y', label='population Ca-signal')
    la.draw_population(ax, data.z_spike, idx, pos=pos, threshold=data.z_spike_threshold, label='population z-spike count')
    la.draw_licking(ax, data.lick, idx, pos=pos-20, threshold=data.lick_threshold, label='licking')
    la.draw_triggers(ax, trig_list_data, idx, -5, trig_list_sign, c=trig_list_color)
    la.draw_conditions(ax, data.experiment_traits, experiment_id, data.FPS, height=20)
    return ax

### Pattern matching

## Peri-event plots

In [None]:
import matlab_tools as mt
imp.reload(la)

In [None]:
def list_peri_3a(df, title=None):
    '''Plot collection: CS+ US'''
    ret = [] # df, trig, allow, disable, title
    ret.append([df, data.rois, lick_triggers_rise, None, None, 'Lick rise'])
    ret.append([df, data.rois, lick_triggers_fall, None, None, 'Lick fall'])
    ret.append([df, data.rois, lick_triggers_rise, csp_triggers_allow, None, 'Lick rise CS+'])
    ret.append([df, data.rois, lick_triggers_fall, csp_triggers_allow, None, 'Lick fall CS+'])
    ret.append([df, data.rois, csp_triggers_rise, None, None, 'CS+ start'])
    ret.append([df, data.rois, csp_triggers_fall, None, None, 'CS+ end'])
    ret.append([df, data.rois, lick_triggers_rise, us_triggers_allow, None, 'Lick rise US'])
    ret.append([df, data.rois, lick_triggers_fall, us_triggers_allow, None, 'Lick fall US'])
    ret.append([df, data.rois, us_triggers_rise, None, None, 'US start'])
    ret.append([df, data.rois, us_triggers_fall, None, None, 'US end'])
    return ret

In [None]:
def list_peri_3b(df, title=None):
    '''Plot collection: CS+ US'''
    ret = [] # df, trig, allow, disable, title
    ret.append([df, data.rois, lick_triggers_rise, None, None, 'Lick rise'])
    ret.append([df, data.rois, lick_triggers_fall, None, None, 'Lick fall'])
    ret.append([df, data.rois, lick_triggers_rise, csp_triggers_allow, None, 'Lick rise CS+'])
    ret.append([df, data.rois, lick_triggers_fall, csp_triggers_allow, None, 'Lick fall CS+'])
    ret.append([df, data.rois, csp_triggers_rise, None, None, 'CS+ start'])
    ret.append([df, data.rois, csp_triggers_fall, None, None, 'CS+ end'])
    ret.append([df, data.rois, lick_triggers_rise, csm_triggers_allow, None, 'Lick rise CS-'])
    ret.append([df, data.rois, lick_triggers_fall, csm_triggers_allow, None, 'Lick fall CS-'])
    ret.append([df, data.rois, csm_triggers_rise, None, None, 'CS- start'])
    ret.append([df, data.rois, csm_triggers_fall, None, None, 'CS- end'])
    return ret

In [None]:
def list_peri_1a(data, df, title=None):
    '''Plot collection: Lick'''
    ret = [] # df, trig, allow, disable, title
    for animal in animals:
        ret.append([data[animal][df], data[animal].rois, data[animal].lick_triggers_rise, None, None, '%s\nLick rise'%animal])
    for animal in animals:
        ret.append([data[animal][df], data[animal].rois, data[animal].lick_triggers_fall, None, None, '%s\nLick fall'%animal])
    return ret
def list_peri_1b(data, df, title=None):
    '''Plot collection: Lick CS+'''
    ret = [] # df, trig, allow, disable, title
    for animal in animals:
        ret.append([data[animal][df], data[animal].rois, data[animal].lick_triggers_rise, data[animal].csp_triggers_allow, None, '%s\nLick rise CS+'%animal])
    for animal in animals:
        ret.append([data[animal][df], data[animal].rois, data[animal].lick_triggers_fall, data[animal].csp_triggers_allow, None, '%s\nLick fall CS+'%animal])
    return ret
def list_peri_1c(data, df, title=None):
    '''Plot collection: Lick CS-'''
    ret = [] # df, trig, allow, disable, title
    for animal in animals:
        ret.append([data[animal][df], data[animal].rois, data[animal].lick_triggers_rise, data[animal].csm_triggers_allow, None, '%s\nLick rise CS-'%animal])
    for animal in animals:
        ret.append([data[animal][df], data[animal].rois, data[animal].lick_triggers_fall, data[animal].csm_triggers_allow, None, '%s\nLick fall CS-'%animal])
    return ret
def list_peri_1d(data, df, title=None):
    '''Plot collection: CS+'''
    ret = [] # df, trig, allow, disable, title
    for animal in animals:
        ret.append([data[animal][df], data[animal].rois, data[animal].csp_triggers_rise, None, None, '%s\nCS+ start'%animal])
    for animal in animals:
        ret.append([data[animal][df], data[animal].rois, data[animal].csp_triggers_fall, None, None, '%s\nCS+ end'%animal])
    return ret
def list_peri_1e(data, df, title=None):
    '''Plot collection: CS-'''
    ret = [] # df, trig, allow, disable, title
    for animal in animals:
        ret.append([data[animal][df], data[animal].rois, data[animal].csm_triggers_rise, None, None, '%s\nCS- start'%animal])
    for animal in animals:
        ret.append([data[animal][df], data[animal].rois, data[animal].csm_triggers_fall, None, None, '%s\nCS- end'%animal])
    return ret
def list_peri_1f(data, df, title=None):
    '''Plot collection: US'''
    ret = [] # df, trig, allow, disable, title
    for animal in animals:
        ret.append([data[animal][df], data[animal].rois, data[animal].us_triggers_rise, None, None, '%s\nUS start'%animal])
    for animal in animals:
        ret.append([data[animal][df], data[animal].rois, data[animal].us_triggers_fall, None, None, '%s\nUS end'%animal])
    return ret

In [None]:
num_rois = 250
peri_range=(-16,48)

In [None]:
imp.reload(la)

In [None]:
# Show an example
fig=la.plot_peri_collection(list_peri_1a(data, 'spike'),'Spiking',peri_range,combine=False)
fig.gca().set_ylim(ymax=num_rois)

In [None]:
pp = helpmultipage('all_peri1.pdf')
#experiment_c = data.experiment_traits[data.experiment_traits.loc[:,'learning_epoch']==epoch]
#spike_c = df_spike.reindex(experiment_c.index, level='time')
#data_c = z_filtered.reindex(experiment_c.index, level='time')
fig=la.plot_peri_collection(list_peri_1a(data, 'spike'),'Spiking on Lick',peri_range,combine=False)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1b(data, 'spike'),'Spiking on Lick CS+',peri_range,combine=False)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1c(data, 'spike'),'Spiking on Lick CS-',peri_range,combine=False)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1d(data, 'spike'),'Spiking on CS+',peri_range,combine=False)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1e(data, 'spike'),'Spiking on CS-',peri_range,combine=False)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1f(data, 'spike'),'Spiking on US',peri_range,combine=False)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1a(data, 'z_filtered'),'z-scored Ca-level on Lick',peri_range,combine=False)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1b(data, 'z_filtered'),'z-scored Ca-level on Lick CS+',peri_range,combine=False)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1c(data, 'z_filtered'),'z-scored Ca-level on Lick CS-',peri_range,combine=False)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1d(data, 'z_filtered'),'z-scored Ca-level on CS+',peri_range,combine=False)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1e(data, 'z_filtered'),'z-scored Ca-level on CS-',peri_range,combine=False)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1f(data, 'z_filtered'),'z-scored Ca-level on US',peri_range,combine=False)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
pp.close()

In [None]:
pp = helpmultipage('all_peri2.pdf')
#experiment_c = data.experiment_traits[data.experiment_traits.loc[:,'learning_epoch']==epoch]
#spike_c = df_spike.reindex(experiment_c.index, level='time')
#data_c = z_filtered.reindex(experiment_c.index, level='time')
fig=la.plot_peri_collection(list_peri_1a(data, 'spike'),'Spiking on Lick',peri_range)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1b(data, 'spike'),'Spiking on Lick CS+',peri_range)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1c(data, 'spike'),'Spiking on Lick CS-',peri_range)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1d(data, 'spike'),'Spiking on CS+',peri_range)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1e(data, 'spike'),'Spiking on CS-',peri_range)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1f(data, 'spike'),'Spiking on US',peri_range)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1a(data, 'z_filtered'),'z-scored Ca-level on Lick',peri_range)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1b(data, 'z_filtered'),'z-scored Ca-level on Lick CS+',peri_range)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1c(data, 'z_filtered'),'z-scored Ca-level on Lick CS-',peri_range)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1d(data, 'z_filtered'),'z-scored Ca-level on CS+',peri_range)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1e(data, 'z_filtered'),'z-scored Ca-level on CS-',peri_range)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
fig=la.plot_peri_collection(list_peri_1f(data, 'z_filtered'),'z-scored Ca-level on US',peri_range)
fig.gca().set_ylim(ymax=num_rois)
pp.savefig()
plt.close(fig)
pp.close()

## Individual ROIs
* since there are many of them, save figure to pdf
* THIS WILL <font color="red">TAKE A WHILE</font>, consider testing with a small range

In [None]:
def plot_roi(filename, data, dfs, names, grp, title_template, by_epoch=False, div=None, fill=None):
    pp = PdfPages(filename)
    for i in range(0,len(data.rois)):
        dfc = []
        for df in dfs:
            dfc.append(df.loc[(slice(None),data.rois[i]),:])
        if by_epoch:
            fig = la.plot_epochs(data, dfc, names, grp, title=title_template%(i,data.rois[i]), div=div, fill=fill)
        else:
            fig = la.plot_data(data, dfc, names, grp, title=title_template%(i,data.rois[i]), div=div, fill=fill)
        pp.savefig()
        plt.close(fig)
    pp.close()

#### Raw data

### Averaging over intervals

#### Intervals aligned to events

#### Averaging over bins

## Correlations

In [None]:
# Whether to average trials (assuming identical timing) or rather concatenate them
concatenate = True

In [None]:
def concat_for_correlation(df, data):
    # Combine information
    ordered = df.reindex(data.mirow, data.icol)
    et1 = data.experiment_traits[data.experiment_traits.loc[:,'session_num']>=0]
    et1 = et1.loc[:,la.sort_learning+['day_num','session_num']]
    ordered = ordered.join(et1, how='inner').reset_index().drop('time', axis=1) # keep roi_id
    ordered = ordered.set_index(la.sort_learning+['roi_id', 'session_num']).sort_index()
    ordered.columns.name='Spike'
    #display(ordered.head())

    # Search for days that contain experiments with same traits and session_num
    # These entries would jeopardize unstacking
    et2 = et1.reset_index(drop=True).set_index(la.sort_learning+['session_num']).sort_index()
    second_occur = et2.index.duplicated()
    set_second = et2.loc[second_occur,'day_num'].unique()
    all_occur = et2.index.get_duplicates()
    set_all = et2.loc[all_occur,'day_num'].unique()
    set_first = np.array(list(set(set_all)-set(set_second)))
    print('Days repeating settings: %s, all conflicted: %s, to be kept: %s'%
          (set_second,set_all,set_first))

    # Filter out second occurrences stored in set2
    if len(set_first):
        ordered = ordered[ordered.loc[:,'day_num'].apply(lambda x: x not in set_first)]
    print('Filtered data:',ordered.shape)

    # Reshape for correlation analysis
    # integer values get converted to float if needed to hold NaN-s
    calendar = ordered['day_num'].unstack(fill_value=0)
    ordered = ordered.drop(['day_num'], axis=1).unstack()
    print('Concatenated data:',ordered.shape)
    return ordered, calendar

In [None]:
def average_for_correlation(df, data):
    # Combine information
    ordered = df.reindex(data.mirow, data.icol)
    et1 = data.experiment_traits[data.experiment_traits.loc[:,'session_num']>=0]
    et1 = et1.loc[:,la.sort_learning+['day_num','session_num']]
    ordered = ordered.join(et1, how='inner').reset_index().drop('time', axis=1) # keep roi_id
    ordered = ordered.set_index(la.sort_learning+['roi_id', 'session_num']).sort_index()
    ordered.columns.name='Spike'
    #display(ordered.head())

    # Reshape for correlation analysis
    # integer values get converted to float if needed to hold NaN-s
    calendar = ordered['day_num'].unstack(fill_value=0)
    avg_labels = list(ordered.index.names)
    avg_labels.remove('session_num')
    print (avg_labels)
    ordered = ordered.drop(['day_num'], axis=1).mean(level=avg_labels)
    compatibility_index_level =  pd.Index([1], name='session_num')
    ordered.columns = pd.MultiIndex.from_product((ordered.columns, compatibility_index_level),
                                              names=(ordered.columns.name,compatibility_index_level.name))
    print('Averaged data:',ordered.shape)
    return ordered, calendar

In [None]:
if concatenate:
    ord1 = 'ord_correl'
    corr_method = 'concatenated'
else:
    ord1 = 'ord_tr_avg'
    corr_method = 'trial-averaged'

In [None]:
# Set a reference for ordering elements in the matrix

for animal in animals:
    key_ref = data[animal].key_ref # default ('Post-Learning','CS+','W+','A+')
    FPS = data[animal].FPS
    time_ref = np.array([15, 40])
    #data[animal].col_ref = slice(int(time_ref[0]*FPS),int(time_ref[1]*FPS))
    sel = data[animal][ord1].loc[data[animal].key_ref+(slice(None),),data[animal].col_ref]
    print(key_ref,time_ref,data[animal].col_ref,sel.shape)

    # Correlate
    corr_df = sel.T.corr()
    corr_np = data[animal].corr_np

    # Discard invalid series
    #data[animal].keep = (np.diag(corr_np) == 1.0)
    #data[animal].corr_np = corr_np[data[animal].keep,:][:,data[animal].keep]

    # Show
    fig, ax = plt.subplots(1,2, figsize=(12,4))
    img = ax[0].matshow(corr_df.values)
    img = ax[1].matshow(corr_np)
    fig.colorbar(img, ax=ax[1])

In [None]:
pp = helpmultipage('all_correl_'+corr_method+'.pdf')

In [None]:
# Define an ordering
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import squareform, pdist

for animal in animals:
    sq_dist = squareform(1.0-data[animal].corr_np)
    corr_link = linkage(sq_dist, 'average')
    sel = data[animal][ord1].loc[data[animal].key_ref+(slice(None),),data[animal].col_ref]
    
    fig, ax = plt.subplots(1,2, figsize=(18,8))
    fig.suptitle('Reference for '+animal+' is presented here: '+(', '.join(np.array(key_ref)))+
                ' and time '+('..'.join(time_ref.astype(str)))+'s',fontsize=16)
    labels = sel.index.get_level_values(4).to_series().reset_index(drop=True)[data[animal].keep]
    dendo = dendrogram(corr_link, ax=ax[1], labels=labels.values, leaf_font_size=2.5, orientation='left')
    ax[1].set_title('Distance of firing patterns')
    #data[animal].corr_order = dendo['leaves']
    # Show reordered
    img = ax[0].matshow(data[animal].corr_np[data[animal].corr_order,:][:,data[animal].corr_order], origin='lower', vmin=-0.8, vmax=1)
    ax[0].xaxis.set_ticks_position('bottom')
    ax[0].set_title('Ordered correlation matrix', y=1.0)
    fig.colorbar(img)
    pp.savefig(dpi=600)

In [None]:
# Show under different conditions
num_phases = 3
num_rows = len(et.index)
num_cols = num_phases

for animal in animals:
    print (animal)
    phase_start = data[animal].event_frames+data[animal].FPS
    phase_end = data[animal].event_frames[1:]-data[animal].FPS

    fig, ax = plt.subplots(num_rows,num_cols, figsize=(5*num_cols,5*num_rows))
    fig.suptitle('Correlation structure under different conditions: learning_epoch, context, port, puffed\n'+
                 '(small number of trials might lead to larger percieved correlation)\n'+
                 '(in phases Ready, CS, Trace the conditions A+ and A- should be very similar)',fontsize=16)
    #ax = np.ravel(ax)

    keep = data[animal].keep
    corr_order = data[animal].corr_order

    for (irow,key),(icol,phase) in itertools.product(enumerate(et.index),enumerate(la.phases[:num_cols])):
        # Find the pre-learning structure
        count = np.nan_to_num(et.loc[key,animal])
        try:
            corr_tmp = data[animal].mx[key+(phase,)]
        except KeyError:
            corr_tmp = []

        # Discard invalid series
        if len(corr_tmp):
            img = ax[irow,icol].matshow(corr_tmp, origin='lower', vmin=-0.8, vmax=1)
        ax[irow,icol].xaxis.set_ticks_position('bottom')
        ax[irow,icol].set_title('%s, %s: %d'%(key,phase,count))
    pp.savefig(dpi=600)
    plt.close(fig)

## Statistics of the correlation coefficients

In [None]:
num_rows = len(et.index)
num_cols = num_phases

for animal in animals:
    print (animal)

    fig, ax = plt.subplots(num_rows,num_cols, figsize=(5*num_cols,5*num_rows))
    fig.suptitle('Distribution of the above correlation coefficients\n'+
                 'for '+animal+' (diagonals excluded)',fontsize=16)
    #ax = np.ravel(ax)

    for (irow,key),(icol,phase) in itertools.product(enumerate(et.index),enumerate(la.phases[:num_cols])):
        count = np.nan_to_num(et.loc[key,animal])
        try:
            corr_tmp = data[animal].mx[key+(phase,)]
            corr_tmp = corr_tmp+np.diag(np.nan*np.diag(corr_tmp))
        except KeyError:
            corr_tmp = []
            
        if len(corr_tmp) and np.sum(corr_tmp>-1.0):
            ax[irow,icol].hist(np.ravel(corr_tmp),range=(-1,1),bins=20)
            ax[irow,icol].set_yscale('log')
        ax[irow,icol].set_title('%s, %s: %d'%(key,phase,count))

    pp.savefig()
    plt.close(fig)

In [None]:
oraculum = True

### Compare correlation coefficient distributions

In [None]:
def describe_correlation(data, title, has_oraculum):
    fig, ax = plt.subplots(1,1,figsize=(12,16))
    fig.suptitle(title,fontsize=16)
    ax.axis('off')
    if oraculum:
        stat = np.round(data.describe(),4).T
    else:
        stat = np.round(data.stack(level=3).describe(),4).T
    ordered = la.df_epoch(stat)
    stat = stat.sort_index()
    cw = np.ones((len(ordered.columns),))
    tab = mpl.table.table(ax, cellText=ordered.values,
             rowLabels=[', '.join(x) for x in ordered.index.values],
             colLabels=ordered.columns.values.astype(str),
             loc='upper right', fontsize=20, colWidths=0.6*cw/np.sum(cw),
             bbox=[0.3,0,0.7,1], cellLoc='center')
    return stat, fig

In [None]:
def compare_correlation(stat, title, has_oraculum):
    lmi = pd.DataFrame([], index=la.phases[0:num_phases],
            columns = la.legal_conditions if oraculum else la.short_conditions).unstack().index

    fig, ax = plt.subplots(1,1,figsize=(12,16))
    fig.suptitle(title,fontsize=16)
    ax.axis('off')
    cellcolor = np.vectorize(lambda x: 'lightcoral' if x>0.5 else (
                            'lightblue' if x<-0.4 else 'white'))
    c = np.sqrt(stat.mean().loc['count'])
    diff = []
    for epoch1, epoch2 in [('Learning','Pre-Learning'),
                           ('Post-Learning','Pre-Learning'),('Post-Learning','Learning')]:
        try:
            d = (stat.loc[epoch1,'mean']-stat.loc[epoch2,'mean'])/(
                 stat.loc[epoch1,'std']+stat.loc[epoch2,'std'])*2
        except KeyError:
            d = pd.Series([]).reindex(index=lmi)
        d = d.to_frame(name='  -  '.join((epoch1,epoch2)).replace('-Learning','-L'))
        diff.append(d)
    diff = np.round(pd.concat(diff,axis=1),4).reindex(lmi)
    cw = np.ones((3,))
    tab = mpl.table.table(ax, cellText=diff.values,
             cellColours=cellcolor(diff.values),
             rowLabels=[', '.join(x) for x in diff.index.values],
             rowColours=np.repeat(la.legal_colors if oraculum else la.short_colors,num_phases),
             colLabels=diff.columns.values.astype(str),
             loc='upper right', fontsize=32, colWidths=0.6*cw/np.sum(cw),
             bbox=[0.3,0,0.7,1], cellLoc='center')
    tab.set_fontsize(32)
    return diff, fig

In [None]:
def plot_correlation_bars(stat, title, et, has_oraculum):
    num_rows = num_phases
    num_cols = len(la.epochs)
    # We don't use sharex on purpose: we want to set different tick labels in the subplot columns
    fig, ax = plt.subplots(num_rows, num_cols, figsize=(5*num_cols,5*num_rows), sharey=True)
    fig.suptitle(title,fontsize=16)
    cat = len(la.legal_conditions if has_oraculum else la.short_conditions)
    bars = stat.reset_index().set_index(['phase']+la.sort_learning[0:(4 if has_oraculum else 3)])
    for (irow,phase), (icol,epoch) in itertools.product(enumerate(la.phases[:num_phases]),enumerate(la.epochs)):
        try:
            bar = bars.loc[(phase, epoch),:].reindex(la.legal_conditions if has_oraculum else la.short_conditions)
            if has_oraculum:
                lab = et.loc[epoch].reindex(la.legal_conditions, fill_value=0)
            else:
                lab = et.loc[epoch].sum(level=('context','port')).reindex(la.short_conditions, fill_value=0)
            ax[irow,icol].set_title(epoch)
            ax[irow,icol].set_ylabel(phase)
            low, high = [0]+bar['25%'].fillna(0).tolist(),[0]+bar['75%'].fillna(0).tolist()
            ax[irow,icol].fill_between(np.arange(0,cat+1), low, high, alpha=0.1, interpolate=False, color='grey', edgecolor=None, step='pre')
            ax[irow,icol].bar(range(0,cat),bar['mean'],1,yerr=bar['std'],color=la.legal_colors if has_oraculum else la.short_colors)
            ax[irow,icol].set_xticks(np.arange(0,cat)+0.5)
            if irow+1==num_rows:
                labels = [('%s: %d'%(', '.join(idx),np.nan_to_num(count))) for idx,count in lab.iteritems()]
            else:
                labels = ['%d'%np.nan_to_num(count) for idx,count in lab.iteritems()]
            ax[irow,icol].set_xticklabels(labels, rotation='vertical')

        except KeyError:
            pass
    return fig

### Real value

In [None]:
for animal in animals:
    plot_correlation_bars(data[animal].stat, 'Distribution of the above coefficients\n'
                          'for '+animal+' (diagonals excluded)', et.loc[:,animal], oraculum)
    pp.savefig()
    plt.close(fig)

### Absolute value

In [None]:
for animal in animals:
    plot_correlation_bars(data[animal].astat, 'Distribution of the absolute value of the correlation coefficients\n'
                          'for '+animal+'(diagonals excluded)', et.loc[:,animal], oraculum)
    pp.savefig()

In [None]:
num_conditions = 2
list_conditions = [('CS+','W+','A+'),('CS-','W+','A-')]
num_kinds = 2
list_kinds = ['stat', 'astat']
name_kinds = ['real value', 'absolute value']
def plot_correlation_comparison(data, phase, title, has_oraculum):
    num_rows = num_conditions
    num_cols = num_kinds
    fig, ax = plt.subplots(num_rows,num_cols, figsize=(5*num_cols,5*num_rows), sharex=True, sharey=True)
    fig.suptitle(title,fontsize=16)
    #cat = len(la.legal_conditions if has_oraculum else la.short_conditions)
    #bars = stat.reset_index().set_index(['phase']+la.sort_learning[0:(4 if has_oraculum else 3)])
    cat = np.array([1,2,3])
    for (irow,cond), (icol,kind) in itertools.product(enumerate(list_conditions),enumerate(list_kinds)):
        mean, std = [], []
        for iani, animal in enumerate(animals):
            try:
                m = data[animal][kind].loc[(slice(None),)+cond+(phase,),'mean']
                s = data[animal][kind].loc[(slice(None),)+cond+(phase,),'std']
                m = m.reset_index(['context','port','puffed','phase'],drop=True).reindex(la.epochs)
                s = s.reset_index(['context','port','puffed','phase'],drop=True).reindex(la.epochs)
                ax[irow,icol].set_title(name_kinds[icol])
                ax[irow,icol].set_ylabel(', '.join(cond))
                #low, high = [0]+bar['25%'].fillna(0).tolist(),[0]+bar['75%'].fillna(0).tolist()
                #ax[irow,icol].fill_between(np.arange(0,cat+1), low, high, alpha=0.1, interpolate=False, color='grey', edgecolor=None, step='pre')
                ax[irow,icol].errorbar(cat+0.1*iani,m,yerr=s,label=animal)
                #ax[irow,icol].set_xticks(np.arange(0,cat)+0.5)
                #if irow+1==num_rows:
                #    labels = [('%s: %d'%(', '.join(idx),np.nan_to_num(count))) for idx,count in lab.iteritems()]
                #else:
                #    labels = ['%d'%np.nan_to_num(count) for idx,count in lab.iteritems()]
                #ax[irow,icol].set_xticklabels(labels, rotation='vertical')
            except KeyError:
                pass
        ax[irow,icol].set_xlim(cat[0],cat[-1]+0.1*len(animals))
        ax[irow,icol].set_xticks(cat+0.05*len(animals))
        ax[irow,icol].set_xticklabels(la.epochs.values) #, rotation='vertical')
    #ax[-1,-1].legend()
    h, l = ax[0,0].get_legend_handles_labels()
    fig.legend(h,l,'lower center',mode='expand',ncol=int((len(animals)+1)/2.0))
    return fig

In [None]:
pp = helpmultipage('all_compare.pdf')

#pp = helpmultipage('all_compare.pdf')
fig = plot_correlation_comparison(data, 'Ready', 'Phase Ready', True)
pp.savefig()
fig = plot_correlation_comparison(data, 'CS', 'Phase CS', True)
pp.savefig()
fig = plot_correlation_comparison(data, 'Trace', 'Phase Trace', True)
pp.savefig()

pp.close()

## Similarity of correlation matrices