In [None]:
%matplotlib widget
%load_ext autoreload

In [None]:
%autoreload 2
from pathlib import Path
import numpy as np
from bouter import EmbeddedExperiment
import pandas as pd
from tqdm import tqdm
import flammkuchen as fl

from matplotlib import pyplot as plt
import seaborn as sns
sns.set(palette="deep", style="ticks")
cols = sns.color_palette()

from scipy.signal import detrend

from bouter.utilities import crop, reliability
from scipy.stats import ranksums, kstest, ttest_ind
# from utilities import *
from xiao_et_al_utils.behavior_and_stimuli import stimulus_df_from_exp0070
from xiao_et_al_utils.plotting import add_fish, despine

In [None]:
master_path = Path("/Volumes/Shared/experiments/E0070_receptive_field/v04_flashing_rad_simple")

path_list = [f.parent for f in master_path.glob("*/data_from_suite2p_unfiltered.h5")]

In [None]:
all_coords = []
for p in path_list:
    coords = fl.load(p/ "data_from_suite2p_unfiltered.h5", "/coords")
    
    coords -= all_offsets[p.name]
    
    all_coords.append(coords)

In [None]:
all_masks = [fl.load(p/ "anatomy.mask", "/mask") for p in path_list]
# all_coords = [fl.load(p/ "data_from_suite2p_unfiltered.h5", "/coords") for p in path_list]
all_offsets = fl.load(master_path / "manual_alignment_offsets.h5")


In [None]:
all_df = fl.load(master_path / "pooled_dfs.h5")

In [None]:
all_df = all_df["all_cells_df"]

In [None]:
plt.figure()
plt.subplot(121)
sel = all_df["max_rel"] > 0.5
plt.scatter(all_df.loc[sel, "x"], all_df.loc[sel, "y"], s=0.1, alpha=0.1, c=all_df.loc[sel, "max_rel"])
plt.axis("equal")

plt.subplot(122)
plt.scatter(all_df.loc[sel, "x_trasf"], all_df.loc[sel, "y_trasf"], s=0.1, alpha=0.1, 
            c=all_df.loc[sel, "max_rel"])
plt.axis("equal")

In [None]:
plt.figure(figsize=(9, 2.5))
plt.subplot(131)
for m, c in zip(all_masks, all_coords):
    prof = m.max(1)
    plt.scatter(c[:, 1], c[:, 2], s=0.1, alpha=0.01, c="k")
plt.axis("off")
    
plt.subplot(132)
for m, c in zip(all_masks, all_coords):
    prof = m.max(0)
    #prof = pro
    x = prof.shape[0] // 2
    y = prof.shape[1] // 2
    plt.scatter(c[:, 1] - x, c[:, 2] - y, s=0.1, alpha=0.01, c="k")
plt.axis("off")

plt.subplot(133)
for m, c, p in zip(all_masks, all_coords, path_list):
    prof = m.max(1)
    #prof = pro
    x, y = all_offsets[p.name][1:]
    plt.scatter(c[:, 1] - x, c[:, 2] - y, s=0.1, alpha=0.01, c="k")
plt.axis("off")

In [None]:
plt.figure(figsize=(9, 2.5))
plt.subplot(131)
for m, c in zip(all_masks, all_coords):
    offsets_z = np.random.rand(c.shape[0])*15
    plt.scatter(c[:, 1], c[:, 0] + offsets_z, s=0.1, alpha=0.01, c="k")
plt.axis("off")
    
plt.subplot(132)
for m, c in zip(all_masks, all_coords):
    prof = m.max(2)
    offsets_z = np.random.rand(c.shape[0])*15
    x = prof.shape[0] // 2
    y = prof.shape[1] // 2
    plt.scatter(c[:,1] - y, c[:, 0] - x + offsets_z, s=0.1, alpha=0.01, c="k")
plt.axis("off")

plt.subplot(133)
for m, c, p in zip(all_masks, all_coords, path_list):
    offsets_z = np.random.rand(c.shape[0])*15
    x, y = all_offsets[p.name][:-1]
    plt.scatter(c[:, 1] - y, c[:, 0] - x + offsets_z, s=0.1, alpha=0.01, c="k")
plt.axis("off")

In [None]:
all_masks = np.array(all_masks)

