In [None]:
%load_ext autoreload
%matplotlib widget

In [None]:
import numpy as np
from pathlib import Path
import pandas as pd
import flammkuchen as fl
import os

from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, Normalize
import seaborn as sns
from itertools import product
from skimage.filters import threshold_otsu
from scipy.cluster.hierarchy import dendrogram, linkage

from scipy.stats import gaussian_kde

plt.style.use("figures.mplstyle")
cols = sns.color_palette()

In [None]:
fig_fold = Path(r"C:\Users\otprat\Documents\figures\luminance\manuscript_figures\fig5")

if not os.path.isdir(fig_fold):
    os.mkdir(fig_fold)

In [None]:
%autoreload
from luminance_analysis import traces_stim_from_path, PooledData
from luminance_analysis.utilities import deconv_resamp_norm_trace, reliability, \
    nanzscore, get_kernel, pearson_regressors, get_mn_and_error, train_test_split
from luminance_analysis.plotting import shade_plot, make_bar, get_yg_custom_cmap, add_offset_axes
from luminance_analysis.clustering import find_trunc_dendro_clusters

# Load data:

In [None]:
master_path = Path(r"\\FUNES2\legacy\experiments\E0032_luminance\neat_exps")

In [None]:
tau_6f = 5
ker_len = 30
delay = 3
n_clust_list = [4, 5, 4] * 2
normalization = "zscore"

protocol = "flashes"
brain_regions_list = ["GC", "IO", "PC"]

data_dict = {"{}_{}".format(r, protocol):{} for r in ["GC", "IO", "PC"]}

for brain_region, n_clust in zip(["GC", "IO", "PC"], n_clust_list):
    path = master_path / protocol / brain_region
    stim, traces, _ = traces_stim_from_path(path)

    # Mean traces, calculate reliability index :
    rel_idxs = reliability(traces)
    
    # Find threshold from reliability histogram...
    rel_thr = threshold_otsu(rel_idxs[~np.isnan(rel_idxs)])

    # ...and load again filtering with the threshold:
    _, traces, meanresps = traces_stim_from_path(path, resp_threshold=rel_thr, nanfraction_thr=1)
    rel_idxs = reliability(traces)
    
    
    # Fix problem with interpolated stimulus values between intermediate luminance levels:
    invalid_idxs = np.array([stim[:, 1] != n for n in [0, 1, 0.2, 0.05]]).all(0)  # find invalid indexes
    if sum(invalid_idxs) > 0:
        stim[np.argwhere(invalid_idxs), 1] = stim[np.argwhere(invalid_idxs)-1, 1]  # replace with following value

    # Cluster traces (needed for the sorted plots):
    linked = linkage(meanresps, "ward")    

    # make truncated tree to get clusters ids. 
    # Ugly but necessary to get the correct sequence of leaves:
    plt.figure(figsize=(0.1, 0.1))  
    dendro = dendrogram(linked, n_clust, truncate_mode ="lastp")
    plt.close()
    cluster_ids = dendro["leaves"]
    labels = find_trunc_dendro_clusters(linked, dendro)
    
    # Add everything to dictionary:
    key = "{}_{}".format(brain_region, protocol)
    data_dict[key]["raw_traces"] = traces
    data_dict[key]["rel_idxs"] = rel_idxs
    data_dict[key]["mean_traces"] = meanresps
    data_dict[key]["stim"] = stim
    data_dict[key]["clust_labels"] = labels
    data_dict[key]["pooled"] = PooledData(path)

# Center of mass plot

In [None]:
# Times in seconds at which the transitions occour:
step_t_sec_list = [2, 12, 26]
post_int_s_list = [3, 7, 21]

In [None]:
# The indexes of the "luminance ON" clusters are manually specified according to results in fig. 3:
for cell_type, excited_clusters in zip(["GC", "IO", "PC"], [[1,2,3], [1, 3, 4], [2, 3]]):
    traces = data_dict[cell_type + "_flashes"]["raw_traces"]
    stim = data_dict[cell_type + "_flashes"]["stim"]
    pooled_data = data_dict[cell_type + "_flashes"]["pooled"]
    labels = data_dict[cell_type + "_flashes"]["clust_labels"]

    # Select only excitatory clusters (clustering is done as in fig. 3):
    sel_idxs = np.argwhere(np.array([labels == n for n in excited_clusters]).any(0))[:, 0]
    sel_traces = traces[sel_idxs, :, :]
    sel_rel_idxs = data_dict[cell_type + "_flashes"]["rel_idxs"][sel_idxs]

    # To cross validate, we will sort the averages of half responses and the display the other half:
    order_traces, test_traces = train_test_split(sel_traces)  # Split repetitions in sort and test blocks
    order_means, test_means = np.nanmean(order_traces, 2), np.nanmean(test_traces, 2)

    #Find response center of mass and time of maximal response in the "order" block during the longest stimulus:
    stim_i = 2
    step_t, post_int = step_t_sec_list[stim_i], post_int_s_list[stim_i]

    # Crop the traces matrix in the stimulus range:
    responses = order_means[:, slice(*pooled_data.frame_from_t([step_t, step_t+post_int]))]

    # Normalize to have integral of 1:
    responses = (responses.T - responses.min(1)).T
    responses = (responses.T/responses.sum(1)).T

    # Center of mass for each trace: (for sorting reps)
    centerofmass = responses.shape[1]*((np.arange(responses.shape[1])*responses).mean(1))
    ordered_idx = np.argsort(centerofmass)
    
    # Center of mass for each trace: (for plotting reps)
    test_responses = test_means[:, slice(*pooled_data.frame_from_t([step_t, step_t+post_int]))]
    test_responses = (test_responses.T - test_responses.min(1)).T
    test_responses = (test_responses.T/test_responses.sum(1)).T
    test_centerofmass = test_responses.shape[1]*((np.arange(test_responses.shape[1])*test_responses).mean(1))

    # We will plot the means of the not-sorted traces, and their peak times:
    test_mean_sm = pd.DataFrame(test_means.T).rolling(4, center=True).mean().values.T  # smooth
    test_mean_sm = ((test_mean_sm.T - np.nanmean(test_mean_sm, 1))/np.nanstd(test_mean_sm, 1)).T  # z score

    # Find time of maximal response:
    pad_pre = 1
    pad_post = 2
    responses = test_mean_sm[:, slice(*pooled_data.frame_from_t([step_t, step_t+post_int]))]
    peak_times = np.argmax(responses, 1)
    
    data_dict[cell_type + "_flashes"]["ordered_idx"] = ordered_idx
    data_dict[cell_type + "_flashes"]["test_mean_sm"] = test_mean_sm
    data_dict[cell_type + "_flashes"]["peak_times"] = peak_times
    data_dict[cell_type + "_flashes"]["rel_idxs_sel"] = sel_rel_idxs
    data_dict[cell_type + "_flashes"]["sel_idxs"] = sel_idxs
    data_dict[cell_type + "_flashes"]["order_coms"] = centerofmass
    data_dict[cell_type + "_flashes"]["test_coms"] = test_centerofmass

In [None]:
def sorted_traces_plot(data_dict, cell_type="GC", bars=True, figure=None, frame=None, cplot_h=1):
    l = 2  # colormap ranges
    pad_pre = 1  # padding before stim starts
    pad_post = 2  # padding after
    num_groups = 11  # number of percentile groups for the mean traces

    # Control x displacement of the plots:
    offset_val = 0.04  
    offset = 0.1
    
    ordered_idx = data_dict[cell_type + "_flashes"]["ordered_idx"]
    test_mean_sm = data_dict[cell_type + "_flashes"]["test_mean_sm"]
    peak_times = data_dict[cell_type + "_flashes"]["peak_times"]

    custom_cm = get_yg_custom_cmap(n=100)
    
    if figure is None:
        figure = plt.figure(figsize=(7,3))

    for i, (step_t, post_int) in enumerate(zip(step_t_sec_list, post_int_s_list)):  # loop over flashes

        w = 0.7 * (post_int_s_list[i] / sum(post_int_s_list))  # y width of the plot

        # Crop stimulus array for the shade plot:
        stim_cropped = stim[slice(*pooled_data.frame_from_t([step_t-pad_pre, step_t+post_int+pad_post])), :]

        # Add axes and make the color plot:
        ax_mat = add_offset_axes(figure, (offset, 0.2, w, 0.5*cplot_h), frame=frame)
        #ax_mat = figure.add_axes()
        to_plot = test_mean_sm[ordered_idx, slice(*pooled_data.frame_from_t([step_t-pad_pre, step_t+post_int+pad_post]))]
        im = ax_mat.imshow(to_plot[::, :],  aspect="auto", cmap="RdBu_r", vmin=-l, vmax=+l, interpolation='none')

        # Add axes and plot average trace of percentiles:
        ax_traces = add_offset_axes(figure, (offset, 0.21 + 0.5*cplot_h, w, 0.25), frame=frame)
        cells_per_group = to_plot.shape[0] // num_groups

        for group in range(num_groups):
            y = np.nanmean(to_plot[cells_per_group*group:cells_per_group*(group+1), :], 0)
            plt.plot(stim_cropped[:, 0], y - y.min(), c=custom_cm.reversed()(group/num_groups),  linewidth=1.5)
        ax_traces.set_xlim(stim_cropped[0, 0], stim_cropped[-1, 0])
        
        shade_plot(stim_cropped)

        ax_mat.axis("off")
        ax_traces.axis("off")

        offset += w + offset_val  # offset of next plot

    # Scatter the peak times:
#     ax_mat.scatter(peak_times[ordered_idx], np.arange(len(peak_times)), color=(0.3,)*3, s=1)
#     ax_mat.set_ylim(len(peak_times)-0.5, 0)

    # Time bar, only on the last plot:
    if bars:
        dt = stim[1,0]
        barlength = 4/dt
        ax_mat.axis("on")
        ax_mat.axes.spines["left"].set_visible(False)
        ax_mat.set_yticks([])
        make_bar(ax_mat, [len(y) - barlength, len(y)-1], label="{} s".format(round(barlength*dt)), lw=2)
        
        axcolor = add_offset_axes(figure, (0.07, 0.2, 0.015, 0.15), frame=frame)
        cbar1 = plt.colorbar(im, cax=axcolor, orientation="vertical")
        cbar1.set_ticks([-l, l])
        cbar1.ax.tick_params(length=3)
        axcolor.yaxis.set_ticks_position('left')
        axcolor.yaxis.set_label_position('left')
        axcolor.set_ylabel("dF/F")

    # Color bar for response position:
    sm = plt.cm.ScalarMappable(cmap=custom_cm)
    sm.set_array([])

    bar_cmap = add_offset_axes(figure, (offset - offset_val+0.003, 0.2, 0.015, 0.5*cplot_h), frame=frame)
    cbar = plt.colorbar(sm, cax=bar_cmap, orientation="vertical", ticks=[.18, .82])
    cbar.ax.set_yticklabels(["Late", "Early"], rotation=90, va="center")
    cbar.ax.tick_params(length=0)
    cbar.outline.set_visible(False)

In [None]:
sorted_traces_plot(data_dict, cell_type="GC", cplot_h=0.5)
# sorted_traces_plot(data_dict, cell_type="IO", cplot_h=0.5)
# sorted_traces_plot(data_dict, cell_type="PC", cplot_h=0.5)

## KDE with peak times:

In [None]:
# def peaks_kde(data_dict, figure=None, frame=None):
#     if figure is None:
#         figure = plt.figure(figsize = (4,2))
#     ax = add_offset_axes(figure, (0.15, 0.25, 0.8, 0.6), frame=frame)
#     off = 0
#     for i, cell_type in enumerate(["GC", "IO"]):
#         ordered_idx = data_dict[cell_type + "_flashes"]["ordered_idx"]
#         test_mean_sm = data_dict[cell_type + "_flashes"]["test_mean_sm"]
#         peak_times = data_dict[cell_type + "_flashes"]["peak_times"]
#         x_arr = np.arange(-40, test_mean_sm.shape[1])
#         kde = gaussian_kde(peak_times)(x_arr)
#         ax.fill_between(x_arr * stim[1, 0], kde/np.max(kde) -np.ones(len(x_arr))*off*i, -i*off, alpha=0.6)
#         ax.text(20, 0.9-i*0.15, cell_type, color=sns.color_palette()[i], fontsize=7)
#     ax.set_xlim(-1, 23)
#     ax.set_xlabel("Time from stim. onset(s)")
#     ax.set_ylabel("Peak distr. (a.u.)")
#     ax.set_yticks([0, 0.4, 0.8])
#     ax.set_ylim(0, 1.1)

