In [None]:
%matplotlib widget
%load_ext autoreload

In [None]:
%autoreload 2
from pathlib import Path
import numpy as np
from bouter import EmbeddedExperiment
import pandas as pd
from tqdm import tqdm
import flammkuchen as fl

from matplotlib import pyplot as plt
import seaborn as sns
sns.set(palette="deep", style="ticks")
cols = sns.color_palette()

from scipy.signal import detrend

from bouter.utilities import crop, reliability
from utilities import stimulus_df_from_exp, despine

In [None]:
master_path = Path("/Volumes/Shared/experiments/E0070_receptive_field/v04_flashing_rad_simple")

path_list = [f.parent for f in master_path.glob("*/data_from_suite2p_unfiltered.h5")]

In [None]:
ot_mask = fl.load(path / "anatomy.mask", "/mask")

In [None]:
# Load traces and experiment metadata:
path = master_path / "210611_f5"
traces = fl.load(path / "data_from_suite2p_unfiltered.h5", "/traces").T
coords = fl.load(path / "data_from_suite2p_unfiltered.h5", "/coords")
rois = fl.load(path / "data_from_suite2p_unfiltered.h5", "/rois_stack")
ot_mask = fl.load(path / "anatomy.mask", "/mask")

exp = EmbeddedExperiment(path)

# detrend the traces:
for i in tqdm(range(traces.shape[1])):
    traces[:, i] = detrend(traces[:, i])
traces = (traces - np.nanmean(traces, 0)) / np.nanstd(traces, 0)

# Read original frequency:
fs = int(exp["imaging"]["microscope_config"]["lightsheet"]["scanning"]["z"]["frequency"])
samp_n = traces.shape[0]
t_orig = np.arange(traces.shape[0]) / fs


stim_df = stimulus_df_from_exp(exp)

# Crop around stimuli:
n_cells = traces.shape[1]
PRE_INT_S = 2
POST_INT_S = 5
cropped = crop(traces, stim_df["t"]*fs, 
                     pre_int=int(PRE_INT_S*fs), post_int=int(POST_INT_S*fs))
cropped = cropped - cropped[:int(PRE_INT_S*fs), :, :].mean(0)

# Find unique positions in the stimulus:
pos = sorted(stim_df.loc[:, "pos_start"].unique())

# Loop over positions and compute reliability scores:
rel_scores = np.zeros((len(pos), n_cells))

for i, p in tqdm(list(enumerate(pos))):
    resps_idxs = stim_df[stim_df["pos_start"] == p].index

    rel_scores[i] = reliability(cropped[:, resps_idxs, :])

