Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
)
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
from torch_tensorrt.dynamo.lowering import (
get_decompositions,
post_lowering,
pre_export_lowering,
)
from torch_tensorrt.dynamo.utils import (
get_torch_inputs,
parse_complex_tensor_structs,
Expand Down Expand Up @@ -181,23 +185,24 @@ def compile(

# Prepare torch_trt inputs
inputs = prepare_inputs(inputs)
torch_inputs = get_torch_inputs(inputs, device)
device = to_torch_tensorrt_device(device)

if not isinstance(exported_program, ExportedProgram):
raise AssertionError(
f"Input graph should be an ExportedProgram but got type {type(exported_program)}"
)

exported_program = pre_export_lowering(exported_program, torch_inputs)
exported_program = exported_program.run_decompositions(
get_decompositions(enable_experimental_decompositions)
)
gm = exported_program.module()
logger.debug("Input graph: " + str(gm.graph))

# Apply lowering on the graph module
torch_inputs = get_torch_inputs(inputs, device)
gm = apply_lowering_passes(gm, torch_inputs)
gm = post_lowering(gm, torch_inputs)
logger.debug("Lowered Input graph: " + str(gm.graph))

enabled_precisions = set(enabled_precisions)

if (
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo._compiler import compile_module
from torch_tensorrt.dynamo.lowering import (
apply_lowering_passes,
get_decompositions,
post_lowering,
repair_input_aliasing,
)
from torch_tensorrt.dynamo.utils import (
Expand Down Expand Up @@ -87,7 +87,7 @@ def _pretraced_backend(

logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

gm = apply_lowering_passes(gm, sample_inputs)
gm = post_lowering(gm, sample_inputs)

torchtrt_inputs = prepare_inputs(
sample_inputs, disable_memory_format_check=True
Expand Down
18 changes: 18 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,24 @@ def aten_ops_neg(
)


@dynamo_tensorrt_converter(torch.ops.ptq.scaled_e4m3.default)
def aten_ops_quantize_fp8(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.quantize.quantize_fp8(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim)
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims)
def aten_ops_squeeze(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
pad,
permutation,
pool,
quantize,
reduce,
select,
shape,
Expand Down
43 changes: 43 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Optional

import numpy as np
import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor


def quantize_fp8(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_tensor: TRTTensor,
scale: np.ndarray,
) -> TRTTensor:
"""
Adds quantize and dequantize ops (QDQ) which quantize to INT8 or FP8 based
on the output_type set and dequantizes them back.
"""
if (isinstance(input_tensor, TRTTensor)) and not (
input_tensor.dtype != trt.float32 or input_tensor.dtype != trt.float16
):
raise ValueError(
f"quantize_fp8 converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16"
)

if isinstance(scale, np.ndarray):
scale = get_trt_tensor(ctx, scale, name + "_scale")
# Add Q node
quantize_layer = ctx.net.add_quantize(input_tensor, scale)
set_layer_name(quantize_layer, target, name + "_quantize", source_ir)
q_output = quantize_layer.get_output(0)
# Add DQ node
dequantize_layer = ctx.net.add_dequantize(q_output, scale)
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
dq_output = dequantize_layer.get_output(0)

return dq_output
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/lowering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
from ._decompositions import get_decompositions # noqa: F401
from ._fusers import * # noqa: F401
from ._repair_input_aliasing import repair_input_aliasing
from .passes import apply_lowering_passes
from .passes import post_lowering, pre_export_lowering
41 changes: 31 additions & 10 deletions py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from .lower_linear import lower_linear
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
from .pass_manager import DynamoPassManager
from .remove_detach import remove_detach
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .view_to_reshape import view_to_reshape

ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
[
remove_input_alias_fixing_clones,
constant_fold,
Expand All @@ -26,6 +27,12 @@
]
)

ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
[
remove_detach,
]
)

logger = logging.getLogger(__name__)


Expand All @@ -48,9 +55,9 @@ def _aten_lowering_pass(
def add_lowering_pass(
lowering_pass: LoweringPassSignature,
) -> LoweringPassSignature:
ATEN_LOWERING_PASSES.add_pass_with_index(lowering_pass, index)
ATEN_POST_LOWERING_PASSES.add_pass_with_index(lowering_pass, index)
logger.debug(
f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_LOWERING_PASSES}"
f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}"
)
return lowering_pass

Expand All @@ -72,23 +79,37 @@ def add_lowering_pass(

def _remove_lowering_pass(*, index: int) -> None:
"""Removes a lowering pass at a specific index from the registry"""
ATEN_LOWERING_PASSES.remove_pass_with_index(index)
ATEN_POST_LOWERING_PASSES.remove_pass_with_index(index)
logger.debug(
f"Removed lowering pass at index {index}, current passlist: {ATEN_LOWERING_PASSES}"
f"Removed lowering pass at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}"
)
return


def apply_lowering_passes(
def post_lowering(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Applies the lowering passes to a graph module, returns the modified GraphModule"""
"""Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule"""
logging.debug(
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_POST_LOWERING_PASSES}"
)
return ATEN_POST_LOWERING_PASSES(gm, sample_inputs)


def pre_export_lowering(
ep: torch.export.ExportedProgram, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule"""
logging.debug(
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_LOWERING_PASSES}"
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_PRE_LOWERING_PASSES}"
)
return ATEN_LOWERING_PASSES(gm, sample_inputs)
gm = ep.module()
gm = ATEN_PRE_LOWERING_PASSES(gm, sample_inputs)
# TODO: Check if re-exporting changes the metadata
transformed_ep = torch.export.export(gm, tuple(sample_inputs), strict=False)
return transformed_ep


def dump_lowering_passes() -> str:
"""Returns a string containing the lowering passes"""
return str(ATEN_LOWERING_PASSES)
return str(ATEN_POST_LOWERING_PASSES)
32 changes: 32 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/remove_detach.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import logging
from typing import Sequence

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


def remove_detach(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Remove detach ops in the graph"""

modified_graph = False
count = 0
for node in gm.graph.nodes:
# If the node is a detach node
if node.target == torch.ops.aten.detach.default:
# Detach node has only one input
node_input = node.all_input_nodes[0]
node.replace_all_uses_with(node_input)
gm.graph.erase_node(node)
count += 1

if modified_graph:
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Removed {count} detach nodes:\n{gm.graph}")

return gm