From 89cf1679e9cb537fbb79ec3e9a0a4cd2ae8a6b9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Ingvar=20Dahlgren?= Date: Tue, 14 Mar 2023 16:26:27 +0100 Subject: [PATCH 1/5] Handle Piecewise in TransformVisitor (for e.g. cse) --- CMakeLists.txt | 4 +++- symengine/tests/basic/test_cse.cpp | 17 +++++++++++++++++ symengine/visitor.cpp | 25 +++++++++++++++++++++++++ symengine/visitor.h | 1 + 4 files changed, 46 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ce6968bc88..02f4e7c010 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,7 @@ cmake_minimum_required(VERSION 2.8.12) - +if (POLICY CMP0074) + cmake_policy(SET CMP0074 NEW) +endif() set(CMAKE_USER_MAKE_RULES_OVERRIDE ${CMAKE_CURRENT_SOURCE_DIR}/cmake/UserOverride.cmake) project(symengine LANGUAGES C CXX) diff --git a/symengine/tests/basic/test_cse.cpp b/symengine/tests/basic/test_cse.cpp index 3be6bead52..da37bd8940 100644 --- a/symengine/tests/basic/test_cse.cpp +++ b/symengine/tests/basic/test_cse.cpp @@ -1,19 +1,25 @@ #include "catch.hpp" +#include "symengine/dict.h" +#include "symengine/logic.h" #include #include #include #include #include +#include using SymEngine::add; using SymEngine::Basic; +using SymEngine::boolTrue; using SymEngine::cse; using SymEngine::div; +using SymEngine::Gt; using SymEngine::integer; using SymEngine::mul; using SymEngine::neg; using SymEngine::one; +using SymEngine::piecewise; using SymEngine::pow; using SymEngine::RCP; using SymEngine::sin; @@ -302,6 +308,17 @@ TEST_CASE("CSE: simple", "[cse]") REQUIRE(unified_eq(substs, {{x0, add(x, y)}, {x1, add(x0, z)}})); REQUIRE(unified_eq(reduced, {x0, add(i2, x0), x1, add(i3, x1)})); } + { + auto pw1 = piecewise( + {{pow(add(x, y), i2), Gt(x, y)}, {sqrt(add(x, y)), boolTrue}}); + + vec_pair substs; + vec_basic reduced; + cse(substs, reduced, {pw1}); + REQUIRE(unified_eq(substs, {{x0, add(x, y)}})); + REQUIRE(unified_eq(reduced, {piecewise({{pow(x0, i2), Gt(x, y)}, + {sqrt(x0), boolTrue}})})); + } } TEST_CASE("CSE: regression test gh-1463", "[cse]") diff --git a/symengine/visitor.cpp b/symengine/visitor.cpp index 1220f9681b..ed160591c8 100644 --- a/symengine/visitor.cpp +++ b/symengine/visitor.cpp @@ -201,6 +201,31 @@ void TransformVisitor::bvisit(const MultiArgFunction &x) result_ = nbarg; } +void TransformVisitor::bvisit(const Piecewise &x) +{ + auto branch_cond_pairs = x.get_vec(); + PiecewiseVec new_pairs; + bool changed = false; + for (const auto &branch_cond : branch_cond_pairs) { + auto branch = branch_cond.first; + auto cond = branch_cond.second; + auto new_branch = apply(branch); + // decltype(cond) new_cond = + // rcp_static_cast(apply(rcp_static_cast(cond))); + if (!changed) { + changed |= !eq(*new_branch, *branch); + // changed |= !eq(*new_cond, *cond); + } + new_pairs.push_back({new_branch, cond /*new_cond*/}); + } + if (changed) { + result_ = piecewise(new_pairs); + } else { + result_ = x.rcp_from_this(); + } +} + void preorder_traversal_local_stop(const Basic &b, LocalStopVisitor &v) { b.accept(v); diff --git a/symengine/visitor.h b/symengine/visitor.h index f201bf1350..5cec4ea29f 100644 --- a/symengine/visitor.h +++ b/symengine/visitor.h @@ -262,6 +262,7 @@ class TransformVisitor : public BaseVisitor } void bvisit(const MultiArgFunction &x); + void bvisit(const Piecewise &x); }; template From fed09385b8b97d479602cebef5d661b81a544f5e Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Tue, 14 Mar 2023 15:21:45 -0500 Subject: [PATCH 2/5] handle replacing condition --- symengine/visitor.cpp | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/symengine/visitor.cpp b/symengine/visitor.cpp index ed160591c8..da46fec40f 100644 --- a/symengine/visitor.cpp +++ b/symengine/visitor.cpp @@ -205,25 +205,18 @@ void TransformVisitor::bvisit(const Piecewise &x) { auto branch_cond_pairs = x.get_vec(); PiecewiseVec new_pairs; - bool changed = false; for (const auto &branch_cond : branch_cond_pairs) { auto branch = branch_cond.first; auto cond = branch_cond.second; auto new_branch = apply(branch); - // decltype(cond) new_cond = - // rcp_static_cast(apply(rcp_static_cast(cond))); - if (!changed) { - changed |= !eq(*new_branch, *branch); - // changed |= !eq(*new_cond, *cond); + auto new_cond = apply(cond); + if (!is_a_Boolean(*new_cond)) { + new_cond = Eq(new_cond, boolTrue); } - new_pairs.push_back({new_branch, cond /*new_cond*/}); - } - if (changed) { - result_ = piecewise(new_pairs); - } else { - result_ = x.rcp_from_this(); + new_pairs.push_back( + {new_branch, rcp_static_cast(new_cond)}); } + result_ = piecewise(new_pairs); } void preorder_traversal_local_stop(const Basic &b, LocalStopVisitor &v) From 0d291d930ddce647f3634b89d3597dce49a9db4c Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Tue, 14 Mar 2023 17:48:35 -0500 Subject: [PATCH 3/5] Do not replace atoms --- symengine/cse.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/symengine/cse.cpp b/symengine/cse.cpp index 6731fc80b8..08f8c276a8 100644 --- a/symengine/cse.cpp +++ b/symengine/cse.cpp @@ -501,7 +501,8 @@ void tree_cse(vec_pair &replacements, vec_basic &reduced_exprs, std::function & expr)> find_repeated; find_repeated = [&](RCP expr) -> void { - if (is_a_Number(*expr)) { + // Do not replace atoms + if (is_a_Number(*expr) or is_a(*expr)) { return; } From 6c644fac22053eb9e507908d6376b8c107810465 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Ingvar=20Dahlgren?= Date: Wed, 15 Mar 2023 12:05:00 +0100 Subject: [PATCH 4/5] Add test of cse in conditions of Piecewise --- symengine/tests/basic/test_cse.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/symengine/tests/basic/test_cse.cpp b/symengine/tests/basic/test_cse.cpp index da37bd8940..177a705fa9 100644 --- a/symengine/tests/basic/test_cse.cpp +++ b/symengine/tests/basic/test_cse.cpp @@ -319,6 +319,19 @@ TEST_CASE("CSE: simple", "[cse]") REQUIRE(unified_eq(reduced, {piecewise({{pow(x0, i2), Gt(x, y)}, {sqrt(x0), boolTrue}})})); } + { + auto pw2 = piecewise({{pow(x, i2), Gt(add(x, y), i2)}, + {sqrt(y), Gt(add(x, y), i3)}, + {sqrt(x), boolTrue}}); + + vec_pair substs; + vec_basic reduced; + cse(substs, reduced, {pw2}); + REQUIRE(unified_eq(substs, {{x0, add(x, y)}})); + REQUIRE(unified_eq(reduced, {piecewise({{pow(x, i2), Gt(x0, i2)}, + {sqrt(y), Gt(x0, i3)}, + {sqrt(x), boolTrue}})})); + } } TEST_CASE("CSE: regression test gh-1463", "[cse]") From 83877e48c65bfbf3a2c07ec40261586f3f251533 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Ingvar=20Dahlgren?= Date: Thu, 16 Mar 2023 09:15:33 +0100 Subject: [PATCH 5/5] drop auto-inserted includes --- symengine/tests/basic/test_cse.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/symengine/tests/basic/test_cse.cpp b/symengine/tests/basic/test_cse.cpp index 177a705fa9..b34dc1a496 100644 --- a/symengine/tests/basic/test_cse.cpp +++ b/symengine/tests/basic/test_cse.cpp @@ -1,6 +1,4 @@ #include "catch.hpp" -#include "symengine/dict.h" -#include "symengine/logic.h" #include #include