diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 08ce69b72c..aec9ebb8e1 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -49,6 +49,7 @@ void LowerGraph(std::shared_ptr& g) { passes::UnpackLogSoftmax(g); passes::RemoveNOPs(g); passes::AliasOperators(g); + passes::SiluToSigmoidMultipication(g); torch::jit::EliminateDeadCode(g); LOG_GRAPH(*g); } diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD index e22d0a59b1..f213a2539a 100644 --- a/core/lowering/passes/BUILD +++ b/core/lowering/passes/BUILD @@ -25,7 +25,8 @@ cc_library( "unpack_addmm.cpp", "unpack_batch_norm.cpp", "unpack_log_softmax.cpp", - "op_aliasing.cpp" + "op_aliasing.cpp", + "silu_to_sigmoid_multiplication.cpp" ], deps = [ "//core/util:prelude", diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index d6bf083a18..770982f67f 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -20,6 +20,7 @@ void UnpackAddMM(std::shared_ptr& graph); void UnpackBatchNorm(std::shared_ptr& graph); void UnpackLogSoftmax(std::shared_ptr& graph); void AliasOperators(std::shared_ptr& graph); +void SiluToSigmoidMultipication(std::shared_ptr& graph); } // namespace passes } // namespace lowering diff --git a/core/lowering/passes/silu_to_sigmoid_multiplication.cpp b/core/lowering/passes/silu_to_sigmoid_multiplication.cpp new file mode 100644 index 0000000000..782e659788 --- /dev/null +++ b/core/lowering/passes/silu_to_sigmoid_multiplication.cpp @@ -0,0 +1,31 @@ +#include + +#include "core/util/prelude.h" + +namespace trtorch { +namespace core { +namespace lowering { +namespace passes { + +void SiluToSigmoidMultipication(std::shared_ptr& graph) { + std::string silu_pattern = R"IR( + graph(%x): + %1 : Tensor = aten::silu(%x) + return (%1))IR"; + std::string sigmoid_multiplication_pattern = R"IR( + graph(%x): + %1 : Tensor = aten::sigmoid(%x) + %2 : Tensor = aten::mul(%x, %1) + return (%2))IR"; + ; + + torch::jit::SubgraphRewriter map_silu_to_sigmoid_multiplication; + map_silu_to_sigmoid_multiplication.RegisterRewritePattern(silu_pattern, sigmoid_multiplication_pattern); + map_silu_to_sigmoid_multiplication.runOnGraph(graph); + LOG_GRAPH("Post map silu -> x * sigmoid(x): " << *graph); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace trtorch diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index dd77960b8c..7742a07e06 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -23,6 +23,10 @@ lowering_test( name = "test_operator_aliasing_pass", ) +lowering_test( + name = "test_silu_to_sigmoid_multiplication", +) + test_suite( name = "lowering_tests", tests = [ diff --git a/tests/core/lowering/test_silu_to_sigmoid_multiplication.cpp b/tests/core/lowering/test_silu_to_sigmoid_multiplication.cpp new file mode 100644 index 0000000000..fec02711ed --- /dev/null +++ b/tests/core/lowering/test_silu_to_sigmoid_multiplication.cpp @@ -0,0 +1,29 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/ir/subgraph_matcher.h" + +TEST(LoweringPasses, RemoveSiluLowersCorrectly) { + std::string source_graph = R"IR( + graph(%x.1 : Tensor): + %2 : Tensor = aten::silu(%x.1) + return (%2))IR"; + std::string target_graph = R"IR( + graph(%x.1): + %2 : Tensor = aten::sigmoid(%x.1) + %3 : Tensor = aten::mul(%x.1, %2) + return (%3))IR"; + + trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + trtorch::core::lowering::passes::SiluToSigmoidMultipication(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} \ No newline at end of file