# EM dataset

In [None]:
%matplotlib widget

In [None]:
from pathlib import Path
import numpy as np

from lotr import DATASET_LOCATION
from lotr import plotting as pltltr
from lotr.analysis.pool_cell_info import get_pooled_cell_info
from lotr.em.core import MIDLINES
from lotr.em.loading import load_skeletons_from_zip
from lotr.utils import linear_regression
from matplotlib import cm
from matplotlib import pyplot as plt
from scipy.ndimage.morphology import binary_dilation
from tqdm import tqdm

from lotr.result_logging import ResultsLogger
logger = ResultsLogger()

COLS = pltltr.COLS

## Data loading

In [None]:
# Reference:
# Load the gad1b stack and the annotation from the IPN atlas:

from bg_atlasapi.core import Atlas

atlas = Atlas(DATASET_LOCATION.parent / "anatomy" / "ipn_zfish_0.5um_v1.6")

gad1b_stack = atlas.additional_references["gad1b_gal4"]


# Enlarge a bit the mask for the IPN (in the future we'll draw better masks):
mask = np.zeros((3, 3, 3)).astype(bool)
mask[:, 1, :] = True
annotation = binary_dilation(
    atlas.get_structure_mask("ipn") > 0, iterations=7, structure=mask
)
annotation_dipn = binary_dilation(
    atlas.get_structure_mask("dors_ipn") > 0, iterations=7, structure=mask
)

In [None]:
# Load the EM and mirror neurons on one side:
em_path = DATASET_LOCATION.parent / "anatomy" / "all_ahb.k.zip"

neurons_list = load_skeletons_from_zip(em_path)

for n in neurons_list:
    n.mirror_right = True

# Select neurons with projections internal to the dIPN:

sel_neurons = [
    n
    for n in neurons_list
    if ("bilat" in n.comments or "bilatdend" in n.comments)
    and not "proj" in n.comments
    and "dipn" in n.comments.split("_")
]

## Projections with line anatomy

In [None]:
from lotr import plotting as pltltr

COLS = pltltr.COLS

In [None]:
#  Specify axes limits over all dimensions:
bs = dict(frontal=(30, 180), vertical=(-4, 170), sagittal=(-50, 120))

