Skip to content

Commit

Permalink
Use explicit Operator.MAX to aid compilation
Browse files Browse the repository at this point in the history
Fix handling of explicit operators in `allreduce`
  • Loading branch information
david-zwicker committed Jun 17, 2024
1 parent c7b34d4 commit ed4c3a8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
6 changes: 4 additions & 2 deletions pde/solvers/explicit_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,14 @@ def _make_error_synchronizer(self) -> Callable[[float], float]:
"""return helper function that synchronizes errors between multiple processes"""
if mpi.parallel_run:
# in a parallel run, we need to return the maximal error
from ..tools.mpi import mpi_allreduce
from ..tools.mpi import Operator, mpi_allreduce

operator_max_id = Operator.MAX

@register_jitable
def synchronize_errors(error: float) -> float:
"""return maximal error accross all cores"""
return mpi_allreduce(error, "MAX") # type: ignore
return mpi_allreduce(error, operator_max_id) # type: ignore

return synchronize_errors # type: ignore
else:
Expand Down
4 changes: 4 additions & 0 deletions pde/tools/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def ol_mpi_allreduce(data, operator: int | str | None = None):
else:
# assume an operator is specified by it's id
op_id = int(operator.literal_value)
elif isinstance(operator, nb.types.Integer):
op_id = None # use given value of operator

Check warning on line 201 in pde/tools/mpi.py

View check run for this annotation

Codecov / codecov/patch

pde/tools/mpi.py#L199-L201

Added lines #L199 - L201 were not covered by tests
else:
raise RuntimeError(f"`operator` must be a literal type, not {operator}")

Check warning on line 203 in pde/tools/mpi.py

View check run for this annotation

Codecov / codecov/patch

pde/tools/mpi.py#L203

Added line #L203 was not covered by tests

Expand All @@ -205,6 +207,8 @@ def _allreduce(sendobj, recvobj, operator: int | str | None = None) -> int:
"""helper function that calls `numba_mpi.allreduce`"""
if operator is None:
return numba_mpi.allreduce(sendobj, recvobj) # type: ignore
elif op_id is None:
return numba_mpi.allreduce(sendobj, recvobj, operator) # type: ignore

Check warning on line 211 in pde/tools/mpi.py

View check run for this annotation

Codecov / codecov/patch

pde/tools/mpi.py#L210-L211

Added lines #L210 - L211 were not covered by tests
else:
return numba_mpi.allreduce(sendobj, recvobj, op_id) # type: ignore

Expand Down

0 comments on commit ed4c3a8

Please sign in to comment.