Skip to content

Commit

Permalink
merge baseset within condition
Browse files Browse the repository at this point in the history
Fix bug in FiniteSet::set_intersection()
  • Loading branch information
ranjithkumar007 committed Jun 10, 2017
1 parent f911edb commit 64530e1
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 94 deletions.
38 changes: 37 additions & 1 deletion symengine/logic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ int Contains::compare(const Basic &o) const
return unified_compare(get_set(), c.get_set());
}

RCP<const Basic> Contains::create(const RCP<const Basic> &lhs,
const RCP<const Set> &rhs) const
{
return contains(lhs, rhs);
}

RCP<const Boolean> contains(const RCP<const Basic> &expr,
const RCP<const Set> &set)
{
Expand Down Expand Up @@ -209,6 +215,11 @@ const set_boolean &And::get_container() const
return container_;
}

RCP<const Basic> And::create(const set_boolean &a) const
{
return logical_and(a);
}

RCP<const Boolean> And::logical_not() const
{
auto container = this->get_container();
Expand Down Expand Up @@ -422,7 +433,32 @@ RCP<const Boolean> and_or(const set_boolean &s, const bool &op_x_notx)
return *(args.begin());
else if (args.size() == 0)
return boolean(not op_x_notx);
return make_rcp<const caller>(args);

set_boolean rest;
std::map<RCP<const Basic>, set_set, RCPBasicKeyLess> mp;
for (auto &a : args) {
if (is_a<Contains>(*a)) {
mp[down_cast<const Contains &>(*a).get_expr()].insert(
down_cast<const Contains &>(*a).get_set());
} else {
rest.insert(a);
}
}

for (auto &elem : mp) {
if (op_x_notx) {
rest.insert(
contains(elem.first, SymEngine::set_union(elem.second)));
} else {
rest.insert(
contains(elem.first, SymEngine::set_intersection(elem.second)));
}
}
if (rest.size() == 1)
return *(rest.begin());
else if (rest.size() == 0)
return boolean(not op_x_notx);
return make_rcp<const caller>(rest);
}

RCP<const Boolean> logical_not(const RCP<const Boolean> &s)
Expand Down
3 changes: 3 additions & 0 deletions symengine/logic.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class Contains : public Boolean
RCP<const Set> get_set() const;
virtual vec_basic get_args() const;
virtual bool __eq__(const Basic &o) const;
virtual RCP<const Basic> create(const RCP<const Basic> &lhs,
const RCP<const Set> &rhs) const;
//! Structural equality comparator
virtual int compare(const Basic &o) const;
};
Expand Down Expand Up @@ -113,6 +115,7 @@ class And : public Boolean
//! \return the hash
hash_t __hash__() const;
virtual vec_basic get_args() const;
virtual RCP<const Basic> create(const set_boolean &a) const;
virtual bool __eq__(const Basic &o) const;
//! Structural equality comparator
virtual int compare(const Basic &o) const;
Expand Down
3 changes: 1 addition & 2 deletions symengine/printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,7 @@ void StrPrinter::bvisit(const ConditionSet &x)
s1 << apply(*p);
}
s1 << "}";
s << "{" << s1.str() << " | " << s1.str() << " in "
<< apply(x.get_baseset()) << " and " << apply(x.get_condition()) << "}";
s << "{" << s1.str() << " | " << apply(x.get_condition()) << "}";
str_ = s.str();
}

Expand Down
139 changes: 87 additions & 52 deletions symengine/sets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ RCP<const Set> Interval::close() const

RCP<const Boolean> Interval::contains(const RCP<const Basic> &a) const
{
if (not is_a_Number(*a) and not is_a<Constant>(*a))
if (not is_a_Number(*a))
return make_rcp<Contains>(a, rcp_from_this_cast<const Set>());
if (eq(*start_, *a))
return boolean(not left_open_);
Expand Down Expand Up @@ -336,7 +336,19 @@ int FiniteSet::compare(const Basic &o) const

RCP<const Boolean> FiniteSet::contains(const RCP<const Basic> &a) const
{
return boolean(container_.find(a) != container_.end());
set_basic rest;
for (const auto &elem : container_) {
auto cont = Eq(elem, a);
if (eq(*cont, *boolTrue))
return boolTrue;
if (not eq(*cont, *boolFalse))
rest.insert(elem);
}
if (rest.empty()) {
return boolFalse;
} else {
return make_rcp<Contains>(a, finiteset(rest));
}
}

RCP<const Set> FiniteSet::set_union(const RCP<const Set> &o) const
Expand Down Expand Up @@ -620,29 +632,41 @@ RCP<const Set> Complement::set_complement(const RCP<const Set> &o) const
return container_->set_complement(newuniv);
}

ConditionSet::ConditionSet(const vec_sym syms, RCP<const Set> base,
RCP<const Boolean> condition)
: syms_(syms), base_(base), condition_(condition)
ConditionSet::ConditionSet(const vec_sym syms, RCP<const Boolean> condition)
: syms_(syms), condition_(condition)
{
SYMENGINE_ASSIGN_TYPEID()
SYMENGINE_ASSERT(ConditionSet::is_canonical(syms, base, condition))
SYMENGINE_ASSERT(ConditionSet::is_canonical(syms, condition))
}

bool ConditionSet::is_canonical(const vec_sym syms, RCP<const Set> base,
bool ConditionSet::is_canonical(const vec_sym syms,
RCP<const Boolean> condition)
{
if (eq(*condition, *boolFalse) or eq(*condition, *boolTrue)) {
return false;
} else if (is_a<EmptySet>(*base)) {
return false;
} else if (is_a<FiniteSet>(*base)) {
} else if (is_a<And>(*condition)) {
if (syms.size() == 1) {
for (const auto &elem :
down_cast<const FiniteSet &>(*base).get_container()) {
if (is_a_Number(*elem) or is_a<Constant>(*elem))
return false;
down_cast<const And &>(*condition).get_container()) {
if (is_a<Contains>(*elem)
and eq(*down_cast<const Contains &>(*elem).get_expr(),
*syms[0])
and is_a<FiniteSet>(
*down_cast<const Contains &>(*elem).get_set())) {
auto fset
= down_cast<const FiniteSet &>(
*down_cast<const Contains &>(*elem).get_set())
.get_container();
for (const auto &a : fset) {
if (is_a_Number(*a) or is_a<Constant>(*a))
return false;
}
}
}
}
} else if (is_a<Contains>(*condition)) {
if (syms.size() == 1)
return false;
}
return true;
}
Expand All @@ -652,7 +676,6 @@ hash_t ConditionSet::__hash__() const
hash_t seed = CONDITIONSET;
for (const auto &a : syms_)
hash_combine<Basic>(seed, *a);
hash_combine<Basic>(seed, *base_);
hash_combine<Basic>(seed, *condition_);
return seed;
}
Expand All @@ -662,7 +685,6 @@ bool ConditionSet::__eq__(const Basic &o) const
if (is_a<ConditionSet>(o)) {
const ConditionSet &other = down_cast<const ConditionSet &>(o);
return unified_eq(syms_, other.get_symbols())
and unified_eq(base_, other.get_baseset())
and unified_eq(condition_, other.get_condition());
}
return false;
Expand All @@ -676,12 +698,7 @@ int ConditionSet::compare(const Basic &o) const
if (c1 != 0) {
return c1;
} else {
int c2 = unified_compare(base_, other.get_baseset());
if (c2 != 0) {
return c2;
} else {
return unified_compare(condition_, other.get_condition());
}
return unified_compare(condition_, other.get_condition());
}
}

Expand All @@ -693,8 +710,12 @@ RCP<const Set> ConditionSet::set_union(const RCP<const Set> &o) const
RCP<const Set> ConditionSet::set_intersection(const RCP<const Set> &o) const
{
if (not is_a<ConditionSet>(*o)) {
auto newbase = SymEngine::set_intersection({o, base_});
return conditionset(syms_, newbase, condition_);
auto it = syms_.begin();
RCP<const Boolean> newcond
= logical_and({condition_, o->contains(*it)});
for (; it != syms_.end(); it++)
newcond = logical_and({newcond, o->contains(*it)});
return conditionset(syms_, newcond);
}
throw std::runtime_error("Not implemented Intersection class");
}
Expand All @@ -706,8 +727,6 @@ RCP<const Set> ConditionSet::set_complement(const RCP<const Set> &o) const

