In [1]:
%matplotlib widget

In [2]:
import flammkuchen as fl
from matplotlib import pyplot as plt
import numpy as np
from scipy import io
import os
import seaborn as sns
import matplotlib.gridspec as gridspec
from matplotlib.ticker import FormatStrFormatter
from pathlib import Path
import pandas as pd
from luminance_analysis import PooledData, traces_stim_from_path, get_meanresp_during_interval

plt.style.use("figures.mplstyle")

In [3]:
fig_fold = Path(r"C:\Users\otprat\Documents\figures\luminance\manuscript_figures\fig2")

if not os.path.isdir(fig_fold):
    os.mkdir(fig_fold)

## Load all data

In [4]:
master_path = Path(r"\\FUNES\Shared\experiments\E0032_luminance\neat_exps")

In [5]:
from luminance_analysis.utilities import deconv_resamp_norm_trace, reliability, nanzscore
from skimage.filters import threshold_otsu
from scipy.cluster.hierarchy import dendrogram, linkage, cut_tree, to_tree, set_link_color_palette
from luminance_analysis.plotting import plot_clusters_dendro, shade_plot, stim_plot, cluster_cols
from luminance_analysis.clustering import cluster_id_search, find_trunc_dendro_clusters

In [6]:
tau_6f = 5
tau_6s = 8
ker_len = 20
normalization = "zscore"
protocol = 'steps'

brain_regions_list = ["GC", "IO"]
tau_list = [tau_6f, tau_6f, tau_6s]
n_cluster_list = [8, 6, 8]
nan_thr_list = [0, 1, 1]

data_dict = {k:{} for k in brain_regions_list}

#load stimulus of GCs and use it as a the reference for time array and stimulus array:
stim_ref = PooledData(path = master_path / protocol / "GC").stimarray_rep

