# Full test

Now that hyperparameters have been tuned for the best approximation of the problem, let's put it into practice on a brain (or an hemisphere).

In [1]:
from helper_loader import *

In [2]:
volume_path = Path("resources/d2783eb27223868f57e159f01ce7a1b2.npz")

In [3]:
atlas_array = load_volume(
    get_structure_mask_path("root", Resolution.MICRONS_25),
    return_raw_array=True,
)
masked_array = np.where(atlas_array, np.load(volume_path)["array"], 0)

In [4]:
known_coordinates = np.nonzero(masked_array)
known_points = np.array(known_coordinates).T

known_values = masked_array[known_coordinates]

target_coordinates = np.nonzero(atlas_array)
target_points = np.array(target_coordinates).T

In [5]:
interpolator = RBFInterpolator(
    known_points,
    known_values,
    kernel="multiquadric",
    neighbors=27,
    epsilon=1,
)

In [20]:
chunk_size = 1_000_000
chunk_count = math.ceil(np.prod(target_points.shape[0]) / chunk_size)
chunk_index = 1

failed_chunks = []

interpolated_array = np.zeros_like(atlas_array, dtype=np.float64)

In [21]:
chunk_start = 0
chunk_end = chunk_start + chunk_size
while True:
    if chunk_start >= target_points.shape[0]:
        break

    logging.info(f"Computing chunk {chunk_index}/{chunk_count}.")

    chunk = target_points[chunk_start:chunk_end]
    coordinates = tuple(
        coordinates[chunk_start:chunk_end] for coordinates in target_coordinates
    )

    try:
        interpolated_array[coordinates] = interpolator(chunk)
    except np.linalg.LinAlgError:
        failed_chunks.append([chunk_start, chunk_end])
        logging.info("Failed to interpolate chunk.")

    chunk_start += chunk_size
    chunk_end += chunk_size

    chunk_index += 1

