In [None]:
%matplotlib widget

In [None]:
from pathlib import Path
import numpy as np
from bouter import EmbeddedExperiment
import pandas as pd
from tqdm import tqdm
import flammkuchen as fl

from matplotlib import pyplot as plt
import seaborn as sns
sns.set(palette="deep", style="ticks")

from scipy.signal import detrend

from bouter.utilities import crop, reliability

In [None]:
master_path = Path("/Volumes/Shared/experiments/E0070_receptive_field/v04_flashing_rad_simple")


path_list = [f.parent for f in master_path.glob("*/data_from_suite2p_unfiltered.h5")]
path_list

In [None]:
# Check genotype consistency:
for path in path_list:
    exp = EmbeddedExperiment(path)
    if ((int(path.name.split("_f")[1]) % 2) != 0) != (exp["general"]["animal"]["genotype"]== "Huc:H2B-GCaMP6s"):
        print("Check metadata of fish ", path.name)

In [None]:
# Load traces and experiment metadata:

for path in path_list:
    traces = fl.load(path / "data_from_suite2p_unfiltered.h5", "/traces").T
    coords = fl.load(path / "data_from_suite2p_unfiltered.h5", "/coords")

    exp = EmbeddedExperiment(path)

    # detrend the traces
    for i in tqdm(range(traces.shape[1])):
        traces[:, i] = detrend(traces[:, i])
    traces = (traces - np.nanmean(traces, 0)) / np.nanstd(traces, 0)

    # Read original frequency:
    fs = int(exp["imaging"]["microscope_config"]["lightsheet"]["scanning"]["z"]["frequency"])
    samp_n = traces.shape[0]
    t_orig = np.arange(traces.shape[0]) / fs

    logs = exp["stimulus"]["log"][1::2]
    stim_dicts = []

    # plt.figure()
    for log in logs:
        clip = log["clip_mask"]
        orientation = 1 if ([0.5, 0.5] in clip) else 0

        # plt.plot([p[0] for p in clip], [p[1] for p in clip])
        if orientation:
            pos_start = np.arctan2(clip[1][1] - 0.5, clip[1][0] - 0.5)
            pos_end = np.arctan2(clip[2][1] - 0.5, clip[2][0] - 0.5)
        else:
            pos_start = (clip[0][0] - 0.5)
            pos_start = (clip[0][0] - 0.5)

        pos = clip[1] if orientation else clip[0]
        stim_dicts.append(dict(t=int(log["t_start"]),
                               orientation=orientation,
                               # size=size
                               pos_start=pos_start,
                               pos_end=pos_end
                            ))

    stim_df = pd.DataFrame(stim_dicts)

    # Crop around stimuli:
    n_cells = traces.shape[1]
    PRE_INT_S = 2
    POST_INT_S = 5
    cropped = crop(traces, stim_df["t"]*fs, 
                         pre_int=int(PRE_INT_S*fs), post_int=int(POST_INT_S*fs))
    cropped = cropped - cropped[:int(PRE_INT_S*fs), :, :].mean(0)

    # Find unique positions in the stimulus:
    pos_y = sorted(stim_df.loc[stim_df["orientation"]==1, "pos_start"].unique())

    # Loop over positions and compute reliability scores:
    rel_scores = np.zeros((len(pos_y), n_cells))
    amp_scores = np.zeros((len(pos_y), n_cells))

    for i, p in tqdm(list(enumerate(pos_y))):
        resps_idxs = stim_df[(stim_df["orientation"] == 1) & \
                              (stim_df["pos_start"] == p)].index

        rel_scores[i] = reliability(cropped[:, resps_idxs, :])
        amp_scores[i] = (cropped[10:15, resps_idxs, :].mean(0) - cropped[0:5, resps_idxs, :].mean(0)).mean(0)

    #cell_resps = pd.DataFrame(dict(rel_scores_y_max=rel_scores_y.max(0),
    #                               rel_scores_y_pk=rel_scores_y.argmax(0),
    #                              ))
    
    fl.save(path / "cell_resps.h5", dict(rel_scores=rel_scores, amp_scores=amp_scores))

