# Visualization example: neuron with spines

In [None]:
# Imports
from copy import deepcopy

import numpy as np
from neurom import NeuriteType
from scipy.spatial import KDTree
from trimesh import Trimesh
from trimesh import util as triutil
from trimesh.exchange import obj
from trimesh.visual.color import ColorVisuals, hsv_to_rgba

from morph_spines.utils.morph_spine_loader import load_morphology_with_spines

In [None]:
# Load data: neuron morphology and spines
filepath = "./data/morphology_with_spines/864691134884740346.h5"

morph_w_spines = load_morphology_with_spines(filepath)

# Or, alternatively, load them in different variables (from the same H5 file)
# from morph_spines.utils.morph_spine_loader import load_morphology, load_spines
# morph = load_morphology(filepath)
# spines = load_spines(filepath)

In [None]:
# Visualize the spines
colors = [[255, 150, 255]]
n = len(colors) + 1

splt = np.linspace(0, morph_w_spines.spines.spine_count, n).astype(int)
rnd = np.random.permutation(morph_w_spines.spines.spine_count)
grps = [rnd[a:b] for a, b in zip(splt[:-1], splt[1:], strict=True)]

spine_grp_meshes = []
for grp, col in zip(grps, colors, strict=False):
    spine_mesh = triutil.concatenate([morph_w_spines.spines.spine_mesh(_i) for _i in grp])
    spine_mesh.visual = ColorVisuals(mesh=spine_mesh, face_colors=col)
    spine_grp_meshes.append(spine_mesh)

spine_grp_meshes = triutil.concatenate(spine_grp_meshes)
spine_grp_meshes.show()

In [None]:
# Focus on the spines located on a given section
section_id = 9
section_mesh = triutil.concatenate(morph_w_spines.spines.spine_meshes_for_section(section_id))
section_mesh.show()

In [None]:
# Load the morphology mesh separately to test the spine locations are correct
mesh_path = "./data/morphology_meshes/864691134884740346.obj"
with open(mesh_path) as fid:
    all_mesh = obj.load_obj(fid)
all_mesh = all_mesh["geometry"][str(mesh_path)]
all_mesh["vertices"] = all_mesh["vertices"] / 1000.0
all_mesh = Trimesh(**all_mesh)
all_mesh.visual = ColorVisuals(mesh=all_mesh, face_colors=[100, 100, 255])

tst = all_mesh.kdtree.query_ball_tree(spine_grp_meshes.kdtree, 8e-2)
_v = np.array(list(map(len, tst))) == 0
nz = np.nonzero(_v)[0]
all_mesh.update_faces(np.all(np.isin(all_mesh.faces, nz), axis=1))
all_mesh.update_vertices(all_mesh.referenced_vertices)

all_mesh.show()

In [None]:
# Additional functions needed to plot both spine and morphology meshes together
def apply_filter_to_mesh(func, mesh):
    """Apply the given filter function to the mesh."""
    mesh = deepcopy(mesh)
    nz = np.nonzero(func(mesh.vertices))[0]
    mesh.update_faces(np.all(np.isin(mesh.faces, nz), axis=1))
    mesh.update_vertices(mesh.referenced_vertices)
    return mesh


def filter_close_to_section_points(morph_obj, section_id):
    """Returns a function that filters the 3D points close to section id geometry."""
    sec_pts = morph_obj.section(section_id).points[:, :3]

    def func(pts):
        tree = KDTree(sec_pts)
        return [len(_x) > 0 for _x in tree.query_ball_point(pts, 1.5)]

    return func


def cyclic_rgb_for_spines(sec_pos):
    """Returns a different RGB color depending on the section position."""
    n_hues = 4
    hues = np.linspace(0, 1.0, n_hues + 1)[:-1]
    hsv = np.vstack([hues, np.ones(len(hues)), 0.85 * np.ones(len(hues))]).transpose()
    idxx = sec_pos.argsort().to_numpy()
    hsv_out = -np.ones((len(idxx), hsv.shape[1]))
    hsv_out[idxx] = hsv[np.mod(np.arange(len(idxx)), n_hues)]

    return hsv_to_rgba(hsv_out)


def meshes_for_spines_on(spine_obj, section_id):
    """Returns the spine meshes for a given section id."""
    spine_ids = spine_obj.spine_table.index[
        spine_obj.spine_table["afferent_section_id"] == section_id
    ]
    rgbs = cyclic_rgb_for_spines(spine_obj.spine_table.loc[spine_ids, "afferent_section_pos"])

    spines_mesh = []
    for _i, _col in zip(spine_ids, rgbs, strict=False):
        _mesh = spine_obj.spine_mesh(_i)
        _mesh.visual = ColorVisuals(mesh=_mesh, face_colors=_col)
        spines_mesh.append(_mesh)
    spines_mesh = triutil.concatenate(spines_mesh)
    return spines_mesh


def show_for_section_id(morphology_with_spines, section_id, mesh):
    """Show the mesh for a given section id together with its spine meshes."""
    show_mesh = triutil.concatenate(
        [
            apply_filter_to_mesh(
                filter_close_to_section_points(morphology_with_spines.morphology, section_id - 1),
                mesh,
            ),
            meshes_for_spines_on(morphology_with_spines.spines, section_id),
        ]
    )
    return show_mesh.show()

In [None]:
# Visualize selected morphology sections with its spines (from separate meshes)
section_id = 2
secs = [sec.id for sec in morph_w_spines.morphology.sections if sec.type != NeuriteType.axon]
show_for_section_id(morph_w_spines, secs[section_id], all_mesh)

In [None]:
# Visualize selected morphology sections with its spines (from separate meshes)
section_id = 3
show_for_section_id(morph_w_spines, secs[section_id], all_mesh)

In [None]:
# Visualize selected morphology sections with its spines (from separate meshes)
section_id = 8
show_for_section_id(morph_w_spines, secs[section_id], all_mesh)

In [None]:
# Visualize selected morphology sections with its spines (from separate meshes)
section_id = 9
show_for_section_id(morph_w_spines, secs[section_id], all_mesh)

In [None]:
# Visualize the whole morphology mesh with all spines (if small enough)
with open(mesh_path) as fid:
    all_mesh = obj.load_obj(fid)
all_mesh = all_mesh["geometry"][str(mesh_path)]
all_mesh["vertices"] = all_mesh["vertices"] / 1000
all_mesh = Trimesh(**(all_mesh))
all_mesh.visual = ColorVisuals(mesh=all_mesh, face_colors=[200, 200, 200])

whole_mesh = triutil.concatenate(all_mesh + spine_grp_meshes)

print(whole_mesh)
whole_mesh.show()