diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 6312532f1c..1a635b5af5 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -59,7 +59,11 @@ from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, ) -from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions +from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + post_lowering, + pre_export_lowering, +) from torch_tensorrt.dynamo.utils import ( get_torch_inputs, parse_complex_tensor_structs, @@ -181,12 +185,15 @@ def compile( # Prepare torch_trt inputs inputs = prepare_inputs(inputs) + torch_inputs = get_torch_inputs(inputs, device) device = to_torch_tensorrt_device(device) if not isinstance(exported_program, ExportedProgram): raise AssertionError( f"Input graph should be an ExportedProgram but got type {type(exported_program)}" ) + + exported_program = pre_export_lowering(exported_program, torch_inputs) exported_program = exported_program.run_decompositions( get_decompositions(enable_experimental_decompositions) ) @@ -194,10 +201,8 @@ def compile( logger.debug("Input graph: " + str(gm.graph)) # Apply lowering on the graph module - torch_inputs = get_torch_inputs(inputs, device) - gm = apply_lowering_passes(gm, torch_inputs) + gm = post_lowering(gm, torch_inputs) logger.debug("Lowered Input graph: " + str(gm.graph)) - enabled_precisions = set(enabled_precisions) if ( diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 1fa2806181..961e1c9344 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -11,8 +11,8 @@ from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo._compiler import compile_module from torch_tensorrt.dynamo.lowering import ( - apply_lowering_passes, get_decompositions, + post_lowering, repair_input_aliasing, ) from torch_tensorrt.dynamo.utils import ( @@ -87,7 +87,7 @@ def _pretraced_backend( logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) - gm = apply_lowering_passes(gm, sample_inputs) + gm = post_lowering(gm, sample_inputs) torchtrt_inputs = prepare_inputs( sample_inputs, disable_memory_format_check=True diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 478cf98dea..186047393c 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -573,6 +573,24 @@ def aten_ops_neg( ) +@dynamo_tensorrt_converter(torch.ops.ptq.scaled_e4m3.default) +def aten_ops_quantize_fp8( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.quantize.quantize_fp8( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim) @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims) def aten_ops_squeeze( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index ca71cb0b0c..a18155d6be 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -18,6 +18,7 @@ pad, permutation, pool, + quantize, reduce, select, shape, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py new file mode 100644 index 0000000000..f5c3a4cd98 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -0,0 +1,43 @@ +from typing import Optional + +import numpy as np +import tensorrt as trt +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTTensor + + +def quantize_fp8( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor: TRTTensor, + scale: np.ndarray, +) -> TRTTensor: + """ + Adds quantize and dequantize ops (QDQ) which quantize to INT8 or FP8 based + on the output_type set and dequantizes them back. + """ + if (isinstance(input_tensor, TRTTensor)) and not ( + input_tensor.dtype != trt.float32 or input_tensor.dtype != trt.float16 + ): + raise ValueError( + f"quantize_fp8 converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16" + ) + + if isinstance(scale, np.ndarray): + scale = get_trt_tensor(ctx, scale, name + "_scale") + # Add Q node + quantize_layer = ctx.net.add_quantize(input_tensor, scale) + set_layer_name(quantize_layer, target, name + "_quantize", source_ir) + q_output = quantize_layer.get_output(0) + # Add DQ node + dequantize_layer = ctx.net.add_dequantize(q_output, scale) + set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir) + dq_output = dequantize_layer.get_output(0) + + return dq_output diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index 7c4e9fdd2d..67a73354ad 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -5,4 +5,4 @@ from ._decompositions import get_decompositions # noqa: F401 from ._fusers import * # noqa: F401 from ._repair_input_aliasing import repair_input_aliasing -from .passes import apply_lowering_passes +from .passes import post_lowering, pre_export_lowering diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 489805cb43..05a171e264 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -8,12 +8,13 @@ from .lower_linear import lower_linear from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention from .pass_manager import DynamoPassManager +from .remove_detach import remove_detach from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .repair_input_as_output import repair_input_as_output from .replace_max_pool_with_indices import replace_max_pool_with_indices from .view_to_reshape import view_to_reshape -ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist( +ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist( [ remove_input_alias_fixing_clones, constant_fold, @@ -26,6 +27,12 @@ ] ) +ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist( + [ + remove_detach, + ] +) + logger = logging.getLogger(__name__) @@ -48,9 +55,9 @@ def _aten_lowering_pass( def add_lowering_pass( lowering_pass: LoweringPassSignature, ) -> LoweringPassSignature: - ATEN_LOWERING_PASSES.add_pass_with_index(lowering_pass, index) + ATEN_POST_LOWERING_PASSES.add_pass_with_index(lowering_pass, index) logger.debug( - f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_LOWERING_PASSES}" + f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}" ) return lowering_pass @@ -72,23 +79,37 @@ def add_lowering_pass( def _remove_lowering_pass(*, index: int) -> None: """Removes a lowering pass at a specific index from the registry""" - ATEN_LOWERING_PASSES.remove_pass_with_index(index) + ATEN_POST_LOWERING_PASSES.remove_pass_with_index(index) logger.debug( - f"Removed lowering pass at index {index}, current passlist: {ATEN_LOWERING_PASSES}" + f"Removed lowering pass at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}" ) return -def apply_lowering_passes( +def post_lowering( gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] ) -> torch.fx.GraphModule: - """Applies the lowering passes to a graph module, returns the modified GraphModule""" + """Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule""" + logging.debug( + f"Invoking DynamoPassManager and applying lowering passes: {ATEN_POST_LOWERING_PASSES}" + ) + return ATEN_POST_LOWERING_PASSES(gm, sample_inputs) + + +def pre_export_lowering( + ep: torch.export.ExportedProgram, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: + """Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule""" logging.debug( - f"Invoking DynamoPassManager and applying lowering passes: {ATEN_LOWERING_PASSES}" + f"Invoking DynamoPassManager and applying lowering passes: {ATEN_PRE_LOWERING_PASSES}" ) - return ATEN_LOWERING_PASSES(gm, sample_inputs) + gm = ep.module() + gm = ATEN_PRE_LOWERING_PASSES(gm, sample_inputs) + # TODO: Check if re-exporting changes the metadata + transformed_ep = torch.export.export(gm, tuple(sample_inputs), strict=False) + return transformed_ep def dump_lowering_passes() -> str: """Returns a string containing the lowering passes""" - return str(ATEN_LOWERING_PASSES) + return str(ATEN_POST_LOWERING_PASSES) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_detach.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_detach.py new file mode 100644 index 0000000000..5c47a8e288 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_detach.py @@ -0,0 +1,32 @@ +import logging +from typing import Sequence + +import torch +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +def remove_detach( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: + """Remove detach ops in the graph""" + + modified_graph = False + count = 0 + for node in gm.graph.nodes: + # If the node is a detach node + if node.target == torch.ops.aten.detach.default: + # Detach node has only one input + node_input = node.all_input_nodes[0] + node.replace_all_uses_with(node_input) + gm.graph.erase_node(node) + count += 1 + + if modified_graph: + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Removed {count} detach nodes:\n{gm.graph}") + + return gm