In [1]:
%matplotlib widget

In [2]:
import flammkuchen as fl
from matplotlib import pyplot as plt
import numpy as np
from scipy import io
import seaborn as sns
import matplotlib.gridspec as gridspec
from matplotlib.ticker import FormatStrFormatter
from pathlib import Path
import pandas as pd
from luminance_analysis import PooledData, traces_stim_from_path
from ipywidgets import interact, fixed, interactive
import ipywidgets as widgets

plt.style.use("figures.mplstyle")
cols = sns.color_palette()

In [459]:
fig_fold = Path(r"C:\Users\otprat\Documents\figures\luminance\manuscript_figures\fig7_v3\all")

# Load data

In [4]:
master_path = Path(r"\\FUNES\Shared\experiments\E0032_luminance\neat_exps")
# master_path = Path(r"J:\_Shared\GC_IO_luminance\data\neat_exps")
# master_path = Path(r"/Users/luigipetrucco/Desktop/data_dictionaries/")

In [5]:
from luminance_analysis.utilities import deconv_resamp_norm_trace, reliability, nanzscore, get_kernel
from skimage.filters import threshold_otsu
from scipy.cluster.hierarchy import dendrogram, linkage, cut_tree, to_tree, set_link_color_palette
from luminance_analysis.plotting import plot_clusters_dendro, shade_plot, add_offset_axes, make_bar
from luminance_analysis.clustering import cluster_id_search, find_trunc_dendro_clusters

In [6]:
tau_6f = 5
tau_6s = 8
ker_len = 20
normalization = "zscore"
protocol = 'steps'

brain_regions_list = ["GC", "IO", "PC"]
tau_list = [tau_6f, tau_6f, tau_6s]
n_cluster_list = [8, 6, 8]
nan_thr_list = [0, 1, 1]

data_dict = {k:{} for k in brain_regions_list}

#load stimulus of GCs and use it as a the reference for time array and stimulus array:
stim_ref = PooledData(path = master_path / protocol / "GC").stimarray_rep

for brain_region, tau, n_cluster, nan_thr in zip(brain_regions_list, tau_list, 
                                                 n_cluster_list, nan_thr_list):
    #Load data :
    path = master_path / protocol / brain_region
    stim, traces, meanresps = traces_stim_from_path(path)

    # Mean traces, calculate reliability index :
    rel_idxs = reliability(traces)
    
    # Find threshold from reliability histogram...
    rel_thr = threshold_otsu(rel_idxs[~np.isnan(rel_idxs)])

    # ...and load again filtering with the threshold:
    _, traces, meanresps = traces_stim_from_path(path, resp_threshold=rel_thr, nanfraction_thr=nan_thr)

    # Hierarchical clustering:
    linked = linkage(meanresps, 'ward')
    
    # Truncate dendrogram at n_cluster level:
    plt.figure(figsize=(0.1, 0.1))  
    dendro = dendrogram(linked, n_cluster, truncate_mode ="lastp")
    plt.close()
    cluster_ids = dendro["leaves"]
    labels = find_trunc_dendro_clusters(linked, dendro) 
    
    # Deconvolution, resampling / normalization:
    deconv_meanresps = np.empty((meanresps.shape[0], stim_ref.shape[0]))
    resamp_meanresps = np.empty((meanresps.shape[0], stim_ref.shape[0]))
    for roi_i in range(deconv_meanresps.shape[0]):
        deconv_meanresps[roi_i, :] = deconv_resamp_norm_trace(meanresps[roi_i, :], stim[:, 0],
                                                                stim_ref[:, 0], tau, ker_len,
                                                                smooth_wnd=4,
                                                                normalization=normalization)
        resamp_meanresps[roi_i, :] = deconv_resamp_norm_trace(meanresps[roi_i, :], stim[:, 0],
                                                                stim_ref[:, 0], None, ker_len,
                                                                smooth_wnd=4,
                                                                normalization=normalization)
    
    cluster_resps = np.empty((n_cluster, stim_ref.shape[0]))
    for clust_i in range(n_cluster):
        cluster_resp = np.nanmean(deconv_meanresps[labels==clust_i, :], 0)  # average cluster responses
        cluster_resps[clust_i, :] = nanzscore(cluster_resp)  # normalize


    # Add everything to dictionary:
    data_dict[brain_region]["linkage_mat"] = linked
    data_dict[brain_region]["clust_labels"] = labels
    #data_dict[brain_region]["raw_mn_resps"] = meanresps
    data_dict[brain_region]["traces"] = traces
    #data_dict[brain_region]["deconv_mn_resps"] = deconv_meanresps
    #data_dict[brain_region]["resamp_mn_resps"] = resamp_meanresps
    data_dict[brain_region]["rel_idxs"] = rel_idxs[rel_idxs > rel_thr]
    data_dict[brain_region]["rel_thr"] = rel_thr
    data_dict[brain_region]["clust_resps"] = cluster_resps

[<luminance_analysis.FishData object at 0x0000013F1364FD48>, <luminance_analysis.FishData object at 0x0000013F1364FDC8>, <luminance_analysis.FishData object at 0x0000013F1365B948>, <luminance_analysis.FishData object at 0x0000013F13665148>, <luminance_analysis.FishData object at 0x0000013F13676908>]
[<luminance_analysis.FishData object at 0x0000013F13653B88>, <luminance_analysis.FishData object at 0x0000013F13653B08>, <luminance_analysis.FishData object at 0x0000013F136612C8>, <luminance_analysis.FishData object at 0x0000013F13665C88>, <luminance_analysis.FishData object at 0x0000013F13676848>]


  c /= stddev[:, None]
  c /= stddev[None, :]


[<luminance_analysis.FishData object at 0x0000013F14341088>, <luminance_analysis.FishData object at 0x0000013F14341108>, <luminance_analysis.FishData object at 0x0000013F1434D6C8>, <luminance_analysis.FishData object at 0x0000013F1435CC88>, <luminance_analysis.FishData object at 0x0000013F14345288>]


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<luminance_analysis.FishData object at 0x0000013F144FB748>, <luminance_analysis.FishData object at 0x0000013F144FB7C8>, <luminance_analysis.FishData object at 0x0000013F144FCE48>, <luminance_analysis.FishData object at 0x0000013F145046C8>, <luminance_analysis.FishData object at 0x0000013F14515F08>]
[<luminance_analysis.FishData object at 0x0000013F14501D88>, <luminance_analysis.FishData object at 0x0000013F14501D08>, <luminance_analysis.FishData object at 0x0000013F14504E88>, <luminance_analysis.FishData object at 0x0000013F14342808>, <luminance_analysis.FishData object at 0x0000013F1450F908>]


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<luminance_analysis.FishData object at 0x0000013F13F24188>, <luminance_analysis.FishData object at 0x0000013F13F270C8>, <luminance_analysis.FishData object at 0x0000013F13F29248>, <luminance_analysis.FishData object at 0x0000013F13F2E3C8>, <luminance_analysis.FishData object at 0x0000013F13F32708>]
[<luminance_analysis.FishData object at 0x0000013F13F27EC8>, <luminance_analysis.FishData object at 0x0000013F13F27D48>, <luminance_analysis.FishData object at 0x0000013F13F2E148>, <luminance_analysis.FishData object at 0x0000013F1451C1C8>, <luminance_analysis.FishData object at 0x0000013F13F42B08>]


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Modelling

### Description of the model


