feat: reintroduce TRT-RTX runtime cache, dynamic shapes, and native CUDA graph support#4294
feat: reintroduce TRT-RTX runtime cache, dynamic shapes, and native CUDA graph support#4294tp5uiuc wants to merge 5 commits into
Conversation
…on TRTEngine
The Python runtime rework moved the runtime-cache and dynamic-shape kernel
specialization machinery off of PythonTorchTensorRTModule, but only the
``runtime_config = None`` stub was carried over to the new TRTEngine.
Everything that actually populated those attributes -- the runtime config
build-up, the file-locked cache load/save, the dynamic-shape strategy
mapping -- was dropped. On non-RTX builds this manifested as
``AttributeError: 'TRTEngine' object has no attribute '_runtime_cache'``
in test_no_runtime_config_for_standard_trt; on RTX builds the runtime
cache was effectively dead code.
This restores the missing pieces on TRTEngine:
- New module-level helper _get_dynamic_shapes_kernel_strategy mapping the
setting string ('lazy' / 'eager' / 'none') to the TRT-RTX enum.
- __init__ and __setstate__ unconditionally initialize runtime_config and
runtime_cache to None so the destructor's save path is safe even on
partially-loaded engines. runtime_cache_path is set in
_load_serialized_info once compilation settings are decoded, so it is
always available regardless of build flavor.
- _setup_engine calls _setup_runtime_config on RTX builds and rebuilds the
execution context so it picks up the new IRuntimeConfig.
_create_execution_context routes through runtime_config when one is
present, falling back to the strategy-based path otherwise.
- Three new helpers ported from the old PythonTorchTensorRTModule:
_setup_runtime_config (builds the IRuntimeConfig, sets allocation
strategy + dynamic-shape strategy, creates the runtime cache, hydrates
it from disk, binds it to the config), _load_runtime_cache (shared
file-locked deserialize) and _save_runtime_cache (exclusive file-locked
serialize, getattr-guarded so destructor paths on partially-constructed
engines stay safe).
- __del__ now calls _save_runtime_cache before reset_captured_graph, with
an exception swallow so engine teardown never throws.
Tests were also unified on the public attribute names:
- test_000_runtime_cache: the four assertions that previously read the
underscored ``engine._runtime_config`` / ``engine._runtime_cache`` now
use the public names (matching what test_001 was already expecting).
This is what makes test_no_runtime_config_for_standard_trt stop raising
AttributeError on non-RTX CI.
- test_001_dynamic_shapes_kernel_strategy: the non-RTX assertion was a
defensive ``getattr(engine, '_runtime_config', None)`` that would have
silently passed even with the attribute missing; switched to a direct
``engine.runtime_config`` read so it actually exercises the contract.
Verified on an A100 RTX build: all 12 RTX runtime-cache tests and all 6
RTX dynamic-shape strategy tests pass; non-RTX gated tests are skipped on
this build as expected.
TensorRT-RTX has native CUDA graph support via IRuntimeConfig.cuda_graph_strategy, where the JIT compiler handles capture/replay/invalidation internally. This is strictly safer than manual torch.cuda.CUDAGraph capture on RTX because: - lazy-compiled specialized kernels can replace a captured fallback path on the fly, which would invalidate a manually captured graph - runtime allocation / data-dependent shapes can cause cudaStreamBeginCapture to fail outright - the JIT compiler tracks graph staleness (shape changes, pointer changes, kernel readiness) for us Settings + entry points (additive, no behavior change for non-RTX): - _defaults: new CUDA_GRAPH_STRATEGY = "disabled" - _settings.CompilationSettings: new cuda_graph_strategy field (str) - _compiler.compile / cross_compile_for_windows / convert_exported_program_to_serialized_trt_engine: new cuda_graph_strategy kwarg threaded through TRTEngine wiring: - New module-level helper _get_cuda_graph_strategy mapping "disabled" / "whole_graph_capture" -> trt.CudaGraphStrategy. - __init__ / __setstate__ initialize self._rtx_native_cudagraphs = False so the forward path can always read it (including on partially-constructed engines). - _setup_runtime_config also sets self.runtime_config.cuda_graph_strategy from settings on RTX builds (with a paired log line). - _setup_engine latches self._rtx_native_cudagraphs = (RTX and cuda_graph_strategy != "disabled") right after _setup_runtime_config. - New _is_monolithic_capturable(stream): non-RTX returns True (preserves existing behavior); RTX returns False if the IExecutionContext is not stream-capturable or if dynamic-shape strategy is "lazy" (lazy-compiled specialized kernels would invalidate a captured graph). - New _enable_rtx_native_cudagraphs(): rewrites cuda_graph_strategy on the IRuntimeConfig to WHOLE_GRAPH_CAPTURE and rebuilds the execution context. - _execute_standard reads cudagraphs_enabled once at the top; on RTX + cudagraphs enabled + not yet RTX-native, transparently switches to RTX-native (with a warning that tells the user how to set it at compile time). Computes effective_cudagraphs = cudagraphs_enabled and not _rtx_native_cudagraphs and uses it everywhere downstream so the manual torch.cuda.CUDAGraph capture path is bypassed when TRT-RTX owns capture. - Debug log appends " (RTX native)" when _rtx_native_cudagraphs is set. CudaGraphsTorchTensorRTModule wiring (whole-graph mode with mixed TRT + PyTorch): - New _check_monolithic_capturability(stream) iterates the compiled subgraph looking for TorchTensorRTModule whose .engine is a TRTEngine. For each, it calls engine._is_monolithic_capturable and raises RuntimeError if any fails. If an engine has RTX-native cudagraphs on, this turns them off (sets the IRuntimeConfig back to DISABLED and rebuilds the context) so the inner RTX capture cannot interfere with the outer torch.cuda.CUDAGraph capture. - The check fires from forward() just before need_cudagraphs_record allocates the outer torch.cuda.CUDAGraph. Tests: - runtime/test_001_cuda_graph_strategy.py (new): 17 cases covering setup, RTX-native override under SUBGRAPH cudagraphs, _is_monolithic_capturable for each dynamic-shape strategy, context-recreation on _enable_rtx_native_cudagraphs, cudagraphs mode toggle, and a non-RTX gated case. Mirrors the test-helper convention from test_001_dynamic_shapes_kernel_strategy.py (_find_python_trt_engine returns the engine, not the wrapping module). - models/test_cuda_graph_strategy_models.py (new): end-to-end resnet18 and dynamic-batch ConvModel tests for both "disabled" and "whole_graph_capture" strategies. Verified on an A100 RTX build: - test_001_cuda_graph_strategy: 17 passed, 1 skipped (non-RTX gated) - test_000_runtime_cache: 12 passed, 2 skipped (no regression vs. commit 1) - test_001_dynamic_shapes_kernel_strategy: 6 passed, 1 skipped (no regression)
Squash candidates for the two commits on this branch; landed as a separate fixup so each individual review-suggestion is traceable. - Drop the verbose comment blocks on the runtime_config / runtime_cache and _rtx_native_cudagraphs init lines; keep only the destructor / manual-capture rationale. - Remove the self.runtime_cache_path attribute. Inline self.settings.runtime_cache_path in _load_runtime_cache and _save_runtime_cache; the test that previously asserted engine.runtime_cache_path now reads engine.settings.runtime_cache_path. - Refactor _create_execution_context to a clean if/else where the RTX branch asserts runtime_config is not None and uses it directly. Drop the getattr defensive check. - Move _setup_runtime_config to before the first execution-context creation in _setup_engine so we only create the context once. The NCCL barrier still runs against a live context; the RTX runtime config feeds straight into that single context creation. - Add _save_runtime_cache to close(). __del__ now delegates to close(). Drop the try/except in __del__ -- _save_runtime_cache already swallows exceptions internally. - Inline the alloc_strategy decision in _setup_runtime_config as a one-liner ternary. - Drop the getattr-guard in _save_runtime_cache; rely on the runtime_cache = None init in __init__ / __setstate__. - _is_monolithic_capturable now uses any(...) over a tuple of not-capturable conditions instead of multiple if/return False. - Shorten the section header to "# --- TensorRT-RTX ---". - Drop the redundant comment block above the RTX-native override in _execute_standard -- the warning message says the same thing. - Reword the effective_cudagraphs comment to "the manual torch.cuda.CUDAGraph machinery is skipped". Tested on A100 RTX (jobid 2322243, container git_trt_tejaswinp_xaajfdwx): - test_000_runtime_cache: 12 passed, 2 skipped - test_001_dynamic_shapes_kernel_strategy: 6 passed, 1 skipped, 3 subtests passed - test_001_cuda_graph_strategy: 17 passed, 1 skipped, 2 subtests passed
|
[by Claude Code]
|
| def set_use_output_allocator(self, enable: bool) -> None: | ||
| self.use_output_allocator_outputs = enable | ||
|
|
||
| def _check_monolithic_capturability(self, stream: torch.cuda.Stream) -> None: |
There was a problem hiding this comment.
Whats the difference between monolithic capture and the "whole graph capture mode"?
There was a problem hiding this comment.
Hi Naren, thanks for the Q. What I meant is the following:
-
Monolithic capture: This path is similar to standard TRT. It relies on the outer torch.cuda.CUDAGraph recorded by _CudaGraphsTorchTensorRTModule.forward (after the user calls
torchtrt.runtime.enable_cudagraphs()) capturing TRT-RTX engines. Similar to standard TRT, it wraps the entire compiled subgraph (every TRT engine and any intervening PyTorch glue ops) into onemonolithictorch-side graph. Scope = whole pytorch graph. This code-path gets triggered forCudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS -
Whole graph capture: This is a TRT-RTX exclusive path and obtained when setting
cuda_graph_strategy="whole_graph_capture"per engine. This relies on TRT-RTX to manage captures : internally the JIT compiler captures/replays/invalidates the graph inside execute_async_v3, per-engine. The name "whole graph" is for the "whole" TRT-RTX forwarded graph. Scope = one TRT engine. This code-path gets triggered forCudaGraphsMode.SUBGRAPH_CUDAGRAPHS
When to use what
These two paths are mutually exclusive, even when running through enable_cuda_graphs().
The guidance would be to prefer (2) whole graph capture, as this would get best perf (per an TRT-RTX engine). However the condition here is that there are no intervening pytorch ops (as this will break TRT-RTX's internal graph capture status), which is consistent with it getting triggered for CudaGraphsMode.SUBGRAPH_CUDAGRAPHS.
However, in case there are PyTorch ops in between (because of graph breaks, op incompleteness), the enable_cuda_graphs() will choose the CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS mode. This means we will attempt to wrap N TRT-RTX engines + PyTorch ops in one outer torch.cuda.CUDAGraph and each engine has to leave its own (internal RTX-cudagraph) capture off, otherwise the RTX-native capture would interfere with the outer torch capture. That's what _check_monolithic_capturability enforces : it asserts every engine can be stream-captured (is_stream_capturable, no lazy-with-dynamic-shapes), and forces RTX-native off on each one (cuda_graph_strategy=DISABLED + context rebuild) before the outer torch.cuda.graph() block runs.
Name sharpening
Now, naming-wise, "monolithic" was picked to avoid the literal collision with "whole_graph_capture" since the latter is a TRT-RTX-defined enum name but also confusing with Torch-TRT's CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS. The difference lies in the distinction of "whole" for TRT-RTX vs torch-TRT. I am open to renaming either side if you have a preference.
Perhaps a good middle ground is to rename the enum in the newly introduced cuda_graph_strategy to disabled and per_engine_capture. per_engine_capture can internally map to TRT-RTX's whole_graph_capture mode. This is easier for the users at least IMO.
P.S. I will also add these best practices/guidance to the documentation. Also the behavior matrix (of how cudagraph mode interacts with TRT-RTX) is documented more in the PR description.
…n runtime) Both python-runtime variants in test_004_weight_streaming.py combine torchtrt.runtime.enable_cudagraphs() (manual whole-graph torch CUDA graph capture) with enable_weight_streaming=True. This combination is fundamentally unsupported on TRT-RTX: weight H2D copies run on a dedicated stream with cross-stream event synchronization, which a single-stream torch.cuda.CUDAGraph capture cannot record. A captured graph would replay against stale or uninitialized weights. The new monolithic-capturability check in the CUDA graph strategy feature already raises RuntimeError for this case at runtime and points at the supported path (cuda_graph_strategy="whole_graph_capture" with set_cudagraphs_mode(True)). The skip avoids the noisy failure during CI sweeps. Skip condition keys off ENABLED_FEATURES.tensorrt_rtx alone, since both tests live in TestWeightStreamingPython and are already python-runtime only.
_is_monolithic_capturable previously returned False whenever the kernel
specialization strategy was "lazy", regardless of whether the engine
actually compiles shape-specialized kernels at runtime. For a
static-shape engine (no DYNAMIC_DIM on any input binding) the lazy
strategy is a no-op -- there are no further specializations possible
after build, so a captured CUDA graph cannot be invalidated by
mid-replay specialization. Empirically TRT-RTX's own
context.is_stream_capturable() returns True for static-shape engines
under lazy strategy, confirming the kernel-readiness concern does not
apply.
Keep the lazy clause but gate it on
any(DYNAMIC_DIM in shape for shape in self.input_shapes)
so it only fires when shape-specialized kernels can actually appear
later. This removes a false-negative that was blocking monolithic
capture for static-shape RTX users on the default ("lazy") strategy.
TRTEngine did not previously cache self.input_shapes the way it caches
self.output_shapes; this commit adds the parallel population in
_setup_engine, mirroring the convention used elsewhere. DYNAMIC_DIM is
imported from torch_tensorrt.dynamo.utils.
Description
PR #4222 (Python runtime rework) replaced
PythonTorchTensorRTModulewith the newTRTEnginebut the TRT-RTX feature surface from #4180 / #4184 / #4187 was lost in the move. This PR reintroduces it onTRTEngine+_CudaGraphsTorchTensorRTModule. Design mirrors #4187, retargeted onto the new split.Two commits for ease of review.
Commit 1 —
fix(runtime): restore TRT-RTX runtime-cache and dynamic-shapes setupFixes the
AttributeError: 'TRTEngine' object has no attribute '_runtime_cache'failure intest_no_runtime_config_for_standard_trton standard-TRT CI and restores the runtime-cache + dynamic-shape-strategy machinery on RTX. Ports_setup_runtime_config,_load_runtime_cache,_save_runtime_cachefrom the deletedPythonTorchTensorRTModule.__init__/__setstate__unconditionally initializeruntime_configandruntime_cachetoNoneso destructor save paths are safe. Tests intest_000_runtime_cache.py/test_001_dynamic_shapes_kernel_strategy.pyare realigned on the publicruntime_config/runtime_cacheattribute names (the underscored variant was reading attributes that no longer existed after #4222).Commit 2 —
feat(runtime): add TRT-RTX native CUDA graph supportSame intent as #4187:
cuda_graph_strategysetting ("disabled"/"whole_graph_capture") threaded through the three_compilerentry points._setup_runtime_configalso setsIRuntimeConfig.cuda_graph_strategyfrom settings._setup_enginelatches_rtx_native_cudagraphs._is_monolithic_capturable(stream)—Falseon RTX when the context isn't stream-capturable or when dynamic-shape strategy is"lazy"(lazy specialization would invalidate a captured graph)._enable_rtx_native_cudagraphs()— flips toWHOLE_GRAPH_CAPTUREand rebuilds the context._execute_standardtransparently overrides to RTX-native when SUBGRAPH cudagraphs are requested on RTX (with a warning pointing at the compile-time setting), and gates the manualtorch.cuda.CUDAGraphpath oneffective_cudagraphs = cudagraphs_enabled and not _rtx_native_cudagraphs._CudaGraphsTorchTensorRTModule._check_monolithic_capturabilitywalks the compiled subgraph, asserts each TRT engine is capturable, and forces RTX-native engines back toDISABLED(rebuilding their contexts) so the outer monolithictorch.cuda.CUDAGraphcapture isn't interfered with.Behavior matrix (unchanged from #4187)
RuntimeErrorVerification
A100 RTX build:
runtime/test_000_runtime_cache.py: 12 passed, 2 skipped (non-RTX gated, including the originally-failing test)runtime/test_001_dynamic_shapes_kernel_strategy.py: 6 passed, 1 skippedruntime/test_001_cuda_graph_strategy.py(new): 17 passed, 1 skippedmodels/test_cuda_graph_strategy_models.py(new) — included with the same shape as feat: add TRT-RTX native CUDA graph support #4187Non-RTX path is correct by inspection (
runtime_config/runtime_cacheare initialized toNonebefore the RTX gate). Worth a non-RTX CI confirm before landing.Type of change
Checklist