Skip to content

Commit

Permalink
[feat] support converter for torch.log2
Browse files Browse the repository at this point in the history
  • Loading branch information
bowang007 committed Mar 4, 2024
1 parent ab08c63 commit ad74a73
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 0 deletions.
17 changes: 17 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2695,6 +2695,23 @@ def aten_ops_flip(
)


@dynamo_tensorrt_converter(torch.ops.aten.log2.default)
def log2(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.log2(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)


@dynamo_tensorrt_converter(torch.ops.aten.scalar_tensor.default)
def aten_ops_scalar_tensor(
ctx: ConversionContext,
Expand Down
16 changes: 16 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,22 @@ def log10(
)


def log2(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
) -> TRTTensor:
log_layer_output = log(ctx, target, source_ir, f"{name}_log", input_val)

ln2 = 0.693147180559945309

return impl.elementwise.div(
ctx, target, source_ir, f"{name}_div", log_layer_output, ln2
)


def sqrt(
ctx: ConversionContext,
target: Target,
Expand Down
49 changes: 49 additions & 0 deletions tests/py/dynamo/conversion/test_log2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestLogConverter(DispatchTestCase):
@parameterized.expand(
[
((10,), torch.float),
((1, 20), torch.float),
((2, 3, 4), torch.float),
((2, 3, 4, 5), torch.float),
]
)
def test_log_float(self, input_shape, dtype):
class log2(nn.Module):
def forward(self, input):
return torch.ops.aten.log2.default(input)

inputs = [torch.randn(input_shape, dtype=dtype)]
self.run_test(
log2(),
inputs,
)

@parameterized.expand(
[
((10,), torch.int, 0, 5),
((1, 20), torch.int, -10, 10),
((2, 3, 4), torch.int, -5, 5),
]
)
def test_log_int(self, input_shape, dtype, low, high):
class log2(nn.Module):
def forward(self, input):
return torch.ops.aten.log2.default(input)

inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
log2(),
inputs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit ad74a73

Please sign in to comment.