The goal of the model is to reconstruct the activity of PCs based on the activity observed for GCs and IONs.
The function that we will use to model a PC is the following:
$$ trace_{PC}^i = o^i + clusters_{GC} * w^i_{GC} + clusters_{IO} * w^i_{IO}$$


 - $trace_{PC}^i$ is the $i^{th}$ PC cell trace;
 - $o^i$ is an offset term:
 - $clusters_{GC}$, $clusters_{IO}$ are matrices with the average activation of all GC&IO clusters;
 - $w^i_{GC}$, $w^i_{IO}$ are weights vectors for each of the GC and IO cluster. We will allow positive and negative $w^i_{GC}$, but only positive $w^i_{IO}$. This is a quite safe assumption considering the known PC physiology. 


### Approach

Here is a summary of the modelling approach:
- **Create a panel of regressors**:
    - Calculate clusters of GC and IO responses; 
    - Deconvolve average response of each cluster with 6fe05 kernel;
    - Reconvolve it with 6s kernel;
    - Normalize it to be  > 0, and with integral = 1.


- **Clean up PC traces**:
    - For each cell, take raw fluorescence if valid trials, and zscore them on a trial-to-trial base (for changes in offset F across planes);
    - concatenate repetitions;
    - high-pass filter them with very low cutoff freq (1/80 Hz) to remove slow fluctuations;
    - smooth them with a 3 pts mean boxcar rolling window;
    
    
- **Split fit and test data**:
    - Randomly pick from each cell:
        - 2 repetitions that will be left out for the analysis (**test traces**);
        - 4 repetitions (or more, if there are more planes) that will be used for finding the regularization term and for the actual fitting (**fit traces**);


- **Define boundaries, cost function and regularization function**:
    - Parameters to be optimized are:
        - *offset*: a constant term, bound to be between -5 and 5;
        - *coefs_GC*: coefficients for GC regressors, bound to be between -1000 and 1000 (the large difference comes from the different normalizations applied on regressors -norm- and on trace - Z scoring)
        - *coefs_IO*: coefficients for GC regressors, bound to be between 0 and 1000, as we have good reasons to postulate that IO contributions are strictly positive
        - cost function: L2 distance to target trace;
        - regularization function: L1 (sum of absolute value of parameters)
        
        
- **Find regularization parameter**:
    - Find regularization parameter (leave one out cross-validation):
    - For each lambda parameter, train the model on all the fit traces but one (**train traces**);
    - Then, calculate the cost of the fit on the one trace that was left out (**validation trace**);
    - Do this iterating over all possible combinations of n-1 and 1 traces;
    - Calculate average cross over all combinations, over all cells;


- **Fit the trace**:
    - Use the resulting regularization term for fitting the fit repetitions, and use the obtained coefficients for plots / further analyses on the test repetitions

## Create a panel of regressors:

Create regressor panel, making traces non-0 and with integral equal to one:

In [7]:
n_gc_clust = data_dict["GC"]["clust_resps"].shape[0]
n_io_clust = data_dict["IO"]["clust_resps"].shape[0]
regressors_mat = np.concatenate([data_dict["GC"]["clust_resps"], data_dict["IO"]["clust_resps"]])

# Reconvolve and normalize regressors:
for i in range(regressors_mat.shape[0]):
    reconvolved = np.convolve(regressors_mat[i, :], get_kernel(ker_len=100, tau=tau_6s))[:regressors_mat.shape[1]]
    
    # Make strictly positive and with integral == 1:
    reconvolved -= np.min(reconvolved)  # offset at 0
    regressors_mat[i, :] = reconvolved / np.sum(reconvolved)

# Arbitrary cluster names (do we need them?):
gc_cluster_names = ['ON 1', 'ON 2', 'ON abs', 'ON inter.1', 'ON inter.2', 'OFF 1', 'OFF 2', 'OFF inter.']
io_cluster_names = ['Onset', 'ON max', 'Offset 1', 'Offset 2', 'On inter.', 'Offset 3']

Plot the regressors panel:

In [8]:
# Plot the regressor panel:
def reg_panel_plot(regressors_mat, figure=None, ax=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize=(3, 3))

    offset = 0.01
    if ax is None:
        ax = add_offset_axes(figure, (0., 0., 1, 1), frameon=False, frame=frame)
    cols = sns.color_palette()
    for i, col in enumerate([cols[0], ]*n_gc_clust + [cols[1], ]*n_io_clust):
        ax.fill_between(stim[:, 0], np.zeros(stim[:, 0].shape) - i*offset, 
                        regressors_mat[i, :] - i*offset, color=col)
        
    ax.axis("off")

In [9]:
reg_panel_plot(regressors_mat)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Clean up PC traces:

In [10]:
from luminance_analysis.utilities import smooth_traces, butter_highpass_filter

In [11]:
def filter_cell_rep_block(cellmat, cutoff=1/80, smooth_wnd=3):
    """ Filter traces from the raw traces block.
    Return a repetitions block containing only the valid repetitions.
    """
    cutoff = 1 / 80
    dt = stim[1, 0]
    
    # Select entries with valid numbers in the repetition matrix:
    cellmat = cellmat[:, ~np.isnan(cellmat).all(0)].copy()
    
    # zscore repetition-wise, important for ROIs spanning more than one plane
    cellmat = (cellmat - np.nanmean(cellmat, 0)) / np.nanstd(cellmat, 0)
    
    # concatenate, highpass filter with very low cutoff, and smooth:
    trace = np.concatenate(cellmat.T, 0)
    filtered = butter_highpass_filter(trace, cutoff, 1 / dt)  # filt trace
    filtered = smooth_traces(filtered[np.newaxis, :], win=3, method="mean")[0, :]  # smooth
    filtered[np.isnan(filtered)] = 0
    
    # reshape in original form, and zscore again after filtering and smoothing:
    reshaped = filtered.reshape(cellmat.T.shape).T  
    reshaped = (reshaped - np.nanmean(reshaped, 0))/np.nanstd(reshaped, 0)
    
    return reshaped

In [12]:
# Cleanup parameters:
cutoff_hz = 1 / 80  # long cutoff for highpass filter - remove long fluctuations in PC signal
smooth_wnd = 3  # smoothing window to reduce noise. Our data is sampled ar around 0.25 seconds

raw_traces = data_dict["PC"]["traces"]  # raw PC fluorescences
rel_idxs = data_dict["PC"]["rel_idxs"]  # reliability indexes for each PC cell

n_rois = raw_traces.shape[0]  # number of ROIs
n_rep_timepts = raw_traces.shape[1]  # timepoints per repetition
n_reps_max = raw_traces.shape[2]  # maximum number of repetitions in a cell

# Find number of valid repetitions for each cell:
n_valid_reps = (~np.isnan(raw_traces).all(1)).sum(1)

# Clean up:
clean_traces = np.full(raw_traces.shape, np.nan)
for i_roi in range(n_rois):
    clean_block = filter_cell_rep_block(raw_traces[i_roi, :, :], cutoff=cutoff_hz, smooth_wnd=smooth_wnd)
    clean_traces[i_roi, :, :n_valid_reps[i_roi]] = clean_block

In [13]:
# Plot a trace panel:
def little_trace_plot(clean_traces, i=0, figure=None, ax=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize=(3, 1))

    if ax is None:
        ax = add_offset_axes(figure, (0., 0., 1, 1), frameon=False, frame=frame)
    offset = 0.01
    plt.plot(clean_traces[i, :, :1].T.flatten())
    ax.axis("off")

In [14]:
little_trace_plot(clean_traces, i=18)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Separate testing and training traces:

In [15]:
from random import shuffle, seed

In [16]:
# Fix randomness for reproducibility:
np.random.seed(572704)
seed(572704)

# Generate list of 2 indexes for each cell which will be used to keep traces out for the testing part.
# Generate randomly indexes for test and train set of traces:
n_test_reps = 3

