diff --git a/symengine/cwrapper.cpp b/symengine/cwrapper.cpp index 8e9261faaa..facab7cacb 100644 --- a/symengine/cwrapper.cpp +++ b/symengine/cwrapper.cpp @@ -28,6 +28,7 @@ using SymEngine::ComplexDouble; using SymEngine::CSRMatrix; using SymEngine::DenseMatrix; using SymEngine::down_cast; +using SymEngine::E; using SymEngine::function_symbol; using SymEngine::FunctionSymbol; using SymEngine::has_symbol; @@ -286,6 +287,33 @@ int basic_has_symbol(const basic e, const basic s) return (int)(has_symbol(*(e->m), *(s->m))); } +int basic_is_Add(const basic s) +{ + return basic_get_type(s) == SYMENGINE_ADD; +} + +int basic_is_Mul(const basic s) +{ + return basic_get_type(s) == SYMENGINE_MUL; +} + +int basic_is_Pow(const basic s) +{ + return basic_get_type(s) == SYMENGINE_POW; +} + +int basic_is_Log(const basic s) +{ + return basic_get_type(s) == SYMENGINE_LOG; +} + +int basic_is_Exp(const basic s) +{ + SYMENGINE_ASSERT(basic_is_Pow(s) == 1); + auto args = s->m->get_args(); + return args[1] == E; +} + CWRAPPER_OUTPUT_TYPE integer_set_si(basic s, long i) { CWRAPPER_BEGIN diff --git a/symengine/cwrapper.h b/symengine/cwrapper.h index 37c0b907a7..f1bf1f2732 100644 --- a/symengine/cwrapper.h +++ b/symengine/cwrapper.h @@ -170,6 +170,16 @@ int number_is_complex(const basic s); //! Returns 1 if `e` contains the symbol `s`; 0 otherwise int basic_has_symbol(const basic e, const basic s); +//! Returns 1 if `s` is of type Add; 0 otherwise +int basic_is_Add(const basic s); +//! Returns 1 if `s` is of type Mul; 0 otherwise +int basic_is_Mul(const basic s); +//! Returns 1 if `s` is of type Pow; 0 otherwise +int basic_is_Pow(const basic s); +//! Returns 1 if `s` is of type Log; 0 otherwise +int basic_is_Log(const basic s); +//! Returns 1 if `s` is of type Exp; 0 otherwise +int basic_is_Exp(const basic s); //! Assign to s, a long. CWRAPPER_OUTPUT_TYPE integer_set_si(basic s, long i); diff --git a/symengine/tests/cwrapper/test_cwrapper.c b/symengine/tests/cwrapper/test_cwrapper.c index f16e4a3996..20ee2d22d4 100644 --- a/symengine/tests/cwrapper/test_cwrapper.c +++ b/symengine/tests/cwrapper/test_cwrapper.c @@ -59,10 +59,12 @@ void test_cwrapper() SYMENGINE_C_ASSERT(basic_has_symbol(e, y) == 0); SYMENGINE_C_ASSERT(basic_has_symbol(e, z) == 0); basic_add(e, e, x); + SYMENGINE_C_ASSERT(basic_is_Add(e) == 1); SYMENGINE_C_ASSERT(basic_has_symbol(e, x) == 1); SYMENGINE_C_ASSERT(basic_has_symbol(e, y) == 0); SYMENGINE_C_ASSERT(basic_has_symbol(e, z) == 0); basic_mul(e, e, y); + SYMENGINE_C_ASSERT(basic_is_Mul(e) == 1); SYMENGINE_C_ASSERT(basic_has_symbol(e, x) == 1); SYMENGINE_C_ASSERT(basic_has_symbol(e, y) == 1); SYMENGINE_C_ASSERT(basic_has_symbol(e, z) == 0); @@ -102,7 +104,9 @@ void test_cwrapper() integer_set_ui(e, 123); basic_sqrt(e, e); + SYMENGINE_C_ASSERT(basic_is_Pow(e) == 1); basic_exp(e, e); + SYMENGINE_C_ASSERT(basic_is_Exp(e) == 1); s = basic_str(e); SYMENGINE_C_ASSERT(strcmp(s, "exp(sqrt(123))") == 0); @@ -1389,6 +1393,7 @@ void test_functions() integer_set_ui(res, 2); basic_log(res, res); + SYMENGINE_C_ASSERT(basic_is_Log(res) == 1); SYMENGINE_C_ASSERT(basic_eq(res, ans)); real_double_set_d(res, 1.1);