✨[Feature] Multiple Optimization Profiles for Disjoint Input Shape Regimes #4313
Replies: 4 comments 22 replies
-
|
@cehongwang is this something we could add to the Input class instead of an additional API. Like allowing for disjoint shape ranges? Something like Input(
profiles = {
"prefill": {
"min": ..., "max": .., "opt": ...
}, "decode": {
"min": ..., "max": .., "opt": ...
}
]
)
With some cross input error checking? |
Beta Was this translation helpful? Give feedback.
-
|
There is a somewhat related thread that @apbose should be working on which is named tuples for different dimensions to allow for cross input dynamic dimensions. These two features should work together |
Beta Was this translation helpful? Give feedback.
-
How do we store the information different profiles to be used at runtime? |
Beta Was this translation helpful? Give feedback.
-
|
There are apis to check number of profiles and each profile tensor shape.
We can detect it from the engine. But if the engine is compiled with
torchtrt why do we need to detect it?
Users can manage this via context manager. You have a better way to deal
with it?
|
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
RFC: Multiple Optimization Profiles for Disjoint Input Shape Regimes
torch_tensorrt.dynamo(AOTcompile/torch.compilebackend)_Input,_tracer,_TRTInterpreter,_TRTEngine,partitioning/common.py, engine cache, runtime1. Problem
Torch-TensorRT builds one TRT
IOptimizationProfileper engine (_TRTInterpreter.__init__line 123;placeholder()always writesoptimization_profiles[0]— see TODO at line 712).That is fine for unimodal shape distributions. It is a poor fit for bimodal workloads like LLM inference:
input_ids[B, 32..2048][B, 1]One profile spanning
[1, 2048]withopt=1024picks kernels that are wrong for both phases. TensorRT supports N profiles per engine; we should expose that.Measured gap (Alpamayo / Edge-LLM)
Decode benchmark on
umb-b200-247,nvcr.io/nvidia/pytorch:25.12-py3, TensorRT v10.14 (trtexec --dumpProfile --dumpLayerInfo --profilingVerbosity=detailed --separateProfileRun). All runs at decode shape:batch=6,inputs_embeds=6×1×4096,past_key_values=6×2×8×4096×128(KV len 4096).opt, seq≈3424)inputs_embedsopt/max6×1×4096)Decode-profile Torch-TRT is ~2.1× faster than the prefill-oriented build and matches ONNX-TRT decode. Layer profiles show the original Torch-TRT engine pays a large FC penalty tuned for long context; the decode-profile engine's qkv/o/up/down matmul totals align with ONNX-TRT decode — consistent with
opttargeting seq=3424 while the benchmark runs seq=1.Today this required a separate sanity-check engine with a decode-only profile. Multi-profile support would keep one engine and switch at runtime.
2. Goals / Non-goals
Goals
optimization_profile(trt_gm, name); auto-select opt-in viaoptimization_profile(trt_gm, "auto")or compile flag when shapes alone should determine the profile.Input.profiles→ today's zero/one-profile behavior (min/opt/max_shapeor EP-only).torch.export(AOT path) andtorch.compile(backend="tensorrt")(JIT path).3. Design: export once, specialize in TRT
torch.exportgives each dynamic dim one contiguous[min, max]. There is no disjoint-union dim.Chosen approach: export once using the union of all profile ranges per dim, then attach N TRT profiles at build time.
Example: prefill
seq_len ∈ [32, 2048], decodeseq_len = 1→ export withDim("seq_len", min=1, max=2048).Rejected alternatives:
Dims — out of scope.The engine only accepts shapes inside a declared profile, even though export accepts the full union envelope.
3.1 Export envelope vs. profile ranges
Multi-profile relies on one
torch.exportover an envelope that covers every profile. Export traces once at theoptexample (Input.torch_tensor = example_tensor("opt_shape")in_Input.py;_tracer.pypasses that tensor toexport()). Each dynamic axis gets oneDim(min, max)whose range must span all profiles the user (or we) intend to build.Worked example (two disjoint image-size profiles):
"small"[1, 3, 64, 64][1, 3, 128, 128][1, 3, 256, 256]"large"[1, 3, 1024, 1024][1, 3, 2048, 2048][1, 3, 4096, 4096]Export must declare
img_dim = Dim("img_dim", min=64, max=4096)(elementwise union over profiles). TRT then gets two optimization profiles with those disjoint ranges. If the user hands us a profile whose[min, max]falls outside the export envelope, we cannot build it — reject at compile time with a clear error.Validation rules (compile time):
[min, max]exceeds exportDimenvelopeDimor shrink the profileDim(max=8192)but profiles only use up to4096)[64, 8192]but the engine only accepts shapes within the declared profilesPost-export, for each dynamic dim in
Input.profiles, assert every profile corner ⊆ShapeEnv.var_to_rangefor the corresponding export symbol. Document clearly: you (or we, when deriving the envelope fromInput.profiles) must export with the union range — profiles are subsets of export, never the other way around.Other failure modes (export trace itself):
if seq_len == 1:specializes to the opt branch; decode path never tracedextract_var_range_infoalso remapsmin=2 → 1)Dim(min=1, …)for decode; validate post-export that the symbol was not narrowedOptional corner check: run the exported / partitioned module at profile corners and confirm shapes match expectations. Not required for v1 if the model satisfies §5.
3.1.1 Overlapping profiles
Overlapping profiles are allowed. Two or more profiles may share part of their
[min, max]envelope on the same dynamic dim. This is common in LLM workloads:seq_lenrange"prefill"[1, 4096]"decode"[1, 1]At runtime, input shape
seq_len=1satisfies both profiles. Overlap is a feature, not an error — it lets one engine cover adjacent regimes without forcing disjoint bounds.Compile time: no restriction beyond §3.1 validation (each profile must still fit inside the export envelope). Overlapping
[min, max]ranges on the same binding are fine.Runtime (auto-selection enabled, §4.2): after ruling out invalid profiles per input and intersecting survivors:
optimization_profile(trt_gm, name)always overrides auto and skips this tie-break.Distance metric: for each surviving profile
p, compute a scalar distance to the current inputs. For every input bindingband every dynamic dimdwhere the profile defines an opt value:Use the profile's declared
opttuple per binding (from compile-timeInput.profiles, cached in_profile_dim_rangesalongside min/max at load). Static dims contribute 0. Pickpwith minimumdist(p); break ties deterministically by lowest profile index.Example (LLM overlap at decode): inputs_embeds shape
(6, 1, 4096).seqopt|1 − opt|"prefill""decode"Auto-selection picks
"decode"even though"prefill"also acceptsseq=1, because decode's opt is closest to the actual shape.Example (mid-prefill): inputs_embeds shape
(6, 512, 4096).[min, max]onseq"prefill"[1, 4096]"decode"[1, 1]Only
"prefill"survives → selected unambiguously. The closest-opt tie-break (§3.1.1 step 3) applies only among profiles that already pass the[min, max]check.Rationale: closest-opt picks the profile TensorRT already tuned kernels for, minimizing the gap between runtime shape and the profile's specialization point. Users who need a specific profile regardless of opt distance (e.g. force decode kernels during a warm-up shape that technically fits prefill) should pin manually.
3.2 Graph breaks: propagate profiles to intermediate submodules
Partitioning can produce a mix of Torch and TRT submodules. Each TRT submodule needs the same named profiles as the top-level inputs (e.g.
"decode"on every engine, including intermediate graphs). We do not re-export or re-trace per profile.Where intermediate shapes come from:
torch.exportattaches symbolic shapes to every placeholder viameta["val"](FakeTensor/SymInt). After a graph break,construct_submodule_inputs()reads those placeholders and builds TRTInputs:Today this yields one min/opt/max envelope per submodule input — the union range implied by export's
ShapeEnv(seeconstruct_dynamic_input,extract_var_range_info). That is correct for single-profile; multi-profile extends it.Propagation rule: intermediate tensor shapes are sympy expressions over the same source symbols assigned at export (e.g. top-level
s0). User-provided profile bounds apply to those source symbols; intermediate dims inherit profile bounds by evaluating the expression.Example:
[1, 3, s0, s0][1, 3, s0/4, s0/4]"small"s0: min=64, opt=128, max=256"large"s0: min=1024, opt=2048, max=4096For each profile name, substitute the profile's
{min, opt, max}value of every free symbol into the intermediateSymIntexpression (expr.xreplace(...)) and evaluate to an integer. Shape ops in export produce affine expressions (s0,s0/4,2*s0, …) that are monotonic in each source symbol, so per-corner substitution is exact — no separate trace per profile.Algorithm (per TRT submodule, after partition):
Input.profilesdynamic dim → export symbol name (read from top-level placeholdermeta["val"]SymInt nodes). Buildprofile_source_bounds: {profile_name: {symbol: {min, opt, max}}}.construct_submodule_inputs(submodule)as today to get symbolic shapes from placeholders; for eachSymIntdim, callextract_var_range_infoonly for the union fallback (or evaluate per profile via substitution as above).Inputper submodule placeholder with:min_shape/opt_shape/max_shape(backward compatible), andprofiles={name: {min, opt, max}}using the same profile names as the top level._TRTInterpreter: write all N profiles for that submodule (§6).Shape-tensor (scalar SymInt) inputs: handled by the same symbolic path in
construct_submodule_inputs(is_shape_tensor=True). Profile substitution applies to the scalar expression the same way.4. User API
4.1 Compile —
Input.profilesMulti-profile shape ranges live on
Input, not a separate compile kwarg. Each dynamic input declares named regimes:min/opt/maxare full shape tuples (same semantics as today'smin_shape/opt_shape/max_shape).Default (single profile) — unchanged. Either no
profileskey:How
compile()usesInput.profiles:Inputs that defineprofiles."prefill", zip each input's"prefill"entry into one TRT profile (internal normalized form for_TRTInterpreter).Dimper dynamic axis in_tracer.py.profiles) keep one shape; that shape is reused in every TRT profile.Input.profilesby symbolic propagation from the same profile names — no second export.Rules:
profilesand top-levelmin_shape/opt_shape/max_shapeare mutually exclusive on the sameInput.min ≤ opt ≤ maxelement-wise;min≥ 1 on every dim; each profile must fit inside the EP'sDimenvelope."prefill","decode") are compile-time labels onInput.profiles; the engine blob stores only TRT profile indices 0…N−1.4.2 Runtime — profile selection (manual default, auto opt-in)
No serialization format change. After
deserialize_cuda_engine, rebuild profile bounds from the TRT API (ICudaEngine::getTensorProfileShape/ Pythonget_tensor_profile_shape). Cache once in_setup_engine():Works for engines compiled in-process, loaded from cache, or deserialized from disk — no new fields in
_serialized_engine_layout.Selection modes
Profile selection is manual by default. Auto-selection is opt-in so users who know their phase boundaries (e.g. LLM prefill/decode) can avoid per-forward matching overhead.
with optimization_profile(trt_gm, "decode"):oroptimization_profile(trt_gm, 0)with optimization_profile(trt_gm, "auto"):enables auto-selection for that span; or compile flagauto_profile_selection=Truefor module-wide autoPriority: pinned name/index > auto (if enabled) > error on ambiguity. Outside an
"auto"span and without an explicit pin, do not switch profiles — the active profile from the lastoptimization_profile(...)call remains in effect (or profile 0 if never set).Profile-switch overhead:
setOptimizationProfileAsyncmay invalidate CUDA Graph captures and incurs a small device sync. Benchmark switching cost during implementation; document expected overhead in the user guide. Manual selection avoids the per-forward candidate scan entirely.Optional: persist
Input.profilesname→index map in the existingserialized_metadataJSON blob (no C++ layout change) sooptimization_profile(m, "decode")works after reload. If absent, the context manager accepts indices only.Auto-selection algorithm (when enabled)
Runs in
setup_input_tensorsbeforeset_input_shape:P = {0, …, N−1}.[min, max]for that binding.{0, 1}and input B leaves{1, 2}with no intersection → raise (conflicting shape signals).set_active_profile(p)if not already active (idempotent).set_active_profile(p)if not already active.See §3.1.1 for the distance metric and overlap examples. Manual pin via
optimization_profile(trt_gm, name)skips steps 1–6 entirely.Context manager
Stack semantics, idempotent switch when already on the requested profile, CUDA Graph invalidation on change, C++ op parity.
4.3 JIT:
torch.compile(backend="tensorrt")The JIT backend (
backend/backends.py→compile_module→ same_TRTInterpreter) shares conversion and runtime with AOT. Differences are only at how shape ranges are supplied.dynamo.compile(ep)torch.compile(..., backend="tensorrt")Dim(min, max)+ optional profilestorch.exportonce over union envelopeaot_export_joint_simpleon dynamo-traced graphoptimization_profile(...)context managerGraphModuleUsage — pass the same
Input(profiles=...)objects viaoptions:Rules for JIT:
Input(profiles=...)inoptions["arg_inputs"](orkwarg_inputs) — do not rely onprepare_inputs(first_tensor)alone.dynamic=Trueso prefill and decode hit one dynamo compile / one engine; switch profiles at runtime via §4.2.optimization_profile()unchanged from §4.2.Not supported in v1: inferring profiles from recompilation (guards building a new engine per shape regime). LLM users who need serialization should prefer AOT
dynamo.compile(ep, ...).5. LLM constraints
No
min=0. Reject inInput.profilesvalidation atInputconstruction /compile()entry.Use a static KV cache (
tools/llm/static_cache_v1.py): fixed[B, H, max_seq_len, D]tensors; track the valid region with scalarstart_idx/end_idx. Multi-profile only needs to specialize token inputs (input_ids,position_ids), not cache shape.Shape-polymorphic over the union: the exported graph must not branch on token length in Python (§3.1). With static cache + index-based attention, intermediate TRT subgraphs see symbolic shapes derived from token inputs only; profile names propagate automatically (§3.2).
6. Implementation sketch
_Input.pyprofiles: Dict[str, {min, opt, max}]; mutual exclusion withmin/opt/max_shape; validation_compiler.pycollect_optimization_profiles(inputs)→ normalized list for interpreter + cache hash; buildprofile_source_boundsfrom top-levelInput.profiles+ EP placeholder symbols; after partition, attach propagated profiles to submodule inputs_tracer.pyInput.profiles(AOT); post-export validation: profile ⊆ export envelope, log narrowing when export envelope ⊃ profile union (§3.1); JIT: same in backend whenInput.profilessetpartitioning/common.pyconstruct_dynamic_input/ addconstruct_dynamic_input_multi_profile: given symbolic shape +profile_source_bounds, evaluate per-profile min/opt/max viaSymIntexpr substitution; extendconstruct_submodule_inputsto emitInput.profileswhen top-level profiles presentdynamo/utils.pyextract_var_range_info_for_profile(symint, profile_bounds, mode)(or generalizeextract_var_range_info) for per-profile evaluation; keep existing function for union / single-profile_TRTInterpreter.pyplaceholder()instead of writing index 0 only_settings.pyoptimization_profileslist (built from inputs, picklable for cache/serialization)_TRTEngine.py_profile_dim_ranges(min/max) and cache per-profile opt shapes for closest-opt tie-break (§3.1.1); opt-in auto-select insetup_input_tensors;set_active_profile; benchmark profile-switch overheadruntime/_optimization_profile.py"auto"span; overlap disambiguationInput.profilescontent at compile timeserialized_metadataonlyPrior art: FX converter already loops over N
shape_rangesinfx2trt.py. Submodule symbolic shapes already flow from export throughconstruct_submodule_inputstoday; multi-profile extendsconstruct_dynamic_input+extract_var_range_infofrom one envelope to N named profiles.7. Example (Alpamayo)
Alpamayo decode benchmark shape:
batch=6,inputs_embeds=(6, 1, 4096), static KV(6, 2, 8, 4096, 128). Onlyinputs_embedsseq len varies between prefill and decode; KV stays static (§5).Decode under
"decode"should match ONNX-TRT decode (~5.1 ms GPU Compute on the Alpamayo repro) instead of the ~2× slower single-profile engine tuned for prefillopt=3424.Graph-break example (symbolic propagation)
Suppose partition yields a TRT submodule whose placeholder shape is
(B, H, s0, D)where the top-levelinputs_embedsseq dim iss0. User profiles oninputs_embeds:s0(seq)s0"prefill""decode"Both engines (root and submodule) get TRT optimization profiles
"prefill"/"decode"with identical bounds ons0— derived from the user's top-levelInput.profiles, not a second trace.If an intermediate op divides spatially, e.g. submodule input
(B, H, s0/4, D):s0s0/4(min / opt / max)"prefill""decode"No user action required on the intermediate tensor; propagation is automatic from export's symbolic metadata.
8. Backward compatibility
Input.profiles→ existing zero/one-profile paths unchanged; same cache key._profile_dim_rangesempty or single-entry; behavior unchanged.get_tensor_profile_shapeonce that build writes N profiles into the blob.9. Implementation plan
_Input.py:profilesdict + validation;collect_optimization_profiles()in_compiler.py._tracer.py: union envelope; post-export profile ⊆ envelope validation (§3.1).partitioning/common.py+utils.py: per-profile symbolic propagation for submodule inputs (§3.2); wire into_compiler.pypartition loop (before / instead of per-submodule single-profileconstruct_submodule_inputsonly)._TRTInterpreter.py: multi-profile loop._TRTEngine:_profile_dim_rangesfromget_tensor_profile_shapeat load; opt-in auto-select + per-input rule-out algorithm (§4.2);set_active_profile; profile-switch overhead benchmark.runtime/_optimization_profile.py+ C++ op: manual pin,"auto"span, stack semantics.serialized_metadata.Beta Was this translation helpful? Give feedback.
All reactions