From 6d0a7805682d8ac6996cf5858e71e7c21b1ca6f8 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 24 Nov 2025 15:15:57 -0800 Subject: [PATCH 1/2] improve engine caching and fix bugs --- py/torch_tensorrt/dynamo/_refit.py | 11 +- py/torch_tensorrt/dynamo/_settings.py | 1 - py/torch_tensorrt/dynamo/backend/backends.py | 4 - .../dynamo/conversion/_TRTInterpreter.py | 95 +--------- .../dynamo/conversion/_conversion.py | 163 +++++++++++++++--- .../models/test_weight_stripped_engine.py | 92 +++++++++- 6 files changed, 232 insertions(+), 134 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 9aae901f87..467daab529 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -52,7 +52,7 @@ logger = logging.getLogger(__name__) -@needs_refit +@needs_refit # type: ignore[misc] def construct_refit_mapping( module: torch.fx.GraphModule, inputs: Sequence[Input], @@ -85,7 +85,7 @@ def construct_refit_mapping( return weight_refit_map -@needs_refit +@needs_refit # type: ignore[misc] def construct_refit_mapping_from_weight_name_map( weight_name_map: dict[Any, Any], state_dict: dict[Any, Any], @@ -128,7 +128,7 @@ def construct_refit_mapping_from_weight_name_map( return engine_weight_map -@needs_refit +@needs_refit # type: ignore[misc] def _refit_single_trt_engine_with_gm( new_gm: torch.fx.GraphModule, old_engine: trt.ICudaEngine, @@ -211,7 +211,7 @@ def _refit_single_trt_engine_with_gm( raise AssertionError("Refitting failed.") -@needs_refit +@needs_refit # type: ignore[misc] def refit_module_weights( compiled_module: torch.fx.GraphModule | ExportedProgram, new_weight_module: ExportedProgram, @@ -484,9 +484,10 @@ def refit_module_weights( weight_name_map=None, ) - # clear EXCLUDE_WEIGHTS flag + # clear EXCLUDE_WEIGHTS flag and set INCLUDE_REFIT flag to make the engine refittable serialization_config = engine.create_serialization_config() serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) + serialization_config.set_flag(trt.SerializationFlag.INCLUDE_REFIT) serialized_engine = engine.serialize_with_config(serialization_config) if isinstance(compiled_submodule, PythonTorchTensorRTModule): diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index d8f6809eae..a64c5b3800 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -167,7 +167,6 @@ def __setstate__(self, state: dict[str, Any]) -> None: "engine_capability", "hardware_compatible", "refit_identical_engine_weights", - "strip_engine_weights", # TODO: @Evan to remove this after implementing caching weight-stripped engines as default? "immutable_weights", "enable_weight_streaming", "tiling_optimization_level", diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index b834d8087f..5d2f50af2f 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -157,10 +157,6 @@ def _pretraced_backend( logger.warning( "require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt" ) - if settings.strip_engine_weights: - logger.error( - "strip_engine_weights arg is not supported for torch.compile()" - ) trt_compiled = compile_module( gm, torchtrt_inputs, diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index e31c423581..d4735baa12 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -31,7 +31,7 @@ from torch_tensorrt._utils import is_tensorrt_version_supported from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._engine_cache import BaseEngineCache -from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible +from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, @@ -594,79 +594,6 @@ def _save_weight_mapping(self) -> None: gc.collect() torch.cuda.empty_cache() - @needs_refit # type: ignore[misc] - def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: - # query the cached TRT engine - cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr] - if cached_data is not None: # hit the cache - ( - serialized_engine, - self._input_names, - self._output_names, - cached_engine_input_specs, - engine_compilation_settings, - self.weight_name_map, - self.ctx.requires_output_allocator, - ) = cached_data - - setting_compatiblity, incompattible_settings = settings_are_compatible( - self.compilation_settings, engine_compilation_settings - ) - assert ( - setting_compatiblity - ), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {engine_compilation_settings}, new_settings: {self.compilation_settings})" - - for i, e in enumerate( - [ - Input.equivalent_spec(c, i) - for c, i in zip(cached_engine_input_specs, self.input_specs) - ] - ): - assert ( - e - ), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_input_specs[i]}, new size: {self.input_specs[i]}" - - _LOGGER.info( - "Found the cached engine that corresponds to this graph. It is directly loaded." - ) - - # refit the cached engine with the new graph module - if not self.compilation_settings.strip_engine_weights: - runtime = trt.Runtime(TRT_LOGGER) - engine = runtime.deserialize_cuda_engine(serialized_engine) - - from torch_tensorrt.dynamo._refit import ( - _refit_single_trt_engine_with_gm, - ) - - _refit_single_trt_engine_with_gm( - new_gm=self.module, - old_engine=engine, - input_list=self.input_specs, - settings=self.compilation_settings, - weight_name_map=self.weight_name_map, - ) - - # TODO: @Evan is waiting for TRT's feature to load the weight-stripped engine - # # EXCLUDE_WEIGHTS flag must be cleared - # serialization_config = engine.create_serialization_config() - # serialization_config.clear_flag( - # trt.SerializationFlag.EXCLUDE_WEIGHTS - # ) - # serialized_engine = engine.serialize_with_config( - # serialization_config - # ) - # # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller - - return TRTInterpreterResult( - engine, - self._input_names, - self._output_names, - self.weight_name_map, - self.ctx.requires_output_allocator, - ) - return None - def run( self, strict_type_constraints: bool = False, @@ -682,26 +609,6 @@ def run( Return: TRTInterpreterResult """ - # self.engine_cache could be None if: - # 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or - # 2) both cache_built_engines and reuse_cached_engines are False - if ( - self.engine_cache is not None - and not self.compilation_settings.immutable_weights - ): - if ( - self.compilation_settings.cache_built_engines - or self.compilation_settings.reuse_cached_engines - ): - hash_val = self.engine_cache.get_hash( - self.module, self.input_specs, self.compilation_settings - ) - - if self.compilation_settings.reuse_cached_engines: - interpreter_result = self._pull_cached_engine(hash_val) - if interpreter_result is not None: # hit the cache - return interpreter_result # type: ignore[no-any-return] - self._construct_trt_network_def() _LOGGER.debug( f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB" diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 76926107a4..bf51e8e9c9 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -4,19 +4,24 @@ import logging from typing import Any, List, NamedTuple, Optional, Sequence +import tensorrt as trt import torch from torch_tensorrt._enums import dtype -from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt._features import ENABLED_FEATURES, needs_refit from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._engine_cache import BaseEngineCache -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter +from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible +from torch_tensorrt.dynamo.conversion._TRTInterpreter import ( + TRTInterpreter, + TRTInterpreterResult, +) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule from torch_tensorrt.dynamo.utils import ( get_cpu_memory_usage, get_output_dtypes, release_host_and_device_memory, ) +from torch_tensorrt.logging import TRT_LOGGER logger = logging.getLogger(__name__) @@ -63,6 +68,128 @@ def interpret_module_to_result( SerializedInterpreterResult """ + def _insert_engine_to_cache( + hash_val: str, interpreter_result: TRTInterpreterResult + ) -> None: # type: ignore[unused-ignore] + # Cache the weight-stripped engine regardless of the `strip_engine_weights` setting + if engine_cache.check(hash_val) is not None: # type: ignore[union-attr] + logger.info(f"Engine already exists in cache for hash: {hash_val}") + return + if not settings.strip_engine_weights: + # set EXCLUDE_WEIGHTS flag to strip weights + serialization_config = ( + interpreter_result.engine.create_serialization_config() + ) + serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) + weight_stripped_serialized_engine = ( + interpreter_result.engine.serialize_with_config(serialization_config) + ) + else: + weight_stripped_serialized_engine = interpreter_result.engine.serialize() + + # Insert weight-stripped engine to cache + engine_cache.insert( # type: ignore[union-attr] + hash_val, + ( + weight_stripped_serialized_engine, + interpreter_result.input_names, + interpreter_result.output_names, + inputs, + settings, + interpreter_result.weight_name_map, + interpreter_result.requires_output_allocator, + ), + ) + logger.info(f"Engine was successfully inserted into cache for hash: {hash_val}") + + @needs_refit # type: ignore[misc] + def _pull_cached_engine(hash_val: str) -> Optional[SerializedInterpreterResult]: + # query the cached TRT engine + cached_data = engine_cache.check(hash_val) # type: ignore[union-attr] + if cached_data is not None: # hit the cache + ( + serialized_engine, # weight-stripped engine + input_names, + output_names, + cached_engine_inputs, + cached_engine_compilation_settings, + weight_name_map, + requires_output_allocator, + ) = cached_data + + setting_compatiblity, incompattible_settings = settings_are_compatible( + settings, cached_engine_compilation_settings + ) + assert ( + setting_compatiblity + ), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {cached_engine_compilation_settings}, new_settings: {settings})" + + for i, e in enumerate( + [ + Input.equivalent_spec(c, i) + for c, i in zip(cached_engine_inputs, inputs) + ] + ): + assert ( + e + ), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_inputs[i]}, new size: {inputs[i]}" + + logger.info( + "Found the cached engine that corresponds to this graph. It is directly loaded." + ) + + # refit the cached engine with the new graph module + if not settings.strip_engine_weights: + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine( + serialized_engine + ) # weight-stripped engine + + from torch_tensorrt.dynamo._refit import ( + _refit_single_trt_engine_with_gm, + ) + + # weight-stripped engine --in place--> weight-included engine + _refit_single_trt_engine_with_gm( + new_gm=module, + old_engine=engine, + input_list=inputs, + settings=settings, + weight_name_map=weight_name_map, + ) + + # EXCLUDE_WEIGHTS flag must be cleared and INCLUDE_REFIT flag must be set + serialization_config = engine.create_serialization_config() + serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) + serialization_config.set_flag(trt.SerializationFlag.INCLUDE_REFIT) + serialized_engine = engine.serialize_with_config(serialization_config) + # Start from here, the engine is weight-included and refittable + + with io.BytesIO() as engine_bytes: + engine_bytes.write(serialized_engine) + serialized_engine = engine_bytes.getvalue() + + return SerializedInterpreterResult( + serialized_engine=serialized_engine, + input_names=input_names, + output_names=output_names, + weight_name_map=weight_name_map, + requires_output_allocator=requires_output_allocator, + ) + return None + + # engine_cache could be None if: + # 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or + # 2) both cache_built_engines and reuse_cached_engines are False + if engine_cache is not None and not settings.immutable_weights: + if settings.cache_built_engines or settings.reuse_cached_engines: + hash_val = engine_cache.get_hash(module, inputs, settings) + + if settings.reuse_cached_engines: + serialized_interpreter_result = _pull_cached_engine(hash_val) + if serialized_interpreter_result is not None: # hit the cache + return serialized_interpreter_result # type: ignore[no-any-return] + output_dtypes = infer_module_output_dtypes( module, truncate_double=settings.truncate_double ) @@ -86,32 +213,20 @@ def interpret_module_to_result( f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB" ) - serialized_engine = interpreter_result.engine.serialize() - with io.BytesIO() as engine_bytes: - engine_bytes.write(serialized_engine) - serialized_engine = engine_bytes.getvalue() - logger.debug( - f"CPU memory usage after serializing engine: {get_cpu_memory_usage()} MB" - ) - # Engine caching only for refittable engines if ( not settings.immutable_weights and settings.cache_built_engines and engine_cache is not None ): - hash_val = engine_cache.get_hash(module, inputs, settings) - engine_cache.insert( - hash_val, - ( - serialized_engine, - interpreter_result.input_names, - interpreter_result.output_names, - inputs, - settings, - interpreter_result.weight_name_map, - interpreter_result.requires_output_allocator, - ), + _insert_engine_to_cache(hash_val, interpreter_result) + + serialized_engine = interpreter_result.engine.serialize() + with io.BytesIO() as engine_bytes: + engine_bytes.write(serialized_engine) + serialized_engine = engine_bytes.getvalue() + logger.debug( + f"CPU memory usage after serializing engine: {get_cpu_memory_usage()} MB" ) serialized_interpreter_result = SerializedInterpreterResult( @@ -122,7 +237,7 @@ def interpret_module_to_result( requires_output_allocator=interpreter_result.requires_output_allocator, ) - return serialized_interpreter_result + return serialized_interpreter_result # type: ignore[no-any-return] def convert_module( diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index d34d7d1edc..d70cd527e2 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -454,11 +454,11 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): not torch_trt.ENABLED_FEATURES.refit, "Engine caching requires refit feature that is not supported in Python 3.13 or higher", ) - def test_different_args_dont_share_cached_engine(self): + def test_different_args_share_cached_engine(self): class MyModel(torch.nn.Module): def __init__(self): super().__init__() - self.conv = torch.nn.Conv2d(3, 4, 3, stride=1, bias=True) + self.conv = torch.nn.Conv2d(512, 64, 32, stride=1, bias=True) self.relu = torch.nn.ReLU() def forward(self, x): @@ -468,11 +468,11 @@ def forward(self, x): pyt_model = MyModel().eval().to("cuda") - engine_cache_dir = "/tmp/test_different_args_dont_share_cached_engine" + engine_cache_dir = "/tmp/test_different_args_share_cached_engine" if os.path.exists(engine_cache_dir): shutil.rmtree(engine_cache_dir) - inputs = [torch.rand((4, 3, 32, 32)).to("cuda")] + inputs = [torch.rand((64, 512, 32, 32)).to("cuda")] for i in range(2): if i == 0: @@ -498,8 +498,8 @@ def forward(self, x): assertions.assertEqual( len(os.listdir(engine_cache_dir)), - 2, - msg=f"It has {len(os.listdir(engine_cache_dir))} cached engine(s) but should have 2 engines", + 1, + msg=f"It has {len(os.listdir(engine_cache_dir))} cached engine(s) but should have 1 engine", ) @unittest.skipIf( @@ -636,3 +636,83 @@ def test_refit_identical_engine_weights(self): ) except Exception as e: pass + + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) + @unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", + ) + def test_refit_weight_stripped_engine_multiple_times(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + # Mark the dim0 of inputs as dynamic + batch = torch.export.Dim("batch", min=1, max=200) + exp_program = torch.export.export( + pyt_model, args=example_inputs, dynamic_shapes={"x": {0: batch}} + ) + + inputs = (torch.rand((128, 3, 224, 224)).to("cuda"),) + + trt_gm = torch_trt.dynamo.compile( + exp_program, + inputs, + use_python_runtime=True, + enabled_precisions={torch.float}, + min_block_size=1, + immutable_weights=False, + cache_built_engines=False, + reuse_cached_engines=False, + strip_engine_weights=True, + refit_identical_engine_weights=False, + ) + output = trt_gm(*inputs) + assertions.assertEqual( + output.sum(), 0, msg="weight-stripped engine results should be all zeros" + ) + + # Refit the weight-stripped engine with the same weights + refitted_trt_gm = refit_module_weights(trt_gm, exp_program) + refitted_output = refitted_trt_gm(*inputs) + assertions.assertNotEqual( + refitted_output.sum(), + 0, + msg="refitted engine results should not be all zeros", + ) + + inputs2 = (torch.rand((64, 3, 224, 224)).to("cuda"),) + exp_program2 = torch.export.export( + pyt_model, args=inputs2, dynamic_shapes={"x": {0: batch}} + ) + + # Refit with different weights + refitted_trt_gm = refit_module_weights(refitted_trt_gm, exp_program2) + refitted_output = refitted_trt_gm(*inputs2) + assertions.assertNotEqual( + refitted_output.sum(), + 0, + msg="refitted engine results should not be all zeros", + ) + + compiled_model = torch.compile( + pyt_model, + backend="tensorrt", + options={ + "use_python_runtime": False, + "enabled_precisions": {torch.float}, + "min_block_size": 1, + "immutable_weights": False, + "cache_built_engines": False, + "reuse_cached_engines": False, + "refit_identical_engine_weights": False, + "strip_engine_weights": False, + }, + ) + compiled_model_output = compiled_model(*inputs2) + cos_sim = cosine_similarity(refitted_output, compiled_model_output) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"refitted_output doesn't match with compiled_model_output. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) From d39a29e77674cb175a0a2bdede37adec60694fe5 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 24 Nov 2025 15:35:53 -0800 Subject: [PATCH 2/2] reduce dims in tests --- tests/py/dynamo/models/test_weight_stripped_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index d70cd527e2..a2f00d68b1 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -458,7 +458,7 @@ def test_different_args_share_cached_engine(self): class MyModel(torch.nn.Module): def __init__(self): super().__init__() - self.conv = torch.nn.Conv2d(512, 64, 32, stride=1, bias=True) + self.conv = torch.nn.Conv2d(3, 4, 3, stride=1, bias=True) self.relu = torch.nn.ReLU() def forward(self, x): @@ -472,7 +472,7 @@ def forward(self, x): if os.path.exists(engine_cache_dir): shutil.rmtree(engine_cache_dir) - inputs = [torch.rand((64, 512, 32, 32)).to("cuda")] + inputs = [torch.rand((4, 3, 32, 32)).to("cuda")] for i in range(2): if i == 0: