Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/impl/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torch_tensorrt.dynamo.conversion.converter_utils import (
SourceIR,
cast_trt_tensor,
extend_attr_to_tuple,
get_trt_tensor,
has_dynamic_shape,
set_layer_name,
Expand Down Expand Up @@ -159,10 +158,9 @@ def convNd(
# Expand parameters manually for Conv1D computations
if is_conv1d:
padding = (tuple(padding) + (0,)) if padding is not None else padding
stride = extend_attr_to_tuple(stride, 2) if stride is not None else stride
dilation = (
extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation
)
# stride in conv1d is (2,) -> need to change to (2, 1) in conv2d
stride = (stride[0], 1) if stride is not None else stride
dilation = (dilation[0], 1) if dilation is not None else dilation

# Set relevant attributes of convolution layer
if padding is not None:
Expand Down
8 changes: 3 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/impl/deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
SourceIR,
extend_attr_to_tuple,
get_trt_tensor,
has_dynamic_shape,
to_torch,
Expand Down Expand Up @@ -142,10 +141,9 @@ def deconvNd(
# Expand parameters manually for Conv1D computations
if is_deconv1d:
padding = (tuple(padding) + (0,)) if padding is not None else padding
stride = extend_attr_to_tuple(stride, 2) if stride is not None else stride
dilation = (
extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation
)
# stride in deconv1d is (2,) -> need to change to (2, 1) in deconv2d
stride = (stride[0], 1) if stride is not None else stride
dilation = (dilation[0], 1) if dilation is not None else dilation
output_padding = (
(tuple(output_padding) + (0,))
if output_padding is not None
Expand Down
5 changes: 5 additions & 0 deletions tests/py/dynamo/conversion/test_convolution_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class TestConvolutionConverter(DispatchTestCase):
param("non_zero_padding", 1, padding=1),
param("dilation", 1, dilation=2),
param("groups", 1, groups=3),
param("stride", 1, stride=1),
param("stride_2", 1, stride=2),
param("stride_tuple", 1, stride=(2,)),
]
)
def test_conv1d(
Expand Down Expand Up @@ -52,6 +55,7 @@ def forward(self, x):
("tuple_parameters", 1, (1), (1)),
param("non_zero_padding", 1, padding=1),
param("dilation", 1, dilation=2),
param("stride", 1, stride=2),
]
)
def test_conv1d_TRTTensor_weight(
Expand Down Expand Up @@ -140,6 +144,7 @@ def forward(self, x):
param("tuple_dilation", 2, dilation=(3, 3)),
param("list_dilation", 2, dilation=[3]),
param("groups", 1, groups=3),
param("stride", 1, stride=(2, 2)),
]
)
def test_conv2d(
Expand Down
Loading