In [1]:
%matplotlib widget

In [2]:
import flammkuchen as fl
from matplotlib import pyplot as plt
import numpy as np
from scipy import io
import seaborn as sns
import matplotlib.gridspec as gridspec
from matplotlib.ticker import FormatStrFormatter
from pathlib import Path
import pandas as pd
from luminance_analysis import PooledData, traces_stim_from_path
import os

plt.style.use("figures.mplstyle")
cols = sns.color_palette()


In [3]:
# fig_fold = None # Path(r"J:\_Shared\GC_IO_luminance\figures\fig7\src")

fig_fold = Path(r"C:\Users\otprat\Documents\figures\luminance\manuscript_figures\fig7_v2")
if not os.path.isdir(fig_fold):
    os.mkdir(fig_fold)

# Load data

In [4]:
master_path = Path(r"\\FUNES\Shared\experiments\E0032_luminance\neat_exps")
# master_path = Path(r"J:\_Shared\GC_IO_luminance\data\neat_exps")
# master_path = Path(r"/Users/luigipetrucco/Desktop/data_dictionaries/")

In [5]:
from luminance_analysis.utilities import deconv_resamp_norm_trace, reliability, nanzscore, get_kernel
from skimage.filters import threshold_otsu
from scipy.cluster.hierarchy import dendrogram, linkage, cut_tree, to_tree, set_link_color_palette
from luminance_analysis.plotting import plot_clusters_dendro, shade_plot, add_offset_axes, make_bar
from luminance_analysis.clustering import cluster_id_search, find_trunc_dendro_clusters

In [6]:
tau_6f = 5
tau_6s = 8
ker_len = 20
normalization = "zscore"
protocol = 'steps'

brain_regions_list = ["GC", "IO", "PC"]
tau_list = [tau_6f, tau_6f, tau_6s]
n_cluster_list = [8, 6, 8]
nan_thr_list = [0, 1, 1]

data_dict = {k:{} for k in brain_regions_list}

#load stimulus of GCs and use it as a the reference for time array and stimulus array:
stim_ref = PooledData(path = master_path / protocol / "GC").stimarray_rep

for brain_region, tau, n_cluster, nan_thr in zip(brain_regions_list, tau_list, 
                                                 n_cluster_list, nan_thr_list):
    #Load data :
    path = master_path / protocol / brain_region
    stim, traces, meanresps = traces_stim_from_path(path)

    # Mean traces, calculate reliability index :
    rel_idxs = reliability(traces)
    
    # Find threshold from reliability histogram...
    rel_thr = threshold_otsu(rel_idxs[~np.isnan(rel_idxs)])

    # ...and load again filtering with the threshold:
    _, traces, meanresps = traces_stim_from_path(path, resp_threshold=rel_thr, nanfraction_thr=nan_thr)

    # Hierarchical clustering:
    linked = linkage(meanresps, 'ward')
    
    # Truncate dendrogram at n_cluster level:
    plt.figure(figsize=(0.1, 0.1))  
    dendro = dendrogram(linked, n_cluster, truncate_mode ="lastp")
    plt.close()
    cluster_ids = dendro["leaves"]
    labels = find_trunc_dendro_clusters(linked, dendro) 
    
    # Deconvolution, resampling / normalization:
    deconv_meanresps = np.empty((meanresps.shape[0], stim_ref.shape[0]))
    resamp_meanresps = np.empty((meanresps.shape[0], stim_ref.shape[0]))
    for roi_i in range(deconv_meanresps.shape[0]):
        deconv_meanresps[roi_i, :] = deconv_resamp_norm_trace(meanresps[roi_i, :], stim[:, 0],
                                                                stim_ref[:, 0], tau, ker_len,
                                                                smooth_wnd=4,
                                                                normalization=normalization)
        resamp_meanresps[roi_i, :] = deconv_resamp_norm_trace(meanresps[roi_i, :], stim[:, 0],
                                                                stim_ref[:, 0], None, ker_len,
                                                                smooth_wnd=4,
                                                                normalization=normalization)
    
    cluster_resps = np.empty((n_cluster, stim_ref.shape[0]))
    for clust_i in range(n_cluster):
        cluster_resp = np.nanmean(deconv_meanresps[labels==clust_i, :], 0)  # average cluster responses
        cluster_resps[clust_i, :] = nanzscore(cluster_resp)  # normalize


    # Add everything to dictionary:
    data_dict[brain_region]["linkage_mat"] = linked
    data_dict[brain_region]["clust_labels"] = labels
    #data_dict[brain_region]["raw_mn_resps"] = meanresps
    data_dict[brain_region]["traces"] = traces
    #data_dict[brain_region]["deconv_mn_resps"] = deconv_meanresps
    #data_dict[brain_region]["resamp_mn_resps"] = resamp_meanresps
    data_dict[brain_region]["rel_idxs"] = rel_idxs[rel_idxs > rel_thr]
    data_dict[brain_region]["rel_thr"] = rel_thr
    data_dict[brain_region]["clust_resps"] = cluster_resps

[<luminance_analysis.FishData object at 0x0000017158A15C48>, <luminance_analysis.FishData object at 0x0000017158A15CC8>, <luminance_analysis.FishData object at 0x0000017158A26688>, <luminance_analysis.FishData object at 0x0000017158A2AE48>, <luminance_analysis.FishData object at 0x0000017158A33648>]
[<luminance_analysis.FishData object at 0x0000017158A1AC48>, <luminance_analysis.FishData object at 0x0000017158A1ABC8>, <luminance_analysis.FishData object at 0x0000017158A26708>, <luminance_analysis.FishData object at 0x0000017158A2A8C8>, <luminance_analysis.FishData object at 0x0000017158A33648>]


  c /= stddev[:, None]
  c /= stddev[None, :]


[<luminance_analysis.FishData object at 0x0000017100C9CF88>, <luminance_analysis.FishData object at 0x0000017100CA8048>, <luminance_analysis.FishData object at 0x0000017100CB5608>, <luminance_analysis.FishData object at 0x0000017100CC3BC8>, <luminance_analysis.FishData object at 0x0000017100CAE1C8>]


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<luminance_analysis.FishData object at 0x0000017100E69988>, <luminance_analysis.FishData object at 0x0000017100E69A08>, <luminance_analysis.FishData object at 0x0000017100E6E088>, <luminance_analysis.FishData object at 0x0000017100E728C8>, <luminance_analysis.FishData object at 0x0000017100E77148>]
[<luminance_analysis.FishData object at 0x0000017100E6ABC8>, <luminance_analysis.FishData object at 0x0000017100E6AB48>, <luminance_analysis.FishData object at 0x0000017100E6E5C8>, <luminance_analysis.FishData object at 0x0000017100CC8808>, <luminance_analysis.FishData object at 0x0000017100E7F3C8>]


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<luminance_analysis.FishData object at 0x00000171008BBEC8>, <luminance_analysis.FishData object at 0x00000171008BD208>, <luminance_analysis.FishData object at 0x00000171008C2388>, <luminance_analysis.FishData object at 0x00000171008C6508>, <luminance_analysis.FishData object at 0x00000171008CB848>]
[<luminance_analysis.FishData object at 0x00000171008BDF08>, <luminance_analysis.FishData object at 0x00000171008BDD88>, <luminance_analysis.FishData object at 0x00000171008C6348>, <luminance_analysis.FishData object at 0x00000171008D7088>, <luminance_analysis.FishData object at 0x00000171008CB588>]


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Modelling

