In [None]:
%matplotlib widget

In [None]:
from pathlib import Path
import numpy as np
from bouter import EmbeddedExperiment
import pandas as pd
from tqdm import tqdm
import flammkuchen as fl

from matplotlib import pyplot as plt
import seaborn as sns
sns.set(palette="deep", style="ticks")

from scipy.signal import detrend

from bouter.utilities import crop, reliability

In [None]:
path = Path("/Volumes/Shared/experiments/E0070_receptive_field/v01_sliding_bars/210603_f1")

In [None]:
# Load traces and experiment metadata:
traces = fl.load(path / "data_from_suite2p_unfiltered.h5", "/traces").T
coords = fl.load(path / "data_from_suite2p_unfiltered.h5", "/coords")

exp = EmbeddedExperiment(path)

sel = ~(traces == 0).all(0)
traces = traces[:, sel]
coords = coords[sel, :]

# detrend the traces
for i in tqdm(range(traces.shape[1])):
    traces[:, i] = detrend(traces[:, i])
    

# nan large bouts
# Window to nan bouts:
NAN_WND_PRE_S = 1
NAN_WND_POST_S = 15

# Read original frequency:
fs = int(exp["imaging"]["microscope_config"]["lightsheet"]["scanning"]["z"]["frequency"])
samp_n = traces.shape[0]
t_orig = np.arange(traces.shape[0]) / fs

nan_wnd_pre = int(NAN_WND_PRE_S * fs)
nan_wnd_post = int(NAN_WND_POST_S * fs)
bouts_df = exp.get_bout_properties()
large_bouts_t = bouts_df.loc[bouts_df["peak_vig"] > 1.5, "t_start"].values
large_bouts_idxs = (large_bouts_t * fs).astype(np.int)

# nan traces:
traces_nanned = traces.copy()
for idx in tqdm(large_bouts_idxs):
    if idx > nan_wnd_pre and idx < (samp_n - nan_wnd_post):
        traces_nanned[idx - nan_wnd_pre:idx + nan_wnd_post, :] = np.nan

In [None]:
# Create dataframe of stimulus features:
stim_logs = exp["stimulus"]["log"] # logs of individual stimuli

# Loop and create dictionary for each stim
cond_dict = []
for i in range(1, len(stim_logs), 3):
    entry = stim_logs[i]
    pre_pause = stim_logs[i - 1]
    
    cond_dict.append(dict(t_start=round(pre_pause["t_start"]),
                          lum=entry["color_2"][0],
                          theta=entry["theta"] + int(entry["x"] < 15)*np.pi,
                          vel=10 if entry["t_start"] % 8 > 2 else 5,
                          size=entry["bar_size"]))

stim_df = pd.DataFrame(cond_dict)  # convert to dataframe


# Reshape traces matrix to crop around stimuli:
stim_dur = stim_df.loc[1, "t_start"]
n_samp = traces.shape[0]
n_cells = traces.shape[1]
n_reps = int(samp_n / (stim_dur*fs))
n_samp_stim = int(n_samp / n_reps)
reshaped = traces_nanned.T.reshape(n_cells, n_reps, n_samp_stim)
reshaped = reshaped.swapaxes(0, 2)

# Create shuffle reshaped matrix:
rand_trig = np.random.randint(int(stim_dur*fs), 
                  traces.shape[0] - int(stim_dur*fs), 
                  n_reps)

reshaped_shuf = crop(traces, rand_trig, pre_int=0, post_int=int(stim_dur*fs))

# Create matrix of responses (thetas x vel x lums x sizes x timepoints x reps x n_cells)
thetas = sorted(stim_df["theta"].unique())
vels = sorted(stim_df["vel"].unique())
lums = sorted(stim_df["lum"].unique())
sizes = sorted(stim_df["size"].unique())

resp_block = np.zeros((len(thetas), len(vels), len(lums), len(sizes), n_samp_stim, 4, n_cells))
resp_block_shuf = np.zeros((len(thetas), len(vels), len(lums), len(sizes), n_samp_stim, 4, n_cells))

