In [6]:
import logging

import numpy as np
from scipy.spatial.transform import Rotation
import vedo

from helper_loader import *
from histalign.backend.ccf.paths import get_atlas_path
from histalign.backend.io import gather_alignment_paths, load_alignment_settings
from histalign.backend.models import HistologySettings, Orientation, VolumeSettings
from histalign.backend.registration.alignment import (
    ALIGNMENT_VOLUMES_CACHE_DIRECTORY,
    generate_hash_from_targets,
)
from histalign.backend.workspace import Volume, VolumeSlicer

_module_logger = logging.getLogger("histalign.notebook")

In [2]:
def compute_normal(pitch: int, yaw: int, orientation: Orientation) -> np.ndarray:
    match orientation:
        case Orientation.CORONAL:
            normal = [-1, 0, 0]
            rotation = Rotation.from_euler("ZY", [pitch, yaw], degrees=True)
        case Orientation.HORIZONTAL:
            normal = [0, 1, 0]
            rotation = Rotation.from_euler("ZX", [pitch, yaw], degrees=True)
        case Orientation.SAGITTAL:
            normal = [0, 0, 1]
            rotation = Rotation.from_euler("XY", [pitch, yaw], degrees=True)

    return rotation.apply(normal)


# Overly complicated but matches logic of alternative
# def find_nearest_grid_point(point: np.ndarray) -> np.ndarray:
#     points = [np.floor(point)]
#     for i in range(3):
#         point1 = points[0].copy()
#         point1[-i - 1] += 1
#         points.append(point1)
#     other_points = [np.ceil(point)]
#     for i in range(3):
#         point2 = other_points[-1].copy()
#         point2[-i - 1] -= 1
#         other_points.insert(0, point2)
#
#     points.extend(other_points)
#
#     distances = np.sum(np.abs(points - point), axis=1)
#
#     return points[np.argmin(distances)]


def find_nearest_grid_point(point: np.ndarray) -> np.ndarray:
    return np.round(point)

## Visualisation

In [3]:
cube_length = 5
orientation = Orientation.CORONAL
camera = coronal_camera
pitch = 45
yaw = 0

update_cameras((cube_length,) * 3)

points = vedo.Points(
    np.vstack(
        np.meshgrid(
            np.linspace(0, cube_length - 1, cube_length, dtype=int),
            np.linspace(0, cube_length - 1, cube_length, dtype=int),
            np.linspace(0, cube_length - 1, cube_length, dtype=int),
            indexing="ij",
        ),
        dtype=int,
    )
    .reshape(3, -1)
    .T,
    r=8,
)

