Skip to content

Commit

Permalink
Merge pull request #1443 from isuruf/countops
Browse files Browse the repository at this point in the history
Implement count_ops
  • Loading branch information
isuruf committed May 11, 2018
2 parents af35ae0 + f7015cc commit b3a2159
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 1 deletion.
5 changes: 5 additions & 0 deletions symengine/tests/basic/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,8 @@ endif()
add_executable(test_cse test_cse.cpp)
target_link_libraries(test_cse symengine catch)
add_test(test_cse ${PROJECT_BINARY_DIR}/test_cse)


add_executable(test_count_ops test_count_ops.cpp)
target_link_libraries(test_count_ops symengine catch)
add_test(test_count_ops ${PROJECT_BINARY_DIR}/test_count_ops)
46 changes: 46 additions & 0 deletions symengine/tests/basic/test_count_ops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include "catch.hpp"

#include <symengine/visitor.h>

using SymEngine::Basic;
using SymEngine::RCP;
using SymEngine::symbol;
using SymEngine::integer;
using SymEngine::one;
using SymEngine::I;
using SymEngine::pi;
using SymEngine::count_ops;

TEST_CASE("CountOps", "[count_ops]")
{
RCP<const Basic> x = symbol("x");
RCP<const Basic> y = symbol("y");
RCP<const Basic> i2 = integer(2);
RCP<const Basic> r1;

r1 = add(add(one, x), y);
REQUIRE(count_ops({r1}) == 2);

r1 = add(add(x, x), y);
REQUIRE(count_ops({r1}) == 2);

r1 = mul(mul(x, x), y);
REQUIRE(count_ops({r1}) == 2);

r1 = mul(mul(i2, x), y);
REQUIRE(count_ops({r1}) == 2);

r1 = add(add(I, one), sin(x));
REQUIRE(count_ops({r1}) == 3);

r1 = add(add(mul(i2, I), one), sin(x));
REQUIRE(count_ops({r1}) == 4);

r1 = add(I, pi);
REQUIRE(count_ops({r1}) == 1);

r1 = pow(pi, pi);
REQUIRE(count_ops({r1}) == 1);

REQUIRE(count_ops({x, y}) == 0);
}
90 changes: 90 additions & 0 deletions symengine/visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,4 +208,94 @@ void preorder_traversal_local_stop(const Basic &b, LocalStopVisitor &v)
}
}

void CountOpsVisitor::apply(const Basic &b)
{
b.accept(*this);
}

void CountOpsVisitor::bvisit(const Mul &x)
{
if (neq(*(x.get_coef()), *one)) {
count++;
apply(*x.get_coef());
}

for (const auto &p : x.get_dict()) {
if (neq(*p.second, *one)) {
count++;
apply(*p.second);
}
apply(*p.first);
count++;
}
count--;
}

void CountOpsVisitor::bvisit(const Add &x)
{
if (neq(*(x.get_coef()), *zero)) {
count++;
apply(*x.get_coef());
}

unsigned i = 0;
for (const auto &p : x.get_dict()) {
if (neq(*p.second, *one)) {
count++;
apply(*p.second);
}
apply(*p.first);
count++;
i++;
}
count--;
}

void CountOpsVisitor::bvisit(const Pow &x)
{
count++;
apply(*x.get_exp());
apply(*x.get_base());
}

void CountOpsVisitor::bvisit(const Number &x)
{
}

void CountOpsVisitor::bvisit(const ComplexBase &x)
{
if (neq(*x.real_part(), *zero)) {
count++;
}

if (neq(*x.imaginary_part(), *one)) {
count++;
}
}

void CountOpsVisitor::bvisit(const Symbol &x)
{
}

void CountOpsVisitor::bvisit(const Constant &x)
{
}

void CountOpsVisitor::bvisit(const Basic &x)
{
count++;
for (const auto &p : x.get_args()) {
apply(*p);
}
}

unsigned count_ops(const vec_basic &a)
{
CountOpsVisitor v;
for (auto &p : a) {
v.apply(*p);
}
return v.count;
}

} // SymEngine
19 changes: 18 additions & 1 deletion symengine/visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,24 @@ inline set_basic atoms(const Basic &b)
{
AtomsVisitor<Args...> visitor;
return visitor.apply(b);
}
};

class CountOpsVisitor : public BaseVisitor<CountOpsVisitor>
{
public:
unsigned count = 0;
void apply(const Basic &b);
void bvisit(const Mul &x);
void bvisit(const Add &x);
void bvisit(const Pow &x);
void bvisit(const Number &x);
void bvisit(const ComplexBase &x);
void bvisit(const Symbol &x);
void bvisit(const Constant &x);
void bvisit(const Basic &x);
};

unsigned count_ops(const vec_basic &a);

} // SymEngine

Expand Down

0 comments on commit b3a2159

Please sign in to comment.