Skip to content

Commit

Permalink
ENH: TST: add a threadpoolctl hook to limit OpenBLAS parallelism
Browse files Browse the repository at this point in the history
This is taken over with minor modifications from scikit-learn.
The most dramatic effect of this is for `scipy.linalg`, its test
suite runs in ~30 sec on a 12-core machine without this change,
and in ~5 sec with it when using `linalg.test(parallel=12)`. The full
test suite is also sped up significantly.

Closes gh-14425
  • Loading branch information
rgommers committed Nov 17, 2021
1 parent 9484b9d commit e300cbc
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ ignore_missing_imports = True
[mypy-mpmath]
ignore_missing_imports = True

[mypy-threadpoolctl]
ignore_missing_imports = True

#
# Extension modules without stubs.
#
Expand Down
35 changes: 35 additions & 0 deletions scipy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from distutils.version import LooseVersion
import numpy as np
import numpy.testing as npt
from scipy._lib._fpumode import get_fpu_mode
from scipy._lib._testutils import FPUModeChangeWarning

Expand Down Expand Up @@ -39,6 +40,40 @@ def pytest_runtest_setup(item):
if mark is not None and np.intp(0).itemsize < 8:
pytest.xfail('Fails on our 32-bit test platform(s): %s' % (mark.args[0],))

# Older versions of threadpoolctl have an issue that may lead to this
# warning being emitted, see gh-14441
with npt.suppress_warnings() as sup:
sup.filter(pytest.PytestUnraisableExceptionWarning)

try:
from threadpoolctl import threadpool_limits

HAS_THREADPOOLCTL = True
except Exception: # observed in gh-14441: (ImportError, AttributeError)
# Optional dependency only. All exceptions are caught, for robustness
HAS_THREADPOOLCTL = False

if HAS_THREADPOOLCTL:
# Set the number of openmp threads based on the number of workers
# xdist is using to prevent oversubscription. Simplified version of what
# sklearn does (it can rely on threadpoolctl and its builtin OpenMP helper
# functions)
try:
xdist_worker_count = int(os.environ['PYTEST_XDIST_WORKER_COUNT'])
except KeyError:
# raises when pytest-xdist is not installed
return

if not os.getenv('OMP_NUM_THREADS'):
max_openmp_threads = os.cpu_count() // 2 # use nr of physical cores
threads_per_worker = max(max_openmp_threads // xdist_worker_count, 1)
try:
threadpool_limits(threads_per_worker, user_api='blas')
except Exception:
# May raise AttributeError for older versions of OpenBLAS.
# Catch any error for robustness.
return


@pytest.fixture(scope="function", autouse=True)
def check_fpu_mode(request):
Expand Down

0 comments on commit e300cbc

Please sign in to comment.