In [None]:
exp = EmbeddedExperiment(path_list[0])
gen = exp["general"]["animal"]["genotype"]

all_cells = pd.DataFrame(fl.load(path_list[0] / "cell_resps.h5", "/rel_scores").T)
all_cells["fid"] = path.name
all_cells["cid"] = [f"{path.name}_c{i}" for i in range(len(all_cells))]
all_cells["gen"] = gen

In [None]:
for path in path_list[1:]:
    exp = EmbeddedExperiment(path)
    gen = exp["general"]["animal"]["genotype"]
    
    cells = pd.DataFrame(fl.load(path_list[0] / "cell_resps.h5", "/rel_scores").T)
    cells["fid"] = path.name
    cells["cid"] = [f"{path.name}_c{i}" for i in range(len(cells))]
    cells["gen"] = gen
    
    all_cells = pd.concat([all_cells, cells], axis=0)
    
all_cells = all_cells.reset_index()

In [None]:
rel_scores_y = all_cells.loc[:, [i for i in range(36)]].values.T
sorted_idxs = np.argsort(-rel_scores_y.max(0))# [:50]
cid = sorted_idxs[np.nanmax(rel_scores_y, 0)[sorted_idxs] > 0.5]
print(len(cid))

In [None]:
rel_scores_y.shape

In [None]:
sorted_idxs.shape

In [None]:
scores_y


In [None]:
-cell_resps.loc[:, "rel_scores_y_pk"]-18).shape

In [None]:

mat = np.array([np.roll(rel_scores_y[:, i], 
                        -cell_resps.loc[:, "rel_scores_y_pk"]-18) for i in range(len(all_cells
                                                                                    ))]).T
#mat.shape

In [None]:
sorted_idxs = np.argsort(-rel_scores_y.max(0))# [:50]
cid = sorted_idxs[np.nanmax(rel_scores_y, 0)[sorted_idxs] > 0.3]

y_base = np.arange(len(rel_scores_y))

f, axs = plt.subplots(2, 1, figsize=(5, 5), sharey=True)


axs[0].plot(y_base, np.flip(rel_scores_y[:, cid], 0), lw=0.5, c="b", alpha=0.4)
# axs[1, 0].plot(np.arange(20) - 7.5, np.flip(rel_scores_y[:, cid].mean(1), 0), lw=2, c="b")
axs[0].set_xlabel("Post - Ant (mm)")

mat = np.array([np.roll(rel_scores_y[:, i], 
                        -cell_resps.loc[i, "rel_scores_y_pk"]-18) for i in range(len(all_cells
                                                                                    ))]).T
axs[1].plot(y_base, np.flip(mat[:, cid], 0), lw=0.2, c="b", alpha=0.3)
axs[1].plot(y_base, np.flip(mat[:, cid].mean(1), 0), lw=1, c="k")
axs[1].set_xlabel("Post - Ant (mm)")

plt.tight_layout()
sns.despine()
    






In [None]:
n_cells

In [None]:
cid = np.argsort(-rel_scores_y.max(0))[:1000]

In [None]:
plt.close("all")

sel_c = cid[2]
c = cell_resps.loc[sel_c, "rel_scores_y_pk"]

plt.figure(figsize=(3, 2))
plt.plot(y_base, rel_scores_y[:, sel_c], c="r")
for i, p in tqdm(enumerate(range(c-3, c+3))):
    plt.figure(figsize=(3, 2))
    resps_idxs = stim_df[(stim_df["orientation"] == 1) & \
                          (stim_df["pos_start"] == pos_y[p])].index
    
    plt.plot(cropped[:, resps_idxs, sel_c])
    
# plt.plot()

