From 08e749d481cd9cbdd28a7efbaf36309ad288c6ea Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Sun, 20 Jun 2021 15:44:33 -0500 Subject: [PATCH] Fix atan2(a, a) --- symengine/constants.cpp | 20 ++++---- symengine/functions.cpp | 62 ++++++++---------------- symengine/tests/basic/test_functions.cpp | 4 +- 3 files changed, 31 insertions(+), 55 deletions(-) diff --git a/symengine/constants.cpp b/symengine/constants.cpp index b900e84056..951894a2c4 100644 --- a/symengine/constants.cpp +++ b/symengine/constants.cpp @@ -111,20 +111,20 @@ umap_basic_basic inverse_cst = { }; umap_basic_basic inverse_tct = { - {div(one, sq3), mul(i2, i3)}, - {div(minus_one, sq3), mul(im2, i3)}, + {div(one, sq3), integer(6)}, + {div(minus_one, sq3), integer(-6)}, {sq3, i3}, {mul(minus_one, sq3), im3}, - {add(one, sq2), div(pow(i2, i3), i3)}, - {mul(minus_one, add(one, sq2)), div(pow(i2, i3), im3)}, - {sub(sq2, one), pow(i2, i3)}, - {sub(one, sq2), pow(im2, i3)}, - {sub(i2, sq3), mul(mul(i2, i2), i3)}, - {sub(sq3, i2), mul(mul(im2, i2), i3)}, + {add(one, sq2), div(integer(8), i3)}, + {mul(minus_one, add(one, sq2)), div(integer(8), im3)}, + {sub(sq2, one), integer(8)}, + {sub(one, sq2), integer(-8)}, + {sub(i2, sq3), integer(12)}, + {sub(sq3, i2), integer(-12)}, {sqrt(add(i5, mul(i2, sqrt(i5)))), div(i5, i2)}, {mul(minus_one, sqrt(add(i5, mul(i2, sqrt(i5))))), div(im5, i2)}, - {one, pow(i2, i2)}, - {minus_one, mul(minus_one, pow(i2, i2))}, + {one, integer(4)}, + {minus_one, integer(-4)}, }; } // namespace SymEngine diff --git a/symengine/functions.cpp b/symengine/functions.cpp index 2f760760ec..3278423811 100644 --- a/symengine/functions.cpp +++ b/symengine/functions.cpp @@ -522,6 +522,13 @@ RCP sign(const RCP &arg) return mul(s, make_rcp(Mul::from_dict(one, std::move(dict)))); } + if (is_a(*arg)) { + RCP pow_arg = rcp_static_cast(arg); + RCP s = sign(pow_arg->get_base()); + if (not is_a(*s) and not eq(*s, *pow_arg->get_base())) { + return sign(pow(s, pow_arg->get_exp())); + } + } return make_rcp(arg); } @@ -1571,54 +1578,23 @@ RCP ATan2::create(const RCP &a, RCP atan2(const RCP &num, const RCP &den) { if (eq(*num, *zero)) { - if (is_a_Number(*den)) { - RCP den_new = rcp_static_cast(den); - if (den_new->is_negative()) - return pi; - else if (den_new->is_positive()) - return zero; - else { - return Nan; - } + if (eq(*den, *zero)) { + return Nan; } + return mul(div(pi, im2), sub(sign(den), one)); } else if (eq(*den, *zero)) { - if (is_a_Number(*num)) { - RCP num_new = rcp_static_cast(num); - if (num_new->is_negative()) - return div(pi, im2); - else - return div(pi, i2); - } + return mul(div(pi, i2), sign(num)); } - RCP index; - bool b = inverse_lookup(inverse_tct, div(num, den), outArg(index)); + RCP divided = div(num, den); + RCP index_b; + bool b = inverse_lookup(inverse_tct, divided, outArg(index_b)); if (b) { - // Ideally the answer should depend on the signs of `num` and `den` - // Currently is_positive() and is_negative() is not implemented for - // types other than `Number` - // Hence this will give exact answers in case when num and den are - // numbers in SymEngine sense and when num and den are positive. - // for the remaining cases in which we just return the value from - // the lookup table. - // TODO: update once is_positive() and is_negative() is implemented - // in `Basic` - if (is_a_Number(*den) and is_a_Number(*num)) { - RCP den_new = rcp_static_cast(den); - RCP num_new = rcp_static_cast(num); - - if (den_new->is_positive()) { - return div(pi, index); - } else if (den_new->is_negative()) { - if (num_new->is_negative()) { - return sub(div(pi, index), pi); - } else { - return add(div(pi, index), pi); - } - } else { - return div(pi, index); - } + SYMENGINE_ASSERT(is_a_Number(*index_b)); + RCP index = rcp_static_cast(index_b); + if (index->is_positive()) { + return add(div(pi, index), mul(div(pi, i2), sub(sign(den), one))); } else { - return div(pi, index); + return sub(div(pi, index), mul(div(pi, i2), sub(sign(den), one))); } } else { return make_rcp(num, den); diff --git a/symengine/tests/basic/test_functions.cpp b/symengine/tests/basic/test_functions.cpp index 0f0b059117..572bd3acce 100644 --- a/symengine/tests/basic/test_functions.cpp +++ b/symengine/tests/basic/test_functions.cpp @@ -2218,7 +2218,7 @@ TEST_CASE("Atan2: functions", "[functions]") REQUIRE(eq(*r1, *r2)); r1 = atan2(add(one, sqrt(i2)), im1); - r2 = div(mul(pi, i3), integer(-8)); + r2 = div(mul(pi, i5), integer(8)); REQUIRE(eq(*r1, *r2)); r1 = atan2(sub(sqrt(i2), one), i1); @@ -2230,7 +2230,7 @@ TEST_CASE("Atan2: functions", "[functions]") REQUIRE(eq(*r1, *r2)); r1 = atan2(sqrt(add(i5, mul(i2, sqrt(i5)))), im1); - r2 = div(mul(pi, im2), i5); + r2 = div(mul(pi, i3), i5); REQUIRE(eq(*r1, *r2)); r1 = atan2(y, x)->diff(x);