In [None]:
# Load traces and experiment metadata:
path = master_path / "210611_f5"
traces = fl.load(path / "data_from_suite2p_unfiltered.h5", "/traces").T
# detrend the traces:
for i in tqdm(range(traces.shape[1])):
    traces[:, i] = detrend(traces[:, i])
traces = (traces - np.nanmean(traces, 0)) / np.nanstd(traces, 0)
rois = fl.load(path / "data_from_suite2p_unfiltered.h5", "/rois_stack")
anatomy = fl.load(path / "data_from_suite2p_unfiltered.h5", "/anatomy_stack")
ot_mask = fl.load(path / "anatomy.mask", "/mask")

# Data for single cell

In [None]:
cells_df = fl.load(path / "cell_df.h5")

In [None]:
cells_df.loc[:, [f"rel_{i}" for i in range(len(pos))]].values.T

In [None]:
plt.figure()
plt.scatter(cells_df.loc[:, [f"rel_{i}" for i in range(len(pos))]].values.T.flatten(), rel_scores.flatten())

In [None]:
coords = fl.load(path / "data_from_suite2p_unfiltered.h5", "/coords")

exp = EmbeddedExperiment(path)

# Read original frequency:
fs = int(exp["imaging"]["microscope_config"]["lightsheet"]["scanning"]["z"]["frequency"])
samp_n = traces.shape[0]
t_orig = np.arange(traces.shape[0]) / fs


stim_df = stimulus_df_from_exp0070(exp)

# Crop around stimuli:
n_cells = traces.shape[1]
PRE_INT_S = 2
POST_INT_S = 5
cropped = crop(traces, stim_df["t"]*fs, 
                     pre_int=int(PRE_INT_S*fs), post_int=int(POST_INT_S*fs))
cropped = cropped - cropped[:int(PRE_INT_S*fs), :, :].mean(0)

# Find unique positions in the stimulus:
pos = sorted(stim_df.loc[:, "pos_start"].unique())

# Loop over positions and compute reliability scores:
# rel_scores = np.zeros((len(pos), n_cells))

#for i, p in tqdm(list(enumerate(pos))):
#    resps_idxs = stim_df[stim_df["pos_start"] == p].index

#    rel_scores[i] = reliability(cropped[:, resps_idxs, :])

rel_scores = cells_df.loc[:, [f"rel_{i}" for i in range(len(pos))]].values.T

# Data for topology

In [None]:
pooled_data = fl.load(master_path / "pooled_dfs.h5", "/all_cells_df")
all_responses = pooled_data.loc[:, [f"rel_{i}" for i in range(35)]].values.T
all_coords = pooled_data.loc[:, ["z", "x", "y"]].values
all_in_tectum = pooled_data["in_tectum"].values

responsive = all_responses.max(0) > 0.5
all_peaks = np.argmax(all_responses, 0)
stim_thetas = np.array(pos)

In [None]:
sum(responsive)

In [None]:
sum(all_in_tectum)

# Data for histogram

In [None]:
pooled_data["fid"] = pooled_data["cid"].apply(fix_fid)

REL_SCORE_THR = 0.5

exp_df = []
df_list = []
for path in tqdm(path_list):
    k = path.name
    data = fl.load(path / "cell_resps.h5")
    gen = ["abl", "cnt"][(int(k.split("_f")[1]) % 2)]
    rel_scores_fish = data["rel_scores"]
    amp_scores = data["amp_scores"]
    in_tectum = data["in_tectum"]

    n_cells = rel_scores_fish.shape[1]
    n_stims = rel_scores_fish.shape[0]

    reord_rel = center_on_peak(rel_scores_fish)
    reord_amp = center_on_peak(amp_scores)

    df = pd.DataFrame(np.concatenate([rel_scores_fish, amp_scores, reord_rel, reord_amp], 0).T, 
                      columns=[f"rel_{i}" for i in range(n_stims)] + 
                              [f"amp_{i}" for i in range(n_stims)] + 
                              [f"rel_reord_{i}" for i in range(n_stims)] + 
                              [f"amp_reord_{i}" for i in range(n_stims)])
    df["cid"] = [f"{k}_{i:05.0f}" for i in range(n_cells)]
    df["gen"] = gen
    df["fid"] = k
    df["in_tectum"] = in_tectum

    df["max_rel"] = np.nanmax(rel_scores_fish, 0)
    df["max_rel_i"] = np.argmax(rel_scores_fish, 0)
    df["max_amp"] = np.nanmax(amp_scores, 0)
    df["max_amp_i"] = np.argmax(amp_scores, 0)

    exp_df.append(dict(fid=k,
                gen=gen,
                n_cells=n_cells,
                n_cells_tectum=sum(in_tectum),
                above_rel_thr=np.sum(np.max(rel_scores_fish[:, in_tectum], 0) > REL_SCORE_THR),
                above_rel_thr_all=np.sum(np.max(rel_scores_fish, 0) > REL_SCORE_THR)))
    df_list.append(df)
    
