feat(runtime): support binding TRTEngine execution to an external CUDA stream#4232
Open
shoumikhin wants to merge 9 commits intopytorch:mainfrom
Open
feat(runtime): support binding TRTEngine execution to an external CUDA stream#4232shoumikhin wants to merge 9 commits intopytorch:mainfrom
shoumikhin wants to merge 9 commits intopytorch:mainfrom
Conversation
…A stream
Adds opt-in support for binding torch-tensorrt's TRT engine execution to
externally-managed CUDA streams -- typically streams created via
`cuGreenCtxStreamCreate` for SM partitioning via CUDA Green Contexts (cuda
12.4+).
Currently, the runtime in `core/runtime/execute_engine.cpp` lazily pulls a
stream from torch's global stream pool on first execute. That pool is bound
to the primary CUDA context, so even when a caller sets a green-context-bound
stream as current, the TRT engine bypasses it and uses a primary-context pool
stream -- defeating any SM partitioning the caller set up.
Pure additive: no behavior change for callers that don't opt in.
This change adds two complementary mechanisms:
(1) Per-engine binding (Python / dynamo / Exported Program path):
- C++ API on `TRTEngine` (exposed via torchbind):
void set_external_stream(int64_t stream_handle);
void clear_external_stream();
int64_t get_external_stream() const;
The handle is `reinterpret_cast<int64_t>(cudaStream_t)`. Reachable from
Python and external C++ via `torch.classes.tensorrt.Engine`.
- Python facade in `torch_tensorrt.runtime.set_external_stream(module, ...)`
with optional per-engine binding via `Dict[submodule_name, StreamLike]`
and RAII context-manager semantics. Walks `named_modules()` so deeply
nested TRT submodules (e.g. HF blocks under wrapper GraphModules) are
reachable.
(2) Process-wide stream passthrough (AOTI / .pt2 C++ path):
- New global flag `ENGINE_STREAM_PASSTHROUGH` and accessors:
bool get_engine_stream_passthrough();
void set_engine_stream_passthrough(bool);
When enabled, `execute_engine` honors the caller's *current* CUDA stream
(`c10::cuda::getCurrentCUDAStream`) instead of acquiring a pool stream.
This unblocks `output_format="aot_inductor"` users whose `TRTEngine`
torchbind constants live inside `OSSProxyExecutor::custom_objs_`
(private, no public PyTorch accessor) and so are unreachable for the
per-engine API. Users wrap `loader.run(...)` in a `CUDAStreamGuard`
bound to e.g. a Green Context stream and the engine inherits it.
- Python facade: `torch_tensorrt.runtime.set_engine_stream_passthrough(bool)`
/ `get_engine_stream_passthrough()`.
Mutual exclusion with CUDA Graphs is enforced for both mechanisms (throws at
execute time). Setter and clearer also invalidate any captured graph so a
subsequent recapture happens cleanly (avoids replaying against a stale stream
identity).
Multi-GPU correctness: `engine_stream` and `caller_stream` are now pinned to
the engine's actual `device_info.id` in the constructor body (the in-class
initializers default to device 0; without this, the lazy pool re-acquire in
`execute_engine` skipped firing on `cuda:N` for `N>0` because the
`engine_stream == getDefaultCUDAStream(current_device_id)` check was always
false).
Same code path serves both the C++ AOTI runtime (model.so dispatch into
`execute_engine.cpp` via the C-shim) and the dynamo Python runtime
(`PythonTorchTensorRTModule`). Per-engine binding lets callers map distinct
green contexts to distinct TRT subgraphs in one compiled model. The
process-wide passthrough is the alternative for callers who can't reach the
engines individually (AOTI's private custom_objs_ map being the canonical
case).
Files changed:
- core/runtime/TRTEngine.{h,cpp} setter / clearer / getter, `external_stream` and `engine_stream_is_external` fields, cudagraph invalidation in setter & clearer, multi-GPU default-stream init in ctor
- core/runtime/execute_engine.cpp stream-resolve sites in both lambdas (regular + output-allocator paths) with per-engine + passthrough + pool fallback precedence; cudagraph mutual-exclusion guards
- core/runtime/runtime.{h,cpp} `ENGINE_STREAM_PASSTHROUGH` global + accessors
- core/runtime/register_jit_hooks.cpp torchbind exposure for all three per-engine methods + the two passthrough globals
- py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py Python runtime parity
- py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py passthrough on the C++-backed runtime
- py/torch_tensorrt/runtime/_external_stream.py top-level facade + context manager + passthrough toggles
- py/torch_tensorrt/runtime/__init__.py re-export
- tests/py/dynamo/runtime/test_006_external_stream.py covers both runtimes (swap, clear, per-engine binding, cudagraph guard, validation, passthrough routing, passthrough+cudagraph mutex)
Test plan:
- pytest tests/py/dynamo/runtime/test_006_external_stream.py -- both
PythonTorchTensorRTModule and TorchTensorRTModule runtime classes.
- GPU runtime test on H100 / Hopper: register a green-context-bound stream,
run a small TRT engine, verify via nsys profile that kernel launches are
confined to the green context's SM partition.
- GPU runtime test on Jetson Thor (Blackwell): same as above with sm_110.
- AOTI C++ test: `set_engine_stream_passthrough(True)`, wrap
`AOTIModelPackageLoader::run()` with a `CUDAStreamGuard` on a green-context
stream, verify SM-partitioned execution.
Open item (deliberately not in this commit, can land separately):
- Device-affinity validation in `set_external_stream`. The current sanity
check (`cudaStreamGetFlags`) confirms the handle is real but does not
validate the stream's device against `device_info.id`. A multi-GPU caller
could silently bind a wrong-device stream. Clean fix uses `cuStreamGetCtx`
+ `cuCtxGetDevice` (driver API) or `cudaStreamGetDevice` (CUDA 12.5+).
bfa0fea to
a0434e4
Compare
Author
Long-term plan: upstream PyTorch PROpened pytorch/pytorch#182149 to add torch::inductor::AOTIModelPackageLoader loader("model.pt2");
for (auto& [name, ivalue] : loader.get_custom_objs()) {
if (auto e = ivalue.toCustomClass<torch_tensorrt::TRTEngine>()) {
e->set_external_stream(reinterpret_cast<int64_t>(my_green_stream));
}
}
loader.run(inputs);This PR (#4232) ships |
Bundle 1 (must-fix from reviewers):
1. Device-affinity validation in set_external_stream
- cuStreamGetCtx + cuCtxPushCurrent + cuCtxGetDevice resolves the stream's
device and asserts it matches engine.device_info.id. Catches the silent
cross-device launch (cuda:1 stream bound to cuda:0 engine) before any
enqueueV3, where the failure would otherwise surface as a confusing
CUDA error far from the bind site.
2. Reject magic stream values
- cudaStreamLegacy / cudaStreamPerThread are now explicitly rejected.
Binding them latches engine_stream_is_external onto a non-isolated
stream that defeats the whole point of the API.
3. Atomic rollback on partial multi-engine bind (Python facade)
- The set_external_stream loop now records each successful application
and reverses them on any failure, so an engine's per-handle validation
throwing midway through a Dict-shaped binding can no longer leave
earlier engines in a half-bound state.
4. Re-entrancy / deadlock fix
- mu is now std::recursive_mutex everywhere on TRTEngine. Allows TRT
plugin -> Python -> set_external_stream re-entry on the same thread
without self-deadlock. Zero downside for the non-reentrant path.
5. Cudagraph mutual-exclusion check moved to set time
- set_external_stream now asserts CUDAGRAPHS_MODE == STANDARD up front
instead of waiting until next execute. Faster failure, clearer call
site, no wasted input migration etc. before the throw. The execute-
time guard remains as defense-in-depth (covers cudagraphs being
enabled AFTER an external stream is bound).
6. is_external_stream_set() companion accessor
- Avoids the ambiguous get_external_stream() == 0 sentinel pattern.
ABI-safe, cheap, exposed via torchbind.
7. Error message typo fix
- 'wraps a non-null CUDA stream is required' -> 'must wrap a non-null
CUDA stream'.
Defer to follow-up: Python torch.cuda.default_stream(self.device) one-char
fix, additional tests (green-context smoke, restore-non-zero-prior,
serialize round-trip), passthrough relocation, NCCL+external_stream
LOG_WARNING.
… avoid libcuda link on Jetpack
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds opt-in support for binding torch-tensorrt's TRT engine execution to externally-managed CUDA streams — typically streams created via
cuGreenCtxStreamCreatefor SM partitioning via CUDA Green Contexts (CUDA 12.4+). The motivating workload is edge / on-device multi-tenant inference on Jetson-class hardware where a vision encoder + policy net + diffusion head all share one process and need disjoint SM partitions to avoid time-slicing.Currently,
core/runtime/execute_engine.cpplazily pulls a stream from torch's global stream pool on first execute. That pool is bound to the primary CUDA context, so even when a caller sets a green-context-bound stream as current (viac10::cuda::CUDAStreamGuard/torch.cuda.stream(...)), the TRT engine bypasses it and uses a primary-context pool stream — defeating any SM partitioning the caller set up.Pure additive: no behavior change for callers that don't opt in.
Two complementary mechanisms
The PR ships two ways to bind a stream, sized to two different deployment shapes:
(1) Per-engine binding — for Python / dynamo /
output_format="exported_program"Reach the
TRTEnginetorchbind through the wrappingnn.Module'snamed_modules()and bind a stream per engine. This is the canonical multi-engine SM-partitioning case where one compiled model contains several TRT subgraphs that should each run on a distinct green context.New C++ API on
TRTEngine(exposed via torchbind):Reachable from Python and external C++ via
torch.classes.tensorrt.Engine.New Python facade with RAII context-manager semantics:
set_external_streamwalksnamed_modules()recursively, so deeply nested TRT submodules (e.g. HF blocks under wrapperGraphModules) are reachable. Submodule names are dotted paths, validated up front so a bad value cannot leave a partially-bound module. The setter validates the stream's device-affinity against the engine's target device (viacuStreamGetCtx+cuCtxGetDevice) and rejects the legacy / per-thread magic stream IDs; the binding is applied atomically across multiple engines (any per-engine failure rolls back successfully-applied bindings before re-raising).(2) Process-wide stream passthrough — for AOTI /
.pt2C++ deploymentsWhen the model is exported with
output_format="aot_inductor"and consumed in pure C++ viatorch::inductor::AOTIModelPackageLoader, the liveTRTEnginetorchbind instances live insideOSSProxyExecutor::custom_objs_— private with no public PyTorch accessor. Re-parsing the.pt2only yields independentIValuecopies that the running.sonever invokes, so the per-engine API in (1) is unreachable.The fix: a process-wide opt-in flag that makes
execute_enginehonor the caller's current CUDA stream instead of the lazy pool stream. Users wraploader.run(...)in aCUDAStreamGuardand the engine inherits it.New globals (C++ + Python):
C++ usage after merge:
Python usage after merge (also valid for AOTI-loaded models via
torch._inductor.aoti_load_package):Precedence
When multiple sources are configured, the resolver picks in this order, every call (so
set/cleartake effect immediately without recreating the engine):external_stream(set viaTRTEngine::set_external_stream)ENGINE_STREAM_PASSTHROUGH→ caller's current CUDA streamgetStreamFromPool) — unchanged default behaviorMutual exclusion with CUDA Graphs
Both mechanisms are mutually exclusive with CUDA Graphs. The check fires at bind time (
set_external_stream/set_engine_stream_passthrough(true)throw if cudagraphs are currently enabled) and again at execute time as defense-in-depth (covers cudagraphs being enabled after the binding):The setter and clearer also invalidate any captured graph (
cudagraph.reset()) so a subsequent recapture happens cleanly and never replays against a stale stream identity.Multi-GPU correctness fix folded in
TRTEngine::engine_streamandTRTEngine::caller_streamare now pinned to the engine's actualdevice_info.idin the constructor body. The in-class initializers atTRTEngine.h:211-212default to device 0 (no device arg). Without this fix, the lazy pool re-acquire inexecute_enginecheckedengine_stream == getDefaultCUDAStream(current_device_id)— always false oncuda:NforN>0— so the engine ran oncuda:0's default stream regardless of the input device. Pre-existing bug; fixed here while we were in the area.Files changed
core/runtime/TRTEngine.{h,cpp}external_stream+engine_stream_is_externalfields, cudagraph invalidation in setter & clearer, multi-GPU default-stream init in ctorcore/runtime/execute_engine.cppcore/runtime/runtime.{h,cpp}ENGINE_STREAM_PASSTHROUGHglobal +get_/set_engine_stream_passthrough()accessorscore/runtime/register_jit_hooks.cpppy/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.pypy/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.pypy/torch_tensorrt/runtime/_external_stream.pypy/torch_tensorrt/runtime/__init__.pytests/py/dynamo/runtime/test_006_external_stream.pyTest plan
pytest tests/py/dynamo/runtime/test_006_external_stream.py— covers bothPythonTorchTensorRTModuleandTorchTensorRTModuleruntime classes, including the new passthrough tests.nsys profilethat kernel launches are confined to the green context's SM partition.set_engine_stream_passthrough(true)+ wrapAOTIModelPackageLoader::run()with aCUDAStreamGuardon a green-context stream, verify SM-partitioned execution innsys.Out of scope (future PRs)
AOTIModelPackageLoader::get_custom_objs()so AOTI users can also use the per-engine API (when they want different streams per submodule inside one.pt2). The passthrough flag in this PR is the interim mechanism while that lands and reaches stable.torch_tensorrt::aoti::TRTAOTILoaderC++ wrapper behind aTORCH_TRT_HAVE_AOTI_CUSTOM_OBJSCMake probe — depends on the upstream PR.requires_native_multidevicepath may need follow-up if a user combines both.