In [None]:
%matplotlib widget

In [None]:
from pathlib import Path
import numpy as np
from bouter import EmbeddedExperiment
import pandas as pd
from tqdm import tqdm
import flammkuchen as fl

from matplotlib import pyplot as plt
import seaborn as sns
sns.set(palette="deep", style="ticks")
cols = sns.color_palette()

from scipy.stats import ranksums

In [None]:
pooled_cc = fl.load("/Users/luigipetrucco/Desktop/pooled_cc.h5")
pooled_dicts = fl.load("/Users/luigipetrucco/Desktop/pooled_resps.h5")

In [None]:
fids = list(pooled_dicts.keys())

In [None]:
from numba import njit, prange

@njit
def roll_matrix(input_mat, indexes):
    
    output_mat = np.empty_like(input_mat)
    
    for i in prange(output_mat.shape[1]):
        output_mat[:, i] = np.roll(input_mat[:, i], indexes[i])
    
    return output_mat

def center_on_peak(input_mat):
    """Recenter along the 1st dimension.
    """
    idxs = - np.argmax(input_mat, 0) - input_mat.shape[0] // 2
    
    return roll_matrix(input_mat, idxs)

In [None]:
REL_SCORE_THR = 0.5

exp_df = []
df_list = []
for k in tqdm(list(pooled_dicts.keys())):
    if k != "210611_f14":
        gen = ["abl", "cnt"][(int(k.split("_f")[1]) % 2)]
        rel_scores = pooled_dicts[k]["rel_scores"]
        amp_scores = pooled_dicts[k]["amp_scores"]

        n_cells = rel_scores.shape[1]
        n_stims = rel_scores.shape[0]

        reord_rel = center_on_peak(rel_scores)
        reord_amp = center_on_peak(amp_scores)

        df = pd.DataFrame(np.concatenate([rel_scores, amp_scores, reord_rel, reord_amp], 0).T, 
                          columns=[f"rel_{i}" for i in range(n_stims)] + 
                                  [f"amp_{i}" for i in range(n_stims)] + 
                                  [f"rel_reord_{i}" for i in range(n_stims)] + 
                                  [f"amp_reord_{i}" for i in range(n_stims)])
        df["cid"] = [f"{k}_{i:05.0f}" for i in range(n_cells)]
        df["gen"] = gen
        df["fid"] = k

        df["max_rel"] = np.max(rel_scores, 0)
        df["max_rel_i"] = np.argmax(rel_scores, 0)
        df["max_amp"] = np.max(amp_scores, 0)
        df["max_amp_i"] = np.argmax(amp_scores, 0)

        exp_df.append(dict(fid=k,
                    gen=gen,
                    n_cells=n_cells,
                    above_rel_thr=np.sum(np.max(rel_scores, 0) > REL_SCORE_THR)))
        df_list.append(df)
    
full_df = pd.concat(df_list, axis=0)
full_df = full_df.set_index(full_df["cid"])


exp_df = pd.DataFrame(exp_df)
exp_df = exp_df.set_index("fid")

In [None]:
pooled_cc[k].keys()

In [None]:
genotypes = exp_df["gen"].unique()
cc_dist_mats_dict = {g: [] for g in genotypes}

for k in exp_df.index:
    cc_dist_mats_dict[exp_df.loc[k, "gen"]].append(pooled_cc[k]["cc_dist"])
    
for g in genotypes:
    cc_dist_mats_dict[g] = np.array(cc_dist_mats_dict[g])

In [None]:
plt.figure()
for i, g in enumerate(genotypes):
    plt.plot(cc_dist_mats_dict[g].T, c=cols[i], lw=0.3)
    plt.plot(cc_dist_mats_dict[g].mean(0), c=cols[i], lw=2) 
    
plt.yscale("log")

In [None]:
f = exp_df.index[0]
rel_thr = 0.4

In [None]:
all_scores = dict({k: [] for k in ["cnt", "abl"]})
all_means = dict({k: [] for k in ["cnt", "abl"]})

for f in exp_df.index:
    ccs = pooled_cc[f]["cc_abs_avg"][full_df.loc[full_df["fid"] == f, "max_rel"] > rel_thr]
    
    all_scores[exp_df.loc[f, "gen"]].append(ccs)
    all_means[exp_df.loc[f, "gen"]].append(np.mean(ccs))

for g in ["cnt", "abl"]:
    all_means[g] = np.array(all_means[g])

In [None]:
from scipy.stats import ranksums

In [None]:
plt.figure(figsize=(4,3))
for i, g in enumerate(["cnt", "abl"]):
    plt.scatter(np.ones(len(all_means[g]))*i+np.random.randn(len(all_means[g]))*0.05, 
                all_means[g])

diff_p = ranksums(*[all_means[g] for g in ["cnt", "abl"]])
plt.text(0.5, 0.25, f"p={diff_p.pvalue:0.4f}")

sns.despine()

In [None]:
pooled_all_dicts["210611_f1"].keys()

In [None]:
pooled_all_dicts = fl.load("/Users/luigipetrucco/Desktop/pooled_all_resps.h5")

In [None]:
rel_istograms = dict()

plt.figure(figsize=(4, 3))
plt.title("Reliability score")
for i, g in enumerate(["cnt", "abl"]):
    sel_fids = full_df.loc[full_df["gen"]==g, "fid"].unique()# [:1]
    print(sel_fids)
    all_hists = []
    for f in sel_fids:
        h, bins = np.histogram(full_df.loc[full_df["fid"]==f, "max_rel"], 
                                   np.arange(0, 1, 0.02), density=True)
        all_hists.append(h)
    rel_istograms[g] = np.array(all_hists)

    x_bins = (bins[1:] + bins[:-1]) / 2

    plt.plot(x_bins, rel_istograms[g].T, c=cols[i], lw=0.2)
    plt.plot(x_bins, rel_istograms[g].mean(0), c=cols[i], lw=2, label=g)
    