full_df = pd.concat(df_list, axis=0)
full_df = full_df.set_index(full_df["cid"])

# Exclude bad fish:

exp_df = pd.DataFrame(exp_df)
exp_df = exp_df.set_index("fid")
exp_df["mn_amplitude"] = full_df.loc[full_df["in_tectum"] & \
                                     full_df["max_rel"] > REL_SCORE_THR, :].groupby("fid").mean()["max_amp"]

# Data for receptive field

In [None]:
from scipy.optimize import curve_fit

x = np.arange(36)
y_single_neuron = full_df.loc[(full_df["gen"]=="cnt") & (full_df["max_rel"] > 0.9), 
                       [f"rel_reord_{i}" for i in range(36)]].values[0, :]

mean = sum(x * y_single_neuron) / sum(y_single_neuron)
sigma = np.sqrt(sum(y_single_neuron * (x - mean)**2) / sum(y_single_neuron))

def gaussian(x, a, x0, sigma):
    return a * np.exp(-(x - x0)**2 / (2 * sigma**2))

popt_singleneuron, pcov = curve_fit(gaussian, x, y_single_neuron, 
                                    p0=[max(y_single_neuron), mean, sigma])

popt = fl.load("fit.h5", "/popt")

fit_params = np.array(popt)

for i, par_name in enumerate(["fit_amp", 
                              "fit_mn", 
                              "fit_sigma"]):
    full_df[par_name] = fit_params[:, i]
    
full_df["fit_sigma"] = np.abs(full_df["fit_sigma"])

exp_df["mean_sigma"] = np.nan
for f in exp_df.index:
    s = full_df.loc[(full_df["fid"]==f) & (full_df["max_rel"] > REL_SCORE_THR) & full_df["in_tectum"], "fit_sigma"] 
    exp_df.loc[f, "mean_sigma"] = np.nanmean(s)

# Main panel

In [None]:
# plt.rcParams['figure.constrained_layout.use'] = True
plt.rcParams['axes.linewidth'] = 0.5 
plt.rcParams['axes.labelsize'] = 8
plt.rcParams["legend.fontsize"] = 8
plt.rcParams["axes.titlesize"] = 8
for t in ["x", "y"]:
    plt.rcParams[t+'tick.major.size'] = 3
    plt.rcParams[t+'tick.labelsize'] = 8
    plt.rcParams[t+'tick.major.width'] = 0.5

# Panel A: stimulus description

In [None]:
clip_masks = [s["clip_mask"] for s in exp["stimulus"]["log"][1::2]]
titles = ["4 s", "2 s", "4 s", "2 s", "4 s"]
stimuli = [0, None, 20, None, 10]
letter_size = 12

fig_a = plt.figure(figsize=(2.5, 3))

xpos, ypos, side = 0.1, 0.6, 0.16
axs = [fig_a.add_axes((xpos+side*1.1*i, ypos, side, side)) for i in range(5)]


clip_masks = [s["clip_mask"] for s in exp["stimulus"]["log"][1::2]]
titles = ["4 s", "2 s", "4 s", "2 s", "4 s"]
stimuli = [0, None, 20, None, 10]
for ax, title, stim_n in zip(axs, titles, stimuli):
    ax.fill([0, 1, 1, 0], [0, 0, 1, 1], fc="r")
    
    if stim_n is not None:
        mask = clip_masks[stim_n]
        ax.fill([m[0] for m in mask], [m[1] for m in mask], fc="k", lw=0, alpha=0.6)
    ax.set(ylim=(0.1, 0.9), xlim=(0.1, 0.9), xticks=[], yticks=[], title=title)
    ax.set_aspect('equal', adjustable='box')
    
    add_fish(ax, offset=[0.45, 0.08], scale=(30/15))

    despine(ax, sides="all")
    
