In [1]:
%load_ext autoreload
%autoreload
%matplotlib widget

In [2]:
import flammkuchen as fl
from matplotlib import pyplot as plt
import numpy as np
from scipy import io
import seaborn as sns
import matplotlib.gridspec as gridspec
from matplotlib.ticker import FormatStrFormatter
from pathlib import Path
import pandas as pd
import os
from luminance_analysis import PooledData, traces_stim_from_path

plt.style.use("figures.mplstyle")

In [3]:
fig_fold = Path(r"C:\Users\otprat\Documents\figures\luminance\manuscript_figures\fig4")

if not os.path.isdir(fig_fold):
    os.mkdir(fig_fold)

## Load all data

In [4]:
# master_path = Path(r"J:\_Shared\GC_IO_luminance\data\neat_exps")
master_path = Path(r"\\FUNES\Shared\experiments\E0032_luminance\neat_exps")

In [5]:
from luminance_analysis.utilities import deconv_resamp_norm_trace, reliability, nanzscore
from skimage.filters import threshold_otsu
from scipy.cluster.hierarchy import dendrogram, linkage, cut_tree, to_tree, set_link_color_palette
from luminance_analysis.plotting import plot_clusters_dendro, shade_plot
from luminance_analysis.clustering import cluster_id_search, find_trunc_dendro_clusters

In [6]:
tau_6f = 5
tau_6s = 8
ker_len = 20
normalization = "zscore"
protocol = 'flashes'

brain_regions_list = ["GC", "IO", "PC"]
tau_list = [tau_6f, tau_6f, tau_6s]
n_cluster_list = [4, 5, 4]
nan_thr_list = [0, 1, 1]

data_dict = {k:{} for k in brain_regions_list}

#load stimulus of GCs and use it as a the reference for time array and stimulus array:
stim_ref = PooledData(path = master_path / protocol / "GC").stimarray_rep

for brain_region, tau, n_cluster, nan_thr in zip(brain_regions_list, tau_list, 
                                                 n_cluster_list, nan_thr_list):
    #Load data :
    path = master_path / protocol / brain_region
    stim, traces, meanresps = traces_stim_from_path(path)

    # Mean traces, calculate reliability index :
    rel_idxs = reliability(traces)
    
    # Find threshold from reliability histogram...
    rel_thr = threshold_otsu(rel_idxs[~np.isnan(rel_idxs)])

    # ...and load again filtering with the threshold:
    _, traces, meanresps, pooled_data = traces_stim_from_path(path, resp_threshold=rel_thr, nanfraction_thr=nan_thr, return_pooled_data=True)

    # Hierarchical clustering:
    linked = linkage(meanresps, 'ward')
    
    # Truncate dendrogram at n_cluster level:
    plt.figure(figsize=(0.1, 0.1))  
    dendro = dendrogram(linked, n_cluster, truncate_mode ="lastp")
    plt.close()
    cluster_ids = dendro["leaves"]
    labels = find_trunc_dendro_clusters(linked, dendro) 
    
    # Deconvolution, resampling / normalization:
    deconv_meanresps = np.empty((meanresps.shape[0], stim_ref.shape[0]))
    resamp_meanresps = np.empty((meanresps.shape[0], stim_ref.shape[0]))
    for roi_i in range(deconv_meanresps.shape[0]):
        deconv_meanresps[roi_i, :] = deconv_resamp_norm_trace(meanresps[roi_i, :], stim[:, 0],
                                                                stim_ref[:, 0], tau, ker_len,
                                                                smooth_wnd=4,
                                                                normalization=normalization)
        resamp_meanresps[roi_i, :] = deconv_resamp_norm_trace(meanresps[roi_i, :], stim[:, 0],
                                                                stim_ref[:, 0], None, ker_len,
                                                                smooth_wnd=4,
                                                                normalization=normalization)
    
    cluster_resps = np.empty((n_cluster, stim_ref.shape[0]))
    for clust_i in range(n_cluster):
        cluster_resp = np.nanmean(deconv_meanresps[labels==clust_i, :], 0)  # average cluster responses
        cluster_resps[clust_i, :] = nanzscore(cluster_resp)  # normalize


    # Add everything to dictionary:
    data_dict[brain_region]["linkage_mat"] = linked
    data_dict[brain_region]["clust_labels"] = labels
    data_dict[brain_region]["pooled_data"] = pooled_data
    data_dict[brain_region]["raw_mn_resps"] = meanresps
    data_dict[brain_region]["deconv_mn_resps"] = deconv_meanresps
    data_dict[brain_region]["resamp_mn_resps"] = resamp_meanresps
    data_dict[brain_region]["rel_idxs"] = rel_idxs
    data_dict[brain_region]["rel_thr"] = rel_thr
    data_dict[brain_region]["clust_resps"] = cluster_resps

