Skip to content

Commit

Permalink
fix: bug in elementwise base for static inputs (#2819)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed May 8, 2024
1 parent db67cb9 commit cd61e54
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 54 deletions.
74 changes: 71 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload

import numpy as np
import tensorrt as trt
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
from torch import SymBool, SymFloat, SymInt
Expand All @@ -15,11 +16,12 @@
ConverterRegistry,
DynamoConverterImplSignature,
)
from torch_tensorrt.fx.converters.converter_utils import get_axes_for_reduce_op
from torch_tensorrt.fx.converters.converter_utils import (
broadcast,
get_axes_for_reduce_op,
)
from torch_tensorrt.fx.types import TRTDataType, TRTTensor

import tensorrt as trt

_LOGGER: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -205,6 +207,72 @@ def broadcastable(
return True


def broadcast_to_same_shape(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
lhs_val: TRTTensor,
rhs_val: TRTTensor,
) -> Tuple[TRTTensor, TRTTensor]:
"""Broadcast ITensors `lhs_val` and `rhs_val` to the same shape. If the shapes are already the same, return the
original tensors. If the shapes are different, broadcast the tensors to the same shape.
This helper function is different from fx/converter_utils.broadcast.
fx/converter_utils.broadcast only broadcasts two ITensors to the same number of dimensions (ranks)
by prepending 1s, while this function broadcasts two ITensors to the same shape.
For example, we have original ITensors: lhs_val.shape: (2, 3) rhs_val.shape: (2, 2, 1, 3)
If calling fx/converter_utils.broadcast, lhs_val.shape: (1, 1, 2, 3) lhs_val.shape: (2, 2, 1, 3).
If calling this function broadcast_to_same_shape, lhs_val.shape: (2, 2, 2, 3) lhs_val.shape: (2, 2, 2, 3).
Args:
lhs_val (TRTTensor): A TensorRT ITensor.
rhs_val (TRTTensor): A TensorRT ITensor.
Returns:
Tuple[TRTTensor, TRTTensor]: Two TensorRT ITensors that are broadcasted to the same shape
"""
lhs_val, rhs_val = broadcast(
ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
)

lhs_val_shape = lhs_val.shape
rhs_val_shape = rhs_val.shape

if tuple(lhs_val_shape) != tuple(rhs_val_shape):
rank = len(lhs_val_shape)
expanded_dims = [-1] * len(lhs_val_shape)

for dim in range(rank):
expanded_dims[dim] = max(lhs_val_shape[dim], rhs_val_shape[dim])

expanded_shape = tuple(expanded_dims)

if lhs_val_shape != expanded_shape:
lhs_val = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_lhs_val",
lhs_val,
expanded_shape,
)

if rhs_val_shape != expanded_shape:
rhs_val = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_rhs_val",
rhs_val,
expanded_shape,
)

return lhs_val, rhs_val


get_axes_for_reduce_op = functools.partial(
get_axes_for_reduce_op, has_implicit_batch_dimension=False
)
Expand Down
39 changes: 5 additions & 34 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from torch.fx.node import Target
from torch_tensorrt import _enums
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
broadcast_to_same_shape,
cast_trt_tensor,
get_trt_tensor,
)
Expand Down Expand Up @@ -152,41 +152,12 @@ def convert_binary_elementwise(

if has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape):
lhs_val, rhs_val = broadcast(
ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
ctx.net, lhs_val, rhs_val, f"{name}_broadcast_lhs", f"{name}_broadcast_rhs"
)
else:
lhs_val_shape = lhs_val.shape
rhs_val_shape = rhs_val.shape
rank_diff = len(lhs_val_shape) - len(rhs_val_shape)
if rank_diff > 0:
rhs_val = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_rhs_val", rhs_val, lhs_val_shape
)
elif rank_diff < 0:
lhs_val = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_lhs_val", lhs_val, rhs_val_shape
)
else:
if tuple(lhs_val_shape) != tuple(rhs_val_shape):
sum_diff = sum(lhs_val_shape) - sum(rhs_val_shape)
if sum_diff > 0:
rhs_val = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_rhs_val",
rhs_val,
lhs_val_shape,
)
elif sum_diff < 0:
lhs_val = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_lhs_val",
lhs_val,
rhs_val_shape,
)
lhs_val, rhs_val = broadcast_to_same_shape(
ctx, target, source_ir, f"{name}_broadcast_to_same_shape", lhs_val, rhs_val
)

