In [None]:
%matplotlib widget

In [None]:
from pathlib import Path
import numpy as np
from bouter import EmbeddedExperiment
import pandas as pd
from tqdm import tqdm
import flammkuchen as fl

from matplotlib import pyplot as plt
import seaborn as sns
sns.set(palette="deep", style="ticks")
cols = sns.color_palette()

from scipy.stats import ranksums, ttest_ind, kstest

In [None]:
pooled_dicts = fl.load("/Users/luigipetrucco/Desktop/pooled_resps.h5")

In [None]:
from numba import njit, prange

@njit
def roll_matrix(input_mat, indexes):
    
    output_mat = np.empty_like(input_mat)
    
    for i in prange(output_mat.shape[1]):
        output_mat[:, i] = np.roll(input_mat[:, i], indexes[i])
    
    return output_mat

def center_on_peak(input_mat):
    """Recenter along the 1st dimension.
    """
    idxs = - np.argmax(input_mat, 0) - input_mat.shape[0] // 2
    
    return roll_matrix(input_mat, idxs)

In [None]:
REL_SCORE_THR = 0.5

exp_df = []
df_list = []
for k in tqdm(list(pooled_dicts.keys())):
    if k != "210611_f14":
        gen = ["abl", "cnt"][(int(k.split("_f")[1]) % 2)]
        rel_scores = pooled_dicts[k]["rel_scores"]
        amp_scores = pooled_dicts[k]["amp_scores"]

        n_cells = rel_scores.shape[1]
        n_stims = rel_scores.shape[0]

        reord_rel = center_on_peak(rel_scores)
        reord_amp = center_on_peak(amp_scores)

        df = pd.DataFrame(np.concatenate([rel_scores, amp_scores, reord_rel, reord_amp], 0).T, 
                          columns=[f"rel_{i}" for i in range(n_stims)] + 
                                  [f"amp_{i}" for i in range(n_stims)] + 
                                  [f"rel_reord_{i}" for i in range(n_stims)] + 
                                  [f"amp_reord_{i}" for i in range(n_stims)])
        df["cid"] = [f"{k}_{i:05.0f}" for i in range(n_cells)]
        df["gen"] = gen
        df["fid"] = k

        df["max_rel"] = np.max(rel_scores, 0)
        df["max_rel_i"] = np.argmax(rel_scores, 0)
        df["max_amp"] = np.max(amp_scores, 0)
        df["max_amp_i"] = np.argmax(amp_scores, 0)

        exp_df.append(dict(fid=k,
                    gen=gen,
                    n_cells=n_cells,
                    above_rel_thr=np.sum(np.max(rel_scores, 0) > REL_SCORE_THR)))
        df_list.append(df)

full_df = pd.concat(df_list, axis=0)
full_df = full_df.set_index(full_df["cid"])


exp_df = pd.DataFrame(exp_df)
exp_df = exp_df.set_index("fid")

## Mutual info calculation

In [None]:
pooled_all_dicts = fl.load("/Users/luigipetrucco/Desktop/pooled_all_resps.h5")

In [None]:
all_amps = pooled_all_dicts["210611_f1"]["all_amps"]
rel_scores = pooled_all_dicts["210611_f1"]["rel_scores"]
amp_scores = pooled_all_dicts["210611_f1"]["amp_scores"]
max_rels = np.max(rel_scores, 0)
max_amps = np.max(amp_scores, 0)

max_idx = np.argsort(max_rels)[-2]

In [None]:
plt.figure()
plt.plot(rel_scores[:, max_idx])
plt.plot(amp_scores[:, max_idx])
plt.plot(all_amps[:, :, max_idx].T, c="k", lw=0.2)