In [None]:
mat.shape


In [None]:
plt.close("all")
sel = cell_resps["score_max"] > 0.3

f, axs = plt.subplots(2, 2, figsize=(6, 6))
for ax in axs[0]:
    ax.scatter(coords[:, 0]*10, coords[:, 2], c=(0.8,)*3, s=3)
    ax.axis("equal")
    ax.axis("off")

ax = axs[0, 0]
im = axs[0, 0].scatter(coords[sel, 0]*10, coords[sel, 2], 
                    c=cell_resps["score_max"][sel], s=3, vmin=0, vmax=1, cmap="viridis")

ax = axs[0, 1]
im = axs[0, 1].scatter(coords[sel, 0]*10, coords[sel, 2], 
                    c=cell_resps["rel_scores_y_pk"][sel], s=3, cmap="twilight")

for ax in axs[1]:
    ax.scatter(coords[:, 1], coords[:, 2], c=(0.8,)*3, s=3)
    ax.axis("equal")
    ax.axis("off")

ax = axs[1, 0]
im = axs[1, 0].scatter(coords[sel, 1], coords[sel, 2], 
                    c=cell_resps["score_max"][sel], s=3, vmin=0, vmax=1, cmap="viridis")
cax = f.add_axes((0.25, 0.1, 0.08, 0.012))
cb = plt.colorbar(im, cax=cax, orientation="horizontal", label="Rel. score")
# cb.set_ticks([0, 0.99])
# cb.set_ticklabels([1, 8])
# cb.set_ticks([0, 3])

ax = axs[1, 1]
im = axs[1, 1].scatter(coords[sel, 1], coords[sel, 2], 
                    c=cell_resps["rel_scores_y_pk"][sel], s=3, cmap="twilight")
cax = f.add_axes((0.55, 0.1, 0.08, 0.012))
cb = plt.colorbar(im, cax=cax, orientation="horizontal", label="Ant-Post")
cb.set_ticks([10, 19])
#cb.set_ticklabels(["+6", "-6"])
# cb.set_ticks([0, 3])

In [None]:
f, axs = plt.subplots(4, 4, figsize=(6, 6))