test_idxs = []
fit_idxs = []
for n_resps in n_valid_reps:
    idxs = np.arange(n_resps)  # possible repetitions indexes
    shuffle(idxs)  # shuffle index list
    test_idxs.append(idxs[:n_test_reps])  # test idxs will be the first 2
    fit_idxs.append(idxs[n_test_reps:])  # the rest goes for training

## Define boundaries, function & regularisation/cost functions

In [17]:
from scipy import optimize

In [18]:
## Functions for the regression.

# This is the the main function that we actually use to describe PC activity:
def offset_cluster_combine(coefs, regressors):
    """ Compute trace from offset/coefficients and regressors.
    It assumes coefs and regressors for GC and IO are all concatenated,
    and first element of coefs array is the offset.
    """
    return coefs[0] + np.sum(coefs[1:] * regressors, 1)  # first term is baseline

def cost_func(fit_coefs, regressors, trace2fit, model):
    """ Cost function: sum of squares.
    """
    diff = trace2fit - model(fit_coefs, regressors)
    return np.sum(diff**2) / trace2fit.shape[0]

def reg_func(fit_coefs, reg_coef=0):
    """ Regularization function: sum of absolute coefs values.
    Does not regularize the offset, so first term is excluded:
    """
    return np.sum(np.abs(fit_coefs[1:]))

def minimization_func(fit_coefs, regressors, trace2fit, model, cost_func, reg_func, reg_lamda=0):
    """Full function to minimize, including cost and regularization terms.
    """
    return cost_func(fit_coefs, regressors, trace2fit, model) + reg_func(fit_coefs) * reg_lamda

In [19]:
# set starting values and bounds for the offset (the +1 below) and the coefficients.

# Initial guesses:
# coefs_init_guess = np.zeros(regressors_mat.shape[0] + 1) 

# We know that IO input can only positively contribute to PC activity, 
# so we set IO coefficients to be positive:
# w_bound = 1000  # bound for regressors weights (high b/c of normalization differences) 
# off_bound = 5  # bound for the offset
# coefs_bounds =[(-off_bound, off_bound)] + \
#               [(0, w_bound) for _ in range(n_gc_clust)] + \
#               [(0, w_bound) for _ in range(n_io_clust)]

## Validate regularization parameter

Here we test for the optimal regularization lambda. As this parameter search can take quite long, you can just skip the section and execute from the next block.

In [20]:
from sklearn.model_selection import LeaveOneOut

In [21]:
cell_types = ['GC', 'IO']

regressors_mat_dict = {'GC': regressors_mat[:n_cluster_list[0], :],
                       'IO': regressors_mat[-n_cluster_list[1]:, :]}

In [22]:
# as the cost of a fit with all 0 coefficients is 1 - the std of the trace (which is zscored),
# this will be our estimation of the cost that we use to decide the regularization lambda range 
# (2 orders of mag below and above the expected cost):
reg_lambda_arr = np.insert(10**np.arange(-7., -2, 0.35), 0, 0)

# Initialise empty matrices for storing the costs and the lambda parameters for all left-one-out fits:
n_lambdas = reg_lambda_arr.shape[0]

In [23]:
costs = {cell_type:np.full((n_rois, n_lambdas, n_reps_max - n_test_reps), np.nan) for cell_type in cell_types}

for cell_type in cell_types:
    
    coefs_init_guess = np.zeros(regressors_mat_dict[cell_type].shape[0] + 1) 
    
    # Prepare a concatenation of regressors long enough to fit the longest possible trace:
    regressors_concat = np.concatenate([regressors_mat_dict[cell_type],]*(n_reps_max - n_test_reps), 1).T
    
    # Use scikit learn leave-one-out iterator:
    loo = LeaveOneOut()

    n_downsample = 1  # skip cells if we are testing. Otherwise, set to 1

    for i_roi in range(0, n_rois, n_downsample):
        if np.mod(i_roi, 50) == 0:
            print(i_roi)
        roi_fit_idxs = fit_idxs[i_roi]

        # Hyperparameter grid search:
        for i_lambda, reg_lambda in enumerate(reg_lambda_arr):

            # Leave-one-out validation:
            for i_loo, (idxs_train, idxs_valid) in enumerate(loo.split(roi_fit_idxs)):
                roi_train_trace = clean_traces[i_roi, :, roi_fit_idxs[idxs_train]].flatten()
                roi_valid_trace = clean_traces[i_roi, :, roi_fit_idxs[idxs_valid[0]]]
                res = optimize.minimize(minimization_func, method='SLSQP', 
                                        args=(regressors_concat[:roi_train_trace.shape[0], :], 
                                              roi_train_trace, offset_cluster_combine, 
                                              cost_func, reg_func, reg_lambda),
                                        x0=coefs_init_guess, bounds=None)

                costs[cell_type][i_roi, i_lambda, i_loo] = cost_func(res.x, regressors_concat[:roi_valid_trace.shape[0], :], 
                                                          roi_valid_trace, offset_cluster_combine)

0
50
100
150
200
250
300
350
400
450
500
550
600
650
0
50
100
150
200
250
300
350
400
450
500
550
600
650


In [24]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharey=True)

for ax, cell_type, col in zip(axes, cell_types, cols[:2]):
    cost_arr_mean = np.nanmean(costs[cell_type], 2)
    
    ax.plot(range(n_lambdas), np.nanmean(cost_arr_mean, 0), c=col)
    ax.fill_between(range(n_lambdas), 
                    np.nanmean(cost_arr_mean, 0)-np.nanstd(cost_arr_mean, 0),
                    np.nanmean(cost_arr_mean, 0)+np.nanstd(cost_arr_mean, 0), color=col, alpha=.2, edgecolor="None")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [25]:
fig.savefig(fig_fold/'reg_opti_curves.png')

In [26]:
def max_within_std_err(cell_costs):
    mean_cost = np.nanmean(cell_costs, 1)
    std_err_cost = np.nanstd(cell_costs, 1) / np.sqrt(np.sum(~np.isnan(cell_costs[0, :])) - 1)

    min_idx = np.argmin(mean_cost)
    std_err_min = std_err_cost[min_idx]

    i = min_idx
    while i < len(mean_cost) - 1 and mean_cost[i + 1] < mean_cost[min_idx] + std_err_min:
        i += 1
    
    if i != len(mean_cost) and mean_cost[i] < 1:
        return i
    else:
        return np.nan

### Fitting

In [27]:
costs_final = {cell_type: np.full(n_rois, np.nan) for cell_type in cell_types}  # Costs of our final fit
coefs_final = {cell_type: np.full((n_rois, regressors_mat_dict[cell_type].shape[0] + 1), np.nan) for cell_type in cell_types}  # Coefs from our final fit
valid_roi_idxs = {cell_type:[] for cell_type in cell_types}


for cell_type in cell_types:
    
    #Set initial guesses
    coefs_init_guess = np.zeros(regressors_mat_dict[cell_type].shape[0] + 1) 
    
    #Prepare a concatenation of regressors long enough to fit the longest possible trace:
    regressors_concat = np.concatenate([regressors_mat_dict[cell_type],]*(n_reps_max - n_test_reps), 1).T

    
    for i_roi in range(n_rois):
        if np.mod(i_roi, 100) == 0:
            print(i_roi)
        cost_idx = max_within_std_err(costs[cell_type][i_roi, :, :])

        if not np.isnan(cost_idx):
            final_lambda_reg = reg_lambda_arr[cost_idx]

            # Get test and fit indexes and traces:
            roi_test_idxs = test_idxs[i_roi]
            roi_test_trace = clean_traces[i_roi, :, roi_test_idxs].flatten()

            roi_fit_idxs = fit_idxs[i_roi]
            roi_fit_trace = clean_traces[i_roi, :, roi_fit_idxs].flatten()  

            # Fit the train set:
            res = optimize.minimize(minimization_func, method='SLSQP', 
                                    args=(regressors_concat[:roi_fit_trace.shape[0], :], 
                                          roi_fit_trace, offset_cluster_combine, 
                                          cost_func, reg_func, final_lambda_reg),
                                    x0=coefs_init_guess, bounds=None)

            # Calculate cost on the test set:
            costs_final[cell_type][i_roi] = cost_func(res.x, regressors_concat[:roi_test_trace.shape[0], :], 
                                                      roi_test_trace, offset_cluster_combine)
            # Save the coefficients:
            coefs_final[cell_type][i_roi, :] = res.x

            valid_roi_idxs[cell_type].append(i_roi)
            