stim_df["stim_id"] = 0
stim_types_n = 0
for j, theta in enumerate(thetas):
    for k, vel in enumerate(vels):
        for i, lum in enumerate(lums):
            for z, size in enumerate(sizes):
                select = (stim_df["lum"] == lum) &  \
                         (stim_df["vel"] == vel) & \
                         (stim_df["theta"] == theta) & \
                         (stim_df["size"] == size) 
                stim_df.loc[select, "stim_id"] = stim_types_n
                
                resp_block[j, k, i, z, :, :, :] = reshaped[:, select, :]
                resp_block_shuf[j, k, i, z, :, :, :] = reshaped_shuf[:, select, :]
                
                stim_types_n+= 1
                

In [None]:
rel_scores = np.zeros((len(thetas), len(vels), len(lums), len(sizes), n_cells))
rel_scores_shuf = np.zeros((len(thetas), len(vels), len(lums), len(sizes), n_cells))

for j in tqdm(range(len(thetas))):
    for k in range(len(vels)):
        for i in range(len(lums)):
            for z in range(len(sizes)):
                resps = resp_block[j, k, i, z, :, :, :]
                resps_shuf = resp_block_shuf[j, k, i, z, :, :, :]
                
                rel_scores[j, k, i, z, :] = reliability(resps)
                rel_scores_shuf[j, k, i, z, :] = reliability(resps_shuf)


In [None]:
# Calculate resolutions for each cell, defined as the minimum bar size 
# in which in any condition there is a response

i = 1
resolutions = np.zeros(n_cells)
for cid in tqdm(range(n_cells)):
    thr = np.percentile(rel_scores_shuf[:, :, :, :, cid].flatten(), 99)
    ids = np.where(rel_scores[:, :, :, :, cid].max((0,1,2)) > thr)[0]
    if len(ids) > 0:
        resolutions[cid] = np.min(ids)
    else:
        resolutions[cid] = np.nan
        
# Calculate resolutions with imposing that all higher resolutions must be seen:
i = 1
strict_resolutions = np.zeros(n_cells)
for cid in tqdm(range(n_cells)):
    thr = np.percentile(rel_scores_shuf[:, :, :, :, cid].flatten(), 99)

    filtered = rel_scores[:, :, :, :, cid]
    filtered = filtered.reshape(len(thetas)*len(vels)*len(lums), len(sizes))
    significant_resps = filtered > thr

    cell_res = np.nan
    for i in range(2,-1,-1):
        if significant_resps[:, i:].all(1).any(0):
            cell_res = i
            
    strict_resolutions[cid] = (cell_res)
    
# Calculate maximum rel score for each cell:
max_rel_score = rel_scores.max((0,1,2,3))

In [None]:
plt.close("all")
cid = cid = np.argmax(max_rel_score) # 566 # 134 # 283 # 566 # 13190# 13090  # 12100# 1020
f, axs = plt.subplots(4, 4, figsize=(8, 8))
lum = 0
for th_i  in range(len(thetas)):
    for vel_i in range(len(vels)):
        for col, lum_i in zip(["b", "r"], [0, 1]):
            for s_i in range(len(sizes)):
                t = resp_block[th_i, vel_i, lum_i, s_i, :, :, cid]
                t = t - np.nanmean(t[:3, :], 0)
                axs[vel_i, th_i].plot(np.arange(t.shape[0]) / fs, np.nanmean(t, 1), 
                               c=col, alpha=1-0.2*s_i, label=f"{size} mm")
                axs[0, 0].get_shared_x_axes().join(axs[0, 0], axs[vel_i, th_i])
                axs[0, 0].get_shared_y_axes().join(axs[0, 0], axs[vel_i, th_i])

                
        axs[0, th_i].set_title(f"Theta: {int(180*thetas[th_i]/np.pi)}°")
        axs[vel_i, 0].set_ylabel(f"Vel: {vels[vel_i]} mm/s")

