Skip to content

Commit

Permalink
Merge pull request #14145 from Upabjojr/imageset_manysets
Browse files Browse the repository at this point in the history
ImageSet now supports multiple sets
  • Loading branch information
Upabjojr committed Feb 11, 2018
2 parents 98d5dd9 + 296e818 commit af4f447
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 152 deletions.
8 changes: 5 additions & 3 deletions sympy/printing/latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1747,10 +1747,12 @@ def _print_Complexes(self, i):
return r"\mathbb{C}"

def _print_ImageSet(self, s):
return r"\left\{%s\; |\; %s \in %s\right\}" % (
sets = s.args[1:]
varsets = [r"%s \in %s" % (self._print(var), self._print(setv))
for var, setv in zip(s.lamda.variables, sets)]
return r"\left\{%s\; |\; %s\right\}" % (
self._print(s.lamda.expr),
', '.join([self._print(var) for var in s.lamda.variables]),
self._print(s.base_set))
', '.join(varsets))

def _print_ConditionSet(self, s):
vars_print = ', '.join([self._print(var) for var in Tuple(s.sym)])
Expand Down
11 changes: 7 additions & 4 deletions sympy/printing/pretty/pretty.py
Original file line number Diff line number Diff line change
Expand Up @@ -1722,12 +1722,15 @@ def _print_ImageSet(self, ts):
inn = u"\N{SMALL ELEMENT OF}"
else:
inn = 'in'
variables = self._print_seq(ts.lamda.variables)
variables = ts.lamda.variables
expr = self._print(ts.lamda.expr)
bar = self._print("|")
base = self._print(ts.base_set)

return self._print_seq((expr, bar, variables, inn, base), "{", "}", ' ')
sets = [self._print(i) for i in ts.args[1:]]
if len(sets) == 1:
return self._print_seq((expr, bar, variables[0], inn, sets[0]), "{", "}", ' ')
else:
pargs = tuple(j for var, setv in zip(variables, sets) for j in (var, inn, setv, ","))
return self._print_seq((expr, bar) + pargs[:-1], "{", "}", ' ')

def _print_ConditionSet(self, ts):
if self._use_unicode:
Expand Down
20 changes: 20 additions & 0 deletions sympy/printing/pretty/tests/test_pretty.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from sympy.tensor.functions import TensorProduct

from sympy.sets.setexpr import SetExpr
from sympy.sets import ImageSet

import sympy as sym
class lowergamma(sym.lowergamma):
Expand Down Expand Up @@ -3600,6 +3601,25 @@ def test_pretty_SetExpr():
assert upretty(se) == ucode_str


def test_pretty_ImageSet():
imgset = ImageSet(Lambda((x, y), x + y), {1, 2, 3}, {3, 4})
ascii_str = '{x + y | x in {1, 2, 3} , y in {3, 4}}'
ucode_str = u('{x + y | x ∊ {1, 2, 3} , y ∊ {3, 4}}')
assert pretty(imgset) == ascii_str
assert upretty(imgset) == ucode_str

imgset = ImageSet(Lambda(x, x**2), S.Naturals)
ascii_str = \
' 2 \n'\
'{x | x in S.Naturals}'
ucode_str = u('''\
⎧ 2 ⎫\n\
⎨x | x ∊ ℕ⎬\n\
⎩ ⎭''')
assert pretty(imgset) == ascii_str
assert upretty(imgset) == ucode_str


def test_pretty_ConditionSet():
from sympy import ConditionSet
ascii_str = '{x | x in (-oo, oo) and sin(x) = 0}'
Expand Down
3 changes: 3 additions & 0 deletions sympy/printing/tests/test_latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,9 @@ def test_latex_ImageSet():
x = Symbol('x')
assert latex(ImageSet(Lambda(x, x**2), S.Naturals)) == \
r"\left\{x^{2}\; |\; x \in \mathbb{N}\right\}"
y = Symbol('y')
imgset = ImageSet(Lambda((x, y), x + y), {1, 2, 3}, {3, 4})
assert latex(imgset) == r"\left\{x + y\; |\; x \in \left\{1, 2, 3\right\}, y \in \left\{3, 4\right\}\right\}"


