Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,7 @@ def aten_ops_expand(
)


@dynamo_tensorrt_converter(torch.ops.aten.amax.default)
@dynamo_tensorrt_converter(torch.ops.aten.amax.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand All @@ -1109,7 +1109,7 @@ def aten_ops_amax(
)


@dynamo_tensorrt_converter(torch.ops.aten.amin.default)
@dynamo_tensorrt_converter(torch.ops.aten.amin.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand All @@ -1133,9 +1133,9 @@ def aten_ops_amin(
)


@dynamo_tensorrt_converter(torch.ops.aten.sum.default)
@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList)
@dynamo_tensorrt_converter(torch.ops.prims.sum.default)
@dynamo_tensorrt_converter(torch.ops.aten.sum.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.prims.sum.default, supports_dynamic_shapes=True)
def aten_ops_sum(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -1167,8 +1167,8 @@ def aten_ops_sum(
return sum_


@dynamo_tensorrt_converter(torch.ops.aten.prod.default)
@dynamo_tensorrt_converter(torch.ops.aten.prod.dim_int)
@dynamo_tensorrt_converter(torch.ops.aten.prod.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.prod.dim_int, supports_dynamic_shapes=True)
def aten_ops_prod(
ctx: ConversionContext,
target: Target,
Expand All @@ -1187,9 +1187,14 @@ def aten_ops_prod(
)


@dynamo_tensorrt_converter(torch.ops.aten.max.default)
@dynamo_tensorrt_converter(
torch.ops.aten.max.dim, capability_validator=one_user_validator
torch.ops.aten.max.default,
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
torch.ops.aten.max.dim,
capability_validator=one_user_validator,
supports_dynamic_shapes=True,
)
def aten_ops_max(
ctx: ConversionContext,
Expand All @@ -1210,9 +1215,14 @@ def aten_ops_max(
)


@dynamo_tensorrt_converter(torch.ops.aten.min.default)
@dynamo_tensorrt_converter(
torch.ops.aten.min.dim, capability_validator=one_user_validator
torch.ops.aten.min.default,
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
torch.ops.aten.min.dim,
capability_validator=one_user_validator,
supports_dynamic_shapes=True,
)
def aten_ops_min(
ctx: ConversionContext,
Expand Down
33 changes: 33 additions & 0 deletions tests/py/dynamo/conversion/test_amax_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -90,6 +91,38 @@ def forward(self, x):
check_dtype=False,
)

@parameterized.expand(
[
((0, 1), True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
((0,), True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
(1, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
(2, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
(-1, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
((-1, 0), True, (2, 2, 5), (3, 3, 6), (4, 5, 7)),
]
)
def test_amax_dynamic_shape(self, dim, keep_dim, min_shape, opt_shape, max_shape):
class Amax(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, x):
return torch.ops.aten.amax.default(x, dim, keep_dim)

input_specs = [
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
Amax(dim),
input_specs,
)


if __name__ == "__main__":
run_tests()
33 changes: 33 additions & 0 deletions tests/py/dynamo/conversion/test_amin_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -90,6 +91,38 @@ def forward(self, x):
check_dtype=False,
)

@parameterized.expand(
[
((0, 1), True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
((0,), False, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
(1, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
(2, False, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
(-1, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
((-1, 0), True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
]
)
def test_amin_dynamic_shape(self, dim, keep_dim, min_shape, opt_shape, max_shape):
class Amin(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, x):
return torch.ops.aten.amin.default(x, dim, keep_dim)

input_specs = [
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
Amin(dim),
input_specs,
)


if __name__ == "__main__":
run_tests()
57 changes: 57 additions & 0 deletions tests/py/dynamo/conversion/test_max_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -65,6 +66,62 @@ def forward(self, x):
check_dtype=False,
)

@parameterized.expand(
[
(1, True, (2, 2, 3), (2, 3, 3), (3, 3, 4)),
(2, False, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
(-1, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
]
)
def test_max_dim_dynamic_shape(
self, dim, keep_dim, min_shape, opt_shape, max_shape
):
class Max(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, x):
return torch.ops.aten.max.dim(x, dim, keep_dim)[0]

input_specs = [
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
Max(dim),
input_specs,
)

@parameterized.expand(
[
((2, 2, 3), (2, 3, 3), (3, 3, 4)),
((2, 3, 5), (3, 4, 6), (4, 5, 7)),
((2, 3, 5), (3, 4, 6), (4, 5, 7)),
]
)
def test_max_default_dynamic_shape(self, min_shape, opt_shape, max_shape):
class Max(nn.Module):
def forward(self, x):
return torch.ops.aten.max.default(x)

input_specs = [
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
Max(),
input_specs,
)


if __name__ == "__main__":
run_tests()
57 changes: 57 additions & 0 deletions tests/py/dynamo/conversion/test_min_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -65,6 +66,62 @@ def forward(self, x):
check_dtype=False,
)

@parameterized.expand(
[
(1, True, (2, 2, 3), (2, 3, 3), (3, 3, 4)),
(2, False, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
(-1, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
]
)
def test_min_dim_dynamic_shape(
self, dim, keep_dim, min_shape, opt_shape, max_shape
):
class Min(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, x):
return torch.ops.aten.min.dim(x, dim, keep_dim)[0]

input_specs = [
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
Min(dim),
input_specs,
)

@parameterized.expand(
[
((2, 2, 3), (2, 3, 3), (3, 3, 4)),
((2, 3, 5), (3, 4, 6), (4, 5, 7)),
((2, 3, 5), (3, 4, 6), (4, 5, 7)),
]
)
def test_min_default_dynamic_shape(self, min_shape, opt_shape, max_shape):
class Min(nn.Module):
def forward(self, x):
return torch.ops.aten.min.default(x)

input_specs = [
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
Min(),
input_specs,
)


if __name__ == "__main__":
run_tests()
28 changes: 28 additions & 0 deletions tests/py/dynamo/conversion/test_prod_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -68,6 +69,33 @@ def forward(self, x):
use_dynamo_tracer=True,
)

@parameterized.expand(
[
(0, (2, 3), (2, 4), (3, 5)),
(1, (2, 3), (2, 4), (3, 5)),
(2, (2, 2, 4), (2, 3, 4), (3, 4, 5)),
(-1, (2, 2, 4), (2, 3, 4), (3, 4, 5)),
]
)
def test_prod_dynamic_shape(self, dim, min_shape, opt_shape, max_shape):
class Prod(nn.Module):
def forward(self, x):
return torch.prod(x, dim)

input_specs = [
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
Prod(),
input_specs,
use_dynamo_tracer=True,
)


if __name__ == "__main__":
run_tests()
Loading