In [None]:
import logging
import os
from pathlib import Path
import re
from typing import Any, Optional

import matplotlib.pyplot as plt
import numpy as np
import scipy.ndimage as ndimage
import vedo

from histalign import set_log_level
from histalign.backend.io import (
    gather_alignment_paths,
    load_alignment_settings,
    load_image,
)
from histalign.backend.maths import (
    compute_normal,
    compute_normal_from_raw,
    compute_origin_from_orientation,
)
from histalign.backend.models import AlignmentSettings, Orientation
from histalign.backend.registration import Registrator
from histalign.backend.registration.alignment import (
    ALIGNMENT_VOLUMES_CACHE_DIRECTORY,
    generate_hash_from_targets,
)
from histalign.backend.workspace import Volume, VolumeSlicer

vedo.settings.default_backend = "k3d"

_module_logger = logging.getLogger("histalign.notebook")
# set_log_level("DEBUG")

In [None]:
Coordinates = np.ndarray  # Coordinates in NumPy form
CoordinatesTuple = tuple[int, ...]  # Coordinates in tuple form (OK as dictionary keys)
Projection = np.ndarray


def snap_coordinates(coordinates: Coordinates) -> Coordinates:
    """Snaps float coordinates to an integer grid by rounding.

    Args:
        coordinates (Coordinates): Float coordinates to snap.

    Returns:
        Coordinates: The coordinates snapped to the grid.
    """
    return np.round(coordinates)


def get_normal_line_points(
    alignment_settings: AlignmentSettings,
) -> tuple[np.ndarray, np.ndarray]:
    """Compute points describing normal passing through volume centre.

    Args:
        alignment_settings (AlignmentSettings):
            Settings used for the alignment.

    Returns:

    """
    alignment_normal = compute_normal(alignment_settings.volume_settings)

    alignment_origin = compute_origin_from_orientation(
        tuple((np.array(alignment_settings.volume_settings.shape) - 1) / 2),
        alignment_settings.volume_settings,
    )
    intersection_line_coordinates = (
        alignment_origin - 1_000_000 * alignment_normal,
        alignment_origin + 1_000_000 * alignment_normal,
    )
    return intersection_line_coordinates


def sub_project_image_stack(
    image_stack: np.ndarray, groups: list[int]
) -> list[np.ndarray]:
    """Projects arrays based on their group.

    Each group will have a projection of all the images than belong to that group.

    Args:
        image_stack (np.ndarray):
            Image array to sub-project. This must be a 3D array whose first dimension is
            the Z index.
        groups (list[int]): Groups each index in the stack belongs to.

    Returns:
        list[np.ndarray]:
            The list of sub-projections. Each unique group ID will have a single
            projection. The projections are returned in the order of encountered groups.
    """
    sub_projections = []

    previous_group = None
    for group in groups:
        if group == previous_group:
            continue
        previous_group = group

        sub_stack = image_stack[np.where(np.array(groups) == group)]
        sub_projection = np.max(sub_stack, axis=0)
        sub_projections.append(sub_projection)

    return sub_projections


def compute_closest_plane(
    target_planes: list[vedo.Plane], fixed_planes: list[vedo.Plane]
) -> list[int]:
    """Computes the index of the closest fixed plane for each target plane.

    Args:
        target_planes (list[vedo.Plane]): Planes to compute the distances for.
        fixed_planes (list[vedo.Plane]): Planes to compute the distances with.

    Returns:
        list[int]: Indices of the closest fixed plane for each target plane.
    """
    groups = []
    for plane in target_planes:
        distances = list(map(plane.distance_to, fixed_planes))
        distances = list(map(np.max, distances))
        groups.append(distances.index(min(distances)))

    return groups