### Description of the model


The goal of the model is to reconstruct the activity of PCs based on the activity observed for GCs and IONs.
The function that we will use to model a PC is the following:
$$ trace_{PC}^i = o^i + clusters_{GC} * w^i_{GC} + clusters_{IO} * w^i_{IO}$$


 - $trace_{PC}^i$ is the $i^{th}$ PC cell trace;
 - $o^i$ is an offset term:
 - $clusters_{GC}$, $clusters_{IO}$ are matrices with the average activation of all GC&IO clusters;
 - $w^i_{GC}$, $w^i_{IO}$ are weights vectors for each of the GC and IO cluster. We will allow positive and negative $w^i_{GC}$, but only positive $w^i_{IO}$. This is a quite safe assumption considering the known PC physiology. 


### Approach

Here is a summary of the modelling approach:
- **Create a panel of regressors**:
    - Calculate clusters of GC and IO responses; 
    - Deconvolve average response of each cluster with 6fe05 kernel;
    - Reconvolve it with 6s kernel;
    - Normalize it to be  > 0, and with integral = 1.


- **Clean up PC traces**:
    - For each cell, take raw fluorescence if valid trials, and zscore them on a trial-to-trial base (for changes in offset F across planes);
    - concatenate repetitions;
    - high-pass filter them with very low cutoff freq (1/80 Hz) to remove slow fluctuations;
    - smooth them with a 3 pts mean boxcar rolling window;
    
    
- **Split fit and test data**:
    - Randomly pick from each cell:
        - 2 repetitions that will be left out for the analysis (**test traces**);
        - 4 repetitions (or more, if there are more planes) that will be used for finding the regularization term and for the actual fitting (**fit traces**);


- **Define boundaries, cost function and regularization function**:
    - Parameters to be optimized are:
        - *offset*: a constant term, bound to be between -5 and 5;
        - *coefs_GC*: coefficients for GC regressors, bound to be between -1000 and 1000 (the large difference comes from the different normalizations applied on regressors -norm- and on trace - Z scoring)
        - *coefs_IO*: coefficients for GC regressors, bound to be between 0 and 1000, as we have good reasons to postulate that IO contributions are strictly positive
        - cost function: L2 distance to target trace;
        - regularization function: L1 (sum of absolute value of parameters)
        
        
- **Find regularization parameter**:
    - Find regularization parameter (leave one out cross-validation):
    - For each lambda parameter, train the model on all the fit traces but one (**train traces**);
    - Then, calculate the cost of the fit on the one trace that was left out (**validation trace**);
    - Do this iterating over all possible combinations of n-1 and 1 traces;
    - Calculate average cross over all combinations, over all cells;


- **Fit the trace**:
    - Use the resulting regularization term for fitting the fit repetitions, and use the obtained coefficients for plots / further analyses on the test repetitions

## Create a panel of regressors:

Create regressor panel, making traces non-0 and with integral equal to one:

In [7]:
n_gc_clust = data_dict["GC"]["clust_resps"].shape[0]
n_io_clust = data_dict["IO"]["clust_resps"].shape[0]
regressors_mat = np.concatenate([data_dict["GC"]["clust_resps"], data_dict["IO"]["clust_resps"]])

# Reconvolve and normalize regressors:
for i in range(regressors_mat.shape[0]):
    reconvolved = np.convolve(regressors_mat[i, :], get_kernel(ker_len=100, tau=tau_6s))[:regressors_mat.shape[1]]
    
    # Make strictly positive and with integral == 1:
    reconvolved -= np.min(reconvolved)  # offset at 0
    regressors_mat[i, :] = reconvolved / np.sum(reconvolved)

# Arbitrary cluster names (do we need them?):
gc_cluster_names = ['ON 1', 'ON 2', 'ON abs', 'ON inter.1', 'ON inter.2', 'OFF 1', 'OFF 2', 'OFF inter.']
io_cluster_names = ['Onset', 'ON max', 'Offset 1', 'Offset 2', 'On inter.', 'Offset 3']

Plot the regressors panel:

In [8]:
# Plot the regressor panel:
def reg_panel_plot(regressors_mat, figure=None, ax=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize=(3, 3))

    offset = 0.01
    if ax is None:
        ax = add_offset_axes(figure, (0., 0., 1, 1), frameon=False, frame=frame)
    cols = sns.color_palette()
    for i, col in enumerate([cols[0], ]*n_gc_clust + [cols[1], ]*n_io_clust):
        ax.fill_between(stim[:, 0], np.zeros(stim[:, 0].shape) - i*offset, 
                        regressors_mat[i, :] - i*offset, color=col)
        
    ax.axis("off")

In [9]:
reg_panel_plot(regressors_mat)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Perform PCA on regressors

In [10]:
from sklearn.decomposition import PCA

In [11]:
#Perform PCA
pca = PCA(n_components=14) #Start by looking at the firts 25 PCs.
pca.fit(regressors_mat)

#Plot the cumulative explained variance by the main PCs.
x=np.arange(0,14,1)
expl_var=np.cumsum(pca.explained_variance_ratio_)
fig = plt.figure(figsize=(3,3))
plt.plot(x, expl_var)
plt.xlabel('Principal Components')
plt.ylabel('Explained Variance')
plt.grid()
plt.tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [12]:
#Define number of principal components based on the explained variance per PC above
n_components = 6
pca=PCA(n_components=n_components)
regressors_pca=pca.fit_transform(regressors_mat)

In [13]:
pcs_mat = np.full((n_components, regressors_mat.shape[1]), np.nan)

for pc in range(n_components):
    zeros = np.zeros(n_components)
    zeros[pc] = 1
    pcs_mat[pc, :] = pca.inverse_transform(zeros)

In [14]:
pcs_fig = plt.figure()

for pc in range(n_components):   
    plt.plot(pcs_mat[pc, :] - .25*pc, c=sns.color_palette()[2])
    
plt.yticks([])
plt.ylabel('PCs')
plt.xlabel('Timepoints')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 0, 'Timepoints')

In [15]:
if fig_fold is not None:
    pcs_fig.savefig(fig_fold / "principal_components.png")

In [16]:
#Plot PC coefficients for regressors
gc_regs = np.arange(n_gc_clust)
io_regs = np.arange(n_io_clust)+n_gc_clust

reg_pcs_fig, axes = plt.subplots(n_components, n_components, figsize=(7, 7), sharex=True, sharey=True)

for i in range(n_components):
    for j in range(n_components):
        if j>=i:
            axes[i, j].axis('off')
        else:
            for gc_reg in gc_regs:
                axes[i, j].plot([0, regressors_pca[gc_reg, j]], [0, regressors_pca[gc_reg, i]], c=sns.color_palette()[0])
            for io_reg in io_regs:
                axes[i, j].plot([0, regressors_pca[io_reg, j]], [0, regressors_pca[io_reg, i]], c=sns.color_palette()[1])
            axes[i, j].set_xticks([])
            axes[i, j].set_yticks([])
            axes[i, j].axvline(0, ls='--', c='white', alpha=.2)
            axes[i, j].axhline(0, ls='--', c='white', alpha=.2)
        
for i in range(n_components):
    axes[i, 0].set_ylabel('PC{}'.format(i+1))
    
for j in range(n_components):
    axes[-1, j].set_xlabel('PC{}'.format(j+1))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [17]:
if fig_fold is not None:
    reg_pcs_fig.savefig(fig_fold / "regressor_PCs.png")

## Clean up PC traces:

In [18]:
from luminance_analysis.utilities import smooth_traces, butter_highpass_filter

In [19]:
def filter_cell_rep_block(cellmat, cutoff=1/80, smooth_wnd=3):
    """ Filter traces from the raw traces block.
    Return a repetitions block containing only the valid repetitions.
    """
    cutoff = 1 / 80
    dt = stim[1, 0]
    
    # Select entries with valid numbers in the repetition matrix:
    cellmat = cellmat[:, ~np.isnan(cellmat).all(0)].copy()
    
    # zscore repetition-wise, important for ROIs spanning more than one plane
    cellmat = (cellmat - np.nanmean(cellmat, 0)) / np.nanstd(cellmat, 0)
    
    # concatenate, highpass filter with very low cutoff, and smooth:
    trace = np.concatenate(cellmat.T, 0)
    filtered = butter_highpass_filter(trace, cutoff, 1 / dt)  # filt trace
    filtered = smooth_traces(filtered[np.newaxis, :], win=3, method="mean")[0, :]  # smooth
    filtered[np.isnan(filtered)] = 0
    
    # reshape in original form, and zscore again after filtering and smoothing:
    reshaped = filtered.reshape(cellmat.T.shape).T  
    reshaped = (reshaped - np.nanmean(reshaped, 0))/np.nanstd(reshaped, 0)
    
    return reshaped

In [20]:
# Cleanup parameters:
cutoff_hz = 1 / 80  # long cutoff for highpass filter - remove long fluctuations in PC signal
smooth_wnd = 3  # smoothing window to reduce noise. Our data is sampled ar around 0.25 seconds

raw_traces = data_dict["PC"]["traces"]  # raw PC fluorescences
rel_idxs = data_dict["PC"]["rel_idxs"]  # reliability indexes for each PC cell

n_rois = raw_traces.shape[0]  # number of ROIs
n_rep_timepts = raw_traces.shape[1]  # timepoints per repetition
n_reps_max = raw_traces.shape[2]  # maximum number of repetitions in a cell

# Find number of valid repetitions for each cell:
n_valid_reps = (~np.isnan(raw_traces).all(1)).sum(1)

# Clean up:
clean_traces = np.full(raw_traces.shape, np.nan)
for i_roi in range(n_rois):
    clean_block = filter_cell_rep_block(raw_traces[i_roi, :, :], cutoff=cutoff_hz, smooth_wnd=smooth_wnd)
    clean_traces[i_roi, :, :n_valid_reps[i_roi]] = clean_block

In [21]:
# Plot a trace panel:
def little_trace_plot(clean_traces, i=0, figure=None, ax=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize=(3, 1))

    if ax is None:
        ax = add_offset_axes(figure, (0., 0., 1, 1), frameon=False, frame=frame)
    offset = 0.01
    plt.plot(clean_traces[i, :, :1].T.flatten())
    ax.axis("off")

In [22]:
little_trace_plot(clean_traces, i=18)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [23]:
#Transform Purkinje cell traces to the regressor PC space
traces_pca = np.full((clean_traces.shape[0], n_components, clean_traces.shape[2]), np.nan)

for roi in range(clean_traces.shape[0]): # For now, we iterate in such a way to deal with NaNed unexisting repetitions
    for rep in range(clean_traces.shape[2]):
        try:
            traces_pca[roi, :, rep] = pca.transform(clean_traces[roi, :, rep].reshape(1, -1))
        except ValueError:
            pass
        
traces_pca.shape

(672, 6, 48)

### PCA analysis

First we will see how much GC and IO clusters contribute to individual Purkinje cell responses without the need to fit anything.

To do so, we will start by calculating the average response for each Purkinje cell trace, normalizing it in the same way we normalized the GC and IO regressors, and projecting the traces to that same pC space.

In [24]:
# Average Purkinje cell traces
mean_traces = np.nanmean(clean_traces, 2)

# Make strictly positive and with integral == 1:
for roi in range(mean_traces.shape[0]):
    roi_trace = mean_traces[roi, :]
    roi_trace -= np.min(roi_trace)  # offset at 0
    mean_traces[roi, :] = roi_trace / np.sum(roi_trace)

# And transform to regressor PC space
mean_traces_pca = pca.transform(mean_traces)

In [25]:
def calculate_pc_similarity(reg_pcs, traces_pcs):
    
    #Calculate average and normalized std of regressor PCs
    reg_mean = np.nanmean(reg_pcs, 0)
    reg_std = np.nanstd(reg_pcs, 0)
    reg_std_norm = reg_std/np.abs(reg_mean)
    
    #For each ROI, calculate the absolute difference between its loads and the average ones from the regressors, multiply them by the relative weight attributed to each PC (inverse of normalized std), and sum them up
    like_idx = np.sum(np.abs(traces_pcs-reg_mean)*(1/reg_std_norm), 1)
    
    return 1/like_idx

In [26]:
gc_like_index = calculate_pc_similarity(regressors_pca[gc_regs], mean_traces_pca)
io_like_index = calculate_pc_similarity(regressors_pca[io_regs], mean_traces_pca)

In [27]:
gc_io_idx = (gc_like_index-io_like_index)/(gc_like_index+io_like_index)

In [28]:
idx_distr_fig = plt.figure()
plt.hist(gc_io_idx, bins=50, color=sns.color_palette()[2]);
plt.axvline(0, ls='--', c='black')
plt.xlabel('GC/IO index')
plt.ylabel('Counts')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0, 0.5, 'Counts')

In [29]:
if fig_fold is not None:
    idx_distr_fig.savefig(fig_fold / "gc_io_idx_hist.png")

In [30]:
#Create a IO-GC colormap
from matplotlib.colors import LinearSegmentedColormap
colors = [sns.color_palette()[1], (.72,)*3, sns.color_palette()[0]]
n_bins = 200
cmap_name = 'contributions_cmap'
custom_cmap = LinearSegmentedColormap.from_list(cmap_name, colors, N=n_bins)

In [31]:
rel_idx = data_dict["PC"]["rel_idxs"]

In [32]:
gc_io_idx_fig = plt.figure()
plt.scatter(gc_io_idx, rel_idx, c=gc_io_idx, cmap=custom_cmap, vmin=-1, vmax=1)
plt.axvline(0, c='black', zorder=-100, ls='--')

plt.ylabel('Reliability index')
plt.xlabel('GC/IO index')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 0, 'GC/IO index')

In [33]:
if fig_fold is not None:
    gc_io_idx_fig.savefig(fig_fold / "gc_io_idx_rel.png")

In [34]:
reg_pc_corr = np.full((regressors_mat.shape[0], n_components), np.nan)

for reg in range(regressors_mat.shape[0]):
    for pc in range(pcs_mat.shape[0]):
        reg_pc_corr[reg, pc] = np.corrcoef(regressors_mat[reg, :], pcs_mat[pc, :])[0, 1]

In [35]:
pc_gc_likeness = np.abs(reg_pc_corr)[gc_regs].sum(0)
pc_io_likeness = np.abs(reg_pc_corr)[io_regs].sum(0)

In [36]:
pc_gc_contr = np.full(n_components, np.nan)
pc_io_contr = np.full(n_components, np.nan)