for brain_region, tau, n_cluster, nan_thr in zip(brain_regions_list, tau_list, 
                                                 n_cluster_list, nan_thr_list):
    #Load data :
    path = master_path / protocol / brain_region
    stim, traces, meanresps = traces_stim_from_path(path)

    # Mean traces, calculate reliability index :
    rel_idxs = reliability(traces)
    
    # Find threshold from reliability histogram...
    rel_thr = threshold_otsu(rel_idxs[~np.isnan(rel_idxs)])

    # ...and load again filtering with the threshold:
    _, traces, meanresps, pooled_data = traces_stim_from_path(path, resp_threshold=rel_thr, nanfraction_thr=nan_thr, return_pooled_data=True)

    # Hierarchical clustering:
    linked = linkage(meanresps, 'ward')
    
    # Truncate dendrogram at n_cluster level:
    plt.figure(figsize=(0.1, 0.1))  
    dendro = dendrogram(linked, n_cluster, truncate_mode ="lastp")
    plt.close()
    cluster_ids = dendro["leaves"]
    labels = find_trunc_dendro_clusters(linked, dendro) 
    
    # Deconvolution, resampling / normalization:
    deconv_meanresps = np.empty((meanresps.shape[0], stim_ref.shape[0]))
    resamp_meanresps = np.empty((meanresps.shape[0], stim_ref.shape[0]))
    for roi_i in range(deconv_meanresps.shape[0]):
        deconv_meanresps[roi_i, :] = deconv_resamp_norm_trace(meanresps[roi_i, :], stim[:, 0],
                                                                stim_ref[:, 0], tau, ker_len,
                                                                smooth_wnd=4,
                                                                normalization=normalization)
        resamp_meanresps[roi_i, :] = deconv_resamp_norm_trace(meanresps[roi_i, :], stim[:, 0],
                                                                stim_ref[:, 0], None, ker_len,
                                                                smooth_wnd=4,
                                                                normalization=normalization)
    
    cluster_resps = np.empty((n_cluster, stim_ref.shape[0]))
    for clust_i in range(n_cluster):
        cluster_resp = np.nanmean(deconv_meanresps[labels==clust_i, :], 0)  # average cluster responses
        cluster_resps[clust_i, :] = nanzscore(cluster_resp)  # normalize

    # Get mean responses of all ROIs to the different luminances of the two step series (upward and downward)
    start_after_stim = 2.5
    post_int_s = 5
    start_after_isi = 4.5
    post_isi_s = 7
    ##Calculate average activity during a luminance value when it was approached from below:
    up_trans_t_sec = np.array([34, 46, 58]) #Times [s.] of upward transitions
    resp_upward = get_meanresp_during_interval(path, up_trans_t_sec, start_after_stim, post_int_s, resp_threshold=rel_thr, nanfraction_thr=nan_thr)
    resp_upward_isi = get_meanresp_during_interval(path, up_trans_t_sec-7, start_after_isi, post_isi_s, resp_threshold=rel_thr, nanfraction_thr=nan_thr)
    ##Calculate average activity during a luminance value when it was approached from above:
    down_trans_t_sec = np.array([101, 89, 77]) #Times [s.] of downward transitions (sorted to match order of upward transitions
    resp_downward = get_meanresp_during_interval(path, down_trans_t_sec, start_after_stim, post_int_s, resp_threshold=rel_thr, nanfraction_thr=nan_thr)
    resp_downward_isi = get_meanresp_during_interval(path, down_trans_t_sec-7, start_after_isi, post_isi_s, resp_threshold=rel_thr, nanfraction_thr=nan_thr)

    # Add everything to dictionary:
    data_dict[brain_region]["linkage_mat"] = linked
    data_dict[brain_region]["clust_labels"] = labels
    data_dict[brain_region]["pooled_data"] = pooled_data
    data_dict[brain_region]["raw_mn_resps"] = meanresps
    data_dict[brain_region]["deconv_mn_resps"] = deconv_meanresps
    data_dict[brain_region]["resamp_mn_resps"] = resamp_meanresps
    data_dict[brain_region]["rel_idxs"] = rel_idxs
    data_dict[brain_region]["rel_thr"] = rel_thr
    data_dict[brain_region]["clust_resps"] = cluster_resps
    data_dict[brain_region]["resp_upward"] = resp_upward
    data_dict[brain_region]["resp_downward"] = resp_downward
    data_dict[brain_region]["resp_upward_isi"] = resp_upward_isi
    data_dict[brain_region]["resp_downward_isi"] = resp_downward_isi

[<luminance_analysis.FishData object at 0x000001A534F30A08>, <luminance_analysis.FishData object at 0x000001A534F30A88>, <luminance_analysis.FishData object at 0x000001A534F204C8>, <luminance_analysis.FishData object at 0x000001A534F40C88>, <luminance_analysis.FishData object at 0x000001A534F2B488>]
[<luminance_analysis.FishData object at 0x000001A534F11D08>, <luminance_analysis.FishData object at 0x000001A534F11C88>, <luminance_analysis.FishData object at 0x000001A534F20AC8>, <luminance_analysis.FishData object at 0x000001A534F40948>, <luminance_analysis.FishData object at 0x000001A534F2B708>]


  c /= stddev[:, None]
  c /= stddev[None, :]