In [None]:
def peaks_kde(data_dict, celltypes=["GC", "IO", "PC"], figure=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize = (4,2))
    ax = add_offset_axes(figure, (0.15, 0.25, 0.8, 0.6), frame=frame)
    off = 0
    for i, cell_type in enumerate(celltypes):
        ordered_idx = data_dict[cell_type + "_flashes"]["ordered_idx"]
        test_mean_sm = data_dict[cell_type + "_flashes"]["test_mean_sm"]
        peak_times = data_dict[cell_type + "_flashes"]["peak_times"]
        x_arr = np.arange(-40, test_mean_sm.shape[1])
        kde = gaussian_kde(peak_times)(x_arr)
        ax.plot(x_arr * stim[1, 0], kde/kde.max() -np.ones(len(x_arr))*off*i, linewidth=2, label=cell_type)
        ax.fill_between(x_arr * stim[1, 0], kde/np.max(kde) -np.ones(len(x_arr))*off*i, -i*off, alpha=0.3)
        ax.text(22, 0.9-i*0.15, cell_type, color=sns.color_palette()[i], fontsize=7)
    ax.set_xlim(-1, 23)
    ax.set_xlabel("Time from stim. onset(s)")
    ax.set_ylabel("Peak distr. (a.u.)")
    ax.set_yticks([0, 0.4, 0.8])
    ax.set_ylim(0, 1.1)

In [None]:
peaks_kde(data_dict, celltypes=["GC", "IO"])

# Individual cell plot

In [None]:
def single_cell_plot(data_dict, idxs, off=2, figure=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize=(5, 2))
    ax = add_offset_axes(figure, [0.1, 0.2, 0.9, 0.8], frame=frame)
    
    for n, i in enumerate(idxs):
        cell_trace = data_dict["GC_flashes"]["raw_traces"][i, :, :].copy()
        cell_trace = pd.DataFrame(cell_trace).rolling(4, center=True).mean().values  # smooth
        cell_trace = (cell_trace - np.nanmean(cell_trace, 0))/np.nanstd(cell_trace, 0)  # normalise
    
        ax.plot(stim[:, 0], cell_trace-n*off, linewidth=0.3, 
                color=get_yg_custom_cmap(n=len(idxs)).reversed()(n/len(idxs)))
        ax.plot(stim[:, 0], np.nanmean(cell_trace, 1)-n*off, linewidth=2, 
                color=get_yg_custom_cmap(n=len(idxs)).reversed()(n/len(idxs)))
    
    shade_plot(stim)
    ax.set_xlim(0, stim[-1, 0])
    ax.set_yticks([])
    ax.spines["left"].set_visible(False) 
    make_bar(ax, [stim[-1, 0]-5, stim[-1,0]], label="5 s")

Find good example of ramping cell with high reliability:

In [None]:
rel_idxs = data_dict["GC_flashes"]["rel_idxs"]
labels = data_dict["GC_flashes"]["clust_labels"]

idxs = np.argwhere((labels == 3) & (rel_idxs > 0.65))[:,0]
print(idxs[27])  # high-quality, organic hand-picked cell ramping

In [None]:
#single_cell_plot(data_dict, [idxs[27]])

In [None]:
single_cell_plot(data_dict, [313, 2164, 325], off=5)

### Decoding

In [None]:
decoding_dir = Path().resolve().parent/'decoding'

In [None]:
duration_decoding_dict = fl.load(decoding_dir / 'duration_decoding.h5')

In [None]:
def plot_duration_decoding(decoding_dict, figure=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize = (4,4))
        
    ax = add_offset_axes(figure, (0.05, 0.05, 0.95, 0.95), frame=frame)
    
    pred_mn = np.nanmean(decoding_dict['GC']['full_pred'], 0)
    pred_sd = np.nanstd(decoding_dict['GC']['full_pred'], 0)

    ax.fill_between(decoding_dict['GC']['time'], pred_mn-pred_sd, pred_mn+pred_sd, color=(0.1,0.1,0.1,0.07), linewidth=0)
    ax.plot(decoding_dict['GC']['time'], pred_mn, label="GCs average", c=sns.color_palette()[0])
    ax.scatter(decoding_dict['IO']['time'], np.concatenate([decoding_dict['GC']['preds'][0], decoding_dict['GC']['preds'][1]]), color=sns.color_palette()[0], s=.4, label="GCs subset")
    ax.scatter(decoding_dict['IO']['time'], decoding_dict['IO']['preds'], color=sns.color_palette()[1], s=.4, label="IONs")

    ax.plot(decoding_dict['GC']['time'], decoding_dict['GC']['time'], color=(0.3, 0.3, 0.3,0.7))
    ax.set_xlabel("Actual time since onset (s)")
    ax.set_ylabel("Predicted time since onset (s)")
    ax.set_aspect(1)
    ax.legend()


