fix: torch-TRT runtime cache attribute + standard-TRT fast refit regression#4225
Merged
zewenli98 merged 3 commits intoMay 4, 2026
Merged
Conversation
PythonTorchTensorRTModule.__getstate__ pops runtime_config and runtime_cache (they hold non-picklable native handles), but __setstate__ never restored them. setup_engine() only re-creates them inside _setup_runtime_config(), which is gated by ENABLED_FEATURES.tensorrt_rtx. On standard TRT the gate is false, so on every unpickle / deepcopy the attributes simply do not exist. __del__ -> _save_runtime_cache() then reads self.runtime_cache, nn.Module's __getattr__ raises AttributeError, and Python emits a PytestUnraisableExceptionWarning across the refit and weight-stripped-engine test suites. Re-initialize both fields to None inside __setstate__ before calling setup_engine(). Mirror the same init in _load_from_state_dict so the method is self-contained even though __init__ usually runs first on that path. Also annotate the deliberate "import tensorrt as trt" placement (after torch_tensorrt, so the tensorrt_rtx alias resolves) with an isort:skip marker, mirroring the convention already used in _refit.py and the weight-stripped-engine test. No behavior change on TRT-RTX, where setup_engine() proceeds to populate the real handles via _setup_runtime_config() as before. Signed-off-by: tejaswinp <tejaswinp@nvidia.com>
Pre-existing F401 surfaced by ruff when other changes touch this file. Signed-off-by: tejaswinp <tejaswinp@nvidia.com>
The strict check added in pytorch#4198 — comparing weights actually set against the full engine weight list — broke fast refit on standard TRT for any engine that contains CONSTANT layers intentionally absent from the mapping (e.g. batch-norm eps constants, which are baked in at build time and not expected to be refit). The pre-existing warn-and-continue branch makes that absence the contract; the strict check made it an error. Concretely, on standard TRT compiling resnet18 through torch.compile + the disk engine cache, the cache-hit path now asserts with "0 missing, 20 unset". _pretraced_backend swallows the assertion and falls back to the plain GraphModule, silently disabling TRT compilation for the test. The strict check exists for the TRT-RTX case where each weight lives in its own independent wtsEngine and get_missing_weights() can under-report. On standard TRT, get_missing_weights() is authoritative because connected weight engines surface any unset weight transitively, so the additional unset-weights cross-check is unnecessary and actively wrong. Gate the unset-weights assertion behind ENABLED_FEATURES.tensorrt_rtx to restore the standard-TRT contract while keeping the TRT-RTX safety net. Also rewrite both fast-refit assertion messages to surface counts vs total plus example unset weights for diagnostic purposes. Signed-off-by: tejaswinp <tejaswinp@nvidia.com>
ec2d418 to
dc978a0
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Two related fixes to issues introduced by recent torch-TensorRT runtime-cache and refit work, both surfaced by the L2 dynamo compile tests.
PythonTorchTensorRTModule.__getstate__popsruntime_configandruntime_cache(they hold non-picklable native handles), but__setstate__never restored them.setup_engine()only re-creates them inside_setup_runtime_config(), which is gated byENABLED_FEATURES.tensorrt_rtx. On standard TRT the gate is false, so on every unpickle / deepcopy the attributes simply do not exist —__del__ -> _save_runtime_cache()then readsself.runtime_cacheand Python emitsPytestUnraisableExceptionWarningacross the refit and weight-stripped-engine test suites. Re-init both fields toNonein__setstate__and_load_from_state_dictbeforesetup_engine(). No behavior change on TRT-RTX.unset_weights = {w for w in weight_list if w not in mapping}added previously broke fast refit on standard TRT for any engine with CONSTANT layers intentionally absent from the mapping (e.g. batch-normepsconstants baked at build time). The pre-existing warn-and-continue branch makes that absence the contract; the strict check made it an error. On a resnet18 disk-engine-cache hit, the path now asserts with "0 missing, 20 unset" and_pretraced_backendsilently falls back to GraphModule (visible as XPASS ontest_dynamo_compile_with_default_disk_engine_cacheandtest_torch_compile_with_default_disk_engine_cache). The strict check exists only to guard the TRT-RTX case where each weight lives in its own independent wtsEngine andget_missing_weights()can under-report; on standard TRT,get_missing_weights()is authoritative because connected weight engines surface any unset weight transitively. Gate the assertion behindENABLED_FEATURES.tensorrt_rtx.CPU_DEVICEimport — pre-existing F401 surfaced by ruff on the touched file.Test plan
Verified end-to-end on a fresh standard-TRT (non-RTX) install of nightly torch_tensorrt on Linux + CUDA 13.0:
pytest tests/py/dynamo/models/test_model_refit.py::test_refit_one_engine_with_weightmap test_refit_one_engine_python_runtime_with_weightmap test_complex_buffer_with_real_param_refit— baseline reproduces 2×AttributeError: ...runtime_cache; with patch, 0 occurrences.pytest tests/py/dynamo/models/test_engine_cache.py::TestEngineCache::test_dynamo_compile_with_default_disk_engine_cache test_torch_compile_with_default_disk_engine_cache— baseline reproduces 1×AssertionError: Fast refitting failed due to incomplete mapping (0 missing, 20 unset)plusReturning GraphModule forward insteadsilent fallback; with patch, 0 of either, TRT compilation actually succeeds on the cache-hit path.PytestUnraisableExceptionWarninggroups disappear.🤖 Generated with Claude Code