# Define clusters for IPN neurons

This is a finnicky notebook written in the rush of thesis finishing, careful when running.

In [None]:
%matplotlib widget
from pathlib import Path

import flammkuchen as fl
import matplotlib as mpl
import numpy as np
import pandas as pd
import scipy.cluster.hierarchy as sch
from bg_atlasapi import BrainGlobeAtlas
from lotr import DATASET_LOCATION
from lotr import plotting as pltltr
from matplotlib import pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from sklearn.metrics.pairwise import euclidean_distances
from tqdm import tqdm

cols = pltltr.COLS["qualitative"] * 5
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=cols)

atlas = BrainGlobeAtlas("ipn_zfish_0.5um")

In [None]:
anatomy_location = DATASET_LOCATION.parent / "anatomy"
neurons = fl.load(anatomy_location / "annotated_traced_neurons" / "all_skeletons.h5")

In [None]:
len(list(neurons.keys()))

In [None]:
[n.comment for n in neurons]

In [None]:
ahb_vipn_neurons = [
    n
    for k, n in neurons.items()
    if "ahb" in n.comments
    and "vipn" in n.comments
    and "unknown" not in n.comments
    and n.has_axon
]
ahb_dipn_neurons = [
    n
    for k, n in neurons.items()
    if "ahb" in n.comments
    and "dipn" in n.comments
    and "unknown" not in n.comments
    and n.has_axon
]
leftout = [
    n
    for k, n in neurons.items()
    if "ahb" in n.comments
    and ("dipn" not in n.comments and "vipn" not in n.comments)
    and "unknown" not in n.comments
    and n.has_axon
]
print([n.comments for n in leftout])

In [None]:
len(ahb_vipn_neurons)

In [None]:
len(ahb_dipn_neurons)

In [None]:
len([n for k, n in neurons.items() if "ahb" in n.comments])

In [None]:
def median_distances(c1, c2):
    coords_dist = euclidean_distances(c1, c2)  # calculate pairwise Euclidean distance
    # np.fill_diagonal(coords_dist, np.nan)
    d1, d2 = np.min(coords_dist, 1), np.min(coords_dist, 0)

    return np.median(d1), np.median(d2)


def get_dendr_axon_dist_mtx(neurons):
    # Cluster over dendrites:
    n_neurons = len(neurons)
    cell_dist_dendr = np.full((n_neurons,) * 2, np.nan)
    cell_dist_axon = np.full((n_neurons,) * 2, np.nan)

    for i in tqdm(range(n_neurons)):
        n1 = neurons[i]
        for j in range(n_neurons):
            n2 = neurons[j]
            c1 = n1.coords_unilat_ipn[n1.dendr_idxs, :]  # coords of cell 1
            c2 = n2.coords_unilat_ipn[n2.dendr_idxs, :]  # coords of cell 2
            cell_dist_dendr[i, j], cell_dist_dendr[j, i] = median_distances(c1, c2)
            if n1.has_axon and n2.has_axon:
                c1 = n1.coords_unilat_ipn[n1.axon_idxs, :]  # coords of cell 1
                c2 = n2.coords_unilat_ipn[n2.axon_idxs, :]  # coords of cell 2
                cell_dist_axon[i, j], cell_dist_axon[j, i] = median_distances(c1, c2)

    symm_cell_dist_dendr = np.nanmax(
        np.stack([cell_dist_dendr, cell_dist_dendr.T]), axis=0
    )
    symm_cell_dist_axon = np.nanmax(
        np.stack([cell_dist_axon, cell_dist_axon.T]), axis=0
    )

    return symm_cell_dist_dendr, symm_cell_dist_axon

In [None]:
symm_cell_dist_dendr, symm_cell_dist_axon = get_dendr_axon_dist_mtx(ahb_vipn_neurons)

In [None]:
# For axons, we can't cluster all cells. For dendrites steps not really needed:
n_neurons = symm_cell_dist_axon.shape[0]
dist_mtx_toclust = symm_cell_dist_axon
filt_idxs = np.arange(n_neurons)
selector = ~np.isnan(dist_mtx_toclust[0, :])
filt_idxs = filt_idxs[selector]
dist_mtx_toclust = dist_mtx_toclust[:, selector][selector, :]

In [None]:
linkage = sch.linkage(dist_mtx_toclust, method="ward")

In [None]:
thr = 80

f = plt.figure(figsize=(2, 3))
# Temp useless plot, just to get axes of right size:
ax = f.add_axes((0.2, 0.1, 0.6, 0.5))
ax.imshow(dist_mtx_toclust, vmin=0, vmax=50, cmap="pink")
pltltr.despine(ax, "all")

