Skip to content

Commit

Permalink
feat: support aten.log1p converter (#2823)
Browse files Browse the repository at this point in the history
  • Loading branch information
chohk88 committed May 22, 2024
1 parent d499f2e commit c78fa7c
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 34 deletions.
85 changes: 51 additions & 34 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,57 @@ def aten_ops_log(
)


@dynamo_tensorrt_converter(torch.ops.aten.log2.default)
def aten_ops_log2(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.log2(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)


@dynamo_tensorrt_converter(torch.ops.aten.log10.default)
def aten_ops_log10(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.log10(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)


@dynamo_tensorrt_converter(torch.ops.aten.log1p.default)
def aten_ops_log1p(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.log1p(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)


@dynamo_tensorrt_converter(torch.ops.aten.sqrt.default)
def aten_ops_sqrt(
ctx: ConversionContext,
Expand Down Expand Up @@ -2849,23 +2900,6 @@ def aten_ops_flip(
)


@dynamo_tensorrt_converter(torch.ops.aten.log2.default)
def log2(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.log2(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)


@dynamo_tensorrt_converter(torch.ops.aten.scalar_tensor.default)
def aten_ops_scalar_tensor(
ctx: ConversionContext,
Expand All @@ -2879,23 +2913,6 @@ def aten_ops_scalar_tensor(
)


@dynamo_tensorrt_converter(torch.ops.aten.log10.default)
def log10(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.log10(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)


@dynamo_tensorrt_converter(torch.ops.aten.roll.default)
@enforce_tensor_types(
{
Expand Down
17 changes: 17 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,23 @@ def log2(
)


def log1p(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
) -> TRTTensor:
"""
Computes log(1 + x) for each element of the input tensor.
"""
one_plus_x = impl.elementwise.add(
ctx, target, source_ir, f"{name}_add", input_val, 1
)

return log(ctx, target, source_ir, f"{name}_log", one_plus_x)


def sqrt(
ctx: ConversionContext,
target: Target,
Expand Down
73 changes: 73 additions & 0 deletions tests/py/dynamo/conversion/test_log1p.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestLog1pConverter(DispatchTestCase):
@parameterized.expand(
[
((10,), torch.float),
((1, 20), torch.float),
((2, 3, 4), torch.float),
((2, 3, 4, 5), torch.float),
]
)
def test_log1p_float(self, input_shape, dtype):
class Log1p(nn.Module):
def forward(self, input):
return torch.ops.aten.log1p.default(input)

inputs = [
torch.randn(input_shape, dtype=dtype).abs() + 0.001
] # ensure positive input
self.run_test(
Log1p(),
inputs,
)

@parameterized.expand(
[
((10,), torch.int, 0, 5),
((1, 20), torch.int, 0, 10),
((2, 3, 4), torch.int, 0, 5),
((2, 3, 4, 5), torch.int, 0, 5),
]
)
def test_log1p_int(self, input_shape, dtype, low, high):
class Log1p(nn.Module):
def forward(self, input):
return torch.ops.aten.log1p.default(input)

inputs = [
torch.randint(low, high, input_shape, dtype=dtype).abs() + 0.001
] # ensure positive input
self.run_test(
Log1p(),
inputs,
)

@parameterized.expand(
[
(torch.full((1, 20), 2, dtype=torch.float),),
(torch.full((2, 3, 4), 3, dtype=torch.float),),
(torch.full((2, 3, 4, 5), 4, dtype=torch.float),),
]
)
def test_log1p_const_float(self, data):
class Log1p(nn.Module):
def forward(self, input):
return torch.ops.aten.log1p.default(input)

inputs = [data]
self.run_test(
Log1p(),
inputs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit c78fa7c

Please sign in to comment.