RCP<const Boolean> ConditionSet::contains(const RCP<const Basic> &o) const
{
if (eq(*base_->contains(o), *boolFalse))
return boolean(false);
if (is_a<FiniteSet>(*o)) {
const FiniteSet &fs = down_cast<const FiniteSet &>(*o);
auto container = fs.get_container();
Expand All @@ -717,8 +736,6 @@ RCP<const Boolean> ConditionSet::contains(const RCP<const Basic> &o) const
map_basic_basic d;
int pos = 0;
for (const auto &elem : container) {
if (eq(*base_->contains(elem), *boolFalse))
return boolean(false);
d[syms_[pos]] = elem;
pos++;
}
Expand Down Expand Up @@ -803,7 +820,7 @@ RCP<const Set> set_intersection(const set_set &in)
bool present = true;
for (const auto &fset : fsets) {
auto contain = fset->contains(fselement);
if (is_a<Contains>(*contain)) {
if (not(eq(*contain, *boolTrue) or eq(*contain, *boolFalse))) {
throw std::runtime_error(
"Not implemented Intersection class");
}
Expand All @@ -813,7 +830,7 @@ RCP<const Set> set_intersection(const set_set &in)
continue;
for (const auto &oset : othersets) {
auto contain = oset->contains(fselement);
if (is_a<Contains>(*contain)) {
if (not(eq(*contain, *boolTrue) or eq(*contain, *boolFalse))) {
throw std::runtime_error(
"Not implemented Intersection class");
}
Expand Down Expand Up @@ -916,38 +933,56 @@ RCP<const Set> set_complement(const RCP<const Set> &universe,
return container->set_complement(universe);
}

RCP<const Set> conditionset(const vec_sym &syms, const RCP<const Set> &base,
RCP<const Set> conditionset(const vec_sym &syms,
const RCP<const Boolean> &condition)
{
if (ConditionSet::is_canonical(syms, base, condition)) {
return make_rcp<const ConditionSet>(syms, base, condition);
if (ConditionSet::is_canonical(syms, condition)) {
return make_rcp<const ConditionSet>(syms, condition);
}
if (eq(*condition, *boolean(false))) {
return emptyset();
} else if (eq(*condition, *boolean(true)) or is_a<EmptySet>(*base)) {
return base;
} else if (eq(*condition, *boolean(true))) {
return universalset();
}
if (is_a<FiniteSet>(*base)) {
const FiniteSet &fset = down_cast<const FiniteSet &>(*base);
auto &container = fset.get_container();
if (is_a<And>(*condition)) {
// Simplify if we have a single symbol.
const And &aset = down_cast<const And &>(*condition);
auto &container = aset.get_container();
set_basic present, others;
for (const auto &fselement : container) {
map_basic_basic d;
d[syms[0]] = fselement;
auto contain = condition->subs(d);
if (eq(*contain, *boolean(true))) {
present.insert(fselement);
} else if (not eq(*contain, *boolean(false))) {
others.insert(fselement);
for (auto it = container.begin(); it != container.end(); it++) {
if (is_a<Contains>(**it)
and eq(*down_cast<const Contains &>(**it).get_expr(), *syms[0])
and is_a<FiniteSet>(
*down_cast<const Contains &>(**it).get_set())) {
auto fset = down_cast<const FiniteSet &>(
*down_cast<const Contains &>(**it).get_set())
.get_container();
auto restCont = container;
restCont.erase(*it);
auto restCond = logical_and(restCont);
for (const auto &fselement : fset) {
map_basic_basic d;
d[syms[0]] = fselement;
auto contain = restCond->subs(d);
if (eq(*contain, *boolean(true))) {
present.insert(fselement);
} else if (not eq(*contain, *boolean(false))) {
others.insert(fselement);
}
}
if (others.empty()) {
return finiteset(present);
} else {
restCond = logical_and(
{finiteset(others)->contains(syms[0]), restCond});
return SymEngine::set_union(
{finiteset(present), conditionset(syms, restCond)});
}
}
}
if (others.empty()) {
return finiteset(present);
} else {
RCP<const Set> o = finiteset(others);
return SymEngine::set_union(
{finiteset(present), conditionset(syms, o, condition)});
}
}
if (is_a<Contains>(*condition)) {
return down_cast<const Contains &>(*condition).get_set();
}
throw std::runtime_error("Not implemented");
}
Expand Down
13 changes: 3 additions & 10 deletions symengine/sets.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ class ConditionSet : public Set
{
private:
vec_sym syms_;
RCP<const Set> base_;
RCP<const Boolean> condition_;

public:
Expand All @@ -265,10 +264,8 @@ class ConditionSet : public Set
{
return {};
}
ConditionSet(const vec_sym syms, RCP<const Set> base,
RCP<const Boolean> condition);
static bool is_canonical(const vec_sym syms, RCP<const Set> base,
RCP<const Boolean> condition);
ConditionSet(const vec_sym syms, RCP<const Boolean> condition);
static bool is_canonical(const vec_sym syms, RCP<const Boolean> condition);

virtual RCP<const Set> set_intersection(const RCP<const Set> &o) const;
virtual RCP<const Set> set_union(const RCP<const Set> &o) const;
Expand All @@ -279,10 +276,6 @@ class ConditionSet : public Set
{
return this->syms_;
}
inline const RCP<const Set> &get_baseset() const
{
return this->base_;
}
inline const RCP<const Boolean> &get_condition() const
{
return this->condition_;
Expand Down Expand Up @@ -337,7 +330,7 @@ RCP<const Set> set_complement(const RCP<const Set> &universe,
const RCP<const Set> &container);

//! \return RCP<const Set>
RCP<const Set> conditionset(const vec_sym &syms, const RCP<const Set> &base,
RCP<const Set> conditionset(const vec_sym &syms,
const RCP<const Boolean> &condition);
}
#endif
19 changes: 19 additions & 0 deletions symengine/subs.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,25 @@ class SubsVisitor : public BaseVisitor<SubsVisitor>
result_ = x.create(v);
}

void bvisit(const Contains &x)
{
RCP<const Basic> a = apply(x.get_expr());
RCP<const Set> b = rcp_static_cast<const Set>(apply(x.get_set()));
if (a == x.get_expr() and b == x.get_set())
result_ = x.rcp_from_this();
else
result_ = x.create(a, b);
}

void bvisit(const And &x)
{
set_boolean v;
for (const auto &elem : x.get_container()) {
v.insert(rcp_static_cast<const Boolean>(apply(elem)));
}
result_ = x.create(v);
}

void bvisit(const Derivative &x)
{
RCP<const Symbol> s;
Expand Down
Loading

0 comments on commit 64530e1

Please sign in to comment.