[<luminance_analysis.FishData object at 0x000002517D69C188>, <luminance_analysis.FishData object at 0x000002517D69C208>, <luminance_analysis.FishData object at 0x000002517D6B7B88>, <luminance_analysis.FishData object at 0x000002517D6CE348>, <luminance_analysis.FishData object at 0x000002517D6ADAC8>]
[<luminance_analysis.FishData object at 0x000002517D69CC48>, <luminance_analysis.FishData object at 0x000002517D69CBC8>, <luminance_analysis.FishData object at 0x000002517D6B7C08>, <luminance_analysis.FishData object at 0x000002517D6CEF48>, <luminance_analysis.FishData object at 0x000002517D6B2EC8>]


  meanresps = np.nanmean(traces, 2)
  reliability[i] = np.nanmean(corr)


[<luminance_analysis.FishData object at 0x000002517D6B7848>, <luminance_analysis.FishData object at 0x000002517D6B7888>, <luminance_analysis.FishData object at 0x000002517D6CE6C8>, <luminance_analysis.FishData object at 0x000002517D6AD988>, <luminance_analysis.FishData object at 0x000002517D6B2B48>]


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<luminance_analysis.FishData object at 0x0000025104FF9308>, <luminance_analysis.FishData object at 0x0000025104FF9388>, <luminance_analysis.FishData object at 0x0000025104FF0988>, <luminance_analysis.FishData object at 0x0000025104FF6EC8>, <luminance_analysis.FishData object at 0x0000025104FF3448>]
[<luminance_analysis.FishData object at 0x0000025104FF6C08>, <luminance_analysis.FishData object at 0x0000025104FF6908>, <luminance_analysis.FishData object at 0x0000025104FF5748>, <luminance_analysis.FishData object at 0x0000025104FF3A08>, <luminance_analysis.FishData object at 0x0000025104FF2B08>]


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<luminance_analysis.FishData object at 0x0000025100E4CD08>, <luminance_analysis.FishData object at 0x0000025100E4CD88>, <luminance_analysis.FishData object at 0x0000025100E4EF48>, <luminance_analysis.FishData object at 0x0000025100E530C8>, <luminance_analysis.FishData object at 0x0000025100E57488>]


  c /= stddev[:, None]
  c /= stddev[None, :]


[<luminance_analysis.FishData object at 0x0000025100EB4488>, <luminance_analysis.FishData object at 0x0000025100EB4508>, <luminance_analysis.FishData object at 0x0000025100EB5708>, <luminance_analysis.FishData object at 0x0000025100EC6848>, <luminance_analysis.FishData object at 0x0000025100EB8C08>]


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### Clustering overview

In [7]:
from luminance_analysis.plotting import plot_clusters_dendro, re_histogram

In [8]:
colors = sns.color_palette()[:2]
fig_clust = plt.figure(figsize=(7, 6))
for i, (k, dendrolim, spacing, cbar, x_pos, color) in enumerate(zip(["GC", "IO"], [940, 82],
                                                [3, 5], [False, True], [0.45, 0.0], colors)):
    f_hist = re_histogram(data_dict[k]["rel_idxs"], data_dict[k]["rel_thr"], fig_clust,  
                          w=0.18, h=0.08, w_p=0.02, h_p=x_pos+0.4, color=color)
    
    meanresps = data_dict[k]["resamp_mn_resps"]  # data_dict[k]["deconv_mn_resps"]
    smooth_mean_resps = pd.DataFrame(meanresps.T).rolling(4, center=True).mean().values.T
    
    fig_clust = plot_clusters_dendro(smooth_mean_resps, stim_ref,
                                 data_dict[k]["linkage_mat"], data_dict[k]["clust_labels"], prefix=k, 
                                 figure=fig_clust, w=1., h=0.65, w_p=0.1, h_p=x_pos, f_lim=2,
                                 dendrolims=(dendrolim, 0), gamma=0.4, spacing=spacing, colorbar=cbar)
    