for pc in range(n_components):
    pc_gc_contr[pc] = pc_gc_likeness[pc]/(pc_gc_likeness[pc]+pc_io_likeness[pc])
    pc_io_contr[pc] = pc_io_likeness[pc]/(pc_gc_likeness[pc]+pc_io_likeness[pc])

In [37]:
gc_contribution = np.full_like(mean_traces, np.nan)
io_contribution = np.full_like(mean_traces, np.nan)

for roi in range(mean_traces_pca.shape[0]):
    
    gc_coef = mean_traces_pca[roi, :] * pc_gc_contr
    io_coef = mean_traces_pca[roi, :] * pc_io_contr
    
    gc_weighted_pcs = pcs_mat * gc_coef[:, None]
    io_weighted_pcs = pcs_mat * io_coef[:, None]
    
    gc_contribution[roi, :] = np.sum(np.abs(gc_weighted_pcs), 0)
    io_contribution[roi, :] = np.sum(np.abs(io_weighted_pcs), 0)  

In [38]:
time_contributions_mat = (gc_contribution-io_contribution) / (gc_contribution+io_contribution)

In [39]:
time_contr_fig, ax = plt.subplots(figsize=(10, 3))

# Find quartiles:
median_contr = np.nanmedian(time_contributions_mat, 0)
low_quart_contr = np.nanquantile(time_contributions_mat, 0.25, axis=0)
high_quart_contr = np.nanquantile(time_contributions_mat, 0.75, axis=0)

# Plot trace and filling:
i_col=5
ax.plot(stim[:, 0], median_contr, color=sns.color_palette()[i_col])
ax.fill_between(stim[:,0], low_quart_contr, high_quart_contr, facecolor=sns.color_palette()[i_col], 
                     alpha=.3, zorder=100, edgecolor=None)
shade_plot(stim, ax, shade_range=(0.75, 0.98))
ax.axhline(0, c = (0.3,)*3, zorder=1, ls='--')
ax.set_ylim(-.5, .5)
ax.set_xlim(0, stim[-1, 0])

make_bar(ax, (stim[-1, 0]-10, stim[-1, 0]), label="10 s")
ax.set_ylabel('Contribution ratio')
ax.text(0, 1.03, 'GC contributions dominates', fontsize=7, color=sns.color_palette()[0], transform=ax.transAxes, va='center')
ax.text(0, -.03, 'IO contributions dominates', fontsize=7, color=sns.color_palette()[1], transform=ax.transAxes, va='center')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0, -0.03, 'IO contributions dominates')

In [40]:
if fig_fold is not None:
    time_contr_fig.savefig(fig_fold / "time_contribution.png")

## Separate testing and training traces:

In [None]:
from random import shuffle, seed

In [None]:
# Fix randomness for reproducibility:
np.random.seed(572704)
seed(572704)

# Generate list of 2 indexes for each cell which will be used to keep traces out for the testing part.
# Generate randomly indexes for test and train set of traces:
n_test_reps = 3

test_idxs = []
fit_idxs = []
for n_resps in n_valid_reps:
    idxs = np.arange(n_resps)  # possible repetitions indexes
    shuffle(idxs)  # shuffle index list
    test_idxs.append(idxs[:n_test_reps])  # test idxs will be the first 2
    fit_idxs.append(idxs[n_test_reps:])  # the rest goes for training

## Define boundaries, function & regularisation/cost functions

In [None]:
from scipy import optimize

In [None]:
## Functions for the regression.

# This is the the main function that we actually use to describe PC activity:
def offset_cluster_combine(coefs, regressors):
    """ Compute trace from offset/coefficients and regressors.
    It assumes coefs and regressors for GC and IO are all concatenated,
    and first element of coefs array is the offset.
    """
    return coefs[0] + np.sum(coefs[1:] * regressors, 1)  # first term is baseline

def cost_func(fit_coefs, regressors, trace2fit, model):
    """ Cost function: sum of squares.
    """
    diff = trace2fit - model(fit_coefs, regressors)
    return np.sum(diff**2) / trace2fit.shape[0]

def reg_func(fit_coefs, reg_coef=0):
    """ Regularization function: sum of absolute coefs values.
    Does not regularize the offset, so first term is excluded:
    """
    return np.sum(np.abs(fit_coefs[1:]))

def minimization_func(fit_coefs, regressors, trace2fit, model, cost_func, reg_func, reg_lamda=0):
    """Full function to minimize, including cost and regularization terms.
    """
    return cost_func(fit_coefs, regressors, trace2fit, model) + reg_func(fit_coefs) * reg_lamda

In [None]:
# set starting values and bounds for the offset (the +1 below) and the coefficients.

# Initial guesses:
coefs_init_guess = np.zeros(regressors_mat.shape[0] + 1) 

# We know that IO input can only positively contribute to PC activity, 
# so we set IO coefficients to be positive:
w_bound = 100  # bound for regressors weights (high b/c of normalization differences) 
off_bound = 5  # bound for the offset
coefs_bounds =[(-off_bound, off_bound)] + \
              [(-w_bound, w_bound) for _ in range(n_gc_clust)] + \
              [(-w_bound, w_bound) for _ in range(n_io_clust)]

## Validate regularization parameter

Here we test for the optimal regularization lambda. As this parameter search can take quite long, you can just skip the section and execute from the next block.

In [None]:
from sklearn.model_selection import LeaveOneOut

In [None]:
# Prepare a concatenation of regressors long enough to fit the longest possible trace:
regressors_concat = np.concatenate([regressors_mat,]*(n_reps_max - n_test_reps), 1).T

# as the cost of a fit with all 0 coefficients is 1 - the std of the trace (which is zscored),
# this will be our estimation of the cost that we use to decide the regularization lambda range 
# (2 orders of mag below and above the expected cost):
reg_lambda_arr = np.insert(10**np.arange(-7., -2, 0.35), 0, 0)

# Initialise empty matrices for storing the costs and the lambda parameters for all left-one-out fits:
n_lambdas = reg_lambda_arr.shape[0]
costs = np.full((n_rois, n_lambdas, n_reps_max - n_test_reps), np.nan)

In [None]:
#%%time
# Use scikit learn leave-one-out iterator:
loo = LeaveOneOut()

n_downsample = 1  # skip cells if we are testing. Otherwise, set to 1

if n_downsample == 1 and (Path("") / "regularization_costs.h5").exists():
    costs = fl.load("regularization_costs.h5")["costs"]
else:
    for i_roi in range(0, n_rois, n_downsample):
        if np.mod(i_roi, 50) == 0:
            print(i_roi)
        roi_fit_idxs = fit_idxs[i_roi]

        # Hyperparameter grid search:
        for i_lambda, reg_lambda in enumerate(reg_lambda_arr):

            # Leave-one-out validation:
            for i_loo, (idxs_train, idxs_valid) in enumerate(loo.split(roi_fit_idxs)):
                roi_train_trace = clean_traces[i_roi, :, roi_fit_idxs[idxs_train]].flatten()
                roi_valid_trace = clean_traces[i_roi, :, roi_fit_idxs[idxs_valid[0]]]
                res = optimize.minimize(minimization_func, method='SLSQP', 
                                        args=(regressors_concat[:roi_train_trace.shape[0], :], 
                                              roi_train_trace, offset_cluster_combine, 
                                              cost_func, reg_func, reg_lambda),
                                        x0=coefs_init_guess, bounds=coefs_bounds)

                costs[i_roi, i_lambda, i_loo] = cost_func(res.x, regressors_concat[:roi_valid_trace.shape[0], :], 
                                                          roi_valid_trace, offset_cluster_combine)
        #         print("%s %s" % (train, test))
        
    if n_downsample == 1:
        fl.save("regularization_costs.h5", dict(costs=costs))

