# Define clusters for IPN neurons

This is a finnicky notebook written in the rush of thesis finishing, careful when running.

In [None]:
%matplotlib widget
from pathlib import Path

import flammkuchen as fl
import matplotlib as mpl
import numpy as np
import pandas as pd
import scipy.cluster.hierarchy as sch
from bg_atlasapi import BrainGlobeAtlas
from lotr import DATASET_LOCATION
from lotr import plotting as pltltr
from matplotlib import pyplot as plt
from scipy.ndimage import gaussian_filter
from tqdm import tqdm

cols = pltltr.COLS["qualitative"] * 5
cols2 = pltltr.COLS["fish_cols"] * 5
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=cols)

atlas = BrainGlobeAtlas("ipn_zfish_0.5um")

In [None]:
anatomy_location = DATASET_LOCATION / "anatomy"

= fl.load(anatomy_location / "annotated_traced_neurons" / "all_skeletons.h5")

In [None]:
neurons_clust = fl.load(anatomy_location / "ipn_morph_clust.h5")

In [None]:
sel_clusters = [4, 5, 6, 10]
neurons_rosette = [
    neurons[nid] for nid, l in neurons_clust.items() if l in sel_clusters
]

In [None]:
rosette_vol = np.zeros(atlas.shape, dtype=np.uint8)
neurons_coords = np.concatenate(
    [n.coords_ipn[n.dendr_idxs, :] for n in neurons_rosette], 0
)
n_coords_idxs = (neurons_coords * 2).astype(int)

rosette_vol[n_coords_idxs[:, 0], n_coords_idxs[:, 1], n_coords_idxs[:, 2]] = 255
# blur_rosette = gaussian_filter(rosette_vol, (7, 7, 7))

In [None]:
#  Specify axes limits over all dimensions:
bs = dict(frontal=(40, 170), vertical=(20, 180), sagittal=(20, 120))
lw = 0.4