def test_latex_ConditionSet():
Expand Down
132 changes: 122 additions & 10 deletions sympy/sets/dispatchers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from sympy.multipledispatch import dispatch, Dispatcher
from sympy.core import Basic, Expr, Function, Add, Mul, Pow, Dummy, Integer
from sympy import Min, Max, Set, sympify, Lambda, symbols, exp, log
from sympy.sets import imageset, Interval, FiniteSet, Union, ImageSet, ProductSet
from sympy import Min, Max, Set, sympify, Lambda, symbols, exp, log, S
from sympy.sets import (imageset, Interval, FiniteSet, Union, ImageSet,
ProductSet, EmptySet, Intersection)
from sympy.core.function import FunctionClass
from sympy.logic.boolalg import And, Or, Not, true, false


_x, _y = symbols("x y")


@dispatch(Set, Set)
def add_sets(x, y):
return ImageSet(Lambda((_x, _y), (_x+_y)), ProductSet(x, y))
return ImageSet(Lambda((_x, _y), (_x+_y)), x, y)


@dispatch(Expr, Expr)
Expand All @@ -35,7 +37,7 @@ def sub_sets(x, y):

@dispatch(Set, Set)
def sub_sets(x, y):
return ImageSet(Lambda((_x, _y), (_x - _y)), ProductSet(x, y))
return ImageSet(Lambda((_x, _y), (_x - _y)), x, y)


@dispatch(Interval, Interval)
Expand All @@ -50,7 +52,7 @@ def sub_sets(x, y):

@dispatch(Set, Set)
def mul_sets(x, y):
return ImageSet(Lambda((_x, _y), (_x * _y)), ProductSet(x, y))
return ImageSet(Lambda((_x, _y), (_x * _y)), x, y)


@dispatch(Expr, Expr)
Expand Down Expand Up @@ -89,7 +91,7 @@ def div_sets(x, y):

@dispatch(Set, Set)
def div_sets(x, y):
return ImageSet(Lambda((_x, _y), (_x / _y)), ProductSet(x, y))
return ImageSet(Lambda((_x, _y), (_x / _y)), x, y)


@dispatch(Interval, Interval)
Expand All @@ -106,7 +108,7 @@ def div_sets(x, y):

@dispatch(Set, Set)
def pow_sets(x, y):
return ImageSet(Lambda((_x, _y), (_x ** _y)), ProductSet(x, y))
return ImageSet(Lambda((_x, _y), (_x ** _y)), x, y)


@dispatch(Expr, Expr)
Expand Down Expand Up @@ -138,9 +140,98 @@ def pow_sets(x, y):
return Interval(x.start**exponent, x.end**exponent, x.left_open, x.right_open)


@dispatch(FunctionClass, FiniteSet)
FunctionUnion = (FunctionClass, Lambda)


@dispatch(FunctionUnion, FiniteSet)
def function_sets(f, x):
return FiniteSet(*map(f, x))

@dispatch(Lambda, Interval)
def function_sets(f, x):
return FiniteSet(*[f(i) for i in x])
from sympy.functions.elementary.miscellaneous import Min, Max
from sympy.solvers.solveset import solveset
from sympy.core.function import diff, Lambda
from sympy.series import limit
from sympy.calculus.singularities import singularities
from sympy.sets import Complement
# TODO: handle functions with infinitely many solutions (eg, sin, tan)
# TODO: handle multivariate functions

expr = f.expr
if len(expr.free_symbols) > 1 or len(f.variables) != 1:
return
var = f.variables[0]

if expr.is_Piecewise:
result = S.EmptySet
domain_set = x
for (p_expr, p_cond) in expr.args:
if p_cond is true:
intrvl = domain_set
else:
intrvl = p_cond.as_set()
intrvl = Intersection(domain_set, intrvl)

if p_expr.is_Number:
image = FiniteSet(p_expr)
else:
image = imageset(Lambda(var, p_expr), intrvl)
result = Union(result, image)