In [None]:
def max_within_std_err(cell_costs):
    mean_cost = np.nanmean(cell_costs, 1)
    std_err_cost = np.nanstd(cell_costs, 1) / np.sqrt(np.sum(~np.isnan(cell_costs[0, :])) - 1)

    min_idx = np.argmin(mean_cost)
    std_err_min = std_err_cost[min_idx]

    i = min_idx
    while i < len(mean_cost) - 1 and mean_cost[i + 1] < mean_cost[min_idx] + std_err_min:
        i += 1
    
    if i != len(mean_cost) and mean_cost[i] < 1:
        return i
    else:
        return np.nan

In [None]:
plt.figure(figsize=(4,3))
i = 21
i *= 30
mean_err = np.nanmean(costs[i, :, :], 1)
std_err = np.nanstd(costs[i, :, :], 1) / np.sqrt(np.sum(~np.isnan(costs[i, 0, :])) - 1)
print(mean_err + std_err)

= max_within_std_err(costs[i, :, :])
if ~np.isnan(threshold):
    plt.axvline(threshold, c=(0.3,)*3)
    plt.axhline(mean_err[threshold],  c=(0.3,)*3)
#plt.fill_between(np.log10(reg_lambda_arr), mean_err-std_err, mean_err+std_err, facecolor=sns.color_palette()[0], 
#                 alpha=0.4, edgecolor="None")
#plt.plot(np.log10(reg_lambda_arr), mean_err, color=sns.color_palette()[0])
plt.fill_between(np.arange(len(reg_lambda_arr)), mean_err-std_err, mean_err+std_err, facecolor=sns.color_palette()[0], 
                 alpha=0.4, edgecolor="None")
plt.plot(mean_err, color=sns.color_palette()[0])

plt.xlabel("Regularization coeff.")
plt.ylabel("Mean MSE")
plt.tight_layout()


In [None]:
final_lambda_reg = reg_lambda_arr[11]  # reg_coefs[np.argmin(np.nanmean(costs, 1))]
# final_lambda_reg = 0.00031  # from long fit

In [None]:
gc_regs = np.arange(n_gc_clust)
io_regs = np.arange(n_io_clust)+n_gc_clust

In [None]:
fig, axes = plt.subplots(n_components, n_components, sharex=True, sharey=True)

for i in range(n_components):
    for j in range(n_components):
        if j>=i:
            axes[i, j].axis('off')
        else:
            for gc_reg in gc_regs:
                axes[i, j].plot([0, regressors_pca[gc_reg, j]], [0, regressors_pca[gc_reg, i]], c=sns.color_palette()[0])
            for io_reg in io_regs:
                axes[i, j].plot([0, regressors_pca[io_reg, j]], [0, regressors_pca[io_reg, i]], c=sns.color_palette()[1])
            axes[i, j].set_xticks([])
            axes[i, j].set_yticks([])
            axes[i, j].axvline(0, ls='--', c='white', alpha=.2)
            axes[i, j].axhline(0, ls='--', c='white', alpha=.2)
        
for i in range(n_components):
    axes[i, 0].set_ylabel('PC{}'.format(i+1))
    
for j in range(n_components):
    axes[-1, j].set_xlabel('PC{}'.format(j+1))

In [None]:
regressors_pca[gc_regs]

In [None]:
gc_mean = np.nanmean(regressors_pca[gc_regs], 0)
gc_std = np.nanstd(regressors_pca[gc_regs], 0)
gc_std_norm = gc_std/np.abs(gc_mean)

In [None]:
regressors_pca[io_regs]

In [None]:
io_mean = np.nanmean(regressors_pca[io_regs], 0)
io_std = np.nanstd(regressors_pca[io_regs], 0)
io_std_norm = io_std/np.abs(io_mean)

In [None]:
traces_pca.shape

In [None]:
io_std_norm

In [None]:
io_mean

In [None]:
plt.figure()
plt.plot(regressors_mat[3, :])

plt.plot(pca.inverse_transform(regressors_pca[3, :]))


## Fit the traces

Fit the traces with the specified lambda:

In [None]:
costs_final = np.full(n_rois, np.nan)  # Costs of our final fit
coefs_final = np.full((n_rois, regressors_mat.shape[0] + 1), np.nan)  # Coefs from our final fit


valid_roi_idxs = []
for i_roi in range(n_rois):
    if np.mod(i_roi, 100) == 0:
        print(i_roi)
    cost_idx = max_within_std_err(costs[i_roi, :, :])
    
    if not np.isnan(cost_idx):
        final_lambda_reg = reg_lambda_arr[cost_idx]
        
        # Get test and fit indexes and traces:
        roi_test_idxs = test_idxs[i_roi]
        roi_test_trace = clean_traces[i_roi, :,roi_test_idxs].flatten()

        roi_fit_idxs = fit_idxs[i_roi]
        roi_fit_trace = clean_traces[i_roi, :, roi_fit_idxs].flatten()  

        # Fit the train set:
        res = optimize.minimize(minimization_func, method='SLSQP', 
                                args=(regressors_concat[:roi_fit_trace.shape[0], :], 
                                      roi_fit_trace, offset_cluster_combine, 
                                      cost_func, reg_func, 0)
                                ,
                                x0=coefs_init_guess, bounds=coefs_bounds)

        # Calculate cost on the test set:
        costs_final[i_roi] = cost_func(res.x, regressors_concat[:roi_test_trace.shape[0], :], 
                                                  roi_test_trace, offset_cluster_combine)
        # Save the coefficients:
        coefs_final[i_roi, :] = res.x
        
        valid_roi_idxs.append(i_roi)
        
valid_roi_idxs = np.array(valid_roi_idxs)
n_valid_rois = len(valid_roi_idxs)

Keep only properly fit rois data:

In [None]:
coefs_final_sel = coefs_final[valid_roi_idxs, :]
costs_final_sel = costs_final[valid_roi_idxs]
clean_traces_sel = clean_traces[valid_roi_idxs, :, :]

test_idxs_sel = [test_idxs[i] for i in valid_roi_idxs]
fit_idxs_sel = [fit_idxs[i] for i in valid_roi_idxs]
rel_idx_sel = data_dict["PC"]["rel_idxs"][valid_roi_idxs]
clust_lab_sel = data_dict["PC"]["clust_labels"][valid_roi_idxs]

Fit with shuffled weights to decide what is a "random" fit:

In [None]:
coefs_shuf = np.empty_like(coefs_final_sel)

for i in range(coefs_final_sel.shape[1]):
    shuf_idx = np.random.permutation(coefs_final_sel.shape[0])
    coefs_shuf[:, i] = coefs_final_sel[shuf_idx, i]
    
costs_shuf = np.full(n_valid_rois, np.nan)  # Costs from shuffled weights

for i_roi in range(n_valid_rois):
    
    # Get test and fit indexes and traces:
    roi_test_idxs = test_idxs_sel[i_roi]
    roi_test_trace = clean_traces_sel[i_roi, :,roi_test_idxs].flatten()
        
    # Calculate cost on the test set with shuffled weights:
    costs_shuf[i_roi] = cost_func(coefs_shuf[i_roi, :], regressors_concat[:roi_test_trace.shape[0], :], 
                                              roi_test_trace, offset_cluster_combine)
    
