In [None]:
%load_ext autoreload
%matplotlib widget

In [None]:
%autoreload
from pathlib import Path

import numpy as np
from lotr import DATASET_LOCATION
from lotr import plotting as pltltr
from lotr.em.loading import load_skeletons_from_zip
from matplotlib import cm
from matplotlib import pyplot as plt
from scipy.ndimage.morphology import binary_dilation
from tqdm import tqdm

In [None]:
em_path = DATASET_LOCATION.parent / "anatomy" / "all_ahb.k.zip"

neurons_list = load_skeletons_from_zip(em_path)

for n in neurons_list:
    n.mirror_right = True

In [None]:
f, axs = plt.subplots(1, 3, figsize=(9, 3))
for ax, sel in zip(axs, ["all", "dendrites", "axon"]):
    for neuron in neurons_list:
        try:
            lines = neuron.generate_plotlines_from_skeleton(space="ipn", select=sel)
            soma = neuron.coords_ipn[neuron.soma_idx, :]
            # if soma[2] > inclusion:
            # soma = midline - (midline - soma)
            #     lines[:, 2] = m - (lines[:, 2] - m)
            l = ax.plot(lines[:, 2], -lines[:, 1], lw=0.5)
            ax.scatter(soma[2], -soma[1])
        except ValueError:
            pass
    # ax.axis("equal")
    ax.set(title=sel, xlim=(-50, 150))

In [None]:
sel_neurons = [
    n
    for n in neurons_list
    if "bilat" in n.comments.split("_") and "dipn" in n.comments.split("_")
]

In [None]:
centroids = []
for n in sel_neurons:
    centroids.append(
        [n.find_centroid_bifurcation(select=s) for s in ["dendrites", "axon"]]
    )

centroids = np.array(centroids)

f, ax = plt.subplots(figsize=(2.0, 2.0), gridspec_kw=dict(bottom=0.2, left=0.2))
ax.scatter(centroids[:, 0, 2], centroids[:, 1, 2], s=5)
pltltr.despine(ax)
ax.set(xlabel="L-R position dendrite", ylabel="L-R position axon")

In [None]:
norm_cm = centroids[:, 1, 2] - centroids[:, 1, 2].min()
norm_cm = norm_cm + 10
norm_cm = norm_cm / (norm_cm.max() + 0.01)

In [None]:
f, ax = plt.subplots(figsize=(3, 3))
for i, (neuron, c) in enumerate(zip(sel_neurons, norm_cm)):
    # if i == 1 or i == 6:
    lines = neuron.generate_plotlines_from_skeleton(space="ipn")
    col = cm.get_cmap("Reds")(c)
    soma = neuron.coords_ipn[neuron.soma_idx, :]
    fake_im = ax.scatter([], [], c=[], cmap="Reds")
    # if soma[2] > inclusion:
    # soma = midline - (midline - soma)
    #     lines[:, 2] = m - (lines[:, 2] - m)
    l = plt.plot(lines[:, 2], -lines[:, 1], lw=0.5, c=col)
    plt.scatter(soma[2], -soma[1], lw=0, fc=col)
ax.axis("equal")
ax.set(xlabel=("left-right"), ylabel="ventral - dorsal")
pltltr.despine(ax, "all")
pltltr.add_cbar(
    fake_im, ax, inset_loc=(0.9, 0.6, 0.035, 0.2), ticks=[], label="Dendrite pos"
)

In [None]:
from bg_atlasapi.core import Atlas

In [None]:
atlas = Atlas(DATASET_LOCATION.parent / "anatomy" / "ipn_zfish_0.5_um_v1.6")

In [None]:
gad1b_stack = atlas.additional_references["gad1b_gal4"]

mask = np.zeros((3, 3, 3)).astype(bool)
mask[:, 1, :] = True
annotation = binary_dilation(
    atlas.get_structure_mask("ipn") > 0, iterations=7, structure=mask
)
annotation_dipn = binary_dilation(
    atlas.get_structure_mask("dors_ipn") > 0, iterations=7, structure=mask
)

In [None]:
[list(l) for l in np.round(
    (
        np.array(
            [
                [27, 158, 119],
                [217, 95, 2],
                [117, 112, 179],
                [231, 41, 138],
                [102, 166, 30],
                [230, 171, 2],
                [166, 118, 29],
            ]
        )
        / 255
    ), 2
)]

In [None]:
off_x = 0
off_y = 0
s = 1.0
f, ax = plt.subplots(figsize=(4, 4))
ax.imshow(
    gad1b_stack[20:, :, :].max(0),
    origin="lower",
    cmap="gray_r",
    vmin=450,
    vmax=15000,
    extent=(0, (gad1b_stack.shape[2] / 2) * s, 0, (gad1b_stack.shape[1] / 2) * s),
)
pltltr.plot_projection(
    annotation[:120, :, :],
    0,
    ax=ax,
    smooth_wnd=15,
    linewidth=0.5,
    fill=False,
    edgecolor=".3",
)
pltltr.plot_projection(
    annotation_dipn[:130, :, :],
    0,
    ax=ax,
    smooth_wnd=15,
    linewidth=0.5,
    fill=False,
    edgecolor=".3",
)

for i, neuron in enumerate(sel_neurons):
    if "bilat" in neuron.comments.split("_") and "dipn" in neuron.comments.split("_"):
        main_c = cols[i % len(cols)]
        lines = neuron.generate_plotlines_from_skeleton(space="ipn", select="dendrites")
        (l,) = ax.plot(
            lines[:, 2], lines[:, 1], lw=0.5, c=pltltr.shift_lum(main_c, -0.15)
        )

        lines = neuron.generate_plotlines_from_skeleton(space="ipn", select="axon")
        (l,) = ax.plot(
            lines[:, 2], lines[:, 1], lw=0.5, c=pltltr.shift_lum(main_c, 0.15)
        )

        soma = neuron.coords_ipn[neuron.soma_idx, :]
        ax.scatter(soma[2], soma[1], s=30, color=pltltr.shift_lum(main_c, -0.15))
        ax.axis("equal")
        ax.set(ylim=(170, -4), xlim=(50, 160))

pltltr.despine(ax, "all")
pltltr.add_scalebar(
    ax, ylen=-20, xlen=20, ypos=140, xpos=40, xlabel="R. - L.", ylabel="inf. sup."
)