Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added floor implementation in solvers #18596

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
3 changes: 1 addition & 2 deletions sympy/solvers/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,8 +1723,7 @@ def _expand(p):
flags['simplify'] = flags.get('simplify', False)
result = soln

# fallback if above fails
# -----------------------

if result is False:
# try unrad
if flags.pop('_unrad', True):
Expand Down
82 changes: 81 additions & 1 deletion sympy/solvers/solveset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from sympy.functions import (log, Abs, tan, cot, sin, cos, sec, csc, exp,
acos, asin, acsc, asec, arg,
piecewise_fold, Piecewise)
from sympy.functions.elementary.integers import floor
from sympy.functions.elementary.trigonometric import (TrigonometricFunction,
HyperbolicFunction)
from sympy.functions.elementary.miscellaneous import real_root
Expand All @@ -53,6 +54,7 @@
from sympy.utilities.iterables import numbered_symbols, has_dups
from sympy.calculus.util import periodicity, continuous_domain
from sympy.core.compatibility import ordered, default_sort_key, is_sequence
from sympy.core.function import Function

from types import GeneratorType
from collections import defaultdict
Expand Down Expand Up @@ -785,6 +787,77 @@ def _solve_radical(f, symbol, solveset_solver):
return solution_set


def _recursive_check_f(f):
not_compute = False
floor_type = False

if type(f) == floor:
mijo2 marked this conversation as resolved.
Show resolved Hide resolved
floor_type = True
f, not_compute, _ = _recursive_check_f(f.args[0])
return f, not_compute, floor_type
elif isinstance(f, Function):
not_compute = True
return f, not_compute, floor_type
elif f.is_Add:
cp_f = 0
operation_type = "Add"
elif f.is_Mul:
cp_f = 1
operation_type = "Mul"
elif f.is_Pow:
cp_f, not_compute, floor_type = _recursive_check_f(f.args[0])
return cp_f ** f.args[1], not_compute, floor_type
else:
return f, not_compute, floor_type

arg = f.args
lst = list(arg)

for i, r in enumerate(arg):
if isinstance(r, Function):
if type(r) == floor:
floor_type = True
lst[i], temp_not_compute, _ = _recursive_check_f(r.args[0])
not_compute = temp_not_compute if temp_not_compute else not_compute
else:
not_compute = True
return f, not_compute, floor_type
if r.is_Mul or r.is_Add or r.is_Pow:
temp_r = r.args[0] if r.is_Pow else r
lst[i], temp_not_compute, temp_floor_type = _recursive_check_f(temp_r)
lst[i] = lst[i] ** r.args[1] if r.is_Pow else lst[i]
floor_type = floor_type if floor_type else temp_floor_type
not_compute = temp_not_compute if temp_not_compute else not_compute

cp_f = cp_f + lst[i] if operation_type == "Add" else cp_f * lst[i]

return cp_f, not_compute, floor_type


def _solve_floor(f, symbol, solver):
""" Helper functions to solve equations with floor """
floor_eq, not_compute, floor_type = _recursive_check_f(f)

if not_compute:
raise NotImplementedError(filldedent('''
The floor equation
cannot be solved with
the current implementation.
'''))

if floor_type is False:
return EmptySet
else:
result = set()

lower_limit = solver(floor_eq, symbol)
upper_limit = solver(floor_eq-1, symbol)

for l, r in zip(lower_limit, upper_limit):
result.add(Interval.Ropen(l, r) if l < r else Interval.Lopen(r, l))
return Union(*result)


def _solve_abs(f, symbol, domain):
""" Helper function to solve equation involving absolute value function """
if not domain.is_subset(S.Reals):
Expand Down Expand Up @@ -954,6 +1027,13 @@ def _solveset(f, symbol, domain, _check=False):
except NotImplementedError:
result = ConditionSet(symbol, f, domain)
return result
elif f.has(floor):
if not domain.is_subset(S.Reals):
raise ValueError(filldedent('''
Floor expressions have meaning
only in real plane not in imaginary
one.'''))
result = _solve_floor(f, symbol, solver)
elif _is_modular(f, symbol):
result = _solve_modular(f, symbol, domain)
else:
Expand Down Expand Up @@ -1977,7 +2057,7 @@ def solveset(f, symbol=None, domain=S.Complexes):
if not isinstance(f, (Expr, Relational, Number)):
raise ValueError("%s is not a valid SymPy expression" % f)

if not isinstance(symbol, (Expr, Relational)) and symbol is not None:
if not isinstance(symbol, (Expr, Relational)) and symbol is not None:
raise ValueError("%s is not a valid SymPy symbol" % symbol)

if not isinstance(domain, Set):
Expand Down
11 changes: 11 additions & 0 deletions sympy/solvers/tests/test_solveset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sympy.functions.elementary.exponential import (LambertW, exp, log)
from sympy.functions.elementary.hyperbolic import (HyperbolicFunction,
sinh, tanh, cosh, sech, coth)
from sympy.functions.elementary.integers import floor
from sympy.functions.elementary.miscellaneous import sqrt, Min, Max
from sympy.functions.elementary.piecewise import Piecewise
from sympy.functions.elementary.trigonometric import (
Expand Down Expand Up @@ -2313,3 +2314,13 @@ def test_solve_modular_fail():
ImageSet(Lambda(n, 74*n + 31), S.Integers)

# end of modular tests


def test_solve_floor():
assert solveset(floor(x), x, S.Reals) == Interval.Ropen(0, 1)
assert solveset(floor(x+1), x, S.Reals) == Interval.Ropen(-1, 0)
assert solveset(x*floor(x+1), x, S.Reals) == Interval(-1, 0)
assert solveset(floor(2*x-3)-5, x, S.Reals) == Interval.Ropen(4, 9/2)

raises(NotImplementedError, lambda: solveset(floor(x) + sin(x), x, S.Reals))
raises(NotImplementedError, lambda: solveset(floor(2/(1 + (x*sin(x)))), x, S.Reals))