cost_threshold = np.percentile(costs_shuf, 5)
sel_fit = np.argwhere(costs_final_sel < cost_threshold)[:, 0]
n_valid_rois = len(sel_fit)

In [None]:
coefs_final_sel = coefs_final_sel[sel_fit, :]
costs_final_sel = costs_final_sel[sel_fit]
clean_traces_sel = clean_traces_sel[sel_fit, :, :]

test_idxs_sel = [test_idxs_sel[i] for i in sel_fit]
fit_idxs_sel = [fit_idxs_sel[i] for i in sel_fit]
rel_idx_sel = rel_idx_sel[sel_fit]
clust_lab_sel = clust_lab_sel[sel_fit]

In [None]:
def cost_figure(costs_final, costs_shuf, cost_threshold, figure=None, frame=None):
    
    bin_array = np.arange(0.01, 1.6, 0.05)
    costs_shuf_hist, b = np.histogram(costs_shuf, bin_array)
    costs_final_hist, b = np.histogram(costs_final, bin_array)
    costs_final_sel_hist, b = np.histogram(costs_final[costs_final < cost_threshold], bin_array)
    
    if figure is None:
        figure = plt.figure(figsize=(3.,2))
    
    a = 0.7
    ax_coefs = add_offset_axes(figure, (0.2, 0.2, 0.7, 0.7), frame=frame)
    
    x = (bin_array[1:] + bin_array[:-1]) / 2
    ax_coefs.fill_between(x, costs_shuf_hist, step="mid", alpha=a, edgecolor=None, facecolor=sns.color_palette()[1])
    ax_coefs.fill_between(x, costs_final_sel_hist, step="mid", alpha=a, edgecolor=None, facecolor=sns.color_palette()[0])
    ax_coefs.step(x, costs_final_hist, where="mid", alpha=a, color=sns.color_palette()[0])
    
    ax_coefs.axvline(cost_threshold, c=(0.4,)*3)
    ax_coefs.set_xlabel("Cost (on test)")
    ax_coefs.set_ylabel("Count")

In [None]:
cost_figure(costs_final, costs_shuf, cost_threshold)

In [None]:
cost_figure(costs_final, costs_shuf, cost_threshold)

In [None]:
idx = np.argsort(costs_final)[0]

ylim = 2.5
coefs = coefs_final[idx, :]
cellmat = traces[idx, :, :]
roi_test_idxs = test_idxs[idx]
test = clean_traces[idx, :,roi_test_idxs]
    
roi_fit_idxs = fit_idxs[idx]
fit = clean_traces[idx, :, roi_fit_idxs]

plt.figure(figsize=(8, 5))
plt.subplot(2,1,1)
all_resp = fit
# reshaped = all_resp.reshape(all_resp.shape[0] // cellmat.shape[0], cellmat.shape[0])
plt.plot(all_resp.T, color = (0.6,)*3, linewidth=0.6)
plt.plot(np.mean(all_resp, 0))
plt.plot(nanzscore(offset_cluster_combine(coefs, regressors_mat.T)))
plt.ylim(-ylim, ylim)
plt.ylabel("Train set")

plt.subplot(2,1,2)
all_resp = test
# reshaped = all_resp.reshape(all_resp.shape[0] // cellmat.shape[0], cellmat.shape[0])
plt.plot(all_resp.T, color = (0.6,)*3, linewidth=0.6)
plt.plot(np.mean(all_resp, 0))
plt.plot(nanzscore(offset_cluster_combine(coefs, regressors_mat.T)))
plt.ylim(-ylim, ylim)
plt.ylabel("Test set")
plt.show()

# Analyse coefficients

In [None]:
def coefs_plot(clust_lab, coefs_final, figure=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize=(4, 3))
    
    idxs_sort = np.argsort(clust_lab)

    c_lim = 100
    ax_coefs = add_offset_axes(figure, (0.05, 0.2, 0.75, 0.8), frame=frame)
    im = ax_coefs.imshow(coefs_final[idxs_sort, 1:].T, vmin=0, vmax=c_lim, aspect="auto", cmap="Reds")
    ax_coefs.set_xlabel("Roi n.")
    ax_coefs.axhline(7.5, c=(0.3,)*3)
    ax_coefs.set_yticks([3., 10.])
    ax_coefs.set_ylim([13.5, -0.5])
    ax_coefs.tick_params(length=0)
    ax_coefs.set_yticklabels(["$w_{GC}$", "$w_{ION}$"], rotation=90)
    [ax_coefs.axes.spines[s].set_visible(False) for s in
         ["left", "right", "top", "bottom"]]
    for ytick, color in zip(ax_coefs.get_yticklabels(), sns.color_palette()[:2]):
        ytick.set_color(color)

    k = np.sum(clust_lab == 0)
    for n in range(1, clust_lab.max() + 1): 
        ax_coefs.axvline(k, c=(0.3,)*3)
        k += np.sum(clust_lab == n)


    axcolor = add_offset_axes(figure, [0.86, 0.2, 0.02, 0.12], frame=frame)
    cbar = plt.colorbar(im, cax=axcolor, orientation="vertical")
    cbar.set_ticks([0, c_lim])
    cbar.ax.tick_params(length=3)

    plt.tight_layout()

In [None]:
coefs_plot(clust_lab_sel, coefs_final_sel, figure=None, frame=None)

### Check error histogram and relationship with cell reliability

Check error of the fit as a function of the cell reliability index. We expect a negative relationship:

In [None]:
def error_vs_reliabil(rel_idx, costs_final, figure=None, frame=None):
    
    
    
    if figure is None:
        figure = plt.figure(figsize=(3,2))
        
    ax = add_offset_axes(figure, (0.2, 0.2, 0.7, 0.7), frame=frame)
    ax.scatter(rel_idx, costs_final, s=5)
    ax.set_xlabel("Reliability idx")
    ax.set_ylabel("Fit error")

In [None]:
error_vs_reliabil(rel_idx_sel, costs_final_sel)

### Look at distribution of IO and GC coefficients weights

In [None]:
cleaned_coefs = coefs_final_sel.copy()
cleaned_coefs = cleaned_coefs[~(cleaned_coefs == 0).all(1), :]
coefs_sum = np.nansum(np.abs(cleaned_coefs), 1)
norm_coefs = (cleaned_coefs.T / coefs_sum).T
non_zero_coefs = np.abs(norm_coefs) > 0.1

In [None]:
def gc_io_weights_hist(cleaned_coefs, figure=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize=(3,2))
        
    ax = add_offset_axes(figure, (0.2, 0.2, 0.7, 0.7), frame=frame)
    l=400
    ax.hist(cleaned_coefs[:, 1:9].sum(1), np.arange(-l, l, 20), alpha=0.6, label="GC")
    ax.hist(cleaned_coefs[:, 9:].sum(1), np.arange(-l, l, 20), alpha=0.6, label="IO")
    ax.legend()
    ax.set_xlabel("Weight of coefs")
    ax.set_ylabel("Count")

In [None]:
gc_io_weights_hist(cleaned_coefs)

### Look at distribution of number of coefficients

