diff --git a/symengine/assumptions.cpp b/symengine/assumptions.cpp index 4287e707ba..df44b24ace 100644 --- a/symengine/assumptions.cpp +++ b/symengine/assumptions.cpp @@ -14,7 +14,9 @@ Assumptions::Assumptions(const set_basic &statements) const auto expr = contains.get_expr(); const auto set = contains.get_set(); if (is_a(*expr)) { - if (is_a(*set)) { + if (is_a(*set)) { + complex_symbols_.insert(expr); + } else if (is_a(*set)) { complex_symbols_.insert(expr); real_symbols_.insert(expr); } else if (is_a(*set)) { diff --git a/symengine/tests/basic/test_assumptions.cpp b/symengine/tests/basic/test_assumptions.cpp index 27b0f8a6db..8f218e84a2 100644 --- a/symengine/tests/basic/test_assumptions.cpp +++ b/symengine/tests/basic/test_assumptions.cpp @@ -4,6 +4,7 @@ using SymEngine::Assumptions; using SymEngine::Basic; +using SymEngine::complexes; using SymEngine::integer; using SymEngine::integers; using SymEngine::Number; @@ -39,6 +40,10 @@ TEST_CASE("Test assumptions", "[assumptions]") RCP rel14 = Ne(x, integer(0)); RCP rel; + Assumptions a = Assumptions({complexes()->contains(x)}); + REQUIRE(is_true(a.is_complex(x))); + REQUIRE(is_indeterminate(a.is_real(x))); + auto a1 = Assumptions({s1->contains(x)}); REQUIRE(is_true(a1.is_real(x))); REQUIRE(is_indeterminate(a1.is_integer(x))); @@ -204,7 +209,7 @@ TEST_CASE("Test assumptions", "[assumptions]") REQUIRE(is_false(a18.is_zero(x))); rel = Eq(x, integer(1)); - auto a = Assumptions({rel}); + a = Assumptions({rel}); REQUIRE(is_true(a.is_nonzero(x))); REQUIRE(is_false(a.is_zero(x)));