In [None]:
%matplotlib widget
from pathlib import Path

import flammkuchen as fl
import matplotlib as mpl
import numpy as np
import pandas as pd
import scipy.cluster.hierarchy as sch
from bg_atlasapi import BrainGlobeAtlas
from lotr import DATASET_LOCATION
from lotr import plotting as pltltr
from matplotlib import pyplot as plt
from sklearn.metrics.pairwise import euclidean_distances
from tqdm import tqdm

cols = pltltr.COLS["qualitative"] * 100
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=cols)

atlas = BrainGlobeAtlas("ipn_zfish_0.5um")

In [None]:
anatomy_location = DATASET_LOCATION / "anatomy"
neurons = fl.load(anatomy_location / "annotated_traced_neurons" / "all_skeletons.h5")
all_neurons_list = [neurons[n] for n in neurons.keys()]

In [None]:
valid_ipn_neurons = [
    n
    for n in neurons.values()
    if (
        (n.comments[0] == "p" and n.comments.split(" - ")[1][0] == "i")
        or (n.comments[0] == "n")
    )
    and "??" not in n.comments
]

In [None]:
#  Specify axes limits over all dimensions:
bs = dict(frontal=(30, 180), vertical=(20, 180), sagittal=(-20, 120))

plotter = pltltr.AtlasPlotter(
    structures=["ipn", "dipn"],
    mask_slices=dict(frontal=slice(0, 120)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)

In [None]:
f, all_axs = plt.subplots(3, 3, figsize=(6.2, 7.2))
for i, axs in enumerate(all_axs):
    f, axs = plotter.generate_projection_plots(axs, labels=True)

a = 0.7
raster = True
for j, n in enumerate(valid_ipn_neurons[:80:3]):
    for i in [0, 1, 2]:
        plotter.plot_neurons(
            all_axs[i, :],
            n,
            select="dendrites",
            color=pltltr.shift_lum(cols[j], -0.15),
            linewidth=0.0,
            rasterized=raster,
            alpha=1,
            soma_s=30,
            # zorder=100,
            label="__nolegend__",
        )
    for i in [0, 2]:
        plotter.plot_neurons(
            all_axs[i, :],
            n,
            select="dendrites",
            color=pltltr.shift_lum(cols[j], -0.15),
            linewidth=0.5,
            rasterized=raster,
            alpha=a,
            soma_s=0,
            # zorder=100,
            label="__nolegend__",
        )
    if n.has_axon:
        for i in [0, 1]:
            plotter.plot_neurons(
                all_axs[i, :],
                n,
                select="axon",
                color=pltltr.shift_lum(cols[j], 0.15),
                linewidth=0.5,
                rasterized=raster,
                alpha=a,
                soma_s=0,
                # zorder=100,
                label="__nolegend__",
            )
        plotter.plot_neurons(  # dummy just for plot ratio
            all_axs[2, :],
            n,
            select="axon",
            color=pltltr.shift_lum(cols[j], 0.15),
            linewidth=0.0,
            rasterized=raster,
            alpha=0,
            soma_s=0,
            label="__nolegend__",
        )


for ax, t in zip(all_axs[:, i], ["Neurons", "Axons", "Dendrites"]):
    ax.text(106, 0, t, fontsize=8, c=(0.3,) * 3, ha="center", va="top")

pltltr.savefig("all_ipn_full", fig=f)

In [None]:
#  Specify axes limits over all dimensions:
bs = dict(frontal=(30, 180), vertical=(40, 180), sagittal=(-20, 120))

plotter = pltltr.AtlasPlotter(
    structures=["ipn", "dipn"],
    mask_slices=dict(frontal=slice(0, 120)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)

In [None]:
gridspec_kw = dict(left=0.05, right=1, top=1, bottom=0.1, hspace=0.01)
figsize = (4.2, 1.3)
f_all, axs_all = plt.subplots(1, 3, figsize=figsize, gridspec_kw=gridspec_kw)
f_axon, axs_axon = plt.subplots(1, 3, figsize=figsize, gridspec_kw=gridspec_kw)
f_dendr, axs_dendr = plt.subplots(1, 3, figsize=figsize, gridspec_kw=gridspec_kw)
all_axs = np.concatenate([a[np.newaxis, :] for a in [axs_all, axs_axon, axs_dendr]], 0)
for i, axs in enumerate(all_axs):
    f, axs = plotter.generate_projection_plots(axs, labels=True)

a = 0.7
raster = True
for j, n in enumerate(valid_ipn_neurons[:80:3]):
    for i in [0, 1, 2]:
        plotter.plot_neurons(
            all_axs[i, :],
            n,
            select="dendrites",
            color=pltltr.shift_lum(cols[j], -0.15),
            linewidth=0.0,
            rasterized=raster,
            alpha=1,
            soma_s=30,
            # zorder=100,
            label="__nolegend__",
        )
    for i in [0, 2]:
        plotter.plot_neurons(
            all_axs[i, :],
            n,
            select="dendrites",
            color=pltltr.shift_lum(cols[j], -0.15),
            linewidth=0.5,
            rasterized=raster,
            alpha=a,
            soma_s=0,
            # zorder=100,
            label="__nolegend__",
        )
    if n.has_axon:
        for i in [0, 1]:
            plotter.plot_neurons(
                all_axs[i, :],
                n,
                select="axon",
                color=pltltr.shift_lum(cols[j], 0.15),
                linewidth=0.5,
                rasterized=raster,
                alpha=a,
                soma_s=0,
                # zorder=100,
                label="__nolegend__",
            )
        plotter.plot_neurons(  # dummy just for plot ratio
            all_axs[2, :],
            n,
            select="axon",
            color=pltltr.shift_lum(cols[j], 0.15),
            linewidth=0.0,
            rasterized=raster,
            alpha=0,
            soma_s=0,
            label="__nolegend__",
        )


# for ax, t in zip(all_axs[:, i], ["Neurons", "Axons", "Dendrites"]):
#    ax.text(106, 0, t, fontsize=8, c=(0.3,) * 3, ha="center", va="top")

pltltr.savefig("all_ipn_full", fig=f_all)
pltltr.savefig("all_ipn_axon", fig=f_axon)
pltltr.savefig("all_ipn_dendr", fig=f_dendr)

# aHB neurons

In [None]:
ahb_neurons = [
    n
    for k, n in neurons.items()
    if "ahb" in n.comments
    and ("vipn" in n.comments or "dipn" in n.comments)
    and "unknown" not in n.comments
    and n.has_axon
]

In [None]:
#  Specify axes limits over all dimensions:
bs = dict(frontal=(15, 195), vertical=(-20, 180), sagittal=(-20, 120))

plotter = pltltr.AtlasPlotter(
    structures=["ipn", "dipn"],
    mask_slices=dict(frontal=slice(0, 120)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)

In [None]:
gridspec_kw = dict(left=0.05, right=1, top=1, bottom=0.1, hspace=0.01)
figsize = (4.2, 1.3)
f_all, axs_all = plt.subplots(1, 3, figsize=figsize, gridspec_kw=gridspec_kw)
f_axon, axs_axon = plt.subplots(1, 3, figsize=figsize, gridspec_kw=gridspec_kw)
f_dendr, axs_dendr = plt.subplots(1, 3, figsize=figsize, gridspec_kw=gridspec_kw)
all_axs = np.concatenate([a[np.newaxis, :] for a in [axs_all, axs_axon, axs_dendr]], 0)
for i, axs in enumerate(all_axs):
    f, axs = plotter.generate_projection_plots(axs, labels=True)

a = 0.7
raster = True
for j, n in enumerate(ahb_neurons):
    for i in [0, 1, 2]:
        plotter.plot_neurons(
            all_axs[i, :],
            n,
            select="dendrites",
            color=pltltr.shift_lum(cols[j], -0.15),
            linewidth=0.0,
            rasterized=raster,
            alpha=1,
            soma_s=30,
            # zorder=100,
            label="__nolegend__",
        )
    for i in [0, 2]:
        plotter.plot_neurons(
            all_axs[i, :],
            n,
            select="dendrites",
            color=pltltr.shift_lum(cols[j], -0.15),
            linewidth=0.5,
            rasterized=raster,
            alpha=a,
            soma_s=0,
            # zorder=100,
            label="__nolegend__",
        )
    if n.has_axon:
        for i in [0, 1]:
            plotter.plot_neurons(
                all_axs[i, :],
                n,
                select="axon",
                color=pltltr.shift_lum(cols[j], 0.15),
                linewidth=0.5,
                rasterized=raster,
                alpha=a,
                soma_s=0,
                # zorder=100,
                label="__nolegend__",
            )
        plotter.plot_neurons(  # dummy just for plot ratio
            all_axs[2, :],
            n,
            select="axon",
            color=pltltr.shift_lum(cols[j], 0.15),
            linewidth=0.0,
            rasterized=raster,
            alpha=0,
            soma_s=0,
            label="__nolegend__",
        )


# for ax, t in zip(all_axs[:, i], ["Neurons", "Axons", "Dendrites"]):
#    ax.text(106, 0, t, fontsize=8, c=(0.3,) * 3, ha="center", va="top")

pltltr.savefig("all_ahb_full", fig=f_all)
pltltr.savefig("all_ahb_axon", fig=f_axon)
pltltr.savefig("all_ahb_dendr", fig=f_dendr)

## all reconstructions plot

In [None]:
f, all_axs = plt.subplots(3, 3, figsize=(6.2, 7.2))
for i, axs in enumerate(all_axs):
    f, axs = plotter.generate_projection_plots(axs, labels=True)

a = 0.7
raster = True
for j, n in enumerate(all_neurons_list):
    for i in [0, 1, 2]:
        plotter.plot_neurons(
            all_axs[i, :],
            n,
            select="dendrites",
            color=pltltr.shift_lum(cols[j], -0.15),
            linewidth=0.0,
            rasterized=raster,
            alpha=1,
            soma_s=30,
            # zorder=100,
            label="__nolegend__",
        )
    for i in [0, 2]:
        plotter.plot_neurons(
            all_axs[i, :],
            n,
            select="dendrites",
            color=pltltr.shift_lum(cols[j], -0.15),
            linewidth=0.5,
            rasterized=raster,
            alpha=a,
            soma_s=0,
            # zorder=100,
            label="__nolegend__",
        )
    if n.has_axon:
        for i in [0, 1]:
            plotter.plot_neurons(
                all_axs[i, :],
                n,
                select="axon",
                color=pltltr.shift_lum(cols[j], 0.15),
                linewidth=0.5,
                rasterized=raster,
                alpha=a,
                soma_s=0,
                # zorder=100,
                label="__nolegend__",
            )
        plotter.plot_neurons(  # dummy just for plot ratio
            all_axs[2, :],
            n,
            select="axon",
            color=pltltr.shift_lum(cols[j], 0.15),
            linewidth=0.0,
            rasterized=raster,
            alpha=0,
            soma_s=0,
            label="__nolegend__",
        )


for ax, t in zip(all_axs[:, i], ["Neurons", "Axons", "Dendrites"]):
    ax.text(106, 0, t, fontsize=8, c=(0.3,) * 3, ha="center", va="top")

In [None]:
neurons_dict

In [None]:
s = 3


for l in range(clusters.max(), clusters.max() + 1):  # clusters.max()):
    idxs = np.argwhere(clusters == l)[:, 0]
    if idxs.shape[0] > 2:
        f, axs = plot_bg_mod(bg_dict)

        # for t in range(3):
        for j, idx in enumerate(idxs):
            n = neurons_list[idx]
            for i in range(3):
                uni_coords = n.coords_unilat  # [n.dendr_idxs, :]
                pts = np.delete(uni_coords, i, axis=1)[:, :]
                axs[0, i].scatter(
                    pts[n.dendr_idxs, 1],
                    pts[n.dendr_idxs, 0],
                    s=s,
                    color=cols[l],
                    alpha=0.02,
                )

                coords = n.coords_mpin  # [n.dendr_idxs, :]
                pts = np.delete(coords, i, axis=1)[:, :]
                for m, sel in enumerate([n.dendr_idxs, n.axon_idxs]):
                    axs[1 + m, i].scatter(
                        pts[sel, 1], pts[sel, 0], s=s / 20, color=cols[j % len(cols)]
                    )
                    axs[1 + m, i].scatter(
                        pts[n.soma_idx : n.soma_idx + 1, 1],
                        pts[n.soma_idx : n.soma_idx + 1, 0],
                        s=s * 10,
                        color=cols[j % len(cols)],
                    )

        axs[0, 2].set_title(f"(Clust. {l} {idxs})")

    # f.savefig(fig_folder / f"clust_{l}.png")