In [None]:
for i_cell in [184, 10818]:
    cell_plane = int(coords[i_cell, 0])

    y_low, y_high = [np.nanpercentile(cropped[:, :, i], p) for p in [1, 99.89]]
    offset = np.median(np.array(pos)[1:] - np.array(pos)[:-1])
    x_time = np.arange(0, cropped.shape[0])/fs - PRE_INT_S

    rois_image = np.ones((rois.shape[1:]) + (3,)) # rois[5, :, :]
    for c in range(3):
        rois_image[:, :,c][rois[cell_plane, :, :] >= 0] = 0.8
        rois_image[:, :, c][rois[cell_plane, :, :] == i_cell] = cols[3][c]

    rois_image = rois_image.swapaxes(0, 1)

    # plt.close("all")
    f = plt.figure(figsize=(6, 6))
    #plt.suptitle(f"{path.name}_{i_cell}")
    ax_w = 0.12
    ax_c = (0.5, 0.5)
    ax_r = 0.25
    all_axs = []
    for i, th in enumerate(pos):
        resps_idxs = stim_df[stim_df["pos_start"] == th].index

        r = ax_r +0.12#+ 0.12*((i+1) % 2)
        plot_th = (np.pi*2 / 36) * i - np.pi + (10*np.pi/2)/360
        x = np.cos(plot_th)*r
        y = -np.sin(plot_th)*r

        ax = f.add_axes((ax_c[0] - ax_w/2 + x, ax_c[1] - ax_w/2+ y, ax_w, ax_w))

        ax.plot(x_time, cropped[:, resps_idxs, i_cell], lw=0.3, c=(0.4,)*3)
        ax.plot(x_time, cropped[:, resps_idxs, i_cell].mean(1), lw=1, c=cols[3])
        ax.plot([0, 4], [-.8, -.8],  c=cols[0], alpha=0.6, lw=3)
        ax.set_xlim(-2, 7)

        ax.set_ylim(-y_high, y_high)
        ax.axis("off")
        all_axs.append(ax)

    sel_ax = all_axs[0]
    sel_ax.axis("on")
    sel_ax.set_ylabel("dF (Z sc.)", fontsize=10)
    sel_ax.set_xlabel("time (s)", fontsize=10)
    sel_ax.set(xticks=[-2, 0, 2, 4], yticks=[-1, 2], yticklabels=[])
    sel_ax.set_xticklabels([-2, 0, 2, 4], fontsize=9)
    [i.set_linewidth(0.5) for _, i in sel_ax.spines.items()]
    sel_ax.tick_params(length=2, width=0.5)


    sns.despine(trim="True")

    p_w = 0.22
    ax = f.add_axes((ax_c[0]-p_w, ax_c[1] - p_w, p_w*2, p_w*2), polar=True)

    ax.bar(-np.array(pos), rel_scores[:, i_cell], lw=0, width=np.pi/18)

    ax.set_thetagrids([])

    ax.spines['polar'].set_visible(False) 
    ax.set_ylim(0, 1)
    ax.set_rgrids([0.3, 0.6, 0.9], angle=0, fontsize=9)
    ax.text(-0.1, 0.68,
            'Reliability', rotation=0,ha='center',va='center', fontsize=10)

    ax = f.add_axes((0.01, 0.78, 0.3, 0.3))
    ax.imshow(rois_image, origin="lower")
    ax.contour(ot_mask[cell_plane, :, :], origin="lower", levels=[1], linewidths=0.5, colors=[(0.5,)*3])
    [i.set_visible(False) for _, i in ax.spines.items()]
    ax.set(xticks=[], yticks=[])


In [None]:
from svgpath2mpl import parse_path    
from matplotlib import collections

path_fish = 'm0 0c-13.119 71.131-12.078 130.72-12.078 138.78-5.372 8.506-3.932 18.626-3.264 23.963-6.671 1.112-2.891 4.002-2.891 5.114s-2.224 8.005.445 9.116c-.223 3.113.222 0 0 1.557-.223 1.556-3.558 3.558-2.891 8.227.667 4.67 3.558 10.228 6.226 9.784 2.224 4.892 5.559 4.669 7.56 4.447 2.001-.223 8.672-.445 10.228-6.004 5.115-1.556 5.562-4.002 5.559-6.67-.003-3.341.223-8.45-3.113-12.008 3.336-4.224.667-13.786-3.335-13.786 1.59-8.161-2.446-13.786-3.558-20.679-2.223-34.909-.298-102.74 1.112-141.84'


In [None]:
plt.close("all")
clip_masks = [s["clip_mask"] for s in exp["stimulus"]["log"][1::2]]
titles = ["4 s", "2 s", "4 s", "2 s", "4 s"]
stimuli = [0, None, 20, None, 10]

f, axs = plt.subplots(1, 5, figsize=(8, 1.5), sharey=True)

for ax, title, stim_n in zip(axs, titles, stimuli):
    ax.fill([0, 1, 1, 0], [0, 0, 1, 1], fc="r")
    
    if stim_n is not None:
        mask = clip_masks[stim_n]
        ax.fill([m[0] for m in mask], [m[1] for m in mask], fc="k", lw=0, alpha=0.6)
    # ax.axis("equal")
    ax.set(ylim=(0.1, 0.9), xlim=(0.1, 0.9), xticks=[], yticks=[], title=title)
    ax.set_aspect('equal', adjustable='box')
    

    path = parse_path(path_fish)
    min_p = np.min(path.vertices, 0)
    path.vertices -= min_p
    f = np.abs(path.vertices[:, 1]).max()*(30/4)
    path.vertices[:, 0] =  path.vertices[:, 0] / f
    path.vertices[:, 1] = path.vertices[:, 1] / f
    
    path.vertices += np.array([0.49, 0.38])

    collection = collections.PathCollection([path],
                                                 linewidths=0,
                                                 facecolors=["#909090"])
    ax.add_artist(collection)

    despine(ax, sides="all")
    

