In [None]:
%matplotlib widget
%load_ext autoreload

In [None]:
%autoreload 2
from pathlib import Path
import numpy as np
from bouter import EmbeddedExperiment
import pandas as pd
from tqdm import tqdm
import flammkuchen as fl

from matplotlib import pyplot as plt
import seaborn as sns
sns.set(palette="deep", style="ticks")
cols = sns.color_palette()

from scipy.stats import ranksums, kstest, ttest_ind

In [None]:
pooled_dicts = fl.load("/Users/luigipetrucco/Desktop/pooled_resps.h5")

In [None]:
master_path = Path("/Volumes/Shared/experiments/E0070_receptive_field/v04_flashing_rad_simple")
pooled_data = fl.load(master_path / "new_pooled.h5")

In [None]:


# pooled_data["fid"] = pooled_data["cid"].apply(fix_fid)
path_list = [f.parent for f in master_path.glob("*/data_from_suite2p_unfiltered.h5")]

In [None]:
data = fl.load(path_list[0] / "cell_resps.h5")

In [None]:
data = fl.load(path / "cell_resps.h5")

In [None]:
for path in tqdm(path_list[:1]):
    k = path.name
    exp = EmbeddedExperiment(path)
    gen = ["abl", "cnt"][(int(k.split("_f")[1]) % 2)]

In [None]:
gen = exp["general"]["animal"]["genotype"]  # long genotype

In [None]:
df = fl.load(path / "cell_df.h5")

In [None]:
df.loc[0, "gen"]

In [None]:
REL_SCORE_THR = 0.5

exp_df = []
df_list = []
for path in tqdm(path_list):
    k = path.name
    data = fl.load(path / "cell_resps.h5")
    gen = ["abl", "cnt"][(int(k.split("_f")[1]) % 2)]
    rel_scores = data["rel_scores"]
    amp_scores = data["amp_scores"]
    in_tectum = data["in_tectum"]

    n_cells = rel_scores.shape[1]
    n_stims = rel_scores.shape[0]

    reord_rel = center_on_peak(rel_scores)
    reord_amp = center_on_peak(amp_scores)

    df = pd.DataFrame(np.concatenate([rel_scores, amp_scores, reord_rel, reord_amp], 0).T, 
                      columns=[f"rel_{i}" for i in range(n_stims)] + 
                              [f"amp_{i}" for i in range(n_stims)] + 
                              [f"rel_reord_{i}" for i in range(n_stims)] + 
                              [f"amp_reord_{i}" for i in range(n_stims)])
    df["cid"] = [f"{k}_{i:05.0f}" for i in range(n_cells)]
    df["gen"] = gen
    df["fid"] = k
    df["in_tectum"] = in_tectum

    df["max_rel"] = np.nanmax(rel_scores, 0)
    df["max_rel_i"] = np.argmax(rel_scores, 0)
    df["max_amp"] = np.nanmax(amp_scores, 0)
    df["max_amp_i"] = np.argmax(amp_scores, 0)

    exp_df.append(dict(fid=k,
                gen=gen,
                n_cells=n_cells,
                n_cells_tectum=sum(in_tectum),
                above_rel_thr=np.sum(np.max(rel_scores[:, in_tectum], 0) > REL_SCORE_THR),
                above_rel_thr_all=np.sum(np.max(rel_scores, 0) > REL_SCORE_THR)))
    df_list.append(df)
    
full_df = pd.concat(df_list, axis=0)
full_df = full_df.set_index(full_df["cid"])

exp_df = pd.DataFrame(exp_df)
exp_df = exp_df.set_index("fid")

In [None]:
exp_df.groupby("gen").count()

In [None]:
rel_histograms = dict()

plt.figure(figsize=(4, 3))
plt.title("Reliability score")
for i, g in enumerate(["cnt", "abl"]):
    
    sel_fids = full_df.loc[full_df["gen"]==g, "fid"].unique()
    print(sel_fids)
    all_hists = []
    for f in sel_fids:
        h, bins = np.histogram(full_df.loc[(full_df["fid"]==f) & full_df["in_tectum"], "max_rel"], 
                                   np.arange(0, 1, 0.02), density=True)
        all_hists.append(h)
    rel_histograms[g] = np.array(all_hists)

    x_bins = (bins[1:] + bins[:-1]) / 2

    plt.plot(x_bins, rel_histograms[g].T, c=cols[i], lw=0.2)
    plt.plot(x_bins, rel_histograms[g].mean(0), c=cols[i], lw=2, label=g)
    