for th_i  in range(len(thetas)):
    for vel_i in range(len(vels)):
        axs[vel_i+2, th_i].axhspan(-1, np.percentile(rel_scores_shuf[:, :, :, :, cid].flatten(), 99), 
                          facecolor=(0.85,)*3, lw=0)
        for col, lum_i in zip(["b", "r"], [0, 1]):
            vals = rel_scores[th_i, vel_i, lum_i, :, cid]
            axs[vel_i+2, th_i].plot(sizes, vals, 
                           c=col, label=f"{size} mm")
            axs[vel_i+2, th_i].set_ylim(-0.2, 1.1)
            axs[vel_i+2, th_i].set_xlim(-0.5, 3.5)
            axs[vel_i+2, th_i].set_xticks([1, 2, 4, 8])
            if vel_i == 0:
                axs[vel_i+2, th_i].set_xticklabels([])
            axs[2, 0].get_shared_x_axes().join(axs[0, 0], axs[vel_i, th_i])
            axs[2, 0].get_shared_y_axes().join(axs[0, 0], axs[vel_i, th_i])
                
        axs[3, th_i].set_xlabel(f"bar size (mm)")
        axs[vel_i+2, 0].set_ylabel(f"Vel: {vels[vel_i]} mm/s")
# axs[0, 3].legend(frameon=False)
plt.show()
plt.tight_layout()
sns.despine()

f.savefig(path / "best_cell.png")

In [None]:
f, axs = plt.subplots(1,3,figsize=(9, 4))
axs[0].scatter(coords[:, 1], coords[:, 2], c=(0.8,)*3, s=3)
im = axs[0].scatter(coords[:, 1], coords[:, 2], c=max_rel_score, s=3, vmin=0, vmax=1)
cax = f.add_axes((0.3, 0.2, 0.05, 0.02))
cb = plt.colorbar(im, cax=cax, orientation="horizontal", label="Rel. score")

axs[0].axis("equal")
axs[0].axis("off")

sel = ~np.isnan(resolutions)
axs[1].scatter(coords[:, 1], coords[:, 2], c=(0.8,)*3, s=3)
im = axs[1].scatter(coords[sel, 1], coords[sel, 2], c=resolutions[sel], s=3, vmin=0, vmax=3, cmap="Reds_r")
cax = f.add_axes((0.55, 0.2, 0.05, 0.02))
cb = plt.colorbar(im, cax=cax, orientation="horizontal", label="Resolution (mm)")
cb.set_ticks([0, 3])
cb.set_ticklabels([1, 8])
# cb.set_ticks([0, 3])

axs[1].axis("equal")
axs[1].axis("off")

sel = ~np.isnan(strict_resolutions)
axs[2].scatter(coords[:, 1], coords[:, 2], c=(0.8,)*3, s=3)
im = axs[2].scatter(coords[sel, 1], coords[sel, 2], c=strict_resolutions[sel], 
                    s=3, vmin=0, vmax=3, cmap="Reds_r")
cax = f.add_axes((0.8, 0.2, 0.05, 0.02))
cb = plt.colorbar(im, cax=cax, orientation="horizontal", label="Resolution strict (mm)")
cb.set_ticks([0, 3])
cb.set_ticklabels([1, 8])

axs[2].axis("equal")
axs[2].axis("off")

f.savefig(path / "maps.png")

In [None]:
fl.save(path / "exported_data.h5", dict(coords=coords,
                                        raw_traces=traces_nanned,
                                        rel_scores=rel_scores,
                                        rel_scores_shuf=rel_scores_shuf,
                                        resp_block=resp_block,
                                        resp_block_shuf=resp_block_shuf,
                                        resp_block_axes=["theta", "vel", "lum", "sizes",  
                                                         "pts", "reps", "cells"],
                                        thetas=thetas,
                                        lums=lums,
                                        vels=vels,
                                        sizes=sizes,
                                        stim_df=stim_df))