[<luminance_analysis.FishData object at 0x000001A5353B79C8>, <luminance_analysis.FishData object at 0x000001A5353B7A48>, <luminance_analysis.FishData object at 0x000001A5353BE048>, <luminance_analysis.FishData object at 0x000001A5353C0608>, <luminance_analysis.FishData object at 0x000001A5353D1BC8>]


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<luminance_analysis.FishData object at 0x000001A535590908>, <luminance_analysis.FishData object at 0x000001A535590988>, <luminance_analysis.FishData object at 0x000001A535593E48>, <luminance_analysis.FishData object at 0x000001A53559B3C8>, <luminance_analysis.FishData object at 0x000001A5355A7908>]
[<luminance_analysis.FishData object at 0x000001A535599188>, <luminance_analysis.FishData object at 0x000001A535599BC8>, <luminance_analysis.FishData object at 0x000001A53559BEC8>, <luminance_analysis.FishData object at 0x000001A5355A48C8>, <luminance_analysis.FishData object at 0x000001A5355A1E88>]
[<luminance_analysis.FishData object at 0x000001A535599108>, <luminance_analysis.FishData object at 0x000001A535599D08>, <luminance_analysis.FishData object at 0x000001A53559B248>, <luminance_analysis.FishData object at 0x000001A5355A4188>, <luminance_analysis.FishData object at 0x000001A5355A7588>]
[<luminance_analysis.FishData object at 0x000001A535475A48>, <luminance_analysis.FishData object 

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<luminance_analysis.FishData object at 0x000001A535292508>, <luminance_analysis.FishData object at 0x000001A535292588>, <luminance_analysis.FishData object at 0x000001A53529DA88>, <luminance_analysis.FishData object at 0x000001A535298308>, <luminance_analysis.FishData object at 0x000001A5352A6B48>]
[<luminance_analysis.FishData object at 0x000001A535295E08>, <luminance_analysis.FishData object at 0x000001A535295D88>, <luminance_analysis.FishData object at 0x000001A535298B08>, <luminance_analysis.FishData object at 0x000001A535494908>, <luminance_analysis.FishData object at 0x000001A5352929C8>]
[<luminance_analysis.FishData object at 0x000001A5352A02C8>, <luminance_analysis.FishData object at 0x000001A5352A0388>, <luminance_analysis.FishData object at 0x000001A5352986C8>, <luminance_analysis.FishData object at 0x000001A53548E6C8>, <luminance_analysis.FishData object at 0x000001A535292E48>]
[<luminance_analysis.FishData object at 0x000001A535494848>, <luminance_analysis.FishData object 

### Clustering overview

In [7]:
%load_ext autoreload
%autoreload

from luminance_analysis.plotting import plot_clusters_dendro, re_histogram

In [8]:
colors = sns.color_palette()[:2]
fig_clust = plt.figure(figsize=(7, 6))
for i, (k, dendrolim, spacing, cbar, x_pos, color) in enumerate(zip(["GC", "IO"], [1040, 112],
                                                [3, 5], [False, False], [0.47, 0.], colors)):
    f_hist = re_histogram(data_dict[k]["rel_idxs"], data_dict[k]["rel_thr"], fig_clust,  
                          w=0.18, h=0.1, w_p=0.04, h_p=x_pos+0.4125, color=color)
    
    meanresps = data_dict[k]["resamp_mn_resps"]  # data_dict[k]["deconv_mn_resps"]
    smooth_mean_resps = pd.DataFrame(meanresps.T).rolling(4, center=True).mean().values.T
    
    fig_clust = plot_clusters_dendro(smooth_mean_resps, stim_ref,
                                 data_dict[k]["linkage_mat"], data_dict[k]["clust_labels"], prefix=k,
                                 figure=fig_clust, w=1., h=0.65, w_p=0.1, h_p=x_pos, f_lim=2,
                                 dendrolims=(dendrolim, 0), gamma=0.4, spacing=spacing, colorbar=cbar)
    
fig_clust.text(.01,.98, 'A')
fig_clust.text(.01,.51, 'B')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  plt.tight_layout()


[ 108  328  819  965 1023 1211 1595 2365]


  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -
  plt.tight_layout()
  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -


[18 32 41 45 54 66]


Text(0.01, 0.51, 'B')

In [9]:
if fig_fold is not None:
    fig_clust.savefig(str(fig_fold / "Clustering.pdf"))

### Supplementary figures

In [10]:
fig_fold = Path(r"C:\Users\otprat\Documents\figures\luminance\manuscript_figures\fig2supp")

if not os.path.isdir(fig_fold):
    os.mkdir(fig_fold)

In [11]:
%load_ext autoreload
from luminance_analysis.plotting import make_bar, get_yg_custom_cmap, add_offset_axes, shade_plot, stim_plot
%autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### Stimulus plot

In [12]:
stim_arr = PooledData(path = master_path / protocol / "GC").stimarray

[<luminance_analysis.FishData object at 0x000001A546086708>, <luminance_analysis.FishData object at 0x000001A546086788>, <luminance_analysis.FishData object at 0x000001A54608FCC8>, <luminance_analysis.FishData object at 0x000001A546083388>, <luminance_analysis.FishData object at 0x000001A546091888>]


In [13]:
fig_stim = stim_plot(stim_arr, xlims = (0, 108), gamma=0.4, figure=None, frame=None)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [14]:
if fig_fold is not None:
    fig_stim.savefig(str(fig_fold / "Stimulation_protocol.pdf"))

#### Fish contributions to each cluster

In [15]:
def plot_fish_contribution(data_dict, figure=None, frame=None):
    
    if figure is None:
        figure = plt.figure(figsize=(6, 3))
        
    barWidth = 0.85
    n_fish = 5
    colors = (sns.color_palette("deep", 10))
    fish_contribution = {brain_region:{} for brain_region in brain_regions_list}
        
    for i, brain_region in enumerate(brain_regions_list):
        
        ax_hist = add_offset_axes(figure, (0.05 + 0.5*i, 0.15, .4, .7), frame=frame)
        
        clusters = np.unique(data_dict[brain_region]['clust_labels'])
        n_clust = clusters.shape[0]
        roi_map = data_dict[brain_region]['pooled_data'].roi_map

        for fish in range(n_fish):
            fish_labels = data_dict[brain_region]['clust_labels'][roi_map[0, :] == fish]
            fish_contribution[brain_region]['{} Fish {}'.format(brain_region, fish+1)] = np.array([np.sum(fish_labels == c) for c in range(n_clust)])

        contributions_df = pd.DataFrame(fish_contribution[brain_region])
        for i, c in enumerate(contributions_df.columns):
            ax_hist.bar(clusters+1, contributions_df[c], bottom=sum([contributions_df[prev] for prev in list(contributions_df.columns)[:i]]),
                    width=barWidth, label=c, color=colors[i])

#         ax_hist.legend(bbox_to_anchor=(1,1))
        ax_hist.set_xlabel("Cluster #")
        ax_hist.set_ylabel("Number of ROIs")
        ax_hist.text(.5, 1, brain_region, ha='center', va='top', transform=ax_hist.transAxes, fontsize=8.5)
        plt.tight_layout()
                
    return(figure)

In [16]:
fig_fish_contrib = plot_fish_contribution(data_dict)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …



In [17]:
if fig_fold is not None:
    fig_fish_contrib.savefig(str(fig_fold / "Fish_contributions.pdf"))

#### Mean response to all flash levels by cluster

In [18]:
from scipy import stats

In [19]:
# def plot_fish_steps_resps(data_dict, brain_region, figure=None, frame=None):
    
#     if figure is None:
#         figure = plt.figure(figsize=(4, 5))
         
#     clusters = np.unique(data_dict[brain_region]['clust_labels'])
#     step_series = ['resp_upward', 'resp_downward']
#     series_titles = ['Up transitions', 'Down transitions']
#     x_ticks = [[0, 1, 2, 3], ['Dark', '5%', '20%', '100%']]
#     ylims = [-2, 30]
#     offset = 4
#     alpha = 0.01
#     n_tests = 48
#     alpha_corrected=alpha/n_tests
    
#     for i, (serie, title) in enumerate(zip(step_series, series_titles)):        
#         ax_hist = add_offset_axes(figure, (0.1+0.4*i, 0.15, .35, .8), frame=frame)
            
#         for cluster, color in zip(clusters, cluster_cols()):
#             cluster_resps = data_dict[brain_region][serie][:, data_dict[brain_region]['clust_labels'] == cluster]
#             isi_cluster_resps = data_dict[brain_region][serie+'_isi'][:, data_dict[brain_region]['clust_labels'] == cluster]

#             if serie == 'resp_upward':
#                 insert_idx = 0
#             elif serie == 'resp_downward':
#                 insert_idx = 3
#             cluster_resps = np.insert(cluster_resps, insert_idx, np.nanmean(isi_cluster_resps, 0), 0)

#             ax_hist.errorbar(x_ticks[0], np.nanmean(cluster_resps, 1) + offset*cluster, yerr=np.std(cluster_resps, 1), capsize=3, c=color, ls='none', marker='o', markersize=3)
#             ax_hist.axhline(offset*cluster, c='gray', alpha=.3, ls=':')
#             if serie == 'resp_upward':
#                 ax_hist.set_yticks(offset*clusters)
#                 ax_hist.set_yticklabels(np.zeros_like(clusters))
#                 ax_hist.set_ylabel('Average response during luminance step')
#             else:
#                 ax_hist.set_yticks([])
                
#             ax_hist.set_xlabel('Luminance')
#             ax_hist.set_xticks(x_ticks[0])
#             ax_hist.set_xticklabels(x_ticks[1])
#             ax_hist.set_ylim(ylims)
#             ax_hist.text(.5,1, title, ha='center', va='top', transform=ax_hist.transAxes, fontsize=7)
                        
# #             #T-tests
# #             x_pos = [(x_ticks[0][i]+x_ticks[0][i+1])/2 for i in [0, 1, 2]]
# #             for i in [0, 1, 2]:
# #                 data1 = cluster_resps[i, :]
# #                 data2 = cluster_resps[i+1, :]
                
# #                 d, pval = stats.ttest_ind(data1, data2)
# #                 y_pos = ((np.nanmean(data1)+np.nanmean(data2))/2)+cluster*offset
                
# #                 if pval < alpha:
# #                     ax_hist.text(x_pos[i], y_pos, '*', ha='center', color=color)
# #                 else:
# #                     ax_hist.text(x_pos[i], y_pos, 'n.s.', fontsize=7, ha='center', va='bottom', color=color)
                    
#     plt.tight_layout()
    
#     return(figure)

In [20]:
def plot_fish_steps_resps(data_dict, brain_region, figure=None, frame=None):
    
    if figure is None:
        figure = plt.figure(figsize=(4, 7))
         
    clusters = np.unique(data_dict[brain_region]['clust_labels'])
    step_series = ['resp_upward', 'resp_downward']
    series_titles = ['Up transitions', 'Down transitions']
    x_ticks = [[0, 1, 2, 3], ['Dark', '5%', '20%', '100%']]
    ylims = [-2, 30]
    offset = 4
    alpha = 0.01
    n_tests = 48
    alpha_corrected=alpha/n_tests
    
    for i, (serie, title) in enumerate(zip(step_series, series_titles)):        
        ax_hist = add_offset_axes(figure, (0.1+0.4*i, 0.15, .35, .75), frame=frame)
            
        for cluster, color in zip(clusters, cluster_cols()):
            cluster_resps = data_dict[brain_region][serie][:, data_dict[brain_region]['clust_labels'] == cluster]
            isi_cluster_resps = data_dict[brain_region][serie+'_isi'][:, data_dict[brain_region]['clust_labels'] == cluster]

            if serie == 'resp_upward':
                insert_idx = 0
            elif serie == 'resp_downward':
                insert_idx = 3
            cluster_resps = np.insert(cluster_resps, insert_idx, np.nanmean(isi_cluster_resps, 0), 0)

            ax_hist.errorbar(x_ticks[0], np.nanmean(cluster_resps, 1) + offset*cluster, yerr=np.std(cluster_resps, 1), capsize=3, c=color, ls='none', marker='o', markersize=3)
            ax_hist.axhline(offset*cluster, c='gray', alpha=.3, ls=':')
            if serie == 'resp_upward':
                ax_hist.set_yticks(offset*clusters)
                ax_hist.set_yticklabels(np.zeros_like(clusters))
                ax_hist.set_ylabel('Average response during luminance step')
            else:
                ax_hist.set_yticks([])
                
            ax_hist.set_xlabel('Luminance')
            ax_hist.set_xticks(x_ticks[0])
            ax_hist.set_xticklabels(x_ticks[1])
            ax_hist.set_ylim(ylims)
#             ax_hist.text(.5,1, title, ha='center', va='top', transform=ax_hist.transAxes, fontsize=7)
                        
#             #T-tests
#             x_pos = [(x_ticks[0][i]+x_ticks[0][i+1])/2 for i in [0, 1, 2]]
#             for i in [0, 1, 2]:
#                 data1 = cluster_resps[i, :]
#                 data2 = cluster_resps[i+1, :]
                
#                 d, pval = stats.ttest_ind(data1, data2)
#                 y_pos = ((np.nanmean(data1)+np.nanmean(data2))/2)+cluster*offset
                
#                 if pval < alpha:
#                     ax_hist.text(x_pos[i], y_pos, '*', ha='center', color=color)
#                 else:
#                     ax_hist.text(x_pos[i], y_pos, 'n.s.', fontsize=7, ha='center', va='bottom', color=color)
                    
    for i, (serie, title, xlims) in enumerate(zip(step_series, series_titles, [[35.5,71.5], [74.8,111]])):        
        ax_hist = add_offset_axes(figure, (0.1+0.4*i, 0.925, .35, 1), frame=frame)
        ax_hist.plot(stim_arr[0, :], stim_arr[1, :])
        ax_hist.set_xlim(xlims)
        ax_hist.set_ylim(-.25,25)
        ax_hist.axis('off')
        
        
        if i==0:
            ax_hist.text(41.5,0.7, '5%', ha='center', va='top', fontsize=7)
            ax_hist.text(54,0.85, '20%', ha='center', va='top', fontsize=7)
            ax_hist.text(65.5,1.65, '100%', ha='center', va='top', fontsize=7)
            ax_hist.text(48,-0.2, 'Dark', ha='center', va='top', fontsize=7)
        else:
            ax_hist.text(78.5,1.5, '100%', ha='center', va='top', fontsize=7)
            ax_hist.text(85,0.05, '20%', ha='center', va='top', fontsize=7)
            ax_hist.text(96.5,-0.1, '5%', ha='center', va='top', fontsize=7)
            ax_hist.text(108.5,-0.2, 'Dark', ha='center', va='top', fontsize=7)

#     plt.tight_layout()
    
    return(figure)

In [21]:
fig_steps_resps = plot_fish_steps_resps(data_dict, 'GC')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [22]:
if fig_fold is not None:
    fig_steps_resps.savefig(str(fig_fold / "Average responses to luminance steps.pdf"))

#### Plot individual traces

In [23]:
from luminance_analysis.roi_display import overimpose_shade, merge_anatomy_and_mask

In [24]:
example_rois = [1777, 29, 430, 1032]

In [25]:
def plot_roi_traces(data_dict, brain_region, rois, stimulus_array, figure=None, frame=None):
    
    if figure is None:
        figure = plt.figure(figsize=(6.5, 2.5))
        
    ybarlength = 4
    xbarlength = 10
    pooled_data = data_dict[brain_region]['pooled_data']
    traces = pooled_data.traces
    stim = stimulus_array
    clusters = data_dict[brain_region]['clust_labels'][rois]

    for i, roi in enumerate(rois):        
        ax_trace = add_offset_axes(figure, (0.1, 0.1+0.2*i, .6, .2), frame=frame)
        ax_anato = add_offset_axes(figure, (0.75, 0.1+0.2*i, .1, .2), frame=frame)
        
        color = cluster_cols()[clusters[i]]

        #Plot ROI repetitions
        roi_traces_zscored = np.empty_like(traces[roi, :, :])
        for rep in range(traces[roi, :, :].shape[1]):
            roi_traces_zscored[:, rep] = nanzscore(traces[roi, :, rep])

        ax_trace.plot(stim[:, 0], roi_traces_zscored, c=color, alpha=0.065)
        ax_trace.plot(stim[:, 0], np.nanmean(roi_traces_zscored, 1), c=color)
        shade_plot((stim[:, 0], stim[:, 1]), ax=ax_trace)
        ax_trace.set_xlim([min(stim[:, 0]), max(stim[:, 0])])
        ax_trace.set_ylim((-1.8,5.2))
        
        #Plot scale bars
        if i == 0:
            # Y axis bar
            make_bar(ax_trace, [0, ybarlength], label="{} s.d. dF/F".format(ybarlength), orientation='vertical', lw=1)
            # X axis bar
            make_bar(ax_trace, [2, 2+xbarlength], label="{} s".format(xbarlength), lw=1)
        else:
            ax_trace.axis('off')   

        #Plot ROI anatomy
        anatomy_stack, mask_stack = pooled_data.get_roi_anatomy_stacks(roi, crop_around=25)
        figure_anatomy = merge_anatomy_and_mask(anatomy_stack, mask_stack, color, gamma=0.5)

        ax_anato.imshow(figure_anatomy)
        ax_anato.axis('off')

    plt.tight_layout()
    
    return(figure)

In [26]:
fig_traces = plot_roi_traces(data_dict, 'GC', example_rois, stim_ref, figure=None, frame=None)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  return (array - np.nanmean(array, axis=axis))/np.nanstd(array, axis=axis)
  keepdims=keepdims)
  return (array - np.nanmean(array, axis=axis))/np.nanstd(array, axis=axis)
  keepdims=keepdims)
  return (array - np.nanmean(array, axis=axis))/np.nanstd(array, axis=axis)
  keepdims=keepdims)
  return (array - np.nanmean(array, axis=axis))/np.nanstd(array, axis=axis)
  keepdims=keepdims)