for n in range(16):
    cid = np.argsort(-cell_resps["score_max"])[n]
    mat = np.zeros((20, 25))

    for i in range(20):
        for j in range(25):
            mat[i, j] = rel_scores_x[j, cid] * rel_scores_y[i, cid]
            
    axs[n // 4, n%4].imshow(mat)
    axs[n // 4, n%4].axis("off")

In [None]:
# Reshape traces matrix to crop around stimuli:
stim_dur = stim_df.loc[1, "t_start"]
n_samp = traces.shape[0]
n_cells = traces.shape[1]
n_reps = int(samp_n / (stim_dur*fs))
n_samp_stim = int(n_samp / n_reps)
reshaped = traces_nanned.T.reshape(n_cells, n_reps, n_samp_stim)
reshaped = reshaped.swapaxes(0, 2)

# Create shuffle reshaped matrix:
rand_trig = np.random.randint(int(stim_dur*fs), 
                  traces.shape[0] - int(stim_dur*fs), 
                  n_reps)

reshaped_shuf = crop(traces, rand_trig, pre_int=0, post_int=int(stim_dur*fs))

# Create matrix of responses (thetas x vel x lums x sizes x timepoints x reps x n_cells)
thetas = sorted(stim_df["theta"].unique())
vels = sorted(stim_df["vel"].unique())
lums = sorted(stim_df["lum"].unique())
sizes = sorted(stim_df["size"].unique())

resp_block = np.zeros((len(thetas), len(vels), len(lums), len(sizes), n_samp_stim, 4, n_cells))
resp_block_shuf = np.zeros((len(thetas), len(vels), len(lums), len(sizes), n_samp_stim, 4, n_cells))

stim_df["stim_id"] = 0
stim_types_n = 0
for j, theta in enumerate(thetas):
    for k, vel in enumerate(vels):
        for i, lum in enumerate(lums):
            for z, size in enumerate(sizes):
                select = (stim_df["lum"] == lum) &  \
                         (stim_df["vel"] == vel) & \
                         (stim_df["theta"] == theta) & \
                         (stim_df["size"] == size) 
                stim_df.loc[select, "stim_id"] = stim_types_n
                
                resp_block[j, k, i, z, :, :, :] = reshaped[:, select, :]
                resp_block_shuf[j, k, i, z, :, :, :] = reshaped_shuf[:, select, :]
                
                stim_types_n+= 1
                

In [None]:
rel_scores = np.zeros((len(thetas), len(vels), len(lums), len(sizes), n_cells))
rel_scores_shuf = np.zeros((len(thetas), len(vels), len(lums), len(sizes), n_cells))

for j in tqdm(range(len(thetas))):
    for k in range(len(vels)):
        for i in range(len(lums)):
            for z in range(len(sizes)):
                resps = resp_block[j, k, i, z, :, :, :]
                resps_shuf = resp_block_shuf[j, k, i, z, :, :, :]
                
                rel_scores[j, k, i, z, :] = reliability(resps)
                rel_scores_shuf[j, k, i, z, :] = reliability(resps_shuf)


In [None]:
# Calculate resolutions for each cell, defined as the minimum bar size 
# in which in any condition there is a response

i = 1
resolutions = np.zeros(n_cells)
for cid in tqdm(range(n_cells)):
    thr = np.percentile(rel_scores_shuf[:, :, :, :, cid].flatten(), 99)
    ids = np.where(rel_scores[:, :, :, :, cid].max((0,1,2)) > thr)[0]
    if len(ids) > 0:
        resolutions[cid] = np.min(ids)
    else:
        resolutions[cid] = np.nan
        
# Calculate resolutions with imposing that all higher resolutions must be seen:
i = 1
strict_resolutions = np.zeros(n_cells)
for cid in tqdm(range(n_cells)):
    thr = np.percentile(rel_scores_shuf[:, :, :, :, cid].flatten(), 99)

    filtered = rel_scores[:, :, :, :, cid]
    filtered = filtered.reshape(len(thetas)*len(vels)*len(lums), len(sizes))
    significant_resps = filtered > thr

    cell_res = np.nan
    for i in range(2,-1,-1):
        if significant_resps[:, i:].all(1).any(0):
            cell_res = i
            
    strict_resolutions[cid] = (cell_res)
    
# Calculate maximum rel score for each cell:
max_rel_score = rel_scores.max((0,1,2,3))

In [None]:
plt.close("all")
cid = cid = np.argmax(max_rel_score) # 566 # 134 # 283 # 566 # 13190# 13090  # 12100# 1020
f, axs = plt.subplots(4, 4, figsize=(8, 8))
lum = 0
for th_i  in range(len(thetas)):
    for vel_i in range(len(vels)):
        for col, lum_i in zip(["b", "r"], [0, 1]):
            for s_i in range(len(sizes)):
                t = resp_block[th_i, vel_i, lum_i, s_i, :, :, cid]
                t = t - np.nanmean(t[:3, :], 0)
                axs[vel_i, th_i].plot(np.arange(t.shape[0]) / fs, np.nanmean(t, 1), 
                               c=col, alpha=1-0.2*s_i, label=f"{size} mm")
                axs[0, 0].get_shared_x_axes().join(axs[0, 0], axs[vel_i, th_i])
                axs[0, 0].get_shared_y_axes().join(axs[0, 0], axs[vel_i, th_i])

                
        axs[0, th_i].set_title(f"Theta: {int(180*thetas[th_i]/np.pi)}°")
        axs[vel_i, 0].set_ylabel(f"Vel: {vels[vel_i]} mm/s")

for th_i  in range(len(thetas)):
    for vel_i in range(len(vels)):
        axs[vel_i+2, th_i].axhspan(-1, np.percentile(rel_scores_shuf[:, :, :, :, cid].flatten(), 99), 
                          facecolor=(0.85,)*3, lw=0)
        for col, lum_i in zip(["b", "r"], [0, 1]):
            vals = rel_scores[th_i, vel_i, lum_i, :, cid]
            axs[vel_i+2, th_i].plot(sizes, vals, 
                           c=col, label=f"{size} mm")
            axs[vel_i+2, th_i].set_ylim(-0.2, 1.1)
            axs[vel_i+2, th_i].set_xlim(-0.5, 3.5)
            axs[vel_i+2, th_i].set_xticks([1, 2, 4, 8])
            if vel_i == 0:
                axs[vel_i+2, th_i].set_xticklabels([])
            axs[2, 0].get_shared_x_axes().join(axs[0, 0], axs[vel_i, th_i])
            axs[2, 0].get_shared_y_axes().join(axs[0, 0], axs[vel_i, th_i])
                
        axs[3, th_i].set_xlabel(f"bar size (mm)")
        axs[vel_i+2, 0].set_ylabel(f"Vel: {vels[vel_i]} mm/s")
# axs[0, 3].legend(frameon=False)
plt.show()
plt.tight_layout()
sns.despine()

f.savefig(path / "best_cell.png")

In [None]:
f, axs = plt.subplots(1,3,figsize=(9, 4))
axs[0].scatter(coords[:, 1], coords[:, 2], c=(0.8,)*3, s=3)
im = axs[0].scatter(coords[:, 1], coords[:, 2], c=max_rel_score, s=3, vmin=0, vmax=1)
cax = f.add_axes((0.3, 0.2, 0.05, 0.02))
cb = plt.colorbar(im, cax=cax, orientation="horizontal", label="Rel. score")

axs[0].axis("equal")
axs[0].axis("off")

sel = ~np.isnan(resolutions)
axs[1].scatter(coords[:, 1], coords[:, 2], c=(0.8,)*3, s=3)
im = axs[1].scatter(coords[sel, 1], coords[sel, 2], c=resolutions[sel], s=3, vmin=0, vmax=3, cmap="Reds_r")
cax = f.add_axes((0.55, 0.2, 0.05, 0.02))
cb = plt.colorbar(im, cax=cax, orientation="horizontal", label="Resolution (mm)")
cb.set_ticks([0, 3])
cb.set_ticklabels([1, 8])
# cb.set_ticks([0, 3])

axs[1].axis("equal")
axs[1].axis("off")

sel = ~np.isnan(strict_resolutions)
axs[2].scatter(coords[:, 1], coords[:, 2], c=(0.8,)*3, s=3)
im = axs[2].scatter(coords[sel, 1], coords[sel, 2], c=strict_resolutions[sel], 
                    s=3, vmin=0, vmax=3, cmap="Reds_r")
cax = f.add_axes((0.8, 0.2, 0.05, 0.02))
cb = plt.colorbar(im, cax=cax, orientation="horizontal", label="Resolution strict (mm)")
cb.set_ticks([0, 3])
cb.set_ticklabels([1, 8])

axs[2].axis("equal")
axs[2].axis("off")

f.savefig(path / "maps.png")

In [None]:
fl.save(path / "exported_data.h5", dict(coords=coords,
                                        raw_traces=traces_nanned,
                                        rel_scores=rel_scores,
                                        rel_scores_shuf=rel_scores_shuf,
                                        resp_block=resp_block,
                                        resp_block_shuf=resp_block_shuf,
                                        resp_block_axes=["theta", "vel", "lum", "sizes",  
                                                         "pts", "reps", "cells"],
                                        thetas=thetas,
                                        lums=lums,
                                        vels=vels,
                                        sizes=sizes,
                                        stim_df=stim_df))