# Define clusters for IPN neurons

This is a finnicky notebook written in the rush of thesis finishing, careful when running.

In [None]:
%matplotlib widget
from pathlib import Path

import flammkuchen as fl
import matplotlib as mpl
import numpy as np
import pandas as pd
import scipy.cluster.hierarchy as sch
from bg_atlasapi import BrainGlobeAtlas
from lotr import DATASET_LOCATION
from lotr import plotting as pltltr
from matplotlib import pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from sklearn.metrics.pairwise import euclidean_distances
from tqdm import tqdm

cols = pltltr.COLS["qualitative"] * 5
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=cols)

atlas = BrainGlobeAtlas("ipn_zfish_0.5um")

In [None]:
anatomy_location = DATASET_LOCATION.parent / "anatomy"
neurons = fl.load(anatomy_location / "annotated_traced_neurons" / "all_skeletons.h5")

In [None]:
all_ipn_neurons = [
    n
    for n in neurons.values()
    if (
        (n.comments[0] == "p" and n.comments.split(" - ")[1][0] == "i")
        or (n.comments[0] == "n")
    )
]
valid_ipn_neurons = [
    n for n in all_ipn_neurons if "??" not in n.comments and n.has_axon
]
print(
    f"included {len(valid_ipn_neurons)} out of {len(all_ipn_neurons)} seeded IPN neurons"
)

n_neurons = len(valid_ipn_neurons)
cell_dist = np.zeros((n_neurons,) * 2)

for i in tqdm(range(n_neurons)):
    for j in range(n_neurons):
        c1 = valid_ipn_neurons[i].coords_unilat_ipn  # coords of cell 1
        c2 = valid_ipn_neurons[j].coords_unilat_ipn  # coords of cell 2
        coords_dist = euclidean_distances(
            c1, c2
        )  # calculate pairwise Euclidean distance
        d1, d2 = np.min(coords_dist, 1), np.min(coords_dist, 0)
        cell_dist[i, j], cell_dist[j, i] = np.median(d1), np.median(d2)

symm_cell_dist = np.max(np.stack([cell_dist, cell_dist.T]), axis=0)

In [None]:
# Cluster over dendrites:
n_neurons = len(valid_ipn_neurons)
cell_dist = np.zeros((n_neurons,) * 2)

for i in tqdm(range(n_neurons)):
    for j in range(n_neurons):
        c1 = valid_ipn_neurons[i].coords_unilat_ipn[
            valid_ipn_neurons[i].dendr_idxs, :
        ]  # coords of cell 1
        c2 = valid_ipn_neurons[j].coords_unilat_ipn[
            valid_ipn_neurons[j].dendr_idxs, :
        ]  # coords of cell 2
        coords_dist = euclidean_distances(
            c1, c2
        )  # calculate pairwise Euclidean distance
        d1, d2 = np.min(coords_dist, 1), np.min(coords_dist, 0)
        cell_dist[i, j], cell_dist[j, i] = np.median(d1), np.median(d2)
symm_cell_dist_dendr = np.nanmax(np.stack([cell_dist, cell_dist.T]), axis=0)

In [None]:
# Cluster over axons:
n_neurons = len(valid_ipn_neurons)
cell_dist = np.zeros((n_neurons,) * 2)

for i in tqdm(range(n_neurons)):
    for j in range(n_neurons):
        if valid_ipn_neurons[i].has_axon and valid_ipn_neurons[j].has_axon:
            c1 = valid_ipn_neurons[i].coords_unilat_ipn[
                valid_ipn_neurons[i].axon_idxs, :
            ]  # coords of cell 1
            c2 = valid_ipn_neurons[j].coords_unilat_ipn[
                valid_ipn_neurons[j].axon_idxs, :
            ]  # coords of cell 2
            coords_dist = euclidean_distances(
                c1, c2
            )  # calculate pairwise Euclidean distance
            d1, d2 = np.min(coords_dist, 1), np.min(coords_dist, 0)
            cell_dist[i, j], cell_dist[j, i] = np.median(d1), np.median(d2)
        else:
            cell_dist[i, j], cell_dist[j, i] = np.nan, np.nan