plt.legend(frameon=False)
plt.yscale("log")
plt.xlabel("Reliability score")
plt.ylabel("log(p)")
sns.despine()
plt.tight_layout()

In [None]:
rel_istograms = dict()

plt.figure(figsize=(4, 3))
plt.title("Response amplitude")

for i, g in enumerate(["cnt", "abl"]):
    sel_fids = full_df.loc[(full_df["gen"]==g), "fid"].unique()
    print(sel_fids)
    all_hists = []
    for f in sel_fids:
        h, bins = np.histogram(full_df.loc[(full_df["fid"]==f) & full_df["in_tectum"], "max_amp"], 
                                   np.arange(0, 6, 0.2),density=True)
        all_hists.append(h)
    all_hists[i] = np.array(all_hists)

    x_bins = (bins[1:] + bins[:-1]) / 2

    plt.plot(x_bins, all_hists[i].T, c=cols[i], lw=0.2)
    plt.plot(x_bins, np.nanmean(all_hists[i], 0), c=cols[i], lw=2, label=g)
    
plt.legend(frameon=False)
sns.despine()
plt.yscale("log")
plt.xlabel("")
plt.tight_layout()
plt.show()

In [None]:
exp_df["mn_amplitude"] = full_df.loc[full_df["in_tectum"] & full_df["max_rel"] > REL_SCORE_THR, :].groupby("fid").mean()["max_amp"]
plt.figure(figsize=(3, 3))
sns.swarmplot(data=exp_df, x="gen", y="mn_amplitude", order=["cnt", "abl"])
sns.despine()
plt.tight_layout()
diff_p = ranksums(*[exp_df.loc[exp_df["gen"] == g, "mn_amplitude"] for g in ["cnt", "abl"]])
plt.xlabel("Genotype")
plt.ylabel("Mean response amplitude")

plt.text(0.5, 1, f"p={diff_p.pvalue:0.4f}", ha="center", fontsize=10)

## Fraction of responsive cells

In [None]:
exp_df["fraction_resp"] = (exp_df["above_rel_thr"])
plt.figure(figsize=(3, 3))
sns.swarmplot(data=exp_df, x="gen", y="fraction_resp", order=["cnt", "abl"])
sns.despine()
plt.tight_layout()
diff_p = ranksums(*[exp_df.loc[exp_df["gen"] == g, "fraction_resp"] for g in ["cnt", "abl"]])
plt.xlabel("Genotype")
plt.ylabel("Responsive cells (n)")

plt.text(0.5, 800, f"p={diff_p.pvalue:0.4f}")

In [None]:
plt.figure(figsize=(3, 3))
sns.violinplot(data=full_df, x="gen", y="max_rel")
sns.despine()
plt.tight_layout()
diff_p = ranksums(*[full_df.loc[full_df["gen"] == g, "max_rel"] for g in ["cnt", "abl"]])
plt.xlabel("Genotype")
plt.ylabel("Responsive cells (%)")

plt.text(0.5, 6, f"p={diff_p.pvalue:0.4f}")

## RF plots

In [None]:
plt.close("all")
x_range = (np.arange(36) - 18)*10
f, axs = plt.subplots(3, 1, figsize=(4, 6), sharex=True)
for i, k in enumerate(["abl", "cnt"]):
    data = full_df.loc[(full_df["gen"]==k) & (full_df["max_rel"] > REL_SCORE_THR) & full_df["in_tectum"], 
                       [f"rel_reord_{i}" for i in range(36)]].values.T
    data = data / data[18, :]
    axs[i].plot(x_range, data[:, ::], lw=0.3, c=[c+0.1 for c in cols[i]])

    axs[i].plot(x_range, np.nanmedian(data, 1), lw=2, c=[c-0.1 for c in cols[i]])
    
for i, k in enumerate(["abl", "cnt"]):
    data = full_df.loc[(full_df["gen"]==k) & (full_df["max_rel"] > 0.5)  & full_df["in_tectum"], 
                       [f"rel_reord_{i}" for i in range(36)]].values.T
    data = data / data[18, :]
    mn = np.nanmedian(data, 1)
    quart_1 = np.nanpercentile(data, 25, axis=1)
    quart_3 = np.nanpercentile(data, 75, axis=1) 
    axs[2].fill_between(x_range, quart_1, quart_3, lw=0, alpha=0.5,
                        fc=[c-0.1 for c in cols[i]])
    axs[2].plot(x_range, mn, lw=2, c=[c-0.1 for c in cols[i]], label=["Control", "Ablated"][i])
    
axs[2].legend(frameon=False)
axs[2].set_xlabel("Distance from preferred loc (°)")
axs[0].set_ylabel("Cnt")
axs[1].set_ylabel("Abl")
axs[2].set_ylabel("Abl")
plt.tight_layout()
plt.show()
sns.despine()

In [None]:
from scipy.optimize import curve_fit

x = np.arange(36)
y = full_df.loc[(full_df["gen"]==k) & (full_df["max_rel"] > 0.9), 
                       [f"rel_reord_{i}" for i in range(36)]].values[0, :]

mean = sum(x * y) / sum(y)
sigma = np.sqrt(sum(y * (x - mean)**2) / sum(y))

def gaussian(x, a, x0, sigma):
    return a * np.exp(-(x - x0)**2 / (2 * sigma**2))

popt, pcov = curve_fit(gaussian, x, y, p0=[max(y), mean, sigma])

plt.figure(figsize=(4, 3))
plt.plot(x_range, y, 'o', label='data', lw=1)
plt.plot(x_range, gaussian(x, *popt), 'r-', label='fit')
plt.legend(frameon=False)
plt.xlabel("Distance from preferred loc (°)")
plt.ylabel("Reliability score")
plt.tight_layout()
sns.despine()
plt.show()

In [None]:
full_data_mat = full_df.loc[:, [f"rel_reord_{i}" for i in range(36)]].values

In [None]:
popt = []
pcov = []
for i in tqdm(range(full_data_mat.shape[0])):
    try:
        p, c = curve_fit(gaussian, x, full_data_mat[i, :], p0=[max(y), mean, sigma])
        popt.append(p)
        pcov.append(c)
    except RuntimeError:
        popt.append(np.full(3, np.nan))
        pcov.append(np.full((3, 3), np.nan))
        
fl.save("fit.h5", dict(popt=np.array(popt), pcov=np.array(pcov)))

In [None]:
popt = fl.load("fit.h5", "/popt")

In [None]:
fit_params = np.array(popt)

for i, par_name in enumerate(["fit_amp", 
                              "fit_mn", 
                              "fit_sigma"]):
    full_df[par_name] = fit_params[:, i]
    
full_df["fit_sigma"] = np.abs(full_df["fit_sigma"])

In [None]:
plt.close("all")
plt.figure(figsize=(4, 3))
thr = REL_SCORE_THR
bw = 0.3
histograms = dict()
for g in ["cnt", "abl"]:
    histograms[g], b = np.histogram(
         full_df.loc[(full_df["max_rel"] > thr) & (full_df["gen"] == g) & full_df["in_tectum"], 
                     "fit_sigma"].values,
         np.arange(0, 10, bw), density=True)
    
    x_bins = (b[1:] + b[:-1]) / 2
    plt.fill_between(x_bins, np.zeros(len(x_bins)), histograms[g]*bw, alpha=0.4, lw=0, label=g)

plt.legend(frameon=False)
#plt.hist(full_df.loc[(full_df["max_rel"] > thr) & (full_df["gen"] == "abl"), "fit_sigma"].values,
#         np.arange(0, 10, bw), lw=0, alpha=0.4, density=True)
ks_diff = kstest(*[histograms[g] for g in ["cnt", "abl"]])
ks_diff = ranksums(*[full_df.loc[(full_df["max_rel"] > thr) & (full_df["gen"] == g), "fit_sigma"].values 
                      for g in ["cnt", "abl"]])
# plt.text(6, 0.05, f"p={ks_diff.pvalue:0.4f}")

plt.xlabel("Sigma")
plt.ylabel("Count")
plt.tight_layout()
sns.despine()
plt.show()

In [None]:
exp_df["mean_sigma"] = np.nan
for f in exp_df.index:
    s = full_df.loc[(full_df["fid"]==f) & (full_df["max_rel"] > REL_SCORE_THR) & full_df["in_tectum"], "fit_sigma"] 
    exp_df.loc[f, "mean_sigma"] = np.nanmean(s)

In [None]:
plt.figure(figsize=(3, 3))
sns.swarmplot(data=exp_df, x="gen", y="mean_sigma", order=["cnt", "abl"])
sns.despine()
plt.tight_layout()
diff_p = ranksums(*[exp_df.loc[exp_df["gen"] == g, "mean_sigma"] for g in ["cnt", "abl"]])
plt.xlabel("Genotype")
plt.ylabel("Mean sigma")

plt.text(0.5, 2, f"p={diff_p.pvalue:0.4f}")