In [None]:
@njit
def stim_resp_binning(cell_resps, v_abs=1.5, binning=0.25):
    cell_resps = cell_resps.copy()
    
    # Correct extremes:
    for i in range(cell_resps.shape[0]):
        for j in range(cell_resps.shape[1]):
            if cell_resps[i, j] > v_abs:
                cell_resps[i, j] = v_abs - binning
            if cell_resps[i, j] < -v_abs:
                cell_resps[i, j] = -v_abs + binning

    base_vect = np.arange(-v_abs, v_abs + binning, binning)
    n_stims = all_amps.shape[1]
    all_counts = np.empty((len(base_vect)-1, n_stims))
    for c in range(n_stims):
        all_counts[:, c], bins = np.histogram(cell_resps[:, c], base_vect)
    all_counts = all_counts / np.sum(all_counts)
        
    return all_counts

@njit
def mutual_info(all_counts):
    x_marg = all_counts.sum(0)
    y_marg = all_counts.sum(1)

    count = 0
    for x in range(len(x_marg)):
        for y in range(len(y_marg)):
            if all_counts[y, x] > 0:
                count += all_counts[y, x] * np.log2(all_counts[y, x] / (x_marg[x] * y_marg[y]))
    return count

@njit(parallel=True)
def mutual_info_allcells(all_amps):
    info_array = np.empty(all_amps.shape[2])
    
    for i in prange(all_amps.shape[2]):
        info_array[i] = mutual_info(stim_resp_binning(all_amps[:, :, i]))
        
    return info_array

In [None]:
all_counts = stim_resp_binning(all_amps[:, :, max_idx].copy(), 1.5, 0.25)

In [None]:
plt.figure()
plt.imshow(all_counts.T)
plt.xlabel("Bin")
plt.ylabel("Stimulus")

In [None]:
%%time
all_mutual_info = mutual_info_allcells(all_amps[:, :, :])

In [None]:
all_amps.shape

In [None]:
plt.figure()
plt.scatter(max_rels, max_amps, s=5, alpha=0.5)

In [None]:
plt.close("all")
max_idx = np.argsort(all_mutual_info)[-1]
plt.figure(figsize=(4, 3))
# plt.plot(rel_scores[:, max_idx])
plt.plot((np.arange(36) - 18)*10, all_amps[:, :, max_idx].T, "o", c=(0.7,)*3, lw=0, alpha=1)
plt.plot((np.arange(36) - 18)*10, amp_scores[:, max_idx], label="mean",c=cols[3])
plt.xlabel("Orientation (°)")
plt.ylabel("Response amplitude")
plt.tight_layout()
plt.show()
sns.despine()

In [None]:
full_df["mutual_info"] = np.nan
for f in tqdm(full_df["fid"].unique()):
    all_amps = pooled_all_dicts[f]["all_amps"]
    all_mutual_info = mutual_info_allcells(all_amps[:, :, :])
    full_df.loc[full_df["fid"] == f, "mutual_info"] = all_mutual_info

In [None]:
exp_df["mean_info"] = full_df[full_df["max_rel"] > 0.6].groupby("fid").mean()["mutual_info"]

In [None]:
plt.figure(figsize=(3, 3))
sns.swarmplot(data=exp_df, x="gen", y="mean_info")
sns.despine()
plt.tight_layout()
diff_p = ttest_ind(*[exp_df.loc[exp_df["gen"] == g, "mean_info"] for g in ["cnt", "abl"]])
plt.xlabel("Genotype")
plt.ylabel("Mean mutual info")

plt.text(0.5, 1.05, f"p={diff_p.pvalue:0.4f}")

In [None]:
full_df.keys()

In [None]:
plt.figure(figsize=(3, 3))
sns.violinplot(data=full_df, x="gen", y="mutual_info")
sns.despine()
plt.tight_layout()
diff_p = ttest_ind(*[full_df.loc[full_df["gen"] == g, "mutual_info"] for g in ["cnt", "abl"]])
plt.xlabel("Genotype")
plt.ylabel("Mean mutual info")

plt.text(0.5, 1.05, f"p={diff_p.pvalue:0.4e}")

