In [None]:
%matplotlib widget

In [None]:
from pathlib import Path
import numpy as np
from bouter import EmbeddedExperiment
import pandas as pd
from tqdm import tqdm
import flammkuchen as fl

from matplotlib import pyplot as plt
import seaborn as sns
sns.set(palette="deep", style="ticks")

from scipy.signal import detrend

from bouter.utilities import crop, reliability

In [None]:
path = Path("/Volumes/Shared/experiments/E0070_receptive_field/v01_sliding_bars/210603_f0")

In [None]:
traces = fl.load(path / "data_from_suite2p_unfiltered.h5", "/traces").T
coords = fl.load(path / "data_from_suite2p_unfiltered.h5", "/coords")

exp = EmbeddedExperiment(path)

sel = ~(traces == 0).all(0)
traces = traces[:, sel]
coords = coords[sel, :]

# detrend the traces
for i in tqdm(range(traces.shape[1])):
    traces[:, i] = detrend(traces[:, i])
    

# nan large bouts
# Window to nan bouts:
NAN_WND_PRE_S = 1
NAN_WND_POST_S = 15

# Read original frequency:
fs = int(exp["imaging"]["microscope_config"]["lightsheet"]["scanning"]["z"]["frequency"])
samp_n = traces.shape[0]
t_orig = np.arange(traces.shape[0]) / fs

nan_wnd_pre = int(NAN_WND_PRE_S * fs)
nan_wnd_post = int(NAN_WND_POST_S * fs)
bouts_df = exp.get_bout_properties()
large_bouts_t = bouts_df.loc[bouts_df["peak_vig"] > 1.5, "t_start"].values
large_bouts_idxs = (large_bouts_t * fs).astype(np.int)

# nan traces:
traces_nanned = traces.copy()
for idx in tqdm(large_bouts_idxs):
    if idx > nan_wnd_pre and idx < (samp_n - nan_wnd_post):
        traces_nanned[idx - nan_wnd_pre:idx + nan_wnd_post, :] = np.nan

In [None]:
# Create dataframe of stimulus features:
stim_logs = exp["stimulus"]["log"] # logs of individual stimuli

# Loop and create dictionary for each stim
cond_dict = []
for i in range(1, len(stim_logs), 3):
    entry = stim_logs[i]
    pre_pause = stim_logs[i - 1]
    
    cond_dict.append(dict(t_start=round(pre_pause["t_start"]),
                          lum=entry["color_2"][0],
                          theta=entry["theta"] + int(entry["x"] < 15)*np.pi,
                          vel=10 if entry["t_start"] % 8 > 2 else 5,
                          size=entry["bar_size"]))

stim_df = pd.DataFrame(cond_dict)  # convert to dataframe


# Reshape traces matrix to crop around stimuli:
stim_dur = stim_df.loc[1, "t_start"]
n_samp = traces.shape[0]
n_cells = traces.shape[1]
n_reps = int(samp_n / (stim_dur*fs))
n_samp_stim = int(n_samp / n_reps)
reshaped = traces_nanned.T.reshape(n_cells, n_reps, n_samp_stim)
reshaped = reshaped.swapaxes(0, 2)

# Create shuffle reshaped matrix:
rand_trig = np.random.randint(int(stim_dur*fs), 
                  traces.shape[0] - int(stim_dur*fs), 
                  n_reps)

reshaped_shuf = crop(traces, rand_trig, pre_int=0, post_int=int(stim_dur*fs))

thetas = sorted(stim_df["theta"].unique())
vels = sorted(stim_df["vel"].unique())
lums = sorted(stim_df["lum"].unique())
sizes = sorted(stim_df["size"].unique())

stim_df["stim_id"] = 0
resp_block = np.zeros((len(thetas), len(vels), len(lums), len(sizes), n_samp_stim, 4, n_cells))
resp_block_shuf = np.zeros((len(thetas), len(vels), len(lums), len(sizes), n_samp_stim, 4, n_cells))

stim_types_n = 0
for j, theta in enumerate(thetas):
    for k, vel in enumerate(vels):
        for i, lum in enumerate(lums):
            for z, size in enumerate(sizes):
                select = (stim_df["lum"] == lum) &  \
                         (stim_df["vel"] == vel) & \
                         (stim_df["theta"] == theta) & \
                         (stim_df["size"] == size) 
                stim_df.loc[select, "stim_id"] = stim_types_n
                
                resp_block[j, k, i, z, :, :, :] = reshaped[:, select, :]
                resp_block_shuf[j, k, i, z, :, :, :] = reshaped_shuf[:, select, :]
                
                stim_types_n+= 1
                

In [None]:
plt.close("all")
cid = 566 # 134 # 283 # 566 # 13190# 13090  # 12100# 1020
f, axs = plt.subplots(2, 4, figsize=(8, 4), sharey=True)
#plt.plot(reshaped[cid, :, :].T)
lum = 0
for th_i  in range(len(thetas)):
    for vel_i in range(len(vels)):
        for col, lum_i in zip(["b", "r"], [0, 1]):
            for s_i in range(len(sizes)):
                # idxs = stim_df[select_ids(stim_df, lum, th, vel, size)].index
                t = resp_block[th_i, vel_i, lum_i, s_i, :, :, cid]
                # print(t.shape)
                t = t - np.nanmean(t[:3, :], 0)
                axs[vel_i, th_i].plot(np.arange(t.shape[0]) / fs, np.nanmean(t, 1), 
                               c=col, alpha=1-0.2*s_i, label=f"{size} mm")
                
        axs[0, th_i].set_title(f"Theta: {int(180*thetas[th_i]/np.pi)}°")
    axs[vel_i, 0].set_ylabel(f"Vel: {vels[vel_i]} mm/s")
