# Using Dask to only skip impossible interpolation

The idea is to use Dask (more specifically `dask.array.map_blocks()`) to avoid running into `LinAlgError: Singular Matrix` when running RBFInterpolator.  
To do so, there needs to be an intermediary function that will catch the error.

In [25]:
from functools import partial
import os

import dask.array as da
from scipy.interpolate import RBFInterpolator
import vedo

from notebook_helpers import *

## Getting the arrays

In [2]:
volume_path = Path(
    "/home/ediun/.local/share/histalign/alignment_volumes/d2783eb27223868f57e159f01ce7a1b2.npz"
)

In [3]:
aligned_array = np.load(volume_path)["array"]

atlas_array = load_volume(
    get_structure_mask_path("root", Resolution.MICRONS_25), return_raw_array=True
)

masked_aligned_array = np.where(atlas_array > 0, aligned_array, 0)

In [4]:
focused_area = (
    (aligned_array.shape[0] // 4, aligned_array.shape[0] // 4 * 3),
    (aligned_array.shape[1] // 4, aligned_array.shape[1] // 4 * 3),
    (0, aligned_array.shape[0] // 2),
)


aligned_array = aligned_array[
    focused_area[0][0] : focused_area[0][1],
    focused_area[1][0] : focused_area[1][1],
    focused_area[2][0] : focused_area[2][1],
]
atlas_array = atlas_array[
    focused_area[0][0] : focused_area[0][1],
    focused_area[1][0] : focused_area[1][1],
    focused_area[2][0] : focused_area[2][1],
]
masked_aligned_array = masked_aligned_array[
    focused_area[0][0] : focused_area[0][1],
    focused_area[1][0] : focused_area[1][1],
    focused_area[2][0] : focused_area[2][1],
]

## Preparing the parameters

In [5]:
known_coordinates = np.nonzero(masked_aligned_array)
known_points = np.array(known_coordinates).T

known_values = masked_aligned_array[known_coordinates]

target_coordinates = np.nonzero(atlas_array[..., : atlas_array.shape[2] // 2])
target_points = np.array(target_coordinates).T

## Initialising the interpolator

Because we can't afford to use 20TiB of RAM to load the huge array required for RBFInterpolator to give a weight to each point of the array to interpolate every unknown point, we use the `neighbors` parameters.  
Currently, we're setting it manually but something that will be useful is to compute it based on the resolution and the window we want to use to interpolate data.

In [6]:
interpolator = RBFInterpolator(
    known_points, known_values, kernel="cubic", neighbors=5**3
)

## Dask to the rescue

Currently, if we were to run the interpolator to get interpolated data at `target_coordinates`, we would run into `LinAlgError: Singular matrix. The matrix of monomials evaluated at the data point coordinates does not have full column rank (3/4).` for some points. As I understand it, this happens because some points cannot be interpolated as their neighbourhood is empty.  
The problem is that the whole interpolation fails even if a single point cannot be interpolated. To solve this, the idea is to use Dask to split the interpolation work into chunks which will be allowed to individually fail while still keeping an output for the rest of the interpolated data.

### Creating the intermediate function

To allow the interpolation to fail, we need an intermediate function which will catch the error and simply return a dummy array (e.g., only zeros) if the interpolation fails.

In [7]:
def interpolation_function(chunk: np.ndarray, interpolator: Any) -> np.ndarray:
    try:
        interpolated_data = interpolator(chunk)
    except np.linalg.LinAlgError:
        interpolated_data = np.zeros(shape=(chunk.shape[0],), dtype=np.float64)

    return interpolated_data

### Trying it out

In [8]:
da_target_points = da.from_array(target_points, chunks=(1_000, 3))
da_target_points

Unnamed: 0,Array,Chunk
Bytes,80.03 MiB,23.44 kiB
Shape,"(3496755, 3)","(1000, 3)"
Dask graph,3497 chunks in 1 graph layer,3497 chunks in 1 graph layer
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 80.03 MiB 23.44 kiB Shape (3496755, 3) (1000, 3) Dask graph 3497 chunks in 1 graph layer Data type int64 numpy.ndarray",3  3496755,

Unnamed: 0,Array,Chunk
Bytes,80.03 MiB,23.44 kiB
Shape,"(3496755, 3)","(1000, 3)"
Dask graph,3497 chunks in 1 graph layer,3497 chunks in 1 graph layer
Data type,int64 numpy.ndarray,int64 numpy.ndarray


In [9]:
interpolation_function = partial(interpolation_function, interpolator=interpolator)
interpolation_function

functools.partial(<function interpolation_function at 0x7f7aa9351900>, interpolator=<scipy.interpolate._rbfinterp.RBFInterpolator object at 0x7f7aa94f3340>)

In [10]:
target_values = da.map_blocks(
    interpolation_function, da_target_points, drop_axis=1, dtype=np.float64
)

### Fingers crossed!

In [11]:
target_values = target_values.compute()

ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Failed to interpolate.
ERROR:root:Fai

## Well, well, well

After a mere 16 minutes and a bunch of failed interpolations, we have the interpolated data.

In [24]:
padding = 35

print(f"{'Number of points:':<{padding}} {target_points.shape[0]}")
print(f"{'Number of interpolated values:':<{padding}} {target_values.shape[0]}")

print(f"{'Interpolated min:':<{padding}} {target_values.min()}")
print(f"{'Interpolated max:':<{padding}} {target_values.max()}")

Number of points:                   3496755
Number of interpolated values:      3496755
Interpolated min:                   -392372.4157710974
Interpolated max:                   127920.15405381192


Better save it now.

In [27]:
os.makedirs("resources", exist_ok=True)
np.savez_compressed(f"resources/{volume_path.stem}.interp.npz", [target_values])

## What does it look like?

In [95]:
interpolated_array = np.zeros_like(masked_aligned_array)
interpolated_array[target_coordinates] = target_values

In [97]:
interpolated_volume = vedo.Volume(interpolated_array)
_ = interpolated_volume.cmap(**get_cmap(interpolated_volume))

In [51]:
atlas_volume = vedo.Volume(atlas_array)

In [42]:
masked_aligned_volume = vedo.Volume(masked_aligned_array)

In [112]:
show(
    {
        "Atlas volume": atlas_volume,
        "Interpolated volume": interpolated_volume,
        "Aligned volume": masked_aligned_volume,
    },
    n=3,
)