In [None]:
def gc_io_nonzero_hist(non_zero_coefs, figure=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize=(3,2))
        
    ax = add_offset_axes(figure, (0.2, 0.2, 0.7, 0.7), frame=frame)
    ax.hist(non_zero_coefs[:, 1:9].sum(1), np.arange(0, 15), 
            label="GC (mn={:2.1f})".format(np.nanmean(non_zero_coefs[:, 1:9].sum(1))), alpha=0.6)
    ax.hist(non_zero_coefs[:, 9:].sum(1), np.arange(0, 15), 
            label="IO (mn={:2.1f})".format(np.nanmean(non_zero_coefs[:, 9:].sum(1))), alpha=0.6)
    ax.legend()
    ax.set_xlabel("N. of coefs")
    ax.set_ylabel("Count")

In [None]:
gc_io_nonzero_hist(non_zero_coefs)

### GC-IO index

In [None]:
from scipy import stats

Create "GC vs IO-ness index", looking at the ratio of coefficients of GC clusters and IO clusters:

In [None]:
w_gc = np.abs(coefs_final_sel[:, 1:n_gc_clust+1]).mean(1)
w_io = np.abs(coefs_final_sel[:, n_gc_clust+1:]).mean(1)
gc_io_idx = (w_gc - w_io) / (w_gc + w_io)

In [None]:
#Create a IO-GC colormap
from matplotlib.colors import LinearSegmentedColormap
colors = [sns.color_palette()[1], (.72,)*3, sns.color_palette()[0]]
n_bins = 200
cmap_name = 'contributions_cmap'
custom_cmap = LinearSegmentedColormap.from_list(cmap_name, colors, N=n_bins)

In [None]:
def indexes_plot(gc_io_idx, rel_idxs, figure=None, frame=None):
    r = stats.spearmanr(gc_io_idx[~np.isnan(gc_io_idx)], rel_idxs[~np.isnan(gc_io_idx)])
    
    
    if figure is None:
        figure = plt.figure(figsize=(3., 2.))
    
    ax = add_offset_axes(figure, (0.2, 0.2, 0.7, 0.7), frame=frame)
    ax.axvline(0, c = (0.6, )*3, zorder=-100)
    ax.scatter(gc_io_idx, rel_idxs, c=gc_io_idx, cmap=custom_cmap, s=6, vmin=-1, vmax=1, 
               edgecolor=(0.3,)*3, linewidths=0.1)
    ax.set_xlim(-1.1, 1.1)
    ax.set_ylabel("Reliability index")
    ax.set_xlabel("GC/IO idx")
#     numbers = 
    ax.text(-1, 0.7, "$\\rho$" + "={:1.2f} \np={:1.2}".format(r.correlation, r.pvalue), fontsize=7, color=(0.3,)*3)


In [None]:
indexes_plot(gc_io_idx, rel_idx_sel)

### Single cell example fit

In [None]:
11, 17, 32, 33, (44), 71

37, 60, 68

In [None]:
def make_single_cell_plot(traces, all_coefs, roi_test_idxs, idx, 
                          legend=True, bar=True, figure=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize=(7, 2))
    coefs = all_coefs[idx, :]
    roi_test_idxs = test_idxs[idx]
    reshaped = traces[idx, :, roi_test_idxs]
    
    ylim = 3.7
    axtrace = add_offset_axes(figure, (0.0, 0.1, 0.58, 0.92), frame=frame)
    axtrace.plot(stim[:, 0], reshaped.T, color=sns.color_palette()[2], linewidth=0.5)
    axtrace.plot(stim[:, 0], np.nanmean(reshaped, 0), color=sns.color_palette()[2], linewidth=2, label="PC trace")
    axtrace.plot(stim[:, 0], nanzscore(offset_cluster_combine(coefs, regressors_mat.T)), color="k", label="Fit")
    shade_plot(stim, axtrace, shade_range=(0.75, 0.98))
    
    axtrace.set_ylim(-ylim, ylim)
    axtrace.set_xlim(0, stim[-1, 0])
    axtrace.spines["left"].set_visible(False)
    axtrace.set_yticks([])
    if bar:
        make_bar(axtrace, (stim[-1, 0]-10, stim[-1, 0]), label="10 s")
    else:
        axtrace.spines["bottom"].set_visible(False)
        axtrace.set_xticks([])
    
    # Legend
    handles, labels = axtrace.get_legend_handles_labels()
    unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]]
    if legend:
        plt.legend(*zip(*unique), loc="lower left", fontsize=7)
    
    # Inserts
    ylim = 1.5
    y_off = -0.5
    coefs_gc = np.zeros(coefs.shape)
    coefs_io = np.zeros(coefs.shape)
    coefs_gc[1:9] = coefs[1:9]
    coefs_io[9:] = coefs[9:]
    for n, (ax_pos, lab, c_coefs) in enumerate(zip([(0.6, 0.6, 0.4, 0.4), (0.6, 0.1, 0.4, 0.4)], 
                                                   ["GC", "IO"],
                                                   [coefs_gc, coefs_io])):
        col = sns.color_palette()[n]
        ax_little = add_offset_axes(figure, ax_pos, frame=frame)
        ax_little.plot(stim[:, 0], offset_cluster_combine(c_coefs, regressors_mat.T), color=col)
        shade_plot(stim, ax_little, shade_range=(0.75, 0.98))
        ax_little.set_ylim(-ylim-y_off, ylim-y_off)
        ax_little.axis("off")
        
        if legend:
            ax_little.text(2, -ylim-y_off + 0.2, lab + " contribution", color=col, fontsize=7)

cells  # (gc_io idx sort): 
- 97, 254 example IO
- 197, example comb
- 89, 115, 60 example comb

In [None]:
97, 197, 89

In [None]:
sel = costs_final_sel < 0.9
idx = np.argsort(gc_io_idx[sel])[-11]
make_single_cell_plot(clean_traces_sel, coefs_final_sel, fit_idxs_sel, idx, figure=None, frame=None)
# for idx in [98]:
#     make_single_cell_plot(clean_traces_sel, coefs_final_sel, fit_idxs_sel, idx, figure=None, frame=None)

### Time contribution plot

In [None]:
time_contributions_mat = np.zeros((n_rep_timepts, n_valid_rois))
coefs_gc = np.zeros(coefs_final_sel.shape)
coefs_io = np.zeros(coefs_final_sel.shape)
coefs_gc[:, 1:9] = coefs_final_sel[:, 1:9]
coefs_io[:, 9:] = coefs_final_sel[:, 9:]

for i_roi in range(n_valid_rois):
       
    gc_contribution = offset_cluster_combine(coefs_gc[i_roi, :], regressors_mat.T)
    io_contribution = offset_cluster_combine(coefs_io[i_roi, :], regressors_mat.T)
    
    contribution_ratio = (np.abs(gc_contribution) - np.abs(io_contribution)) / \
                (np.abs(gc_contribution) + np.abs(io_contribution))
    
    time_contributions_mat[:, i_roi] = contribution_ratio

In [None]:
def time_contr_plot(time_contributions_mat, stim, figure=None, frame=None):
    
    if figure is None:
        figure = plt.figure(figsize=(7, 1.5))
    
    ax = add_offset_axes(figure, (0.1, 0.1, 0.9, 0.9), frame=frame)

    # Find quartiles:
    median_contr = np.nanmedian(time_contributions_mat, 1)
    low_quart_contr = np.nanquantile(time_contributions_mat, 0.25, axis=1)
    high_quart_contr = np.nanquantile(time_contributions_mat, 0.75, axis=1)
    
    # Plot trace and filling:
    i_col=5
    ax.plot(stim[:, 0], median_contr, color=sns.color_palette()[i_col])
    ax.fill_between(stim[:,0], low_quart_contr, high_quart_contr, facecolor=sns.color_palette()[i_col], 
                         alpha=.3, zorder=100, edgecolor=None)
    shade_plot(stim, ax, shade_range=(0.75, 0.98))
    ax.axhline(0, c = (0.3,)*3, zorder=1)
    ax.set_ylim(-1., 1.)
    ax.set_xlim(0, stim[-1, 0])
