From c275dd0b72cb591c71837c3d2a56ff7f6e92f0ea Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Sat, 7 May 2022 19:53:17 -0700 Subject: [PATCH 1/5] tests: Update to new exception operator Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- tests/core/lowering/test_exception_elimination_pass.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/core/lowering/test_exception_elimination_pass.cpp b/tests/core/lowering/test_exception_elimination_pass.cpp index e8a96ca97a..441582ef0f 100644 --- a/tests/core/lowering/test_exception_elimination_pass.cpp +++ b/tests/core/lowering/test_exception_elimination_pass.cpp @@ -42,7 +42,7 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block0) { g->insertNode(bool_node); auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0); auto if_block0 = if_node->addBlock(); - auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val}, 0); + auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0); if_block0->appendNode(exception_node); auto if_block1 = if_node->addBlock(); g->insertNode(if_node); @@ -50,7 +50,9 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block0) { g->insertNode(cat_node); g->registerOutput(cat_node->output()); + std::cout << "Source Graph: " << *g << std::endl; torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern(g); + std::cout << "Modified Graph: " << *g << std::endl; for (auto node : g->nodes()) { EXPECT_NE(node, if_node); } @@ -97,14 +99,16 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block1) { auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0); auto if_block0 = if_node->addBlock(); auto if_block1 = if_node->addBlock(); - auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val}, 0); + auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0); if_block1->appendNode(exception_node); g->insertNode(if_node); auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val}); g->insertNode(cat_node); g->registerOutput(cat_node->output()); + std::cout << "Source Graph: " << *g << std::endl; torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern(g); + std::cout << "Modified Graph: " << *g << std::endl; for (auto node : g->nodes()) { EXPECT_NE(node, if_node); } From 828d12056d31233f3ad5623b2d212f081e3f7dd5 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Sat, 7 May 2022 19:54:59 -0700 Subject: [PATCH 2/5] tests: Update fp16 test for new API Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- tests/accuracy/test_fp16_accuracy.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/accuracy/test_fp16_accuracy.cpp b/tests/accuracy/test_fp16_accuracy.cpp index dd68202312..f32f8c1df0 100644 --- a/tests/accuracy/test_fp16_accuracy.cpp +++ b/tests/accuracy/test_fp16_accuracy.cpp @@ -25,8 +25,10 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) { } torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100; - std::vector> input_shape = {{32, 3, 32, 32}}; - auto compile_spec = torch_tensorrt::ts::CompileSpec({input_shape}); + std::vector input_shape = {32, 3, 32, 32}; + auto input = torch_tensorrt::Input(input_shape); + input.dtype = torch::kF16; + auto compile_spec = torch_tensorrt::ts::CompileSpec({input}); compile_spec.enabled_precisions.insert(torch::kF16); auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec); From 65dbf90a130ce545d567e6f6b6baa07701b554ba Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Sat, 7 May 2022 19:55:46 -0700 Subject: [PATCH 3/5] feat(aten::add): adding string concat evaluator Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/evaluators/aten.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 30cdeaa46a..018b565421 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -342,6 +342,10 @@ auto aten_registrations TORCHTRT_UNUSED = auto a = args.at(n->input(0)).unwrapToDouble(); auto b = args.at(n->input(1)).unwrapToDouble(); return a + b; + } else if (args.at(n->input(0)).IValue()->isString()) { + auto a = args.at(n->input(0)).unwrapToString(); + auto b = args.at(n->input(1)).unwrapToString(); + return a + b; } else { TORCHTRT_THROW_ERROR( "Unimplemented data type for aten::add evaluator: " @@ -349,8 +353,11 @@ auto aten_registrations TORCHTRT_UNUSED = return {}; } }, - EvalOptions().validSchemas( - {"aten::add.int(int a, int b) -> (int)", "aten::add.float(float a, float b) -> (float)"})}) + EvalOptions().validSchemas({ + "aten::add.int(int a, int b) -> (int)", + "aten::add.float(float a, float b) -> (float)", + "aten::add.str(str a, str b) -> (str)" + })}) .evaluator({c10::Symbol::fromQualString("aten::add_"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isList()) { From 79b8392aa1ce1fe3168b5c96dde02b9116451e87 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Sat, 7 May 2022 20:42:30 -0700 Subject: [PATCH 4/5] refactor(//core/lowering): Make logic a bit clearer in EE pass Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- .../lowering/passes/exception_elimination.cpp | 37 +++++++++---------- .../test_exception_elimination_pass.cpp | 6 +-- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/core/lowering/passes/exception_elimination.cpp b/core/lowering/passes/exception_elimination.cpp index a2cc5cc694..77aec78b08 100644 --- a/core/lowering/passes/exception_elimination.cpp +++ b/core/lowering/passes/exception_elimination.cpp @@ -44,10 +44,10 @@ struct ExceptionOrPassPatternElimination { bool arm1_starts_with_exception = (*arm1_start)->kind() == prim::RaiseException; bool arm2_starts_with_exception = (*arm2_start)->kind() == prim::RaiseException; - if (!arm1_starts_with_exception && !arm2_starts_with_exception) { + //if (!arm1_starts_with_exception && !arm2_starts_with_exception) { // Neither arm matches the pattern - return false; - } + // return false; + //} /// Check if this Node hosts a pattern like so: /// = prim::If(%5958) @@ -57,14 +57,12 @@ struct ExceptionOrPassPatternElimination { /// block1(): /// -> () if (arm1_starts_with_exception) { - if ((*(++arm1_start))->kind() != prim::Return) { + if ((*(++arm1_start))->kind() == prim::Return) { // Make sure that block0 is solely just the exception and the return - return false; - } - - if ((*(arm2_start))->kind() != prim::Return) { - // Make sure that block1 is solely the return - return false; + if ((*(arm2_start))->kind() == prim::Return) { + // Make sure that block1 is solely the return + return true; + } } } @@ -76,25 +74,23 @@ struct ExceptionOrPassPatternElimination { /// = prim::RaiseException(%45) /// -> () if (arm2_starts_with_exception) { - if ((*(++arm2_start))->kind() != prim::Return) { + if ((*(++arm2_start))->kind() == prim::Return) { // Make sure that block1 is solely just the exception and the return - return false; - } - - if ((*(arm1_start))->kind() != prim::Return) { - // Make sure that block0 is solely the return - return false; + if ((*(arm1_start))->kind() == prim::Return) { + // Make sure that block0 is solely the return + return true; + } } } - return true; + return false; } void findExceptionOrPassNodes(Block* b) { for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) { auto n = *it; if (n->kind() == prim::If && isExceptionOrPassNode(n)) { - LOG_GRAPH("Found that node " << *n << " is an exception or pass node (EliminateChecks)" << std::endl); + LOG_ERROR("Found that node " << *n << " is an exception or pass node (EliminateChecks)" << std::endl); it.destroyCurrent(); } } @@ -107,6 +103,9 @@ struct ExceptionOrPassPatternElimination { void EliminateExceptionOrPassPattern(std::shared_ptr graph) { ExceptionOrPassPatternElimination eppe(std::move(graph)); eppe.run(); + if (graph) { + LOG_ERROR("Post Eliminate Exception or Pass Patterns: " << *graph); + } } } // namespace passes diff --git a/tests/core/lowering/test_exception_elimination_pass.cpp b/tests/core/lowering/test_exception_elimination_pass.cpp index 441582ef0f..5a0931ee8d 100644 --- a/tests/core/lowering/test_exception_elimination_pass.cpp +++ b/tests/core/lowering/test_exception_elimination_pass.cpp @@ -44,7 +44,7 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block0) { auto if_block0 = if_node->addBlock(); auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0); if_block0->appendNode(exception_node); - auto if_block1 = if_node->addBlock(); + /*auto if_block1 =*/ if_node->addBlock(); g->insertNode(if_node); auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val}); g->insertNode(cat_node); @@ -97,7 +97,7 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block1) { bool_node->output()->setType(torch::jit::BoolType::get()); g->insertNode(bool_node); auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0); - auto if_block0 = if_node->addBlock(); + /*auto if_block0 = */if_node->addBlock(); auto if_block1 = if_node->addBlock(); auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0); if_block1->appendNode(exception_node); @@ -154,7 +154,7 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Negative) { auto if_block0 = if_node->addBlock(); auto append_node = g->create(torch::jit::aten::append, {list_node->output(), y}); if_block0->appendNode(append_node); - auto if_block1 = if_node->addBlock(); + /*auto if_block1 = */if_node->addBlock(); g->insertNode(if_node); auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val}); g->insertNode(cat_node); From b89a6f501306c1b7e7b7cbb32a2aa446dfbf5743 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Sat, 7 May 2022 20:51:43 -0700 Subject: [PATCH 5/5] refactor: apply linting Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/evaluators/aten.cpp | 8 +++----- core/lowering/passes/exception_elimination.cpp | 6 +++--- core/lowering/register_trt_placeholder_ops.cpp | 5 ++++- core/partitioning/partitioning.cpp | 0 tests/core/lowering/test_exception_elimination_pass.cpp | 6 +++--- 5 files changed, 13 insertions(+), 12 deletions(-) mode change 100755 => 100644 core/partitioning/partitioning.cpp diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 018b565421..4632744790 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -353,11 +353,9 @@ auto aten_registrations TORCHTRT_UNUSED = return {}; } }, - EvalOptions().validSchemas({ - "aten::add.int(int a, int b) -> (int)", - "aten::add.float(float a, float b) -> (float)", - "aten::add.str(str a, str b) -> (str)" - })}) + EvalOptions().validSchemas({"aten::add.int(int a, int b) -> (int)", + "aten::add.float(float a, float b) -> (float)", + "aten::add.str(str a, str b) -> (str)"})}) .evaluator({c10::Symbol::fromQualString("aten::add_"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isList()) { diff --git a/core/lowering/passes/exception_elimination.cpp b/core/lowering/passes/exception_elimination.cpp index 77aec78b08..abed644151 100644 --- a/core/lowering/passes/exception_elimination.cpp +++ b/core/lowering/passes/exception_elimination.cpp @@ -44,9 +44,9 @@ struct ExceptionOrPassPatternElimination { bool arm1_starts_with_exception = (*arm1_start)->kind() == prim::RaiseException; bool arm2_starts_with_exception = (*arm2_start)->kind() == prim::RaiseException; - //if (!arm1_starts_with_exception && !arm2_starts_with_exception) { - // Neither arm matches the pattern - // return false; + // if (!arm1_starts_with_exception && !arm2_starts_with_exception) { + // Neither arm matches the pattern + // return false; //} /// Check if this Node hosts a pattern like so: diff --git a/core/lowering/register_trt_placeholder_ops.cpp b/core/lowering/register_trt_placeholder_ops.cpp index 5ba8171208..17d7d3f47a 100644 --- a/core/lowering/register_trt_placeholder_ops.cpp +++ b/core/lowering/register_trt_placeholder_ops.cpp @@ -10,7 +10,10 @@ c10::AliasAnalysisKind aliasAnalysisFromSchema() { RegisterOperators trt_placeholder_ops_reg({ /// Op marks a Tensor to be conveted from an Torch Tensor /// to a TRT constant Tensor - Operator("trt::const(Tensor val) -> Tensor", [](Stack& stack) { /*noop*/ }, aliasAnalysisFromSchema()), + Operator( + "trt::const(Tensor val) -> Tensor", + [](Stack& stack) { /*noop*/ }, + aliasAnalysisFromSchema()), }); } // namespace jit diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp old mode 100755 new mode 100644 diff --git a/tests/core/lowering/test_exception_elimination_pass.cpp b/tests/core/lowering/test_exception_elimination_pass.cpp index 5a0931ee8d..b7e4ac00d1 100644 --- a/tests/core/lowering/test_exception_elimination_pass.cpp +++ b/tests/core/lowering/test_exception_elimination_pass.cpp @@ -44,7 +44,7 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block0) { auto if_block0 = if_node->addBlock(); auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0); if_block0->appendNode(exception_node); - /*auto if_block1 =*/ if_node->addBlock(); + /*auto if_block1 =*/if_node->addBlock(); g->insertNode(if_node); auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val}); g->insertNode(cat_node); @@ -97,7 +97,7 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block1) { bool_node->output()->setType(torch::jit::BoolType::get()); g->insertNode(bool_node); auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0); - /*auto if_block0 = */if_node->addBlock(); + /*auto if_block0 = */ if_node->addBlock(); auto if_block1 = if_node->addBlock(); auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0); if_block1->appendNode(exception_node); @@ -154,7 +154,7 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Negative) { auto if_block0 = if_node->addBlock(); auto append_node = g->create(torch::jit::aten::append, {list_node->output(), y}); if_block0->appendNode(append_node); - /*auto if_block1 = */if_node->addBlock(); + /*auto if_block1 = */ if_node->addBlock(); g->insertNode(if_node); auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val}); g->insertNode(cat_node);