From 99cea1b6d5f20083dc4e818411988a34ee4fb4cf Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Thu, 5 May 2022 16:43:47 -0700 Subject: [PATCH] fix: Resolve issues in exception elmination pass Signed-off-by: Michael Feliz --- .../lowering/passes/exception_elimination.cpp | 12 +- tests/core/lowering/BUILD | 5 + .../test_exception_elimination_pass.cpp | 167 ++++++++++++++++++ 3 files changed, 182 insertions(+), 2 deletions(-) create mode 100644 tests/core/lowering/test_exception_elimination_pass.cpp diff --git a/core/lowering/passes/exception_elimination.cpp b/core/lowering/passes/exception_elimination.cpp index cdd7603792..a2cc5cc694 100644 --- a/core/lowering/passes/exception_elimination.cpp +++ b/core/lowering/passes/exception_elimination.cpp @@ -41,6 +41,14 @@ struct ExceptionOrPassPatternElimination { auto arm1_start = arm1->nodes().begin(); auto arm2_start = arm2->nodes().begin(); + bool arm1_starts_with_exception = (*arm1_start)->kind() == prim::RaiseException; + bool arm2_starts_with_exception = (*arm2_start)->kind() == prim::RaiseException; + + if (!arm1_starts_with_exception && !arm2_starts_with_exception) { + // Neither arm matches the pattern + return false; + } + /// Check if this Node hosts a pattern like so: /// = prim::If(%5958) /// block0(): @@ -48,7 +56,7 @@ struct ExceptionOrPassPatternElimination { /// -> () /// block1(): /// -> () - if ((*arm1_start)->kind() == prim::RaiseException) { + if (arm1_starts_with_exception) { if ((*(++arm1_start))->kind() != prim::Return) { // Make sure that block0 is solely just the exception and the return return false; @@ -67,7 +75,7 @@ struct ExceptionOrPassPatternElimination { /// block1(): /// = prim::RaiseException(%45) /// -> () - if ((*arm2_start)->kind() == prim::RaiseException) { + if (arm2_starts_with_exception) { if ((*(++arm2_start))->kind() != prim::Return) { // Make sure that block1 is solely just the exception and the return return false; diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index 26940d11eb..526fe467fa 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -30,6 +30,10 @@ lowering_test( name = "test_conv1d_pass", ) +lowering_test( + name = "test_exception_elimination_pass", +) + lowering_test( name = "test_remove_contiguous_pass", ) @@ -82,6 +86,7 @@ test_suite( name = "lowering_tests", tests = [ ":test_conv1d_pass", + ":test_exception_elimination_pass", ":test_linear_to_addmm", ":test_module_fallback_passes", ":test_operator_aliasing_pass", diff --git a/tests/core/lowering/test_exception_elimination_pass.cpp b/tests/core/lowering/test_exception_elimination_pass.cpp new file mode 100644 index 0000000000..e8a96ca97a --- /dev/null +++ b/tests/core/lowering/test_exception_elimination_pass.cpp @@ -0,0 +1,167 @@ +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "torch/csrc/jit/ir/irparser.h" + +TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block0) { + // parseIR does not support " = prim::If(%51)" with no return value + /*std::string source_ir = R"IR(graph(%x.1 : Tensor, %y.1 : Tensor): + %3 : NoneType = prim::Constant() + %4 : int = prim::Constant[value=0]() + %mod_list.1 : Tensor[] = prim::ListConstruct(%x.1) + %47 : Tensor = aten::sum(%x.1, %3) + %49 : Tensor = aten::sum(%y.1, %3) + %50 : Tensor = aten::gt(%47, %49) + %51 : bool = aten::Bool(%50) + = prim::If(%51) + block0(): + = prim::RaiseException(%45) + -> () + block1(): + -> () + %z.1 : Tensor = aten::cat(%mod_list.1, %4) + return (%z.1))IR";*/ + + auto g = std::make_shared(); + auto x = g->insertInput(0, "x"); + auto y = g->insertInput(1, "y"); + torch::jit::IValue zero(0); + auto zero_const_val = g->insertConstant(zero); + auto none_const_val = g->insertConstant(torch::jit::IValue()); + torch::jit::IValue except("EXCEPTION"); + auto except_val = g->insertConstant(except); + auto list_node = g->createList(x->type(), torch::jit::ArrayRef(x)); + g->insertNode(list_node); + auto sum_x_node = g->create(torch::jit::aten::sum, {x, none_const_val}); + g->insertNode(sum_x_node); + auto sum_y_node = g->create(torch::jit::aten::sum, {y, none_const_val}); + g->insertNode(sum_y_node); + auto gt_node = g->create(torch::jit::aten::gt, {sum_x_node->output(), sum_y_node->output()}); + g->insertNode(gt_node); + auto bool_node = g->create(torch::jit::aten::Bool, {gt_node->output()}); + bool_node->output()->setType(torch::jit::BoolType::get()); + g->insertNode(bool_node); + auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0); + auto if_block0 = if_node->addBlock(); + auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val}, 0); + if_block0->appendNode(exception_node); + auto if_block1 = if_node->addBlock(); + g->insertNode(if_node); + auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val}); + g->insertNode(cat_node); + g->registerOutput(cat_node->output()); + + torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern(g); + for (auto node : g->nodes()) { + EXPECT_NE(node, if_node); + } +} + +TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block1) { + // parseIR does not support " = prim::If(%51)" with no return value + /*std::string source_ir = R"IR(graph(%x.1 : Tensor, %y.1 : Tensor): + %3 : NoneType = prim::Constant() + %4 : int = prim::Constant[value=0]() + %mod_list.1 : Tensor[] = prim::ListConstruct(%x.1) + %47 : Tensor = aten::sum(%x.1, %3) + %49 : Tensor = aten::sum(%y.1, %3) + %50 : Tensor = aten::gt(%47, %49) + %51 : bool = aten::Bool(%50) + = prim::If(%51) + block0(): + -> () + block1(): + = prim::RaiseException(%45) + -> () + %z.1 : Tensor = aten::cat(%mod_list.1, %4) + return (%z.1))IR";*/ + + auto g = std::make_shared(); + auto x = g->insertInput(0, "x"); + auto y = g->insertInput(1, "y"); + torch::jit::IValue zero(0); + auto zero_const_val = g->insertConstant(zero); + auto none_const_val = g->insertConstant(torch::jit::IValue()); + torch::jit::IValue except("EXCEPTION"); + auto except_val = g->insertConstant(except); + auto list_node = g->createList(x->type(), torch::jit::ArrayRef(x)); + g->insertNode(list_node); + auto sum_x_node = g->create(torch::jit::aten::sum, {x, none_const_val}); + g->insertNode(sum_x_node); + auto sum_y_node = g->create(torch::jit::aten::sum, {y, none_const_val}); + g->insertNode(sum_y_node); + auto gt_node = g->create(torch::jit::aten::gt, {sum_x_node->output(), sum_y_node->output()}); + g->insertNode(gt_node); + auto bool_node = g->create(torch::jit::aten::Bool, {gt_node->output()}); + bool_node->output()->setType(torch::jit::BoolType::get()); + g->insertNode(bool_node); + auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0); + auto if_block0 = if_node->addBlock(); + auto if_block1 = if_node->addBlock(); + auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val}, 0); + if_block1->appendNode(exception_node); + g->insertNode(if_node); + auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val}); + g->insertNode(cat_node); + g->registerOutput(cat_node->output()); + + torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern(g); + for (auto node : g->nodes()) { + EXPECT_NE(node, if_node); + } +} + +TEST(LoweringPasses, EliminateExceptionOrPassPattern_Negative) { + // parseIR does not support " = prim::If(%51)" with no return value + /*std::string source_ir = R"IR(graph(%x.1 : Tensor, %y.1 : Tensor): + %3 : NoneType = prim::Constant() + %4 : int = prim::Constant[value=0]() + %mod_list.1 : Tensor[] = prim::ListConstruct(%x.1) + %47 : Tensor = aten::sum(%x.1, %3) + %49 : Tensor = aten::sum(%y.1, %3) + %50 : Tensor = aten::gt(%47, %49) + %51 : bool = aten::Bool(%50) + = prim::If(%51) + block0(): + %10 : Tensor[] = aten::append(%mod_list.1, %y.1) + -> () + block1(): + -> () + %z.1 : Tensor = aten::cat(%mod_list.1, %4) + return (%z.1))IR";*/ + + auto g = std::make_shared(); + auto x = g->insertInput(0, "x"); + auto y = g->insertInput(1, "y"); + torch::jit::IValue zero(0); + auto zero_const_val = g->insertConstant(zero); + auto none_const_val = g->insertConstant(torch::jit::IValue()); + auto list_node = g->createList(x->type(), torch::jit::ArrayRef(x)); + g->insertNode(list_node); + auto sum_x_node = g->create(torch::jit::aten::sum, {x, none_const_val}); + g->insertNode(sum_x_node); + auto sum_y_node = g->create(torch::jit::aten::sum, {y, none_const_val}); + g->insertNode(sum_y_node); + auto gt_node = g->create(torch::jit::aten::gt, {sum_x_node->output(), sum_y_node->output()}); + g->insertNode(gt_node); + auto bool_node = g->create(torch::jit::aten::Bool, {gt_node->output()}); + bool_node->output()->setType(torch::jit::BoolType::get()); + g->insertNode(bool_node); + auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0); + auto if_block0 = if_node->addBlock(); + auto append_node = g->create(torch::jit::aten::append, {list_node->output(), y}); + if_block0->appendNode(append_node); + auto if_block1 = if_node->addBlock(); + g->insertNode(if_node); + auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val}); + g->insertNode(cat_node); + g->registerOutput(cat_node->output()); + + torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern(g); + int if_count = 0; + for (auto node : g->nodes()) { + if (node == if_node) { + if_count++; + } + } + EXPECT_EQ(1, if_count); +} \ No newline at end of file