From 101fac6d65be0a818dfc86337a68010b9f508e3c Mon Sep 17 00:00:00 2001 From: Anurag Dixit Date: Wed, 22 May 2024 11:49:50 -0700 Subject: [PATCH] fix: Fix deconv kernel channel num_output_maps where wts are ITensor (#2678) Signed-off-by: Anurag Dixit --- .../converters/impl/conv_deconv.cpp | 8 +-- .../converters/test_conv_deconv.cpp | 53 +++++++++++++++++++ 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/core/conversion/converters/impl/conv_deconv.cpp b/core/conversion/converters/impl/conv_deconv.cpp index c71007ac03..da6ce4b98a 100644 --- a/core/conversion/converters/impl/conv_deconv.cpp +++ b/core/conversion/converters/impl/conv_deconv.cpp @@ -139,7 +139,12 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) filter_dim.nbDims = nbSpatialDims; filter_dim.d[0] = kernel_dims.d[2]; filter_dim.d[1] = kernel_dims.d[3]; + // For Conv2d layer, weights are in the shape of (out_channels, in_channels/groups,...) int32_t num_output_maps = kernel_dims.d[0]; + if (transposed) { + // For ConvTranspose layer, weights are in the shape of (in_channels, out_channel/groups,...) + num_output_maps = kernel_dims.d[1]; + } bool expand_dims = nbSpatialDims == 1; if (expand_dims) { // In case of Conv1D -> map it to 2D version @@ -150,9 +155,6 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) LOG_DEBUG("Reshaping input dimensions to: " << in->getDimensions()); kernel = addPadding(ctx, n, kernel, 4, true, true, std::string(util::node_info(n) + "_kernel_shuffle")); LOG_DEBUG("Reshaping kernel dimensions to: " << kernel->getDimensions()); - if (transposed) { - num_output_maps = kernel_dims.d[1]; - } } // Initialize a dummy constant kernel to pass it to INetwork->addConvolutionNd/addDeconvolutionNd API. diff --git a/tests/core/conversion/converters/test_conv_deconv.cpp b/tests/core/conversion/converters/test_conv_deconv.cpp index faaf7f2474..662da71a6f 100644 --- a/tests/core/conversion/converters/test_conv_deconv.cpp +++ b/tests/core/conversion/converters/test_conv_deconv.cpp @@ -497,6 +497,59 @@ TEST(Converters, ATenConvTransposeConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } +TEST(Converters, ATenConvTranspose2dWithWeightsAsTensorsConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Float(48, 56, 3, 3, strides=[504, 9, 3, 1])): + %2 : int = prim::Constant[value=-128]() + %3 : float = prim::Constant[value=3.5]() + %4 : int = prim::Constant[value=0]() + %5 : int = prim::Constant[value=127]() + %quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%0, %3, %4, %2, %5) + %6 : int = prim::Constant[value=6]() + %7 : int = prim::Constant[value=56]() + %8 : Device = prim::Constant[value="cuda:0"]() + %9 : None = prim::Constant() + %10 : int[] = prim::ListConstruct(%7) + %11 : Tensor = aten::full(%10, %3, %6, %9, %8, %9) + %12 : int[] = prim::ListConstruct(%7) + %13 : int = prim::Constant[value=1]() + %14 : Tensor = aten::full(%12, %13, %6, %9, %8, %9) + %quant_wts : Tensor = aten::fake_quantize_per_channel_affine(%1, %11, %14, %13, %2, %5) + %15 : None = prim::Constant() + %16 : bool = prim::Constant[value=1]() + %17 : int = prim::Constant[value=1]() # Adjusted padding + %17.1: int = prim::Constant[value=0]() # Adjusted out_padding + %18 : int = prim::Constant[value=1]() # Adjusted dilation + %19 : int = prim::Constant[value=2]() # Adjusted stride + %20 : int = prim::Constant[value=1]() + %21 : int[] = prim::ListConstruct(%17) + %22 : int[] = prim::ListConstruct(%17, %17) + %23 : int[] = prim::ListConstruct(%18, %18) + %23.1: int[] = prim::ListConstruct(%17.1, %17.1) + %24 : int[] = prim::ListConstruct(%19, %19) + %25 : Tensor = aten::_convolution(%quant_input, %quant_wts, %15, %24, %22, %23, %16, %23.1, %17, %16, %16, %16, %16) + return (%25))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 48, 2, 200}, {at::kCUDA}); + auto w = at::randint(1, 2, {48, 56, 3, 3}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto jit_w = at::clone(w); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in, jit_w}); + + auto trt_in = at::clone(in); + auto trt_w = at::clone(w); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in, trt_w}, nvinfer1::DataType::kINT8); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + TEST(Converters, ATenConvTransposeNoBiasConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor,