-
Notifications
You must be signed in to change notification settings - Fork 394
feat(runtime): add TensorRT-RTX runtime cache, dynamic shapes strategy, and native CUDA graph support to C++ runtime #4202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b96440e
01b9f38
481455f
2b630e8
54f9ccd
1fa8c82
a4989c7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| #include <algorithm> | ||
| #include <filesystem> | ||
|
|
||
| #include <cuda_runtime.h> | ||
| #include "NvInfer.h" | ||
|
|
@@ -54,26 +55,28 @@ void DynamicOutputAllocator::notifyShape(char const* tensorName, nvinfer1::Dims | |
| } | ||
|
|
||
| TRTEngine::TRTEngine( | ||
| const std::string& serialized_engine, | ||
| std::string serialized_engine, | ||
| const RTDevice& cuda_device, | ||
| const std::vector<std::string>& _in_binding_names, | ||
| const std::vector<std::string>& _out_binding_names, | ||
| const Platform& target_platform, | ||
| bool hardware_compatible, | ||
| bool requires_output_allocator, | ||
| const std::string& serialized_metadata, | ||
| const ResourceAllocationStrategy resource_allocation_strategy) | ||
| std::string serialized_metadata, | ||
| const ResourceAllocationStrategy resource_allocation_strategy, | ||
| TRTRuntimeConfig runtime_cfg) | ||
| : TRTEngine( | ||
| "deserialized_trt", | ||
| serialized_engine, | ||
| std::move(serialized_engine), | ||
| cuda_device, | ||
| _in_binding_names, | ||
| _out_binding_names, | ||
| target_platform, | ||
| hardware_compatible, | ||
| requires_output_allocator, | ||
| serialized_metadata, | ||
| resource_allocation_strategy) {} | ||
| std::move(serialized_metadata), | ||
| resource_allocation_strategy, | ||
| std::move(runtime_cfg)) {} | ||
|
|
||
| TRTEngine::TRTEngine(std::vector<std::string> serialized_info) | ||
| : TRTEngine( | ||
|
|
@@ -88,19 +91,22 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info) | |
| serialized_info[SERIALIZED_METADATA_IDX], | ||
| (static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) | ||
| ? ResourceAllocationStrategy::kDynamic | ||
| : ResourceAllocationStrategy::kStatic)) {} | ||
| : ResourceAllocationStrategy::kStatic), | ||
| make_runtime_config_from_serialized(serialized_info)) {} | ||
|
|
||
| TRTEngine::TRTEngine( | ||
| const std::string& mod_name, | ||
| const std::string& serialized_engine, | ||
| std::string mod_name, | ||
| std::string serialized_engine, | ||
| const RTDevice& cuda_device, | ||
| const std::vector<std::string>& _in_binding_names, | ||
| const std::vector<std::string>& _out_binding_names, | ||
| const Platform& target_platform, | ||
| bool hardware_compatible, | ||
| bool requires_output_allocator, | ||
| const std::string& serialized_metadata, | ||
| const ResourceAllocationStrategy resource_allocation_strategy) { | ||
| std::string serialized_metadata, | ||
| const ResourceAllocationStrategy resource_allocation_strategy, | ||
| TRTRuntimeConfig runtime_cfg) { | ||
| this->runtime_cfg = std::move(runtime_cfg); | ||
| TORCHTRT_CHECK( | ||
| is_supported_on_current_platform(target_platform), | ||
| "This engine was not built to run on this platform (built for: " << target_platform << ", current platform: " | ||
|
|
@@ -111,15 +117,15 @@ TRTEngine::TRTEngine( | |
| auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible); | ||
| TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine"); | ||
|
|
||
| this->serialized_metadata = serialized_metadata; | ||
| this->serialized_metadata = std::move(serialized_metadata); | ||
| this->requires_output_allocator = requires_output_allocator; | ||
| device_info = most_compatible_device.value(); | ||
| multi_gpu_device_check(); | ||
| set_rt_device(device_info); | ||
|
|
||
| rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger())); | ||
|
|
||
| name = slugify(mod_name); | ||
| name = slugify(std::move(mod_name)); | ||
|
|
||
| cuda_engine = make_trt(rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size())); | ||
| TORCHTRT_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine"); | ||
|
|
@@ -134,13 +140,7 @@ TRTEngine::TRTEngine( | |
| LOG_DEBUG( | ||
| "Resource allocation strategy: " | ||
| << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static")); | ||
| if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) { | ||
| this->exec_ctx = | ||
| make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); | ||
| } else { | ||
| this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); | ||
| } | ||
| TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context"); | ||
| recreate_execution_context(); | ||
|
|
||
| // Pre-allocate placeholder for empty tensors (TensorRT requires non-null addresses) | ||
| cudaMalloc(&empty_tensor_placeholder, 1); | ||
|
|
@@ -265,6 +265,9 @@ TRTEngine::TRTEngine( | |
|
|
||
| TRTEngine::~TRTEngine() { | ||
| torch::cuda::synchronize(device_info.id); | ||
| // Marked noexcept by the type system, so safe to invoke from a destructor without | ||
| // explicit try/catch; any I/O error is logged internally. | ||
| runtime_cfg.save_runtime_cache(); | ||
| trt_engine_profiler.reset(); | ||
| exec_ctx.reset(); | ||
| cuda_engine.reset(); | ||
|
|
@@ -278,8 +281,7 @@ void TRTEngine::disable_profiling() { | |
| torch::cuda::synchronize(device_info.id); | ||
| profile_execution = false; | ||
| trt_engine_profiler.reset(); | ||
| exec_ctx = make_trt(cuda_engine->createExecutionContext()); | ||
| TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context"); | ||
| recreate_execution_context(); | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Disabling profiling doesn't seem to respect |
||
| } | ||
|
|
||
| void TRTEngine::dump_engine_layer_info_to_file(const std::string& path) { | ||
|
|
@@ -376,10 +378,7 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) { | |
| trt_engine_profiler.reset(); | ||
| } | ||
| bool result = cuda_engine->setWeightStreamingBudgetV2(budget); | ||
| exec_ctx = make_trt(cuda_engine->createExecutionContext()); | ||
| TORCHTRT_CHECK( | ||
| (exec_ctx.get() != nullptr), | ||
| "Unable to recreate TensorRT execution context after setting new device memory budget"); | ||
| recreate_execution_context(); | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as https://github.com/pytorch/TensorRT/pull/4202/changes#r3120574611, confirm |
||
| if (profile_execution) { | ||
| enable_profiling(); | ||
| } | ||
|
|
@@ -428,6 +427,7 @@ std::string TRTEngine::to_str() const { | |
| ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl; | ||
| ss << " Target Platform: " << target_platform << std::endl; | ||
| ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl; | ||
| ss << runtime_cfg; | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use ss << runtime_cfg.to_str() here. |
||
| // clang-format on | ||
| return ss.str(); | ||
| } | ||
|
|
@@ -472,7 +472,14 @@ FlattenedState TRTEngine::__obj_flatten__() { | |
| std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]), | ||
| std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]), | ||
| std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]), | ||
| std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])); | ||
| std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]) | ||
| #ifdef TRT_MAJOR_RTX | ||
| , | ||
| std::tuple("runtime_cache_path", serialized_info[RUNTIME_CACHE_PATH_IDX]), | ||
| std::tuple("dynamic_shapes_kernel_strategy", serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX]), | ||
| std::tuple("cuda_graph_strategy", serialized_info[CUDA_GRAPH_STRATEGY_IDX]) | ||
| #endif | ||
| ); | ||
| } | ||
|
|
||
| std::vector<std::string> TRTEngine::serialize() { | ||
|
|
@@ -497,6 +504,13 @@ std::vector<std::string> TRTEngine::serialize() { | |
| serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize(); | ||
| serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = | ||
| this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0"; | ||
| #ifdef TRT_MAJOR_RTX | ||
| serialized_info[RUNTIME_CACHE_PATH_IDX] = runtime_cfg.runtime_cache_path; | ||
| serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = std::to_string( | ||
| static_cast<std::underlying_type_t<DynamicShapesKernelStrategy>>(runtime_cfg.dynamic_shapes_kernel_strategy)); | ||
| serialized_info[CUDA_GRAPH_STRATEGY_IDX] = | ||
| std::to_string(static_cast<std::underlying_type_t<CudaGraphStrategyOption>>(runtime_cfg.cuda_graph_strategy)); | ||
| #endif | ||
|
|
||
| return serialized_info; | ||
| } | ||
|
|
@@ -508,17 +522,38 @@ void TRTEngine::reset_captured_graph() { | |
| void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationStrategy new_strategy) { | ||
| if (new_strategy != this->resource_allocation_strategy) { | ||
| this->resource_allocation_strategy = new_strategy; | ||
| if (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) { | ||
| LOG_DEBUG("Setting resource allocation strategy to dynamic"); | ||
| this->exec_ctx = | ||
| make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); | ||
| } else { | ||
| LOG_DEBUG("Setting resource allocation strategy to static"); | ||
| this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); | ||
| } | ||
| LOG_DEBUG( | ||
| "Setting resource allocation strategy to " | ||
| << (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic ? "dynamic" | ||
| : "static")); | ||
| recreate_execution_context(); | ||
| } | ||
| } | ||
|
|
||
| bool TRTEngine::is_monolithic_capturable(cudaStream_t stream) const { | ||
| return runtime_cfg.is_monolithic_capturable(exec_ctx.get(), stream); | ||
| } | ||
|
|
||
| void TRTEngine::disable_rtx_native_cudagraphs() { | ||
| bool was_disabled = runtime_cfg.rtx_native_cudagraphs_disabled; | ||
| runtime_cfg.disable_rtx_native_cudagraphs(name); | ||
| if (!was_disabled && runtime_cfg.rtx_native_cudagraphs_disabled) { | ||
| // The CUDA graph strategy on the IRuntimeConfig has been flipped; rebuild exec_ctx | ||
| // so the new strategy takes effect for subsequent enqueueV3 calls. | ||
| recreate_execution_context(); | ||
| } | ||
| } | ||
|
|
||
| void TRTEngine::recreate_execution_context() { | ||
| runtime_cfg.ensure_initialized(cuda_engine.get()); | ||
| runtime_cfg.set_execution_context_allocation_strategy( | ||
| resource_allocation_strategy == ResourceAllocationStrategy::kDynamic | ||
| ? nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED | ||
| : nvinfer1::ExecutionContextAllocationStrategy::kSTATIC); | ||
| exec_ctx = make_trt(cuda_engine->createExecutionContext(runtime_cfg.config.get())); | ||
| TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Unable to (re)create TensorRT execution context"); | ||
| } | ||
|
|
||
| } // namespace runtime | ||
| } // namespace core | ||
| } // namespace torch_tensorrt | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,12 +13,25 @@ | |
| #include "torch/custom_class.h" | ||
|
|
||
| #include "core/runtime/TRTEngineProfiler.h" | ||
| #include "core/runtime/TRTRuntimeConfig.h" | ||
| #include "core/util/prelude.h" | ||
|
|
||
| namespace torch_tensorrt { | ||
| namespace core { | ||
| namespace runtime { | ||
|
|
||
| #ifdef TRT_MAJOR_RTX | ||
| // Extra FlattenedState entries for TensorRT-RTX-only fields. Leading comma so this | ||
| // macro can be dropped directly into the std::tuple parameter pack after the final | ||
| // shared entry without duplicating the per-entry type in both branches. | ||
| #define TRTRTX_FLATTENED_STATE_EXTRAS \ | ||
| , std::tuple<std::string, std::string> /* Runtime Cache Path */ \ | ||
| , std::tuple<std::string, std::string> /* Dynamic Shapes Kernel Strategy */ \ | ||
| , std::tuple<std::string, std::string> /* CUDA Graph Strategy */ | ||
| #else | ||
|
tp5uiuc marked this conversation as resolved.
|
||
| #define TRTRTX_FLATTENED_STATE_EXTRAS | ||
| #endif | ||
|
|
||
| using FlattenedState = std::tuple< | ||
| std::tuple<std::string, std::string>, // ABI_VERSION | ||
| std::tuple<std::string, std::string>, // name | ||
|
|
@@ -30,7 +43,8 @@ using FlattenedState = std::tuple< | |
| std::tuple<std::string, std::string>, // requires_output_allocator | ||
| std::tuple<std::string, std::string>, // serialized metadata | ||
| std::tuple<std::string, std::string>, // Platform | ||
| std::tuple<std::string, std::string>>; // Resource Allocation Strategy | ||
| std::tuple<std::string, std::string> /* Resource Allocation Strategy */ | ||
| TRTRTX_FLATTENED_STATE_EXTRAS>; | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO : Inline and fix |
||
|
|
||
| struct TorchTRTRuntimeStates { | ||
| // Indicates whether CUDAGraphs were enabled in the previous execute_engine | ||
|
|
@@ -125,31 +139,33 @@ struct TRTEngine : torch::CustomClassHolder { | |
|
|
||
| ~TRTEngine(); | ||
| TRTEngine( | ||
| const std::string& serialized_engine, | ||
| std::string serialized_engine, | ||
| const RTDevice& cuda_device, | ||
| const std::vector<std::string>& in_binding_names, | ||
| const std::vector<std::string>& out_binding_names, | ||
| const Platform& target_platform = get_current_platform(), | ||
| bool hardware_compatible = false, | ||
| bool requires_output_allocator = false, | ||
| const std::string& serialized_metadata = "", | ||
| std::string serialized_metadata = "", | ||
| const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = | ||
| TRTEngine::ResourceAllocationStrategy::kStatic); | ||
| TRTEngine::ResourceAllocationStrategy::kStatic, | ||
| TRTRuntimeConfig runtime_cfg = TRTRuntimeConfig{}); | ||
|
|
||
| TRTEngine(std::vector<std::string> serialized_info); | ||
|
|
||
| TRTEngine( | ||
| const std::string& mod_name, | ||
| const std::string& serialized_engine, | ||
| std::string mod_name, | ||
| std::string serialized_engine, | ||
| const RTDevice& cuda_device, | ||
| const std::vector<std::string>& in_binding_names, | ||
| const std::vector<std::string>& out_binding_names, | ||
| const Platform& target_platform = get_current_platform(), | ||
| bool hardware_compatible = false, | ||
| bool requires_output_allocator = false, | ||
| const std::string& serialized_metadata = "", | ||
| std::string serialized_metadata = "", | ||
| const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = | ||
| TRTEngine::ResourceAllocationStrategy::kStatic); | ||
| TRTEngine::ResourceAllocationStrategy::kStatic, | ||
| TRTRuntimeConfig runtime_cfg = TRTRuntimeConfig{}); | ||
|
|
||
| TRTEngine& operator=(const TRTEngine& other); | ||
| std::string to_str() const; | ||
|
|
@@ -217,6 +233,24 @@ struct TRTEngine : torch::CustomClassHolder { | |
| ResourceAllocationStrategy resource_allocation_strategy = kStatic; | ||
| void set_resource_allocation_strategy(ResourceAllocationStrategy new_strategy); | ||
| ResourceAllocationStrategy get_resource_allocation_strategy(); | ||
|
|
||
| // All TensorRT-RTX-specific IRuntimeConfig state lives here. On non-RTX builds this | ||
| // still owns a shared IRuntimeConfig (so the execution-context allocation strategy is | ||
| // applied via the uniform code path) but the RTX-only setters become no-ops. | ||
| TRTRuntimeConfig runtime_cfg; | ||
|
|
||
| // Monolithic-capturability check used when this engine is wrapped by an outer whole-graph | ||
| // capture (e.g. CudaGraphsTorchTensorRTModule). Non-RTX builds always return true. | ||
| bool is_monolithic_capturable(cudaStream_t stream) const; | ||
|
|
||
| // Disable TensorRT-RTX native CUDA graph capture on this engine (one-shot, invoked when | ||
| // an outer stream capture is detected around execute_engine). No-op on non-RTX. | ||
| void disable_rtx_native_cudagraphs(); | ||
|
|
||
| private: | ||
| // Single entry point that (re)creates exec_ctx. Also creates (once) the IRuntimeConfig | ||
| // owned by runtime_cfg and applies all runtime config settings. | ||
| void recreate_execution_context(); | ||
| }; | ||
|
|
||
| } // namespace runtime | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.