planes = [2, 6]
pad = 10

an_ax = fig_a.add_axes((0, 0, 1, 0.5))
an_ax.imshow(np.concatenate([anatomy[i, pad:-pad, pad:-pad] for i in planes]).T, 
           cmap="gray_r", origin="lower", vmax=100, vmin=0)
an_ax.contour(np.concatenate([ot_mask[i, pad:-pad, pad:-pad] for i in planes], axis=1), 
            origin="lower", 
            levels=[1], linewidths=0.5, colors=[cols[3]])

b_len = 100
bar_pos_x = rois.shape[1]
for ax, labels, bar_pos_y in zip(axs, [["caud-rost", "l-r"], ["vent-dors"]], [400]):
    an_ax.plot([bar_pos_x, bar_pos_x, bar_pos_x+b_len], 
              [bar_pos_y-b_len, bar_pos_y, bar_pos_y], lw=0.5, c=(0.3,)*3)
    an_ax.text(bar_pos_x, bar_pos_y - b_len/2, labels[0], ha="right", va="center", 
               rotation='vertical', fontsize=8)
    an_ax.text(bar_pos_x + b_len/2, bar_pos_y + 10, labels[1], ha="center", va="bottom", fontsize=8)
an_ax.text(rois.shape[1]*2-pad*5, rois.shape[1] - pad*7, 
           "Huc:H2B-GCaMP6s", fontsize=7, ha="right", va="top", c=(0.5,)*3)


plt.axis("off")

# Panel B, example cells

In [None]:
plt.close("all")
fig_b = plt.figure(figsize=(7, 3))

m_xpos, m_ypos, xside, yside = 0.15, 0., 0.43, 1
bounds_lims = [(m_xpos+xside*1.05*i, m_ypos, xside, yside) for i in range(2)]

x_time = np.arange(0, cropped.shape[0])/fs - PRE_INT_S  # time array
offset = np.median(np.array(pos)[1:] - np.array(pos)[:-1])  # offset on angle
# parameters for circle of subplots:
ax_w = 0.12
ax_c = (0.5, 0.5)
ax_r = 0.37
n_stims = 36
scaling_percentiles = [1, 99.89]

p_w = 0.22  # side of the central histogram