In [None]:
plot_duration_decoding(duration_decoding_dict)

## Assemble final figure

In [None]:
#figure5 = plt.figure(figsize=(6, 6))
figure5 = plt.figure(figsize=(6, 9))

single_cell_plot(data_dict, [313, 2164, 325], off=4, figure=figure5, frame=(0.12, 0.775, 0.7, 0.2))
# plt.ylim(-9.8, 3.8)
figure5.text(0.15, 0.98, 'A')

sorted_traces_plot(data_dict, cell_type="GC", figure=figure5, frame=(0.05, 0.5, .95, 0.3), bars=False)
figure5.text(0.1, 0.79, 'B')

# #sorted_traces_plot(data_dict, cell_type="IO", figure=figure5, frame=(0.05, 0.17, .95, 0.36), cplot_h=0.4)
sorted_traces_plot(data_dict, cell_type="IO", figure=figure5, frame=(0.05, 0.3, .95, 0.35), cplot_h=0.4, bars=True)
figure5.text(0.1, 0.53, 'C')

# sorted_traces_plot(data_dict, cell_type="PC", figure=figure5, frame=(0.05, 0.325, .95, 0.35), cplot_h=0.4)

peaks_kde(data_dict, celltypes=["GC", "IO"], figure=figure5, frame=(0.05, 0.075, .45, 0.25))
figure5.text(0.025, 0.3, 'D')

plot_duration_decoding(duration_decoding_dict, figure=figure5, frame=(0.45, 0.05, .6, 0.3))
figure5.text(0.475, 0.34, 'E')


In [None]:
if fig_fold is not None:
    figure5.savefig(str(fig_fold / "sorted.pdf"))

### Supplementary figures

In [None]:
from scipy import stats

In [None]:
fig_fold = Path(r"C:\Users\otprat\Documents\figures\luminance\manuscript_figures\fig5supp")

if not os.path.isdir(fig_fold):
    os.mkdir(fig_fold)

#### COM crossvalidation reliability

In [None]:
def plot_crossval_com_rel(data_dict, vmin, vmax, figure=None, frame=None):
    
    if figure is None:
        figure = plt.figure(figsize=(9, 3))

    brain_regions = ['GC_flashes', 'IO_flashes']
    colors = sns.color_palette()[:3]
    
    for i, (brain_region, color) in enumerate(zip(brain_regions, colors)):
        
        ax_scatter = add_offset_axes(figure, (0.075 + .3*i, 0.15, .25, .8), frame=frame)
        
        rel_cmap = LinearSegmentedColormap.from_list('rel_map', [[0.9, 0.9, 0.9], color], N=100)
        
        dt = data_dict['IO_flashes']['pooled'].dt_im
        
        points = ax_scatter.scatter(data_dict[brain_region]['order_coms']*dt - pad_pre, data_dict[brain_region]['test_coms']*dt - pad_pre, 
                        c=data_dict[brain_region]["rel_idxs_sel"], cmap=rel_cmap, vmin=vmin, vmax=vmax)
        
        ax_scatter.set_xlabel('COM in plotted reps. [s.]')
        
        lims = [2, 16]
        ax_scatter.set_ylim([lims[0], lims[1]])
        ax_scatter.set_xlim([lims[0], lims[1]])        
        ax_scatter.plot(lims, lims, ls='--', c='black', alpha=.35)
            
        #Colorbar
        axcolor = add_offset_axes(figure, (.63 + 0.015*i, 0.4, .01, .3), frame=frame)
        cbar = plt.colorbar(points, cax=axcolor, shrink=.5)
        cbar.set_ticks([])
        
        if i == 0:
            ax_scatter.set_ylabel('COM in sorting reps. [s.]')
            cbar.set_label('Reliability coef.', labelpad=-15)
        elif i == 2:
            cbar.set_ticks([vmin, vmax])
            
        # r & p values
        r_val, p_val = stats.pearsonr(data_dict[brain_region]['order_coms']*dt - pad_pre, 
                                      data_dict[brain_region]['test_coms']*dt - pad_pre)
        
        ax_scatter.text(.1, .85, 'r = {:.4f} \np = {:.4f}'.format(r_val, p_val), transform=ax_scatter.transAxes,
                       fontsize=6)
        print(r_val, p_val)
        
        
        
    plt.tight_layout()
    return(figure)