fig_clust.text(0.005, 0.95, 'A')
fig_clust.text(0.005, 0.5, 'B')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  plt.tight_layout()


[ 445  874 1379 2570]


  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -
  plt.tight_layout()
  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -
  np.nanmean(traces[labels == i, :], 0) -


[ 2  8 10 13 26]


Text(0.005, 0.5, 'B')

In [9]:
if fig_fold is not None:
    fig_clust.savefig(str(fig_fold / "flashes_clusters.pdf"))

### Supplementary figures

In [10]:
fig_fold = Path(r"C:\Users\otprat\Documents\figures\luminance\manuscript_figures\fig4supp")

if not os.path.isdir(fig_fold):
    os.mkdir(fig_fold)

In [11]:
%load_ext autoreload
from luminance_analysis.plotting import make_bar, get_yg_custom_cmap, add_offset_axes, shade_plot, stim_plot
%autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### Stimulus plot

In [12]:
stim_arr = PooledData(path = master_path / protocol / "GC").stimarray

[<luminance_analysis.FishData object at 0x0000025109EEB708>, <luminance_analysis.FishData object at 0x0000025109EEB788>, <luminance_analysis.FishData object at 0x0000025109EF2D48>, <luminance_analysis.FishData object at 0x0000025109EFB248>, <luminance_analysis.FishData object at 0x0000025109F03888>]


In [13]:
fig_stim = stim_plot(stim_arr, xlims = (0, 54), gamma=0.4, figure=None, frame=None)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [14]:
if fig_fold is not None:
    fig_stim.savefig(str(fig_fold / "Stimulation_protocol.pdf"))

#### Fish contribution to each cluster

In [15]:
def plot_fish_contribution(data_dict, figure=None, frame=None):
    
    if figure is None:
        figure = plt.figure(figsize=(6, 3))
        
    barWidth = 0.85
    n_fish = 5
    colors = (sns.color_palette("deep", 10))
    fish_contribution = {brain_region:{} for brain_region in ["GC", "IO"]}
        
    for i, brain_region in enumerate(["GC", "IO"]):
                
        ax_hist = add_offset_axes(figure, (0.1 + 0.5*i, 0.15, .35, .7), frame=frame)
        
        clusters = np.unique(data_dict[brain_region]['clust_labels'])
        n_clust = clusters.shape[0]
        roi_map = data_dict[brain_region]['pooled_data'].roi_map

        for fish in range(n_fish):
            fish_labels = data_dict[brain_region]['clust_labels'][roi_map[0, :] == fish]
            fish_contribution[brain_region]['{} Fish {}'.format(brain_region, fish+1)] = np.array([np.sum(fish_labels == c) for c in range(n_clust)])

        contributions_df = pd.DataFrame(fish_contribution[brain_region])
        for i, c in enumerate(contributions_df.columns):
            ax_hist.bar(clusters+1, contributions_df[c], bottom=sum([contributions_df[prev] for prev in list(contributions_df.columns)[:i]]),
                    width=barWidth, label=c, color=colors[i])

        #ax_hist.legend()
        ax_hist.set_xlabel("Cluster #")
        #ax_hist.text(.5, 1, brain_region, ha='center', va='top', transform=ax_hist.transAxes, fontsize=8.5)
        ax_hist.set_title(brain_region)
        
        if brain_region == 'GC':
            ax_hist.set_ylabel("Number of ROIs")
        
    plt.tight_layout()    
    return(figure)

In [16]:
fig_fish_contrib = plot_fish_contribution(data_dict)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …



In [17]:
if fig_fold is not None:
    fig_fish_contrib.savefig(str(fig_fold / "Fish_contributions.pdf"))

### Assemble figure

In [18]:
figureS4 = plt.figure(figsize=(8, 3))

#Stimulus plot:
stim_panel = stim_plot(stim_arr, xlims = (0, 54), gamma=0.4, figure=figureS4, frame=(0.04, 0.275, 0.25, 0.5))
figureS4.text(.005, .75, 'A')

#Fish contributions to cluster plot:
fish_contrib_panel = plot_fish_contribution(data_dict, figure=figureS4, frame=(0.3, 0.2, 0.7, 0.7))
figureS4.text(.275, .75, 'B')


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …



Text(0.275, 0.75, 'B')

In [19]:
if fig_fold is not None:
    figureS4.savefig(str(fig_fold / "flashes_supplementary.pdf"))