plotter = pltltr.AtlasPlotter(
    structures=["ipn", "glomeruli"],
    mask_slices=dict(frontal=slice(0, 400)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)
plotter.fill_kw["linewidth"] = 0.4

bs = dict(frontal=(50, 160), vertical=(20, 180), sagittal=(20, 120))
plotter_noglom = pltltr.AtlasPlotter(
    structures=["ipn"],
    mask_slices=dict(frontal=slice(0, 400)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)
plotter_noglom.fill_kw["linewidth"] = 0.4

In [None]:
clust_cols = dict(c4=0, c6=1, c5=2)
f, axs = plt.subplots(1, 2, figsize=(4.16, 2), gridspec_kw=dict(top=1, left=0, right=1))
axs[1].axis("off")
axs[1].axis("equal")

ax = axs[0]
plotter_noglom.plot_on_axis(ax, projection="horizontal", labels=True, title=False)
all_coords = np.concatenate(
    [neuron.coords_ipn[neuron.dendr_idxs, :] for neuron in neurons_rosette], 0
)
for cl_id, col_id in clust_cols.items():
    for neuron in neurons_rosette:
        if neurons_clust[neuron.id] == int(cl_id[1:]):
            col = cols2[col_id + 3 * neuron.is_left - 1]
            plotter.plot_neuron_projection(
                ax,
                neuron,
                projection="horizontal",
                select="dendrites",
                c=pltltr.shift_lum(col, 0.15),
                soma_s=0,
                alpha=0.8,
                lw=0.4,
                zorder=-1,
                rasterized=True,
            )
            plotter.plot_neuron_projection(
                ax,
                neuron,
                projection="horizontal",
                select="dendrites",
                c=pltltr.shift_lum(col, -0.05),
                soma_s=30,
                lw=0,
                zorder=1,
                rasterized=True,
            )

ax = axs[1]
plotter.plot_on_axis(ax, projection="horizontal", labels=False, title=False)
plotter.ax_scatterplot(
    ax, "horizontal", all_coords, rasterized=True, alpha=0.004, c="k", s=0.1
)

mid = 107
dx = [7, 25, 14]
dy = [54, 58, 85]

for i in range(3):
    for x in [mid - dx[i], mid + dx[i] + 3]:
        tx = ax.text(
            x,
            dy[i],
            f"G{(i + 1)}",
            fontsize=8,
            # fontweight="bold",
            va="center",
            ha="center",
            c="Maroon",
        )
    # tx = ax.text(
    #    mid - dx[i], dy[i], f"G{(i + 1)}", fontsize=8, va="center", ha="center"
    # )
    # tx.set_bbox(dict(facecolor="Maroon", alpha=0.5, edgecolor="w"))

pltltr.savefig("rosette")

# From histology

In [None]:
from tifffile import imread

In [None]:
staining = imread(r"/Volumes/Shared/Luigi/best_GAD67-Alexa488.tif")

In [None]:
#  Specify axes limits over all dimensions:
bs = dict(frontal=(40, 170), vertical=(20, 180), sagittal=(20, 120))
lw = 0.4

plotter_r = pltltr.AtlasPlotter(
    structures=["ipn", "glomeruli"],
    mask_slices=dict(frontal=slice(0, 400)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)

In [None]:
plotter_r.fill_kw["linewidth"] = 0.5
plotter_r.fill_kw["edgecolor"] = pltltr.shift_lum("Maroon", 0.3)
plotter_r.fill_kw["alpha"] = 0.8

f, axs = plt.subplots(1, 1, figsize=(2, 2), gridspec_kw=dict(top=1, left=0, right=1))
axs
axs.imshow(
    staining[11, 640:820, 420:610],
    vmax=800,
    cmap="gray",
    origin="lower",
    extent=[160, 55, 20, 120],
)
plotter_r.plot_on_axis(
    axs, projection="horizontal",
)
plt.text(
    160, 20, "antiGAD67-Alexa488", fontsize=8, c=(0.4,) * 3, ha="right", va="bottom"
)

pltltr.savefig("rosette_staining")

# Glomeruli connections 

In [None]:
neurons_g1 = []

In [None]:
neurons_clust_ahb = fl.load(anatomy_location / "ahb_morph_clust.h5")

In [None]:
neurons_clust_ahb

In [None]:
#  Specify axes limits over all dimensions:
bs = dict(frontal=(20, 190), vertical=(40, 180), sagittal=(-20, 120))

plotter = pltltr.AtlasPlotter(
    structures=["ipn", "dipn"],
    mask_slices=dict(frontal=slice(0, 120)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)

In [None]:
[k for k, c in neurons_clust_ahb.items() if c == 1]

In [None]:
f, axs = plotter.generate_projection_plots(figsize=(4.2, 2), labels=True)

# i_a0
plotter.plot_neurons(axs, neurons["p041"], lw=0.4)
plotter.plot_neurons(axs, neurons["p048"], lw=0.4)
plotter.plot_neurons(axs, neurons["p056"], lw=0.4)
plotter.plot_neurons(axs, neurons["p070"], lw=0.4)

In [None]:
def plot_bicol_n(
    ax,
    neuron,
    plotter,
    projection="horizontal",
    color=None,
    shift=0.15,
    linewidth=0.6,
    alpha=1,
    raster=True,
    col_axon=None,
):
    shifted_col = pltltr.shift_lum(color, shift)
    if col_axon is None:
        col_axon = shifted_col
    plotter.plot_neuron_projection(
        ax,
        neuron,
        projection=projection,
        select="dendrites",
        color=color,
        linewidth=0,
        rasterized=raster,
        alpha=alpha,
        soma_s=30,
        zorder=-100,
        label="__nolegend__",
    )
    plotter.plot_neuron_projection(
        ax,
        neuron,
        projection=projection,
        select="dendrites",
        color=color,
        linewidth=linewidth,
        rasterized=raster,
        alpha=alpha,
        soma_s=0,
        zorder=100,
        label="__nolegend__",
    )
    if neuron.has_axon:
        plotter.plot_neuron_projection(
            ax,
            neuron,
            projection=projection,
            select="axon",
            color=col_axon,
            linewidth=linewidth - 0.2,
            rasterized=raster,
            alpha=alpha,
            soma_s=0,
            zorder=-100,
            label="__nolegend__",
        )

In [None]:
clust_cols = dict(c4=0, c6=1, c5=2)
f, ax = plt.subplots(1, 1, figsize=(2, 1.5), gridspec_kw=dict(top=1, left=0.1, right=1))
ax.axis("off")
ax.axis("equal")
p = "horizontal"
plotter.plot_on_axis(ax, projection=p, labels=False, title=False)

# plotter.plot_neurons(ax, neurons["p041"], lw=0.4)
plot_bicol_n(ax, neurons["p041"], plotter, p, cols[0], shift=0.2)
plot_bicol_n(ax, neurons["p048"], plotter, p, cols[0], shift=0.2, alpha=0.2)
plot_bicol_n(ax, neurons["p056"], plotter, p, cols[1], shift=0.2, alpha=0.2)
plot_bicol_n(ax, neurons["p070"], plotter, p, cols[1], shift=0.2)

pltltr.savefig("bidirect_connect_glom")