In [None]:
import hashlib

from helper_loader import *
from histalign.backend.registration.alignment import (
    INTERPOLATED_VOLUMES_CACHE_DIRECTORY,
    _module_logger,
    build_alignment_volume,
)

set_log_level("DEBUG")

In [6]:
alignment_path = Path(
    "/home/ediun/git/histalign/projects/project_cortical_depth/93e6cae680"
)

array = build_alignment_volume(alignment_path, return_raw_array=True)

[2024-11-06 16:32:06] - [   DEBUG ] - Found cached volume. Loading from file. (histalign.backend.registration.alignment:55)


In [7]:
def generate_hash_from_targets(targets: list[Path]) -> str:
    return hashlib.md5("".join(map(str, targets)).encode("UTF-8")).hexdigest()


def interpolate_sparse_3d_array(
    array: np.ndarray,
    reference_mask: Optional[np.ndarray] = None,
    pre_masked: bool = False,
    kernel: str = "multiquadric",
    neighbours: int = 27,
    epsilon: int = 1,
    degree: Optional[int] = None,
    chunk_size: Optional[int] = 1_000_000,
    recursive: bool = False,
    use_cache: bool = False,
    alignment_directory: str | Path = "",
    mask_name: str = "",
) -> np.ndarray:
    start_time = time.perf_counter()

    if use_cache and not alignment_directory:
        raise ValueError(
            "Cannot use cache without 'alignment_directory' identifying information."
        )
    if use_cache and reference_mask is not None and not mask_name:
        raise ValueError(
            "Cannot use cache with reference mask but no 'mask_name' "
            "identifying information."
        )
    if isinstance(alignment_directory, str):
        alignment_directory = Path(alignment_directory)

    if reference_mask is not None and (array_shape := array.shape) != (
        reference_shape := reference_mask.shape
    ):
        raise ValueError(
            f"Array and reference mask have different shapes "
            f"({array_shape} vs {reference_shape})."
        )

    # Mask the array if necessary
    if reference_mask is not None and not pre_masked:
        array = np.where(reference_mask, array, 0)

    cache_hash = generate_hash_from_targets(gather_alignment_paths(alignment_directory))
    mask_name = "-".join(mask_name.split(" ")).lower()
    cache_path = (
        INTERPOLATED_VOLUMES_CACHE_DIRECTORY
        / f"{cache_hash}{f'_{mask_name}' if reference_mask is not None else ''}.npz"
    )
    if cache_path.exists() and use_cache:
        _module_logger.debug("Found cached array. Loading from file.")

        return np.load(cache_path)["array"]

    interpolated_array = array.copy()
    interpolated_array = interpolated_array.astype(np.float64)

    if reference_mask is None:
        # Interpolate the whole grid
        target_coordinates = tuple(
            array.flatten().astype(int)
            for array in np.meshgrid(
                np.linspace(
                    0, interpolated_array.shape[0] - 1, interpolated_array.shape[0]
                ),
                np.linspace(
                    0, interpolated_array.shape[1] - 1, interpolated_array.shape[1]
                ),
                np.linspace(
                    0, interpolated_array.shape[2] - 1, interpolated_array.shape[2]
                ),
                indexing="ij",
            )
        )
    else:
        # Interpolate only non-zero coordinates of mask
        target_coordinates = np.nonzero(reference_mask)
    target_points = np.array(target_coordinates).T

    if chunk_size is None:
        chunk_size = target_points.shape[0]

    _module_logger.info(
        f"Starting interpolation with parameters "
        f"{{"
        f"kernel: {kernel}, "
        f"neighbours: {neighbours}, "
        f"epsilon: {epsilon}, "
        f"degree: {degree}, "
        f"chunk size: {chunk_size:,}, "
        f"recursive: {recursive}"
        f"}}."
    )

    failed_chunks = []
    previous_target_size = target_points.shape[0]
    while True:
        known_coordinates = np.nonzero(interpolated_array)
        known_points = np.array(known_coordinates).T

        known_values = array[known_coordinates]

        interpolator = RBFInterpolator(
            known_points,
            known_values,
            kernel=kernel,
            neighbors=neighbours,
            epsilon=epsilon,
            degree=degree,
        )

        chunk_start = 0
        chunk_end = chunk_size
        chunk_index = 1
        chunk_count = math.ceil(target_points.shape[0] / chunk_size)
        while chunk_start < target_points.shape[0]:
            _module_logger.info(
                f"Interpolating chunk {chunk_index}/{chunk_count} "
                f"({chunk_index / chunk_count:.0%})."
            )

            chunk_coordinates = tuple(
                coordinate[chunk_start:chunk_end] for coordinate in target_coordinates
            )
            chunk_points = target_points[chunk_start:chunk_end]

            try:
                interpolated_array[chunk_coordinates] = interpolator(chunk_points)
            except np.linalg.LinAlgError:
                failed_chunks.append([chunk_start, chunk_end])
                _module_logger.info(f"Failed to interpolate chunk {chunk_index}.")

            chunk_start += chunk_size
            chunk_end += chunk_size
            chunk_index += 1

        if not recursive or len(failed_chunks) == 0:
            break

        # Prepare the next loop
        target_coordinates = tuple(
            np.concatenate(
                [target_coordinate[start:end] for start, end in failed_chunks]
            )
            for target_coordinate in target_coordinates
        )
        target_points = np.array(target_coordinates).T
        failed_chunks = []

        # Avoid infinitely looping
        if previous_target_size == target_points.shape[0]:
            _module_logger.error(
                f"Interpolation is not fully solvable with current combination of "
                f"kernel, neighbours parameter and chunk size. "
                f"Returning current result."
            )
            break
        previous_target_size = target_points.shape[0]

        _module_logger.info(
            f"There were {len(failed_chunks)} failed chunks of size {chunk_size}. "
            f"Recursing with newly interpolated data."
        )

    total_time = time.perf_counter() - start_time
    total_hours, remaining_time = divmod(total_time, 3600)
    total_minutes, total_seconds = divmod(remaining_time, 60)
    time_string = (
        f"{f'{total_hours:.0f}h' if total_hours else ''}"
        f"{f'{total_minutes:>2.0f}m' if total_minutes else ''}"
        f"{total_seconds:>2.0f}s"
    )
    _module_logger.info(f"Finished interpolation in {time_string}.")

    if use_cache:
        _module_logger.debug("Caching interpolated array to file.")
        os.makedirs(INTERPOLATED_VOLUMES_CACHE_DIRECTORY, exist_ok=True)
        np.savez_compressed(cache_path, array=interpolated_array)

    return interpolated_array

