From e0433433b658984c3ee62efaf21be8aad7b9d046 Mon Sep 17 00:00:00 2001 From: Srajan Garg Date: Fri, 22 Jan 2016 03:09:01 +0530 Subject: [PATCH] fixed main trigno functions, to simplify if arg is inverse --- symengine/functions.cpp | 42 +++++++++++++++++------- symengine/tests/basic/test_functions.cpp | 24 ++++++++++++++ 2 files changed, 54 insertions(+), 12 deletions(-) diff --git a/symengine/functions.cpp b/symengine/functions.cpp index dfb34955d4..1e2f63fa36 100644 --- a/symengine/functions.cpp +++ b/symengine/functions.cpp @@ -327,9 +327,12 @@ RCP sin(const RCP &arg) if (is_a_Number(*arg) and not static_cast(*arg).is_exact()) { return static_cast(*arg).get_eval().sin(*arg); } + if (is_a(*arg)) { + return rcp_static_cast(arg)->get_arg(); + } + RCP ret_arg; - int index; - int sign; + int index, sign; bool conjugate = eval(arg, 2, 1, 0, //input outArg(ret_arg), index, sign); //output @@ -406,9 +409,12 @@ RCP cos(const RCP &arg) if (is_a_Number(*arg) and not static_cast(*arg).is_exact()) { return static_cast(*arg).get_eval().cos(*arg); } + if (is_a(*arg)) { + return rcp_static_cast(arg)->get_arg(); + } + RCP ret_arg; - int index; - int sign; + int index, sign; bool conjugate = eval(arg, 2, 0, 1, //input outArg(ret_arg), index, sign); //output @@ -484,9 +490,12 @@ RCP tan(const RCP &arg) if (is_a_Number(*arg) and not static_cast(*arg).is_exact()) { return static_cast(*arg).get_eval().tan(*arg); } + if (is_a(*arg)) { + return rcp_static_cast(arg)->get_arg(); + } + RCP ret_arg; - int index; - int sign; + int index, sign; bool conjugate = eval(arg, 1, 1, 1, //input outArg(ret_arg), index, sign); //output @@ -563,9 +572,12 @@ RCP cot(const RCP &arg) if (is_a_Number(*arg) and not static_cast(*arg).is_exact()) { return static_cast(*arg).get_eval().cot(*arg); } + if (is_a(*arg)) { + return rcp_static_cast(arg)->get_arg(); + } + RCP ret_arg; - int index; - int sign; + int index, sign; bool conjugate = eval(arg, 1, 1, 1, //input outArg(ret_arg), index, sign); //output @@ -643,9 +655,12 @@ RCP csc(const RCP &arg) if (is_a_Number(*arg) and not static_cast(*arg).is_exact()) { return static_cast(*arg).get_eval().csc(*arg); } + if (is_a(*arg)) { + return rcp_static_cast(arg)->get_arg(); + } + RCP ret_arg; - int index; - int sign; + int index, sign; bool conjugate = eval(arg, 2, 1, 0, //input outArg(ret_arg), index, sign); //output @@ -723,9 +738,12 @@ RCP sec(const RCP &arg) if (is_a_Number(*arg) and not static_cast(*arg).is_exact()) { return static_cast(*arg).get_eval().sec(*arg); } + if (is_a(*arg)) { + return rcp_static_cast(arg)->get_arg(); + } + RCP ret_arg; - int index; - int sign; + int index, sign; bool conjugate = eval(arg, 2, 0, 1, //input outArg(ret_arg), index, sign); //output diff --git a/symengine/tests/basic/test_functions.cpp b/symengine/tests/basic/test_functions.cpp index e6efdee5e0..ad3faf5e80 100644 --- a/symengine/tests/basic/test_functions.cpp +++ b/symengine/tests/basic/test_functions.cpp @@ -170,6 +170,10 @@ TEST_CASE("Sin: functions", "[functions]") r2 = sin(y); REQUIRE(eq(*r1, *r2)); + // sin(asin(x)) = x + r1 = sin(asin(x)); + REQUIRE(eq(*r1, *x)); + // sin(pi + y) = -sin(y) r1 = sin(add(pi, y)); r2 = mul(im1, sin(y)); @@ -263,6 +267,10 @@ TEST_CASE("Cos: functions", "[functions]") r2 = cos(y); REQUIRE(eq(*r1, *r2)); + // cos(acos(x)) = x + r1 = cos(acos(x)); + REQUIRE(eq(*r1, *x)); + // cos(pi - y) = -cos(y) r1 = cos(sub(pi, y)); r2 = mul(im1, cos(y)); @@ -368,6 +376,10 @@ TEST_CASE("Tan: functions", "[functions]") r2 = mul(im1, tan(y)); REQUIRE(eq(*r1, *r2)); + // tan(atan(x)) = x + r1 = tan(atan(x)); + REQUIRE(eq(*r1, *x)); + // tan(pi + y) = -tan(y) r1 = tan(add(pi, y)); r2 = tan(y); @@ -459,6 +471,10 @@ TEST_CASE("Cot: functions", "[functions]") r2 = mul(im1, cot(y)); REQUIRE(eq(*r1, *r2)); + // cot(acot(x)) = x + r1 = cot(acot(x)); + REQUIRE(eq(*r1, *x)); + // cot(pi + y) = -cot(y) r1 = cot(add(pi, y)); r2 = cot(y); @@ -551,6 +567,10 @@ TEST_CASE("Csc: functions", "[functions]") r2 = csc(y); REQUIRE(eq(*r1, *r2)); + // csc(acsc(x)) = x + r1 = csc(acsc(x)); + REQUIRE(eq(*r1, *x)); + // csc(pi + y) = -csc(y) r1 = csc(add(pi, y)); r2 = mul(im1, csc(y)); @@ -642,6 +662,10 @@ TEST_CASE("Sec: functions", "[functions]") r2 = mul(im1, sec(y)); REQUIRE(eq(*r1, *r2)); + // sec(asec(x)) = x + r1 = sec(asec(x)); + REQUIRE(eq(*r1, *x)); + // sec(pi + y) = -sec(y) r1 = sec(add(pi, y)); r2 = mul(im1, sec(y));