ax1 = f.add_axes((0.2, ax.get_position().y1, 0.6, 0.4))
with plt.rc_context({"lines.linewidth": 0.5}):
    dendrogram = sch.dendrogram(
        linkage, color_threshold=thr, above_threshold_color=(0.5,) * 3
    )
s_idxs = dendrogram["leaves"]
plt.xticks([])
plt.yticks([])
plt.axhline(thr, c=(0.3,) * 3, lw=0.5)

im = ax.imshow(dist_mtx_toclust[s_idxs, :][:, s_idxs], vmin=0, vmax=50, cmap="pink")
print(ax.get_position())

pltltr.add_cbar(
    im, ax, (1.15, 0.05, 0.05, 0.2), orientation="vertical", title="Dist.", titlesize=8
)
ax.set(
    xlabel="Sorted cell n.", ylabel="Sorted cell n.",
)
pltltr.despine(ax, "all")
pltltr.despine(ax1, "all")

plt.show()
pltltr.savefig("clustering_dendr")

In [None]:
clusters = np.array(sch.cut_tree(linkage, height=thr)[:, 0])
# ugly loop to figure out correspondance of dendrogram above with cluster identity:
# Maybe hierarchical clustering tools can provide a better way?
tree_sort_clusters = []
for i in clusters[np.array(dendrogram["leaves"])]:
    if i not in tree_sort_clusters:
        tree_sort_clusters.append(i)

reindexed_clusters = np.zeros(len(clusters), dtype=int)
for n, i in enumerate(tree_sort_clusters):
    reindexed_clusters[clusters == i] = n

In [None]:
#  Specify axes limits over all dimensions:
bs = dict(frontal=(0, 210), vertical=(0, 180), sagittal=(-20, 120))

