diff --git a/core/conversion/converters/impl/reduce.cpp b/core/conversion/converters/impl/reduce.cpp index eb986c6d60..2dbd6737ac 100644 --- a/core/conversion/converters/impl/reduce.cpp +++ b/core/conversion/converters/impl/reduce.cpp @@ -9,6 +9,23 @@ namespace impl { namespace { auto reduced_registrations = RegisterNodeConversionPatterns() .pattern({ + "aten::mean(Tensor self, *, ScalarType? dtype=None) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in_tensor = args[0].ITensor(); + auto in_dims = util::toVec(in_tensor->getDimensions()); + LOG_WARNING("Mean Converter disregards dtype"); + + uint32_t axis_mask = (uint32_t)(((uint64_t)1 << in_dims.size()) - 1); + + auto mean_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kAVG, axis_mask, false); + + TRTORCH_CHECK(mean_layer, "Unable to create mean layer from node: " << *n); + + mean_layer->setName(util::node_info(n).c_str()); + ctx->AssociateValueAndTensor(n->outputs()[0], mean_layer->getOutput(0)); + return true; + } + }).pattern({ "aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto in_tensor = args[0].ITensor(); @@ -23,7 +40,7 @@ auto reduced_registrations = RegisterNodeConversionPatterns() TRTORCH_CHECK(mean_layer, "Unable to create mean layer from node: " << *n); mean_layer->setName(util::node_info(n).c_str()); - associate_value_and_tensor(ctx, n->outputs()[0], mean_layer->getOutput(0)); + ctx->AssociateValueAndTensor(n->outputs()[0], mean_layer->getOutput(0)); return true; } }); @@ -32,5 +49,64 @@ auto reduced_registrations = RegisterNodeConversionPatterns() } // namespace converters } // namespace conversion } // namespace core -} // namespace trtorch +} // namespace trtorch + +// #include "core/util/prelude.h" +// #include "core/conversion/converters/converters.h" + +// namespace trtorch { +// namespace core { +// namespace conversion { +// namespace converters { +// namespace impl { +// namespace { + +// #define convert(unary, trt_type) \ +// auto unary##_registrations TRTORCH_UNUSED = \ +// RegisterNodeConversionPatterns().pattern( \ +// {"aten::" #unary "(Tensor self) -> Tensor", \ +// [](ConversionCtx *ctx, const torch::jit::Node *n, \ +// args &args) -> bool { \ +// auto in = args[0].ITensor(); \ +// auto unary = \ +// ctx->net->addUnary(*in, nvinfer1::UnaryOperation::trt_type); \ +// \ +// TRTORCH_CHECK( \ +// unary, \ +// "Unable to create " #unary " layer from node: " << *n); \ +// \ +// unary->setName(util::node_info(n).c_str()); \ +// auto out_tensor = ctx->AssociateValueAndTensor( \ +// n->outputs()[0], \ +// unary->getOutput(0)); \ +// LOG_DEBUG( \ +// "Output tensor shape: " << out_tensor->getDimensions()); \ +// \ +// return true; \ +// }}); + +// convert(cos, kCOS); +// convert(acos, kACOS); +// convert(cosh, kCOSH); +// convert(sin, kSIN); +// convert(asin, kASIN); +// convert(sinh, kSINH); +// convert(tan, kTAN); +// convert(atan, kATAN); +// convert(abs, kABS); +// convert(floor, kFLOOR); +// convert(reciprocal, kRECIP); +// convert(log, kLOG); +// convert(ceil, kCEIL); +// convert(sqrt, kSQRT); +// convert(exp, kEXP); +// convert(neg, kNEG); + +// #undef convert +// } // namespace +// } // namespace impl +// } // namespace converters +// } // namespace conversion +// } // namespace core +// } // namespace trtorch diff --git a/tests/core/converters/test_reduce.cpp b/tests/core/converters/test_reduce.cpp index b263ab74e1..ed67c50331 100644 --- a/tests/core/converters/test_reduce.cpp +++ b/tests/core/converters/test_reduce.cpp @@ -5,13 +5,55 @@ #include "core/compiler.h" TEST(Converters, ATenMeanConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %4 : None = prim::Constant() + %5 : Tensor = aten::mean(%0, %4) + return (%5))IR"; + + auto g = std::make_shared(); + torch::jit::script::parseIR(graph, &*g); + + auto in = at::randint(-5, 5, {4, 4}, at::kCUDA); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0])); +} + +TEST(Converters, ATenMeanHigherDimensionConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %4 : None = prim::Constant() + %5 : Tensor = aten::mean(%0, %4) + return (%5))IR"; + + auto g = std::make_shared(); + torch::jit::script::parseIR(graph, &*g); + + auto in = at::randint(-5, 5, {4, 4, 4, 4}, at::kCUDA); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0])); +} + +TEST(Converters, ATenMeanRowConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor): %1 : int = prim::Constant[value=1]() - %2 : int[] = prim::ListConstruct(%1) + %2 : int[] = prim::ListConstruct(%1) %3 : bool = prim::Constant[value=0]() %4 : None = prim::Constant() - %5 : Tensor = aten::mean(%0, %2, %3, %4) + %5 : Tensor = aten::mean(%0, %2, %3, %4) return (%5))IR"; auto g = std::make_shared(); @@ -24,7 +66,7 @@ TEST(Converters, ATenMeanConvertsCorrectly) { in = at::clone(in); params = trtorch::core::conversion::get_named_params(g->inputs(), {}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); - + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0])); } @@ -32,10 +74,10 @@ TEST(Converters, ATenMeanKeepDimsConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor): %1 : int = prim::Constant[value=1]() - %2 : int[] = prim::ListConstruct(%1) + %2 : int[] = prim::ListConstruct(%1) %3 : bool = prim::Constant[value=1]() %4 : None = prim::Constant() - %5 : Tensor = aten::mean(%0, %2, %3, %4) + %5 : Tensor = aten::mean(%0, %2, %3, %4) return (%5))IR"; auto g = std::make_shared(); @@ -48,6 +90,6 @@ TEST(Converters, ATenMeanKeepDimsConvertsCorrectly) { in = at::clone(in); params = trtorch::core::conversion::get_named_params(g->inputs(), {}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); - + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0])); } \ No newline at end of file