In [27]:
if fig_fold is not None:
    fig_traces.savefig(str(fig_fold / "GC traces.pdf"))

#### Sensory history plot

In [28]:
def plot_sensory_history(data_dict, brain_region, figure=None, frame=None, align='h'):

    if figure is None:
        figure = plt.figure(figsize=(5, 5))
         
    labels = data_dict[brain_region]['clust_labels']
    clusters = np.unique(labels)
    resp_upward = data_dict[brain_region]['resp_upward'][:2, :]
    resp_downward = data_dict[brain_region]['resp_downward'][1:, :]
    step_series = ['resp_upward', 'resp_downward']
    series_titles = ['UP transitions', 'DOWN transitions']
    colors = cluster_cols()
    
    for i, title in enumerate(["5% luminance", "20% luminance"]):    
        if align == 'h':
            ax_scatter = add_offset_axes(figure, (0.1+0.45*i, 0.15, .35, .35), frame=frame)
        elif align == 'v':
            ax_scatter = add_offset_axes(figure, (0.1, 1-0.45*(1+i), .35, .35), frame=frame)
      
        for j in range(clusters.shape[0]):
            mnx = resp_upward[i,labels==j].mean()
            stdx = resp_upward[i,labels==j].std()
            mny = resp_downward[i,labels==j].mean()
            stdy = resp_downward[i,labels==j].std()
            ax_scatter.scatter(mnx, mny, color=colors[j], s=6, alpha=1, 
                    edgecolors=None)
            ax_scatter.plot([mnx, mnx], [mny - stdy/2, mny+stdy/2], color=colors[j])
            ax_scatter.plot([mnx - stdx/2, mnx+stdx/2], [mny, mny], color=colors[j])

        lims = [-1, 1.5]
        ticks = np.arange(-0.5, 1.6, 1)
        ax_scatter.plot(lims, lims, color="k", linewidth=0.4, zorder=-10)  # linear edge
        ax_scatter.set_aspect('auto')
        ax_scatter.set_xticks(ticks)

        ax_scatter.set_yticks(ticks)