valid_roi_idxs_dict = {cell_type: np.array(valid_roi_idxs[cell_type]) for cell_type in cell_types}
n_valid_rois = {cell_type: len(valid_roi_idxs[cell_type]) for cell_type in cell_types}

0
100
200
300
400
500
600
0
100
200
300
400
500
600


In [28]:
#Keep data only from properly fit ROIs
coefs_final_sel = {} 
costs_final_sel = {} 
clean_traces_sel = {} 
test_idxs_sel = {}
fit_idxs_sel = {}
rel_idx_sel = {}
clust_lab_sel = {}

for cell_type in cell_types:
    coefs_final_sel[cell_type] = coefs_final[cell_type][valid_roi_idxs[cell_type], :]
    costs_final_sel[cell_type] = costs_final[cell_type][valid_roi_idxs[cell_type]]
    clean_traces_sel[cell_type] = clean_traces[valid_roi_idxs[cell_type], :, :]
    test_idxs_sel[cell_type] = [test_idxs[i] for i in valid_roi_idxs[cell_type]]
    fit_idxs_sel[cell_type] = [fit_idxs[i] for i in valid_roi_idxs[cell_type]]
    rel_idx_sel[cell_type] = data_dict["PC"]["rel_idxs"][valid_roi_idxs[cell_type]]
    clust_lab_sel[cell_type] = data_dict["PC"]["clust_labels"][valid_roi_idxs[cell_type]]

### Fitting with shuffled coefficients to set a random fit baseline

In [29]:
coefs_shuf = {cell_type: np.full_like(coefs_final_sel[cell_type], np.nan) for cell_type in cell_types}
costs_shuf = {cell_type: np.full(n_valid_rois[cell_type], np.nan) for cell_type in cell_types}  # Costs from shuffled weights

for cell_type in cell_types:
    
    #Prepare a concatenation of regressors long enough to fit the longest possible trace:
    regressors_concat = np.concatenate([regressors_mat_dict[cell_type],]*(n_reps_max - n_test_reps), 1).T
    
    for i in range(coefs_final_sel[cell_type].shape[1]):
        shuf_idx = np.random.permutation(coefs_final_sel[cell_type].shape[0])
        coefs_shuf[cell_type][:, i] = coefs_final_sel[cell_type][shuf_idx, i]
    
    for i_roi in range(n_valid_rois[cell_type]):
    
        # Get test and fit indexes and traces:
        roi_test_idxs = test_idxs_sel[cell_type][i_roi]
        roi_test_trace = clean_traces_sel[cell_type][i_roi, :,roi_test_idxs].flatten()
        
        # Calculate cost on the test set with shuffled weights:
        costs_shuf[cell_type][i_roi] = cost_func(coefs_shuf[cell_type][i_roi, :], regressors_concat[:roi_test_trace.shape[0], :], 
                                                 roi_test_trace, offset_cluster_combine)

cost_threshold = {cell_type: np.percentile(costs_shuf[cell_type], 5) for cell_type in cell_types}
sel_fit = {cell_type: np.argwhere(costs_final_sel[cell_type] < cost_threshold[cell_type])[:, 0] for cell_type in cell_types}
n_valid_rois = {cell_type: len(sel_fit[cell_type]) for cell_type in cell_types}

In [177]:
#Retrieve original indexes of ROIs with a good fitting
inclusion_mask = {cell_type: costs_final_sel[cell_type] < cost_threshold[cell_type] for cell_type in cell_types}
final_sel_idxs = {cell_type:{} for cell_type in cell_types}

for cell_type in cell_types:
    final_sel_idxs[cell_type] = np.array(valid_roi_idxs[cell_type])[inclusion_mask[cell_type]]
    
#And find ROIs with a good fitting with both GC and IO regressors
total_fit_idxs = np.intersect1d(final_sel_idxs['GC'], final_sel_idxs['IO'])
total_fit_idxs.shape

(236,)

## Plots

In [460]:
def cost_figure(costs_final, costs_shuf, cost_threshold, ax_coefs=None, frame=None):
    
    bin_array = np.arange(0.01, 1.6, 0.05)
    costs_shuf_hist, b = np.histogram(costs_shuf, bin_array)
    costs_final_hist, b = np.histogram(costs_final, bin_array)
    costs_final_sel_hist, b = np.histogram(costs_final[costs_final < cost_threshold], bin_array)
    
    sel_fract = np.nansum(costs_final < cost_threshold) / np.sum(~np.isnan(costs_final))
    
    if ax_coefs is None:
        figure = plt.figure(figsize=(3.,2))
        ax_coefs = add_offset_axes(figure, (0.2, 0.2, 0.7, 0.7), frame=frame)
            
    a = 0.7
    x = (bin_array[1:] + bin_array[:-1]) / 2
    ax_coefs.fill_between(x, costs_shuf_hist, step="mid", alpha=a, edgecolor=None, facecolor=sns.color_palette()[4])
    ax_coefs.fill_between(x, costs_final_sel_hist, step="mid", alpha=a, edgecolor=None, facecolor=sns.color_palette()[3])
    ax_coefs.step(x, costs_final_hist, where="mid", alpha=a, color=sns.color_palette()[3])
    
    ax_coefs.axvline(cost_threshold, c=(0.4,)*3)
    ax_coefs.text(0.05, .95, '{:.2f}% PCs'.format(sel_fract*100), transform=ax_coefs.transAxes, c=sns.color_palette()[3], ha='left', fontsize=8.5)
    ax_coefs.set_xlabel("Cost (on test)")
    ax_coefs.set_ylabel("Count")

In [461]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharey=True)

for ax, cell_type in zip(axes, cell_types):
    cost_figure(costs_final[cell_type], costs_shuf[cell_type], cost_threshold[cell_type], ax_coefs=ax)
    ax.set_title('Fitting cost with {} regressors'.format(cell_type), c=(0.4,)*3)
    
axes[1].text(1, .5, 'Regressor fits', transform=axes[1].transAxes, c=sns.color_palette()[3], ha='left', va='top', fontsize=8.5)
axes[1].text(1, .45, 'Shuffled fits', transform=axes[1].transAxes, c=sns.color_palette()[4], ha='left', va='top', fontsize=8.5)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(1, 0.45, 'Shuffled fits')

In [462]:
fig.savefig(fig_fold/'fitting_costs.png')

### ROI fit explorer