symm_cell_dist_axon = np.nanmin(np.stack([cell_dist, cell_dist.T]), axis=0)

In [None]:
# For axons, we can't cluster all cells. For dendrites steps not really needed:
dist_mtx_toclust = symm_cell_dist_dendr
filt_idxs = np.arange(n_neurons)
selector = ~np.isnan(dist_mtx_toclust[0, :])
filt_idxs = filt_idxs[selector]
dist_mtx_toclust = dist_mtx_toclust[:, selector][selector, :]

In [None]:
linkage = sch.linkage(dist_mtx_toclust, method="ward")

In [None]:
thr = 120

f = plt.figure(figsize=(2, 3))
# Temp useless plot, just to get axes of right size:
ax = f.add_axes((0.2, 0.1, 0.6, 0.5))
ax.imshow(dist_mtx_toclust, vmin=0, vmax=50, cmap="pink")
pltltr.despine(ax, "all")

ax1 = f.add_axes((0.2, ax.get_position().y1, 0.6, 0.4))
with plt.rc_context({"lines.linewidth": 0.5}):
    dendrogram = sch.dendrogram(
        linkage, color_threshold=thr, above_threshold_color=(0.5,) * 3
    )
s_idxs = dendrogram["leaves"]
plt.xticks([])
plt.yticks([])
plt.axhline(thr, c=(0.3,) * 3, lw=0.5)

im = ax.imshow(symm_cell_dist_dendr[s_idxs, :][:, s_idxs], vmin=0, vmax=50, cmap="pink")
print(ax.get_position())

pltltr.add_cbar(
    im, ax, (1.15, 0.05, 0.05, 0.2), orientation="vertical", title="Dist.", titlesize=8
)
ax.set(
    xlabel="Sorted cell n.", ylabel="Sorted cell n.",
)
pltltr.despine(ax, "all")
pltltr.despine(ax1, "all")

plt.show()
pltltr.savefig("clustering_dendr")

In [None]:
clusters = np.array(sch.cut_tree(linkage, height=thr)[:, 0])

In [None]:
# ids = np.arange(n_neurons)[np.array(dendrogram["leaves"])]
# print(ids[idxs])
# ugly loop to figure out correspondance of dendrogram above with cluster identity:
# Maybe hierarchical clustering tools can provide a better way?
tree_sort_clusters = []
for i in clusters[np.array(dendrogram["leaves"])]:
    if i not in tree_sort_clusters:
        tree_sort_clusters.append(i)

reindexed_clusters = np.zeros(len(clusters), dtype=int)
for n, i in enumerate(tree_sort_clusters):
    reindexed_clusters[clusters == i] = n

In [None]:
#  Specify axes limits over all dimensions:
bs = dict(frontal=(30, 180), vertical=(20, 180), sagittal=(-20, 120))

