In [None]:
%matplotlib widget
from pathlib import Path

import flammkuchen as fl
import matplotlib as mpl
import numpy as np
import pandas as pd
import scipy.cluster.hierarchy as sch
from bg_atlasapi import BrainGlobeAtlas
from lotr import DATASET_LOCATION
from lotr import plotting as pltltr
from matplotlib import pyplot as plt
from sklearn.metrics.pairwise import euclidean_distances
from tqdm import tqdm

cols = pltltr.COLS["qualitative"] * 5
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=cols)

atlas = BrainGlobeAtlas("ipn_zfish_0.5um")

In [None]:
anatomy_location = DATASET_LOCATION.parent / "anatomy"
neurons = fl.load(anatomy_location / "annotated_traced_neurons" / "all_skeletons.h5")

In [None]:
valid_ipn_neurons = [
    n
    for n in neurons.values()
    if (
        (n.comments[0] == "p" and n.comments.split(" - ")[1][0] == "i")
        or (n.comments[0] == "n")
    )
    and "??" not in n.comments
]

In [None]:
n_neurons = len(valid_ipn_neurons)
cell_dist = np.zeros((n_neurons,) * 2)

for i in tqdm(range(n_neurons)):
    for j in range(n_neurons):
        c1 = valid_ipn_neurons[i].coords_unilat_ipn  # coords of cell 1
        c2 = valid_ipn_neurons[j].coords_unilat_ipn  # coords of cell 2
        coords_dist = euclidean_distances(
            c1, c2
        )  # calculate pairwise Euclidean distance
        d1, d2 = np.min(coords_dist, 1), np.min(coords_dist, 0)
        cell_dist[i, j], cell_dist[j, i] = np.median(d1), np.median(d2)

symm_cell_dist = np.max(np.stack([cell_dist, cell_dist.T]), axis=0)

In [None]:
n_neurons = len(valid_ipn_neurons)
cell_dist = np.zeros((n_neurons,) * 2)

for i in tqdm(range(n_neurons)):
    for j in range(n_neurons):
        c1 = valid_ipn_neurons[i].coords_unilat_ipn[
            valid_ipn_neurons[i].dendr_idxs, :
        ]  # coords of cell 1
        c2 = valid_ipn_neurons[j].coords_unilat_ipn[
            valid_ipn_neurons[j].dendr_idxs, :
        ]  # coords of cell 2
        coords_dist = euclidean_distances(
            c1, c2
        )  # calculate pairwise Euclidean distance
        d1, d2 = np.min(coords_dist, 1), np.min(coords_dist, 0)
        cell_dist[i, j], cell_dist[j, i] = np.median(d1), np.median(d2)
symm_cell_dist_dendr = np.nanmax(np.stack([cell_dist, cell_dist.T]), axis=0)

In [None]:
n_neurons = len(valid_ipn_neurons)
cell_dist = np.zeros((n_neurons,) * 2)

for i in tqdm(range(n_neurons)):
    for j in range(n_neurons):
        if valid_ipn_neurons[i].has_axon and valid_ipn_neurons[j].has_axon:
            c1 = valid_ipn_neurons[i].coords_unilat_ipn[
                valid_ipn_neurons[i].axon_idxs, :
            ]  # coords of cell 1
            c2 = valid_ipn_neurons[j].coords_unilat_ipn[
                valid_ipn_neurons[j].axon_idxs, :
            ]  # coords of cell 2
            coords_dist = euclidean_distances(
                c1, c2
            )  # calculate pairwise Euclidean distance
            d1, d2 = np.min(coords_dist, 1), np.min(coords_dist, 0)
            cell_dist[i, j], cell_dist[j, i] = np.median(d1), np.median(d2)
        else:
            cell_dist[i, j], cell_dist[j, i] = np.nan, np.nan
symm_cell_dist_axon = np.nanmax(np.stack([cell_dist, cell_dist.T]), axis=0)

In [None]:
# For axons, we can't cluster all cells. For dendrites steps not really needed:
dist_mtx_toclust = symm_cell_dist_dendr
filt_idxs = np.arange(n_neurons)
selector = ~np.isnan(dist_mtx_toclust[0, :])
filt_idxs = filt_idxs[selector]
dist_mtx_toclust = dist_mtx_toclust[:, selector][selector, :]

In [None]:
linkage = sch.linkage(dist_mtx_toclust, method="ward")

In [None]:
thr = 120

f = plt.figure(figsize=(2, 3))
# Temp useless plot, just to get axes of right size:
ax = f.add_axes((0.2, 0.1, 0.6, 0.5))
ax.imshow(dist_mtx_toclust, vmin=0, vmax=50, cmap="pink")
pltltr.despine(ax, "all")

