In [None]:
%matplotlib widget
from pathlib import Path

import flammkuchen as fl
import matplotlib as mpl
import numpy as np
import pandas as pd
import scipy.cluster.hierarchy as sch
from bg_atlasapi import BrainGlobeAtlas
from lotr import DATASET_LOCATION
from lotr import plotting as pltltr
from lotr.em.loading import load_skeletons_from_xml
from matplotlib import pyplot as plt
from scipy.ndimage import gaussian_filter
from tqdm import tqdm

cols = pltltr.COLS["qualitative"] * 5
cols2 = pltltr.COLS["fish_cols"] * 5
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=cols)

atlas = BrainGlobeAtlas("ipn_zfish_0.5um")

In [None]:
# master_folder = Path(r"J:\_Shared\experiments\E0060_ipn_em_tracing")
master_folder = DATASET_LOCATION.parent / "anatomy"
data_folder = master_folder / "synapse_annotations"

In [None]:
annotations_n = [53, 84]

annotations_dict = dict()

for annotation in annotations_n:
    neuron_dict = dict()
    neuron_and_synapses = load_skeletons_from_xml(
        data_folder / f"annotation_{annotation}.xml"
    )
    print([n.id for n in neuron_and_synapses if "Syna" not in n.id])
    neuron_dict["neuron"] = [n for n in neuron_and_synapses if "Syna" not in n.id]
    neuron_dict["synapses"] = [n for n in neuron_and_synapses if "Syna" in n.id]

    annotations_dict[annotation] = neuron_dict

In [None]:
#  Specify axes limits over all dimensions:
bs = dict(frontal=(30, 180), vertical=(20, 190), sagittal=(-20, 120))

plotter = pltltr.AtlasPlotter(
    structures=["ipn", "dipn"],
    mask_slices=dict(frontal=slice(0, 120)),
    bounds_dict=dict(
        frontal=[bs["vertical"], bs["frontal"]],
        horizontal=[bs["sagittal"], bs["frontal"]],
        sagittal=[bs["vertical"], bs["sagittal"]],
    ),
)

In [None]:
f, axs = plt.subplots(
    1,
    3,
    figsize=(6.2, 2.2),
    gridspec_kw=dict(left=0.02, right=0.95, top=1, hspace=0, wspace=0.1),
)
f, axs = plotter.generate_projection_plots(axs, labels=True, edge=120,)
for n, col in zip(annotations_n, cols):
    print(n)
    n_dict = annotations_dict[n]
    plotter.plot_neurons(
        axs, n_dict["neuron"], select="all", color=col, linewidth=0.5, rasterized=True
    )
    plotter.plot_neurons(
        axs,
        n_dict["synapses"],
        select="all",
        color=(0.1,) * 3,
        linewidth=1,
        rasterized=True,
    )

pltltr.savefig("synapses_traced")

In [None]:
len(n_dict["synapses"])

In [None]:
neurons = fl.load(
    DATASET_LOCATION.parent
    / "anatomy"
    / "annotated_traced_neurons"
    / "all_skeletons.h5"
)
postsyn_partners = [neurons[f"p0{i}"] for i in range(41, 57)]

In [None]:
f, axs = plt.subplots(
    1,
    3,
    figsize=(6.2, 2.2),
    gridspec_kw=dict(left=0.02, right=0.95, top=1, hspace=0, wspace=0.1),
)
f, axs = plotter.generate_projection_plots(axs, labels=True, edge=120,)
for n, col in zip(annotations_n, cols):
    print(n)
    n_dict = annotations_dict[n]
    plotter.plot_neurons(axs, n_dict["neuron"], select="all", color=col, linewidth=0.5)
    plotter.plot_neurons(
        axs, n_dict["synapses"], select="all", color=(0.1,) * 3, linewidth=1
    )

pltltr.savefig("synapses_traced")

In [None]:
postsyn_dict = {
    "Habenular axons": [n for n in postsyn_partners if "hab" in n.comments],
    "aHB neurons": [n for n in postsyn_partners if "ahb" in n.comments],
    "IPN neurons": [
        n for n in postsyn_partners if n.comments.split(" - ")[1][:3] == "ipn"
    ],
    "others": [
        n
        for n in postsyn_partners
        if n.comments.split(" - ")[1][:3] != "ipn"
        and "ahb" not in n.comments
        and "hab" not in n.comments
    ],
}

In [None]:
f, axs = plt.subplots(
    1,
    3,
    figsize=(6.2, 2.2),
    gridspec_kw=dict(left=0.02, right=0.95, top=1, hspace=0, wspace=0.1),
)
f, axs = plotter.generate_projection_plots(axs, labels=True, edge=120,)

for k, col in zip(postsyn_dict.keys(), cols):
    neurons = postsyn_dict[k]
    for n in neurons:
        plotter.plot_neurons(axs, n, select="all", color=col, linewidth=0.5, alpha=0.4)

# for n, col in zip(annotations_n, cols):
#    n_dict = annotations_dict[n]
#    plotter.plot_neurons(axs, n_dict["neuron"], select="all", color=col, linewidth=0.5)

pltltr.savefig("postsyn_partners")

In [None]:
for k, col in zip(postsyn_dict.keys(), cols):
    neurons = postsyn_dict[k]
    print("-------")
    for n in neurons:
        print(n.comments)