In [None]:
%matplotlib widget
%gui qt5
from pathlib import Path

import flammkuchen as fl
import matplotlib as mpl
import napari
import numpy as np
import pandas as pd
import scipy.cluster.hierarchy as sch
from bg_atlasapi import BrainGlobeAtlas
from lotr import DATASET_LOCATION
from lotr import plotting as pltltr
from matplotlib import pyplot as plt
from sklearn.metrics.pairwise import euclidean_distances
from tqdm import tqdm

cols = pltltr.COLS["qualitative"] * 5
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=cols)

atlas = BrainGlobeAtlas("ipn_zfish_0.5um")
atlas_mpin = BrainGlobeAtlas("mpin_zfish_1um")

In [None]:
anatomy_location = DATASET_LOCATION / "anatomy"
all_neurons = fl.load(
    anatomy_location / "annotated_traced_neurons" / "all_skeletons.h5"
)
# include only ipn neurons:
neurons_list = [
    n
    for n in all_neurons.values()
    if (
        (n.comments[0] == "p" and n.comments.split(" - ")[1][0] == "i")
        or (n.comments[0] == "n")
    )
    and "??" not in n.comments
]

In [None]:
labels = fl.load(anatomy_location / "em_soma_segmentation" / "soma_labels_uint16.h5")

# make background 0 and cells starting with 1:
labels[labels < 2] = 0
labels[labels > 1] = labels[labels > 1] - 1

In [None]:
data = fl.load(anatomy_location / "em_soma_segmentation" / "coords.h5")
seg_somas_coords, seg_somas_areas = data["coords"], data["areas"]

In [None]:
# offset of the columns:
start_x = 17000  # 20000
start_z = 17000  # 6000
start_y = 5800  # 20000

dwn = 1  # downsampling factor
mag = 4  # magnification

In [None]:
plt.figure()
plt.imshow(labels.mean(0))

In [None]:
coords = np.concatenate(
    [
        all_neurons[k].coords_em[all_neurons[k].dendr_idxs, :]
        for k in all_neurons
        if all_neurons[k].comments[0] == "n"
    ],
    0,
)

In [None]:
coords = np.concatenate(
    [
        all_neurons[k].coords_em[all_neurons[k].soma_idx, :][np.newaxis, :]
        for k in all_neurons
        if all_neurons[k].comments[0] == "n"
    ],
    0,
)

In [None]:
stack_offset = np.array([17000, 17000, 5800])
vox_vol = (0.014 * 0.014 * 0.025) * (
    8 ** 3
)  # voxel size in microns corrected for downsampling
stack_shape_vox = np.array([s * 8 for s in labels.shape])

# radii = ((areas * vox_vol) * 3 / (4 * np.pi)) ** (1 / 3)

In [None]:
proj = 1
d1, d2 = 2, 0
plt.figure()
plt.imshow(
    labels.mean(proj),
    extent=[
        stack_offset[d1],
        stack_offset[d1] + stack_shape_vox[d1],
        stack_offset[d2],
        stack_offset[d2] + stack_shape_vox[d2],
    ],
)
# plt.scatter(coords[:, d1], coords[:, d2], s=1, c="r", alpha=0.01)
plt.scatter(coords[:, d1], coords[:, d2], s=15, c="r")

In [None]:
from lotr.em.loading import neuron_from_xml

data_path = Path(anatomy_location / "em_soma_segmentation" / "cells_ipn.xml")
ipn_somas = neuron_from_xml(data_path)

In [None]:
somas = ipn_somas.coords_em

In [None]:
proj = 1
d1, d2 = 2, 0
plt.figure()
plt.imshow(
    labels.mean(proj),
    extent=[
        stack_offset[d1],
        stack_offset[d1] + stack_shape_vox[d1],
        stack_offset[d2],
        stack_offset[d2] + stack_shape_vox[d2],
    ],
)
plt.scatter(seg_somas_coords[:, d1], seg_somas_coords[:, d2], s=15, c="r", alpha=1)
# plt.scatter(somas[:, d1], somas[:, d2], s=15, c="r")

In [None]:
stack_offset = np.array([17000, 17000, 5800])

label_dist = []
soma_ids = []
soma_labels = []
for n in range(somas.shape[0]):
    soma_coords = somas[n, :]  # n.coords_em[n.soma_idx, :]
    distances = np.sqrt(np.sum((seg_somas_coords - soma_coords) ** 2, 1))

    s_c = ((soma_coords - stack_offset) / 8).astype(np.int)
    soma_labels.append(labels[s_c[0], s_c[1], s_c[2]])

    idx = np.argmin(distances)
    label_dist.append(distances[idx])
    soma_ids.append(idx)

In [None]:
suppsedly = seg_somas_coords[soma_ids, :]

In [None]:
ipn_somas = np.zeros(labels.shape, dtype=np.uint8)

In [None]:
n_ids = [n.id for n in neurons_list]

for n, soma_id in tqdm(list(enumerate(soma_ids))):
    # n = int(n[1:]) if n[0] == "n" else int(n[1:]) + 100
    # idxs = not_tr == int(n)
    # not_tr[idxs] = -1
    ipn_somas[labels == soma_id] = n

In [None]:
plt.figure()
plt.imshow(ipn_somas.max(1))

In [None]:
from scipy.ndimage import affine_transform

In [None]:
from lotr.em.transformations import em_2_ipnref, transform_points

In [None]:
stack_to_em = np.array(
    [
        [8, 0, 0, stack_offset[0]],
        [0, 8, 0, stack_offset[1]],
        [0, 0, 8, stack_offset[2]],
        [0, 0, 0, 1],
    ]
)
# diagonal is 8 for the downsampling

# Matrix to account for the fact that reference is 0.5 um in voxel size:
res_matrix = np.eye(4)
for i, r in enumerate(atlas.resolution):
    res_matrix[i] = 1 / r

In [None]:
test = np.ones((50, 50, 50))
out = np.zeros(atlas.shape, dtype=np.uint16)
output = affine_transform(
    labels,
    np.linalg.inv(resolution_mat @ em_2_ipnref @ stack_to_em),
    output=out,
    order=0,
)

In [None]:
plt.figure()
plt.imshow(output.mean(1))

In [None]:
ipn_mask = (atlas.get_structure_mask("ipn") / 2).astype(np.uint16)

In [None]:
import napari

In [None]:
napari.view_labels((output))