Neighbours 16, no mask.
```text
[2024-11-06 16:19:08] - [    INFO ] - Starting interpolation with parameters {kernel: multiquadric, neighbours: 16, epsilon: 1, degree: None, chunk size: 1,000,000, recursive: False}. (histalign.backend.registration.alignment:84)
[2024-11-06 16:19:09] - [    INFO ] - Interpolating chunk 1/78 (1%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:19:33] - [    INFO ] - Interpolating chunk 2/78 (3%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:19:56] - [    INFO ] - Interpolating chunk 3/78 (4%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:20:21] - [    INFO ] - Interpolating chunk 4/78 (5%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:20:47] - [    INFO ] - Interpolating chunk 5/78 (6%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:21:16] - [    INFO ] - Interpolating chunk 6/78 (8%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:21:45] - [    INFO ] - Interpolating chunk 7/78 (9%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:22:17] - [    INFO ] - Interpolating chunk 8/78 (10%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:22:50] - [    INFO ] - Interpolating chunk 9/78 (12%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:23:25] - [    INFO ] - Interpolating chunk 10/78 (13%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:24:03] - [    INFO ] - Interpolating chunk 11/78 (14%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:24:46] - [    INFO ] - Interpolating chunk 12/78 (15%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:25:32] - [    INFO ] - Interpolating chunk 13/78 (17%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:26:24] - [    INFO ] - Interpolating chunk 14/78 (18%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:27:22] - [    INFO ] - Interpolating chunk 15/78 (19%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:28:26] - [    INFO ] - Interpolating chunk 16/78 (21%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:29:38] - [    INFO ] - Interpolating chunk 17/78 (22%). (histalign.backend.registration.alignment:118)
```

Neighbours 16, root mask.
```text
[2024-11-06 16:32:08] - [    INFO ] - Starting interpolation with parameters {kernel: multiquadric, neighbours: 16, epsilon: 1, degree: None, chunk size: 1,000,000, recursive: False}. (histalign.backend.registration.alignment:84)
[2024-11-06 16:32:08] - [    INFO ] - Interpolating chunk 1/33 (3%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:32:21] - [    INFO ] - Interpolating chunk 2/33 (6%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:32:41] - [    INFO ] - Interpolating chunk 3/33 (9%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:33:19] - [    INFO ] - Interpolating chunk 4/33 (12%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:34:17] - [    INFO ] - Interpolating chunk 5/33 (15%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:35:35] - [    INFO ] - Interpolating chunk 6/33 (18%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:37:02] - [    INFO ] - Interpolating chunk 7/33 (21%). (histalign.backend.registration.alignment:118)
```

In [8]:
mask = load_structure_mask("root", Resolution.MICRONS_25)
interpolated_array = interpolate_sparse_3d_array(
    array,
    reference_mask=mask,
    alignment_directory=alignment_path,
    mask_name="root",
    neighbours=16,
)

[2024-11-06 16:32:08] - [    INFO ] - Starting interpolation with parameters {kernel: multiquadric, neighbours: 16, epsilon: 1, degree: None, chunk size: 1,000,000, recursive: False}. (histalign.backend.registration.alignment:84)
[2024-11-06 16:32:08] - [    INFO ] - Interpolating chunk 1/33 (3%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:32:21] - [    INFO ] - Interpolating chunk 2/33 (6%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:32:41] - [    INFO ] - Interpolating chunk 3/33 (9%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:33:19] - [    INFO ] - Interpolating chunk 4/33 (12%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:34:17] - [    INFO ] - Interpolating chunk 5/33 (15%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:35:35] - [    INFO ] - Interpolating chunk 6/33 (18%). (histalign.backend.registration.alignment:118)
[2024-11-06 16:37:02] - [    INFO ] - Interpolating chunk 7/33 (21%). (