In [None]:
plt.close("all")
plt.figure(figsize=(4, 3))
thr = 0.5
bw = 0.05
histograms = dict()
for g in ["cnt", "abl"]:
    histograms[g], b = np.histogram(
         full_df.loc[(full_df["max_rel"] > thr) & (full_df["gen"] == g), "mutual_info"].values,
         np.arange(0.25, 1.75, bw), density=True)
    
    x_bins = (b[1:] + b[:-1]) / 2
    plt.fill_between(x_bins, np.zeros(len(x_bins)), histograms[g]*bw, alpha=0.4, lw=0, label=g)

plt.legend(frameon=False)
#plt.hist(full_df.loc[(full_df["max_rel"] > thr) & (full_df["gen"] == "abl"), "fit_sigma"].values,
#         np.arange(0, 10, bw), lw=0, alpha=0.4, density=True)
ks_diff = kstest(*[histograms[g] for g in ["cnt", "abl"]])
ks_diff = ttest_ind(*[full_df.loc[(full_df["gen"] == g) & (full_df["max_rel"] > thr), 
                                  "mutual_info"] for g in ["cnt", "abl"]])
plt.text(1.2, 0.1, f"p={ks_diff.pvalue:0.4f}")

plt.xlabel("Mutual information with stimulus")
plt.ylabel("Count")
plt.tight_layout()
sns.despine()
plt.show()

In [None]:
exp_df.loc[exp_df["gen"] == "cnt", "mean_info"].values
exp_df.loc[exp_df["gen"] == "abl", "mean_info"].values

In [None]:
ks_diff.pvalue

In [None]:
rel_istograms = dict()

plt.figure(figsize=(4, 3))
plt.title("Reliability score")
for i, g in enumerate(["cnt", "abl"]):
    sel_fids = full_df.loc[full_df["gen"]==g, "fid"].unique()# [:1]
    print(sel_fids)
    all_hists = []
    for f in sel_fids:
        h, bins = np.histogram(full_df.loc[full_df["fid"]==f, "max_rel"], 
                                   np.arange(0, 1, 0.02), density=True)
        all_hists.append(h)
    rel_istograms[g] = np.array(all_hists)

    x_bins = (bins[1:] + bins[:-1]) / 2

    plt.plot(x_bins, rel_istograms[g].T, c=cols[i], lw=0.2)
    plt.plot(x_bins, rel_istograms[g].mean(0), c=cols[i], lw=2, label=g)
    
plt.legend(frameon=False)
plt.yscale("log")
sns.despine()

In [None]:
rel_istograms = dict()

plt.figure(figsize=(4, 3))
plt.title("Response amplitude")

for i, g in enumerate(["cnt", "abl"]):
    sel_fids = full_df.loc[full_df["gen"]==g, "fid"].unique()
    print(sel_fids)
    all_hists = []
    for f in sel_fids:
        h, bins = np.histogram(full_df.loc[full_df["fid"]==f, "max_amp"], 
                                   np.arange(0, 6, 0.2),density=True)
        all_hists.append(h)
    rel_istograms[g] = np.array(all_hists)

    x_bins = (bins[1:] + bins[:-1]) / 2

    plt.plot(x_bins, rel_istograms[g].T, c=cols[i], lw=0.2)
    plt.plot(x_bins, np.nanmean(rel_istograms[g], 0), c=cols[i], lw=2, label=g)
    
plt.legend(frameon=False)
sns.despine()
plt.yscale("log")
plt.tight_layout()

## Fraction of responsive cells

In [None]:
exp_df["fraction_resp"] = (exp_df["above_rel_thr"] / exp_df["n_cells"])*100
plt.figure(figsize=(3, 3))
sns.swarmplot(data=exp_df, x="gen", y="fraction_resp")
sns.despine()
plt.tight_layout()
diff_p = ranksums(*[exp_df.loc[exp_df["gen"] == g, "fraction_resp"] for g in ["cnt", "abl"]])
plt.xlabel("Genotype")
plt.ylabel("Responsive cells (%)")