# remove the part which has been `imaged`
domain_set = Complement(domain_set, intrvl)
if domain_set.is_EmptySet:
break
return result

if not x.start.is_comparable or not x.end.is_comparable:
return

try:
sing = [i for i in singularities(expr, var)
if i.is_real and i in x]
except NotImplementedError:
return

if x.left_open:
_start = limit(expr, var, x.start, dir="+")
elif x.start not in sing:
_start = f(x.start)
if x.right_open:
_end = limit(expr, var, x.end, dir="-")
elif x.end not in sing:
_end = f(x.end)

if len(sing) == 0:
solns = list(solveset(diff(expr, var), var))

extr = [_start, _end] + [f(i) for i in solns
if i.is_real and i in x]
start, end = Min(*extr), Max(*extr)

left_open, right_open = False, False
if _start <= _end:
# the minimum or maximum value can occur simultaneously
# on both the edge of the interval and in some interior
# point
if start == _start and start not in solns:
left_open = x.left_open
if end == _end and end not in solns:
right_open = x.right_open
else:
if start == _end and start not in solns:
left_open = x.right_open
if end == _start and end not in solns:
right_open = x.left_open

return Interval(start, end, left_open, right_open)
else:
return imageset(f, Interval(x.start, sing[0],
x.left_open, True)) + \
Union(*[imageset(f, Interval(sing[i], sing[i + 1], True, True))
for i in range(0, len(sing) - 1)]) + \
imageset(f, Interval(sing[-1], x.end, True, x.right_open))

@dispatch(FunctionClass, Interval)
def function_sets(f, x):
Expand All @@ -150,6 +241,27 @@ def function_sets(f, x):
return Interval(log(x.start), log(x.end), x.left_open, x.right_open)
return ImageSet(Lambda(_x, f(_x)), x)

@dispatch(FunctionClass, Set)
@dispatch(FunctionUnion, Union)
def function_sets(f, x):
return Union(imageset(f, arg) for arg in x.args)

@dispatch(FunctionUnion, Intersection)
def function_sets(f, x):
# If the function is invertible, intersect the maps of the sets.
u = symbols("u")
fdiff = f(u).diff(u)
# TODO: find a better condition for invertible functions:
if ((f in (exp, log)) # functions known to be invertible
or (fdiff > 0) == True or (fdiff < 0) == True # monotonous funcs
):
return Intersection(imageset(f, arg) for arg in x.args)
else:
return ImageSet(Lambda(_x, f(_x)), x)

@dispatch(FunctionUnion, EmptySet)
def function_sets(f, x):
return x

@dispatch(FunctionUnion, Set)
def function_sets(f, x):
return ImageSet(Lambda(_x, f(_x)), x)
14 changes: 7 additions & 7 deletions sympy/sets/fancysets.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,15 @@ class ImageSet(Set):
========
sympy.sets.sets.imageset
"""
def __new__(cls, lamda, base_set):
if not isinstance(lamda, Lambda):
def __new__(cls, flambda, *sets):
if not isinstance(flambda, Lambda):
raise ValueError('first argument must be a Lambda')
if lamda is S.IdentityFunction:
return base_set
if not lamda.expr.free_symbols or not lamda.expr.args:
return FiniteSet(lamda.expr)
if flambda is S.IdentityFunction and len(sets) == 1:
return sets[0]
if not flambda.expr.free_symbols or not flambda.expr.args:
return FiniteSet(flambda.expr)

return Basic.__new__(cls, lamda, base_set)
return Basic.__new__(cls, flambda, *sets)

lamda = property(lambda self: self.args[0])
base_set = property(lambda self: self.args[1])
Expand Down
2 changes: 2 additions & 0 deletions sympy/sets/setexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,5 +123,7 @@ def apply_operation(op, x, y):
if isinstance(y, SetExpr):
y = y.set
out = dispatch_on_operation(x, y, op)
if not isinstance(out, Set):
import pdb; pdb.set_trace()
assert isinstance(out, Set)
return SetExpr(out)

0 comments on commit af4f447

Please sign in to comment.