for c_n, (i_cell, (xpos, ypos, xside, yside)) in enumerate(zip([30, 10818], bounds_lims)):
    
    cell_plane = int(coords[i_cell, 0])  # plane in which the cell is found

    # Get percentiles of responses scaling:
    y_low, y_high = [np.nanpercentile(cropped[:, :, i_cell], p) for p in scaling_percentiles]

    # Color roi stack with red in location of current ROI
    rois_image = np.ones((rois.shape[1:]) + (3,))
    for c in range(3):
        rois_image[:, :,c][rois[cell_plane, :, :] >= 0] = 0.8
        rois_image[:, :, c][rois[cell_plane, :, :] == i_cell] = cols[3][c]

    rois_image = rois_image.swapaxes(0, 1)

    # Loop over stimulus positions and plot reps and mean:
    for i, th in enumerate(pos):
        resps_idxs = stim_df[stim_df["pos_start"] == th].index

        plot_th = (np.pi*2 / n_stims) * i - np.pi + (10*np.pi/2) / 360
        x = np.cos(plot_th)*ax_r
        y = -np.sin(plot_th)*ax_r

        ax = fig_b.add_axes((xpos + (ax_c[0] - ax_w/2 + x)*xside, 
                           ypos + (ax_c[1] - ax_w/2+ y)*yside, 
                           ax_w*xside, ax_w*yside))

        ax.plot(x_time, cropped[:, resps_idxs, i_cell], lw=0.3, c=(0.4,)*3)
        ax.plot(x_time, cropped[:, resps_idxs, i_cell].mean(1), lw=1, c=cols[3])
        
        ax.plot([0, 4], [-1.5, -1.5],  c=cols[0], alpha=0.6, lw=1)
        ax.set(xlim=(-2, 7), ylim=(-y_high, y_high))
        ax.axis("off")
        
    ax = fig_b.add_axes((xpos + (ax_c[0] - ax_w/2 - 1/np.sqrt(2)*ax_r)*xside - 0.05, 
                           ypos + (ax_c[1] - ax_w/2- 1/np.sqrt(2)*ax_r)*yside, 
                           ax_w*xside, ax_w*yside))
    ax.set(xlim=(-2, 7), ylim=(-y_high, y_high))
    
    bar_pos_x, bar_pos_y = 0, 0
    bar_s_x, bar_s_y = (4, 4)
    ax.axis("off")
    ax.plot([bar_pos_x, bar_pos_x, bar_pos_x+bar_s_x], 
              [bar_pos_y+bar_s_y, bar_pos_y, bar_pos_y], lw=0.5, c=(0.3,)*3)
    ax.text(bar_pos_x-1, bar_pos_y + bar_s_y/2, f"{bar_s_x} dF/F (Z-sc.)", ha="right", va="center", 
               rotation='vertical', fontsize=8)
    ax.text(bar_pos_x + bar_s_x/2, bar_pos_y-1, f"{bar_s_y} s", ha="center", va="top", fontsize=8)
    ax.patch.set_alpha(0.)


    ax = fig_b.add_axes((xpos + (ax_c[0]-p_w)*xside, 
                         ypos + (ax_c[1] - p_w)*yside, 
                         p_w*2*xside, p_w*2*yside), polar=True)

    ax.bar(-np.array(pos), rel_scores[:, i_cell], lw=0, width=np.pi/18)

    ax.set_thetagrids([])

    ax.spines['polar'].set_visible(False) 
    ax.set_ylim(0, 1)
    ax.set_rgrids([0.5, 1], angle=0, fontsize=7)
    ax.text(-0.15, 0.85,
            'Reliability', rotation=0,ha='center',va='center', fontsize=8)

    anatomy_ax = fig_b.add_axes((xpos - 0.12*xside, ypos + 0.75*yside, 0.3*xside, 0.3*yside))
    # anatomy_ax.imshow(rois_image, origin="lower")
    anatomy_ax.set_xlim(-5, rois.shape[1])
    anatomy_ax.contour(rois[cell_plane, :, :] == i_cell, origin="lower", levels=[1], 
                       linewidths=2, colors=[cols[3]])
    anatomy_ax.contour(ot_mask[cell_plane, :, :], origin="lower", levels=[1], 
                       linewidths=0.5, colors=[(0.5,)*3])
    anatomy_ax.axis("off")

# Panel C: topology

In [None]:
all_coords[filt, 1]

In [None]:
plt.figure()
plt.scatter(all_coords[:, 1], all_coords[:, 0])

In [None]:
{"210611_f2": (366, 202),
 "210611_f3": (363, 197),
 "210611_f4": (399, 299),
 "210611_f5": (313, 160),
 "210611_f15": (360, 177),
 "210621_f1": (370, 176),
 "210621_f2": (380, 162),
 "210621_f3": (,),
 "210621_f4": (,),
 "210621_f10": (,),
 "210621_f11": (,),
 "210621_f13": (,),
 "210621_f14": (,),
 "210621_f20": (,),
 "210621_f21": (,),
 "210621_f22": (,),
 "210624_f0": (,),
 "210624_f1": (,),
 "210624_f3": (,),
 "210624_f4": (,),
 "210624_f5": (,),
 "210624_f11": (,),
 "210624_f14": (,)}

In [None]:
def _shift_90_deg(array):
    out = array + np.pi/2
    out[out > np.pi] = -np.pi + np.mod(out[out > np.pi], np.pi)
    return out

In [None]:
fig_c = plt.figure(figsize=(7.5, 3))
m_xpos, m_ypos, xside, yside = 0.2, 0.1, 0.5, 0.7
anat_scatt_size=3

all_axs = [fig_c.add_axes((m_xpos + 0.65*i*xside, m_ypos, 
                        xside, yside)) for i in range(2)]
# all_axs = [all_axs_top, all_axs_bot]