plt.text(0.5, 6, f"p={diff_p.pvalue:0.4f}")

In [None]:
plt.close("all")
f, axs = plt.subplots(3, 1, figsize=(4, 6))
for i, k in enumerate(["abl", "cnt"]):
    data = full_df.loc[(full_df["gen"]==k) & (full_df["max_rel"] > 0.8), 
                       [f"rel_reord_{i}" for i in range(36)]].values.T
    axs[i].plot(data[:, ::], lw=0.3, c=[c+0.1 for c in cols[i]])

    axs[i].plot(np.nanmean(data, 1), lw=2, c=[c-0.1 for c in cols[i]])
    
for i, k in enumerate(["abl", "cnt"]):
    data = full_df.loc[(full_df["gen"]==k) & (full_df["max_rel"] > 0.8), 
                       [f"rel_reord_{i}" for i in range(36)]].values.T

    axs[2].plot(np.nanmean(data, 1), lw=2, c=[c-0.1 for c in cols[i]])
plt.show()
sns.despine()

In [None]:
data = full_df.loc[(full_df["gen"]==k) & (full_df["max_rel"] > 0.8), 
                       [f"rel_reord_{i}" for i in range(36)]].values.T

popt, pcov = [], []
for i in range(data.shape[1]):
    o, c = curve_fit(gaussian, x, y, p0=[max(y), mean, sigma])
    popt.append(o)
    pcov.append

In [None]:
from scipy.optimize import curve_fit

x = np.arange(36)
y = data[:, 500]

mean = sum(x * y) / sum(y)
sigma = np.sqrt(sum(y * (x - mean)**2) / sum(y))

def gaussian(x, a, x0, sigma):
    return a * np.exp(-(x - x0)**2 / (2 * sigma**2))

popt, pcov = curve_fit(gaussian, x, y, p0=[max(y), mean, sigma])

plt.figure()
plt.plot(x, y, 'b+:', label='data')
plt.plot(x, Gauss(x, *popt), 'r-', label='fit')
plt.legend()
plt.title('Fig. 3 - Fit for Time Constant')
plt.xlabel('Time (s)')
plt.ylabel('Voltage (V)')
plt.show()

In [None]:
full_data_mat = full_df.loc[:, [f"rel_reord_{i}" for i in range(36)]].values

In [None]:
popt = []
pcov = []
for i in tqdm(range(full_data_mat.shape[0])):
    try:
        p, c = curve_fit(gaussian, x, full_data_mat[i, :], p0=[max(y), mean, sigma])
        popt.append(p)
        pcov.append(c)
    except RuntimeError:
        popt.append(np.full(3, np.nan))
        pcov.append(np.full((3, 3), np.nan))

In [None]:
fl.save("fit.h5", dict(popt=np.array(popt), pcov=np.array(pcov)))

In [None]:
fit_params = np.array(popt)

for i, par_name in enumerate(["fit_amp", 
                              "fit_mn", 
                              "fit_sigma"]):
    full_df[par_name] = fit_params[:, i]
    
full_df["fit_sigma"] = np.abs(full_df["fit_sigma"])

In [None]:
plt.figure()
plt.hist()
sns.violinplot(data=full_df[full_df["max_rel"] > 0.7], x="gen", y="fit_sigma")
# full_df.[(full_df["gen"]=="abl") & (), "fit_sigma"]

In [None]:
plt.close("all")
plt.figure()
thr = 0.7
bw = 0.2
plt.hist(full_df.loc[(full_df["max_rel"] > thr) & (full_df["gen"] == "cnt"), "fit_sigma"].values,
         np.arange(0, 10, bw), lw=0, alpha=0.4, density=True)
plt.hist(full_df.loc[(full_df["max_rel"] > thr) & (full_df["gen"] == "abl"), "fit_sigma"].values,
         np.arange(0, 10, bw), lw=0, alpha=0.4, density=True)
plt.show()