Skip to content

Commit

Permalink
fix: Fix deconv kernel channel num_output_maps where wts are ITensor (#…
Browse files Browse the repository at this point in the history
…2678)

Signed-off-by: Anurag Dixit <a.dixit91@gmail.com>
  • Loading branch information
andi4191 authored May 22, 2024
1 parent 50206d5 commit 101fac6
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 3 deletions.
8 changes: 5 additions & 3 deletions core/conversion/converters/impl/conv_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,12 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
filter_dim.nbDims = nbSpatialDims;
filter_dim.d[0] = kernel_dims.d[2];
filter_dim.d[1] = kernel_dims.d[3];
// For Conv2d layer, weights are in the shape of (out_channels, in_channels/groups,...)
int32_t num_output_maps = kernel_dims.d[0];
if (transposed) {
// For ConvTranspose layer, weights are in the shape of (in_channels, out_channel/groups,...)
num_output_maps = kernel_dims.d[1];
}
bool expand_dims = nbSpatialDims == 1;
if (expand_dims) {
// In case of Conv1D -> map it to 2D version
Expand All @@ -150,9 +155,6 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
LOG_DEBUG("Reshaping input dimensions to: " << in->getDimensions());
kernel = addPadding(ctx, n, kernel, 4, true, true, std::string(util::node_info(n) + "_kernel_shuffle"));
LOG_DEBUG("Reshaping kernel dimensions to: " << kernel->getDimensions());
if (transposed) {
num_output_maps = kernel_dims.d[1];
}
}

// Initialize a dummy constant kernel to pass it to INetwork->addConvolutionNd/addDeconvolutionNd API.
Expand Down
53 changes: 53 additions & 0 deletions tests/core/conversion/converters/test_conv_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,59 @@ TEST(Converters, ATenConvTransposeConvertsCorrectly) {
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenConvTranspose2dWithWeightsAsTensorsConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Float(48, 56, 3, 3, strides=[504, 9, 3, 1])):
%2 : int = prim::Constant[value=-128]()
%3 : float = prim::Constant[value=3.5]()
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=127]()
%quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%0, %3, %4, %2, %5)
%6 : int = prim::Constant[value=6]()
%7 : int = prim::Constant[value=56]()
%8 : Device = prim::Constant[value="cuda:0"]()
%9 : None = prim::Constant()
%10 : int[] = prim::ListConstruct(%7)
%11 : Tensor = aten::full(%10, %3, %6, %9, %8, %9)
%12 : int[] = prim::ListConstruct(%7)
%13 : int = prim::Constant[value=1]()
%14 : Tensor = aten::full(%12, %13, %6, %9, %8, %9)
%quant_wts : Tensor = aten::fake_quantize_per_channel_affine(%1, %11, %14, %13, %2, %5)
%15 : None = prim::Constant()
%16 : bool = prim::Constant[value=1]()
%17 : int = prim::Constant[value=1]() # Adjusted padding
%17.1: int = prim::Constant[value=0]() # Adjusted out_padding
%18 : int = prim::Constant[value=1]() # Adjusted dilation
%19 : int = prim::Constant[value=2]() # Adjusted stride
%20 : int = prim::Constant[value=1]()
%21 : int[] = prim::ListConstruct(%17)
%22 : int[] = prim::ListConstruct(%17, %17)
%23 : int[] = prim::ListConstruct(%18, %18)
%23.1: int[] = prim::ListConstruct(%17.1, %17.1)
%24 : int[] = prim::ListConstruct(%19, %19)
%25 : Tensor = aten::_convolution(%quant_input, %quant_wts, %15, %24, %22, %23, %16, %23.1, %17, %16, %16, %16, %16)
return (%25))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {1, 48, 2, 200}, {at::kCUDA});
auto w = at::randint(1, 2, {48, 56, 3, 3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto jit_w = at::clone(w);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in, jit_w});

auto trt_in = at::clone(in);
auto trt_w = at::clone(w);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in, trt_w}, nvinfer1::DataType::kINT8);

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenConvTransposeNoBiasConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
Expand Down

0 comments on commit 101fac6

Please sign in to comment.