layer = ctx.net.add_elementwise(lhs_val, rhs_val, op_type)
set_layer_name(layer, target, name, source_ir)
Expand Down
2 changes: 0 additions & 2 deletions tests/py/dynamo/conversion/test_binary_ops_aten.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import unittest
from typing import Callable

import torch
Expand Down Expand Up @@ -59,7 +58,6 @@ def forward(self, x):
self.run_test(m, inputs)

@parameterized.expand([(op[0].__name__, op[0]) for op in elementwise_ops])
@unittest.skip("Pending reimplementation of all binary converters in Dynamo")
def test_elementwise_ops_mismatched_dtypes(self, name, orig_op: Callable):
class TestModule(nn.Module):
def __init__(self, orig_op):
Expand Down
13 changes: 8 additions & 5 deletions tests/py/dynamo/conversion/test_bitwise_and_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@
class TestBitwiseAndConverter(DispatchTestCase):
@parameterized.expand(
[
("2d", (5, 3)),
("3d", (5, 3, 2)),
("2d", (2, 3), (2, 3)),
("3d", (5, 3, 2), (5, 3, 2)),
("3d_broadcast", (2, 3), (2, 1, 3)),
("4d_broadcast_1", (2, 3), (1, 2, 1, 3)),
("4d_broadcast_2", (2, 3), (2, 2, 2, 3)),
]
)
def test_bitwise_and_tensor(self, _, shape):
def test_bitwise_and_tensor(self, _, lhs_shape, rhs_shape):
class bitwise_and(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.bitwise_and.Tensor(lhs_val, rhs_val)

inputs = [
torch.randint(0, 2, shape, dtype=bool),
torch.randint(0, 2, shape, dtype=bool),
torch.randint(0, 2, lhs_shape, dtype=bool),
torch.randint(0, 2, rhs_shape, dtype=bool),
]
self.run_test(
bitwise_and(),
Expand Down
13 changes: 8 additions & 5 deletions tests/py/dynamo/conversion/test_bitwise_or_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@
class TestBitwiseOrConverter(DispatchTestCase):
@parameterized.expand(
[
("2d", (5, 3)),
("3d", (5, 3, 2)),
("2d", (2, 3), (2, 3)),
("3d", (5, 3, 2), (5, 3, 2)),
("3d_broadcast", (2, 3), (2, 1, 3)),
("4d_broadcast_1", (2, 3), (1, 2, 1, 3)),
("4d_broadcast_2", (2, 3), (2, 2, 2, 3)),
]
)
def test_bitwise_or_tensor(self, _, shape):
def test_bitwise_or_tensor(self, _, lhs_shape, rhs_shape):
class bitwise_or(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.bitwise_or.Tensor(lhs_val, rhs_val)

inputs = [
torch.randint(0, 2, shape, dtype=bool),
torch.randint(0, 2, shape, dtype=bool),
torch.randint(0, 2, lhs_shape, dtype=bool),
torch.randint(0, 2, rhs_shape, dtype=bool),
]
self.run_test(
bitwise_or(),
Expand Down
13 changes: 8 additions & 5 deletions tests/py/dynamo/conversion/test_bitwise_xor_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@
class TestBitwiseXorConverter(DispatchTestCase):
@parameterized.expand(
[
("2d", (5, 3)),
("3d", (5, 3, 2)),
("2d", (2, 3), (2, 3)),
("3d", (5, 3, 2), (5, 3, 2)),
("3d_broadcast", (2, 3), (2, 1, 3)),
("4d_broadcast_1", (2, 3), (1, 2, 1, 3)),
("4d_broadcast_2", (2, 3), (2, 2, 2, 3)),
]
)
def test_bitwise_xor_tensor(self, _, shape):
def test_bitwise_xor_tensor(self, _, lhs_shape, rhs_shape):
class bitwise_xor(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.bitwise_xor.Tensor(lhs_val, rhs_val)

inputs = [
torch.randint(0, 2, shape, dtype=bool),
torch.randint(0, 2, shape, dtype=bool),
torch.randint(0, 2, lhs_shape, dtype=bool),
torch.randint(0, 2, rhs_shape, dtype=bool),
]
self.run_test(
bitwise_xor(),
Expand Down

0 comments on commit cd61e54

Please sign in to comment.