In [463]:
def plot_roi_fits(roi_idx):
    gc_fit_rank = ((np.argwhere(np.argsort(costs_final['GC']) == roi_idx)[0][0])/n_rois)*100
    io_fit_rank = ((np.argwhere(np.argsort(costs_final['IO']) == roi_idx)[0][0])/n_rois)*100

    fig, axes = plt.subplots(2, 2, figsize=(15, 5))

    #Find train-test trials
    roi_test_idxs = test_idxs[roi_idx]
    test_traces = clean_traces[roi_idx, :, roi_test_idxs]

    #Recover traces
    roi_fit_idxs = fit_idxs[roi_idx]
    fit_traces = clean_traces[roi_idx, :, roi_fit_idxs]

    #Recover fit coefficients
    gc_coefs = coefs_final['GC'][roi_idx, :]
    io_coefs = coefs_final['IO'][roi_idx, :]

    #Reconstruct fits
    gc_fit = offset_cluster_combine(gc_coefs, regressors_mat_dict['GC'].T)
    io_fit = offset_cluster_combine(io_coefs, regressors_mat_dict['IO'].T)

    for row, traces in zip(range(2), [fit_traces, test_traces]):
        for col, fit in zip(range(2), [gc_fit, io_fit]):
            for rep in range(traces.shape[0]):
                axes[row, col].plot(traces[rep, :], 'gray', alpha=.2)
            axes[row, col].plot(np.nanmean(traces, 0), c=cols[2])
            axes[row, col].plot(fit, c=cols[col])


        for row, label in zip(range(2), ['Training set', 'Test set']):
            axes[row, 0].set_ylabel(label, c='gray')

        for col, title, rank in zip(range(2), ['GC regressors', 'IO regressors'], [gc_fit_rank, io_fit_rank]):
            axes[0, col].set_title('{} (best {:.2f}%)'.format(title, rank), c='gray')

    plt.suptitle('ROI {} fits'.format(roi_idx), c='gray')

In [464]:
interact(plot_roi_fits,
         roi_idx=widgets.IntSlider(min=0, max=n_rois, step=1, continuous_update=False))

interactive(children=(IntSlider(value=0, continuous_update=False, description='roi_idx', max=672), Output()), …

<function __main__.plot_roi_fits(roi_idx)>

In [497]:
plot_roi_fits(np.argsort(costs_final['GC'])[2])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [34]:
fig.savefig(fig_fold/'best_GC_fit.png')

# Analyse coefficients

In [483]:
def coefs_plot(clust_lab, coefs_final, cell_type, col, ax_coefs=None, frame=None):
    
    if ax_coefs is None:
        figure = plt.figure(figsize=(4, 3))
        ax_coefs = add_offset_axes(figure, (0.05, 0.2, 0.75, 0.8), frame=frame)
    
    idxs_sort = np.argsort(clust_lab)

    c_lim = 125
    im = ax_coefs.imshow(coefs_final[idxs_sort, 1:].T, vmin=-c_lim, vmax=c_lim, aspect="auto", cmap="RdBu")
    ax_coefs.set_xlabel("Roi n.")

    ax_coefs.set_yticks([])
    ax_coefs.set_ylabel('{}'.format(cell_type), rotation=90, c=col)
    [ax_coefs.axes.spines[s].set_visible(False) for s in
         ["left", "right", "top", "bottom"]]
    for ytick, color in zip(ax_coefs.get_yticklabels(), sns.color_palette()[:2]):
        ytick.set_color(color)

    k = np.sum(clust_lab == 0)
    for n in range(1, clust_lab.max() + 1): 
        ax_coefs.axvline(k, c='black')
        k += np.sum(clust_lab == n)
        
#     axcolor = add_offset_axes(figure, [0.86, 0.2, 0.02, 0.12], frame=frame)
#     cbar = plt.colorbar(im, orientation="vertical")
#     cbar.set_ticks([-c_lim, c_lim])
#     cbar.ax.tick_params(length=3)

    plt.tight_layout()

    return im

In [484]:
figure, axes = plt.subplots(1, 2, figsize=(7, 3))

for ax, cell_type, c in zip(axes, cell_types, range(2)):
    im = coefs_plot(clust_lab_sel[cell_type], coefs_final_sel[cell_type], cell_type, col=cols[c], ax_coefs=ax)
    
cbar = figure.colorbar(im, ax=ax, orientation='vertical', shrink=.25)


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [485]:
figure.savefig(fig_fold/'fitting_coefs.png')

### Check error histogram and relationship with cell reliability

Check error of the fit as a function of the cell reliability index. We expect a negative relationship:

In [468]:
fig, axes = plt.subplots(1, 2, figsize=(7, 3), sharey=True)

for ax, cell_type, c in zip(axes, cell_types, range(2)):
    ax.scatter(rel_idxs, costs_final[cell_type], c='none', edgecolors=cols[c], linewidths=.25, s=7)
    ax.scatter(rel_idxs[final_sel_idxs[cell_type]], costs_final[cell_type][final_sel_idxs[cell_type]], c=cols[c], s=7)
    
#     error_vs_reliabil(rel_idx_sel[cell_type], costs_final_sel[cell_type], ax, c=cols[c])
    
for ax, cell_type in zip(axes, cell_types):
    ax.set_title('Fits with {} regressors'.format(cell_type), c='gray', fontsize=10)
    ax.set_xlabel("Reliability idx", fontsize=8.5)
    ax.set_ylabel("Fit error", fontsize=8.5)
    
plt.tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

*c* argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with *x* & *y*.  Please use the *color* keyword-argument or provide a 2D array with a single row if you intend to specify the same RGB or RGBA value for all points.
*c* argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with *x* & *y*.  Please use the *color* keyword-argument or provide a 2D array with a single row if you intend to specify the same RGB or RGBA value for all points.


In [469]:
# fig, axes = plt.subplots(1, 2, figsize=(7, 3), sharey=True)

# for ax, cell_type, c in zip(axes, cell_types, range(2)):
#     ax.scatter(rel_idx_sel[cell_type], costs_final_sel[cell_type], c=cols[c], s=5)
    
# #     error_vs_reliabil(rel_idx_sel[cell_type], costs_final_sel[cell_type], ax, c=cols[c])
    
# for ax, cell_type in zip(axes, cell_types):
#     ax.set_title('Fits with {} regressors'.format(cell_type), c='gray', fontsize=10)
#     ax.set_xlabel("Reliability idx", fontsize=8.5)
#     ax.set_ylabel("Fit error", fontsize=8.5)
    
# plt.tight_layout()

In [470]:
fig.savefig(fig_fold/'fitting_coefs_rel.png')

### GC-IO index

In [471]:
from scipy import stats

In [472]:
gc_io_idx = (1/(costs_final['GC']) - 1/(cost
                                        nal['IO']))

In [473]:
#Create a IO-GC colormap
from matplotlib.colors import LinearSegmentedColormap
colors = [sns.color_palette()[1], (.72,)*3, sns.color_palette()[0]]
n_bins = 200
cmap_name = 'contributions_cmap'
custom_cmap = LinearSegmentedColormap.from_list(cmap_name, colors, N=n_bins)

In [474]:
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

axes[0].scatter(costs_final['GC'], costs_final['IO'], c='none', edgecolor=cols[2], linewidth=.25, s=7)
axes[0].scatter(costs_final['GC'][total_fit_idxs], costs_final['IO'][total_fit_idxs], c=cols[2], s=7)
axes[0].plot([0, 1], [0, 1], transform=axes[0].transAxes, c='gray', ls='--')

axes[0].set_ylabel('IO fitting costs', fontsize=8)
axes[0].set_xlabel('GC fitting costs', fontsize=8)


r = stats.spearmanr(gc_io_idx[total_fit_idxs], rel_idxs[total_fit_idxs])
axes[1].axvline(0, c = (0.6, )*3, zorder=-100)

edge_cols = custom_cmap(gc_io_idx)
axes[1].scatter(gc_io_idx, rel_idxs, facecolor='none', edgecolor='white', s=7, linewidths=.075)
axes[1].scatter(gc_io_idx[total_fit_idxs], rel_idxs[total_fit_idxs], c=gc_io_idx[total_fit_idxs], cmap=custom_cmap, s=6, vmin=-0.1, vmax=0.1, 
           edgecolor=(0.3,)*3, linewidths=0.1)

#     ax.set_xlim(-1.1, 1.1)
axes[1].set_ylabel("Reliability index")
axes[1].set_xlabel("GC-IO idx")
#     numbers = 
axes[1].text(.05, .95, "$\\rho$" + "={:1.2f} \np={:1.2}".format(r.correlation, r.pvalue), fontsize=7, color=(0.3,)*3, transform=axes[1].transAxes)


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

*c* argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with *x* & *y*.  Please use the *color* keyword-argument or provide a 2D array with a single row if you intend to specify the same RGB or RGBA value for all points.


Text(0.05, 0.95, '$\\rho$=0.29 \np=5.5e-06')

In [475]:
fig.savefig(fig_fold/'gc_io_idx.png')

In [476]:
fit_diffs = {cell_type: np.full((n_rois, clean_traces.shape[1]), np.nan) for cell_type in cell_types}

for roi_idx in range(n_rois):
    
    #Find train-test trials
    roi_test_idxs = test_idxs[roi_idx]
    test_traces = clean_traces[roi_idx, :, roi_test_idxs]

    #Recover traces
    roi_fit_idxs = fit_idxs[roi_idx]
    fit_traces = clean_traces[roi_idx, :, roi_fit_idxs]

    #Recover fit coefficients
    gc_coefs = coefs_final['GC'][roi_idx, :]
    io_coefs = coefs_final['IO'][roi_idx, :]

    #Reconstruct fits
    gc_fit = offset_cluster_combine(gc_coefs, regressors_mat_dict['GC'].T)
    io_fit = offset_cluster_combine(io_coefs, regressors_mat_dict['IO'].T)

    #Calculate error between fit and average test trace   
    fit_diffs['GC'][roi_idx, :] = (np.nanmean(test_traces, 0) - gc_fit)**2
    fit_diffs['IO'][roi_idx, :] = (np.nanmean(test_traces, 0) - io_fit)**2

In [477]:
time_contributions_mat = np.full((n_rois, clean_traces.shape[1]), np.nan)

for roi in range(n_rois):
    time_contributions_mat[roi, :] = fit_diffs['GC'][roi, :] - fit_diffs['IO'][roi, :]

In [478]:
time_contributions_mat.shape

(672, 426)

In [479]:
# a = np.argsort(costs_final['IO'])[:20]

In [480]:
figure = plt.figure(figsize=(14,4))

ax = add_offset_axes(figure, (0.1, 0.1, 0.9, 0.9))

# Find quartiles:
median_contr = np.nanmedian(time_contributions_mat[total_fit_idxs], 0)
low_quart_contr = np.nanquantile(time_contributions_mat[total_fit_idxs], 0.25, axis=0)
high_quart_contr = np.nanquantile(time_contributions_mat[total_fit_idxs], 0.75, axis=0)

# Plot trace and filling:
i_col=5
ax.plot(stim[:, 0], median_contr, color=sns.color_palette()[i_col])
ax.fill_between(stim[:,0], low_quart_contr, high_quart_contr, facecolor=sns.color_palette()[i_col], 
                     alpha=.3, zorder=100, edgecolor=None)
shade_plot(stim, ax, shade_range=(0.75, 0.98))
ax.axhline(0, c = (0.3,)*3, zorder=1)
ax.set_xlim(0, stim[-1, 0])
#     for y_line in [-1, 1]:
#         ax.axhline(y_line, color=(0.3,)*3, linewidth=0.5)

make_bar(ax, (stim[-1, 0]-10, stim[-1, 0]), label="10 s")
ax.set_ylabel('Fitting error')
ax.text(0.025, .97, 'Fits w/ GC', fontsize=7, color=sns.color_palette()[0], transform=ax.transAxes, va='center')
ax.text(0.025, .03, 'Fits w/ IO', fontsize=7, color=sns.color_palette()[1], transform=ax.transAxes, va='center')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.025, 0.03, 'Fits w/ IO')