#         ax_scatter.set_title(title)
        ax_scatter.text(.5,1.1, title, ha='center', va='top', transform=ax_scatter.transAxes, fontsize=7)        
        
        if align == 'h':
            ax_scatter.set_xlabel("Low-high transition")
            if i == 0:
                ax_scatter.set_ylabel("High-low transition")
            else:
                ax_scatter.set_yticklabels([])
        elif align == 'v':
            ax_scatter.set_ylabel("Low-high transition")
            if i == 1:
                ax_scatter.set_xlabel("High-low transition")
            else:
                ax_scatter.set_xticklabels([])

    plt.tight_layout()
    return(figure)

In [29]:
fig_sens_hist = plot_sensory_history(data_dict, 'GC', align='v')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …



In [30]:
if fig_fold is not None:
    fig_sens_hist.savefig(str(fig_fold / "Sensory_history_in_GCs.pdf"))

### Assemble figure

In [31]:
#OLD VERSION
# figureS2 = plt.figure(figsize=(7, 9))

# #Stimulus plot:
# stim_panel = stim_plot(stim_arr, xlims = (0, 108), gamma=0.4, figure=figureS2, frame=(0.05, 0.8, 0.35, 0.2))

# #Cluster step reponses plot:
# steps_resps_panel = plot_fish_steps_resps(data_dict, 'GC', figure=figureS2, frame=(0.35, 0.485, 0.75, 0.5))