plotter = pltltr.AtlasPlotter(
    structures=["ipn", "dipn"],
    mask_slices=dict(frontal=slice(0, 120)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)

In [None]:
raster = True

col_k = 0
base_col = (0.3,) * 3
for i_l, l in enumerate(range(reindexed_clusters.max() + 1)):  # clusters.max()):
    f, axs = plotter.generate_projection_plots(figsize=(6.5, 2), labels=True)
    # idxs = filt_idxs[np.argwhere(clusters == l)[:, 0]]
    idxs = np.argwhere(reindexed_clusters == l)[:, 0]
    alpha = [0] * len(idxs)  # different lum for first neuron
    alpha[-1] = 0.3
    a = 0.5
    a1 = 0.95
    if len(idxs) > 1:
        base_col = cols[col_k]
        col_k += 1
    else:
        base_col = (0.3,) * 3
    for i, s in zip(idxs, alpha):
        neuron = ahb_vipn_neurons[i]
        plotter.plot_neurons(
            axs,
            neuron,
            select="dendrites",
            color=pltltr.shift_lum(cols[0], 0.15 - s),
            linewidth=0.5,
            rasterized=raster,
            alpha=a if s == 0 else a1,
            soma_s=0,
            label="__nolegend__",
        )
        plotter.plot_neurons(
            axs,
            neuron,
            select="dendrites",
            color=pltltr.shift_lum(cols[1], 0.15 - s),
            linewidth=0,
            rasterized=raster,
            alpha=a if s == 0 else a1,
            zorder=100,
            label="__nolegend__",
        )
        if neuron.has_axon:
            plotter.plot_neurons(
                axs,
                neuron,
                select="axon",
                color=pltltr.shift_lum(cols[3], -s),
                linewidth=0.4,
                rasterized=raster,
                alpha=a if s == 0 else a1,
                soma_s=0,
                # zorder=100,
                label="__nolegend__",
            )

        # alpha = 0.4

    axs[1].text(
        106,
        0,
        f"c{i_l} ({len(idxs)} cells)",
        fontsize=8,
        c=base_col,
        ha="center",
        va="top",
    )

In [None]:
clust_dict = dict()
for i, n in enumerate(ahb_vipn_neurons):
    clust_dict[n.id] = reindexed_clusters[i]

## dIPN projections

In [None]:
symm_cell_dist_dendr, symm_cell_dist_axon = get_dendr_axon_dist_mtx(ahb_dipn_neurons)

In [None]:
dist_mtx_toclust = symm_cell_dist_axon
linkage = sch.linkage(dist_mtx_toclust, method="ward")

In [None]:
thr = 200

f = plt.figure(figsize=(2, 3))
# Temp useless plot, just to get axes of right size:
ax = f.add_axes((0.2, 0.1, 0.6, 0.5))
ax.imshow(dist_mtx_toclust, vmin=0, vmax=50, cmap="pink")
pltltr.despine(ax, "all")

ax1 = f.add_axes((0.2, ax.get_position().y1, 0.6, 0.4))
with plt.rc_context({"lines.linewidth": 0.5}):
    dendrogram = sch.dendrogram(
        linkage, color_threshold=thr, above_threshold_color=(0.5,) * 3
    )
s_idxs = dendrogram["leaves"]
plt.xticks([])
plt.yticks([])
plt.axhline(thr, c=(0.3,) * 3, lw=0.5)

im = ax.imshow(dist_mtx_toclust[s_idxs, :][:, s_idxs], vmin=0, vmax=50, cmap="pink")
print(ax.get_position())

pltltr.add_cbar(
    im, ax, (1.15, 0.05, 0.05, 0.2), orientation="vertical", title="Dist.", titlesize=8
)
ax.set(
    xlabel="Sorted cell n.", ylabel="Sorted cell n.",
)
pltltr.despine(ax, "all")
pltltr.despine(ax1, "all")

plt.show()
pltltr.savefig("clustering_dendr")

In [None]:
clusters = np.array(sch.cut_tree(linkage, height=thr)[:, 0])
# ugly loop to figure out correspondance of dendrogram above with cluster identity:
# Maybe hierarchical clustering tools can provide a better way?
tree_sort_clusters = []
for i in clusters[np.array(dendrogram["leaves"])]:
    if i not in tree_sort_clusters:
        tree_sort_clusters.append(i)

reindexed_clusters = np.zeros(len(clusters), dtype=int)
for n, i in enumerate(tree_sort_clusters):
    reindexed_clusters[clusters == i] = n

In [None]:
raster = True

col_k = 0
base_col = (0.3,) * 3
for i_l, l in enumerate(range(reindexed_clusters.max() + 1)):  # clusters.max()):
    f, axs = plotter.generate_projection_plots(figsize=(6.5, 2), labels=True)
    # idxs = filt_idxs[np.argwhere(clusters == l)[:, 0]]
    idxs = np.argwhere(reindexed_clusters == l)[:, 0]
    alpha = [0] * len(idxs)  # different lum for first neuron
    alpha[-1] = 0.3
    a = 0.5
    a1 = 0.95
    if len(idxs) > 1:
        base_col = cols[col_k]
        col_k += 1
    else:
        base_col = (0.3,) * 3
    for i, s in zip(idxs, alpha):
        neuron = ahb_dipn_neurons[i]
        plotter.plot_neurons(
            axs,
            neuron,
            select="dendrites",
            color=pltltr.shift_lum(cols[0], 0.15 - s),
            linewidth=0.5,
            rasterized=raster,
            alpha=a if s == 0 else a1,
            soma_s=0,
            label="__nolegend__",
        )
        plotter.plot_neurons(
            axs,
            neuron,
            select="dendrites",
            color=pltltr.shift_lum(cols[1], 0.15 - s),
            linewidth=0,
            rasterized=raster,
            alpha=a if s == 0 else a1,
            zorder=100,
            label="__nolegend__",
        )
        if neuron.has_axon:
            plotter.plot_neurons(
                axs,
                neuron,
                select="axon",
                color=pltltr.shift_lum(cols[3], -s),
                linewidth=0.4,
                rasterized=raster,
                alpha=a if s == 0 else a1,
                soma_s=0,
                # zorder=100,
                label="__nolegend__",
            )

        # alpha = 0.4

    axs[1].text(
        106,
        0,
        f"c{i_l} ({len(idxs)} cells)",
        fontsize=8,
        c=base_col,
        ha="center",
        va="top",
    )

In [None]:
max_clust = max([v for k, v in clust_dict.items()]) + 1
for i, n in enumerate(ahb_dipn_neurons):
    clust_dict[n.id] = reindexed_clusters[i] + max_clust

# Final cluster plot - all clusters and individual neurons

In [None]:
all_coords = np.concatenate([n.coords_ipn for n in ahb_dipn_neurons + ahb_vipn_neurons])

In [None]:
nice_n = [1, 0, 4, 2]

raster = True
n_clust = max([v for k, v in clust_dict.items()]) + 1

f, all_axs = plt.subplots(
    n_clust,
    3,
    figsize=(6.2, 7.5),
    gridspec_kw=dict(top=1, bottom=0.1, left=0.05, right=1, wspace=0, hspace=0.02),
)

# Exclude the first cluster with only two neurons:
for l in range(n_clust):  # clusters.max()):
    axs = all_axs[l, :]
    plotter.generate_projection_plots(axs, labels=True)
    # idxs = filt_idxs[np.argwhere(clusters == l)[:, 0]]
    idxs = [k for k, v in clust_dict.items() if v == l]
    # Plot nice neuron as last one:
    p = nice_n[l]
    idxs = np.concatenate([idxs[:p], idxs[p + 1 :], idxs[p : p + 1]])
    alpha = [0] * len(idxs)  # different lum for first neuron
    alpha[-1] = 0.3
    a = 0.5
    a1 = 0.95
    for ax in axs:
        plotter.axs_scatterplot(axs, all_coords, s=0, rasterized=raster)
    for i, s in zip(idxs, alpha):
        neuron = neurons[i]
        plotter.plot_neurons(
            axs,
            neuron,
            select="dendrites",
            color=pltltr.shift_lum(cols[0], 0.15 - s),
            linewidth=0.5,
            rasterized=raster,
            alpha=a if s == 0 else a1,
            soma_s=0,
            zorder=-100,
            label="__nolegend__",
        )
        plotter.plot_neurons(
            axs,
            neuron,
            select="dendrites",
            color=pltltr.shift_lum(cols[1], 0.15 - s),
            linewidth=0,
            rasterized=raster,
            alpha=a if s == 0 else a1,
            zorder=1000,
            label="__nolegend__",
        )
        plotter.plot_neurons(
            axs,
            neuron,
            select="axon",
            color=pltltr.shift_lum(cols[3], -s),
            linewidth=0.4,
            rasterized=raster,
            alpha=a if s == 0 else a1,
            soma_s=0,
            zorder=-100,
            label="__nolegend__",
        )

        # alpha = 0.4

    axs[1].text(
        106,
        0,
        f"c{l+1} ({len(idxs)} cells)",
        fontsize=8,
        c=cols[i_l],
        ha="center",
        va="top",
    )

pltltr.savefig("clusters_ahb.pdf")

# Export cluster identity

In [None]:
fl.save(DATASET_LOCATION.parent / "anatomy" / "ahb_morph_clust.h5", clust_dict)

In [None]:
clust_ids = {
    n.id: l for l, n in zip(reindexed_clusters, ahb_dipn_neurons + ahb_vipn_neurons)
}

In [None]:
clust_ids

In [None]:
from lotr import FIGURES_LOCATION

In [None]:
raster = True
FIGURES_LOCATION
with PdfPages(Path("/Users/luigipetrucco/Desktop") / "ipn_neurons_clusters.pdf") as pdf:
    for i_l, l in enumerate(
        tree_sort_clusters
    ):  # range(clusters.max() + 1):  # clusters.max()):
        idxs = filt_idxs[np.argwhere(clusters == l)[:, 0]]

        f, all_axs = plt.subplots(
            len(idxs) + 1, 3, figsize=(6.5, 2.3 * (len(idxs) + 1))
        )
        for i in range(all_axs.shape[0]):
            plotter.generate_projection_plots(all_axs[i, :], labels=True)
        alpha = [0] * len(idxs)  # different lum for first neuron
        alpha[-1] = 0.3
        a1 = 0.95
        for n, (i, s) in enumerate(zip(idxs, alpha)):
            for n_plot, axs in enumerate(all_axs[[0, n + 1], :]):
                a = 0.5
                if n_plot > 0:
                    s = alpha[-1]
                    a = a1
                plotter.plot_neurons(
                    axs,
                    valid_ipn_neurons[i],
                    select="dendrites",
                    color=pltltr.shift_lum(cols[0], 0.15 - s),
                    linewidth=0.5,
                    rasterized=raster,
                    alpha=a if s == 0 else a1,
                    soma_s=0,
                    label="__nolegend__",
                )
                plotter.plot_neurons(
                    axs,
                    valid_ipn_neurons[i],
                    select="dendrites",
                    color=pltltr.shift_lum(cols[1], 0.15 - s),
                    linewidth=0,
                    rasterized=raster,
                    alpha=a if s == 0 else a1,
                    label="__nolegend__",
                )
                plotter.plot_neurons(
                    axs,
                    valid_ipn_neurons[i],
                    select="axon",
                    color=pltltr.shift_lum(cols[3], -s),
                    linewidth=0.4,
                    rasterized=raster,
                    alpha=a if s == 0 else a1,
                    soma_s=0,
                    zorder=100,
                    label="__nolegend__",
                )

        all_axs[0, 1].text(
            106,
            0,
            f"c{i_l} ({len(idxs)} cells)",
            fontsize=8,
            c=cols[i_l],
            ha="center",
            va="top",
        )
        pdf.savefig(f)

In [None]:
alpha

###### plt.figure()
for i, l in enumerate(
    tree_sort_clusters
):  # range(clusters.max() + 1):  # clusters.max()):
    idxs = filt_idxs[np.argwhere(clusters == l)[:, 0]]
    filt = symm_cell_dist_axon[:, idxs][idxs, :].copy()
    np.fill_diagonal(filt, np.nan)
    s, b = np.histogram(filt, np.arange(0, 100, 5), density=True)
    plt.plot((b[:-1] + b[1:]) / 2, s + i / 20)

In [None]:
dist_mtx_toclust