Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions core/conversion/converters/impl/conv_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().patter
bool deterministic, bool cudnn_enabled) -> (Tensor))SIG",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor(); // assumes non-static input Tensor

auto w = Weights(ctx, args[1].unwrapToTensor());
auto stride = util::toDims(args[3].unwrapToIntList());
LOG_DEBUG("stride: " << stride);
Expand All @@ -29,36 +28,45 @@ auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().patter
auto out_padding = util::toDims(args[7].unwrapToIntList());
LOG_DEBUG("out_padding: " << out_padding);
int64_t groups = args[8].unwrapToInt();
LOG_DEBUG("groups: " << groups);

nvinfer1::ILayer* new_layer;
if (transposed) {
nvinfer1::IDeconvolutionLayer* deconv;
Weights bias;
if (args[2].IValue()->isTensor()) {
Weights b(ctx, args[2].IValue()->toTensor());
deconv = ctx->net->addDeconvolutionNd(*in, w.num_input_maps, w.kernel_shape, w.data, b.data);
bias = Weights(ctx, args[2].unwrapToTensor());
} else {
deconv = ctx->net->addDeconvolutionNd(*in, w.num_input_maps, w.kernel_shape, w.data, {});
bias = Weights(ctx, torch::zeros(args[1].unwrapToTensor().sizes()[1] * groups));
}

// shape of deconvolution's weight: [in, out/groups, ...]
auto deconv = ctx->net->addDeconvolutionNd(
*in, args[1].unwrapToTensor().sizes()[1] * groups, w.kernel_shape, w.data, bias.data);
TRTORCH_CHECK(deconv, "Unable to create deconvolution layer from node: " << *n);

deconv->setStrideNd(stride);
deconv->setPaddingNd(padding);
#if NV_TENSORRT_MAJOR > 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR == 1)
deconv->setDilationNd(dilation);
deconv->setNbGroups(groups);
#else
TRTORCH_CHECK(groups == 1, "for deconv with groups > 1, require TensorRT version >= 7.1");
for (auto it = dilation.begin(); it != dilation.end(); ++it) {
TRTORCH_CHECK(*it == 1, "for deconv with dilation > 1, require TensorRT version >= 7.1");
}
#endif
new_layer = deconv;
} else {
nvinfer1::IConvolutionLayer* conv;
Weights bias;
if (args[2].IValue()->isTensor()) {
Weights b(ctx, args[2].unwrapToTensor());
conv = ctx->net->addConvolutionNd(*in, w.num_output_maps, w.kernel_shape, w.data, b.data);
bias = Weights(ctx, args[2].unwrapToTensor());
} else {
Weights b(ctx, torch::zeros(args[1].unwrapToTensor().sizes()[0]));
conv = ctx->net->addConvolutionNd(*in, w.num_output_maps, w.kernel_shape, w.data, b.data);
bias = Weights(ctx, torch::zeros(args[1].unwrapToTensor().sizes()[0]));
}

// shape of convolution's weight: [out, in/groups, ...]
auto conv =
ctx->net->addConvolutionNd(*in, args[1].unwrapToTensor().sizes()[0], w.kernel_shape, w.data, bias.data);
TRTORCH_CHECK(conv, "Unable to create convolution layer from node: " << *n);

conv->setStrideNd(stride);
Expand Down
147 changes: 85 additions & 62 deletions tests/core/converters/test_conv_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,65 +532,88 @@ TEST(Converters, ATenConvTransposeWithPaddingConvertsCorrectly) {
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

// TEST(Converters, ATenConvolutionWithDialationConvertsCorrectly) {
// const auto graph = R"IR(
// graph(%0 : Tensor,
// %1 : Float(8, 3, 5, 5),
// %2 : Float(8)):
// %3 : int = prim::Constant[value=1]()
// %4 : int = prim::Constant[value=0]()
// %5 : int = prim::Constant[value=2]()
// %6 : int = prim::Constant[value=0]()
// %7 : bool = prim::Constant[value=0]()
// %8 : int[] = prim::ListConstruct(%3, %3)
// %9 : int[] = prim::ListConstruct(%4, %4)
// %10 : int[] = prim::ListConstruct(%5, %5)
// %11 : int[] = prim::ListConstruct(%6, %6)
// %12 : int = prim::Constant[value=1]()
// %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11,
// %12, %7, %7, %7) return (%13))IR";

// conv_test_helper(graph);
// }

// TEST(Converters, ATenConvolutionWithPostPaddingConvertsCorrectly) {
// const auto graph = R"IR(
// graph(%0 : Tensor,
// %1 : Float(8, 3, 5, 5),
// %2 : Float(8)):
// %3 : int = prim::Constant[value=1]()
// %4 : int = prim::Constant[value=0]()
// %5 : int = prim::Constant[value=1]()
// %6 : int = prim::Constant[value=2]()
// %7 : bool = prim::Constant[value=0]()
// %8 : int[] = prim::ListConstruct(%3, %3)
// %9 : int[] = prim::ListConstruct(%4, %4)
// %10 : int[] = prim::ListConstruct(%5, %5)
// %11 : int[] = prim::ListConstruct(%6, %6)
// %12 : int = prim::Constant[value=1]()
// %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11,
// %12, %7, %7, %7) return (%13))IR";

// conv_test_helper(graph);
// }

// TEST(Converters, ATenConvolutionWithGroupConvertsCorrectly) {
// const auto graph = R"IR(
// graph(%0 : Tensor,
// %1 : Float(8, 3, 5, 5),
// %2 : Float(8)):
// %3 : int = prim::Constant[value=1]()
// %4 : int = prim::Constant[value=0]()
// %5 : int = prim::Constant[value=1]()
// %6 : int = prim::Constant[value=0]()
// %7 : bool = prim::Constant[value=0]()
// %8 : int[] = prim::ListConstruct(%3, %3)
// %9 : int[] = prim::ListConstruct(%4, %4)
// %10 : int[] = prim::ListConstruct(%5, %5)
// %11 : int[] = prim::ListConstruct(%6, %6)
// %12 : int = prim::Constant[value=2]()
// %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11,
// %12, %7, %7, %7) return (%13))IR";

// conv_test_helper(graph);
// }
TEST(Converters, ATenConvolutionWithGroupConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Float(8:48, 1:16, 2:4, 2:1),
%2 : Float(8:1)):
%3 : int = prim::Constant[value=1]()
%4 : int = prim::Constant[value=2]()
%5 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=0]()
%7 : bool = prim::Constant[value=0]()
%8 : int[] = prim::ListConstruct(%3, %3)
%9 : int[] = prim::ListConstruct(%4, %4)
%10 : int[] = prim::ListConstruct(%5, %5)
%11 : int[] = prim::ListConstruct(%6, %6)
%12 : int = prim::Constant[value=4]()
%13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7)
return (%13))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(1, 10, {1, 4, 4, 4}, {at::kCUDA});
auto w = at::randint(1, 10, {8, 1, 2, 2}, {at::kCUDA});
auto b = at::randint(1, 10, {8}, {at::kCUDA});

auto jit_in = at::clone(in);
auto jit_w = at::clone(w);
auto jit_b = at::clone(b);

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_w = at::clone(w);
auto trt_b = at::clone(b);
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenConvTransposeWithGroupConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Float(8:56, 4:16, 3:3, 3:1),
%2 : Float(16:1)):
%3 : int = prim::Constant[value=1]()
%4 : int = prim::Constant[value=1]()
%5 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=0]()
%7 : bool = prim::Constant[value=1]()
%8 : int[] = prim::ListConstruct(%3, %3)
%9 : int[] = prim::ListConstruct(%4, %4)
%10 : int[] = prim::ListConstruct(%5, %5)
%11 : int[] = prim::ListConstruct(%6, %6)
%12 : int = prim::Constant[value=4]()
%13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7)
return (%13))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(1, 10, {1, 8, 5, 5}, {at::kCUDA});
auto w = at::randint(1, 10, {8, 4, 3, 3}, {at::kCUDA});
auto b = at::randint(1, 10, {16}, {at::kCUDA});

auto jit_in = at::clone(in);
auto jit_w = at::clone(w);
auto jit_b = at::clone(b);

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_w = at::clone(w);
auto trt_b = at::clone(b);
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}