# axs[0, 3].legend(frameon=False)
plt.show()
plt.tight_layout()
sns.despine()

In [None]:
resps_shuf = resp_block_shuf[:, :, :, :, :, :, cid]

In [None]:
BASELINE_PTS = 3

means_shuf = np.zeros((stim_types_n, n_samp_stim, n_cells))

for i in range(stim_types_n):
    # select groups of 4 shuffled reps:
    means_shuf[i, :, :] = np.nanmedian(reshaped_shuf[:, i*4:(i+1)*4, :], 1) 
    # subtract mean
    means_shuf[i, :, :] = means_shuf[i, :, :] - np.nanmedian(means_shuf[i, :BASELINE_PTS, :], 0)
    
null_distr = means_shuf.reshape(means_shuf.shape[0]*means_shuf.shape[1], -1)

In [None]:
for j, theta in enumerate(thetas):
    for k, vel in enumerate(vels):
        for i, lum in enumerate([0]):
            for z, size in enumerate(sizes):
                resps = resp_block[j, k, i, z, :, :, :]
                resps_shuf = resp_block_shuf[j, k, i, z, :, :, :]
                

In [None]:
resps.shape

In [None]:
BASELINE_PTS = 3
mn = np.nanmedian(resps, 1)
mn = mn - np.nanmedian(mn[:BASELINE_PTS])


In [None]:
B_SIZE = 0.01
THR = 0.01
hists = []
for i in tqdm(range(n_cells)):
    h, _ = np.histogram(np.abs(null_distr[:, i]), np.arange(0, 40, B_SIZE), density=True)
    hists.append(h)
hists = np.array(hists).T

cum_p = np.cumsum(hists, axis=0) * B_SIZE

In [None]:
counts = np.zeros(n_cells)
log_p_tot = np.zeros(n_cells)
for i in tqdm(range(n_cells)):
    p_vect = np.abs(1 - cum_p[(np.abs(mn[:, i]) / B_SIZE).astype(np.int), i])
    counts[i] = sum(p_vect < THR)
    log_p_tot[i] = np.sum(np.log10(np.abs(1 - cum_p[(np.abs(mn[:, i]) / b_size).astype(np.int), i])))

In [None]:
plt.figure()
plt.plot(np.abs(mn[:, 566]))
plt.plot(np.log10(np.abs(1 - cum_p[(np.abs(mn[:, 566]) / b_size).astype(np.int), 566])))
plt.scatter(np.random.randint(0, 24, null_distr.shape[0]), np.abs(null_distr[:, 566]))
plt.show()

In [None]:
mn.shape

In [None]:
from bouter.utilities import reliability

rel_scores = np.zeros((len(thetas), len(vels), len(lums), len(sizes), n_cells))
rel_scores_shuf = np.zeros((len(thetas), len(vels), len(lums), len(sizes), n_cells))

k = 0
for j, theta in enumerate(thetas):
    for k, vel in enumerate(vels):
        for i, lum in enumerate([lums]):
            for z, size in enumerate(sizes):
                resps = resp_block[j, k, i, z, :, :, cid]
                resps_shuf = resp_block_shuf[j, k, i, z, :, :, cid]
                
                rels_df[j, k, i, z, :] = reliability(resp_block)
                rels_df_shuf[j, k, i, z, :] = reliability(reshaped_shuf[:, idxs, :])

                k+= 1


            

In [None]:
plt.figure(figsize=(4, 3))
plt.hist(rels_df.flatten(), np.arange(-1, 1, 0.05), alpha=0.4, density=True)
plt.hist(rels_df_shuf.flatten(), np.arange(-1, 1, 0.05), alpha=0.4, density=True)
plt.xlabel("reliability score")
sns.despine()
plt.show()

In [None]:
plt.close("all")
cid = 12100  # 13190  # 13090  # 12100  # 1020
f, axs = plt.subplots(2, 4, figsize=(8, 4), sharey=True, sharex=True)
#plt.plot(reshaped[cid, :, :].T)
sizes = sorted(stim_df["size"].unique())
for j, vel in enumerate(sorted(stim_df["vel"].unique())):
    for i, th in enumerate(sorted(stim_df["theta"].unique())):
        axs[j, i].axhspan(-1, np.percentile(rels_df_shuf[cid], 95), facecolor=(0.85,)*3, lw=0)
        for col, lum in zip(["b", "r"], [0, 255]):
            ks = [stim_df.loc[select_ids(stim_df, lum, th, vel, s), "stim_id"].values[0] for s in sizes]

            vals = rels_df[cid, ks]
            axs[j, i].plot(sizes, vals, 
                           c=col, label=f"{size} mm")
            axs[j, i].set_ylim(-0.2, 1.1)
            axs[j, i].set_xlim(0, 10)
                
        axs[0, i].set_title(f"Theta: {int(180*th/np.pi)}°")
        axs[1, i].set_xlabel(f"bar size (mm)")
    axs[j, 0].set_ylabel(f"Vel: {vel} mm/s")
# axs[0, 3].legend(frameon=False)
plt.show()
plt.tight_layout()
sns.despine()