start_point = np.array([cube_length // 2, cube_length // 2, cube_length // 2])
normal = compute_normal(pitch, yaw, orientation)
end_point = start_point + normal
arrow = vedo.Arrow(
    start_pt=start_point,
    end_pt=end_point,
)

plane = vedo.Plane(pos=end_point, normal=normal, s=(cube_length,) * 2)

show(
    [
        points,
        arrow,
        plane,
    ],
    camera=camera,
)

## Computation

In [4]:
cube_length = 5
orientation = Orientation.CORONAL
camera = coronal_camera
pitch = 45
yaw = 0

start_point = np.array([cube_length // 2, cube_length // 2, cube_length // 2])
normal = compute_normal(pitch, yaw, orientation)
end_point = start_point + normal
print(f"End point: {end_point}")

plane = vedo.Plane(pos=start_point, normal=normal, s=(cube_length,) * 2)

nearest_grid_point = find_nearest_grid_point(end_point)
print(f"Nearest grid point to end point: {nearest_grid_point}")
distance_to_grid = vedo.Point(find_nearest_grid_point(end_point)).distance_to(plane)[0]
print(f"Distance from plane to grid point: {distance_to_grid}")

End point: [1.29289322 1.29289322 2.        ]
Nearest grid point to end point: [1. 1. 2.]
Distance from plane to grid point: 1.4142135623730951


Assuming a volume resolution of 25 microns and a Z-stack of 21 images, each 10 microns apart.

The alignment will have been done on a maximum intensity projection of the whole stack. Hence, the alignment coordinates will represent the centre of the stack. For each multiple of the normal, we can try getting a new projection of the stack around that distance to insert into the volume.
Find N closest points and categorise Z indices based on which point they should map to.
Get maximum intensity projection of stack segments per point.
Slice volume at those points, insert projections.

In [360]:
shape = (528, 320, 456)
orientation = Orientation.CORONAL
resolution = Resolution.MICRONS_25
z_stack_count = 21
z_stack_distance = 10

pitch = 10
yaw = 4

normal = compute_normal(pitch, yaw, orientation)

alignment_coordinates = np.array([264, 180, 228])


def group_to_closest_grid_point(
    volume_resolution: int,
    stack_count: int,
    stack_distance: int,
) -> list[int]:
    """Generates a group index for each image of a Z-stack.

    The group index indicates the closest grid point the image should be projected
    to.

    Args:
        volume_resolution (int): Resolution in microns of the reference volume.
        stack_count (int): How many images are present in the stack.
        stack_distance (int): Z-distance of the stack in microns.

    Returns:
        list[int]: A group index list of length `stack_count`.
    """
    middle_index = (stack_count - 1) // 2

    groups = []
    for i in range(stack_count):
        slice_distance = (i - middle_index) * stack_distance

        times, remainder = divmod(slice_distance, volume_resolution)
        group = times + int(remainder > volume_resolution / 2)

        groups.append(group)

    return groups


def subproject_array(array: np.ndarray, groups: list) -> np.ndarray:
    array_dimension_count = len(array.shape)
    if array_dimension_count != 3:
        # Allow 2D arrays for convenience
        if array_dimension_count != 2:
            raise ValueError(
                f"Cannot project array with {array_dimension_count} dimensions."
            )

        return array

    if not groups:
        raise ValueError("Groups is empty.")
    elif len(groups) != array.shape[0]:
        raise ValueError("Not enough groups to cover input array's first dimension.")

    projected_array = None

    current_group = groups[0]
    start_index = 0
    for index, group in enumerate(groups):
        if group == current_group and index < len(groups) - 1:
            continue

        if index == len(groups) - 1:
            index += 1

        if projected_array is None:
            projected_array = np.max(array[start_index:index], axis=0, keepdims=True)
        else:
            projected_array = np.vstack(
                (
                    projected_array,
                    np.max(array[start_index:index], axis=0, keepdims=True),
                )
            )

        start_index = index
        current_group = group

    return projected_array


def generate_aligned_planes(
    alignment_volume: Volume | vedo.Volume,
    alignment_paths: list[Path],
) -> list[vedo.Plane]:
    _module_logger.debug(
        f"Starting generation of {len(alignment_paths)} aligned planes."
    )

    planes = []
    slicer = VolumeSlicer(volume=alignment_volume)

    for index, alignment_path in enumerate(alignment_paths):
        if index > 0 and index % 5 == 0:
            _module_logger.debug(f"Generated {index} aligned planes...")

        alignment_settings = load_alignment_settings(alignment_path)
        histology_slice = load_image(alignment_settings.histology_path)

        registrator = Registrator(True, True)
        registered_slice = registrator.get_forwarded_image(
            histology_slice, alignment_settings
        )

        plane_mesh = slicer.slice(alignment_settings.volume_settings, return_mesh=True)

        # Undo padding. See `VolumeSlicer.slice` for more details.
        registered_slice = registered_slice[
            plane_mesh.metadata["i_padding"][0] : registered_slice.shape[0]
            - plane_mesh.metadata["i_padding"][1],
            plane_mesh.metadata["j_padding"][0] : registered_slice.shape[1]
            - plane_mesh.metadata["j_padding"][1],
        ]

        plane_mesh.pointdata["ImageScalars"] = registered_slice.flatten()

        planes.append(plane_mesh)

    _module_logger.debug(
        f"Finished generating all {len(alignment_paths)} aligned planes."
    )
    return planes


def insert_aligned_planes_into_array(
    array: np.ndarray,
    planes: list[vedo.Plane],
    inplace: bool = True,
) -> np.ndarray:
    _module_logger.debug(f"Starting insertion of {len(planes)} into alignment array.")

    if not inplace:
        array = array.copy()

    for index, plane in enumerate(planes):
        if index > 0 and index % 5 == 0:
            _module_logger.debug(f"Inserted {index} into alignment array...")

        temporary_volume = vedo.Volume(np.zeros_like(array))
        temporary_volume.interpolate_data_from(plane, radius=1)

        temporary_array = temporary_volume.tonumpy()
        temporary_array = np.round(temporary_array).astype(np.uint16)

        array[:] = np.maximum(array, temporary_array)

    return array


def build_aligned_volume(
    alignment_directory: str | Path,
    use_cache: bool = True,
    return_raw_array: bool = False,
) -> np.ndarray | vedo.Volume:
    if isinstance(alignment_directory, str):
        alignment_directory = Path(alignment_directory)

    alignment_paths = gather_alignment_paths(alignment_directory)
    if not alignment_paths:
        raise ValueError("Cannot build aligned volume from empty alignment directory.")

    alignment_hash = generate_hash_from_targets(alignment_paths)

    cache_path = ALIGNMENT_VOLUMES_CACHE_DIRECTORY / f"{alignment_hash}.npz"
    if cache_path.exists() and use_cache:
        _module_logger.debug("Found cached aligned volume. Loading from file.")

        array = np.load(cache_path)["array"]
        if return_raw_array:
            return array
        return vedo.Volume(array)

    reference_shape = load_alignment_settings(alignment_paths[0]).volume_settings.shape

    # Volume needs to be created before array as vedo makes a copy
    aligned_volume = vedo.Volume(np.zeros(shape=reference_shape, dtype=np.uint16))
    aligned_array = aligned_volume.tonumpy()

    planes = generate_aligned_planes(aligned_volume, alignment_paths)

    insert_aligned_planes_into_array(aligned_array, planes)
    aligned_volume.modified()  # Probably unnecessary but good practice

    if use_cache:
        _module_logger.debug("Caching volume to file as a NumPy array.")
        os.makedirs(ALIGNMENT_VOLUMES_CACHE_DIRECTORY, exist_ok=True)
        np.savez_compressed(cache_path, array=aligned_array)

    if return_raw_array:
        return aligned_array
    return aligned_volume


groups = group_to_closest_grid_point(resolution.value, z_stack_count, z_stack_distance)

mock_stack = np.random.random(size=z_stack_count * 10 * 10).reshape(
    z_stack_count, 10, 10
)
for i in range(z_stack_count):
    mock_stack[i] = i

project_stack = subproject_array(mock_stack, groups)

[-4, -4, -3, -3, -2, -2, -2, -1, -1, 0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4]
(9, 10, 10)


In [72]:
def euclidean_distance(point1: np.ndarray, point2: np.ndarray) -> float:
    return np.sqrt(np.sum((point1 - point2) ** 2))

In [46]:
shape = (528, 320, 456)
orientation = Orientation.CORONAL
resolution = Resolution.MICRONS_25
z_stack_count = 21
z_stack_distance = 10

alignment_coordinates = np.array([264, 180, 228])
pitch = 0
yaw = 0

settings = AlignmentSettings(
    volume_path=get_atlas_path(resolution),
    volume_settings=VolumeSettings(
        orientation=orientation,
        resolution=resolution,
        pitch=pitch,
        yaw=yaw,
    ),
    histology_settings=HistologySettings(),
)

# Load reference atlas volume
atlas_volume = load_volume(get_atlas_path(resolution))

# Get the plane of alignment
alignment_plane: vedo.Plane = VolumeSlicer(volume=atlas_volume).slice(
    settings.volume_settings, return_mesh=True
)
# Get the plane's normal
alignment_normal = compute_normal(pitch, yaw, orientation)

# Compute closest grid point along normal direction
vedo.Point(alignment_coordinates + alignment_normal).distance_to(alignment_plane)

array([1.])

In [73]:
plane2 = vedo.Plane(
    pos=alignment_coordinates + alignment_normal, normal=alignment_normal, s=(100, 100)
)

point1 = alignment_plane.intersect_with_line(
    alignment_coordinates - 1000 * alignment_normal,
    alignment_coordinates + 1000 * alignment_normal,
)
point2 = plane2.intersect_with_line(
    alignment_coordinates - 1000 * alignment_normal,
    alignment_coordinates + 1000 * alignment_normal,
)

euclidean_distance(point1, point2)

1.0

In [32]:
np.sqrt(
    np.sum(((alignment_coordinates + alignment_normal) - alignment_coordinates) ** 2)
)

0.9999999999999748