#     for y_line in [-1, 1]:
#         ax.axhline(y_line, color=(0.3,)*3, linewidth=0.5)
    
    make_bar(ax, (stim[-1, 0]-10, stim[-1, 0]), label="10 s")
    ax.set_ylabel('Contribution ratio')
    ax.text(0.5, 1.05, 'GC contributions dominates', fontsize=7, color=sns.color_palette()[0])
    ax.text(0.5, -1.15, 'IO contributions dominates', fontsize=7, color=sns.color_palette()[1])

In [None]:
time_contr_plot(time_contributions_mat, stim)

# Assemble final panel

In [None]:
import matplotlib.patches as mpatches
import matplotlib.lines as lines

In [None]:
clust_names = ["IO{}".format(i+1) for i in range(n_io_clust)] + ["GC{}".format(i+1) for i in range(n_gc_clust)]
cols = [sns.color_palette()[1]]*n_io_clust + [sns.color_palette()[0]]*n_gc_clust

In [None]:
# Plot the regressor panel:
def reg_panel_plot(regressors_mat, figure=None, ax=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize=(3, 3))

    offset = 0.01
    if ax is None:
        ax = add_offset_axes(figure, (0., 0., 1, 1), frameon=False, frame=frame)
    cols = sns.color_palette()
    for i, col in enumerate([cols[0], ]*n_gc_clust + [cols[1], ]*n_io_clust):
        ax.fill_between(stim[:, 0], np.zeros(stim[:, 0].shape) - i*offset, 
                        regressors_mat[i, :] - i*offset, color=col)
        
    ax.axis("off")
    
# Plot a trace panel:
def little_trace_plot(clean_traces, i=0, figure=None, ax=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize=(3, 1))

    if ax is None:
        ax = add_offset_axes(figure, (0., 0., 1, 1), frameon=False, frame=frame)
    offset = 0.01
    ax.plot(clean_traces[i, :, :1].T.flatten(), c=sns.color_palette()[2])
    ax.axis("off")

In [None]:
def schema_panel(clean_traces, regressors_mat, i=9, figure=None, frame=None):
    if figure is None:        
        figure = plt.figure(figsize=(7, 3))

    schema_y = 0.3

    h_cent = 0.5
    trace_h = 0.2
    reg_h = 0.9
    schema_h = 0.9
    ax_trace = add_offset_axes(figure, [0.7, h_cent - trace_h/2, 0.3, trace_h], frameon=False, frame=frame)
    ax_regressors = add_offset_axes(figure, [0.02, h_cent - reg_h/2 - 0.01, 0.3, reg_h], frameon=False, frame=frame)
    ax_schema = add_offset_axes(figure, [0.31, h_cent - schema_h/2, 0.4, schema_h], frameon=False, aspect=1., frame=frame)
    little_trace_plot(clean_traces, i=18, figure=figure, ax=ax_trace)

    reg_panel_plot(regressors_mat, figure=figure, ax=ax_regressors)

    ax_schema.xaxis.set_visible(False)
    ax_schema.yaxis.set_visible(False)

    n_clust = 14
    n_gc_clust
    n_io_clust
    c_cent = (0.6, 0.5)  # x, y
    p = mpatches.Circle(c_cent, 0.1, edgecolor="k", facecolor="w")

    w_pos_x = 0.0
    line_displ_x = 0.2
    for i_clust, (label, col) in enumerate(zip(clust_names, cols)):
        c = 1/(2*(n_clust + 1)) + i_clust/(n_clust + 1)
        l = lines.Line2D([w_pos_x + line_displ_x, c_cent[0]], [c, c_cent[1]], c="k", zorder=-100)
        ax_schema.add_line(l)
        ax_schema.text(w_pos_x, c, "$\cdot w_i^{" + label + "}$", ha="left", va="center", fontsize=7, color=col)

    c = 1/(2*n_clust) + (i_clust + 1)/(n_clust + 1)
    l = lines.Line2D([w_pos_x + line_displ_x, c_cent[0]], [c, c_cent[1]], c="k", zorder=-100)
    ax_schema.add_line(l)
    ax_schema.text(w_pos_x, c, "$offset_i$", ha="left", va="center", fontsize=7, color="k")

    ax_schema.add_patch(p)

    l = lines.Line2D([c_cent[0], 1], [c_cent[1], c_cent[1]], c="k", zorder=-100)
    ax_schema.add_line(l)

    # Text
    ax_schema.text(*c_cent, "$\sum$", ha="center", va="center", fontsize=9)
    ax_schema.text(0.55, 0.9, "$PC_i = regr^{GC} \cdot w_i^{GC} + regr^{IO} \cdot w_i^{IO} + offset_i$", 
                   ha="left", va="center", fontsize=7.5)
    ax_trace.text(0, -3, "$PC_i$", ha="left", va="center", fontsize=9, color=sns.color_palette()[2])

In [None]:
fig = plt.figure(figsize=(7, 9))
schema_panel(clean_traces, regressors_mat, i=9, figure=fig, frame=[0.25, 0.77, 0.6, 0.2])

offset_y = 0.13
start_y = 0.36
for i, idx in enumerate([97, 197, 89]):
    y = start_y + i*offset_y
    
    make_single_cell_plot(clean_traces_sel, coefs_final_sel, fit_idxs_sel, idx, figure=fig, frame=[0., y, 0.53, 0.12], legend=(i==2), bar=(i==0))

coefs_plot(clust_lab_sel, coefs_final_sel, figure=fig, frame=[0.58, 0.56, 0.4, 0.18])
indexes_plot(gc_io_idx, rel_idx_sel, figure=fig, frame=[0.58, 0.35, 0.35, 0.2])
time_contr_plot(time_contributions_mat, stim, figure= fig, frame=[0.1, 0.17, 0.65, 0.15])

In [None]:
if fig_fold is not None:
    fig.savefig(fig_fold / "fig7_panel.pdf", forma="pdf")

## Supplementary

In [None]:
figsupp = plt.figure(figsize=(6, 4))
s = 0.45
cost_figure(costs_final, costs_shuf, cost_threshold, figure=figsupp, frame=[0., 0.5, s, s])
error_vs_reliabil(rel_idx_sel, costs_final_sel, figure=figsupp, frame=[0.5, 0.5, s, s])
# gc_io_weights_hist(cleaned_coefs, figure=figsupp, frame=[0., 0., s, s])
gc_io_nonzero_hist(non_zero_coefs, figure=figsupp, frame=[0.25, 0., s, s])
# sorted_examples_plot(coefs_final, gc_io_idx, clean_traces, test_idxs, figure=figure6supp, frame=[0.5, 0.1, 0.5, 0.8])

In [None]:
if fig_fold is not None:
    figsupp.savefig(fig_fold.parent.parent / (fig_fold.parent.name + "supp/src") / "supp_panel.pdf", format="pdf")