plotter = pltltr.AtlasPlotter(
    atlas=atlas,
    structures=["ipn", "dors_ipn"],
    mask_slices=dict(frontal=slice(0, 120)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)

## Calculate centroids

In [None]:
centroids = []
for n in sel_neurons:
    centroids.append(
        [n.find_centroid_bifurcation(select=s) for s in ["dendrites", "axon"]]
    )

centroids = np.array(centroids)
norm_cm = centroids[:, 1, 2] - centroids[:, 1, 2].min()
norm_cm = norm_cm + 5
norm_cm = norm_cm / (norm_cm.max() + 0.01)

In [None]:
bs = dict(frontal=(30, 180), vertical=(-20, 170), sagittal=(-50, 120))
plotter = pltltr.AtlasPlotter(
    atlas=atlas,
    structures=["ipn", "dors_ipn"],
    mask_slices=dict(frontal=slice(0, 120)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)

In [None]:
len(sel_neurons)

In [None]:
f, axs = plotter.generate_projection_plots()

lw = 0.5
for i, (neuron, c) in enumerate(zip(sel_neurons, norm_cm)):
    if "bilat" in neuron.comments and "dipn" in neuron.comments.split("_"):
        main_c = cm.get_cmap("Reds")(c)
        plotter.plot_neurons(axs, neuron, c=main_c, lw=lw, rasterized=True)

fake_im = axs[0].scatter([], [], c=[], cmap="Reds")
pltltr.add_cbar(
    fake_im,
    axs[1],
    inset_loc=(0.425, 0.8, 0.3, 0.04),
    ticks=[0.1, 0.9],
    ticklabels=["lat", "med"],
    orientation="horizontal",
    title="Dendrite position",
    titlesize=8,
)

for i, (proj, pos) in enumerate(
    zip(plotter.space.sections, [(35, 165), (35, 110), (5, 165)])
):
    pltltr.add_anatomy_scalebar(
        axs[i], plane=proj, pos=pos, cartesian=True, equalize_axis=False, length=30,
    )
pltltr.savefig("dipn_neurons_colcoded")

In [None]:
distances_from_midline = np.abs(centroids[:, :, 2] - MIDLINES["ipn"])

s = 8
f, axs = plt.subplots(
    1, 2, figsize=(4, 1.5), gridspec_kw=dict(bottom=0.3, left=0.3, wspace=0.5)
)
xline_lims = np.mean(distances_from_midline[:, 0]) + np.array(
    [np.std(distances_from_midline[:, 0]) * 1.2 * s for s in [-1, 1]]
)

soma_pos = np.zeros(len(sel_neurons))
for i, neuron in enumerate(sel_neurons):
    soma_pos[i] = neuron.coords_ipn[neuron.soma_idx, 0]

for i, (x, y) in enumerate(
    [
        (distances_from_midline[:, 0], distances_from_midline[:, 1]),
        (distances_from_midline[:, 0], soma_pos),
    ]
):
    axs[i].scatter(x, y, s=s, c=norm_cm, cmap="Reds", vmin=0, vmax=1)
    a, b = linear_regression(x, y)
    axs[i].plot(xline_lims, a + xline_lims * b, lw=1, c=".4", zorder=-10)
    ylims = axs[i].get_ylim()
    cc = np.corrcoef(x, y)[0, 1]
    axs[i].text(
        np.mean(xline_lims),
        ylims[1] - (ylims[1] - ylims[0]) * 0.1,
        f"R = {cc:0.2f}",
        ha="center",
    )

axs[0].set(xlabel="Lat-med pos. dendrite", ylabel="Lat-med pos. axon")
axs[1].set(xlabel="Lat-med pos. dendrite", ylabel="Ant-post pos. soma")

for ax in axs:
    pltltr.despine(ax)
pltltr.savefig("dipn_neurons_scatterplots")

In [None]:
logger

In [None]:
#  Specify axes limits over all dimensions:
bs = dict(frontal=(50, 160), vertical=(-4, 170), sagittal=(-50, 120))

plotter = pltltr.AtlasPlotter(
    atlas=atlas,
    structures=["ipn", "dors_ipn"],
    mask_slices=dict(frontal=slice(0, 120)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)

# Plot pairs

In [None]:
neurons_list_notmirr = load_skeletons_from_zip(em_path)

# Select neurons with projections internal to the dIPN:

sel_neurons_notmirr = [
    n
    for n in neurons_list_notmirr
    if "bilat" in n.comments.split("_") and "dipn" in n.comments.split("_")
]

In [None]:
sel_neurons_notmirr_dict.keys()

In [None]:
sel_neurons_notmirr_dict = {n.id: n for n in sel_neurons_notmirr}
cols = COLS["qualitative"]

off_x = 0
off_y = 0
s = 1.0
f, ax = plt.subplots(figsize=(4, 4))
"""
ax.imshow(
    gad1b_stack[20:170, :, :].max(0),
    origin="lower",
    cmap="gray_r",
    vmin=450,
    vmax=15000,
    extent=(0, (gad1b_stack.shape[2] / 2) * s, 0, (gad1b_stack.shape[1] / 2) * s),
)
"""
pltltr.plot_projection(
    annotation[:120, :, :],
    0,
    ax=ax,
    smooth_wnd=15,
    linewidth=0.5,
    fill=False,
    edgecolor=".3",
)
pltltr.plot_projection(
    annotation_dipn[:130, :, :],
    0,
    ax=ax,
    smooth_wnd=15,
    linewidth=0.5,
    fill=False,
    edgecolor=".3",
)

for i, neuron_id in enumerate(["c017", "p084", "c011", "p077"]):
    neuron = sel_neurons_notmirr_dict[neuron_id]
    if "bilat" in neuron.comments.split("_") and "dipn" in neuron.comments.split("_"):
        main_c = COLS["phase"](1 / 4 + i / 4)
        lines = neuron.generate_plotlines_from_skeleton(space="ipn", select="dendrites")
        (l,) = ax.plot(
            lines[:, 2], lines[:, 1], lw=0.5, c=pltltr.shift_lum(main_c, 0.0)
        )

        lines = neuron.generate_plotlines_from_skeleton(space="ipn", select="axon")
        (l,) = ax.plot(
            lines[:, 2], lines[:, 1], lw=0.5, c=pltltr.shift_lum(main_c, 0.0)
        )

        soma = neuron.coords_ipn[neuron.soma_idx, :]
        ax.scatter(soma[2], soma[1], s=30, color=pltltr.shift_lum(main_c, 0.0))
        ax.axis("equal")
        ax.set(ylim=(170, -4), xlim=(50, 160))

pltltr.despine(ax, "all")
pltltr.add_scalebar(
    ax, ylen=-20, xlen=20, ypos=140, xpos=40, xlabel="R. - L.", ylabel="inf. sup."
)
pltltr.savefig("dipn_pairs")

## Coords in MPI space

In [None]:
pltltr.savefig("pairs")
from bg_atlasapi import BrainGlobeAtlas
from tifffile import imread

In [None]:
mpin_atlas = BrainGlobeAtlas("mpin_zfish_1um")
mpin_atlas.annotation
# Annoying fixes required to deal with the Baier annotation mess. Will be cleaned in next version of the atlas.
mpin_atlas._annotation[[0, -1], :, :] = 0
mpin_atlas._annotation[:, [0, -1], :] = 0
mpin_atlas._annotation[:, :, [0, -1]] = 0
mpin_atlas._annotation[atlas._annotation == 0] = 15000
mpin_atlas.structures["root"]["id"] = 1

In [None]:
import flammkuchen as fl
from lotr.data_preprocessing.anatomy import transform_points

matrices = fl.load(
    "/Volumes/Shared/experiments/E0044_spontaneous/gad1b/2p_anatomy/dendra/morphed anatomy files/xform_mat.h5"
)
mat = matrices["to_mpin"] @ np.linalg.inv(matrices["to_ipn"])

In [None]:
coords = np.concatenate([neuron.coords_ipn[:, :] for neuron in sel_neurons_notmirr])
# trasf_coords = transform_points(coords, mat)

In [None]:
#  Specify axes limits over all dimensions:
bs = dict(
    frontal=(0, mpin_atlas.shape[2] + 10),
    vertical=(0, mpin_atlas.shape[1] + 10),
    sagittal=(0, mpin_atlas.shape[0] + 10),
)

plotter = pltltr.AtlasPlotter(
    atlas=mpin_atlas,
    structures=["dorsal habenula", "root", "retina", "interpeduncular nucleus"],
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
    smooth_wnd=50,
)

In [None]:
plt.figure()
plt.scatter(coords[:, 2], coords[:, 1])