diff --git a/README.md b/README.md index 2b27283..d948954 100644 --- a/README.md +++ b/README.md @@ -200,7 +200,7 @@ key: "ENABLE_CACHE_CLEANING" * `INTER_OP_THREAD_COUNT`: PyTorch allows using multiple CPU threads during TorchScript model inference. -One or more inference threads execute a model’s forward pass on the given +One or more inference threads execute a model's forward pass on the given inputs. Each inference thread invokes a JIT interpreter that executes the ops of a model inline, one by one. This parameter sets the size of this thread pool. The default value of this setting is the number of cpu cores. Please refer @@ -218,6 +218,11 @@ key: "INTER_OP_THREAD_COUNT" } ``` +> [!NOTE] +> This parameter is set globally for the PyTorch backend. +> The value from the first model config file that specifies this parameter will be used. +> Subsequent values from other model config files, if different, will be ignored. + * `INTRA_OP_THREAD_COUNT`: In addition to the inter-op parallelism, PyTorch can also utilize multiple threads @@ -238,6 +243,11 @@ key: "INTRA_OP_THREAD_COUNT" } ``` +> [!NOTE] +> This parameter is set globally for the PyTorch backend. +> The value from the first model config file that specifies this parameter will be used. +> Subsequent values from other model config files, if different, will be ignored. + * Additional Optimizations: Three additional boolean parameters are available to disable certain Torch optimizations that can sometimes cause latency regressions in models with complex execution modes and dynamic shapes. If not specified, all are enabled by default. diff --git a/src/libtorch.cc b/src/libtorch.cc index 26a2960..c873375 100644 --- a/src/libtorch.cc +++ b/src/libtorch.cc @@ -28,6 +28,7 @@ #include #include +#include #include "libtorch_utils.h" #include "triton/backend/backend_common.h" @@ -66,6 +67,11 @@ // PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API. // +namespace { +std::once_flag pytorch_interop_threads_flag; +std::once_flag pytorch_intraop_threads_flag; +} // namespace + namespace triton { namespace backend { namespace pytorch { // @@ -509,11 +515,15 @@ ModelState::ParseParameters() } } else { if (intra_op_thread_count > 0) { - at::set_num_threads(intra_op_thread_count); + // at::set_num_threads() does not throw if called more than once, but + // issues warnings. std::call_once() is useful to limit these. + std::call_once(pytorch_intraop_threads_flag, [intra_op_thread_count]() { + at::set_num_threads(intra_op_thread_count); + }); LOG_MESSAGE( TRITONSERVER_LOG_INFO, (std::string("Intra op thread count is set to ") + - std::to_string(intra_op_thread_count) + " for model instance '" + + std::to_string(at::get_num_threads()) + " for model instance '" + Name() + "'") .c_str()); } @@ -533,12 +543,22 @@ ModelState::ParseParameters() } } else { if (inter_op_thread_count > 0) { - at::set_num_interop_threads(inter_op_thread_count); + // at::set_num_interop_threads() throws if called more than once. + // std::call_once() should prevent this, but try/catch is additionally + // used for safety. + std::call_once(pytorch_interop_threads_flag, [inter_op_thread_count]() { + try { + at::set_num_interop_threads(inter_op_thread_count); + } + catch (const c10::Error& e) { + // do nothing + } + }); LOG_MESSAGE( TRITONSERVER_LOG_INFO, (std::string("Inter op thread count is set to ") + - std::to_string(inter_op_thread_count) + " for model instance '" + - Name() + "'") + std::to_string(at::get_num_interop_threads()) + + " for model instance '" + Name() + "'") .c_str()); } }