In [None]:
COM_reliability_plot = plot_crossval_com_rel(data_dict, 0, .7, figure=None, frame=None)

In [None]:
if fig_fold is not None:
    COM_reliability_plot.savefig(str(fig_fold / "COM_reliability.pdf"))

#### TPI index

In [None]:
from scipy.stats import pearsonr

from luminance_analysis.plotting import TPI_plot
%autoreload

In [None]:
#Define brain regions
brain_regions = ['GC_flashes', 'IO_flashes']

#Define padding for cropping
pad_pre = 1
pad_post = 2

#Define Ca kernel
filter_t = np.arange(0, 1, 0.005)
tau = 0.01
ca_kernel = np.exp(-filter_t / tau)

#Crop stimulus and responses, normalize and store in dict
flashes_dict = {brain_region: {} for brain_region in brain_regions}

for brain_region in brain_regions:
    flashes_dict_br = {flash_idx:{'stim':{}, 'conv_stim':{}, 'resps':{}} for flash_idx in range(3)}
    sel_traces = data_dict[brain_region]['raw_traces'][data_dict[brain_region]['sel_idxs']]

    for flash in range(3):
        step_t, post_int = step_t_sec_list[flash], post_int_s_list[flash]
        
        #Crop and store stimulus
        stim_cropped = data_dict[brain_region]['stim'][slice(*data_dict[brain_region]['pooled'].frame_from_t([step_t-pad_pre, step_t+post_int+pad_post])), :]
        flashes_dict_br[flash]['stim'] = stim_cropped
        
        #Convolve stimulus
        conv_stim = np.convolve(stim_cropped[:, 1], ca_kernel)[:stim_cropped[:, 1].shape[0]]
        conv_stim_norm = conv_stim/np.trapz(conv_stim)
        flashes_dict_br[flash]['conv_stim'] = np.array((stim_cropped[:, 0], conv_stim_norm)).transpose()
        
        #Crop and store responses
        resps = sel_traces[data_dict[brain_region]['ordered_idx'], slice(*data_dict[brain_region]['pooled'].frame_from_t([step_t-pad_pre, step_t+post_int+pad_post]))]
        flashes_dict_br[flash]['resps'] = np.empty_like(resps)
        for roi in range(resps.shape[0]):
            roi_data = resps[roi, :, :]
            roi_data_norm = roi_data - np.nanmean(roi_data[:5, :], 0) #Normalize to response during first 2s (bring baseline to 0)
            resps_integ = np.trapz(np.abs(roi_data_norm), axis=0)
            roi_data_norm  = roi_data_norm/resps_integ #Normalize to integral
            flashes_dict_br[flash]['resps'][roi, :, :] = roi_data_norm
            
    flashes_dict[brain_region] = flashes_dict_br

In [None]:
TPI_dict = {brain_region:{flash:{} for flash in range(3)} for brain_region in brain_regions}

for brain_region in brain_regions:   
    
    for flash in range(3):
        num_rois = flashes_dict[brain_region][flash]['resps'].shape[0]
        flash_tpis = []
        
        for roi in range(num_rois):
            roi_tpis = []
            roi_resps = flashes_dict[brain_region][flash]['resps'][roi, :, :]
            roi_resps = roi_resps[:, np.all(~np.isnan(roi_resps), 0)]
            
            #For each ROI, calculate its TPI during each one of its individual repetitions
            for rep in range(roi_resps.shape[1]):
                corr_coef = pearsonr(roi_resps[:, rep], flashes_dict[brain_region][flash]['conv_stim'][:, 1])
                rep_tpi = 1-np.abs(corr_coef[0])
                roi_tpis.append(rep_tpi)
            
            flash_tpis.append(roi_tpis)
            
        tpis_mat = np.empty((num_rois, max([len(flash_tpis[roi]) for roi in range(num_rois)])))
        tpis_mat[:] = np.nan
        
        #Store TPIs from each ROI in a matrix
        for roi in range(num_rois):
            tpis_mat[roi, :len(flash_tpis[roi])] = flash_tpis[roi]
            
        TPI_dict[brain_region][flash] = tpis_mat       