# Big figure

In [None]:
pooled_data = fl.load(master_path / "new_pooled.h5")

In [None]:
responses = pooled_data.loc[:, [i for i in range(35)]].values.T
coords = pooled_data.loc[:, ["z", "x", "y"]].values
in_tectum = pooled_data["in_tectum"].values

responsive = responses.max(0) > 0.5
all_peaks = np.argmax(responses, 0)
stim_thetas = np.array(pos)

In [None]:
# plt.rcParams['figure.constrained_layout.use'] = True
plt.rcParams['axes.linewidth'] = 0.5 
plt.rcParams['axes.labelsize'] = 10
plt.rcParams["legend.fontsize"] = 8
plt.rcParams["axes.titlesize"] = 8
for t in ["x", "y"]:
    plt.rcParams[t+'tick.major.size'] = 3
    plt.rcParams[t+'tick.labelsize'] = 8
    plt.rcParams[t+'tick.major.width'] = 0.5

In [None]:
def add_fish(ax, offset=(0, 0), scale=1):
    path_fish = 'm0 0c-13.119 71.131-12.078 130.72-12.078 138.78-5.372 8.506-3.932 18.626-3.264 23.963-6.671 1.112-2.891 4.002-2.891 5.114s-2.224 8.005.445 9.116c-.223 3.113.222 0 0 1.557-.223 1.556-3.558 3.558-2.891 8.227.667 4.67 3.558 10.228 6.226 9.784 2.224 4.892 5.559 4.669 7.56 4.447 2.001-.223 8.672-.445 10.228-6.004 5.115-1.556 5.562-4.002 5.559-6.67-.003-3.341.223-8.45-3.113-12.008 3.336-4.224.667-13.786-3.335-13.786 1.59-8.161-2.446-13.786-3.558-20.679-2.223-34.909-.298-102.74 1.112-141.84'
    path = parse_path(path_fish)
    min_p = np.min(path.vertices, 0)
    path.vertices -= min_p
    f = np.abs(path.vertices[:, 1]).max()*scale
    path.vertices[:, 0] =  path.vertices[:, 0] / f
    path.vertices[:, 1] = path.vertices[:, 1] / f
    
    path.vertices += np.array(offset)

    collection = collections.PathCollection([path],
                                                 linewidths=0,
                                                 facecolors=["#909090"])
    ax.add_artist(collection)

In [None]:
plt.close("all")
clip_masks = [s["clip_mask"] for s in exp["stimulus"]["log"][1::2]]
titles = ["4 s", "2 s", "4 s", "2 s", "4 s"]
stimuli = [0, None, 20, None, 10]
letter_size = 12

fig = plt.figure(figsize=(8, 12))


####################
## Panel A: stimulus

xpos, ypos, side = 0.08, 0.92, 0.04
axs = [fig.add_axes((xpos+side*1.1*i, ypos, side, side)) for i in range(5)]

add_letter(axs[0], letter="A", xoff=0.03, yoff=0.0, fontsize=letter_size)

clip_masks = [s["clip_mask"] for s in exp["stimulus"]["log"][1::2]]
titles = ["4 s", "2 s", "4 s", "2 s", "4 s"]
stimuli = [0, None, 20, None, 10]
for ax, title, stim_n in zip(axs, titles, stimuli):
    ax.fill([0, 1, 1, 0], [0, 0, 1, 1], fc="r")
    
    if stim_n is not None:
        mask = clip_masks[stim_n]
        ax.fill([m[0] for m in mask], [m[1] for m in mask], fc="k", lw=0, alpha=0.6)
    ax.set(ylim=(0.1, 0.9), xlim=(0.1, 0.9), xticks=[], yticks=[], title=title)
    ax.set_aspect('equal', adjustable='box')
    
    add_fish(ax, offset=[0.45, 0.08], scale=(30/15))

    despine(ax, sides="all")
    
########################
## Panel B: example ROIs
m_xpos, m_ypos, xside, yside = 0.4, 0.75, 0.28, 0.2
bounds_lims = [(m_xpos+xside*1.2*i, m_ypos, xside, yside) for i in range(2)]