In [482]:
figure.savefig(fig_fold/'time_contr.png')

In [None]:
plt.figure(figsize=(7, 1.5))
    
# Find quartiles:
median_contr = np.nanmedian(time_contributions_mat, 0)
low_quart_contr = np.nanquantile(time_contributions_mat, 0.25, axis=0)
high_quart_contr = np.nanquantile(time_contributions_mat, 0.75, axis=0)

plt.plot(stim[:, 0], median_contr, color=sns.color_palette()[i_col])
plt.fill_between(stim[:,0], low_quart_contr, high_quart_contr, facecolor=sns.color_palette()[i_col])


In [None]:
time_contributions_mat

In [None]:
# Plot trace and filling:
i_col=5
ax.plot(stim[:, 0], median_contr, color=sns.color_palette()[i_col])
ax.fill_between(stim[:,0], low_quart_contr, high_quart_contr, facecolor=sns.color_palette()[i_col], 
                     alpha=.3, zorder=100, edgecolor=None)
shade_plot(stim, ax, shade_range=(0.75, 0.98))
ax.axhline(0, c = (0.3,)*3, zorder=1)
ax.set_xlim(0, stim[-1, 0])
#     for y_line in [-1, 1]:
#         ax.axhline(y_line, color=(0.3,)*3, linewidth=0.5)

make_bar(ax, (stim[-1, 0]-10, stim[-1, 0]), label="10 s")
ax.set_ylabel('Contribution ratio')
ax.text(0.5, 1.05, 'GC contributions dominates', fontsize=7, color=sns.color_palette()[0])
ax.text(0.5, -1.15, 'IO contributions dominates', fontsize=7, color=sns.color_palette()[1])

In [None]:
plt.figure()
plt.fill_between(stim[:,0], low_quart_contr, high_quart_contr)


In [None]:
high_quart_contr

In [None]:
roi_idx = np.argsort(costs_final['GC'])[2]

fig, axes = plt.subplots(2, 2, figsize=(15, 5))

#Find train-test trials
roi_test_idxs = test_idxs[roi_idx]
test_traces = clean_traces[roi_idx, :, roi_test_idxs]

#Recover traces
roi_fit_idxs = fit_idxs[roi_idx]
fit_traces = clean_traces[roi_idx, :, roi_fit_idxs]

#Recover fit coefficients
gc_coefs = coefs_final['GC'][roi_idx, :]
io_coefs = coefs_final['IO'][roi_idx, :]

#Reconstruct fits
gc_fit = offset_cluster_combine(gc_coefs, regressors_mat_dict['GC'].T)
io_fit = offset_cluster_combine(io_coefs, regressors_mat_dict['IO'].T)



for row, traces in zip(range(2), [fit_traces, test_traces]):
    for col, fit in zip(range(2), [gc_fit, io_fit]):
        for rep in range(traces.shape[0]):
            axes[row, col].plot(traces[rep, :], 'gray', alpha=.2)
        axes[row, col].plot(np.nanmean(traces, 0), c=cols[2])
        axes[row, col].plot(fit, c=cols[col])
        

    for row, label in zip(range(2), ['Training set', 'Test set']):
        axes[row, 0].set_ylabel(label, c='gray')
        
    for col, title in zip(range(2), ['GC regressors', 'IO regressors']):
        axes[0, col].set_title(title, c='gray')


### Look at distribution of IO and GC coefficients weights

In [None]:
cleaned_coefs = coefs_final_sel.copy()
cleaned_coefs = cleaned_coefs[~(cleaned_coefs == 0).all(1), :]
coefs_sum = np.nansum(np.abs(cleaned_coefs), 1)
norm_coefs = (cleaned_coefs.T / coefs_sum).T
non_zero_coefs = np.abs(norm_coefs) > 0.1

