From 6b292ea4db2c4a6e4f4adb71a4bffaf302e3a05c Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 9 Sep 2025 13:37:38 -0700 Subject: [PATCH 1/3] add benchmark models and polish codes --- py/torch_tensorrt/dynamo/_compiler.py | 238 ++++++++++++++++---------- tools/perf/README.md | 5 +- tools/perf/benchmark.sh | 42 ++++- tools/perf/custom_models.py | 33 +++- tools/perf/perf_run.py | 166 +++++++++++------- tools/perf/utils.py | 32 +++- 6 files changed, 348 insertions(+), 168 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 608c8e84c9..db8e014879 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -53,7 +53,7 @@ logger = logging.getLogger(__name__) -@needs_cross_compile +@needs_cross_compile # type: ignore[misc] def cross_compile_for_windows( exported_program: ExportedProgram, inputs: Optional[Sequence[Sequence[Any]]] = None, @@ -238,15 +238,6 @@ def cross_compile_for_windows( "`immutable_weights` must be False when `refit_identical_engine_weights` is True." ) - if ( - not immutable_weights - and not refit_identical_engine_weights - and enable_weight_streaming - ): - raise ValueError( - "TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305" - ) - engine_capability = EngineCapability._from(engine_capability) if torch_executed_modules is not None and torch_executed_modules: @@ -450,7 +441,7 @@ def compile( Arguments: exported_program (torch.export.ExportedProgram): Source module, running torch.export on a ``torch.nn.Module`` - inputs (Tuple[Any, ...]): List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using + inputs (Optional[Sequence[Sequence[Any]]]): List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. @@ -469,8 +460,8 @@ def compile( ] Keyword Arguments: - arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs. - kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function. + arg_inputs (Optional[Sequence[Sequence[Any]]]): Same as inputs. Alias for better understanding with kwarg_inputs. + kwarg_inputs (Optional[dict[Any, Any]]): kwarg inputs to the module forward function. device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on :: device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True) @@ -478,8 +469,8 @@ def compile( disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. - enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels - capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels + enabled_precisions (Union[Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]]): The set of datatypes that TensorRT can use when selecting kernels + engine_capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels workspace_size (int): Maximum size of workspace given to TensorRT dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer. @@ -488,10 +479,10 @@ def compile( truncate_double (bool): Truncate weights provided in double (float64) to float32 require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch min_block_size (int): The minimum number of contiguous TensorRT convertible operations in order to run a set of operations in TensorRT - torch_executed_ops (Collection[Target]): Set of aten operators that must be run in PyTorch. An error will be thrown if this set is not empty but ``require_full_compilation`` is True - torch_executed_modules (List[str]): List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True + torch_executed_ops (Optional[Collection[Target]]): Set of aten operators that must be run in PyTorch. An error will be thrown if this set is not empty but ``require_full_compilation`` is True + torch_executed_modules (Optional[List[str]]): List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True pass_through_build_failures (bool): Error out if there are issues during compilation (only applicable to torch.compile workflows) - max_aux_stream (Optional[int]): Maximum streams in the engine + max_aux_streams (Optional[int]): Maximum streams in the engine version_compatible (bool): Build the TensorRT engines compatible with future versions of TensorRT (Restrict to lean runtime operators to provide version forward compatibility for the engines) optimization_level: (Optional[int]): Setting a higher optimization level allows TensorRT to spend longer engine building time searching for more optimization options. The resulting engine may have better performance compared to an engine built with a lower optimization level. The default optimization level is 3. Valid values include integers from 0 to the maximum optimization level, which is currently 5. Setting it to be greater than the maximum level results in identical behavior to the maximum level. use_python_runtime: (bool): Return a graph using a pure Python runtime, reduces options for serialization @@ -503,8 +494,8 @@ def compile( lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage - engine_cache_dir (Optional[str]): Directory to store the cached TRT engines - engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default + engine_cache_dir (str): Directory to store the cached TRT engines + engine_cache_size (int): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored. use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs. use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. @@ -632,11 +623,6 @@ def compile( device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} - if not isinstance(exported_program, ExportedProgram): - raise AssertionError( - f"Input graph should be an ExportedProgram but got type {type(exported_program)}" - ) - engine_cache = None if cache_built_engines or reuse_cached_engines: engine_cache = ( @@ -723,7 +709,7 @@ def compile( return trt_gm -@fn_supports_debugger +@fn_supports_debugger # type: ignore[misc] def compile_module( gm: torch.fx.GraphModule, sample_arg_inputs: Sequence[Input], @@ -1016,32 +1002,40 @@ def convert_exported_program_to_serialized_trt_engine( *, arg_inputs: Optional[Sequence[Sequence[Any]]] = None, kwarg_inputs: Optional[dict[Any, Any]] = None, + device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE, + disable_tf32: bool = _defaults.DISABLE_TF32, + assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, + sparse_weights: bool = _defaults.SPARSE_WEIGHTS, enabled_precisions: Union[ Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] ] = _defaults.ENABLED_PRECISIONS, - assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, + engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, + num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, + dla_sram_size: int = _defaults.DLA_SRAM_SIZE, + dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, + dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, + truncate_double: bool = _defaults.TRUNCATE_DOUBLE, + require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, min_block_size: int = _defaults.MIN_BLOCK_SIZE, - torch_executed_ops: Optional[Set[str]] = None, + torch_executed_ops: Optional[Collection[Target]] = None, + torch_executed_modules: Optional[List[str]] = None, pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES, max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS, version_compatible: bool = _defaults.VERSION_COMPATIBLE, optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, - use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME, - truncate_double: bool = _defaults.TRUNCATE_DOUBLE, + use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME, use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, - device: Device = Device._current_device(), - require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, - disable_tf32: bool = _defaults.DISABLE_TF32, - sparse_weights: bool = _defaults.SPARSE_WEIGHTS, - engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, - dla_sram_size: int = _defaults.DLA_SRAM_SIZE, - dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, - dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, - allow_shape_tensors: bool = False, + dryrun: bool = _defaults.DRYRUN, + hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, + lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, + cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, + reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, + engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR, + engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE, + custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE, use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING, use_fp32_acc: bool = _defaults.USE_FP32_ACC, refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS, @@ -1058,16 +1052,14 @@ def convert_exported_program_to_serialized_trt_engine( Converts an ExportedProgram to a serialized TensorRT engine given a dictionary of conversion settings Arguments: - exported_program (torch.export.ExportedProgram): Source module - - Keyword Args: - inputs (Optional[Sequence[torch_tensorrt.Input | torch.Tensor]]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using + exported_program (torch.export.ExportedProgram): Source module, running torch.export on a ``torch.nn.Module`` + inputs (Optional[Sequence[Sequence[Any]]]): List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. .. code-block:: py - inputs=[ + inputs=[ torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 torch_tensorrt.Input( min_shape=(1, 224, 224, 3), @@ -1078,34 +1070,45 @@ def convert_exported_program_to_serialized_trt_engine( ), # Dynamic input shape for input #2 torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ] - enabled_precisions (Optional[Set[torch.dtype | _enums.dtype]]): The set of datatypes that TensorRT can use - workspace_size (int): Workspace TRT is allowed to use for the module (0 is default) - min_block_size (int): Minimum number of operators per TRT-Engine Block - torch_executed_ops (Set[str]): Set of operations to run in Torch, regardless of converter coverage - pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False) - max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine - version_compatible (bool): Provide version forward-compatibility for engine plan files - optimization_level (Optional[int]): Builder optimization 0-5, higher levels imply longer build time, - searching for more optimization options. TRT defaults to 3 - use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime - based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the - argument as None - truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32 - use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system - enable_experimental_decompositions (bool): Whether to enable all core aten decompositions - or only a selected subset of them - device (Device): GPU to compile the model on - require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT. - Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path - disable_tf32 (bool): Whether to disable TF32 computation for TRT layers - sparse_weights (bool): Whether to allow the builder to use sparse weights - engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels + + Keyword Arguments: + arg_inputs (Optional[Sequence[Sequence[Any]]]): Same as inputs. Alias for better understanding with kwarg_inputs. + kwarg_inputs (Optional[dict[Any, Any]]): kwarg inputs to the module forward function. + device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on :: + + device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True) + + disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas + assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False + sparse_weights (bool): Enable sparsity for convolution and fully connected layers. + enabled_precisions (Union[Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]]): The set of datatypes that TensorRT can use when selecting kernels + engine_capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels + workspace_size (int): Maximum size of workspace given to TensorRT dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer. dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution - allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT + truncate_double (bool): Truncate weights provided in double (float64) to float32 + require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch + min_block_size (int): The minimum number of contiguous TensorRT convertible operations in order to run a set of operations in TensorRT + torch_executed_ops (Optional[Collection[Target]]): Set of aten operators that must be run in PyTorch. An error will be thrown if this set is not empty but ``require_full_compilation`` is True + torch_executed_modules (Optional[List[str]]): List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True + pass_through_build_failures (bool): Error out if there are issues during compilation (only applicable to torch.compile workflows) + max_aux_streams (Optional[int]): Maximum streams in the engine + version_compatible (bool): Build the TensorRT engines compatible with future versions of TensorRT (Restrict to lean runtime operators to provide version forward compatibility for the engines) + optimization_level: (Optional[int]): Setting a higher optimization level allows TensorRT to spend longer engine building time searching for more optimization options. The resulting engine may have better performance compared to an engine built with a lower optimization level. The default optimization level is 3. Valid values include integers from 0 to the maximum optimization level, which is currently 5. Setting it to be greater than the maximum level results in identical behavior to the maximum level. + use_python_runtime: (bool): Return a graph using a pure Python runtime, reduces options for serialization + use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optimal. Use the global paritioner (``False``) if looking for best performance + enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the graph easier to convert to TensorRT, potentially increasing the amount of graphs run in TensorRT. + dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs + hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation + lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. + cache_built_engines (bool): Whether to save the compiled TRT engines to storage + reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage + engine_cache_dir (str): Directory to store the cached TRT engines + engine_cache_size (int): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default + custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored. use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs. use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs. @@ -1114,6 +1117,8 @@ def convert_exported_program_to_serialized_trt_engine( enable_weight_streaming (bool): Enable weight streaming. tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). + offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. + **kwargs: Any, Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ @@ -1170,45 +1175,84 @@ def convert_exported_program_to_serialized_trt_engine( ) if ( - not immutable_weights - and not refit_identical_engine_weights - and enable_weight_streaming + "enable_cross_compile_for_windows" in kwargs.keys() + and kwargs["enable_cross_compile_for_windows"] ): raise ValueError( - "TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305" + "Please use cross_compile_for_windows() api if you want to cross compile the module in Linux for inferencing in Windows." + ) + + engine_capability = EngineCapability._from(engine_capability) + + if torch_executed_modules is not None and torch_executed_modules: + logger.warning( + f"Detected torch_executed_modules was non-empty: {torch_executed_modules}" + "\nThis feature is unimplemented in Torch-TRT Dynamo currently." ) + if use_explicit_typing: + if len(enabled_precisions) != 1 or not any( + x in enabled_precisions + for x in {torch.float32, dtype.f32, torch.float4_e2m1fn_x2, dtype.f4} + ): + raise AssertionError( + f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True" + ) + + if use_fp32_acc: + logger.debug( + "FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \ + This flag inserts casts around matmul layers and ensures TensorRT executes the matmul layers in FP16 with FP32 accumulation." + ) + + if enable_weight_streaming and not use_explicit_typing: + raise AssertionError( + "When enable_weight_streaming is enabled, it requires use_explicit_typing to be set to True" + ) + # Aliasing inputs to arg_inputs for better understanding if not arg_inputs and not kwarg_inputs and not inputs: raise AssertionError( "'arg_inputs', 'kwarg_inputs' and 'inputs' should not all be None." ) - elif arg_inputs is not None and inputs is not None: + elif arg_inputs and inputs: raise AssertionError( "'arg_inputs' and 'inputs' should not be used at the same time." ) arg_inputs = inputs or arg_inputs - torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set() + if kwarg_inputs is None: kwarg_inputs = {} - # Prepare torch_trt inputs - arg_input_list = list(prepare_inputs(arg_inputs)) - kwarg_input_list = prepare_inputs(kwarg_inputs) - flattened_input_list = get_flat_args_with_check( - exported_program, arg_input_list, kwarg_input_list - )[0] + if not isinstance(arg_inputs, collections.abc.Sequence): + arg_inputs = [arg_inputs] # type: ignore + # Prepare torch_trt inputs + trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) + trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) device = to_torch_tensorrt_device(device) - enabled_precisions = {dtype._from(e) for e in enabled_precisions} + enabled_precisions = {dtype._from(p) for p in enabled_precisions} + + engine_cache = None + if cache_built_engines or reuse_cached_engines: + engine_cache = ( + custom_engine_cache + if custom_engine_cache is not None + else DiskEngineCache(engine_cache_dir, engine_cache_size) + ) compilation_options = { + "enabled_precisions": ( + enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS + ), + "device": device, "assume_dynamic_shape_support": assume_dynamic_shape_support, - "enabled_precisions": enabled_precisions, "workspace_size": workspace_size, "min_block_size": min_block_size, - "torch_executed_ops": torch_executed_ops, + "torch_executed_ops": ( + torch_executed_ops if torch_executed_ops is not None else set() + ), "pass_through_build_failures": pass_through_build_failures, "max_aux_streams": max_aux_streams, "version_compatible": version_compatible, @@ -1216,22 +1260,27 @@ def convert_exported_program_to_serialized_trt_engine( "use_python_runtime": use_python_runtime, "truncate_double": truncate_double, "use_fast_partitioner": use_fast_partitioner, + "num_avg_timing_iters": num_avg_timing_iters, "enable_experimental_decompositions": enable_experimental_decompositions, - "device": device, "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, "engine_capability": engine_capability, - "num_avg_timing_iters": num_avg_timing_iters, "dla_sram_size": dla_sram_size, "dla_local_dram_size": dla_local_dram_size, "dla_global_dram_size": dla_global_dram_size, + "dryrun": dryrun, + "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, + "lazy_engine_init": lazy_engine_init, + "cache_built_engines": cache_built_engines, + "reuse_cached_engines": reuse_cached_engines, "use_explicit_typing": use_explicit_typing, "use_fp32_acc": use_fp32_acc, "refit_identical_engine_weights": refit_identical_engine_weights, "strip_engine_weights": strip_engine_weights, "immutable_weights": immutable_weights, + "enable_cross_compile_for_windows": False, "enable_weight_streaming": enable_weight_streaming, "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, @@ -1240,21 +1289,20 @@ def convert_exported_program_to_serialized_trt_engine( settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) - exported_program = pre_export_lowering(exported_program, settings) - # Decompose the exported program exported_program = exported_program.run_decompositions( get_decompositions(enable_experimental_decompositions) ) + gm = exported_program.module() + # Move the weights in the state_dict to CPU logger.debug("Input graph: " + str(gm.graph)) # Apply lowering on the graph module gm = post_lowering(gm, settings) logger.debug("Lowered Input graph: " + str(gm.graph)) - # Configure user compilation settings to converters. - CONVERTERS.set_compilation_settings(settings) + # Move the weights in the state_dict to CPU if offload_module_to_cpu: deallocate_module(exported_program.module(), delete_module=False) logger.info( @@ -1266,13 +1314,19 @@ def convert_exported_program_to_serialized_trt_engine( logger.warning( "Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True" ) + + flattened_input_list = get_flat_args_with_check( + exported_program, list(trt_arg_inputs), trt_kwarg_inputs + )[0] + try: interpreter_result = interpret_module_to_result( gm, inputs=flattened_input_list, - arg_inputs=arg_input_list, - kwarg_inputs=kwarg_input_list, + arg_inputs=list(trt_arg_inputs), + kwarg_inputs=trt_kwarg_inputs, settings=settings, + engine_cache=engine_cache, ) except UnsupportedOperatorException: logger.error( @@ -1289,7 +1343,7 @@ def convert_exported_program_to_serialized_trt_engine( return serialized_engine -@needs_cross_compile +@needs_cross_compile # type: ignore[misc] def save_cross_compiled_exported_program( gm: torch.fx.GraphModule, file_path: str, diff --git a/tools/perf/README.md b/tools/perf/README.md index 36c85386f7..5e48f3fce6 100644 --- a/tools/perf/README.md +++ b/tools/perf/README.md @@ -6,7 +6,8 @@ This is a comprehensive Python benchmark suite to run perf runs using different 2. Torch-TensorRT [Torchscript] 3. Torch-TensorRT [Dynamo] 4. Torch-TensorRT [torch_compile] -5. TensorRT +5. Torch Inductor +6. ONNX-TensorRT ## Prerequisite @@ -42,7 +43,7 @@ Benchmark scripts depends on following Python packages in addition to requiremen Here are the list of `CompileSpec` options that can be provided directly to compile the pytorch module -* `--backends` : Comma separated string of backends. Eg: torch, torch_compile, dynamo, tensorrt +* `--backends` : Comma separated string of backends. Eg: torch, ts_trt, dynamo, torch_compile, inductor, onnx_trt * `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `dynamo` or `torch_compile`, the input should be a Pytorch module (instead of a torchscript module). * `--model_torch` : Name of the PyTorch model file (optional, only necessary if `dynamo` or `torch_compile` is a chosen backend) * `--onnx` : ONNX model file which helps bypass the step of exporting ONNX from `model_torch`. If this argument is provided, the ONNX will be directly converted to TRT engine diff --git a/tools/perf/benchmark.sh b/tools/perf/benchmark.sh index 2770db3999..50bb45e4e2 100644 --- a/tools/perf/benchmark.sh +++ b/tools/perf/benchmark.sh @@ -7,8 +7,8 @@ python hub.py batch_sizes=(1 2 4 8 16 32 64 128 256) large_model_batch_sizes=(1 2 4 8 16 32 64) -backends=("torch" "ts_trt" "dynamo" "torch_compile" "inductor" "tensorrt") -backends_no_torchscript=("torch" "dynamo" "torch_compile" "inductor" "tensorrt") +backends=("torch" "ts_trt" "dynamo" "torch_compile" "inductor" "onnx_trt") +backends_no_torchscript=("torch" "dynamo" "torch_compile" "inductor" "onnx_trt") # Benchmark VGG16 model @@ -107,18 +107,48 @@ do done done -# Benchmark Stable Diffusion UNet model -echo "Benchmarking SD UNet model" +# Benchmark Stable Diffusion v1.4 UNet model +echo "Benchmarking SD-v1.4 UNet model" for bs in ${large_model_batch_sizes[@]} do for backend in ${backends_no_torchscript[@]} do - python perf_run.py --model_torch sd_unet \ + python perf_run.py --model_torch sd1.4_unet \ --precision fp16 --inputs="(${bs}, 4, 64, 64);(${bs});(${bs}, 1, 768)" \ --batch_size ${bs} \ --truncate \ --backends ${backend} \ - --report "sd_unet_perf_bs${bs}_backend_${backend}.csv" + --report "sd1.4_unet_perf_bs${bs}_backend_${backend}.csv" + done +done + +# Benchmark Stable Diffusion v2.1 UNet model +echo "Benchmarking SD-v2.1 UNet model" +for bs in ${large_model_batch_sizes[@]} +do + for backend in ${backends_no_torchscript[@]} + do + python perf_run.py --model_torch sd2.1_unet \ + --precision fp16 --inputs="(${bs}, 4, 64, 64);(${bs});(${bs}, 1, 1024)" \ + --batch_size ${bs} \ + --truncate \ + --backends ${backend} \ + --report "sd2.1_unet_perf_bs${bs}_backend_${backend}.csv" + done +done + +# Benchmark Stable Diffusion v2.1 VAE decoder model +echo "Benchmarking SD-v2.1 VAE decoder model" +for bs in ${large_model_batch_sizes[@]} +do + for backend in ${backends_no_torchscript[@]} + do + python perf_run.py --model_torch sd2.1_vae_decoder \ + --precision fp16 --inputs="(${bs}, 4, 64, 64)" \ + --batch_size ${bs} \ + --truncate \ + --backends ${backend} \ + --report "sd2.1_vae_decoder_perf_bs${bs}_backend_${backend}.csv" done done diff --git a/tools/perf/custom_models.py b/tools/perf/custom_models.py index ba5e6f2ddf..16657f34c9 100644 --- a/tools/perf/custom_models.py +++ b/tools/perf/custom_models.py @@ -26,7 +26,7 @@ def BertInputs(): return [tokens_tensor, segments_tensors] -def StableDiffusionUnet(): +def StableDiffusion1_4_Unet(): from diffusers import DiffusionPipeline pipe = DiffusionPipeline.from_pretrained( @@ -35,7 +35,25 @@ def StableDiffusionUnet(): return pipe.unet -def UNet(): +def StableDiffusion2_1_Unet(): + from diffusers import StableDiffusionPipeline + + pipe = StableDiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 + ) + return pipe.unet + + +def StableDiffusion2_1_VaeDecoder(): + from diffusers import StableDiffusionPipeline + + pipe = StableDiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 + ) + return pipe.vae.decoder + + +def MonaiUNet(): from monai.networks.nets import UNet model = UNet( @@ -46,4 +64,13 @@ def UNet(): strides=(2, 2), num_res_units=2, ) - return model.eval().cuda() + return model + + +def GoogleViTForImageClassification(): + from transformers import ViTForImageClassification + + model = ViTForImageClassification.from_pretrained( + "google/vit-base-patch16-224", torch_dtype=torch.float16 + ) + return model diff --git a/tools/perf/perf_run.py b/tools/perf/perf_run.py index 1ea82d0936..80b71b56b8 100644 --- a/tools/perf/perf_run.py +++ b/tools/perf/perf_run.py @@ -15,6 +15,7 @@ # Importing supported Backends import torch import torch_tensorrt as torchtrt +from torch_tensorrt.dynamo import convert_exported_program_to_serialized_trt_engine from utils import ( BENCHMARK_MODELS, export_llm, @@ -29,6 +30,15 @@ WARMUP_ITER = 10 results = [] +SUPPORTED_BACKENDS = [ + "all", + "torch", + "ts_trt", + "dynamo", + "torch_compile", + "inductor", + "onnx_trt", +] def run_with_try_except(func): @@ -87,8 +97,8 @@ def record_llm_perf( # We only support single input (B x seq_len) for LLMs now input_seq = input_tensors[0] with torch.no_grad(): - # Warm up for 3 iterations - _ = time_generate(model, input_seq, output_seq_length, iterations=iterations) + # Warm up + _ = time_generate(model, input_seq, output_seq_length, iterations=WARMUP_ITER) torch.cuda.synchronize() @@ -141,7 +151,7 @@ def record_perf( def run_torch(model, input_tensors, params, precision, batch_size): print("Running Torch for precision: ", precision, " batch_size : ", batch_size) iters = params.get("iterations", 20) - model = model.to("cuda:0") + if params["is_text_llm"]: output_seq_length = params["output_sequence_length"] return record_llm_perf( @@ -173,7 +183,6 @@ def run_ts_trt(model, input_tensors, params, precision, batch_size): compile_settings = { "inputs": input_tensors, "enabled_precisions": {precision_to_dtype(precision)}, - "truncate_double": params.get("truncate", False), } if precision == "int8": @@ -207,25 +216,40 @@ def run_hf_dynamo(model, input_tensors, params, precision, batch_size): """ Compile the huggingface model using Torch-TensorRT dynamo frontend and record performance stats """ - osl = params["output_sequence_length"] iters = params.get("iterations", 20) - # Move the model and inputs to cpu and trace it. - model = model.to("cpu") - inputs_cpu = [tensor.clone().cpu() for tensor in input_tensors] - exp_program = export_llm(model, inputs_cpu, min_seq_len=1, max_seq_len=osl) - start_compile = timeit.default_timer() + compilation_options = { + "enabled_precisions": {precision_to_dtype(precision)}, + "min_block_size": params.get("min_block_size", 1), + "truncate_double": params.get("truncate", False), + "immutable_weights": params.get("immutable_weights", True), + "strip_engine_weights": params.get("strip_engine_weights", False), + "refit_identical_engine_weights": params.get( + "refit_identical_engine_weights", False + ), + "cache_built_engines": params.get("cache_built_engines", False), + "reuse_cached_engines": params.get("reuse_cached_engines", False), + "use_python_runtime": params.get("use_python_runtime", False), + "optimization_level": params.get("optimization_level", 3), + } + start_compile = timeit.default_timer() + exp_program = export_llm(model, input_tensors, min_seq_len=1, max_seq_len=osl) trt_model = torchtrt.dynamo.compile( - exp_program, - inputs=input_tensors, - enabled_precisions={precision_to_dtype(precision)}, - truncate_double=params.get("truncate", False), - use_python_runtime=params.get("use_python_runtime", False), + exp_program, inputs=input_tensors, **compilation_options ) end_compile = timeit.default_timer() compile_time_s = end_compile - start_compile + if params.get("save_dynamo_trt_engine", False): + serialized_engine = convert_exported_program_to_serialized_trt_engine( + exp_program, + arg_inputs=input_tensors, + **compilation_options, + ) + with open(f"{params['model_torch']}-dynamo-trt-engine.trt", "wb") as f: + f.write(serialized_engine) + if params.get("enable_cuda_graph", False): with torchtrt.runtime.enable_cudagraphs(trt_model) as cudagraphs_module: record_llm_perf( @@ -265,28 +289,38 @@ def run_dynamo(model, input_tensors, params, precision, batch_size): if params["is_text_llm"]: return run_hf_dynamo(model, input_tensors, params, precision, batch_size) - start_compile = timeit.default_timer() - model = torchtrt.compile( - model, - inputs=input_tensors, - ir="dynamo", - enabled_precisions={precision_to_dtype(precision)}, - min_block_size=params.get("min_block_size", 1), - truncate_double=params.get("truncate", False), - immutable_weights=params.get("immutable_weights", True), - strip_engine_weights=params.get("strip_engine_weights", False), - refit_identical_engine_weights=params.get( + compilation_options = { + "enabled_precisions": {precision_to_dtype(precision)}, + "min_block_size": params.get("min_block_size", 1), + "truncate_double": params.get("truncate", False), + "immutable_weights": params.get("immutable_weights", True), + "strip_engine_weights": params.get("strip_engine_weights", False), + "refit_identical_engine_weights": params.get( "refit_identical_engine_weights", False ), - cache_built_engines=params.get("cache_built_engines", False), - reuse_cached_engines=params.get("reuse_cached_engines", False), - use_python_runtime=params.get("use_python_runtime", False), - optimization_level=params.get("optimization_level", 3), + "cache_built_engines": params.get("cache_built_engines", False), + "reuse_cached_engines": params.get("reuse_cached_engines", False), + "use_python_runtime": params.get("use_python_runtime", False), + "optimization_level": params.get("optimization_level", 3), + } + start_compile = timeit.default_timer() + exp_program = torch.export.export(model, input_tensors) + model = torchtrt.dynamo.compile( + exp_program, inputs=input_tensors, **compilation_options ) end_compile = timeit.default_timer() compile_time_s = end_compile - start_compile iters = params.get("iterations", 20) + if params.get("save_dynamo_trt_engine", False): + serialized_engine = convert_exported_program_to_serialized_trt_engine( + exp_program, + arg_inputs=input_tensors, + **compilation_options, + ) + with open(f"{params['model_torch']}-dynamo-trt-engine.trt", "wb") as f: + f.write(serialized_engine) + if params.get("enable_cuda_graph", False): with torchtrt.runtime.enable_cudagraphs(model) as cudagraphs_module: record_perf( @@ -309,8 +343,6 @@ def run_torch_compile(model, input_tensors, params, precision, batch_size): """ Compile the given model using Torch-TensorRT torch.compile frontend and record performance stats """ - # Move the model to GPU - model = model.to("cuda:0") torch._dynamo.reset() print( @@ -361,6 +393,8 @@ def run_hf_inductor(model, input_tensors, params, precision, batch_size): """ Compile the huggingface model using torch inductor and record performance stats """ + torch._dynamo.reset() + osl = params["output_sequence_length"] # Mark dynamic shapes for input sequence input_seq = input_tensors[0] @@ -395,7 +429,7 @@ def run_inductor(model, input_tensors, params, precision, batch_size): Compile the given model using torch inductor and record performance stats """ torch._dynamo.reset() - model = model.to("cuda:0") + print( "Running Torch [inductor] for precision: ", precision, @@ -428,7 +462,7 @@ def run_inductor(model, input_tensors, params, precision, batch_size): @run_with_try_except -def run_tensorrt( +def run_onnx_trt( model, input_tensors, params, @@ -445,8 +479,9 @@ def run_tensorrt( if params["onnx"]: onnx_path = params["onnx"] else: - onnx_path = "./onnx-trt.onnx" + onnx_path = f"{params['model_torch']}-onnx-trt.onnx" torch.onnx.export(model, tuple(input_tensors), onnx_path, dynamo=True) + start_compile = timeit.default_timer() builder = trt.Builder(logger) network = builder.create_network( 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) @@ -460,18 +495,24 @@ def run_tensorrt( if precision == "fp16": config.set_flag(trt.BuilderFlag.FP16) config.builder_optimization_level = params.get("optimization_level", 3) - start_compile = timeit.default_timer() serialized_engine = builder.build_serialized_network(network, config) end_compile = timeit.default_timer() compile_time_s = end_compile - start_compile + # Deserialize the TensorRT engine with trt.Runtime(logger) as runtime: engine = runtime.deserialize_cuda_engine(serialized_engine) - print("Running TensorRT for precision: ", precision, " batch_size : ", batch_size) + # save the generated TRT engine + if params.get("save_onnx_trt_engine", False): + with open(f"{params['model_torch']}-onnx-trt-engine.trt", "wb") as f: + f.write(serialized_engine) + + print( + "Running ONNX-TensorRT for precision: ", precision, " batch_size : ", batch_size + ) iters = params.get("iterations", 20) - start_time = timeit.default_timer() # Get I/O tensor information using TensorRT 10 API input_names = [] output_names = [] @@ -511,8 +552,6 @@ def run_tensorrt( dedicated_stream = torch.cuda.Stream() current_stream = torch.cuda.current_stream() - setup_time = timeit.default_timer() - # Warm up for i in range(WARMUP_ITER): # Wait for current stream to finish @@ -522,23 +561,19 @@ def run_tensorrt( current_stream.wait_stream(dedicated_stream) torch.cuda.synchronize() - infer_start_time = timeit.default_timer() # Performance measurement for i in range(iters): + infer_start_time = timeit.default_timer() # Wait for current stream to finish dedicated_stream.wait_stream(current_stream) context.execute_async_v3(dedicated_stream.cuda_stream) # Wait for TensorRT stream to finish current_stream.wait_stream(dedicated_stream) torch.cuda.synchronize() + infer_end_time = timeit.default_timer() + timings.append(infer_end_time - infer_start_time) - end_time = timeit.default_timer() - - # to compare against torch-trt dynamo apples to apples - infer_time = (end_time - infer_start_time + setup_time - start_time) / iters - timings.append(infer_time) - - recordStats("TensorRT", timings, precision, batch_size, compile_time_s) + recordStats("ONNX-TensorRT", timings, precision, batch_size, compile_time_s) # Deploys inference run for different backend configurations @@ -590,7 +625,7 @@ def run( precision, batch_size, ) - run_tensorrt( + run_onnx_trt( model_torch, input_tensors, params, @@ -610,8 +645,8 @@ def run( precision, batch_size, ) - elif backend == "tensorrt": - run_tensorrt( + elif backend == "onnx_trt": + run_onnx_trt( model_torch, input_tensors, params, @@ -636,7 +671,7 @@ def run( arg_parser.add_argument( "--backends", type=str, - help="Comma separated string of backends. Eg: torch, ts_trt, dynamo, torch_compile, inductor, tensorrt", + help="Comma separated string of backends. Eg: torch, ts_trt, dynamo, torch_compile, inductor, onnx_trt", ) arg_parser.add_argument( "--model", type=str, default="", help="Name of torchscript model file" @@ -743,6 +778,16 @@ def run( action="store_true", help="Whether to load the compiled TRT engines from storage.", ) + arg_parser.add_argument( + "--save_onnx_trt_engine", + action="store_true", + help="Whether to save the ONNX-TRT backend generated TRT engine.", + ) + arg_parser.add_argument( + "--save_dynamo_trt_engine", + action="store_true", + help="Whether to save the Torch-TRT Dynamo backend generated TRT engine.", + ) args = arg_parser.parse_args() # Create random input tensor of certain size @@ -779,12 +824,17 @@ def run( ) backends = parse_backends(params["backends"]) - if any( - backend in ["dynamo", "torch_compile", "tensorrt"] for backend in backends - ) and (model_torch is None): - raise ValueError( - "No Pytorch model (nn.Module) is provided for torchdynamo compilation. Please provide a pytorch model using --model_torch argument" - ) + for backend in backends: + if backend not in SUPPORTED_BACKENDS: + raise ValueError( + f"Backend {backend} is not supported. Please provide a valid backend." + ) + if backend in ["dynamo", "torch_compile", "onnx_trt", "all"] and ( + model_torch is None + ): + raise ValueError( + "No Pytorch model (nn.Module) is provided for torchdynamo compilation. Please provide a pytorch model using --model_torch argument" + ) batch_size = params["batch_size"] is_trt_engine = params["is_trt_engine"] diff --git a/tools/perf/utils.py b/tools/perf/utils.py index 13d7deac43..a491716ad7 100644 --- a/tools/perf/utils.py +++ b/tools/perf/utils.py @@ -16,8 +16,11 @@ "efficientnet_b0", "vit", "vit_large", + "google_vit", "bert_base_uncased", - "sd_unet", + "sd1.4_unet", + "sd2.1_unet", + "sd2.1_vae_decoder", "meta-llama/Llama-2-7b-chat-hf", "gpt2", "meta-llama/Meta-Llama-3-8B", @@ -25,7 +28,7 @@ "apple/DCLM-7B", "mistralai/Mistral-7B-Instruct-v0.3", "microsoft/Phi-3-mini-4k-instruct", - "monai/unet", + "monai_unet", } @@ -78,15 +81,30 @@ def __getitem__(self, name: str): "model": timm.create_model("vit_giant_patch14_224", pretrained=False), "path": ["script", "pytorch"], } + elif name == "google_vit": + return { + "model": cm.GoogleViTForImageClassification(), + "path": "pytorch", + } elif name == "bert_base_uncased": return { "model": cm.BertModule(), "inputs": cm.BertInputs(), "path": ["trace", "pytorch"], } - elif name == "sd_unet": + elif name == "sd1.4_unet": + return { + "model": cm.StableDiffusion1_4_Unet(), + "path": "pytorch", + } + elif name == "sd2.1_unet": + return { + "model": cm.StableDiffusion2_1_Unet(), + "path": "pytorch", + } + elif name == "sd2.1_vae_decoder": return { - "model": cm.StableDiffusionUnet(), + "model": cm.StableDiffusion2_1_VaeDecoder(), "path": "pytorch", } elif name in [ @@ -109,9 +127,9 @@ def __getitem__(self, name: str): "model": hf_artifact["model"], "path": "pytorch", } - elif name == "monai/unet": + elif name == "monai_unet": return { - "model": cm.UNet(), + "model": cm.MonaiUNet(), "path": "pytorch", } else: @@ -164,7 +182,7 @@ def parse_inputs(user_inputs, dtype): else: torchtrt_inputs.append(torch.Tensor([1.0]).cuda()) - return torchtrt_inputs + return tuple(torchtrt_inputs) def parse_backends(backends): From 9851bfea7a48cd8d9d8e159d9e14bf4a6a5844a8 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 16 Sep 2025 12:07:30 -0700 Subject: [PATCH 2/3] add unset_fake_temporarily with minor changes --- py/torch_tensorrt/dynamo/_compiler.py | 3 ++- .../dynamo/conversion/impl/normalization/ops.py | 5 +++-- tools/perf/README.md | 8 ++++---- tools/perf/perf_run.py | 10 +++++++++- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index db8e014879..7bb7d3f62c 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -513,7 +513,7 @@ def compile( if kwargs.get("debug", False): warnings.warn( - "`debug` is deprecated. Please use `with torch_tensorrt.dynamo.Debugger(...)` to wrap your compilation call to enable debugging functionality", + "`debug` is deprecated. Please use `with torch_tensorrt.dynamo.Debugger(...)` to wrap your compilation call to enable debugging functionality.", DeprecationWarning, stacklevel=2, ) @@ -1122,6 +1122,7 @@ def convert_exported_program_to_serialized_trt_engine( Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ + if kwargs.get("debug", False): warnings.warn( "`debug` is deprecated. Please use `with torch_tensorrt.dynamo.Debugger(...)` to wrap your compilation call to enable debugging functionality.", diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 75edf2b44d..f12b16b150 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -325,8 +325,9 @@ def native_group_norm( shape = [1, group] + [1] * (rank - 2) - weight_torch = torch.ones(shape) - bias_torch = torch.zeros(shape) + with unset_fake_temporarily(): + weight_torch = torch.ones(shape) + bias_torch = torch.zeros(shape) weight_one = get_trt_tensor(ctx, weight_torch, f"{name}_weight_one", input.dtype) bias_zero = get_trt_tensor(ctx, bias_torch, f"{name}_bias_zero", input.dtype) diff --git a/tools/perf/README.md b/tools/perf/README.md index 5e48f3fce6..1b7c90686c 100644 --- a/tools/perf/README.md +++ b/tools/perf/README.md @@ -44,7 +44,7 @@ Benchmark scripts depends on following Python packages in addition to requiremen Here are the list of `CompileSpec` options that can be provided directly to compile the pytorch module * `--backends` : Comma separated string of backends. Eg: torch, ts_trt, dynamo, torch_compile, inductor, onnx_trt -* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `dynamo` or `torch_compile`, the input should be a Pytorch module (instead of a torchscript module). +* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (pairing with `--is_trt_engine`)). If the backend is `dynamo` or `torch_compile`, the input should be a Pytorch module (instead of a torchscript module). * `--model_torch` : Name of the PyTorch model file (optional, only necessary if `dynamo` or `torch_compile` is a chosen backend) * `--onnx` : ONNX model file which helps bypass the step of exporting ONNX from `model_torch`. If this argument is provided, the ONNX will be directly converted to TRT engine * `--inputs` : List of input shapes & dtypes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT @@ -61,16 +61,16 @@ Eg: ``` python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \ --model_torch ${MODELS_DIR}/vgg16_torch.pt \ - --precision fp32,fp16 --inputs="(1, 3, 224, 224)@fp32" \ + --precision fp32,fp16 \ + --inputs "(1, 3, 224, 224)@fp32" \ --batch_size 1 \ - --backends torch,ts_trt,dynamo,torch_compile,tensorrt \ + --backends torch,ts_trt,dynamo,torch_compile,inductor,onnx_trt \ --report "vgg_perf_bs1.txt" ``` Note: 1. Please note that measuring INT8 performance is only supported via a `calibration cache` file or QAT mode for `torch_tensorrt` backend. -2. TensorRT engine filename should end with `.plan` otherwise it will be treated as Torchscript module. ### Example models diff --git a/tools/perf/perf_run.py b/tools/perf/perf_run.py index 80b71b56b8..37fa079e76 100644 --- a/tools/perf/perf_run.py +++ b/tools/perf/perf_run.py @@ -480,7 +480,15 @@ def run_onnx_trt( onnx_path = params["onnx"] else: onnx_path = f"{params['model_torch']}-onnx-trt.onnx" - torch.onnx.export(model, tuple(input_tensors), onnx_path, dynamo=True) + len_output = len(model(*input_tensors)) + # to match the output names with Torch-TRT engine's + torch.onnx.export( + model, + tuple(input_tensors), + onnx_path, + dynamo=True, + output_names=[f"output{i}" for i in range(len_output)], + ) start_compile = timeit.default_timer() builder = trt.Builder(logger) network = builder.create_network( From 8e7342e30cecf386e1635625b530f97c138cd92e Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 16 Sep 2025 13:54:18 -0700 Subject: [PATCH 3/3] rephrase prompt --- py/torch_tensorrt/dynamo/_compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 7bb7d3f62c..0dc4654db0 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -568,7 +568,7 @@ def compile( and kwargs["enable_cross_compile_for_windows"] ): raise ValueError( - "Please use cross_compile_for_windows() api if you want to cross compile the module in Linux for inferencing in Windows." + "Please use torch_tensorrt.dynamo.cross_compile_for_windows() if you want to cross compile the module in Linux for inferencing in Windows." ) engine_capability = EngineCapability._from(engine_capability) @@ -1180,7 +1180,7 @@ def convert_exported_program_to_serialized_trt_engine( and kwargs["enable_cross_compile_for_windows"] ): raise ValueError( - "Please use cross_compile_for_windows() api if you want to cross compile the module in Linux for inferencing in Windows." + "Please use torch_tensorrt.dynamo.cross_compile_for_windows() if you want to cross compile the module in Linux for inferencing in Windows." ) engine_capability = EngineCapability._from(engine_capability)