x_time = np.arange(0, cropped.shape[0])/fs - PRE_INT_S  # time array
offset = np.median(np.array(pos)[1:] - np.array(pos)[:-1])  # offset on angle
# parameters for circle of subplots:
ax_w = 0.12
ax_c = (0.5, 0.5)
ax_r = 0.37
n_stims = 36
scaling_percentiles = [1, 99.89]

p_w = 0.22  # side of the central histogram

for c_n, (i_cell, (xpos, ypos, xside, yside)) in enumerate(zip([184, 10818], bounds_lims)):
    
    cell_plane = int(coords[i_cell, 0])  # plane in which the cell is found

    # Get percentiles of responses scaling:
    y_low, y_high = [np.nanpercentile(cropped[:, :, i_cell], p) for p in scaling_percentiles]

    # Color roi stack with red in location of current ROI
    rois_image = np.ones((rois.shape[1:]) + (3,))
    for c in range(3):
        rois_image[:, :,c][rois[cell_plane, :, :] >= 0] = 0.8
        rois_image[:, :, c][rois[cell_plane, :, :] == i_cell] = cols[3][c]

    rois_image = rois_image.swapaxes(0, 1)

    # Loop over stimulus positions and plot reps and mean:
    # all_axs = []
    for i, th in enumerate(pos):
        resps_idxs = stim_df[stim_df["pos_start"] == th].index

        plot_th = (np.pi*2 / n_stims) * i - np.pi + (10*np.pi/2) / 360
        x = np.cos(plot_th)*r
        y = -np.sin(plot_th)*r

        ax = fig.add_axes((xpos + (ax_c[0] - ax_w/2 + x)*xside, 
                           ypos + (ax_c[1] - ax_w/2+ y)*yside, 
                           ax_w*xside, ax_w*yside))

        ax.plot(x_time, cropped[:, resps_idxs, i_cell], lw=0.3, c=(0.4,)*3)
        ax.plot(x_time, cropped[:, resps_idxs, i_cell].mean(1), lw=1, c=cols[3])
        ax.plot([0, 4], [-.8, -.8],  c=cols[0], alpha=0.6, lw=3)
        ax.set(xlim=(-2, 7), ylim=(-y_high, y_high))
        ax.axis("off")
        # all_axs.append(ax)

    # sel_ax = all_axs[0]
    # sel_ax.axis("on")
    # sel_ax.set_ylabel("dF (Z sc.)", fontsize=10)
    # sel_ax.set_xlabel("time (s)", fontsize=10)
    # sel_ax.set(xticks=[-2, 0, 2, 4], yticks=[-1, 2], yticklabels=[])
    # sel_ax.set_xticklabels([-2, 0, 2, 4], fontsize=9)
    # [i.set_linewidth(0.5) for _, i in sel_ax.spines.items()]
    # sel_ax.tick_params(length=2, width=0.5)


    # sns.despine(trim="True")

    ax = fig.add_axes((xpos + (ax_c[0]-p_w)*xside, 
                       ypos + (ax_c[1] - p_w)*yside, 
                       p_w*2*xside, p_w*2*yside), polar=True)

    ax.bar(-np.array(pos), rel_scores[:, i_cell], lw=0, width=np.pi/18)

    ax.set_thetagrids([])

    ax.spines['polar'].set_visible(False) 
    ax.set_ylim(0, 1)
    ax.set_rgrids([0.5, 1], angle=0, fontsize=7)
    ax.text(-0.15, 0.85,
            'Reliability', rotation=0,ha='center',va='center', fontsize=8)

    anatomy_ax = fig.add_axes((xpos - 0.05*xside, ypos + 0.78*yside, 0.3*xside, 0.3*yside))
    anatomy_ax.imshow(rois_image, origin="lower")
    anatomy_ax.contour(ot_mask[cell_plane, :, :], origin="lower", levels=[1], 
                       linewidths=0.5, colors=[(0.5,)*3])
    anatomy_ax.axis("off")
    # [i.set_visible(False) for _, i in ax.spines.items()]
    # ax.set(xticks=[], yticks=[])
    if c_n == 0:
        add_letter(anatomy_ax, letter="B", xoff=0.03, yoff=-0.0, fontsize=letter_size)



###################
# Panel C: topology

m_xpos, m_ypos, xside, yside = 0.05, 0.5, 0.4, 0.4
anat_scatt_size=5

all_axs = [[fig.add_axes((m_xpos + (0.1+0.4*i)*xside, m_ypos + (0.1+0.2*j)*yside, 
                        0.4*xside, 0.4*yside)) 
            for j in range(1, -1, -1)] for i in range(2)]

for g_i, (g, lab) in enumerate(zip(["Huc:H2B-GCaMP6s", "Huc:H2B-GCaMP6s;olig1:Ntr"],
                                 ["Control", "OPC ablated"])):
    axs = all_axs[g_i]
    for i in range(9):
        filt = responsive & in_tectum & (coords[:, 0] == i) & (pooled_data["gen"] == g)
        axs[0].scatter(coords[filt, 1], coords[filt, 2], 
                    c=stim_thetas[all_peaks[filt]], cmap="twilight_shifted", s=anat_scatt_size)
    axs[0].set_aspect('equal', adjustable='box')
    axs[0].set(xlim=(-20, 760), ylim=(0, 460))
    axs[0].axis("off")
    axs[0].text(351, 450, lab, fontsize=8, va="bottom", ha="center")

    for i in range(12, 0, -1):
        filt = responsive & in_tectum & (pooled_data["gen"] == g) & (coords[:, 2] > i*50) & (coords[:, 2] < (i+1)*50)
        axs[1].scatter(coords[filt, 1], coords[filt, 0]*12, 
                    c=stim_thetas[all_peaks[filt]], cmap="twilight_shifted", s=anat_scatt_size)
    axs[1].set_aspect('equal', adjustable='box')
    axs[0].set(xlim=(-20, 760))# , ylim=(0, 460))
    axs[1].axis("off")

val_range = np.linspace(-np.pi, np.pi, 100)
ax = fig.add_axes((m_xpos - 0.09*xside, m_ypos + 0.5*yside, 0.15*xside, 0.15*yside))
ax.imshow(np.angle(val_range[:,None] + 1j*val_range[None,:]).T, 
          cmap="twilight_shifted", extent=[0, 1, 0, 1])
ax.text(0.5,1.1, "stim. θ", fontsize=7, va="bottom", ha="center")
ax.axis("equal")
ax.axis("off")
add_fish(ax, offset=[0.45, 0.0], scale=1.7)

In [None]:
fig, ax = plt.subplots()
# ax = fig.add_axes((m_xpos, m_ypos + 0.5*yside, 0.1*xside, 0.1*yside))
ax.imshow(np.angle(val_range[:,None] + 1j*val_range[None,:]).T, 
          cmap="twilight_shifted", extent=[0, 1, 0, 1])
ax.text(0,1.1, "stim. θ", fontsize=7)
ax.axis("equal")
ax.axis("off")
add_fish(ax, offset=[0.45, 0.08], scale=1.7)

In [None]:
import matplotlib.pyplot as plt
def move_axes(ax, fig, subplot_spec=111):
    """Move an Axes object from a figure to a new pyplot managed Figure in
    the specified subplot.
    From https://gist.github.com/salotz/8b4542d7fe9ea3e2eacc1a2eef2532c5
    """

    # get a reference to the old figure context so we can release it
    old_fig = ax.figure

    # remove the Axes from it's original Figure context
    ax.remove()

    # set the pointer from the Axes to the new figure
    ax.figure = fig

    # add the Axes to the registry of axes for the figure
    fig.add_axes(ax) # fig.axes.append(ax)
    # twice, I don't know why...
    # 

    # then to actually show the Axes in the new figure we have to make
    # a subplot with the positions etc for the Axes to go, so make a
    # subplot which will have a dummy Axes
    dummy_ax = fig.add_subplot(subplot_spec)

    # then copy the relevant data from the dummy to the ax
    ax.set_position(dummy_ax.get_position())

    # then remove the dummy
    dummy_ax.remove()

    # close the figure the original axis was bound to
    plt.close(old_fig)

In [None]:
new_fig = plt.figure()
move_axes(ax, new_fig, subplot_spec=111)