[2024-09-23 10:52:35] - [    INFO ] - Computing chunk 1/33. (root:7)
[2024-09-23 10:52:54] - [    INFO ] - Computing chunk 2/33. (root:7)
[2024-09-23 10:53:18] - [    INFO ] - Computing chunk 3/33. (root:7)
[2024-09-23 10:53:46] - [    INFO ] - Computing chunk 4/33. (root:7)
[2024-09-23 10:54:17] - [    INFO ] - Computing chunk 5/33. (root:7)
[2024-09-23 10:54:49] - [    INFO ] - Computing chunk 6/33. (root:7)
[2024-09-23 10:55:23] - [    INFO ] - Computing chunk 7/33. (root:7)
[2024-09-23 10:55:58] - [    INFO ] - Computing chunk 8/33. (root:7)
[2024-09-23 10:56:33] - [    INFO ] - Computing chunk 9/33. (root:7)
[2024-09-23 10:57:09] - [    INFO ] - Computing chunk 10/33. (root:7)
[2024-09-23 10:57:45] - [    INFO ] - Computing chunk 11/33. (root:7)
[2024-09-23 10:58:20] - [    INFO ] - Computing chunk 12/33. (root:7)
[2024-09-23 10:58:56] - [    INFO ] - Computing chunk 13/33. (root:7)
[2024-09-23 10:59:33] - [    INFO ] - Computing chunk 14/33. (root:7)
[2024-09-23 11:00:09] - [    

In [34]:
volume = vedo.Volume(interpolated_array[..., : interpolated_array.shape[2] // 2])
volume.cmap(**get_cmap(volume, vmin=0, vmax=interpolated_array.max()))
show(volume)

In [67]:
from ipywidgets import interact


def update(index: int = 0):
    imshow(interpolated_array[..., index].T)


interact(update, index=(0, interpolated_array.shape[2] // 2));

interactive(children=(IntSlider(value=0, description='index', max=228), Output()), _dom_classes=('widget-inter…

In [73]:
def interpolate_3d_array(
    array: np.ndarray,
    reference_mask: Optional[np.ndarray] = None,
    pre_masked: bool = False,
    kernel: str = "multiquadric",
    neighbours: int = 27,
    epsilon: int = 1,
    chunk_size: Optional[int] = 100_000,
    recursive: bool = False,
) -> np.ndarray:
    if reference_mask is not None and (array_shape := array.shape) != (
        reference_shape := reference_mask.shape
    ):
        raise ValueError(
            f"Array and reference mask have different shapes "
            f"({array_shape} vs {reference_shape})."
        )

    # Mask the array
    if reference_mask is not None and not pre_masked:
        array = np.where(reference_mask, array, 0)

    interpolated_array = array.copy()

    if reference_mask is None:
        # Interpolate the whole grid
        target_coordinates = tuple(
            array.flatten().astype(int)
            for array in np.meshgrid(
                np.linspace(
                    0, interpolated_array.shape[0] - 1, interpolated_array.shape[0]
                ),
                np.linspace(
                    0, interpolated_array.shape[1] - 1, interpolated_array.shape[1]
                ),
                np.linspace(
                    0, interpolated_array.shape[2] - 1, interpolated_array.shape[2]
                ),
                indexing="ij",
            )
        )
    else:
        # Interpolate only non-zero coordinates of mask
        target_coordinates = np.nonzero(reference_mask)
    target_points = np.array(target_coordinates).T

    if chunk_size is None:
        chunk_size = target_points.shape[0]

    logging.info("Starting interpolation.")

    failed_chunks = []
    previous_target_size = target_points.shape[0]
    while True:
        known_coordinates = np.nonzero(interpolated_array)
        known_points = np.array(known_coordinates).T

        known_values = array[known_coordinates]

        interpolator = RBFInterpolator(
            known_points,
            known_values,
            kernel=kernel,
            neighbors=neighbours,
            epsilon=epsilon,
        )

        chunk_start = 0
        chunk_end = chunk_size
        chunk_index = 1
        chunk_count = math.ceil(target_points.shape[0] / chunk_size)
        while chunk_start < target_points.shape[0]:
            logging.info(f"Interpolating chunk {chunk_index}/{chunk_count}.")

            chunk = target_points[chunk_start:chunk_end]
            coordinates = tuple(
                coordinate[chunk_start:chunk_end] for coordinate in target_coordinates
            )

            try:
                interpolated_array[coordinates] = interpolator(chunk)
            except np.linalg.LinAlgError:
                failed_chunks.append([chunk_start, chunk_end])
                logging.info(f"Failed to interpolate chunk {chunk_index}.")

            chunk_start += chunk_size
            chunk_end += chunk_size
            chunk_index += 1

        if not recursive or len(failed_chunks) == 0:
            break

        # Prepare the next loop
        target_coordinates = tuple(
            np.concatenate(
                [target_coordinate[start:end] for start, end in failed_chunks]
            )
            for target_coordinate in target_coordinates
        )
        target_points = np.array(target_coordinates).T
        failed_chunks = []

        # Avoid infinitely looping
        if previous_target_size == target_coordinates.shape[0]:
            logging.error(
                f"Interpolation is not fully solvable with current combination of"
                f"neighbours parameter and chunk size."
            )
            break
        previous_target_size = target_coordinates.shape[0]

        logging.info(
            f"There were {len(failed_chunks)} failed chunks. "
            f"Recursing with newly interpolated data."
        )

    return interpolated_array

In [74]:
custom_interpolated_array = interpolate_3d_array(
    masked_array,
    reference_mask=atlas_array,
    pre_masked=True,
    recursive=True,
)

[2024-09-23 12:19:53] - [    INFO ] - Starting interpolation. (root:50)
[2024-09-23 12:19:55] - [    INFO ] - Interpolating chunk 1/324. (root:73)
[2024-09-23 12:19:58] - [    INFO ] - Interpolating chunk 2/324. (root:73)
[2024-09-23 12:20:00] - [    INFO ] - Interpolating chunk 3/324. (root:73)
[2024-09-23 12:20:02] - [    INFO ] - Interpolating chunk 4/324. (root:73)
[2024-09-23 12:20:04] - [    INFO ] - Interpolating chunk 5/324. (root:73)
[2024-09-23 12:20:06] - [    INFO ] - Interpolating chunk 6/324. (root:73)
[2024-09-23 12:20:08] - [    INFO ] - Interpolating chunk 7/324. (root:73)
[2024-09-23 12:20:10] - [    INFO ] - Interpolating chunk 8/324. (root:73)
[2024-09-23 12:20:12] - [    INFO ] - Interpolating chunk 9/324. (root:73)
[2024-09-23 12:20:14] - [    INFO ] - Interpolating chunk 10/324. (root:73)


KeyboardInterrupt: 