diff --git a/extension/pybindings/portable_lib.py b/extension/pybindings/portable_lib.py index 0982d55b474..27468c8b7b5 100644 --- a/extension/pybindings/portable_lib.py +++ b/extension/pybindings/portable_lib.py @@ -65,6 +65,7 @@ _load_program, # noqa: F401 _load_program_from_buffer, # noqa: F401 _reset_profile_results, # noqa: F401 + _threadpool_get_thread_count, # noqa: F401 _unsafe_reset_threadpool, # noqa: F401 BundledModule, # noqa: F401 ExecuTorchMethod, # noqa: F401 diff --git a/extension/pybindings/pybindings.cpp b/extension/pybindings/pybindings.cpp index c3cd4ed0b47..eb81bda22f7 100644 --- a/extension/pybindings/pybindings.cpp +++ b/extension/pybindings/pybindings.cpp @@ -1558,6 +1558,13 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) { }, py::arg("num_threads"), call_guard); + m.def( + "_threadpool_get_thread_count", + []() { + return ::executorch::extension::threadpool::get_threadpool() + ->get_thread_count(); + }, + call_guard); py::class_(m, "ExecuTorchModule") .def( diff --git a/extension/pybindings/pybindings.pyi b/extension/pybindings/pybindings.pyi index a3b75780369..9e5ab6211ce 100644 --- a/extension/pybindings/pybindings.pyi +++ b/extension/pybindings/pybindings.pyi @@ -288,3 +288,12 @@ def _unsafe_reset_threadpool(num_threads: int) -> None: This API is experimental and subject to change without notice. """ ... + +@experimental("This API is experimental and subject to change without notice.") +def _threadpool_get_thread_count() -> int: + """ + .. warning:: + + This API is experimental and subject to change without notice. + """ + ...