From eac8809729af0eefa6fc9c66bde2207148d8b87d Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 27 Oct 2025 21:59:11 -0700 Subject: [PATCH 1/6] implement autocast --- core/runtime/execute_engine.cpp | 6 - examples/dynamo/autocast_example.py | 98 ++++++ py/torch_tensorrt/dynamo/_compiler.py | 63 ++++ py/torch_tensorrt/dynamo/_defaults.py | 5 + py/torch_tensorrt/dynamo/_settings.py | 22 ++ .../dynamo/conversion/_TRTInterpreter.py | 14 +- .../lowering/passes/_aten_lowering_pass.py | 11 +- .../dynamo/lowering/passes/nodeclassifier.py | 299 ++++++++++++++++++ .../lowering/passes/rule_based_autocast.py | 129 ++++++++ .../runtime/_PythonTorchTensorRTModule.py | 8 - 10 files changed, 626 insertions(+), 29 deletions(-) create mode 100644 examples/dynamo/autocast_example.py create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 64b111750f..6ea4d01d03 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -107,12 +107,6 @@ void setup_input_tensors( TORCHTRT_CHECK( inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device()); - auto expected_type = - util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); - TORCHTRT_CHECK( - inputs[i].dtype() == expected_type, - "Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype()); - auto dims = core::util::toDims(inputs[i].sizes()); auto shape = core::util::toVec(dims); LOG_DEBUG("Input Name: " << name << " Shape: " << dims); diff --git a/examples/dynamo/autocast_example.py b/examples/dynamo/autocast_example.py new file mode 100644 index 0000000000..d047fd1d24 --- /dev/null +++ b/examples/dynamo/autocast_example.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +import torch_tensorrt +import torchvision + + +class MyModule(torch.nn.Module): + def forward(self, a_float32, b_float32, c_float32, d_float32): + with torch.autocast(device_type="cuda"): + e_float16 = torch.mm(a_float32, b_float32) + with torch.autocast(device_type="cuda", enabled=False): + # Calls e_float16.float() to ensure float32 execution + # (necessary because e_float16 was created in an autocasted region) + f_float32 = torch.mm(c_float32, e_float16.float()) + + # No manual casts are required when re-entering the autocast-enabled region. + # torch.mm again runs in float16 and produces float16 output, regardless of input types. + g_float16 = torch.mm(d_float32, f_float32) + return g_float16 + + +class AutocastExample(nn.Module): + def __init__(self): + super(AutocastExample, self).__init__() + self.conv1 = nn.Conv2d( + in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1 + ) + self.relu1 = nn.ReLU() + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv2 = nn.Conv2d( + in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1 + ) + self.relu2 = nn.ReLU() + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(16 * 8 * 8, 10) + + def forward(self, x, y): + out = self.pool1(self.relu1(self.conv1(x))) # fp16 + x = self.pool2(self.relu2(self.conv2(out))) # fp16 + x = self.flatten(x) + with torch.autocast(x.device.type, enabled=True, dtype=torch.float32): + x = self.fc1(x) # fp32 + with torch.autocast(x.device.type, enabled=False): + x = torch.sub(x.half(), y) # fp16 + out2 = torch.add(x, x) # fp16 + with torch.autocast(x.device.type, enabled=True, dtype=torch.float16): + out2 = torch.log(out2) # fp32 + return x, out, out2 + + +class MyResNet18Wrapper(torch.nn.Module): + def __init__(self, num_classes=1000, pretrained=True): + super(MyResNet18Wrapper, self).__init__() + self.resnet = torchvision.models.resnet18( + num_classes=num_classes, weights="IMAGENET1K_V1" if pretrained else None + ) + + def forward(self, x): + x = self.resnet(x) + return x + + +if __name__ == "__main__": + # model = MyModule().cuda().eval() + # inputs = (torch.randn((8, 8), device="cuda"), + # torch.randn((8, 8), device="cuda"), + # torch.randn((8, 8), device="cuda"), + # torch.randn((8, 8), device="cuda"),) + + # model = AutocastExample().cuda().eval() + # inputs = (torch.randn((1, 3, 32, 32), dtype=torch.float32, device="cuda"), + # torch.randn((1,), dtype=torch.float16, device="cuda"),) + + model = MyResNet18Wrapper().cuda().eval() + inputs = (torch.randn((1, 3, 224, 224), dtype=torch.float32, device="cuda"),) + + ep = torch.export.export(model, inputs) + + with torch_tensorrt.dynamo.Debugger( + "graphs", + logging_dir=".", + engine_builder_monitor=False, + ): + trt_mod = torch_tensorrt.compile( + ep.module(), + arg_inputs=inputs, + use_explicit_typing=False, + min_block_size=1, + use_python_runtime=True, + low_precision_type=torch.float16, + # nodes_to_exclude={"^conv2d$"}, + targets_to_exclude={}, + data_max=512, + max_depth_of_reduction=None, + ) + + trt_out = trt_mod(*inputs) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index c8ad938032..615e5cba7b 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -434,6 +434,13 @@ def compile( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + low_precision_type: Optional[ + Union[torch.dtype, dtype] + ] = _defaults.LOW_PRECISION_TYPE, + nodes_to_exclude: Collection[str] = _defaults.NODES_TO_EXCLUDE, + targets_to_exclude: Collection[Target] = _defaults.TARGETS_TO_EXCLUDE, + data_max: float = _defaults.DATA_MAX, + max_depth_of_reduction: Optional[int] = _defaults.MAX_DEPTH_OF_REDUCTION, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -511,6 +518,11 @@ def compile( l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model + low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used. + nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is []. + targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. + data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. + max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -593,6 +605,19 @@ def compile( f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True" ) + if low_precision_type is not None: + if not isinstance(low_precision_type, (torch.dtype, dtype)): + raise ValueError( + f"low_precision_type must be a torch.dtype or dtype, got {type(low_precision_type)}" + ) + if low_precision_type not in { + torch.float16, + torch.bfloat16, + } and low_precision_type not in {dtype.f16, dtype.bf16}: + raise ValueError( + f"low_precision_type must be one of torch.float16, torch.bfloat16, dtype.f16, dtype.bf16, got {low_precision_type}" + ) + if use_fp32_acc: logger.debug( "FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \ @@ -622,6 +647,38 @@ def compile( if not isinstance(arg_inputs, collections.abc.Sequence): arg_inputs = [arg_inputs] # type: ignore + # save intermediate outputs of each node for Autocast + intermediate_node_outputs = {} + if not use_explicit_typing: + + class DumpInterpreter(torch.fx.Interpreter): # type: ignore[misc] + """Dump intermediate outputs of each node""" + + def run_node(self, n: torch.fx.Node) -> Any: + if ( + n.op == "call_function" + and n.target != torch.ops.higher_order.wrap_with_autocast + ): + out = super().run_node(n) + if not isinstance(out, torch.Tensor): + raise ValueError( + f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}." + ) + intermediate_node_outputs[n.name] = out + return out + return super().run_node(n) + + def _materialize(x: Input | torch.Tensor) -> torch.Tensor: + """Materialize an Input object to a tensor""" + if isinstance(x, Input): + return x.torch_tensor + return x + + with torch.no_grad(): + mat_args = tuple(_materialize(a) for a in arg_inputs) + mat_kwargs = {k: _materialize(v) for k, v in kwarg_inputs.items()} + DumpInterpreter(exported_program.module()).run(*mat_args, **mat_kwargs) + # Prepare torch_trt inputs trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) @@ -680,6 +737,12 @@ def compile( "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, "use_distributed_mode_trace": use_distributed_mode_trace, + "low_precision_type": low_precision_type, + "nodes_to_exclude": nodes_to_exclude, + "targets_to_exclude": targets_to_exclude, + "data_max": data_max, + "max_depth_of_reduction": max_depth_of_reduction, + "intermediate_node_outputs": intermediate_node_outputs, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index de970ecd81..d278dd4238 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -57,6 +57,11 @@ L2_LIMIT_FOR_TILING = -1 USE_DISTRIBUTED_MODE_TRACE = False OFFLOAD_MODULE_TO_CPU = False +LOW_PRECISION_TYPE = None +NODES_TO_EXCLUDE = set[str]() +TARGETS_TO_EXCLUDE = set[torch.fx.node.Target]() +DATA_MAX = 512 +MAX_DEPTH_OF_REDUCTION = None if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index d8f6809eae..fbf842421e 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,12 +1,14 @@ from dataclasses import dataclass, field from typing import Any, Collection, Optional, Set, Tuple, Union +import torch from torch.fx.node import Target from torch_tensorrt._Device import Device from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, CACHE_BUILT_ENGINES, + DATA_MAX, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, @@ -21,8 +23,11 @@ IMMUTABLE_WEIGHTS, L2_LIMIT_FOR_TILING, LAZY_ENGINE_INIT, + LOW_PRECISION_TYPE, MAX_AUX_STREAMS, + MAX_DEPTH_OF_REDUCTION, MIN_BLOCK_SIZE, + NODES_TO_EXCLUDE, NUM_AVG_TIMING_ITERS, OFFLOAD_MODULE_TO_CPU, OPTIMIZATION_LEVEL, @@ -32,6 +37,7 @@ REUSE_CACHED_ENGINES, SPARSE_WEIGHTS, STRIP_ENGINE_WEIGHTS, + TARGETS_TO_EXCLUDE, TILING_OPTIMIZATION_LEVEL, TIMING_CACHE_PATH, TRUNCATE_DOUBLE, @@ -97,6 +103,12 @@ class CompilationSettings: tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model + low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used. + nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is []. + targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. + data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. + max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None. + intermediate_node_outputs (dict[str, torch.Tensor]): The intermediate node outputs of the graph. Default is {}. """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -140,6 +152,16 @@ class CompilationSettings: l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU + low_precision_type: Optional[dtype] = LOW_PRECISION_TYPE + nodes_to_exclude: Collection[str] = field(default_factory=lambda: NODES_TO_EXCLUDE) + targets_to_exclude: Collection[Target] = field( + default_factory=lambda: TARGETS_TO_EXCLUDE + ) + data_max: float = DATA_MAX + max_depth_of_reduction: Optional[int] = MAX_DEPTH_OF_REDUCTION + intermediate_node_outputs: dict[str, torch.Tensor] = field( + default_factory=lambda: {} + ) def __getstate__(self) -> dict[str, Any]: from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 73af09448e..4a86a7f907 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -292,17 +292,9 @@ def _populate_trt_builder_config( ) if not self.compilation_settings.use_explicit_typing: - if dtype.float16 in self.compilation_settings.enabled_precisions: - builder_config.set_flag(trt.BuilderFlag.FP16) - - if dtype.int8 in self.compilation_settings.enabled_precisions: - builder_config.set_flag(trt.BuilderFlag.INT8) - - if dtype.fp8 in self.compilation_settings.enabled_precisions: - builder_config.set_flag(trt.BuilderFlag.FP8) - - if dtype.bfloat16 in self.compilation_settings.enabled_precisions: - builder_config.set_flag(trt.BuilderFlag.BF16) + _LOGGER.info( + "Torch-TensorRT uses Autocast to determine the precision of the graph, because weak typing has been deprecated in TensorRT 10.12." + ) if self.compilation_settings.sparse_weights: builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index e5183668ae..1499e670bd 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -15,6 +15,13 @@ from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes from .repair_input_as_output import repair_input_as_output from .replace_max_pool_with_indices import replace_max_pool_with_indices +from .rule_based_autocast import rule_based_autocast + +pre_lowering_pass_list = [ + remove_detach, + rule_based_autocast, + remove_assert_nodes, # rule_based_autocast might insert assert nodes +] post_lowering_pass_list = [ remove_input_alias_fixing_clones, @@ -27,10 +34,6 @@ complex_graph_detection, ] -pre_lowering_pass_list = [ - remove_detach, -] - if not is_tegra_platform(): from .fuse_distributed_ops import fuse_distributed_ops diff --git a/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py b/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py new file mode 100644 index 0000000000..08f3dcdd8d --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py @@ -0,0 +1,299 @@ +# Borrowed from ModelOpt AutoCast's nodeclassifier.py, modified to fit Torch-TensorRT's needs. +import abc +import logging +import operator +import re +from typing import Collection, Optional + +import torch + +logger = logging.getLogger(__name__) + + +class NodeRuleBase: + """Base class for node classification rules. + + This class defines the interface for rules that determine whether a node + should be kept in high precision or converted to low precision. + """ + + @abc.abstractmethod + def _check_inner(self, node): + """Implement this method to check if node conversion should be skipped based on rule criteria.""" + + def _log_skipped(self, node, **kwargs): + """Log information about skipped nodes.""" + logger.info(f"Skipping node {node.name}: {self.__class__.__name__}") + + def check(self, node): + """Check if a node should be skipped based on the rule. + + Args: + node: The ONNX node to check. + + Returns: + bool: True if the node should be kept in high precision, False otherwise. + """ + result = self._check_inner(node) + if result: + self._log_skipped(node) + return True + return False + + +class DisabledNodeNameRegexRule(NodeRuleBase): + """Rule for keeping nodes with matching names in high precision.""" + + def __init__(self, disabled_node_name_regex): + """Initialize the rule. + + Args: + disabled_node_name_regex: List of regex patterns for node names to keep in high precision. + """ + self.disabled_node_name_regex = disabled_node_name_regex + + def _check_inner(self, node): + return any( + re.match(regex, node.name) for regex in self.disabled_node_name_regex + ) + + +class DisabledTargets(NodeRuleBase): + """Rule for keeping nodes with specific operation types in high precision.""" + + def __init__(self, targets_to_exclude): + """Initialize the rule. + + Args: + targets_to_exclude: List of operation types to keep in high precision. + """ + self.targets_to_exclude = targets_to_exclude + + def _check_inner(self, node): + return node.target in self.targets_to_exclude + + +class IORangeRule(NodeRuleBase): + """Rule for keeping nodes with out-of-range inputs/outputs in high precision.""" + + def __init__(self, data_max, reference_data): + """Initialize the rule. + + Args: + data_max: Maximum absolute value allowed for node I/O. + reference_data: Reference data for checking I/O ranges. + """ + self.data_max = data_max + self.reference_data = reference_data + self.output_data = None + + def _check_inner(self, node): + def is_io_out_of_range(node): + tensor_name = node.name + if tensor_name not in self.reference_data: + logger.debug( + f"Node {node.name}: Tensor {tensor_name} not found in reference data. Skipping I/O range check." + ) + return False + ref_data = self.reference_data[tensor_name] + if ref_data.numel() == 0: + logger.debug( + f"Node {node.name}: Tensor {tensor_name} has 0 elements. Skipping I/O range check." + ) + return False + logger.debug( + f"Node {node.name}: reference data: min={ref_data.min()}, max={ref_data.max()}" + ) + if torch.any(torch.abs(ref_data) > self.data_max): + self.output_data = ref_data + return True + + if self.reference_data: + for in_node in node.all_input_nodes: + if is_io_out_of_range(in_node): + return True + for out_node in list(node.users): + if is_io_out_of_range(out_node): + return True + return False + + def _log_skipped(self, node, **kwargs): + """Log information about skipped nodes with I/O range violations.""" + if self.output_data is not None: + logger.info( + f"Skipping node {node.name}: reference IO out of range: min={torch.min(self.output_data)}, " + f"max={torch.max(self.output_data)}, range=[{-self.data_max}, {self.data_max}]" + ) + else: + super()._log_skipped(node, **kwargs) + + +class DepthOfReductionRule(NodeRuleBase): + """Rule for keeping nodes with high depth of reduction in high precision.""" + + def __init__(self, max_depth_of_reduction, reference_data): + """Initialize the rule. + + Args: + max_depth_of_reduction: Maximum depth of reduction allowed in low precision. + reference_data: Reference data for checking I/O ranges. + """ + self.max_depth_of_reduction = max_depth_of_reduction + self.reference_data = reference_data + self.reduction_depth = 0 + + def _get_tensor_shape(self, tensor_name): + """Get tensor shape from reference data.""" + if tensor_name in self.reference_data: + return self.reference_data[tensor_name].shape + return None + + def _log_skipped(self, node, **kwargs): + """Log information about skipped nodes with depth of reduction violations.""" + if self.reduction_depth > 0: + logger.info( + f"Skipping node {node.name}: depth of reduction {self.reduction_depth} exceeds " + f"{self.max_depth_of_reduction}." + ) + else: + super()._log_skipped(node, **kwargs) + + def _check_inner(self, node): + # All reduction ops rely on shape of input[0] + input_0_dims = ( + self._get_tensor_shape(node.all_input_nodes[0].name) + if len(node.all_input_nodes) > 0 + else None + ) + if input_0_dims is None: + return False + self.reduction_depth = 0 + if node.target in [ + torch.ops.aten.scaled_dot_product_attention, + ]: + # Attention: input (batch_size, sequence_length, hidden_size) + # or (batch_size, kv_num_heads, total_sequence_length, head_size) + assert len(input_0_dims) == 3 or len(input_0_dims) == 4 + hidden_size = ( + input_0_dims[2] + if len(input_0_dims) == 3 + else input_0_dims[1] * input_0_dims[3] + ) + self.reduction_depth = hidden_size + elif node.target in [ + torch.ops.aten.convolution, + torch.ops.aten.conv1d, + torch.ops.aten.conv2d, + torch.ops.aten.conv3d, + ]: + # Conv: input (N x C x D1 x D2 ... x Dn) + # weight (out_channels, in_channels, kD1, kD2, ... kDn) + # Reduction depth = in_channels * kernel_volume + weight_shape = ( + self._get_tensor_shape(node.all_input_nodes[1].name) + if len(node.all_input_nodes) > 1 + else None + ) + if weight_shape is None: + return False + in_channels = weight_shape[1] + kernel_volume = torch.prod(weight_shape[2:]) + self.reduction_depth = in_channels * kernel_volume + elif node.target in [ + torch.ops.aten.matmul, + torch.ops.aten.dot, + torch.ops.aten.mm, + torch.ops.aten.mv, + torch.ops.aten.bmm, + ]: + # GEMM: A (M, K) @ B (K, N) = C (M, N) + self.reduction_depth = input_0_dims[-1] + # TODO: Add more reduction ops here + return self.reduction_depth > self.max_depth_of_reduction + + +class NodeClassifier: + """Main class for classifying nodes into high and low precision groups.""" + + def __init__( + self, + nodes, + nodes_to_exclude: Collection[str] | None = None, + targets_to_exclude: Collection[torch.fx.node.Target] | None = None, + custom_rule: NodeRuleBase | None = None, + data_max: float | None = 1000.0, + max_depth_of_reduction: int | None = None, + ): + """Initialize the node classifier. + + Args: + nodes: The nodes to classify. + nodes_to_exclude: Collection of regex patterns for node names to keep in high precision. + targets_to_exclude: Collection of targets to keep in high precision. + custom_rule: Optional custom classification rule. + data_max: Maximum absolute value allowed for node I/O. + max_depth_of_reduction: Maximum depth of reduction allowed in low precision. + """ + self.nodes = nodes + self.nodes_to_exclude = nodes_to_exclude + self.targets_to_exclude = targets_to_exclude + self.custom_rule = custom_rule + self.data_max = data_max + self.max_depth_of_reduction = max_depth_of_reduction + + def _gen_block_node_rules(self, reference_data): + """Generate list of rules for blocking nodes from precision conversion. + + Args: + reference_data: Reference data for checking I/O ranges. + + Returns: + list[NodeRuleBase]: List of rules to apply. + """ + block_node_rules: list[NodeRuleBase] = [] + if self.nodes_to_exclude: + block_node_rules.append(DisabledNodeNameRegexRule(self.nodes_to_exclude)) + if self.targets_to_exclude: + block_node_rules.append(DisabledTargets(self.targets_to_exclude)) + if reference_data: + block_node_rules.append(IORangeRule(self.data_max, reference_data)) + if self.max_depth_of_reduction is not None: + block_node_rules.append( + DepthOfReductionRule( + self.max_depth_of_reduction, + reference_data, + ) + ) + if self.custom_rule: + block_node_rules.append(self.custom_rule) + return block_node_rules + + def run( + self, ref_outputs_dict: Optional[dict[str, torch.Tensor]] = None + ) -> tuple[list[str], list[str]]: + """Run node classification. + + Args: + ref_outputs_dict: Optional tensors' reference data. + + Returns: + tuple: Lists of node names (low_precision_nodes, high_precision_nodes). + """ + block_node_rules = self._gen_block_node_rules(ref_outputs_dict) + low_precision_nodes = [] + high_precision_nodes = [] + for node in self.nodes: + if node.op == "call_function": + if ( + node.target == torch.ops.higher_order.wrap_with_autocast + or node.target == operator.getitem + ): + continue + # If any condition is met - node will be executed in high precision + if any(rule.check(node) for rule in block_node_rules): + high_precision_nodes.append(node.name) + else: + low_precision_nodes.append(node.name) + logger.debug(f"Low Precision Nodes: {low_precision_nodes}") + logger.debug(f"High Precision Nodes: {high_precision_nodes}") + return low_precision_nodes, high_precision_nodes diff --git a/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py b/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py new file mode 100644 index 0000000000..02aa4e5a5e --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py @@ -0,0 +1,129 @@ +import logging +import operator +from typing import Any + +import torch +from torch._export.passes.replace_autocast_with_hop_pass import ( + replace_autocast_with_hop_pass, +) +from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo._settings import CompilationSettings + +from .nodeclassifier import NodeClassifier +from .pass_utils import clean_up_graph_after_modifications + +logger = logging.getLogger(__name__) + + +def is_tensor_node(n: torch.fx.Node) -> bool: + val = n.meta.get("val", None) + if hasattr(val, "dtype"): + return True + return False + + +def rule_based_autocast( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Rule-based autocast""" + if settings.use_explicit_typing: + logger.debug("Strong typing is enabled, skipping rule-based autocast.") + return gm + + # nodes = list(gm.graph.nodes) + # # insert enter autocast node in the beginning of the graph + # with gm.graph.inserting_before(nodes[0]): + # enter_autocast_node = gm.graph.call_function(torch.amp.autocast_mode._enter_autocast, args=("cuda", torch.float16, True, True)) + # enter_autocast_node.meta.update(getattr(nodes[0], "meta", {})) + + # # insert exit autocast node before the return node, assuming the return node is the last node + # with gm.graph.inserting_before(nodes[-1]): + # exit_autocast_node = gm.graph.call_function(torch.amp.autocast_mode._exit_autocast, args=(enter_autocast_node,)) + # exit_autocast_node.meta.update(getattr(nodes[-1], "meta", {})) + + # gm = clean_up_graph_after_modifications(gm) + # gm, new_signature = replace_autocast_with_hop_pass(gm, None) + # logger.debug("Graph after replace_autocast_with_hop_pass:\n%s", gm.graph) + + # get config from settings + low_precision_type = settings.low_precision_type + if low_precision_type is None: + return gm + if isinstance(low_precision_type, dtype): + low_precision_type = low_precision_type.to(torch.dtype) + high_precision_type = torch.float32 + nodes_to_exclude = settings.nodes_to_exclude + targets_to_exclude = settings.targets_to_exclude + data_max = settings.data_max + max_depth_of_reduction = settings.max_depth_of_reduction + reference_data: dict[str, torch.Tensor] = settings.intermediate_node_outputs + + node_classifier = NodeClassifier( + gm.graph.nodes, + nodes_to_exclude=nodes_to_exclude, + targets_to_exclude=targets_to_exclude, + data_max=data_max, + max_depth_of_reduction=max_depth_of_reduction, + ) + low_precision_nodes, high_precision_nodes = node_classifier.run(reference_data) + + for node in list(gm.graph.nodes): + if node.op == "call_function": + if ( + node.target == torch.ops.higher_order.wrap_with_autocast + or node.target == operator.getitem + ): + continue + + def _cast_all_tensor_args_to_dtype(arg: Any, dtype: torch.dtype) -> Any: + """Cast all tensor args to the given dtype + + Args: + arg: The argument to cast + dtype: The dtype to cast to + + Returns: + The casted argument + """ + if isinstance(arg, torch.fx.Node) and is_tensor_node(arg): + val = arg.meta.get("val", None) + with gm.graph.inserting_before(node): + cast = gm.graph.call_function( + torch.ops.aten.to.dtype, args=(arg, dtype) + ) + + if isinstance(val, torch.Tensor): + arg.meta["val"] = val.to(dtype) + cast.meta.update(arg.meta) + return cast + elif isinstance(arg, (tuple, list)): + return type(arg)( + _cast_all_tensor_args_to_dtype(a, dtype) for a in arg + ) + elif isinstance(arg, dict): + return { + k: _cast_all_tensor_args_to_dtype(v, dtype) + for k, v in arg.items() + } + else: + return arg + + if node.name in low_precision_nodes: + node.args = _cast_all_tensor_args_to_dtype( + node.args, low_precision_type + ) + node.kwargs = _cast_all_tensor_args_to_dtype( + node.kwargs, low_precision_type + ) + elif node.name in high_precision_nodes: + node.args = _cast_all_tensor_args_to_dtype( + node.args, high_precision_type + ) + node.kwargs = _cast_all_tensor_args_to_dtype( + node.kwargs, high_precision_type + ) + + gm = clean_up_graph_after_modifications(gm) + logger.debug("Graph after Autocast based on the rules:\n%s", gm.graph) + + return gm diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index d18a5674e0..0eb5ebbbca 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -275,10 +275,6 @@ def setup_engine(self) -> None: len(self.input_names) + len(self.output_names) ) - self.input_dtypes = [ - dtype._from(self.engine.get_tensor_dtype(input_name)) - for input_name in self.input_names - ] self.input_shapes = [ self.engine.get_tensor_shape(input_name) for input_name in self.input_names ] @@ -371,10 +367,6 @@ def setup_input_tensors( + contiguous_inputs[i + 1 :] ) - assert ( - contiguous_inputs[i].dtype == self.input_dtypes[i] - ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." - if need_cudagraphs_record: # If cudagraphs is enabled, this memory is reserved for future cudagraph runs # Clone is required to avoid re-using user-provided GPU memory From f6c7c7c92c82bf997e5a5b105d5e5e8a3331399b Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 27 Oct 2025 22:10:49 -0700 Subject: [PATCH 2/6] fix bug --- .../dynamo/lowering/passes/nodeclassifier.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py b/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py index 08f3dcdd8d..72bb376291 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py @@ -169,7 +169,7 @@ def _check_inner(self, node): return False self.reduction_depth = 0 if node.target in [ - torch.ops.aten.scaled_dot_product_attention, + torch.ops.aten.scaled_dot_product_attention.default, ]: # Attention: input (batch_size, sequence_length, hidden_size) # or (batch_size, kv_num_heads, total_sequence_length, head_size) @@ -181,10 +181,10 @@ def _check_inner(self, node): ) self.reduction_depth = hidden_size elif node.target in [ - torch.ops.aten.convolution, - torch.ops.aten.conv1d, - torch.ops.aten.conv2d, - torch.ops.aten.conv3d, + torch.ops.aten.convolution.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv3d.default, ]: # Conv: input (N x C x D1 x D2 ... x Dn) # weight (out_channels, in_channels, kD1, kD2, ... kDn) @@ -201,10 +201,11 @@ def _check_inner(self, node): self.reduction_depth = in_channels * kernel_volume elif node.target in [ torch.ops.aten.matmul, - torch.ops.aten.dot, - torch.ops.aten.mm, - torch.ops.aten.mv, - torch.ops.aten.bmm, + torch.ops.aten.matmul.default, + torch.ops.aten.dot.default, + torch.ops.aten.mm.default, + torch.ops.aten.mv.default, + torch.ops.aten.bmm.default, ]: # GEMM: A (M, K) @ B (K, N) = C (M, N) self.reduction_depth = input_0_dims[-1] From f7d80686f0f31a0a0dded18232812e6a1397cffe Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 28 Oct 2025 19:10:46 -0700 Subject: [PATCH 3/6] add arg enable_autocast --- examples/dynamo/autocast_example.py | 7 ++++++- py/torch_tensorrt/dynamo/_compiler.py | 11 +++++++++-- py/torch_tensorrt/dynamo/_defaults.py | 1 + py/torch_tensorrt/dynamo/_settings.py | 4 ++++ .../dynamo/conversion/_TRTInterpreter.py | 14 +++++++++++--- .../dynamo/lowering/passes/rule_based_autocast.py | 4 ++-- .../runtime/_CudaGraphsTorchTensorRTModule.py | 4 ---- 7 files changed, 33 insertions(+), 12 deletions(-) diff --git a/examples/dynamo/autocast_example.py b/examples/dynamo/autocast_example.py index d047fd1d24..9feeb6d751 100644 --- a/examples/dynamo/autocast_example.py +++ b/examples/dynamo/autocast_example.py @@ -85,9 +85,14 @@ def forward(self, x): trt_mod = torch_tensorrt.compile( ep.module(), arg_inputs=inputs, - use_explicit_typing=False, min_block_size=1, use_python_runtime=True, + ##### weak typing ##### + # use_explicit_typing=False, + # enabled_precisions={torch.float16}, + ##### strong typing + autocast ##### + use_explicit_typing=True, + enable_autocast=True, low_precision_type=torch.float16, # nodes_to_exclude={"^conv2d$"}, targets_to_exclude={}, diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 615e5cba7b..511d215335 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -141,7 +141,7 @@ def cross_compile_for_windows( disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. - enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels + enabled_precisions (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels workspace_size (int): Maximum size of workspace given to TensorRT @@ -434,6 +434,7 @@ def compile( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + enable_autocast: bool = _defaults.ENABLE_AUTOCAST, low_precision_type: Optional[ Union[torch.dtype, dtype] ] = _defaults.LOW_PRECISION_TYPE, @@ -518,6 +519,7 @@ def compile( l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model + enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True. low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used. nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is []. targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. @@ -596,6 +598,10 @@ def compile( "\nThis feature is unimplemented in Torch-TRT Dynamo currently." ) + if enable_autocast: + use_explicit_typing = True + logger.debug("Autocast is enabled, setting use_explicit_typing to True.") + if use_explicit_typing: if len(enabled_precisions) != 1 or not any( x in enabled_precisions @@ -608,7 +614,7 @@ def compile( if low_precision_type is not None: if not isinstance(low_precision_type, (torch.dtype, dtype)): raise ValueError( - f"low_precision_type must be a torch.dtype or dtype, got {type(low_precision_type)}" + f"low_precision_type must be a torch.dtype or torch_tensorrt._enums.dtype, got {type(low_precision_type)}" ) if low_precision_type not in { torch.float16, @@ -737,6 +743,7 @@ def _materialize(x: Input | torch.Tensor) -> torch.Tensor: "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, "use_distributed_mode_trace": use_distributed_mode_trace, + "enable_autocast": enable_autocast, "low_precision_type": low_precision_type, "nodes_to_exclude": nodes_to_exclude, "targets_to_exclude": targets_to_exclude, diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index d278dd4238..e69cda70c7 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -57,6 +57,7 @@ L2_LIMIT_FOR_TILING = -1 USE_DISTRIBUTED_MODE_TRACE = False OFFLOAD_MODULE_TO_CPU = False +ENABLE_AUTOCAST = False LOW_PRECISION_TYPE = None NODES_TO_EXCLUDE = set[str]() TARGETS_TO_EXCLUDE = set[torch.fx.node.Target]() diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index fbf842421e..e406bba615 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -14,6 +14,7 @@ DLA_LOCAL_DRAM_SIZE, DLA_SRAM_SIZE, DRYRUN, + ENABLE_AUTOCAST, ENABLE_CROSS_COMPILE_FOR_WINDOWS, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, ENABLE_WEIGHT_STREAMING, @@ -103,6 +104,7 @@ class CompilationSettings: tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model + enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True. low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used. nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is []. targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. @@ -152,6 +154,7 @@ class CompilationSettings: l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU + enable_autocast: bool = ENABLE_AUTOCAST low_precision_type: Optional[dtype] = LOW_PRECISION_TYPE nodes_to_exclude: Collection[str] = field(default_factory=lambda: NODES_TO_EXCLUDE) targets_to_exclude: Collection[Target] = field( @@ -179,6 +182,7 @@ def __setstate__(self, state: dict[str, Any]) -> None: self.__dict__.update(state) +# If any of the following setting is changed, the engine should be rebuilt. _SETTINGS_TO_BE_ENGINE_INVARIANT = ( "enabled_precisions", "max_aux_streams", diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 4a86a7f907..73af09448e 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -292,9 +292,17 @@ def _populate_trt_builder_config( ) if not self.compilation_settings.use_explicit_typing: - _LOGGER.info( - "Torch-TensorRT uses Autocast to determine the precision of the graph, because weak typing has been deprecated in TensorRT 10.12." - ) + if dtype.float16 in self.compilation_settings.enabled_precisions: + builder_config.set_flag(trt.BuilderFlag.FP16) + + if dtype.int8 in self.compilation_settings.enabled_precisions: + builder_config.set_flag(trt.BuilderFlag.INT8) + + if dtype.fp8 in self.compilation_settings.enabled_precisions: + builder_config.set_flag(trt.BuilderFlag.FP8) + + if dtype.bfloat16 in self.compilation_settings.enabled_precisions: + builder_config.set_flag(trt.BuilderFlag.BF16) if self.compilation_settings.sparse_weights: builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py b/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py index 02aa4e5a5e..6a824a6a90 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py @@ -26,8 +26,8 @@ def rule_based_autocast( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: """Rule-based autocast""" - if settings.use_explicit_typing: - logger.debug("Strong typing is enabled, skipping rule-based autocast.") + if not settings.enable_autocast: + logger.debug("Autocast is not enabled, skipping rule-based autocast.") return gm # nodes = list(gm.graph.nodes) diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index 9e54fbac3d..24166eb895 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -154,10 +154,6 @@ def forward( + contiguous_inputs[i + 1 :] ) - assert ( - contiguous_inputs[i].dtype == inputs[i].dtype - ), f"Dtype mismatch for {i}th input. Expect {inputs[i].dtype}, got {contiguous_inputs[i].dtype}." - if need_cudagraphs_record: # If cudagraphs is enabled, this memory is reserved for future cudagraph runs # Clone is required to avoid re-using user-provided GPU memory From e15ce946fa28a8faf5cbd6ec220329a86954255e Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 4 Nov 2025 18:50:01 -0800 Subject: [PATCH 4/6] change names of API and support for user specified node names --- examples/dynamo/autocast_example.py | 76 ++++++------------- py/torch_tensorrt/dynamo/_compiler.py | 52 +++++++------ py/torch_tensorrt/dynamo/_defaults.py | 10 +-- py/torch_tensorrt/dynamo/_settings.py | 38 +++++----- .../dynamo/lowering/passes/nodeclassifier.py | 32 ++++---- .../lowering/passes/rule_based_autocast.py | 50 +++++------- 6 files changed, 111 insertions(+), 147 deletions(-) diff --git a/examples/dynamo/autocast_example.py b/examples/dynamo/autocast_example.py index 9feeb6d751..f1487cfb72 100644 --- a/examples/dynamo/autocast_example.py +++ b/examples/dynamo/autocast_example.py @@ -1,22 +1,6 @@ import torch import torch.nn as nn import torch_tensorrt -import torchvision - - -class MyModule(torch.nn.Module): - def forward(self, a_float32, b_float32, c_float32, d_float32): - with torch.autocast(device_type="cuda"): - e_float16 = torch.mm(a_float32, b_float32) - with torch.autocast(device_type="cuda", enabled=False): - # Calls e_float16.float() to ensure float32 execution - # (necessary because e_float16 was created in an autocasted region) - f_float32 = torch.mm(c_float32, e_float16.float()) - - # No manual casts are required when re-entering the autocast-enabled region. - # torch.mm again runs in float16 and produces float16 output, regardless of input types. - g_float16 = torch.mm(d_float32, f_float32) - return g_float16 class AutocastExample(nn.Module): @@ -36,44 +20,32 @@ def __init__(self): self.fc1 = nn.Linear(16 * 8 * 8, 10) def forward(self, x, y): - out = self.pool1(self.relu1(self.conv1(x))) # fp16 - x = self.pool2(self.relu2(self.conv2(out))) # fp16 - x = self.flatten(x) + x = self.conv1(x) # fp32 because of "^conv1$" in `autocast_excluded_nodes` + x = self.relu1(x) # fp32 because of "relu" in `autocast_excluded_nodes` + out = self.pool1(x) # fp16 + x = self.conv2(out) # fp16 + x = self.relu2(x) # fp32 because of "relu" in `autocast_excluded_nodes` + x = self.pool2(x) # fp16 + x = self.flatten( + x + ) # fp32 because of `torch.ops.aten.flatten.using_ints` in `autocast_excluded_ops` + # Respect the precisions in the pytorch autocast context with torch.autocast(x.device.type, enabled=True, dtype=torch.float32): - x = self.fc1(x) # fp32 + x = self.fc1(x) with torch.autocast(x.device.type, enabled=False): - x = torch.sub(x.half(), y) # fp16 - out2 = torch.add(x, x) # fp16 + x = torch.sub(x.half(), y) + out2 = torch.add(x, x) with torch.autocast(x.device.type, enabled=True, dtype=torch.float16): - out2 = torch.log(out2) # fp32 + out2 = torch.log(out2) return x, out, out2 -class MyResNet18Wrapper(torch.nn.Module): - def __init__(self, num_classes=1000, pretrained=True): - super(MyResNet18Wrapper, self).__init__() - self.resnet = torchvision.models.resnet18( - num_classes=num_classes, weights="IMAGENET1K_V1" if pretrained else None - ) - - def forward(self, x): - x = self.resnet(x) - return x - - if __name__ == "__main__": - # model = MyModule().cuda().eval() - # inputs = (torch.randn((8, 8), device="cuda"), - # torch.randn((8, 8), device="cuda"), - # torch.randn((8, 8), device="cuda"), - # torch.randn((8, 8), device="cuda"),) - - # model = AutocastExample().cuda().eval() - # inputs = (torch.randn((1, 3, 32, 32), dtype=torch.float32, device="cuda"), - # torch.randn((1,), dtype=torch.float16, device="cuda"),) - - model = MyResNet18Wrapper().cuda().eval() - inputs = (torch.randn((1, 3, 224, 224), dtype=torch.float32, device="cuda"),) + model = AutocastExample().cuda().eval() + inputs = ( + torch.randn((1, 3, 32, 32), dtype=torch.float32, device="cuda"), + torch.randn((1,), dtype=torch.float16, device="cuda"), + ) ep = torch.export.export(model, inputs) @@ -93,11 +65,11 @@ def forward(self, x): ##### strong typing + autocast ##### use_explicit_typing=True, enable_autocast=True, - low_precision_type=torch.float16, - # nodes_to_exclude={"^conv2d$"}, - targets_to_exclude={}, - data_max=512, - max_depth_of_reduction=None, + autocast_low_precision_type=torch.float16, + autocast_excluded_nodes={"^conv1$", "relu"}, + autocast_excluded_ops={torch.ops.aten.flatten.using_ints}, + autocast_data_max=512, + autocast_max_depth_of_reduction=None, ) trt_out = trt_mod(*inputs) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 511d215335..a78ae0a813 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -435,13 +435,15 @@ def compile( offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, enable_autocast: bool = _defaults.ENABLE_AUTOCAST, - low_precision_type: Optional[ + autocast_low_precision_type: Optional[ Union[torch.dtype, dtype] - ] = _defaults.LOW_PRECISION_TYPE, - nodes_to_exclude: Collection[str] = _defaults.NODES_TO_EXCLUDE, - targets_to_exclude: Collection[Target] = _defaults.TARGETS_TO_EXCLUDE, - data_max: float = _defaults.DATA_MAX, - max_depth_of_reduction: Optional[int] = _defaults.MAX_DEPTH_OF_REDUCTION, + ] = _defaults.AUTOCAST_LOW_PRECISION_TYPE, + autocast_excluded_nodes: Collection[str] = _defaults.AUTOCAST_EXCLUDED_NODES, + autocast_excluded_ops: Collection[Target] = _defaults.AUTOCAST_EXCLUDED_OPS, + autocast_data_max: float = _defaults.AUTOCAST_DATA_MAX, + autocast_max_depth_of_reduction: Optional[ + int + ] = _defaults.AUTOCAST_MAX_DEPTH_OF_REDUCTION, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -520,11 +522,11 @@ def compile( offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True. - low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used. - nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is []. - targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. - data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. - max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None. + autocast_low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used. + autocast_excluded_nodes (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is []. + autocast_excluded_ops (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. + autocast_data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. + autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -611,17 +613,17 @@ def compile( f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True" ) - if low_precision_type is not None: - if not isinstance(low_precision_type, (torch.dtype, dtype)): + if autocast_low_precision_type is not None: + if not isinstance(autocast_low_precision_type, (torch.dtype, dtype)): raise ValueError( - f"low_precision_type must be a torch.dtype or torch_tensorrt._enums.dtype, got {type(low_precision_type)}" + f"autocast_low_precision_type must be a torch.dtype or torch_tensorrt._enums.dtype, got {type(autocast_low_precision_type)}" ) - if low_precision_type not in { + if autocast_low_precision_type not in { torch.float16, torch.bfloat16, - } and low_precision_type not in {dtype.f16, dtype.bf16}: + } and autocast_low_precision_type not in {dtype.f16, dtype.bf16}: raise ValueError( - f"low_precision_type must be one of torch.float16, torch.bfloat16, dtype.f16, dtype.bf16, got {low_precision_type}" + f"autocast_low_precision_type must be one of torch.float16, torch.bfloat16, dtype.f16, dtype.bf16, got {autocast_low_precision_type}" ) if use_fp32_acc: @@ -654,7 +656,7 @@ def compile( arg_inputs = [arg_inputs] # type: ignore # save intermediate outputs of each node for Autocast - intermediate_node_outputs = {} + autocast_intermediate_node_outputs = {} if not use_explicit_typing: class DumpInterpreter(torch.fx.Interpreter): # type: ignore[misc] @@ -670,7 +672,7 @@ def run_node(self, n: torch.fx.Node) -> Any: raise ValueError( f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}." ) - intermediate_node_outputs[n.name] = out + autocast_intermediate_node_outputs[n.name] = out return out return super().run_node(n) @@ -744,12 +746,12 @@ def _materialize(x: Input | torch.Tensor) -> torch.Tensor: "offload_module_to_cpu": offload_module_to_cpu, "use_distributed_mode_trace": use_distributed_mode_trace, "enable_autocast": enable_autocast, - "low_precision_type": low_precision_type, - "nodes_to_exclude": nodes_to_exclude, - "targets_to_exclude": targets_to_exclude, - "data_max": data_max, - "max_depth_of_reduction": max_depth_of_reduction, - "intermediate_node_outputs": intermediate_node_outputs, + "autocast_low_precision_type": autocast_low_precision_type, + "autocast_excluded_nodes": autocast_excluded_nodes, + "autocast_excluded_ops": autocast_excluded_ops, + "autocast_data_max": autocast_data_max, + "autocast_max_depth_of_reduction": autocast_max_depth_of_reduction, + "autocast_intermediate_node_outputs": autocast_intermediate_node_outputs, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index e69cda70c7..a92fcf9d4e 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -58,11 +58,11 @@ USE_DISTRIBUTED_MODE_TRACE = False OFFLOAD_MODULE_TO_CPU = False ENABLE_AUTOCAST = False -LOW_PRECISION_TYPE = None -NODES_TO_EXCLUDE = set[str]() -TARGETS_TO_EXCLUDE = set[torch.fx.node.Target]() -DATA_MAX = 512 -MAX_DEPTH_OF_REDUCTION = None +AUTOCAST_LOW_PRECISION_TYPE = None +AUTOCAST_EXCLUDED_NODES = set[str]() +AUTOCAST_EXCLUDED_OPS = set[torch.fx.node.Target]() +AUTOCAST_DATA_MAX = 512 +AUTOCAST_MAX_DEPTH_OF_REDUCTION = None if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index e406bba615..814e75f917 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -7,8 +7,12 @@ from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, + AUTOCAST_DATA_MAX, + AUTOCAST_EXCLUDED_NODES, + AUTOCAST_EXCLUDED_OPS, + AUTOCAST_LOW_PRECISION_TYPE, + AUTOCAST_MAX_DEPTH_OF_REDUCTION, CACHE_BUILT_ENGINES, - DATA_MAX, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, @@ -24,11 +28,8 @@ IMMUTABLE_WEIGHTS, L2_LIMIT_FOR_TILING, LAZY_ENGINE_INIT, - LOW_PRECISION_TYPE, MAX_AUX_STREAMS, - MAX_DEPTH_OF_REDUCTION, MIN_BLOCK_SIZE, - NODES_TO_EXCLUDE, NUM_AVG_TIMING_ITERS, OFFLOAD_MODULE_TO_CPU, OPTIMIZATION_LEVEL, @@ -38,7 +39,6 @@ REUSE_CACHED_ENGINES, SPARSE_WEIGHTS, STRIP_ENGINE_WEIGHTS, - TARGETS_TO_EXCLUDE, TILING_OPTIMIZATION_LEVEL, TIMING_CACHE_PATH, TRUNCATE_DOUBLE, @@ -105,12 +105,12 @@ class CompilationSettings: l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True. - low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used. - nodes_to_exclude (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is []. - targets_to_exclude (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. - data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. - max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None. - intermediate_node_outputs (dict[str, torch.Tensor]): The intermediate node outputs of the graph. Default is {}. + autocast_low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used. + autocast_excluded_nodes (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is []. + autocast_excluded_ops (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. + autocast_data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. + autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None. + autocast_intermediate_node_outputs (dict[str, torch.Tensor]): The intermediate node outputs of the graph. Default is {}. """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -155,14 +155,16 @@ class CompilationSettings: use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU enable_autocast: bool = ENABLE_AUTOCAST - low_precision_type: Optional[dtype] = LOW_PRECISION_TYPE - nodes_to_exclude: Collection[str] = field(default_factory=lambda: NODES_TO_EXCLUDE) - targets_to_exclude: Collection[Target] = field( - default_factory=lambda: TARGETS_TO_EXCLUDE + autocast_low_precision_type: Optional[dtype] = AUTOCAST_LOW_PRECISION_TYPE + autocast_excluded_nodes: Collection[str] = field( + default_factory=lambda: AUTOCAST_EXCLUDED_NODES ) - data_max: float = DATA_MAX - max_depth_of_reduction: Optional[int] = MAX_DEPTH_OF_REDUCTION - intermediate_node_outputs: dict[str, torch.Tensor] = field( + autocast_excluded_ops: Collection[Target] = field( + default_factory=lambda: AUTOCAST_EXCLUDED_OPS + ) + autocast_data_max: float = AUTOCAST_DATA_MAX + autocast_max_depth_of_reduction: Optional[int] = AUTOCAST_MAX_DEPTH_OF_REDUCTION + autocast_intermediate_node_outputs: dict[str, torch.Tensor] = field( default_factory=lambda: {} ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py b/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py index 72bb376291..b7b7c770c3 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py @@ -53,24 +53,28 @@ def __init__(self, disabled_node_name_regex): self.disabled_node_name_regex = disabled_node_name_regex def _check_inner(self, node): + stack = node.meta.get("nn_module_stack") + node_name = next(reversed(stack), "").split("__")[ + -1 + ] # get the user specified name of the node return any( - re.match(regex, node.name) for regex in self.disabled_node_name_regex + re.match(regex, node_name) for regex in self.disabled_node_name_regex ) -class DisabledTargets(NodeRuleBase): +class DisabledOpTypes(NodeRuleBase): """Rule for keeping nodes with specific operation types in high precision.""" - def __init__(self, targets_to_exclude): + def __init__(self, excluded_ops): """Initialize the rule. Args: - targets_to_exclude: List of operation types to keep in high precision. + excluded_ops: List of operation types to keep in high precision. """ - self.targets_to_exclude = targets_to_exclude + self.excluded_ops = excluded_ops def _check_inner(self, node): - return node.target in self.targets_to_exclude + return node.target in self.excluded_ops class IORangeRule(NodeRuleBase): @@ -219,8 +223,8 @@ class NodeClassifier: def __init__( self, nodes, - nodes_to_exclude: Collection[str] | None = None, - targets_to_exclude: Collection[torch.fx.node.Target] | None = None, + excluded_nodes: Collection[str] | None = None, + excluded_ops: Collection[torch.fx.node.Target] | None = None, custom_rule: NodeRuleBase | None = None, data_max: float | None = 1000.0, max_depth_of_reduction: int | None = None, @@ -236,8 +240,8 @@ def __init__( max_depth_of_reduction: Maximum depth of reduction allowed in low precision. """ self.nodes = nodes - self.nodes_to_exclude = nodes_to_exclude - self.targets_to_exclude = targets_to_exclude + self.excluded_nodes = excluded_nodes + self.excluded_ops = excluded_ops self.custom_rule = custom_rule self.data_max = data_max self.max_depth_of_reduction = max_depth_of_reduction @@ -252,10 +256,10 @@ def _gen_block_node_rules(self, reference_data): list[NodeRuleBase]: List of rules to apply. """ block_node_rules: list[NodeRuleBase] = [] - if self.nodes_to_exclude: - block_node_rules.append(DisabledNodeNameRegexRule(self.nodes_to_exclude)) - if self.targets_to_exclude: - block_node_rules.append(DisabledTargets(self.targets_to_exclude)) + if self.excluded_nodes: + block_node_rules.append(DisabledNodeNameRegexRule(self.excluded_nodes)) + if self.excluded_ops: + block_node_rules.append(DisabledOpTypes(self.excluded_ops)) if reference_data: block_node_rules.append(IORangeRule(self.data_max, reference_data)) if self.max_depth_of_reduction is not None: diff --git a/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py b/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py index 6a824a6a90..097e17a944 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py @@ -3,9 +3,6 @@ from typing import Any import torch -from torch._export.passes.replace_autocast_with_hop_pass import ( - replace_autocast_with_hop_pass, -) from torch_tensorrt._enums import dtype from torch_tensorrt.dynamo._settings import CompilationSettings @@ -30,40 +27,27 @@ def rule_based_autocast( logger.debug("Autocast is not enabled, skipping rule-based autocast.") return gm - # nodes = list(gm.graph.nodes) - # # insert enter autocast node in the beginning of the graph - # with gm.graph.inserting_before(nodes[0]): - # enter_autocast_node = gm.graph.call_function(torch.amp.autocast_mode._enter_autocast, args=("cuda", torch.float16, True, True)) - # enter_autocast_node.meta.update(getattr(nodes[0], "meta", {})) - - # # insert exit autocast node before the return node, assuming the return node is the last node - # with gm.graph.inserting_before(nodes[-1]): - # exit_autocast_node = gm.graph.call_function(torch.amp.autocast_mode._exit_autocast, args=(enter_autocast_node,)) - # exit_autocast_node.meta.update(getattr(nodes[-1], "meta", {})) - - # gm = clean_up_graph_after_modifications(gm) - # gm, new_signature = replace_autocast_with_hop_pass(gm, None) - # logger.debug("Graph after replace_autocast_with_hop_pass:\n%s", gm.graph) - # get config from settings - low_precision_type = settings.low_precision_type - if low_precision_type is None: + autocast_low_precision_type = settings.autocast_low_precision_type + if autocast_low_precision_type is None: return gm - if isinstance(low_precision_type, dtype): - low_precision_type = low_precision_type.to(torch.dtype) + if isinstance(autocast_low_precision_type, dtype): + autocast_low_precision_type = autocast_low_precision_type.to(torch.dtype) high_precision_type = torch.float32 - nodes_to_exclude = settings.nodes_to_exclude - targets_to_exclude = settings.targets_to_exclude - data_max = settings.data_max - max_depth_of_reduction = settings.max_depth_of_reduction - reference_data: dict[str, torch.Tensor] = settings.intermediate_node_outputs + autocast_excluded_nodes = settings.autocast_excluded_nodes + autocast_excluded_ops = settings.autocast_excluded_ops + autocast_data_max = settings.autocast_data_max + autocast_max_depth_of_reduction = settings.autocast_max_depth_of_reduction + reference_data: dict[str, torch.Tensor] = ( + settings.autocast_intermediate_node_outputs + ) node_classifier = NodeClassifier( gm.graph.nodes, - nodes_to_exclude=nodes_to_exclude, - targets_to_exclude=targets_to_exclude, - data_max=data_max, - max_depth_of_reduction=max_depth_of_reduction, + excluded_nodes=autocast_excluded_nodes, + excluded_ops=autocast_excluded_ops, + data_max=autocast_data_max, + max_depth_of_reduction=autocast_max_depth_of_reduction, ) low_precision_nodes, high_precision_nodes = node_classifier.run(reference_data) @@ -110,10 +94,10 @@ def _cast_all_tensor_args_to_dtype(arg: Any, dtype: torch.dtype) -> Any: if node.name in low_precision_nodes: node.args = _cast_all_tensor_args_to_dtype( - node.args, low_precision_type + node.args, autocast_low_precision_type ) node.kwargs = _cast_all_tensor_args_to_dtype( - node.kwargs, low_precision_type + node.kwargs, autocast_low_precision_type ) elif node.name in high_precision_nodes: node.args = _cast_all_tensor_args_to_dtype( From 94757d28d7468c792d78412778c22ca79951080f Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 6 Nov 2025 10:47:33 -0800 Subject: [PATCH 5/6] support dataloader for calibration --- py/torch_tensorrt/dynamo/_compiler.py | 38 +++---------------- py/torch_tensorrt/dynamo/_defaults.py | 1 + py/torch_tensorrt/dynamo/_settings.py | 7 ++-- .../lowering/passes/_aten_lowering_pass.py | 30 +++++++++++++++ 4 files changed, 40 insertions(+), 36 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index a78ae0a813..2ee509f660 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -444,6 +444,9 @@ def compile( autocast_max_depth_of_reduction: Optional[ int ] = _defaults.AUTOCAST_MAX_DEPTH_OF_REDUCTION, + autocast_calibration_dataloader: Optional[ + torch.utils.data.DataLoader + ] = _defaults.AUTOCAST_CALIBRATION_DATALOADER, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -527,6 +530,7 @@ def compile( autocast_excluded_ops (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. autocast_data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None. + autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -655,38 +659,6 @@ def compile( if not isinstance(arg_inputs, collections.abc.Sequence): arg_inputs = [arg_inputs] # type: ignore - # save intermediate outputs of each node for Autocast - autocast_intermediate_node_outputs = {} - if not use_explicit_typing: - - class DumpInterpreter(torch.fx.Interpreter): # type: ignore[misc] - """Dump intermediate outputs of each node""" - - def run_node(self, n: torch.fx.Node) -> Any: - if ( - n.op == "call_function" - and n.target != torch.ops.higher_order.wrap_with_autocast - ): - out = super().run_node(n) - if not isinstance(out, torch.Tensor): - raise ValueError( - f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}." - ) - autocast_intermediate_node_outputs[n.name] = out - return out - return super().run_node(n) - - def _materialize(x: Input | torch.Tensor) -> torch.Tensor: - """Materialize an Input object to a tensor""" - if isinstance(x, Input): - return x.torch_tensor - return x - - with torch.no_grad(): - mat_args = tuple(_materialize(a) for a in arg_inputs) - mat_kwargs = {k: _materialize(v) for k, v in kwarg_inputs.items()} - DumpInterpreter(exported_program.module()).run(*mat_args, **mat_kwargs) - # Prepare torch_trt inputs trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) @@ -751,7 +723,7 @@ def _materialize(x: Input | torch.Tensor) -> torch.Tensor: "autocast_excluded_ops": autocast_excluded_ops, "autocast_data_max": autocast_data_max, "autocast_max_depth_of_reduction": autocast_max_depth_of_reduction, - "autocast_intermediate_node_outputs": autocast_intermediate_node_outputs, + "autocast_calibration_dataloader": autocast_calibration_dataloader, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index a92fcf9d4e..772629f204 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -63,6 +63,7 @@ AUTOCAST_EXCLUDED_OPS = set[torch.fx.node.Target]() AUTOCAST_DATA_MAX = 512 AUTOCAST_MAX_DEPTH_OF_REDUCTION = None +AUTOCAST_CALIBRATION_DATALOADER = None if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 814e75f917..fefbae0f30 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -7,6 +7,7 @@ from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, + AUTOCAST_CALIBRATION_DATALOADER, AUTOCAST_DATA_MAX, AUTOCAST_EXCLUDED_NODES, AUTOCAST_EXCLUDED_OPS, @@ -110,7 +111,7 @@ class CompilationSettings: autocast_excluded_ops (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. autocast_data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None. - autocast_intermediate_node_outputs (dict[str, torch.Tensor]): The intermediate node outputs of the graph. Default is {}. + autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None. """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -164,8 +165,8 @@ class CompilationSettings: ) autocast_data_max: float = AUTOCAST_DATA_MAX autocast_max_depth_of_reduction: Optional[int] = AUTOCAST_MAX_DEPTH_OF_REDUCTION - autocast_intermediate_node_outputs: dict[str, torch.Tensor] = field( - default_factory=lambda: {} + autocast_calibration_dataloader: Optional[torch.utils.data.DataLoader] = ( + AUTOCAST_CALIBRATION_DATALOADER ) def __getstate__(self) -> dict[str, Any]: diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 1499e670bd..1c8134f2c0 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -138,6 +138,36 @@ def pre_export_lowering( logging.debug( f"Invoking DynamoPassManager and applying lowering passes: {ATEN_PRE_LOWERING_PASSES}" ) + + # Only for rule-based autocast to collect the intermediate node outputs + if settings.enable_autocast: + autocast_intermediate_node_outputs: dict[str, torch.Tensor] = {} + + class IntermediateNodeTracer(torch.fx.Interpreter): # type: ignore[misc] + def run_node(self, n: torch.fx.Node) -> Any: + out = super().run_node(n) + if ( + n.op == "call_function" + and n.target != torch.ops.higher_order.wrap_with_autocast + ): + if not isinstance(out, torch.Tensor): + raise ValueError( + f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}." + ) + if n.name in autocast_intermediate_node_outputs: + autocast_intermediate_node_outputs[n.name] = torch.cat( + [autocast_intermediate_node_outputs[n.name], out], dim=0 + ) + else: + autocast_intermediate_node_outputs[n.name] = out + return out + + if settings.autocast_calibration_dataloader is not None: + tracer = IntermediateNodeTracer(ep.module()) + for batch in settings.autocast_calibration_dataloader: + tracer.run(tuple(batch)) + settings.autocast_intermediate_node_outputs = autocast_intermediate_node_outputs + gm = ep.graph_module gm = ATEN_PRE_LOWERING_PASSES(gm, settings) return ep From 4bf12e7a033896df2a3dbd482679695450431f70 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 6 Nov 2025 18:22:41 -0800 Subject: [PATCH 6/6] fix comments --- examples/dynamo/autocast_example.py | 16 ++-- py/torch_tensorrt/dynamo/_compiler.py | 10 +-- py/torch_tensorrt/dynamo/_defaults.py | 2 +- py/torch_tensorrt/dynamo/_settings.py | 10 +-- .../lowering/passes/_aten_lowering_pass.py | 36 ++------ .../dynamo/lowering/passes/nodeclassifier.py | 35 ++++---- .../dynamo/lowering/passes/pass_utils.py | 43 +++++++++- .../lowering/passes/rule_based_autocast.py | 83 ++++++++++--------- 8 files changed, 136 insertions(+), 99 deletions(-) diff --git a/examples/dynamo/autocast_example.py b/examples/dynamo/autocast_example.py index f1487cfb72..24a31fc90f 100644 --- a/examples/dynamo/autocast_example.py +++ b/examples/dynamo/autocast_example.py @@ -31,12 +31,14 @@ def forward(self, x, y): ) # fp32 because of `torch.ops.aten.flatten.using_ints` in `autocast_excluded_ops` # Respect the precisions in the pytorch autocast context with torch.autocast(x.device.type, enabled=True, dtype=torch.float32): - x = self.fc1(x) + x = self.fc1(x) # fp32 with torch.autocast(x.device.type, enabled=False): - x = torch.sub(x.half(), y) - out2 = torch.add(x, x) + x = torch.sub(x.half(), y) # fp16 + out2 = torch.add(x, x) # fp16 with torch.autocast(x.device.type, enabled=True, dtype=torch.float16): - out2 = torch.log(out2) + out2 = torch.log( + out2 + ) # fp32 because Pytorch Autocast requires `log` to be in fp32 return x, out, out2 @@ -46,6 +48,9 @@ def forward(self, x, y): torch.randn((1, 3, 32, 32), dtype=torch.float32, device="cuda"), torch.randn((1,), dtype=torch.float16, device="cuda"), ) + calibration_dataloader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(*inputs), batch_size=1, shuffle=False + ) ep = torch.export.export(model, inputs) @@ -68,8 +73,9 @@ def forward(self, x, y): autocast_low_precision_type=torch.float16, autocast_excluded_nodes={"^conv1$", "relu"}, autocast_excluded_ops={torch.ops.aten.flatten.using_ints}, - autocast_data_max=512, + autocast_max_output_threshold=512, autocast_max_depth_of_reduction=None, + autocast_calibration_dataloader=calibration_dataloader, ) trt_out = trt_mod(*inputs) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 2ee509f660..0ff86ad235 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -440,7 +440,7 @@ def compile( ] = _defaults.AUTOCAST_LOW_PRECISION_TYPE, autocast_excluded_nodes: Collection[str] = _defaults.AUTOCAST_EXCLUDED_NODES, autocast_excluded_ops: Collection[Target] = _defaults.AUTOCAST_EXCLUDED_OPS, - autocast_data_max: float = _defaults.AUTOCAST_DATA_MAX, + autocast_max_output_threshold: float = _defaults.AUTOCAST_MAX_OUTPUT_THRESHOLD, autocast_max_depth_of_reduction: Optional[ int ] = _defaults.AUTOCAST_MAX_DEPTH_OF_REDUCTION, @@ -526,10 +526,10 @@ def compile( use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True. autocast_low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used. - autocast_excluded_nodes (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is []. + autocast_excluded_nodes (Collection[str]): The set of regex patterns to match user-specified node names that should remain in FP32. Default is []. autocast_excluded_ops (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. - autocast_data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. - autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None. + autocast_max_output_threshold (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. + autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. This helps prevent excessive accuracy loss in operations particularly sensitive to reduced precision, as higher-depth reductions may amplify computation errors in low precision formats. If not provided, infinity will be used. Default is None. autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None. **kwargs: Any, Returns: @@ -721,7 +721,7 @@ def compile( "autocast_low_precision_type": autocast_low_precision_type, "autocast_excluded_nodes": autocast_excluded_nodes, "autocast_excluded_ops": autocast_excluded_ops, - "autocast_data_max": autocast_data_max, + "autocast_max_output_threshold": autocast_max_output_threshold, "autocast_max_depth_of_reduction": autocast_max_depth_of_reduction, "autocast_calibration_dataloader": autocast_calibration_dataloader, } diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 772629f204..3a238c11ee 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -61,7 +61,7 @@ AUTOCAST_LOW_PRECISION_TYPE = None AUTOCAST_EXCLUDED_NODES = set[str]() AUTOCAST_EXCLUDED_OPS = set[torch.fx.node.Target]() -AUTOCAST_DATA_MAX = 512 +AUTOCAST_MAX_OUTPUT_THRESHOLD = 512 AUTOCAST_MAX_DEPTH_OF_REDUCTION = None AUTOCAST_CALIBRATION_DATALOADER = None diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index fefbae0f30..d62c75e0da 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -8,11 +8,11 @@ from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, AUTOCAST_CALIBRATION_DATALOADER, - AUTOCAST_DATA_MAX, AUTOCAST_EXCLUDED_NODES, AUTOCAST_EXCLUDED_OPS, AUTOCAST_LOW_PRECISION_TYPE, AUTOCAST_MAX_DEPTH_OF_REDUCTION, + AUTOCAST_MAX_OUTPUT_THRESHOLD, CACHE_BUILT_ENGINES, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, @@ -107,10 +107,10 @@ class CompilationSettings: use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True. autocast_low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used. - autocast_excluded_nodes (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is []. + autocast_excluded_nodes (Collection[str]): The set of regex patterns to match user-specified node names that should remain in FP32. Default is []. autocast_excluded_ops (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. - autocast_data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. - autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None. + autocast_max_output_threshold (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. + autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. This helps prevent excessive accuracy loss in operations particularly sensitive to reduced precision, as higher-depth reductions may amplify computation errors in low precision formats. If not provided, infinity will be used. Default is None. autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None. """ @@ -163,7 +163,7 @@ class CompilationSettings: autocast_excluded_ops: Collection[Target] = field( default_factory=lambda: AUTOCAST_EXCLUDED_OPS ) - autocast_data_max: float = AUTOCAST_DATA_MAX + autocast_max_output_threshold: float = AUTOCAST_MAX_OUTPUT_THRESHOLD autocast_max_depth_of_reduction: Optional[int] = AUTOCAST_MAX_DEPTH_OF_REDUCTION autocast_calibration_dataloader: Optional[torch.utils.data.DataLoader] = ( AUTOCAST_CALIBRATION_DATALOADER diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 1c8134f2c0..0d8bdadb9d 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -1,9 +1,13 @@ import logging +import operator from typing import Any, Callable, Optional, Sequence, Union import torch from torch_tensorrt._utils import is_tegra_platform from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + trace_intermediate_node_outputs, +) from .complex_graph_rewrite import complex_graph_detection from .constant_folding import constant_fold @@ -141,33 +145,11 @@ def pre_export_lowering( # Only for rule-based autocast to collect the intermediate node outputs if settings.enable_autocast: - autocast_intermediate_node_outputs: dict[str, torch.Tensor] = {} - - class IntermediateNodeTracer(torch.fx.Interpreter): # type: ignore[misc] - def run_node(self, n: torch.fx.Node) -> Any: - out = super().run_node(n) - if ( - n.op == "call_function" - and n.target != torch.ops.higher_order.wrap_with_autocast - ): - if not isinstance(out, torch.Tensor): - raise ValueError( - f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}." - ) - if n.name in autocast_intermediate_node_outputs: - autocast_intermediate_node_outputs[n.name] = torch.cat( - [autocast_intermediate_node_outputs[n.name], out], dim=0 - ) - else: - autocast_intermediate_node_outputs[n.name] = out - return out - - if settings.autocast_calibration_dataloader is not None: - tracer = IntermediateNodeTracer(ep.module()) - for batch in settings.autocast_calibration_dataloader: - tracer.run(tuple(batch)) - settings.autocast_intermediate_node_outputs = autocast_intermediate_node_outputs - + settings.autocast_intermediate_node_outputs = trace_intermediate_node_outputs( + ep.module(), + settings.autocast_calibration_dataloader, + [torch.ops.higher_order.wrap_with_autocast, operator.getitem], + ) gm = ep.graph_module gm = ATEN_PRE_LOWERING_PASSES(gm, settings) return ep diff --git a/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py b/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py index b7b7c770c3..3c221701b2 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py @@ -29,7 +29,7 @@ def check(self, node): """Check if a node should be skipped based on the rule. Args: - node: The ONNX node to check. + node: The torch.fx.Node to check. Returns: bool: True if the node should be kept in high precision, False otherwise. @@ -42,13 +42,13 @@ def check(self, node): class DisabledNodeNameRegexRule(NodeRuleBase): - """Rule for keeping nodes with matching names in high precision.""" + """Rule for keeping nodes with matching user-specified names in high precision.""" def __init__(self, disabled_node_name_regex): """Initialize the rule. Args: - disabled_node_name_regex: List of regex patterns for node names to keep in high precision. + disabled_node_name_regex: List of regex patterns for user-specified node names to keep in high precision. """ self.disabled_node_name_regex = disabled_node_name_regex @@ -63,13 +63,13 @@ def _check_inner(self, node): class DisabledOpTypes(NodeRuleBase): - """Rule for keeping nodes with specific operation types in high precision.""" + """Rule for keeping nodes with specific ATen ops in high precision.""" def __init__(self, excluded_ops): """Initialize the rule. Args: - excluded_ops: List of operation types to keep in high precision. + excluded_ops: List of ATen ops that should remain in FP32. """ self.excluded_ops = excluded_ops @@ -80,14 +80,14 @@ def _check_inner(self, node): class IORangeRule(NodeRuleBase): """Rule for keeping nodes with out-of-range inputs/outputs in high precision.""" - def __init__(self, data_max, reference_data): + def __init__(self, max_output_threshold, reference_data): """Initialize the rule. Args: - data_max: Maximum absolute value allowed for node I/O. + max_output_threshold: Maximum absolute value allowed for node I/O. reference_data: Reference data for checking I/O ranges. """ - self.data_max = data_max + self.max_output_threshold = max_output_threshold self.reference_data = reference_data self.output_data = None @@ -108,7 +108,7 @@ def is_io_out_of_range(node): logger.debug( f"Node {node.name}: reference data: min={ref_data.min()}, max={ref_data.max()}" ) - if torch.any(torch.abs(ref_data) > self.data_max): + if torch.any(torch.abs(ref_data) > self.max_output_threshold): self.output_data = ref_data return True @@ -126,14 +126,17 @@ def _log_skipped(self, node, **kwargs): if self.output_data is not None: logger.info( f"Skipping node {node.name}: reference IO out of range: min={torch.min(self.output_data)}, " - f"max={torch.max(self.output_data)}, range=[{-self.data_max}, {self.data_max}]" + f"max={torch.max(self.output_data)}, range=[{-self.max_output_threshold}, {self.max_output_threshold}]" ) else: super()._log_skipped(node, **kwargs) class DepthOfReductionRule(NodeRuleBase): - """Rule for keeping nodes with high depth of reduction in high precision.""" + """ + Rule for keeping nodes with high depth of reduction in high precision. This helps prevent excessive accuracy loss in operations particularly sensitive to reduced precision, as higher-depth reductions may amplify computation errors in low precision formats. + Reduction ops are those that aggregate data across one or more axes, decreasing the dimensionality of the input tensor, such as convolution, gemm, etc. + """ def __init__(self, max_depth_of_reduction, reference_data): """Initialize the rule. @@ -226,7 +229,7 @@ def __init__( excluded_nodes: Collection[str] | None = None, excluded_ops: Collection[torch.fx.node.Target] | None = None, custom_rule: NodeRuleBase | None = None, - data_max: float | None = 1000.0, + max_output_threshold: float | None = 512, max_depth_of_reduction: int | None = None, ): """Initialize the node classifier. @@ -236,14 +239,14 @@ def __init__( nodes_to_exclude: Collection of regex patterns for node names to keep in high precision. targets_to_exclude: Collection of targets to keep in high precision. custom_rule: Optional custom classification rule. - data_max: Maximum absolute value allowed for node I/O. + max_output_threshold: Maximum absolute value allowed for node I/O. max_depth_of_reduction: Maximum depth of reduction allowed in low precision. """ self.nodes = nodes self.excluded_nodes = excluded_nodes self.excluded_ops = excluded_ops self.custom_rule = custom_rule - self.data_max = data_max + self.max_output_threshold = max_output_threshold self.max_depth_of_reduction = max_depth_of_reduction def _gen_block_node_rules(self, reference_data): @@ -261,7 +264,9 @@ def _gen_block_node_rules(self, reference_data): if self.excluded_ops: block_node_rules.append(DisabledOpTypes(self.excluded_ops)) if reference_data: - block_node_rules.append(IORangeRule(self.data_max, reference_data)) + block_node_rules.append( + IORangeRule(self.max_output_threshold, reference_data) + ) if self.max_depth_of_reduction is not None: block_node_rules.append( DepthOfReductionRule( diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py index 1736a234a2..e92c90e578 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Any, Dict, List, Sequence import torch @@ -68,3 +68,44 @@ def is_node_complex(node: torch.fx.Node, complexNodes): complexNodes[node.name] = True return True return False + + +def trace_intermediate_node_outputs( + gm: torch.fx.GraphModule, + calibration_dataloader: torch.utils.data.DataLoader, + excluded_ops: Sequence[torch.fx.node.Target] = [], +) -> Dict[str, torch.Tensor]: + """Trace the intermediate node outputs of a graph module. + + Args: + gm (torch.fx.GraphModule): The graph module to trace the intermediate node outputs of. + calibration_dataloader (torch.utils.data.DataLoader): The dataloader to use for tracing. + excluded_ops (Set[torch.fx.node.Target]): The set of ATen ops that should be excluded from the trace. For example, `{torch.ops.higher_order.wrap_with_autocast, operator.getitem}`. Default is an empty set. + + Returns: + Dict[str, torch.Tensor]: A dictionary of intermediate node outputs. The key is the node name and the value is the tensor. + """ + + intermediate_node_outputs: Dict[str, torch.Tensor] = {} + + class IntermediateNodeTracer(torch.fx.Interpreter): # type: ignore[misc] + def run_node(self, n: torch.fx.Node) -> Any: + out = super().run_node(n) + if n.op == "call_function" and n.target not in excluded_ops: + if not isinstance(out, torch.Tensor): + raise ValueError( + f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}." + ) + if n.name in intermediate_node_outputs: + intermediate_node_outputs[n.name] = torch.cat( + [intermediate_node_outputs[n.name], out], dim=0 + ) + else: + intermediate_node_outputs[n.name] = out + return out + + if calibration_dataloader is not None: + tracer = IntermediateNodeTracer(gm) + for batch in calibration_dataloader: + tracer.run(tuple(batch)) + return intermediate_node_outputs diff --git a/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py b/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py index 097e17a944..6a771e4706 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py @@ -33,10 +33,10 @@ def rule_based_autocast( return gm if isinstance(autocast_low_precision_type, dtype): autocast_low_precision_type = autocast_low_precision_type.to(torch.dtype) - high_precision_type = torch.float32 + autocast_high_precision_type = torch.float32 autocast_excluded_nodes = settings.autocast_excluded_nodes autocast_excluded_ops = settings.autocast_excluded_ops - autocast_data_max = settings.autocast_data_max + autocast_max_output_threshold = settings.autocast_max_output_threshold autocast_max_depth_of_reduction = settings.autocast_max_depth_of_reduction reference_data: dict[str, torch.Tensor] = ( settings.autocast_intermediate_node_outputs @@ -46,11 +46,47 @@ def rule_based_autocast( gm.graph.nodes, excluded_nodes=autocast_excluded_nodes, excluded_ops=autocast_excluded_ops, - data_max=autocast_data_max, + max_output_threshold=autocast_max_output_threshold, max_depth_of_reduction=autocast_max_depth_of_reduction, ) low_precision_nodes, high_precision_nodes = node_classifier.run(reference_data) + def _cast_all_tensor_args_to_dtype( + node: torch.fx.Node, arg: Any, dtype: torch.dtype + ) -> Any: + """Cast all tensor args to the given dtype + + Args: + node: The node to insert the cast before + arg: The argument to cast + dtype: The dtype to cast to + + Returns: + The casted argument + """ + if isinstance(arg, torch.fx.Node) and is_tensor_node(arg): + val = arg.meta.get("val", None) + with gm.graph.inserting_before(node): + cast = gm.graph.call_function( + torch.ops.aten.to.dtype, args=(arg, dtype) + ) + + if isinstance(val, torch.Tensor): + arg.meta["val"] = val.to(dtype) + cast.meta.update(arg.meta) + return cast + elif isinstance(arg, (tuple, list)): + return type(arg)( + _cast_all_tensor_args_to_dtype(node, a, dtype) for a in arg + ) + elif isinstance(arg, dict): + return { + k: _cast_all_tensor_args_to_dtype(node, v, dtype) + for k, v in arg.items() + } + else: + return arg + for node in list(gm.graph.nodes): if node.op == "call_function": if ( @@ -59,52 +95,19 @@ def rule_based_autocast( ): continue - def _cast_all_tensor_args_to_dtype(arg: Any, dtype: torch.dtype) -> Any: - """Cast all tensor args to the given dtype - - Args: - arg: The argument to cast - dtype: The dtype to cast to - - Returns: - The casted argument - """ - if isinstance(arg, torch.fx.Node) and is_tensor_node(arg): - val = arg.meta.get("val", None) - with gm.graph.inserting_before(node): - cast = gm.graph.call_function( - torch.ops.aten.to.dtype, args=(arg, dtype) - ) - - if isinstance(val, torch.Tensor): - arg.meta["val"] = val.to(dtype) - cast.meta.update(arg.meta) - return cast - elif isinstance(arg, (tuple, list)): - return type(arg)( - _cast_all_tensor_args_to_dtype(a, dtype) for a in arg - ) - elif isinstance(arg, dict): - return { - k: _cast_all_tensor_args_to_dtype(v, dtype) - for k, v in arg.items() - } - else: - return arg - if node.name in low_precision_nodes: node.args = _cast_all_tensor_args_to_dtype( - node.args, autocast_low_precision_type + node, node.args, autocast_low_precision_type ) node.kwargs = _cast_all_tensor_args_to_dtype( - node.kwargs, autocast_low_precision_type + node, node.kwargs, autocast_low_precision_type ) elif node.name in high_precision_nodes: node.args = _cast_all_tensor_args_to_dtype( - node.args, high_precision_type + node, node.args, autocast_high_precision_type ) node.kwargs = _cast_all_tensor_args_to_dtype( - node.kwargs, high_precision_type + node, node.kwargs, autocast_high_precision_type ) gm = clean_up_graph_after_modifications(gm)