plt.legend(frameon=False)
plt.yscale("log")
sns.despine()

In [None]:
rel_istograms = dict()

plt.figure(figsize=(4, 3))
plt.title("Response amplitude")

for i, g in enumerate(["cnt", "abl"]):
    sel_fids = full_df.loc[full_df["gen"]==g, "fid"].unique()
    print(sel_fids)
    all_hists = []
    for f in sel_fids:
        h, bins = np.histogram(full_df.loc[full_df["fid"]==f, "max_amp"], 
                                   np.arange(0, 6, 0.2),density=True)
        all_hists.append(h)
    rel_istograms[g] = np.array(all_hists)

    x_bins = (bins[1:] + bins[:-1]) / 2

    plt.plot(x_bins, rel_istograms[g].T, c=cols[i], lw=0.2)
    plt.plot(x_bins, np.nanmean(rel_istograms[g], 0), c=cols[i], lw=2, label=g)
    
plt.legend(frameon=False)
sns.despine()
plt.yscale("log")
plt.tight_layout()

## Fraction of responsive cells

In [None]:
exp_df["fraction_resp"] = (exp_df["above_rel_thr"] / exp_df["n_cells"])*100
plt.figure(figsize=(3, 3))
sns.swarmplot(data=exp_df, x="gen", y="fraction_resp")
sns.despine()
plt.tight_layout()
diff_p = ranksums(*[exp_df.loc[exp_df["gen"] == g, "fraction_resp"] for g in ["cnt", "abl"]])
plt.xlabel("Genotype")
plt.ylabel("Responsive cells (%)")

plt.text(0.5, 6, f"p={diff_p.pvalue:0.4f}")

In [None]:
plt.close("all")
f, axs = plt.subplots(3, 1, figsize=(4, 6))
for i, k in enumerate(["abl", "cnt"]):
    data = full_df.loc[(full_df["gen"]==k) & (full_df["max_rel"] > 0.8), 
                       [f"rel_reord_{i}" for i in range(36)]].values.T
    axs[i].plot(data[:, ::], lw=0.3, c=[c+0.1 for c in cols[i]])

    axs[i].plot(np.nanmean(data, 1), lw=2, c=[c-0.1 for c in cols[i]])
    
for i, k in enumerate(["abl", "cnt"]):
    data = full_df.loc[(full_df["gen"]==k) & (full_df["max_rel"] > 0.8), 
                       [f"rel_reord_{i}" for i in range(36)]].values.T

    axs[2].plot(np.nanmean(data, 1), lw=2, c=[c-0.1 for c in cols[i]])
plt.show()
sns.despine()

In [None]:
data = full_df.loc[(full_df["gen"]==k) & (full_df["max_rel"] > 0.8), 
                       [f"rel_reord_{i}" for i in range(36)]].values.T

popt, pcov = [], []
for i in range(data.shape[1]):
    o, c = curve_fit(gaussian, x, y, p0=[max(y), mean, sigma])
    popt.append(o)
    pcov.append

In [None]:
from scipy.optimize import curve_fit

x = np.arange(36)
y = data[:, 500]

mean = sum(x * y) / sum(y)
sigma = np.sqrt(sum(y * (x - mean)**2) / sum(y))

def gaussian(x, a, x0, sigma):
    return a * np.exp(-(x - x0)**2 / (2 * sigma**2))

popt, pcov = curve_fit(gaussian, x, y, p0=[max(y), mean, sigma])

plt.figure()
plt.plot(x, y, 'b+:', label='data')
plt.plot(x, Gauss(x, *popt), 'r-', label='fit')
plt.legend()
plt.title('Fig. 3 - Fit for Time Constant')
plt.xlabel('Time (s)')
plt.ylabel('Voltage (V)')
plt.show()

In [None]:
full_data_mat = full_df.loc[:, [f"rel_reord_{i}" for i in range(36)]].values

In [None]:
popt = []
pcov = []
for i in tqdm(range(full_data_mat.shape[0])):
    try:
        p, c = curve_fit(gaussian, x, full_data_mat[i, :], p0=[max(y), mean, sigma])
        popt.append(p)
        pcov.append(c)
    except RuntimeError:
        popt.append(np.full(3, np.nan))
        pcov.append(np.full((3, 3), np.nan))

In [None]:
fl.save("fit.h5", dict(popt=np.array(popt), pcov=np.array(pcov)))

In [None]:
fit_params = np.array(popt)

for i, par_name in enumerate(["fit_amp", 
                              "fit_mn", 
                              "fit_sigma"]):
    full_df[par_name] = fit_params[:, i]
    
full_df["fit_sigma"] = np.abs(full_df["fit_sigma"])

In [None]:
plt.figure()
plt.hist()
sns.violinplot(data=full_df[full_df["max_rel"] > 0.7], x="gen", y="fit_sigma")
# full_df.[(full_df["gen"]=="abl") & (), "fit_sigma"]

In [None]:
plt.close("all")
plt.figure()
thr = 0.7
bw = 0.2
plt.hist(full_df.loc[(full_df["max_rel"] > thr) & (full_df["gen"] == "cnt"), "fit_sigma"].values,
         np.arange(0, 10, bw), lw=0, alpha=0.4, density=True)
plt.hist(full_df.loc[(full_df["max_rel"] > thr) & (full_df["gen"] == "abl"), "fit_sigma"].values,
         np.arange(0, 10, bw), lw=0, alpha=0.4, density=True)
plt.show()