spacing = 300
for g_i, g in enumerate(["MTZ-cnt", "OPC-abl"]):
    axs = all_axs[g_i]
    for i in sorted(np.unique(all_coords[:, 0])):
        filt = responsive & all_in_tectum & (all_coords[:, 0] == i) & (pooled_data["gen"] == g)
        axs.scatter(all_coords[filt, 1], all_coords[filt, 2], 
                    c=_shift_90_deg(stim_thetas[all_peaks[filt]]), 
                    cmap="twilight_shifted", s=anat_scatt_size)
    axs.set_aspect('equal', adjustable='box')
    # axs[0].set(xlim=(-20, 760), ylim=(0, 460))

    # axs[0].axis("off")
    # axs.text(351, 450, g, fontsize=8, va="bottom", ha="center")
    axs.set_title(g)
    despine(axs, sides="all")

    for i in range(12, 0, -1):
        filt = responsive & all_in_tectum & (pooled_data["gen"] == g) & \
                    (all_coords[:, 2] > i*50) & (all_coords[:, 2] < (i+1)*50)
        axs.scatter(all_coords[filt, 1], all_coords[filt, 0] - spacing + np.random.rand(sum(filt))*14, 
                    c=_shift_90_deg(stim_thetas[all_peaks[filt]]), 
                    cmap="twilight_shifted", s=anat_scatt_size)
    # axs[0].set(xlim=(-20, 760))# , ylim=(0, 460))
    # axs[1].axis("off")
    
    # if g_i == 1:
    b_len = 100
    bar_pos_x = -100
    print(g_i)
    if g_i == 0:
        print(g_i)
        for labels, bar_pos_y in zip([["caud-rost", "l-r"], ["vent-dors", "l-r"]], [200, 200 - spacing]):
            axs.plot([bar_pos_x, bar_pos_x, bar_pos_x+b_len], 
                  [bar_pos_y-b_len, bar_pos_y, bar_pos_y], lw=0.5, c=(0.3,)*3)
            axs.text(bar_pos_x, bar_pos_y - b_len/2, labels[0], ha="right", va="center", 
                       rotation='vertical', fontsize=8)
            axs.text(bar_pos_x + b_len/2, bar_pos_y + 10, labels[1], ha="center", va="bottom", fontsize=8)
    axs.set(xlim=(-120, 600))# , ylim=(0, 460))
val_range = np.linspace(-np.pi, np.pi, 100)
axs = fig_c.add_axes((0., 0.35, 0.25, 0.25))
axs.imshow(_shift_90_deg(np.angle(val_range[:,None] + 1j*val_range[None,:]).T), 
          cmap="twilight_shifted", extent=[0, 1, 0, 1])
axs.text(0.5,1.05, "stim. θ", fontsize=7, va="bottom", ha="center")
axs.axis("equal")
axs.axis("off")
add_fish(axs, offset=[0.45, 0.0], scale=1.7)

# Panel D: rel. histogram

In [None]:
HIST_FIG_SIZE = (4., 3)
group_colors = [(0.4,)*3, cols[2]]

def hist_and_scatter(fig, hist_key, hist_range=None, hist_label=None, scatter_key=None,
                    scatter_coef=1, scatter_label=None):
    SCAT_DISP = 50  # scatter dispersion
    hist_box = (0.1, 0.25, 0.4, 0.5)
    scat_box = (0.75, 0.25, 0.2, 0.5)
    p_val_size = 8
    
    axs = fig.add_axes(hist_box)
    
    rel_histograms = dict()
    for i, g in enumerate(["cnt", "abl"]):
        sel_fids = full_df.loc[full_df["gen"]==g, "fid"].unique()
        all_hists = []
        for f in sel_fids:
            h, bins = np.histogram(full_df.loc[(full_df["fid"]==f) & full_df["in_tectum"], hist_key], 
                                       hist_range, density=True)
            all_hists.append(h)
        rel_histograms[g] = np.array(all_hists)

        x_bins = (bins[1:] + bins[:-1]) / 2

        axs.plot(x_bins, rel_histograms[g].T, c=group_colors[i], lw=0.2)
        axs.plot(x_bins, rel_histograms[g].mean(0), c=group_colors[i], lw=2, label=g)
    plt.legend(frameon=False)
    axs.set(yscale="log", xlabel=hist_label)
    despine(axs)

    scat_axs = fig.add_axes(scat_box)
    for i, g in enumerate(["cnt", "abl"]):
        sel = exp_df.loc[exp_df["gen"] == g, scatter_key]
        scat_axs.scatter(np.random.randn(len(sel))/SCAT_DISP+i, sel*scatter_coef, 
                        c=group_colors[i], s=8)
    diff_p = ranksums(*[exp_df.loc[exp_df["gen"] == g, scatter_key] for g in ["cnt", "abl"]])
    scat_axs.set(xlim=[-0.5, 1.5], xticks=[0, 1], xticklabels=["cnt", "abl"], 
            ylabel=scatter_label)
    
    print(f"p={diff_p.pvalue:0.4f}")
    pval = "n.s."
    if diff_p.pvalue < 0.05:
        pval = "*"
    if diff_p.pvalue < 0.01:
        pval = "**"
    if diff_p.pvalue < 0.001:
        pval = "***"
    pval_pos = np.percentile(exp_df.loc[:, scatter_key]*scatter_coef, 75)
    scat_axs.text(0.5, 0.800, pval, fontsize=p_val_size, ha="center") # f"p={diff_p.pvalue:0.4f}"
    despine(scat_axs)
    
    return axs, scat_axs
    
