diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index a263440128..a24a203cab 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload import numpy as np +import tensorrt as trt import torch import torch_tensorrt.dynamo.conversion.impl as impl from torch import SymBool, SymFloat, SymInt @@ -15,11 +16,12 @@ ConverterRegistry, DynamoConverterImplSignature, ) -from torch_tensorrt.fx.converters.converter_utils import get_axes_for_reduce_op +from torch_tensorrt.fx.converters.converter_utils import ( + broadcast, + get_axes_for_reduce_op, +) from torch_tensorrt.fx.types import TRTDataType, TRTTensor -import tensorrt as trt - _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -205,6 +207,72 @@ def broadcastable( return True +def broadcast_to_same_shape( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: TRTTensor, + rhs_val: TRTTensor, +) -> Tuple[TRTTensor, TRTTensor]: + """Broadcast ITensors `lhs_val` and `rhs_val` to the same shape. If the shapes are already the same, return the + original tensors. If the shapes are different, broadcast the tensors to the same shape. + + This helper function is different from fx/converter_utils.broadcast. + fx/converter_utils.broadcast only broadcasts two ITensors to the same number of dimensions (ranks) + by prepending 1s, while this function broadcasts two ITensors to the same shape. + + For example, we have original ITensors: lhs_val.shape: (2, 3) rhs_val.shape: (2, 2, 1, 3) + If calling fx/converter_utils.broadcast, lhs_val.shape: (1, 1, 2, 3) lhs_val.shape: (2, 2, 1, 3). + If calling this function broadcast_to_same_shape, lhs_val.shape: (2, 2, 2, 3) lhs_val.shape: (2, 2, 2, 3). + + Args: + lhs_val (TRTTensor): A TensorRT ITensor. + rhs_val (TRTTensor): A TensorRT ITensor. + + Returns: + Tuple[TRTTensor, TRTTensor]: Two TensorRT ITensors that are broadcasted to the same shape + + """ + lhs_val, rhs_val = broadcast( + ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" + ) + + lhs_val_shape = lhs_val.shape + rhs_val_shape = rhs_val.shape + + if tuple(lhs_val_shape) != tuple(rhs_val_shape): + rank = len(lhs_val_shape) + expanded_dims = [-1] * len(lhs_val_shape) + + for dim in range(rank): + expanded_dims[dim] = max(lhs_val_shape[dim], rhs_val_shape[dim]) + + expanded_shape = tuple(expanded_dims) + + if lhs_val_shape != expanded_shape: + lhs_val = impl.slice.expand( + ctx, + target, + source_ir, + f"{name}_expand_lhs_val", + lhs_val, + expanded_shape, + ) + + if rhs_val_shape != expanded_shape: + rhs_val = impl.slice.expand( + ctx, + target, + source_ir, + f"{name}_expand_rhs_val", + rhs_val, + expanded_shape, + ) + + return lhs_val, rhs_val + + get_axes_for_reduce_op = functools.partial( get_axes_for_reduce_op, has_implicit_batch_dimension=False ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index ffac049140..e9e80593e9 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -8,9 +8,9 @@ from torch.fx.node import Target from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( + broadcast_to_same_shape, cast_trt_tensor, get_trt_tensor, ) @@ -152,41 +152,12 @@ def convert_binary_elementwise( if has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape): lhs_val, rhs_val = broadcast( - ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" + ctx.net, lhs_val, rhs_val, f"{name}_broadcast_lhs", f"{name}_broadcast_rhs" ) else: - lhs_val_shape = lhs_val.shape - rhs_val_shape = rhs_val.shape - rank_diff = len(lhs_val_shape) - len(rhs_val_shape) - if rank_diff > 0: - rhs_val = impl.slice.expand( - ctx, target, source_ir, f"{name}_expand_rhs_val", rhs_val, lhs_val_shape - ) - elif rank_diff < 0: - lhs_val = impl.slice.expand( - ctx, target, source_ir, f"{name}_expand_lhs_val", lhs_val, rhs_val_shape - ) - else: - if tuple(lhs_val_shape) != tuple(rhs_val_shape): - sum_diff = sum(lhs_val_shape) - sum(rhs_val_shape) - if sum_diff > 0: - rhs_val = impl.slice.expand( - ctx, - target, - source_ir, - f"{name}_expand_rhs_val", - rhs_val, - lhs_val_shape, - ) - elif sum_diff < 0: - lhs_val = impl.slice.expand( - ctx, - target, - source_ir, - f"{name}_expand_lhs_val", - lhs_val, - rhs_val_shape, - ) + lhs_val, rhs_val = broadcast_to_same_shape( + ctx, target, source_ir, f"{name}_broadcast_to_same_shape", lhs_val, rhs_val + ) layer = ctx.net.add_elementwise(lhs_val, rhs_val, op_type) set_layer_name(layer, target, name, source_ir) diff --git a/tests/py/dynamo/conversion/test_binary_ops_aten.py b/tests/py/dynamo/conversion/test_binary_ops_aten.py index ebe727b716..79c0d9430a 100644 --- a/tests/py/dynamo/conversion/test_binary_ops_aten.py +++ b/tests/py/dynamo/conversion/test_binary_ops_aten.py @@ -1,4 +1,3 @@ -import unittest from typing import Callable import torch @@ -59,7 +58,6 @@ def forward(self, x): self.run_test(m, inputs) @parameterized.expand([(op[0].__name__, op[0]) for op in elementwise_ops]) - @unittest.skip("Pending reimplementation of all binary converters in Dynamo") def test_elementwise_ops_mismatched_dtypes(self, name, orig_op: Callable): class TestModule(nn.Module): def __init__(self, orig_op): diff --git a/tests/py/dynamo/conversion/test_bitwise_and_aten.py b/tests/py/dynamo/conversion/test_bitwise_and_aten.py index 5c2a78a18a..6cf0fab2cb 100644 --- a/tests/py/dynamo/conversion/test_bitwise_and_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_and_aten.py @@ -9,18 +9,21 @@ class TestBitwiseAndConverter(DispatchTestCase): @parameterized.expand( [ - ("2d", (5, 3)), - ("3d", (5, 3, 2)), + ("2d", (2, 3), (2, 3)), + ("3d", (5, 3, 2), (5, 3, 2)), + ("3d_broadcast", (2, 3), (2, 1, 3)), + ("4d_broadcast_1", (2, 3), (1, 2, 1, 3)), + ("4d_broadcast_2", (2, 3), (2, 2, 2, 3)), ] ) - def test_bitwise_and_tensor(self, _, shape): + def test_bitwise_and_tensor(self, _, lhs_shape, rhs_shape): class bitwise_and(nn.Module): def forward(self, lhs_val, rhs_val): return torch.ops.aten.bitwise_and.Tensor(lhs_val, rhs_val) inputs = [ - torch.randint(0, 2, shape, dtype=bool), - torch.randint(0, 2, shape, dtype=bool), + torch.randint(0, 2, lhs_shape, dtype=bool), + torch.randint(0, 2, rhs_shape, dtype=bool), ] self.run_test( bitwise_and(), diff --git a/tests/py/dynamo/conversion/test_bitwise_or_aten.py b/tests/py/dynamo/conversion/test_bitwise_or_aten.py index b5e0200734..99286a815b 100644 --- a/tests/py/dynamo/conversion/test_bitwise_or_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_or_aten.py @@ -9,18 +9,21 @@ class TestBitwiseOrConverter(DispatchTestCase): @parameterized.expand( [ - ("2d", (5, 3)), - ("3d", (5, 3, 2)), + ("2d", (2, 3), (2, 3)), + ("3d", (5, 3, 2), (5, 3, 2)), + ("3d_broadcast", (2, 3), (2, 1, 3)), + ("4d_broadcast_1", (2, 3), (1, 2, 1, 3)), + ("4d_broadcast_2", (2, 3), (2, 2, 2, 3)), ] ) - def test_bitwise_or_tensor(self, _, shape): + def test_bitwise_or_tensor(self, _, lhs_shape, rhs_shape): class bitwise_or(nn.Module): def forward(self, lhs_val, rhs_val): return torch.ops.aten.bitwise_or.Tensor(lhs_val, rhs_val) inputs = [ - torch.randint(0, 2, shape, dtype=bool), - torch.randint(0, 2, shape, dtype=bool), + torch.randint(0, 2, lhs_shape, dtype=bool), + torch.randint(0, 2, rhs_shape, dtype=bool), ] self.run_test( bitwise_or(), diff --git a/tests/py/dynamo/conversion/test_bitwise_xor_aten.py b/tests/py/dynamo/conversion/test_bitwise_xor_aten.py index 8c1a8136ef..94a282d701 100644 --- a/tests/py/dynamo/conversion/test_bitwise_xor_aten.py +++ b/tests/py/dynamo/conversion/test_bitwise_xor_aten.py @@ -9,18 +9,21 @@ class TestBitwiseXorConverter(DispatchTestCase): @parameterized.expand( [ - ("2d", (5, 3)), - ("3d", (5, 3, 2)), + ("2d", (2, 3), (2, 3)), + ("3d", (5, 3, 2), (5, 3, 2)), + ("3d_broadcast", (2, 3), (2, 1, 3)), + ("4d_broadcast_1", (2, 3), (1, 2, 1, 3)), + ("4d_broadcast_2", (2, 3), (2, 2, 2, 3)), ] ) - def test_bitwise_xor_tensor(self, _, shape): + def test_bitwise_xor_tensor(self, _, lhs_shape, rhs_shape): class bitwise_xor(nn.Module): def forward(self, lhs_val, rhs_val): return torch.ops.aten.bitwise_xor.Tensor(lhs_val, rhs_val) inputs = [ - torch.randint(0, 2, shape, dtype=bool), - torch.randint(0, 2, shape, dtype=bool), + torch.randint(0, 2, lhs_shape, dtype=bool), + torch.randint(0, 2, rhs_shape, dtype=bool), ] self.run_test( bitwise_xor(),