def _snap_stack_to_grid(
    image_stack: np.ndarray, alignment_settings: AlignmentSettings
) -> dict[CoordinatesTuple, Projection]:
    """Snaps a Z stack to a grid through sub-projections.

    Args:
        image_stack (np.ndarray):
            Image array of the stack. The first dimension should be the stack index.
        alignment_settings (AlignmentSettings): Settings used for the alignment.

    Returns:
        dict[CoordinatesTuple, Projection]:
            A dictionary mapping grid coordinates to a sub-projection.
    """
    match alignment_settings.volume_settings.orientation:
        case Orientation.CORONAL:
            orientation_axis_length = alignment_settings.volume_settings.shape[0]
        case Orientation.HORIZONTAL:
            orientation_axis_length = alignment_settings.volume_settings.shape[1]
        case Orientation.SAGITTAL:
            orientation_axis_length = alignment_settings.volume_settings.shape[2]
        case other:
            raise Exception(f"ASSERT NOT REACHED ({other})")

    # Normal as-if no pitch or yaw are applied
    flat_normal = compute_normal_from_raw(
        0, 0, alignment_settings.volume_settings.orientation
    )

    # Points describing the normal using the aligned pitch and yaw
    normal_line_points = get_normal_line_points(alignment_settings)
    # Origin of the aligned image based on the offset
    alignment_origin = compute_origin_from_orientation(
        tuple((np.array(alignment_settings.volume_settings.shape) - 1) / 2),
        alignment_settings.volume_settings,
    )
    # Normal of the plane used for alignment
    alignment_normal = compute_normal(alignment_settings.volume_settings)

    # Intersection of the normal line and every plane orthogonal to flat normal
    free_floating_intersections = [
        vedo.Plane(
            pos=i * np.abs(flat_normal),
            normal=flat_normal,
            s=(1_000_000, 1_000_000),
        ).intersect_with_line(
            np.squeeze(normal_line_points[0]),
            np.squeeze(normal_line_points[1]),
        )
        for i in range(orientation_axis_length)
    ]
    # Intersections snapped to the closest grid point of the volume
    snapped_intersections = [
        snap_coordinates(point) for point in free_floating_intersections
    ]

    # Mock up a plane for each snapped intersection
    snapped_planes = [
        vedo.Plane(
            pos=np.squeeze(point), normal=alignment_normal, s=(1_000_000, 1_000_000)
        )
        for point in snapped_intersections
    ]

    # Mock up a plane for each Z-index of the stack
    # TODO: Obtain the real spacing
    z_distance = 200
    stack_spacing = z_distance / alignment_settings.volume_settings.resolution.value
    stack_planes = [
        vedo.Plane(
            pos=alignment_origin + i * alignment_normal * stack_spacing,
            normal=alignment_normal,
            s=(1_000_000, 1_000_000),
        )
        for i in range(-image_stack.shape[0] // 2 + 1, image_stack.shape[0] // 2 + 1)
    ]

    # Group Z-indices based on closest snapped intersection
    groups = compute_closest_plane(stack_planes, snapped_planes)

    # Find the coordinates of the Z indices
    coordinates = np.array(snapped_intersections)
    sub_projection_coordinates = []
    previous_group = None
    for index, group in enumerate(groups):
        if group == previous_group:
            continue
        previous_group = group

        sub_projection_coordinates.append(coordinates[group])

    # Sub-project
    sub_projections = sub_project_image_stack(image_stack, groups)

    return {
        tuple(np.squeeze(sub_projection_coordinates[i])): sub_projections[i]
        for i in range(len(sub_projection_coordinates))
    }


def snap_array_to_grid(
    image_array: np.ndarray, alignment_settings: AlignmentSettings
) -> dict[CoordinatesTuple, Projection]:
    """Snaps a 2D or 3D array to a grid.

    Args:
        image_array (np.ndarray): Array to snap. This can be a single image (2D) or a stack (3D).
        alignment_settings (AlignmentSettings): Settings used for the alignment.

    Returns:
        dict[CoordinatesTuple, Projection]:
            A dictionary mapping grid coordinates to a sub-projection. In the case of a 2D image,
            the dictionary has one key and one value. In the case of a stack, each image of the
            stack is snapped to the closest grid coordinate. When multiple images are snapped
            to the same grid point, their maximum intensity projection is taken.
    """
    dimension_count = len(image_array.shape)
    if dimension_count < 2 or dimension_count > 3:
        raise ValueError(
            f"Unexpected shape of image array. Expected 2 or 3 dimensions, "
            f"got {dimension_count}."
        )

    if dimension_count == 3:
        # Z-stacks require a lot more work
        return _snap_stack_to_grid(image_array, alignment_settings)

    alignment_origin = compute_origin_from_orientation(
        tuple((np.array(alignment_settings.volume_settings.shape) - 1) / 2),
        alignment_settings.volume_settings,
    )
    alignment_origin = tuple(map(int, alignment_origin))

    return {alignment_origin: image_array}


def undo_padding(image: np.ndarray, mesh: vedo.Mesh) -> np.ndarray:
    """Undoes the padding added during registration.

    Args:
        image (np.ndarray): Registered image.
        mesh (vedo.Mesh):
            Mesh containing the padding information. This should be included in a
            metadata field. See `histalign.backend.registration.Registrator`.

    Returns:
        np.ndarray: The image without padding.
    """
    unpadded_image = image[
        mesh.metadata["i_padding"][0] : image.shape[0] - mesh.metadata["i_padding"][1],
        mesh.metadata["j_padding"][0] : image.shape[1] - mesh.metadata["j_padding"][1],
    ]

    return unpadded_image


def get_plane_from_2d_image(
    image: np.ndarray,
    alignment_settings: AlignmentSettings,
    slicer: VolumeSlicer,
    origin: Optional[list[float]] = None,
) -> vedo.Mesh:
    """Creates a plane mesh from an image and its alignment settings.

    Args:
        image (np.ndarray): Scalar information for the plane.
        alignment_settings (AlignmentSettings): Settings used for the alignment.
        slicer (VolumeSlicer): Volume slicer from which to obtain the plane.
        origin (Optional[list[float]], optional):
            Origin to use when slicing the volume slicer. If not provided, the centre
            of the volume along the non-orientation axes is used (e.g., centre along YZ
            when working coronally).

    Returns:
        vedo.Mesh:
            The plane whose scalar point data has been filled with the values of `image.`
    """
    registrator = Registrator(True, True)
    registered_slice = registrator.get_forwarded_image(
        image, alignment_settings, origin
    )

    if alignment_settings.volume_settings.orientation == Orientation.HORIZONTAL:
        registered_slice = ndimage.rotate(registered_slice, 90, reshape=False)
    if alignment_settings.volume_settings.orientation != Orientation.SAGITTAL:
        registered_slice = ndimage.rotate(
            registered_slice, -alignment_settings.volume_settings.pitch, reshape=False
        )

    plane_mesh = slicer.slice(
        alignment_settings.volume_settings, origin=origin, return_mesh=True
    )

    registered_slice = undo_padding(registered_slice, plane_mesh)

    plane_mesh.pointdata["ImageScalars"] = registered_slice.flatten()

    return plane_mesh


def replace_path_parts(
    path: Path,
    channel_index: Optional[int] = None,
    channel_regex: Optional[str] = None,
    projection_regex: Optional[str] = None,
    misc_regexes: Optional[list[str]] = None,
    misc_subs: Optional[list[str]] = None,
) -> Path:
    """Extracts the original file name given the channel, Z indices, and optional parts.

    Careful not to trust the output of this function blindly if obtained from external
    input as `misc_regexes` and `misc_subs` can potentially replace the whole path.

    Args:
        path (Path): Path to remove parts on.
        channel_index (Optional[int], optional):
            Channel index to use in the returned path.
        channel_regex (Optional[str], optional):
            Channel regex identifying the channel part of `path`'s name.
        projection_regex (Optional[str], optional):
            Projection regex identifying the projection part of `path`'s name.
        misc_regexes (Optional[list[str]], optional):
            Miscellaneous regex identifying an extra part of alignment paths found in
            `alignment_directory`.
        misc_subs (Optional[list[str]], optional):
            Substitutions to replace `misc_regexes` with. This should have as many
            elements as `misc_regexes`.

    Returns:
        Path: The path with the parts removed.

    Examples:
        > replace_path_parts(Path("/data/filename_C0_max.h5"), 1, r"_C\d_", "_max")
        Path('/data/filename_C0.h5')
    """
    # Replace the channel index
    if channel_regex is not None:
        if channel_index is not None:
            path = path.with_name(
                re.sub(
                    channel_regex,
                    channel_regex.replace(r"\d", str(channel_index)),
                    path.name,
                    count=1,
                )
            )

    # Remove part of the file name that indicates the projection
    if projection_regex is not None:
        path = path.with_name(
            re.sub(projection_regex, "", path.name, count=1),
        )

    # Replace the miscellaneous parts
    if misc_regexes is not None and misc_subs is not None:
        if (len1 := len(misc_regexes)) != (len2 := len(misc_subs)):
            _module_logger.error(
                f"Received different numbers of misc regex and subs "
                f"({len1} vs {len2}). Skipping miscellaneous replacement."
            )
        else:
            for i in range(len(misc_regexes)):
                path = Path(
                    re.sub(misc_regexes[i], misc_subs[i], str(path)),
                )

    return path


def generate_aligned_planes(
    alignment_volume: Volume | vedo.Volume,
    alignment_paths: list[Path],
    channel_index: Optional[int] = None,
    channel_regex: Optional[str] = None,
    projection_regex: Optional[str] = None,
    misc_regexes: Optional[str] = None,
    misc_subs: Optional[str] = None,
) -> list[vedo.Mesh]:
    """Generates aligned planes for each image (2D or 3D) from the alignment paths.

    Args:
        alignment_volume (Volume | vedo.Volume): Volume to generate planes for.
        alignment_paths (list[Path]):
            List of alignment settings paths to use when reconstructing the planes.
        channel_index (Optional[int], optional):
            Channel index to use when retrieving the original files.
        channel_regex (Optional[str], optional):
            Channel regex identifying the channel part of `alignment_paths`'s names.
        projection_regex (Optional[str], optional):
            Projection regex identifying the projection part of `alignment_paths`'s
            names.
        misc_regexes (Optional[list[str]], optional):
            Miscellaneous regexes identifying extra parts of alignment paths found in
            `alignment_directory`.
        misc_subs (Optional[list[str]], optional):
            Substitutions to replace `misc_regexes` with. This should have as many
            elements as `misc_regexes`.

    Returns:
        list[vedo.Mesh]:
            A list of all the aligned planes obtained from the alignment paths.
    """
    _module_logger.debug(f"Starting generation of aligned planes.")

    planes = []
    slicer = VolumeSlicer(volume=alignment_volume)

    for index, alignment_path in enumerate(alignment_paths):
        if index > 0 and index % 5 == 0:
            _module_logger.debug(f"Generating plane(s) for {alignment_path.name}...")

        alignment_settings = load_alignment_settings(alignment_path)

        histology_path_with_replacement = replace_path_parts(
            alignment_settings.histology_path,
            channel_index,
            channel_regex,
            projection_regex,
            misc_regexes,
            misc_subs,
        )
        if not histology_path_with_replacement.exists():
            _module_logger.error(
                f"Could not find file '{histology_path_with_replacement}' "
                f"(original path: '{alignment_settings.histology_path}'). "
                f"Skipping it."
            )
            continue

        alignment_settings.histology_path = histology_path_with_replacement

        image_array = load_image(alignment_settings.histology_path, allow_stack=True)

        projections_map = snap_array_to_grid(image_array, alignment_settings)
        for origin, projection in projections_map.items():
            planes.append(
                get_plane_from_2d_image(
                    projection, alignment_settings, origin=origin, slicer=slicer
                )
            )

    _module_logger.debug(f"Finished generating all aligned planes.")
    return planes


def insert_aligned_planes_into_array(
    array: np.ndarray,
    planes: list[vedo.Mesh],
    inplace: bool = True,
) -> np.ndarray:
    """Inserts aligned planes into a 3D numpy array.

    Args:
        array (np.ndarray): Array to insert into.
        planes (list[vedo.Mesh]): Planes to insert.
        inplace (bool, optional): Whether to modify `array` in-place.

    Returns:
        np.ndarray:
            `array` if `inplace` is `True`, else a copy, with the planes inserted.
    """
    _module_logger.debug(
        f"Starting insertion of {len(planes)} planes into alignment array."
    )

    if not inplace:
        array = array.copy()

    for index, plane in enumerate(planes):
        if index > 0 and index % 5 == 0:
            _module_logger.debug(f"Inserted {index} planes into alignment array...")

        temporary_volume = vedo.Volume(np.zeros_like(array))
        temporary_volume.interpolate_data_from(plane, radius=1)

        temporary_array = temporary_volume.tonumpy()
        temporary_array = np.round(temporary_array).astype(np.uint16)

        array[:] = np.maximum(array, temporary_array)

    _module_logger.debug(
        f"Finished inserting all {len(planes)} planes into alignment array."
    )

    return array


def build_aligned_volume(
    alignment_directory: str | Path,
    allow_cache_load: bool = True,
    allow_cache_save: bool = True,
    return_raw_array: bool = False,
    channel_index: Optional[int] = None,
    channel_regex: Optional[str] = None,
    projection_regex: Optional[str] = None,
    misc_regexes: Optional[str | list[str]] = None,
    misc_subs: Optional[str | list[str]] = None,
) -> np.ndarray | vedo.Volume:
    """Builds an aligned volume from alignment settings.

    Args:
        alignment_directory (str | Path):
        allow_cache_load (bool, optional):
            Whether to use the cache for loading. If `False`, the entire volume is
            guaranteed to be build from scratch. If `True`, the cache will be queried
            for an existing volume created with the same alignment paths. If one exists,
            it is returned. Otherwise, the volume is built as normal.
        allow_cache_save (bool, optional):
            Whether to save the built volume to the cache.
        return_raw_array (bool, optional):
            Whether to return a numpy array (`True`) or a vedo volume (`False`).
        channel_index (Optional[int], optional):
            Channel index to use when retrieving the original files.
        channel_regex (Optional[str], optional):
            Channel regex identifying the channel part of alignment paths' names found
            in `alignment_directory`.
        projection_regex (Optional[str], optional):
            Projection regex identifying the projection part of alignment paths' names
            found in `alignment_directory`.
        misc_regexes (Optional[list[str]], optional):
            Miscellaneous regexes identifying extra parts of alignment paths found in
            `alignment_directory`.
        misc_subs (Optional[list[str]], optional):
            Substitutions to replace `misc_regexes` with. This should have as many
            elements as `misc_regexes`.

    Returns:
        np.ndarray | vedo.Volume: The aligned volume or 3D array.
    """
    _module_logger.debug("Starting build of aligned volume.")

    if channel_regex is not None and channel_index is None:
        _module_logger.warning(
            "Received channel regex but no channel index. Building alignment "
            "volume using the same channel as was used for alignment."
        )
    elif channel_regex is None and channel_index is not None:
        _module_logger.warning(
            "Received channel index but no channel regex. Building alignment "
            "volume using the same channel as was used for alignment."
        )

    if isinstance(alignment_directory, str):
        alignment_directory = Path(alignment_directory)

    alignment_paths = gather_alignment_paths(alignment_directory)
    if not alignment_paths:
        raise ValueError("Cannot build aligned volume from empty alignment directory.")

    alignment_hash = generate_hash_from_targets(alignment_paths)

    cache_path = ALIGNMENT_VOLUMES_CACHE_DIRECTORY / f"{alignment_hash}.npz"
    if cache_path.exists() and allow_cache_load:
        _module_logger.debug("Found cached aligned volume. Loading from file.")

        array = np.load(cache_path)["array"]
        if return_raw_array:
            return array
        return vedo.Volume(array)

    reference_shape = load_alignment_settings(alignment_paths[0]).volume_settings.shape

    # Volume needs to be created before array as vedo makes a copy
    aligned_volume = vedo.Volume(np.zeros(shape=reference_shape, dtype=np.uint16))
    aligned_array = aligned_volume.tonumpy()

    planes = generate_aligned_planes(
        aligned_volume,
        alignment_paths,
        channel_index,
        channel_regex,
        projection_regex,
        misc_regexes,
        misc_subs,
    )

    insert_aligned_planes_into_array(aligned_array, planes)
    aligned_volume.modified()  # Probably unnecessary but good practice

    if allow_cache_save:
        _module_logger.debug("Caching volume to file as a NumPy array.")
        os.makedirs(ALIGNMENT_VOLUMES_CACHE_DIRECTORY, exist_ok=True)
        np.savez_compressed(cache_path, array=aligned_array)

    if return_raw_array:
        return aligned_array
    return aligned_volume


def imshow(
    image: np.ndarray,
    cmap: Optional[str] = None,
    title: str = "",
    colorbar: bool = False,
    full_range: bool = False,
    vmin: Optional[int | float] = None,
    vmax: Optional[int | float] = None,
    figsize: Optional[tuple[int, int]] = None,
    tight: bool = True,
) -> None:
    figure, axes = plt.subplots(figsize=figsize)

    if full_range:
        try:
            vmin = np.iinfo(image.dtype).min
        except ValueError:
            vmin = np.finfo(image.dtype).min
        try:
            vmax = np.iinfo(image.dtype).max
        except ValueError:
            vmax = np.finfo(image.dtype).max

    image = axes.imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
    axes.axis(False)

    if title:
        figure.suptitle(title)
    if colorbar:
        plt.colorbar(image, ax=axes)

    if tight:
        plt.tight_layout()
    plt.show()


def show(
    volumes: vedo.CommonVisual | list[vedo.CommonVisual],
    camera: dict[str, Any] | None = None,
) -> None:
    if isinstance(volumes, vedo.CommonVisual):
        volumes = [volumes]

    plotter = vedo.Plotter(axes=3, interactive=True, screensize=(1920, 1080))
    for volume in volumes:
        plotter.add(volume)

    plotter.show(camera=camera)

In [None]:
replace_path_parts(
    Path("/data/dataset1/filename_C2_max.h5"),
    0,
    r"C\d",
    "_max",
    ["data/", "dataset"],
    ["my_data/", "experiment"],
)

In [None]:
alignment_directory = "/home/ediun/git/histalign/projects/project_z_stack/1ebb0c1a09"
aligned_volume = build_aligned_volume(
    alignment_directory,
    use_cache=False,
    z_stack_regex="_max",
)

In [None]:
aligned_volume.cmap(c="red", alpha=[-0.5, 5])
aligned_volume.interpolation(0)
vedo.show(aligned_volume)

In [None]:
imshow(aligned_volume.tonumpy()[..., 77].T)