Skip to content

Commit

Permalink
Merge pull request #19502 from gschintgen/fix-19496-condset-dummy
Browse files Browse the repository at this point in the history
Fix ConditionSet.dummy_eq() and .as_dummy()
  • Loading branch information
oscarbenjamin committed Jun 7, 2020
2 parents 32c28eb + 71fadc2 commit eee759c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 25 deletions.
26 changes: 6 additions & 20 deletions sympy/sets/conditionset.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,12 @@ def __new__(cls, sym, condition, base_set=S.UniversalSet):

@property
def free_symbols(self):
s, c, b = self.args
return (c.free_symbols - s.free_symbols) | b.free_symbols
cond_syms = self.condition.free_symbols - self.sym.free_symbols
return cond_syms | self.base_set.free_symbols

@property
def bound_symbols(self):
return self.sym.free_symbols

def _contains(self, other):
return And(
Expand Down Expand Up @@ -246,21 +250,3 @@ def _eval_subs(self, old, new):
# __new__ we *don't* check if 'sym' actually belongs to
# 'base'. In other words: assumptions are ignored.
return self.func(self.sym, cond, base)

def dummy_eq(self, other, symbol=None):
if not isinstance(other, self.func):
return False
if isinstance(self.sym, Symbol) != isinstance(other.sym, Symbol):
# this test won't be necessary when unsolved equations
# syntax is removed
return False
if symbol:
raise ValueError('symbol arg not supported for ConditionSet')
o = other
if isinstance(self.sym, Symbol) and isinstance(other.sym, Symbol):
# this code will not need to be in an if-block when
# the unsolved equations syntax is removed
o = other.func(self.sym,
other.condition.subs(other.sym, self.sym),
other.base_set)
return self == o
43 changes: 38 additions & 5 deletions sympy/sets/tests/test_conditionset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sympy.sets import (ConditionSet, Intersection, FiniteSet,
EmptySet, Union, Contains, imageset)
from sympy import (Symbol, Eq, S, Abs, sin, asin, pi, Interval,
EmptySet, Union, Contains, ImageSet)
from sympy import (Symbol, Eq, Ne, S, Abs, sin, asin, pi, Interval,
And, Mod, oo, Function, Lambda)
from sympy.testing.pytest import raises, XFAIL, warns_deprecated_sympy

Expand Down Expand Up @@ -85,6 +85,30 @@ def test_free_symbols():
).free_symbols == {z}
assert ConditionSet(x, Eq(x, 0), FiniteSet(x, z)
).free_symbols == {x, z}
assert ConditionSet(x, Eq(x, 0), ImageSet(Lambda(y, y**2), S.Integers)
).free_symbols == set()


def test_bound_symbols():
assert ConditionSet(x, Eq(y, 0), FiniteSet(z)
).bound_symbols == {x}
assert ConditionSet(x, Eq(x, 0), FiniteSet(x, y)
).bound_symbols == {x}
assert ConditionSet(x, x < 10, ImageSet(Lambda(y, y**2), S.Integers)
).bound_symbols == {x}
assert ConditionSet(x, x < 10, ConditionSet(y, y > 1, S.Integers)
).bound_symbols == {x}


def test_as_dummy():
_0 = Symbol('_0')
assert ConditionSet(x, x < 1, Interval(y, oo)
).as_dummy() == ConditionSet(_0, _0 < 1, Interval(y, oo))
assert ConditionSet(x, x < 1, Interval(x, oo)
).as_dummy() == ConditionSet(_0, _0 < 1, Interval(x, oo))
assert ConditionSet(x, x < 1, ImageSet(Lambda(y, y**2), S.Integers)
).as_dummy() == ConditionSet(
_0, _0 < 1, ImageSet(Lambda(_0, _0**2), S.Integers))


def test_subs_CondSet():
Expand Down Expand Up @@ -132,8 +156,8 @@ def test_subs_CondSet():

# issue 17341
k = Symbol('k')
img1 = imageset(Lambda(k, 2*k*pi + asin(y)), S.Integers)
img2 = imageset(Lambda(k, 2*k*pi + asin(S.One/3)), S.Integers)
img1 = ImageSet(Lambda(k, 2*k*pi + asin(y)), S.Integers)
img2 = ImageSet(Lambda(k, 2*k*pi + asin(S.One/3)), S.Integers)
assert ConditionSet(x, Contains(
y, Interval(-1,1)), img1).subs(y, S.One/3).dummy_eq(img2)

Expand All @@ -154,7 +178,6 @@ def test_dummy_eq():
assert c.dummy_eq(C(y, y < 1, I))
assert c.dummy_eq(1) == False
assert c.dummy_eq(C(x, x < 1, S.Reals)) == False
raises(ValueError, lambda: c.dummy_eq(C(x, x < 1, S.Reals), z))

c1 = ConditionSet((x, y), Eq(x + 1, 0) & Eq(x + y, 0), S.Reals)
c2 = ConditionSet((x, y), Eq(x + 1, 0) & Eq(x + y, 0), S.Reals)
Expand All @@ -164,6 +187,16 @@ def test_dummy_eq():
assert c.dummy_eq(c1) is False
assert c1.dummy_eq(c) is False

# issue 19496
m = Symbol('m')
n = Symbol('n')
a = Symbol('a')
d1 = ImageSet(Lambda(m, m*pi), S.Integers)
d2 = ImageSet(Lambda(n, n*pi), S.Integers)
c1 = ConditionSet(x, Ne(a, 0), d1)
c2 = ConditionSet(x, Ne(a, 0), d2)
assert c1.dummy_eq(c2)


def test_contains():
assert 6 in ConditionSet(x, x > 5, Interval(1, 7))
Expand Down

0 comments on commit eee759c

Please sign in to comment.