diff --git a/pde/solvers/explicit_mpi.py b/pde/solvers/explicit_mpi.py index 16b4cb9d..a79ec2bb 100644 --- a/pde/solvers/explicit_mpi.py +++ b/pde/solvers/explicit_mpi.py @@ -121,12 +121,14 @@ def _make_error_synchronizer(self) -> Callable[[float], float]: """return helper function that synchronizes errors between multiple processes""" if mpi.parallel_run: # in a parallel run, we need to return the maximal error - from ..tools.mpi import mpi_allreduce + from ..tools.mpi import Operator, mpi_allreduce + + operator_max_id = Operator.MAX @register_jitable def synchronize_errors(error: float) -> float: """return maximal error accross all cores""" - return mpi_allreduce(error, "MAX") # type: ignore + return mpi_allreduce(error, operator_max_id) # type: ignore return synchronize_errors # type: ignore else: diff --git a/pde/tools/mpi.py b/pde/tools/mpi.py index 6538c9f6..eb6e37fe 100644 --- a/pde/tools/mpi.py +++ b/pde/tools/mpi.py @@ -27,6 +27,11 @@ from numba import types from numba.extending import overload, register_jitable +try: + from numba.types import Literal +except ImportError: + from numba.types.misc import Literal + if TYPE_CHECKING: from numba_mpi import Operator @@ -184,18 +189,26 @@ def ol_mpi_allreduce(data, operator: int | str | None = None): if operator is None or isinstance(operator, nb.types.NoneType): op_id = -1 # value will not be used - elif isinstance(operator, nb.types.misc.StringLiteral): - op_id = Operator.id(operator.literal_value) - elif isinstance(operator, nb.types.misc.Literal): - op_id = int(operator) + elif isinstance(operator, Literal): + # an operator is specified (using a literal value + if isinstance(operator.literal_value, str): + # an operator is specified by it's name + op_id = Operator.id(operator.literal_value) + else: + # assume an operator is specified by it's id + op_id = int(operator.literal_value) + elif isinstance(operator, nb.types.Integer): + op_id = None # use given value of operator else: - raise RuntimeError("`operator` must be a literal type") + raise RuntimeError(f"`operator` must be a literal type, not {operator}") @register_jitable def _allreduce(sendobj, recvobj, operator: int | str | None = None) -> int: """helper function that calls `numba_mpi.allreduce`""" if operator is None: return numba_mpi.allreduce(sendobj, recvobj) # type: ignore + elif op_id is None: + return numba_mpi.allreduce(sendobj, recvobj, operator) # type: ignore else: return numba_mpi.allreduce(sendobj, recvobj, op_id) # type: ignore diff --git a/pde/tools/resources/requirements_basic.txt b/pde/tools/resources/requirements_basic.txt index 9d406830..01982929 100644 --- a/pde/tools/resources/requirements_basic.txt +++ b/pde/tools/resources/requirements_basic.txt @@ -1,7 +1,7 @@ # These are the basic requirements for the package matplotlib>=3.1 numba>=0.59 -numpy>=1.22 +numpy>=1.22,<2 scipy>=1.10 sympy>=1.9 tqdm>=4.66 diff --git a/pde/tools/resources/requirements_full.txt b/pde/tools/resources/requirements_full.txt index 16bab919..54881188 100644 --- a/pde/tools/resources/requirements_full.txt +++ b/pde/tools/resources/requirements_full.txt @@ -3,7 +3,7 @@ ffmpeg-python>=0.2 h5py>=2.10 matplotlib>=3.1 numba>=0.59 -numpy>=1.22 +numpy>=1.22,<2 pandas>=2 py-modelrunner>=0.18 rocket-fft>=0.2.4 diff --git a/pde/tools/resources/requirements_mpi.txt b/pde/tools/resources/requirements_mpi.txt index af4b800c..5aa6a093 100644 --- a/pde/tools/resources/requirements_mpi.txt +++ b/pde/tools/resources/requirements_mpi.txt @@ -4,7 +4,7 @@ matplotlib>=3.1 mpi4py>=3 numba>=0.59 numba-mpi>=0.22 -numpy>=1.22 +numpy>=1.22,<2 pandas>=2 scipy>=1.10 sympy>=1.9 diff --git a/pyproject.toml b/pyproject.toml index 0aeb1801..bd22625c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ ] # Requirements for setuptools -dependencies = ["matplotlib>=3.1", "numba>=0.59", "numpy>=1.22", "scipy>=1.10", "sympy>=1.9", "tqdm>=4.66"] +dependencies = ["matplotlib>=3.1", "numba>=0.59", "numpy>=1.22,<2", "scipy>=1.10", "sympy>=1.9", "tqdm>=4.66"] [project.optional-dependencies] io = ["h5py>=2.10", "pandas>=2", "ffmpeg-python>=0.2"] diff --git a/requirements.txt b/requirements.txt index a043079f..c634cc4a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ matplotlib>=3.1 numba>=0.59 -numpy>=1.22 +numpy>=1.22,<2 scipy>=1.10 sympy>=1.9 tqdm>=4.66 diff --git a/scripts/create_requirements.py b/scripts/create_requirements.py index ce35c6f1..2eccc5fd 100755 --- a/scripts/create_requirements.py +++ b/scripts/create_requirements.py @@ -72,7 +72,7 @@ def line(self, relation: str = ">=") -> str: ), Requirement( name="numpy", - version_min="1.22", + version_min="1.22,<2", usage="Handling numerical data", essential=True, ), diff --git a/tests/requirements_full.txt b/tests/requirements_full.txt index 16bab919..54881188 100644 --- a/tests/requirements_full.txt +++ b/tests/requirements_full.txt @@ -3,7 +3,7 @@ ffmpeg-python>=0.2 h5py>=2.10 matplotlib>=3.1 numba>=0.59 -numpy>=1.22 +numpy>=1.22,<2 pandas>=2 py-modelrunner>=0.18 rocket-fft>=0.2.4 diff --git a/tests/requirements_min.txt b/tests/requirements_min.txt index fde8f671..64383746 100644 --- a/tests/requirements_min.txt +++ b/tests/requirements_min.txt @@ -1,7 +1,7 @@ # These are the minimal requirements used to test compatibility matplotlib~=3.1 numba~=0.59 -numpy~=1.22 +numpy~=1.22,<2 scipy~=1.10 sympy~=1.9 tqdm~=4.66 diff --git a/tests/requirements_mpi.txt b/tests/requirements_mpi.txt index 82ab6622..b53de48f 100644 --- a/tests/requirements_mpi.txt +++ b/tests/requirements_mpi.txt @@ -4,7 +4,7 @@ matplotlib>=3.1 mpi4py>=3 numba>=0.59 numba-mpi>=0.22 -numpy>=1.22 +numpy>=1.22,<2 pandas>=2 scipy>=1.10 sympy>=1.9