# #Traces plot
# traces_panel = plot_roi_traces(data_dict, 'GC', example_rois, stim_ref, figure=figureS2, frame=(0.05, 0.275, 1, .25))

# #Fish contributions plot
# fish_contrib_panel = plot_fish_contribution(data_dict, figure=figureS2, frame=(0.05, 0.05, 1, .25))

# #Sensory history figure
# sens_hist_panel = plot_sensory_history(data_dict, 'GC',  figure=figureS2, frame=(0.05, 0.535, .5, .275), align='v')

In [32]:
figureS2 = plt.figure(figsize=(7, 9))

#Stimulus plot:
stim_panel = stim_plot(stim_arr, xlims = (0, 108), gamma=0.4, figure=figureS2, frame=(0.05, 0.8, 0.4, 0.2))
figureS2.text(.05,.98, 'A')

#Sensory history figure
sens_hist_panel = plot_sensory_history(data_dict, 'GC',  figure=figureS2, frame=(0.1, 0.535, .45, .275), align='v')
figureS2.text(.05,.8, 'C')

#Cluster step reponses plot:
steps_resps_panel = plot_fish_steps_resps(data_dict, 'GC', figure=figureS2, frame=(0.35, 0.485, 0.6, 0.5))
figureS2.text(.375,.98, 'B')

#Traces plot
traces_panel = plot_roi_traces(data_dict, 'GC', example_rois, stim_ref, figure=figureS2, frame=(0.05, 0.275, .9, .25))
figureS2.text(.1, .5, 'D')

#Fish contributions plot
fish_contrib_panel = plot_fish_contribution(data_dict, figure=figureS2, frame=(0.05, 0.05, .9, .25))
figureS2.text(.025, .275, 'E')


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  return (array - np.nanmean(array, axis=axis))/np.nanstd(array, axis=axis)
  keepdims=keepdims)
  return (array - np.nanmean(array, axis=axis))/np.nanstd(array, axis=axis)
  keepdims=keepdims)
  return (array - np.nanmean(array, axis=axis))/np.nanstd(array, axis=axis)
  keepdims=keepdims)
  return (array - np.nanmean(array, axis=axis))/np.nanstd(array, axis=axis)
  keepdims=keepdims)


Text(0.025, 0.275, 'E')

In [33]:
if fig_fold is not None:
    figureS2.savefig(str(fig_fold / "steps_supplementary.pdf"))