In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
"""Tests for verifying process/thread usage in parallelized functions."""

from __future__ import annotations

from collections.abc import Callable
from functools import partial

import dask.array as da
import numba
import numpy as np
import pytest  # type: ignore[import]

from squidpy._utils import Signal, parallelize

# Functions to be parallelized


@numba.njit(parallel=True)
def numba_parallel_func(x, y) -> np.ndarray:
    return x * 2 + y


@numba.njit(parallel=False)
def numba_serial_func(x, y) -> np.ndarray:
    return x * 2 + y


def dask_func(x, y) -> np.ndarray:
    return (da.from_array(x) * 2 + y).compute()


def vanilla_func(x, y) -> np.ndarray:
    return x * 2 + y


# Mock runner function


def mock_runner(x, y, queue, func):
    for i in range(len(x)):
        x[i] = func(x[i], y)
        if queue is not None:
            queue.put(Signal.UPDATE)
    if queue is not None:
        queue.put(Signal.FINISH)
    return x


@pytest.fixture(params=["numba_parallel", "numba_serial", "dask", "vanilla"])
def func(request) -> Callable:
    return {
        "numba_parallel": numba_parallel_func,
        "numba_serial": numba_serial_func,
        "dask": dask_func,
        "vanilla": vanilla_func,
    }[request.param]

In [8]:
n = 8
func = numba_parallel_func
arr1 = [np.arange(n) for _ in range(n)]
arr2 = np.arange(n)
runner = partial(mock_runner, func=func)
# expected = [func(arr1[i], arr2) for i in range(len(arr1))]
p_func = parallelize(runner, arr1, n_jobs=2, backend="loky", use_ixs=False, n_splits=len(arr1))
result = p_func(arr2)[0]
assert len(result) == len(expected), f"Expected: {expected} but got {result}. Length mismatch"
for i in range(len(arr1)):
    assert np.all(result[i] == expected[i]), f"Expected {expected[i]} but got {result[i]}"

  0%|          | 0/8 [00:00<?, ?/s]

8 8
8 8




8 8
8 8
8 8
8 8
8 8
8 8


AssertionError: Expected: [array([ 0,  3,  6,  9, 12, 15, 18, 21]), array([ 0,  3,  6,  9, 12, 15, 18, 21]), array([ 0,  3,  6,  9, 12, 15, 18, 21]), array([ 0,  3,  6,  9, 12, 15, 18, 21]), array([ 0,  3,  6,  9, 12, 15, 18, 21]), array([ 0,  3,  6,  9, 12, 15, 18, 21]), array([ 0,  3,  6,  9, 12, 15, 18, 21]), array([ 0,  3,  6,  9, 12, 15, 18, 21])] but got [array([ 0,  3,  6,  9, 12, 15, 18, 21]), array([ 0,  3,  6,  9, 12, 15, 18, 21]), array([ 0,  3,  6,  9, 12, 15, 18, 21]), array([ 0,  3,  6,  9, 12, 15, 18, 21])]. Length mismatch