diff --git a/symengine/subs.h b/symengine/subs.h index 11d449df0b..06012780ad 100644 --- a/symengine/subs.h +++ b/symengine/subs.h @@ -276,6 +276,229 @@ class SubsVisitor : public BaseVisitor : BaseVisitor(subs_dict_) { } + void bvisit(const Mul &x) + { + RCP coef = x.get_coef(); + map_basic_basic dict = x.get_dict(); + map_basic_basic d; + bool fast_exec = false; + for (const auto &p : x.get_dict()) { + RCP factor_old; + if (eq(*p.second, *one)) { + factor_old = p.first; + } else { + factor_old = make_rcp(p.first, p.second); + } + RCP factor = apply(factor_old); + if (factor == factor_old) { + Mul::dict_add_term_new(outArg(coef), d, p.second, p.first); + } else if (is_a_Number(*factor)) { + fast_exec = true; + if (down_cast(*factor).is_zero()) { + result_ = factor; + return; + } + imulnum(outArg(coef), rcp_static_cast(factor)); + } else if (is_a(*factor)) { + fast_exec = true; + RCP tmp = rcp_static_cast(factor); + imulnum(outArg(coef), tmp->get_coef()); + for (const auto &q : tmp->get_dict()) { + Mul::dict_add_term_new(outArg(coef), d, q.second, q.first); + } + } else { + fast_exec = true; + RCP exp, t; + Mul::as_base_exp(factor, outArg(exp), outArg(t)); + Mul::dict_add_term_new(outArg(coef), d, exp, t); + } + } + if (fast_exec) { + result_ = Mul::from_dict(coef, std::move(d)); + return; + } + for (const auto &iter : subs_dict_) { + d.clear(); + bool exists = true; + auto sub1 = iter.first; + auto rep = iter.second; + if (is_a(*sub1)) { + RCP subst = rcp_static_cast(sub1); + for (auto &p : subst->get_dict()) { + auto it = dict.find(p.first); + RCP diff_; + if (it != dict.end()) + diff_ = sub(it->second, p.second); + if (it == dict.end() + || down_cast(*diff_).is_negative()) { + exists = false; + break; + } else { + if (!down_cast(*diff_).is_zero()) + Mul::dict_add_term_new(outArg(coef), d, + sub(it->second, p.second), + p.first); + } + } + if (exists) { + for (const auto &p : dict) { + auto it = subst->get_dict().find(p.first); + if (it == subst->get_dict().end()) + Mul::dict_add_term_new(outArg(coef), d, p.second, + p.first); + } + if (is_a_Number(*rep)) { + if (down_cast(*rep).is_zero()) { + result_ = rep; + return; + } + imulnum(outArg(coef), + rcp_static_cast(rep)); + } else if (is_a(*rep)) { + RCP tmp = rcp_static_cast(rep); + imulnum(outArg(coef), tmp->get_coef()); + for (const auto &q : tmp->get_dict()) { + Mul::dict_add_term_new(outArg(coef), d, q.second, + q.first); + } + } else { + RCP exp, t; + Mul::as_base_exp(rep, outArg(exp), outArg(t)); + Mul::dict_add_term_new(outArg(coef), d, exp, t); + } + } else + d = x.get_dict(); + } else if (is_a(*sub1)) { + RCP subst = rcp_static_cast(sub1); + auto sub1_exp = subst->get_exp(); + auto sub1_base = subst->get_base(); + exists = false; + if (is_a_Number(*sub1_exp)) { + for (const auto &p : dict) { + auto diff_ = sub(p.second, sub1_exp); + if (eq(*sub1_base, *(p.first)) + and eq(*sub1_exp, *p.second)) { + exists = true; + } else if (eq(*sub1_base, *(p.first)) + and down_cast(*diff_) + .is_positive()) { + exists = true; + Mul::dict_add_term_new(outArg(coef), d, + sub(p.second, sub1_exp), + p.first); + } else { + Mul::dict_add_term_new(outArg(coef), d, p.second, + p.first); + } + } + } else { + for (const auto &p : dict) { + if (eq(*sub1_base, *(p.first)) + and eq(*sub1_exp, *p.second)) { + exists = true; + } else { + Mul::dict_add_term_new(outArg(coef), d, p.second, + p.first); + } + } + } + if (exists) { + if (is_a_Number(*rep)) { + if (down_cast(*rep).is_zero()) { + result_ = rep; + return; + } + imulnum(outArg(coef), + rcp_static_cast(rep)); + } else if (is_a(*rep)) { + RCP tmp = rcp_static_cast(rep); + imulnum(outArg(coef), tmp->get_coef()); + for (const auto &q : tmp->get_dict()) { + Mul::dict_add_term_new(outArg(coef), d, q.second, + q.first); + } + } else { + RCP exp, t; + Mul::as_base_exp(rep, outArg(exp), outArg(t)); + Mul::dict_add_term_new(outArg(coef), d, exp, t); + } + } else + d = x.get_dict(); + } else if (is_a(*sub1)) { + exists = false; + for (const auto &p : dict) { + if (eq(*sub1, *(p.first)) and eq(*one, *p.second)) { + exists = true; + } else if (eq(*sub1, *(p.first)) + and not eq(*one, *p.second)) { + exists = true; + Mul::dict_add_term_new(outArg(coef), d, + sub(p.second, one), p.first); + } else { + Mul::dict_add_term_new(outArg(coef), d, p.second, + p.first); + } + } + if (exists) { + if (is_a_Number(*rep)) { + if (down_cast(*rep).is_zero()) { + result_ = rep; + return; + } + imulnum(outArg(coef), + rcp_static_cast(rep)); + } else if (is_a(*rep)) { + RCP tmp = rcp_static_cast(rep); + imulnum(outArg(coef), tmp->get_coef()); + for (const auto &q : tmp->get_dict()) { + Mul::dict_add_term_new(outArg(coef), d, q.second, + q.first); + } + } else { + RCP exp, t; + Mul::as_base_exp(rep, outArg(exp), outArg(t)); + Mul::dict_add_term_new(outArg(coef), d, exp, t); + } + } else + d = x.get_dict(); + } else { + exists = false; + for (const auto &p : dict) { + if (eq(*sub1, *(p.first))) { + exists = true; + } else { + Mul::dict_add_term_new(outArg(coef), d, p.second, + p.first); + } + } + if (exists) { + if (is_a_Number(*rep)) { + if (down_cast(*rep).is_zero()) { + result_ = rep; + return; + } + imulnum(outArg(coef), + rcp_static_cast(rep)); + } else if (is_a(*rep)) { + RCP tmp = rcp_static_cast(rep); + imulnum(outArg(coef), tmp->get_coef()); + for (const auto &q : tmp->get_dict()) { + Mul::dict_add_term_new(outArg(coef), d, q.second, + q.first); + } + } else { + RCP exp, t; + Mul::as_base_exp(rep, outArg(exp), outArg(t)); + Mul::dict_add_term_new(outArg(coef), d, exp, t); + } + } else + d = x.get_dict(); + } + dict.clear(); + dict.insert(d.begin(), d.end()); + } + result_ = Mul::from_dict(coef, std::move(d)); + } void bvisit(const Pow &x) { diff --git a/symengine/tests/basic/test_subs.cpp b/symengine/tests/basic/test_subs.cpp index d39bc29dbf..d06309969e 100644 --- a/symengine/tests/basic/test_subs.cpp +++ b/symengine/tests/basic/test_subs.cpp @@ -152,6 +152,77 @@ TEST_CASE("Mul: subs", "[subs]") r2 = z; REQUIRE(eq(*r1->subs(d), *r2)); + d.clear(); + d[mul(x, y)] = z; + r1 = mul(mul(x, y), z); + r2 = pow(z, i2); + REQUIRE(eq(*r1->subs(d), *r2)); + + d.clear(); + d[mul(x, y)] = mul(y, z); + r1 = mul(mul(x, y), z); + r2 = mul(y, pow(z, i2)); + REQUIRE(eq(*r1->subs(d), *r2)); + + d.clear(); + d[mul(x, y)] = i4; + r1 = mul(mul(x, y), z); + r2 = mul(z, i4); + REQUIRE(eq(*r1->subs(d), *r2)); + + d.clear(); + d[mul(x, y)] = z; + r1 = add(add(mul(mul(pow(y, i2), x), i2), mul(i3, pow(x, i2))), + mul(i4, pow(y, i2))); + r2 = add(add(mul(i3, pow(x, i2)), mul(i4, pow(y, i2))), mul(mul(i2, y), z)); + REQUIRE(eq(*r1->subs(d), *r2)); + + d.clear(); + d[mul(mul(x, y), z)] = i2; + r1 = add( + add(mul(mul(mul(pow(y, i3), x), i2), pow(z, i2)), mul(i3, pow(x, i2))), + mul(i4, pow(y, i2))); + r2 = add(add(mul(i3, pow(x, i2)), mul(i4, pow(y, i2))), + mul(mul(pow(y, i2), z), i4)); + REQUIRE(eq(*r1->subs(d), *r2)); + + d.clear(); + d[mul(x, y)] = z; + r1 = mul(x, z); + r2 = mul(x, z); + REQUIRE(eq(*r1->subs(d), *r2)); + + d.clear(); + d[pow(x, y)] = z; + r1 = mul(z, pow(x, y)); + r2 = pow(z, i2); + REQUIRE(eq(*r1->subs(d), *r2)); + + d.clear(); + d[pow(x, i2)] = z; + r1 = mul(z, pow(x, i2)); + r2 = pow(z, i2); + REQUIRE(eq(*r1->subs(d), *r2)); + + d.clear(); + d[mul(y, pow(z, i2))] = x; + r1 = add(add(mul(mul(x, y), pow(z, i2)), mul(mul(z, y), pow(x, i2))), + mul(y, pow(z, i3))); + r2 = add(add(pow(x, i2), mul(mul(z, y), pow(x, i2))), mul(x, z)); + REQUIRE(eq(*r1->subs(d), *r2)); + + d.clear(); + d[x] = i2; + r1 = mul(pow(x, i2), y); + r2 = mul(i4, y); + REQUIRE(eq(*r1->subs(d), *r2)); + + d.clear(); + d[x] = i2; + r1 = add(mul(pow(x, i2), y), mul(pow(x, i2), z)); + r2 = add(mul(i4, y), mul(i4, z)); + REQUIRE(eq(*r1->subs(d), *r2)); + d.clear(); d[pow(x, y)] = z; r1 = mul(i2, pow(x, y));