ax1 = f.add_axes((0.2, ax.get_position().y1, 0.6, 0.4))
with plt.rc_context({"lines.linewidth": 0.5}):
    dendrogram = sch.dendrogram(
        linkage, color_threshold=thr, above_threshold_color=(0.5,) * 3
    )
s_idxs = dendrogram["leaves"]
plt.xticks([])
plt.yticks([])
plt.axhline(thr, c=(0.3,) * 3, lw=0.5)

im = ax.imshow(dist_mtx_toclust[s_idxs, :][:, s_idxs], vmin=0, vmax=50, cmap="pink")
print(ax.get_position())

pltltr.add_cbar(
    im, ax, (1.15, 0.05, 0.05, 0.2), orientation="vertical", title="Dist.", titlesize=8
)
ax.set(
    xlabel="Sorted cell n.", ylabel="Sorted cell n.",
)
pltltr.despine(ax, "all")
pltltr.despine(ax1, "all")

plt.show()

In [None]:
pltltr.savefig("clust_dend")

In [None]:
clusters = sch.cut_tree(linkage, height=thr)[:, 0]

In [None]:
# ids = np.arange(n_neurons)[np.array(dendrogram["leaves"])]
# print(ids[idxs])
# ugly loop to figure out correspondance of dendrogram above with cluster identity:
# Maybe hierarchical clustering tools can provide a better way?
tree_sort_clusters = []
for i in clusters[np.array(dendrogram["leaves"])]:
    if i not in tree_sort_clusters:
        tree_sort_clusters.append(i)
tree_sort_clusters

In [None]:
#  Specify axes limits over all dimensions:
bs = dict(frontal=(30, 180), vertical=(20, 180), sagittal=(-20, 120))

plotter = pltltr.AtlasPlotter(
    structures=["ipn", "dipn"],
    mask_slices=dict(frontal=slice(0, 120)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)

In [None]:
raster = True


for l in tree_sort_clusters:  # range(clusters.max() + 1):  # clusters.max()):
    f, axs = plotter.generate_projection_plots(figsize=(6.5, 2), labels=True)
    idxs = filt_idxs[np.argwhere(clusters == l)[:, 0]]
    alpha = [0.3,] * len(idxs)  # different alpha for first neuron
    alpha[-1] = 1
    for i, a in zip(idxs, alpha):
        plotter.plot_neurons(
            axs,
            valid_ipn_neurons[i],
            select="dendrites",
            color=cols[0],
            linewidth=0.5,
            rasterized=raster,
            alpha=a,
            soma_s=0,
            # zorder=100,
            label="__nolegend__",
        )
        plotter.plot_neurons(
            axs,
            valid_ipn_neurons[i],
            select="dendrites",
            color=cols[1],
            linewidth=0,
            rasterized=raster,
            alpha=a,
            # zorder=100,
            label="__nolegend__",
        )
        try:
            plotter.plot_neurons(
                axs,
                valid_ipn_neurons[i],
                select="axon",
                color=cols[3],
                linewidth=0.4,
                rasterized=raster,
                alpha=a,
                soma_s=0,
                # zorder=100,
                label="__nolegend__",
            )
        except ValueError:
            pass
        alpha = 0.4

In [None]:
neurons_dict

In [None]:
s = 3


for l in range(clusters.max(), clusters.max() + 1):  # clusters.max()):
    idxs = np.argwhere(clusters == l)[:, 0]
    if idxs.shape[0] > 2:
        f, axs = plot_bg_mod(bg_dict)

        # for t in range(3):
        for j, idx in enumerate(idxs):
            n = neurons_list[idx]
            for i in range(3):
                uni_coords = n.coords_unilat  # [n.dendr_idxs, :]
                pts = np.delete(uni_coords, i, axis=1)[:, :]
                axs[0, i].scatter(
                    pts[n.dendr_idxs, 1],
                    pts[n.dendr_idxs, 0],
                    s=s,
                    color=cols[l],
                    alpha=0.02,
                )

                coords = n.coords_mpin  # [n.dendr_idxs, :]
                pts = np.delete(coords, i, axis=1)[:, :]
                for m, sel in enumerate([n.dendr_idxs, n.axon_idxs]):
                    axs[1 + m, i].scatter(
                        pts[sel, 1], pts[sel, 0], s=s / 20, color=cols[j % len(cols)]
                    )
                    axs[1 + m, i].scatter(
                        pts[n.soma_idx : n.soma_idx + 1, 1],
                        pts[n.soma_idx : n.soma_idx + 1, 0],
                        s=s * 10,
                        color=cols[j % len(cols)],
                    )

        axs[0, 2].set_title(f"(Clust. {l} {idxs})")

    # f.savefig(fig_folder / f"clust_{l}.png")