In [None]:
def gc_io_weights_hist(cleaned_coefs, figure=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize=(3,2))
        
    ax = add_offset_axes(figure, (0.2, 0.2, 0.7, 0.7), frame=frame)
    l=400
    ax.hist(cleaned_coefs[:, 1:9].sum(1), np.arange(-l, l, 20), alpha=0.6, label="GC")
    ax.hist(cleaned_coefs[:, 9:].sum(1), np.arange(-l, l, 20), alpha=0.6, label="IO")
    ax.legend()
    ax.set_xlabel("Weight of coefs")
    ax.set_ylabel("Count")

In [None]:
gc_io_weights_hist(cleaned_coefs)

### Look at distribution of number of coefficients

In [None]:
def gc_io_nonzero_hist(non_zero_coefs, figure=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize=(3,2))
        
    ax = add_offset_axes(figure, (0.2, 0.2, 0.7, 0.7), frame=frame)
    ax.hist(non_zero_coefs[:, 1:9].sum(1), np.arange(0, 15), 
            label="GC (mn={:2.1f})".format(np.nanmean(non_zero_coefs[:, 1:9].sum(1))), alpha=0.6)
    ax.hist(non_zero_coefs[:, 9:].sum(1), np.arange(0, 15), 
            label="IO (mn={:2.1f})".format(np.nanmean(non_zero_coefs[:, 9:].sum(1))), alpha=0.6)
    ax.legend()
    ax.set_xlabel("N. of coefs")
    ax.set_ylabel("Count")

In [None]:
gc_io_nonzero_hist(non_zero_coefs)

### GC-IO index

In [None]:
from scipy import stats

Create "GC vs IO-ness index", looking at the ratio of coefficients of GC clusters and IO clusters:

In [None]:
w_gc = np.abs(coefs_final_sel[:, 1:n_gc_clust+1]).mean(1)
w_io = np.abs(coefs_final_sel[:, n_gc_clust+1:]).mean(1)
gc_io_idx = (w_gc - w_io) / (w_gc + w_io)

In [None]:
#Create a IO-GC colormap
from matplotlib.colors import LinearSegmentedColormap
colors = [sns.color_palette()[1], (.72,)*3, sns.color_palette()[0]]
n_bins = 200
cmap_name = 'contributions_cmap'
custom_cmap = LinearSegmentedColormap.from_list(cmap_name, colors, N=n_bins)

In [None]:
def indexes_plot(gc_io_idx, rel_idxs, figure=None, frame=None):
    r = stats.spearmanr(gc_io_idx[~np.isnan(gc_io_idx)], rel_idxs[~np.isnan(gc_io_idx)])
    
    
    if figure is None:
        figure = plt.figure(figsize=(3., 2.))
    
    ax = add_offset_axes(figure, (0.2, 0.2, 0.7, 0.7), frame=frame)
    ax.axvline(0, c = (0.6, )*3, zorder=-100)
    ax.scatter(gc_io_idx, rel_idxs, c=gc_io_idx, cmap=custom_cmap, s=6, vmin=-1, vmax=1, 
               edgecolor=(0.3,)*3, linewidths=0.1)
    ax.set_xlim(-1.1, 1.1)
    ax.set_ylabel("Reliability index")
    ax.set_xlabel("GC/IO idx")
#     numbers = 
    ax.text(-1, 0.7, "$\\rho$" + "={:1.2f} \np={:1.2}".format(r.correlation, r.pvalue), fontsize=7, color=(0.3,)*3)


In [None]:
indexes_plot(gc_io_idx, rel_idx_sel)

### Single cell example fit

In [None]:
11, 17, 32, 33, (44), 71

37, 60, 68

In [None]:
def make_single_cell_plot(traces, all_coefs, roi_test_idxs, idx, 
                          legend=True, bar=True, figure=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize=(7, 2))
    coefs = all_coefs[idx, :]
    roi_test_idxs = test_idxs[idx]
    reshaped = traces[idx, :, roi_test_idxs]
    
    ylim = 3.7
    axtrace = add_offset_axes(figure, (0.0, 0.1, 0.58, 0.92), frame=frame)
    axtrace.plot(stim[:, 0], reshaped.T, color=sns.color_palette()[2], linewidth=0.5)
    axtrace.plot(stim[:, 0], np.nanmean(reshaped, 0), color=sns.color_palette()[2], linewidth=2, label="PC trace")
    axtrace.plot(stim[:, 0], nanzscore(offset_cluster_combine(coefs, regressors_mat.T)), color="k", label="Fit")
    shade_plot(stim, axtrace, shade_range=(0.75, 0.98))
    
    axtrace.set_ylim(-ylim, ylim)
    axtrace.set_xlim(0, stim[-1, 0])
    axtrace.spines["left"].set_visible(False)
    axtrace.set_yticks([])
    if bar:
        make_bar(axtrace, (stim[-1, 0]-10, stim[-1, 0]), label="10 s")
    else:
        axtrace.spines["bottom"].set_visible(False)
        axtrace.set_xticks([])
    
    # Legend
    handles, labels = axtrace.get_legend_handles_labels()
    unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]]
    if legend:
        plt.legend(*zip(*unique), loc="lower left", fontsize=7)
    
    # Inserts
    ylim = 1.5
    y_off = -0.5
    coefs_gc = np.zeros(coefs.shape)
    coefs_io = np.zeros(coefs.shape)
    coefs_gc[1:9] = coefs[1:9]
    coefs_io[9:] = coefs[9:]
    for n, (ax_pos, lab, c_coefs) in enumerate(zip([(0.6, 0.6, 0.4, 0.4), (0.6, 0.1, 0.4, 0.4)], 
                                                   ["GC", "IO"],
                                                   [coefs_gc, coefs_io])):
        col = sns.color_palette()[n]
        ax_little = add_offset_axes(figure, ax_pos, frame=frame)
        ax_little.plot(stim[:, 0], offset_cluster_combine(c_coefs, regressors_mat.T), color=col)
        shade_plot(stim, ax_little, shade_range=(0.75, 0.98))
        ax_little.set_ylim(-ylim-y_off, ylim-y_off)
        ax_little.axis("off")
        
        if legend:
            ax_little.text(2, -ylim-y_off + 0.2, lab + " contribution", color=col, fontsize=7)

cells  # (gc_io idx sort): 
- 97, 254 example IO
- 197, example comb
- 89, 115, 60 example comb

In [None]:
97, 197, 89

In [None]:
sel = costs_final_sel < 0.9
idx = np.argsort(gc_io_idx[sel])[-11]
make_single_cell_plot(clean_traces_sel, coefs_final_sel, fit_idxs_sel, idx, figure=None, frame=None)
# for idx in [98]:
#     make_single_cell_plot(clean_traces_sel, coefs_final_sel, fit_idxs_sel, idx, figure=None, frame=None)

### Time contribution plot

In [None]:
time_contributions_mat = np.zeros((n_rep_timepts, n_valid_rois))
coefs_gc = np.zeros(coefs_final_sel.shape)
coefs_io = np.zeros(coefs_final_sel.shape)
coefs_gc[:, 1:9] = coefs_final_sel[:, 1:9]
coefs_io[:, 9:] = coefs_final_sel[:, 9:]

for i_roi in range(n_valid_rois):
       
    gc_contribution = offset_cluster_combine(coefs_gc[i_roi, :], regressors_mat.T)
    io_contribution = offset_cluster_combine(coefs_io[i_roi, :], regressors_mat.T)
    
    contribution_ratio = (np.abs(gc_contribution) - np.abs(io_contribution)) / \
                (np.abs(gc_contribution) + np.abs(io_contribution))
    
    time_contributions_mat[:, i_roi] = contribution_ratio

