In [None]:
%matplotlib widget

import flammkuchen as fl
import lotr.plotting as pltltr
import numpy as np
import pandas as pd
from lotr import DATASET_LOCATION
from lotr.em.core import EmNeuron
from lotr.em.loading import load_skeletons_from_zip
from matplotlib import pyplot as plt

anatomy_location = DATASET_LOCATION.parent / "anatomy"

# Quantify projections in the IPN

In [None]:
skeletons = load_skeletons_from_zip(
    anatomy_location / "annotated_traced_neurons" / "all_habaxons_p000-p040.k.zip"
)

projection_df = pd.DataFrame(
    dict(
        side=[s.comments.split(" - ")[1][0] for s in skeletons],
        projection=[s.comments.split(" - ")[1][5] for s in skeletons],
    )
)
# Exclude not fully traced values:
projection_df = projection_df.loc[projection_df["projection"] != "["]

In [None]:
counts_df = pd.DataFrame(
    {
        s: projection_df.loc[projection_df["side"] == s, "projection"].value_counts()
        for s in ["l", "r"]
    }
)

In [None]:
counts_df

In [None]:
((counts_df / projection_df["side"].value_counts()) * 100).round(1)

# Plots of all axons

In [None]:
%load_ext autoreload

In [None]:
%autoreload
from bg_atlasapi import BrainGlobeAtlas
from lotr.plotting import AtlasPlotter

In [None]:
#  Specify axes limits over all dimensions:
bs = dict(frontal=(30, 180), vertical=(40, 180), sagittal=(-20, 120))

plotter = pltltr.AtlasPlotter(
    structures=["ipn", "dipn"],
    mask_slices=dict(frontal=slice(0, 120)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)

In [None]:
neurons_dict = fl.load(
    anatomy_location / "annotated_traced_neurons" / "all_skeletons.h5"
)

In [None]:
r_hab_ids = [s.id for s in neurons_dict.values() if "rhab" in s.comments]
l_hab_ids = [s.id for s in neurons_dict.values() if "lhab" in s.comments]

In [None]:
from scipy.stats import gaussian_kde

kde_list = []

for id_list in [r_hab_ids, l_hab_ids]:
    all_coords = np.concatenate([neurons_dict[cid].coords_ipn for cid in id_list])
    all_coords = all_coords[all_coords[:, 0] > 30] * 2

    kde_list.append(gaussian_kde(all_coords[:, 1]))

In [None]:
raster = True

gridspec_kw = dict(left=0.05, right=1, top=1, bottom=0.1, hspace=0.01)
figsize = (4.2, 1.3)

kde_axs = []
f, axs = plt.subplots(1, 3, figsize=figsize, gridspec_kw=gridspec_kw)
plotter.generate_projection_plots(axs, labels=True)

pos = np.array(axs[2].get_position())

for i, (kde, cids, col, lab) in enumerate(
    zip(
        kde_list,
        [r_hab_ids, l_hab_ids],
        [pltltr.COLS["fish_cols"][3], pltltr.COLS["fish_cols"][5]],
        ["left Hb", "right Hb"],
    )
):
    # ax.margins(1, tight=True)
    # ax.fill_betweenx(x, kde(x), alpha=0.8, facecolor=col)
    # ax.set_ylim(plotter.space.shape[1], 0)
    # ax.set_xlim(0, max([k(x).max() for k in kde_list]))

    for axs_i in [i, 2]:
        for cid in cids:
            n_dict = neurons_dict[cid]

            plotter.plot_neurons(
                axs,
                n_dict,
                color=col,
                linewidth=0.4,
                rasterized=raster,
                alpha=0.5,
                # zorder=100,
                label="__nolegend__",
            )
    axs[2].plot([], [], label=lab, color=col, linewidth=0)

axs[2].legend(frameon=False, labelcolor="linecolor", handlelength=0)
axs[0].text(75, 80, "dIPN")
axs[0].text(130, 180, "vIPN")


pltltr.savefig("hab_projections")

In [None]:
#  Specify axes limits over all dimensions:
bs = dict(frontal=(50, 160), vertical=(40, 180), sagittal=(0, 120))

dipn_plotter = pltltr.AtlasPlotter(
    structures=["ipn", "dipn"],
    mask_slices=dict(frontal=slice(0, 120)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)

In [None]:
gridspec_kw = dict(left=0.05, right=1, top=1, bottom=0.1, hspace=0.01)

f, axs = plt.subplots(2, 4, figsize=(4.16, 2), gridspec_kw=gridspec_kw)

lhab_ids_2plot = [l_hab_ids[i] for i in [3, 2]]  # [6, 1, 2, 3]]
rhab_ids_2plot = [r_hab_ids[i] for i in [4, 3]]  # [0, 13, 4, 3]]

for i, p in enumerate(["horizontal", "sagittal"]):
    for j in range(2):

        for plot_column, neuron_list, col_i in zip(
            [j, j + 2], [lhab_ids_2plot, rhab_ids_2plot], [5, 3]
        ):
            dipn_plotter.plot_on_axis(
                axs[i, plot_column], projection=p, labels=False, title=False
            )

            dipn_plotter.plot_neuron_projection(
                axs[i, plot_column],
                neurons_dict[neuron_list[j]],
                projection=p,
                color=pltltr.COLS["fish_cols"][col_i],
                lw=0.4,
                rasterized=True
            )
        # axs[i, j].axis("equal")
        # axs[i, j+2].axis("equal")
        
pltltr.savefig("single_hab_projections")

In [None]:
vipn_plotter = pltltr.AtlasPlotter(
    structures=["vipn"],
    mask_slices=dict(frontal=slice(0, 120)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)

In [None]:
pltltr.COLS["fish_cols"][3]

In [None]:
np.array(BrainGlobeAtlas("allen_mouse_25um").shape) * 25

In [None]:
7760, -31645, -5943

In [None]:
np.array(plotter.atlas.shape) / 2

In [None]:
plotter.atlas.structures

In [None]:
np.unique(plotter.atlas.annotation)

In [None]:
# for i, ax in enumerate(axs):
#    if i < 2:
#        ax.imshow(rosette_volume.mean(i), cmap="gray_r")
#    else:
#        ax.imshow(rosette_volume.mean(i).T, cmap="gray_r")

x = np.arange(plotter.space.shape[1])
for i, (kde, cids, col) in enumerate(
    zip(kde_list, [r_hab_ids, l_hab_ids], [[0.8, 0.14, 0.25], [0.2, 0.39, 0.85]])
):
    for axs_i in [i, 2]:
        for cid in cids:
            n_dict = neurons_dict[cid]

            plotter.plot_neurons(
                all_axs[axs_i], n_dict, color=col, linewidth=0.2, rasterized=raster
            )

        x = np.arange(plotter.space.shape[1])

        axs = kde_axs[axs_i]
        axs.fill_betweenx(x, kde(x), alpha=0.8, facecolor=col)
        axs.set_ylim(plotter.space.shape[1], 0)
        axs.set_xlim(0, max([k(x).max() for k in kde_list]))
        # axs.set_xlabel("Dens.")
        axs.spines["left"].set_linewidth(0.2)
        [axs.spines[k].set_visible(False) for k in ["right", "top", "bottom"]]
        axs.set_xticks([])
        axs.set_yticks([])

k = ["", "rast"][raster]

In [None]:
fliplabels = False
(1 - fliplabels, int(fliplabels))  # [i for i in range(1-fliplabels, fliplabels)]