plotter = pltltr.AtlasPlotter(
    structures=["ipn", "dipn"],
    mask_slices=dict(frontal=slice(0, 120)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)

In [None]:
raster = True

for i_l, l in enumerate(range(reindexed_clusters.max() + 1)):  # clusters.max()):
    f, axs = plotter.generate_projection_plots(figsize=(6.5, 2), labels=True)
    # idxs = filt_idxs[np.argwhere(clusters == l)[:, 0]]
    idxs = np.argwhere(reindexed_clusters == l)[:, 0]
    alpha = [0] * len(idxs)  # different lum for first neuron
    alpha[-1] = 0.3
    a = 0.5
    a1 = 0.95
    for i, s in zip(idxs, alpha):
        plotter.plot_neurons(
            axs,
            valid_ipn_neurons[i],
            select="dendrites",
            color=pltltr.shift_lum(cols[0], 0.15 - s),
            linewidth=0.5,
            rasterized=raster,
            alpha=a if s == 0 else a1,
            soma_s=0,
            label="__nolegend__",
        )
        plotter.plot_neurons(
            axs,
            valid_ipn_neurons[i],
            select="dendrites",
            color=pltltr.shift_lum(cols[1], 0.15 - s),
            linewidth=0,
            rasterized=raster,
            alpha=a if s == 0 else a1,
            zorder=100,
            label="__nolegend__",
        )
        plotter.plot_neurons(
            axs,
            valid_ipn_neurons[i],
            select="axon",
            color=pltltr.shift_lum(cols[3], -s),
            linewidth=0.4,
            rasterized=raster,
            alpha=a if s == 0 else a1,
            soma_s=0,
            # zorder=100,
            label="__nolegend__",
        )

        alpha = 0.4

    axs[1].text(
        106,
        0,
        f"c{i_l} ({len(idxs)} cells)",
        fontsize=8,
        c=cols[i_l],
        ha="center",
        va="top",
    )

In [None]:
idxs = filt_idxs[np.argwhere(clusters == tree_sort_clusters[5])[:, 0]]
len(idxs)

In [None]:
clust_to_split = 5
l = tree_sort_clusters[clust_to_split]
idxs_cur_clust = filt_idxs[np.argwhere(clusters == l)[:, 0]]
clust_axons_dmat = symm_cell_dist_axon[idxs_cur_clust, :][:, idxs_cur_clust]
linkage_axons_clust = sch.linkage(clust_axons_dmat, method="ward")

In [None]:
thr_dict = dict(c5=55, c6=90)

thr = thr_dict[f"c{clust_to_split}"]


f = plt.figure(figsize=(2, 3))
# Temp useless plot, just to get axes of right size:
ax = f.add_axes((0.2, 0.1, 0.6, 0.5))
ax.imshow(clust_axons_dmat, vmin=0, vmax=50, cmap="pink")
pltltr.despine(ax, "all")

ax1 = f.add_axes((0.2, ax.get_position().y1, 0.6, 0.4))
with plt.rc_context({"lines.linewidth": 0.5}):
    dendrogram = sch.dendrogram(
        linkage_axons_clust, color_threshold=thr, above_threshold_color=(0.5,) * 3
    )
s_idxs = dendrogram["leaves"]
plt.xticks([])
plt.yticks([])
plt.axhline(thr, c=(0.3,) * 3, lw=0.5)

im = ax.imshow(clust_axons_dmat[s_idxs, :][:, s_idxs], vmin=0, vmax=50, cmap="pink")
print(ax.get_position())

pltltr.add_cbar(
    im, ax, (1.15, 0.05, 0.05, 0.2), orientation="vertical", title="Dist.", titlesize=8
)
ax.set(
    xlabel="Sorted cell n.", ylabel="Sorted cell n.",
)
pltltr.despine(ax, "all")
pltltr.despine(ax1, "all")

plt.show()

In [None]:
subclusters = np.array(sch.cut_tree(linkage_axons_clust, height=thr)[:, 0])

In [None]:
raster = True

for l in range(subclusters.max() + 1):  # clusters.max()):
    f, axs = plotter.generate_projection_plots(figsize=(6.5, 2), labels=True)
    idxs = idxs_cur_clust[
        subclusters == l
    ]  # filt_idxs[np.argwhere(clusters == l)[:, 0]]
    alpha = [0] * len(idxs)  # different lum for first neuron
    alpha[-1] = 0.3
    a = 0.5
    a1 = 0.95
    for i, s in zip(idxs, alpha):
        plotter.plot_neurons(
            axs,
            valid_ipn_neurons[i],
            select="dendrites",
            color=pltltr.shift_lum(cols[0], 0.15 - s),
            linewidth=0.5,
            rasterized=raster,
            alpha=a if s == 0 else a1,
            soma_s=0,
            zorder=-100,
            label="__nolegend__",
        )
        plotter.plot_neurons(
            axs,
            valid_ipn_neurons[i],
            select="dendrites",
            color=pltltr.shift_lum(cols[1], 0.15 - s),
            linewidth=0,
            rasterized=raster,
            alpha=a if s == 0 else a1,
            zorder=1000,
            label="__nolegend__",
        )
        try:
            plotter.plot_neurons(
                axs,
                valid_ipn_neurons[i],
                select="axon",
                color=pltltr.shift_lum(cols[3], -s),
                linewidth=0.4,
                rasterized=raster,
                alpha=a if s == 0 else a1,
                soma_s=0,
                zorder=-100,
                label="__nolegend__",
            )
        except ValueError:
            pass
        alpha = 0.4

## Decision on subclasters
The last subcluster is clearly a defined category, and it seems that the second to last cluster is made of neurons with a similar projection but incomplete axons

In [None]:
split_group = (subclusters != 5) & (subclusters != 3)
idxs_clust_to_split = np.argwhere(reindexed_clusters == clust_to_split)[:, 0]
reindexed_clusters[idxs_clust_to_split[split_group]] = np.max(reindexed_clusters) + 1

In [None]:
raster = True

for i_l, l in enumerate(range(reindexed_clusters.max() + 1)):  # clusters.max()):
    f, axs = plotter.generate_projection_plots(figsize=(6.5, 2), labels=True)
    # idxs = filt_idxs[np.argwhere(clusters == l)[:, 0]]
    idxs = np.argwhere(reindexed_clusters == l)[:, 0]
    alpha = [0] * len(idxs)  # different lum for first neuron
    alpha[-1] = 0.3
    a = 0.5
    a1 = 0.95
    for i, s in zip(idxs, alpha):
        plotter.plot_neurons(
            axs,
            valid_ipn_neurons[i],
            select="dendrites",
            color=pltltr.shift_lum(cols[0], 0.15 - s),
            linewidth=0.5,
            rasterized=raster,
            alpha=a if s == 0 else a1,
            soma_s=0,
            label="__nolegend__",
        )
        plotter.plot_neurons(
            axs,
            valid_ipn_neurons[i],
            select="dendrites",
            color=pltltr.shift_lum(cols[1], 0.15 - s),
            linewidth=0,
            rasterized=raster,
            alpha=a if s == 0 else a1,
            zorder=100,
            label="__nolegend__",
        )
        plotter.plot_neurons(
            axs,
            valid_ipn_neurons[i],
            select="axon",
            color=pltltr.shift_lum(cols[3], -s),
            linewidth=0.4,
            rasterized=raster,
            alpha=a if s == 0 else a1,
            soma_s=0,
            # zorder=100,
            label="__nolegend__",
        )

        alpha = 0.4

    axs[1].text(
        106,
        0,
        f"c{i_l} ({len(idxs)} cells)",
        fontsize=8,
        c=cols[i_l],
        ha="center",
        va="top",
    )

# Final cluster plot - all clusters and individual neurons

In [None]:
all_coords = np.concatenate([n.coords_ipn for n in valid_ipn_neurons])

In [None]:
nice_n = [0, 1, 1, 0, -3, -1, -1, -1, 0, -1, 0]

raster = True
n_clust = reindexed_clusters.max() + 1

# Exclude the first cluster with only two neurons:
for i_fig, c_range in enumerate([range(1, 6), range(6, 11)]):
    f, all_axs = plt.subplots(
        5,
        3,
        figsize=(6.2, 8),
        gridspec_kw=dict(top=1, bottom=0.1, left=0, right=1, wspace=0, hspace=0),
    )
    for i_l, l in enumerate(c_range):  # clusters.max()):
        axs = all_axs[i_l, :]
        plotter.generate_projection_plots(axs, labels=True)
        # idxs = filt_idxs[np.argwhere(clusters == l)[:, 0]]
        idxs = np.argwhere(reindexed_clusters == l)[:, 0]
        # Plot nice neuron as last one:
        p = nice_n[l]
        idxs = np.concatenate([idxs[:p], idxs[p + 1 :], idxs[p : p + 1]])
        alpha = [0] * len(idxs)  # different lum for first neuron
        alpha[-1] = 0.3
        a = 0.5
        a1 = 0.95
        for ax in axs:
            plotter.axs_scatterplot(axs, all_coords, s=0, rasterized=raster)
        for i, s in zip(idxs, alpha):

            plotter.plot_neurons(
                axs,
                valid_ipn_neurons[i],
                select="dendrites",
                color=pltltr.shift_lum(cols[0], 0.15 - s),
                linewidth=0.5,
                rasterized=raster,
                alpha=a if s == 0 else a1,
                soma_s=0,
                zorder=-10,
                label="__nolegend__",
            )
            plotter.plot_neurons(
                axs,
                valid_ipn_neurons[i],
                select="dendrites",
                color=pltltr.shift_lum(cols[1], 0.15 - s),
                linewidth=0,
                rasterized=raster,
                alpha=a if s == 0 else a1,
                zorder=1000,
                label="__nolegend__",
            )
            plotter.plot_neurons(
                axs,
                valid_ipn_neurons[i],
                select="axon",
                color=pltltr.shift_lum(cols[3], -s),
                linewidth=0.4,
                rasterized=raster,
                alpha=a if s == 0 else a1,
                soma_s=0,
                zorder=-10,
                label="__nolegend__",
            )

            # alpha = 0.4

        axs[1].text(
            106,
            0,
            f"c{l} ({len(idxs)} cells)",
            fontsize=8,
            c=cols[i_l],
            ha="center",
            va="top",
        )

    pltltr.savefig(f"clusters_{i_fig}.pdf")

# Export cluster identity

In [None]:
clust_ids = {n.id: l for l, n in zip(reindexed_clusters, valid_ipn_neurons)}

In [None]:
fl.save(DATASET_LOCATION.parent / "anatomy" / "ipn_morph_clust.h5", clust_ids)

In [None]:
from lotr import FIGURES_LOCATION

In [None]:
raster = True
FIGURES_LOCATION
with PdfPages(FIGURES_LOCATION / "ipn_neurons_clusters.pdf") as pdf:
    for i_l, l in enumerate(
        tree_sort_clusters
    ):  # range(clusters.max() + 1):  # clusters.max()):
        idxs = filt_idxs[np.argwhere(clusters == l)[:, 0]]

        f, all_axs = plt.subplots(
            len(idxs) + 1, 3, figsize=(6.5, 2.3 * (len(idxs) + 1))
        )
        for i in range(all_axs.shape[0]):
            plotter.generate_projection_plots(all_axs[i, :], labels=True)
        alpha = [0] * len(idxs)  # different lum for first neuron
        alpha[-1] = 0.3
        a1 = 0.95
        for n, (i, s) in enumerate(zip(idxs, alpha)):
            for n_plot, axs in enumerate(all_axs[[0, n + 1], :]):
                a = 0.5
                if n_plot > 0:
                    s = alpha[-1]
                    a = a1
                plotter.plot_neurons(
                    axs,
                    valid_ipn_neurons[i],
                    select="dendrites",
                    color=pltltr.shift_lum(cols[0], 0.15 - s),
                    linewidth=0.5,
                    rasterized=raster,
                    alpha=a if s == 0 else a1,
                    soma_s=0,
                    label="__nolegend__",
                )
                plotter.plot_neurons(
                    axs,
                    valid_ipn_neurons[i],
                    select="dendrites",
                    color=pltltr.shift_lum(cols[1], 0.15 - s),
                    linewidth=0,
                    rasterized=raster,
                    alpha=a if s == 0 else a1,
                    label="__nolegend__",
                )
                plotter.plot_neurons(
                    axs,
                    valid_ipn_neurons[i],
                    select="axon",
                    color=pltltr.shift_lum(cols[3], -s),
                    linewidth=0.4,
                    rasterized=raster,
                    alpha=a if s == 0 else a1,
                    soma_s=0,
                    zorder=100,
                    label="__nolegend__",
                )

        all_axs[0, 1].text(
            106,
            0,
            f"c{i_l} ({len(idxs)} cells)",
            fontsize=8,
            c=cols[i_l],
            ha="center",
            va="top",
        )
        pdf.savefig(f)

In [None]:
valid_ipn_neurons