In [None]:
def time_contr_plot(time_contributions_mat, stim, figure=None, frame=None):
    
    if figure is None:
        figure = plt.figure(figsize=(7, 1.5))
    
    ax = add_offset_axes(figure, (0.1, 0.1, 0.9, 0.9), frame=frame)

    # Find quartiles:
    median_contr = np.nanmedian(time_contributions_mat, 1)
    low_quart_contr = np.nanquantile(time_contributions_mat, 0.25, axis=1)
    high_quart_contr = np.nanquantile(time_contributions_mat, 0.75, axis=1)
    
    # Plot trace and filling:
    i_col=5
    ax.plot(stim[:, 0], median_contr, color=sns.color_palette()[i_col])
    ax.fill_between(stim[:,0], low_quart_contr, high_quart_contr, facecolor=sns.color_palette()[i_col], 
                         alpha=.3, zorder=100, edgecolor=None)
    shade_plot(stim, ax, shade_range=(0.75, 0.98))
    ax.axhline(0, c = (0.3,)*3, zorder=1)
    ax.set_ylim(-1., 1.)
    ax.set_xlim(0, stim[-1, 0])
#     for y_line in [-1, 1]:
#         ax.axhline(y_line, color=(0.3,)*3, linewidth=0.5)
    
    make_bar(ax, (stim[-1, 0]-10, stim[-1, 0]), label="10 s")
    ax.set_ylabel('Contribution ratio')
    ax.text(0.5, 1.05, 'GC contributions dominates', fontsize=7, color=sns.color_palette()[0])
    ax.text(0.5, -1.15, 'IO contributions dominates', fontsize=7, color=sns.color_palette()[1])

In [None]:
time_contr_plot(time_contributions_mat, stim)

# Assemble final panel

In [None]:
import matplotlib.patches as mpatches
import matplotlib.lines as lines

In [None]:
clust_names = ["IO{}".format(i+1) for i in range(n_io_clust)] + ["GC{}".format(i+1) for i in range(n_gc_clust)]
cols = [sns.color_palette()[1]]*n_io_clust + [sns.color_palette()[0]]*n_gc_clust

In [None]:
# Plot the regressor panel:
def reg_panel_plot(regressors_mat, figure=None, ax=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize=(3, 3))

    offset = 0.01
    if ax is None:
        ax = add_offset_axes(figure, (0., 0., 1, 1), frameon=False, frame=frame)
    cols = sns.color_palette()
    for i, col in enumerate([cols[0], ]*n_gc_clust + [cols[1], ]*n_io_clust):
        ax.fill_between(stim[:, 0], np.zeros(stim[:, 0].shape) - i*offset, 
                        regressors_mat[i, :] - i*offset, color=col)
        
    ax.axis("off")
    
# Plot a trace panel:
def little_trace_plot(clean_traces, i=0, figure=None, ax=None, frame=None):
    if figure is None:
        figure = plt.figure(figsize=(3, 1))

    if ax is None:
        ax = add_offset_axes(figure, (0., 0., 1, 1), frameon=False, frame=frame)
    offset = 0.01
    ax.plot(clean_traces[i, :, :1].T.flatten(), c=sns.color_palette()[2])
    ax.axis("off")

In [None]:
def schema_panel(clean_traces, regressors_mat, i=9, figure=None, frame=None):
    if figure is None:        
        figure = plt.figure(figsize=(7, 3))

    schema_y = 0.3

    h_cent = 0.5
    trace_h = 0.2
    reg_h = 0.9
    schema_h = 0.9
    ax_trace = add_offset_axes(figure, [0.7, h_cent - trace_h/2, 0.3, trace_h], frameon=False, frame=frame)
    ax_regressors = add_offset_axes(figure, [0.02, h_cent - reg_h/2 - 0.01, 0.3, reg_h], frameon=False, frame=frame)
    ax_schema = add_offset_axes(figure, [0.31, h_cent - schema_h/2, 0.4, schema_h], frameon=False, aspect=1., frame=frame)
    little_trace_plot(clean_traces, i=18, figure=figure, ax=ax_trace)

    reg_panel_plot(regressors_mat, figure=figure, ax=ax_regressors)

    ax_schema.xaxis.set_visible(False)
    ax_schema.yaxis.set_visible(False)

    n_clust = 14
    n_gc_clust
    n_io_clust
    c_cent = (0.6, 0.5)  # x, y
    p = mpatches.Circle(c_cent, 0.1, edgecolor="k", facecolor="w")

    w_pos_x = 0.0
    line_displ_x = 0.2
    for i_clust, (label, col) in enumerate(zip(clust_names, cols)):
        c = 1/(2*(n_clust + 1)) + i_clust/(n_clust + 1)
        l = lines.Line2D([w_pos_x + line_displ_x, c_cent[0]], [c, c_cent[1]], c="k", zorder=-100)
        ax_schema.add_line(l)
        ax_schema.text(w_pos_x, c, "$\cdot w_i^{" + label + "}$", ha="left", va="center", fontsize=7, color=col)

    c = 1/(2*n_clust) + (i_clust + 1)/(n_clust + 1)
    l = lines.Line2D([w_pos_x + line_displ_x, c_cent[0]], [c, c_cent[1]], c="k", zorder=-100)
    ax_schema.add_line(l)
    ax_schema.text(w_pos_x, c, "$offset_i$", ha="left", va="center", fontsize=7, color="k")

    ax_schema.add_patch(p)

    l = lines.Line2D([c_cent[0], 1], [c_cent[1], c_cent[1]], c="k", zorder=-100)
    ax_schema.add_line(l)

    # Text
    ax_schema.text(*c_cent, "$\sum$", ha="center", va="center", fontsize=9)
    ax_schema.text(0.55, 0.9, "$PC_i = regr^{GC} \cdot w_i^{GC} + regr^{IO} \cdot w_i^{IO} + offset_i$", 
                   ha="left", va="center", fontsize=7.5)
    ax_trace.text(0, -3, "$PC_i$", ha="left", va="center", fontsize=9, color=sns.color_palette()[2])

In [None]:
fig = plt.figure(figsize=(7, 9))
schema_panel(clean_traces, regressors_mat, i=9, figure=fig, frame=[0.25, 0.77, 0.6, 0.2])

offset_y = 0.13
start_y = 0.36
for i, idx in enumerate([97, 197, 89]):
    y = start_y + i*offset_y
    
    make_single_cell_plot(clean_traces_sel, coefs_final_sel, fit_idxs_sel, idx, figure=fig, frame=[0., y, 0.53, 0.12], legend=(i==2), bar=(i==0))

coefs_plot(clust_lab_sel, coefs_final_sel, figure=fig, frame=[0.58, 0.56, 0.4, 0.18])
indexes_plot(gc_io_idx, rel_idx_sel, figure=fig, frame=[0.58, 0.35, 0.35, 0.2])
time_contr_plot(time_contributions_mat, stim, figure= fig, frame=[0.1, 0.17, 0.65, 0.15])

In [None]:
if fig_fold is not None:
    fig.savefig(fig_fold / "fig7_panel.pdf", forma="pdf")

## Supplementary

In [None]:
figsupp = plt.figure(figsize=(6, 4))
s = 0.45
cost_figure(costs_final, costs_shuf, cost_threshold, figure=figsupp, frame=[0., 0.5, s, s])
error_vs_reliabil(rel_idx_sel, costs_final_sel, figure=figsupp, frame=[0.5, 0.5, s, s])
# gc_io_weights_hist(cleaned_coefs, figure=figsupp, frame=[0., 0., s, s])
gc_io_nonzero_hist(non_zero_coefs, figure=figsupp, frame=[0.25, 0., s, s])
# sorted_examples_plot(coefs_final, gc_io_idx, clean_traces, test_idxs, figure=figure6supp, frame=[0.5, 0.1, 0.5, 0.8])

In [None]:
if fig_fold is not None:
    figsupp.savefig(fig_fold.parent.parent / (fig_fold.parent.name + "supp/src") / "supp_panel.pdf", format="pdf")