fig_d = plt.figure(figsize=HIST_FIG_SIZE)
axs, scat_axs = hist_and_scatter(fig_d, hist_key="max_rel", hist_range=np.arange(0, 1, 0.02), 
                 hist_label="reliability score", scatter_key="above_rel_thr", 
                 scatter_coef=1/1000, scatter_label="responsive rois (thous.)",
                 )
axs.axvline(0.5, lw=0.5, c=(0.4,)*3)

# Panel E: amplitude histogram

In [None]:
fig_e = plt.figure(figsize=(4.5, 3))
axs, scat_axs = hist_and_scatter(fig_e, hist_key="max_amp", hist_range=np.arange(0, 6, 0.2), 
                 hist_label="resp. amplitude", scatter_key="mn_amplitude", 
                 scatter_label="average amplitude",
                 )

In [None]:
def _conv_to_dist(val):
    return (val - 18) * 10
x_range = _conv_to_dist(np.arange(36))

m_xpos, m_ypos, xside, yside = 0.3, 0.25, 0.6, 0.25
# fig_f, axs = plt.subplots(2, 1, figsize=(2., 3), sharex=True)
fig_f = plt.figure(figsize=(3, 3))
axs = [fig_f.add_axes((m_xpos, m_ypos + + 1.2*i*yside, 
                        xside, yside)) for i in range(2)]
            
data = full_df.loc[(full_df["max_rel"] > REL_SCORE_THR) & full_df["in_tectum"], 
                   [f"rel_reord_{i}" for i in range(36)]].values.T
data = data / data[18, :]
axs[1].plot(x_range, data[:, ::10], lw=0.3, c=(0.8,)*3)

axs[1].plot(x_range, np.nanmedian(data, 1), lw=2, c=(0.5,)*3)

axs[0].plot(x_range, y_single_neuron, 'o', label='data', lw=1, c=(0.5,)*3, markersize=3)
axs[0].plot(x_range, gaussian(np.arange(len(x_range)), 
                              *popt_singleneuron), 'r-', label='fit', lw=1)
axs[0].axvspan(_conv_to_dist(popt_singleneuron[1] - popt_singleneuron[2]), 
               _conv_to_dist(popt_singleneuron[1] + popt_singleneuron[2]), 
               fc=(0.9,)*3, lw=0)
axs[0].text(_conv_to_dist(popt_singleneuron[1]), 1.05, 
            "$2 \dot \sigma$", fontsize=8, c=(0.4,)*3, ha="center")
axs[0].set(xlabel="distance from peak (°)", ylabel="reliability", ylim=(-0.1, 1.2))
axs[1].set(xticklabels=[], ylabel="reliability")
#plt.tight_layout()
plt.show()
sns.despine()

# Panel G: RF size

In [None]:
fig_g = plt.figure(figsize=(4.5, 3))
axs, scat_axs = hist_and_scatter(fig_g, hist_key="fit_sigma", hist_range=np.arange(0, 8, 0.2), 
                 hist_label="$\sigma$", scatter_key="mean_sigma", 
                 scatter_label="average $\sigma$",
                 )

# Put together with trifle

In [None]:
%autoreload 2
from trifle import *

In [None]:
flist = [[fig_a, fig_b], [fig_c, fig_d], [fig_e, fig_f, fig_g]]

compfig, axes = compose_figure(flist, fig_width=8, enumeration="letters")

In [None]:
transfer_fig_list(flist, axes)

In [None]:
compfig.savefig("/Users/luigipetrucco/Desktop/assembled_hr.png", dpi=200)