From 44313b3fc230e8865b927b14fb95a86b0bef3fe1 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Thu, 9 May 2024 15:37:08 +0900 Subject: [PATCH 1/2] feat: support aten.log1p converter --- .../dynamo/conversion/aten_ops_converters.py | 85 +++++++++++-------- .../dynamo/conversion/impl/unary/ops.py | 17 ++++ tests/py/dynamo/conversion/test_log1p.py | 73 ++++++++++++++++ 3 files changed, 141 insertions(+), 34 deletions(-) create mode 100644 tests/py/dynamo/conversion/test_log1p.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1705dd06db..9bf1e3ead3 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1165,6 +1165,57 @@ def aten_ops_log( ) +@dynamo_tensorrt_converter(torch.ops.aten.log2.default) +def log2( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.log2( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.log10.default) +def log10( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.log10( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.log1p.default) +def aten_ops_log1p( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.log1p( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.sqrt.default) def aten_ops_sqrt( ctx: ConversionContext, @@ -2829,23 +2880,6 @@ def aten_ops_flip( ) -@dynamo_tensorrt_converter(torch.ops.aten.log2.default) -def log2( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.unary.log2( - ctx, - target, - SourceIR.ATEN, - name, - args[0], - ) - - @dynamo_tensorrt_converter(torch.ops.aten.scalar_tensor.default) def aten_ops_scalar_tensor( ctx: ConversionContext, @@ -2859,23 +2893,6 @@ def aten_ops_scalar_tensor( ) -@dynamo_tensorrt_converter(torch.ops.aten.log10.default) -def log10( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.unary.log10( - ctx, - target, - SourceIR.ATEN, - name, - args[0], - ) - - @dynamo_tensorrt_converter(torch.ops.aten.roll.default) @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index 9f2ad07612..beb13fca9b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -119,6 +119,23 @@ def log2( ) +def log1p( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + """ + Computes log(1 + x) for each element of the input tensor. + """ + one_plus_x = impl.elementwise.add( + ctx, target, source_ir, f"{name}_add", input_val, 1 + ) + + return log(ctx, target, source_ir, f"{name}_log", one_plus_x) + + def sqrt( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/conversion/test_log1p.py b/tests/py/dynamo/conversion/test_log1p.py new file mode 100644 index 0000000000..7e59cc16d3 --- /dev/null +++ b/tests/py/dynamo/conversion/test_log1p.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestLog1pConverter(DispatchTestCase): + @parameterized.expand( + [ + ((10,), torch.float), + ((1, 20), torch.float), + ((2, 3, 4), torch.float), + ((2, 3, 4, 5), torch.float), + ] + ) + def test_log1p_float(self, input_shape, dtype): + class Log1p(nn.Module): + def forward(self, input): + return torch.ops.aten.log1p.default(input) + + inputs = [ + torch.randn(input_shape, dtype=dtype).abs() + 0.001 + ] # ensure positive input + self.run_test( + Log1p(), + inputs, + ) + + @parameterized.expand( + [ + ((10,), torch.int, 0, 5), + ((1, 20), torch.int, 0, 10), + ((2, 3, 4), torch.int, 0, 5), + ((2, 3, 4, 5), torch.int, 0, 5), + ] + ) + def test_log1p_int(self, input_shape, dtype, low, high): + class Log1p(nn.Module): + def forward(self, input): + return torch.ops.aten.log1p.default(input) + + inputs = [ + torch.randint(low, high, input_shape, dtype=dtype).abs() + 0.001 + ] # ensure positive input + self.run_test( + Log1p(), + inputs, + ) + + @parameterized.expand( + [ + (torch.full((1, 20), 2, dtype=torch.float),), + (torch.full((2, 3, 4), 3, dtype=torch.float),), + (torch.full((2, 3, 4, 5), 4, dtype=torch.float),), + ] + ) + def test_log1p_const_float(self, data): + class Log1p(nn.Module): + def forward(self, input): + return torch.ops.aten.log1p.default(input) + + inputs = [data] + self.run_test( + Log1p(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() From dbb97721074c76446507f4a037445f255c70c827 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Wed, 22 May 2024 12:22:48 +0900 Subject: [PATCH 2/2] chore: minor naming changes --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 9bf1e3ead3..08efc51ce1 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1166,7 +1166,7 @@ def aten_ops_log( @dynamo_tensorrt_converter(torch.ops.aten.log2.default) -def log2( +def aten_ops_log2( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], @@ -1183,7 +1183,7 @@ def log2( @dynamo_tensorrt_converter(torch.ops.aten.log10.default) -def log10( +def aten_ops_log10( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...],