In [None]:
#Calculate reliability of each ROI, during each flash
reliability_dict = {brain_region:{flash:{} for flash in range(3)} for brain_region in brain_regions}

for brain_region in brain_regions:
    for flash in range(3):
        roi_reliability = reliability(flashes_dict[brain_region][flash]['resps'])
        reliability_dict[brain_region][flash]['reliability'] = roi_reliability                 

In [None]:
def TPI_plot(TPI_dict, reliability_dict, v_min, v_max, figure=None, frame=None):
    
    if figure is None:
        figure = plt.figure(figsize=(6, 6))
    
    brain_regions = ['GC_flashes', 'IO_flashes']
    colors = sns.color_palette()[:3]
    
    for i, brain_region in enumerate(brain_regions):
        
        rel_cmap = LinearSegmentedColormap.from_list('rel_map', [[0.9, 0.9, 0.9], colors[i]], N=100)
        
        for flash in range(3):
            
            ax_hist = add_offset_axes(figure, (0.075 + .275*flash, 0.7 - .3*i, .25, .25), frame=frame)

            rois = range(TPI_dict[brain_region][flash].shape[0])
            mean_tpi = np.nanmean(TPI_dict[brain_region][flash], 1)
            points = ax_hist.scatter(rois, mean_tpi, c=reliability_dict[brain_region][flash]['reliability'], cmap=rel_cmap, vmin=v_min, vmax=v_max)

            ax_hist.set_xlim(0, TPI_dict[brain_region][flash].shape[0])
            ax_hist.set_ylim(0, 1)
            
            if i == 2:
                ax_hist.set_xlabel('ROI (sorted by C.O.M. time)')
            elif i == 0:
                #ax_hist.set_title('Flash {}'.format(flash+1))
                ax_hist.text(.5, 1.05, 'Flash {}'.format(flash+1), ha='center', va='top', transform=ax_hist.transAxes, fontsize=8.5)

            else:
                pass

            if flash == 0:
                ax_hist.set_ylabel('TPI')
            else:
                ax_hist.set_yticklabels([])

        #Colorbar
        axcolor = add_offset_axes(figure, (.925, 0.75 - .3*i, .015, .15), frame=frame)
        cbar = plt.colorbar(points, cax=axcolor, shrink=.5)
        cbar.set_ticks([v_min, v_max])
        cbar.set_label('Reliability coef.', labelpad=-35)

    return(figure)

In [None]:
tpi_plot = TPI_plot(TPI_dict, reliability_dict, 0, .7, figure=None, frame=None)

In [None]:
if fig_fold is not None:
    tpi_plot.savefig(str(fig_fold / "TPI_panel.pdf"))

In [None]:
duration_decoding_dict['IO']['rsquared']

In [None]:
def plot_rsquared(decoding_dict, figure=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize = (4,4))
        
    ax = add_offset_axes(figure, (0.1, 0.1, 0.8, 0.8), frame=frame)

    ax.hist(decoding_dict['GC']['rsquared'], bins=20, label="GCs subset", color=sns.color_palette()[0])
    ax.axvline(decoding_dict['IO']['rsquared'], color=sns.color_palette()[1], label="IONs")
    ax.set_xlabel("$R^2$")
    ax.set_ylabel('Counts')
    ax.legend()

In [None]:
plot_rsquared(duration_decoding_dict)

### Assemble figure

In [None]:
figureS5 = plt.figure(figsize=(9,3))

#COM reliability plot:
COM_rel_panel = plot_crossval_com_rel(data_dict, 0, .7, figure=figureS5, frame=(0.05, 0.05, .9, .9))
figureS5.text(0.05, 0.9, 'A')

#COM reliability plot:
COM_rel_panel = plot_rsquared(duration_decoding_dict, figure=figureS5, frame=(0.685, 0.05, 0.3, .9))
figureS5.text(0.65, 0.9, 'B')

# #TPI plot:
# TPI_panel = TPI_plot(TPI_dict, reliability_dict, 0, .7, figure=figureS5, frame=(0.05, 0.05, 0.96, 0.75))

In [None]:
if fig_fold is not None:
    figureS5.savefig(str(fig_fold / "temporal_patterning_supplementary.pdf"))