Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RecursionError _solve _tsolve _vsolve in expression involving Mod and floor #24368

Closed
ezyang opened this issue Dec 11, 2022 · 17 comments · Fixed by #24999
Closed

RecursionError _solve _tsolve _vsolve in expression involving Mod and floor #24368

ezyang opened this issue Dec 11, 2022 · 17 comments · Fixed by #24999

Comments

@ezyang
Copy link

ezyang commented Dec 11, 2022

This expression was extracted from some PyTorch model shape expressions involving SymPy. We have discovered the expression in question is zero, and we would like to know if s2 is now statically known. We don't mind if solving says not implemented, or if there are no solutions (there are probably no solutions). However, the equation instead stack overflows.

import sympy
from sympy import Mod, floor

s2 = sympy.Symbol('s2', integer=True, positive=True)

sympy.solve((Mod(floor(s2/2 - 1/2)**2/(floor(s2/2 - 1/2) + 1) + 2*floor(s2/2 - 1/2)/(floor(s2/2 - 1/2) + 1) + 1/(floor(s2/2 - 1/2) + 1), 1))*floor(s2/2 - 1/2) + Mod(floor(s2/2 - 1/2)**2/(floor(s2/2 - 1/2) + 1) + 2*floor(s2/2 - 1/2)/(floor(s2/2 - 1/2) + 1) + 1/(floor(s2/2 - 1/2) + 1), 1) - 0, s2)

Fragment of the backtrace:

  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2752, in _tsolve
    return _vsolve(rewrite - rhs, sym, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1630, in _solve
    inv = _vsolve(u - t, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1682, in _solve
    soln = _tsolve(f_num, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2752, in _tsolve
    return _vsolve(rewrite - rhs, sym, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1630, in _solve
    inv = _vsolve(u - t, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1682, in _solve
    soln = _tsolve(f_num, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2752, in _tsolve
    return _vsolve(rewrite - rhs, sym, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1630, in _solve
    inv = _vsolve(u - t, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1682, in _solve
    soln = _tsolve(f_num, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2752, in _tsolve
    return _vsolve(rewrite - rhs, sym, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1630, in _solve
    inv = _vsolve(u - t, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1682, in _solve
    soln = _tsolve(f_num, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2752, in _tsolve
    return _vsolve(rewrite - rhs, sym, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1630, in _solve
    inv = _vsolve(u - t, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1682, in _solve
    soln = _tsolve(f_num, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2752, in _tsolve
    return _vsolve(rewrite - rhs, sym, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1630, in _solve
    inv = _vsolve(u - t, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1682, in _solve
    soln = _tsolve(f_num, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2752, in _tsolve
    return _vsolve(rewrite - rhs, sym, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1630, in _solve
    inv = _vsolve(u - t, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1682, in _solve
    soln = _tsolve(f_num, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2752, in _tsolve
    return _vsolve(rewrite - rhs, sym, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1630, in _solve
    inv = _vsolve(u - t, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1682, in _solve
    soln = _tsolve(f_num, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2752, in _tsolve
    return _vsolve(rewrite - rhs, sym, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1630, in _solve
    inv = _vsolve(u - t, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1682, in _solve
    soln = _tsolve(f_num, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2752, in _tsolve
    return _vsolve(rewrite - rhs, sym, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1630, in _solve
    inv = _vsolve(u - t, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1682, in _solve
    soln = _tsolve(f_num, symbol, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2752, in _tsolve
    return _vsolve(rewrite - rhs, sym, **flags)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2584, in _vsolve
    return [i[s] for i in _solve(e, s, **flags)]
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 1425, in _solve
    f_num, sol = solve_linear(f, symbols=symbols)
  File "/Users/ezyang/Dev/sympy/sympy/solvers/solvers.py", line 2132, in solve_linear
    dnewn_dxi = newn.diff(xi)
  File "/Users/ezyang/Dev/sympy/sympy/core/expr.py", line 3577, in diff
    return _derivative_dispatch(self, *symbols, **assumptions)
  File "/Users/ezyang/Dev/sympy/sympy/core/function.py", line 1913, in _derivative_dispatch
    return Derivative(expr, *variables, **kwargs)
  File "/Users/ezyang/Dev/sympy/sympy/core/function.py", line 1441, in __new__
    obj = cls._dispatch_eval_derivative_n_times(expr, v, count)
  File "/Users/ezyang/Dev/sympy/sympy/core/function.py", line 1902, in _dispatch_eval_derivative_n_times
    return expr._eval_derivative_n_times(v, count)
  File "/Users/ezyang/Dev/sympy/sympy/core/basic.py", line 1835, in _eval_derivative_n_times
    obj2 = obj._eval_derivative(s)
  File "/Users/ezyang/Dev/sympy/sympy/core/cache.py", line 72, in wrapper
    retval = cfunc(*args, **kwargs)
  File "/Users/ezyang/Dev/sympy/sympy/core/add.py", line 508, in _eval_derivative
    return self.func(*[a.diff(s) for a in self.args])
  File "/Users/ezyang/Dev/sympy/sympy/core/add.py", line 508, in <listcomp>
    return self.func(*[a.diff(s) for a in self.args])
  File "/Users/ezyang/Dev/sympy/sympy/core/expr.py", line 3577, in diff
    return _derivative_dispatch(self, *symbols, **assumptions)
  File "/Users/ezyang/Dev/sympy/sympy/core/function.py", line 1913, in _derivative_dispatch
    return Derivative(expr, *variables, **kwargs)
  File "/Users/ezyang/Dev/sympy/sympy/core/function.py", line 1441, in __new__
    obj = cls._dispatch_eval_derivative_n_times(expr, v, count)
  File "/Users/ezyang/Dev/sympy/sympy/core/function.py", line 1902, in _dispatch_eval_derivative_n_times
    return expr._eval_derivative_n_times(v, count)
  File "/Users/ezyang/Dev/sympy/sympy/core/basic.py", line 1835, in _eval_derivative_n_times
    obj2 = obj._eval_derivative(s)
  File "/Users/ezyang/Dev/sympy/sympy/core/function.py", line 608, in _eval_derivative
    df = self.fdiff(i)
  File "/Users/ezyang/Dev/sympy/sympy/core/function.py", line 791, in fdiff
    return Subs(Derivative(self.func(*args), D), D, A)
  File "/Users/ezyang/Dev/sympy/sympy/core/function.py", line 2205, in __new__
    s_pts = {p: Symbol(pre + mystr(p)) for p in pts}
  File "/Users/ezyang/Dev/sympy/sympy/core/function.py", line 2205, in <dictcomp>
    s_pts = {p: Symbol(pre + mystr(p)) for p in pts}
  File "/Users/ezyang/Dev/sympy/sympy/core/function.py", line 2203, in mystr
    return p.doprint(expr)
  File "/Users/ezyang/Dev/sympy/sympy/printing/printer.py", line 293, in doprint
    return self._str(self._print(expr))
  File "/Users/ezyang/Dev/sympy/sympy/printing/printer.py", line 332, in _print
    return printmethod(expr, **kwargs)
  File "/Users/ezyang/Dev/sympy/sympy/printing/str.py", line 57, in _print_Add
    t = self._print(term)
  File "/Users/ezyang/Dev/sympy/sympy/printing/printer.py", line 332, in _print
    return printmethod(expr, **kwargs)
  File "/Users/ezyang/Dev/sympy/sympy/printing/str.py", line 161, in _print_Function
    return expr.func.__name__ + "(%s)" % self.stringify(expr.args, ", ")
  File "/Users/ezyang/Dev/sympy/sympy/printing/str.py", line 41, in stringify
    return sep.join([self.parenthesize(item, level) for item in args])
  File "/Users/ezyang/Dev/sympy/sympy/printing/str.py", line 41, in <listcomp>
    return sep.join([self.parenthesize(item, level) for item in args])
  File "/Users/ezyang/Dev/sympy/sympy/printing/str.py", line 38, in parenthesize
    return self._print(item)
  File "/Users/ezyang/Dev/sympy/sympy/printing/printer.py", line 332, in _print
    return printmethod(expr, **kwargs)
  File "/Users/ezyang/Dev/sympy/sympy/printing/str.py", line 364, in _print_Mul
    a_str = [self.parenthesize(x, prec, strict=False) for x in a]
  File "/Users/ezyang/Dev/sympy/sympy/printing/str.py", line 364, in <listcomp>
    a_str = [self.parenthesize(x, prec, strict=False) for x in a]
  File "/Users/ezyang/Dev/sympy/sympy/printing/str.py", line 38, in parenthesize
    return self._print(item)
  File "/Users/ezyang/Dev/sympy/sympy/printing/printer.py", line 332, in _print
    return printmethod(expr, **kwargs)
  File "/Users/ezyang/Dev/sympy/sympy/printing/str.py", line 161, in _print_Function
    return expr.func.__name__ + "(%s)" % self.stringify(expr.args, ", ")
  File "/Users/ezyang/Dev/sympy/sympy/printing/str.py", line 41, in stringify
    return sep.join([self.parenthesize(item, level) for item in args])
  File "/Users/ezyang/Dev/sympy/sympy/printing/str.py", line 41, in <listcomp>
    return sep.join([self.parenthesize(item, level) for item in args])
  File "/Users/ezyang/Dev/sympy/sympy/printing/str.py", line 38, in parenthesize
    return self._print(item)
  File "/Users/ezyang/Dev/sympy/sympy/printing/printer.py", line 332, in _print
    return printmethod(expr, **kwargs)
  File "/Users/ezyang/Dev/sympy/sympy/printing/str.py", line 52, in _print_Add
    terms = self._as_ordered_terms(expr, order=order)
  File "/Users/ezyang/Dev/sympy/sympy/printing/printer.py", line 350, in _as_ordered_terms
    return expr.as_ordered_terms(order=order)
  File "/Users/ezyang/Dev/sympy/sympy/core/expr.py", line 1146, in as_ordered_terms
    terms, gens = self.as_terms()
  File "/Users/ezyang/Dev/sympy/sympy/core/expr.py", line 1176, in as_terms
    coeff = complex(coeff)
  File "/Users/ezyang/Dev/sympy/sympy/core/expr.py", line 355, in __complex__
    re, im = result.as_real_imag()
  File "/Users/ezyang/Dev/sympy/sympy/core/expr.py", line 1931, in as_real_imag
    return (re(self), im(self))
  File "/Users/ezyang/Dev/sympy/sympy/core/cache.py", line 72, in wrapper
    retval = cfunc(*args, **kwargs)
  File "/Users/ezyang/Dev/sympy/sympy/core/numbers.py", line 1375, in __eq__
    other = _sympify(other)
  File "/Users/ezyang/Dev/sympy/sympy/core/sympify.py", line 528, in _sympify
    return sympify(a, strict=True)
  File "/Users/ezyang/Dev/sympy/sympy/core/sympify.py", line 361, in sympify
    is_sympy = getattr(a, '__sympy__', None)
RecursionError: maximum recursion depth exceeded while calling a Python object

Reproduces on 1fbd995 (master at time of writing)

@ezyang
Copy link
Author

ezyang commented Dec 11, 2022

I have a few more repros from different models but on inspection of the stack I'm guessing they're all the same deal.

import sympy 
from sympy import Mod, floor 
 
s0 = sympy.Symbol('s0', integer=True, positive=True) 
s2 = sympy.Symbol('s2', integer=True, positive=True) 
 
sympy.solve(Mod(116, floor(116*s0*floor(floor(floor(s2/2 - 1/2)/2)/2)**2/(s0*floor(floor(floor(s2/2 - 1/2)/2)/2)**2 + 2*s0*floor(floor(floor(s2/2 - 1/2)/2)/2) + s0) + 232*s0*floor(floor(floor(s2/2 - 1/2)/2)/2)/(s0*floor(floor(floor(s2/2 - 1/2)/2)/2)**2 + 2*s0*floor(floor(floor(s2/2 - 1/2)/2)/2) + s0) + 116*s0/(s0*floor(floor(floor(s2/2 - 1/2)/2)/2)**2 + 2*s0*floor(floor(floor(s2/2 - 1/2)/2)/2) + s0))) - 0, s2) 

@smichr
Copy link
Member

smichr commented Dec 11, 2022

A check is made to see if the lhs, when rewritten in term of exp, has changed. Maybe the rewritten expression should be entered into and tested for being in flags['tsolve_saw'].

@asmeurer
Copy link
Member

It's really just the big mod factor that produces the recursion error. The other factor isn't implemented (which doesn't leave much hope that the other one would actually produce an answer):

>>> sympy.solve(floor(s2/2 - S(1)/2) + 1, s2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/aaronmeurer/Documents/Python/sympy/sympy/./sympy/solvers/solvers.py", line 1144, in solve
    solution = _solve(f[0], *symbols, **flags)
  File "/Users/aaronmeurer/Documents/Python/sympy/sympy/./sympy/solvers/solvers.py", line 1690, in _solve
    raise NotImplementedError('\n'.join([msg, not_impl_msg % f]))
NotImplementedError:
No algorithms are implemented to solve equation floor(s2/2 - 1/2) + 1

Solving floor(expr) = n means solving n <= expr < n+1. Here this seems to be implemented incorrectly for reduce_inequalities (the solution should be 3 <= s2 < 5, giving 3 and 4 as the integral solutions).

>>> reduce_inequalities(((s2 - 1)/2 < 2) & ((s2 - 1)/2 >= 1))
False

@ezyang
Copy link
Author

ezyang commented Dec 11, 2022

I found a very simple mod that induces the infinite loop:

>>> from sympy import Mod, Symbol
>>> s = Symbol("s")
>>> sympy.solve(Mod(s**2, 49), s)

@smichr
Copy link
Member

smichr commented Dec 12, 2022

seems to be implemented incorrectly

Solving floor(s2/2 - 1/2) + 1 means solving floor(s/2-1/2) = -1, or

>>> reduce_inequalities([(s-1)/2>=-1, (s2-1)/2<1])
(-1 <= x) & (x < 2)

@asmeurer
Copy link
Member

Ah, I didn't realize reduce_inequalities takes a list (also I mistakenly used 1 instead of -1 above).

So it seems the following steps are needed to make this work:

  • Convert expressions with floor into equivalent inequalities (using the law floor(x) = n <=> n <= x < n + 1).
  • Solve these inequalities with reduce_inequalities.
  • Convert this solution into a list of integers.

What's the best way to achieve the last item? A set of reduced inequalities might not represent a finite set, of course, so maybe this is a better fit for solveset.

It seems this is possible with

>>> reduce_inequalities([(s-1)/2>=-1, (s-1)/2<1]).as_set().intersection(Integers)
Range(-1, 3, 1)

Actually this is wrong (it should be Range(-1, 2, 1)), but that's presumably just a simple bug that can be fixed.

The result being a Range is annoying. Ideally it would be a FiniteSet when the answer is finite and not too large.

Regarding the actual recursion error. _tsolve is stuck in an infinite loop recursively passing the same expression to itself. It looks like the problem is that it tries to recursively rewrite the expression as exp, producing Mod(exp(2*log(x)), 49). But then it passes this through Poly, which is smart enough to convert it back to Mod(x**2, 49). There's several possible ways to fix this. I'm not sure what the idea behind the rewrite(exp) is but is it really supposed to rewrite exponents like x**2? It seems to me like it should be limited to only rewrite exponents where the symbol is in the exponent.

If we can fix that, then comes the hard part, which is making things like solve(Mod(x**2, 49)) work. Note that the OP Mod(floor(s2/2 - 1/2)**2/(floor(s2/2 - 1/2) + 1) + 2*floor(s2/2 - 1/2)/(floor(s2/2 - 1/2) + 1) + 1/(floor(s2/2 - 1/2) + 1), 1)) simplifies:

>>> simplify(S("Mod(floor(s2/2 - 1/2)**2/(floor(s2/2 - 1/2) + 1) + 2*floor(s2/2 - 1/2)/(floor(s2/2 - 1/2) + 1) + 1/(floor(s2/2 - 1/2) + 1), 1)"))
Mod(floor(s2/2 - 1/2), 1)

However, cancel() doesn't simplify inside of the Mod, which might be why it leads to the same recursion as the Mod(x**2, 49) example.

This really ought to simplify away completely to just 0, since floor() is always an integer.

So in fact, the solution set to this particular equation is all real numbers (you can verify this by substituting random values).

@ezyang is this is intended. Is being able to solve the other factor (floor(s2/2 - S(1)/2) + 1 = 0) or being able to solve things like Mod(x**2, 49) = 0 or even things similar to Mod(floor(s2/2 - 1/2), 1) where the argument to the mod is not always an integer something that would be useful?

@ezyang
Copy link
Author

ezyang commented Jan 18, 2023

I don't think we care too much about getting solutions here, we would like sympy to try some easy solving and if it fails no biggy, we will keep going.

@ezyang
Copy link
Author

ezyang commented Jan 18, 2023

Cc @Chillee

@smichr
Copy link
Member

smichr commented Jan 18, 2023

Actually this is wrong (it should be Range(-1, 2, 1))

It's ok: Range, like range, does not include the end point:

image

asmeurer added a commit to asmeurer/sympy that referenced this issue Jan 18, 2023
Setting it to False (it is True by default) disables rewriting cases that
create unevaluated Pow/exp objects.

This is used in solve to avoid a recursion error (fixes sympy#24368) which is
caused by the call to rewrite(exp) in solve, which then gets rewritten back to
a Pow by Poly. There doesn't appear to be any benefit to solving to rewrite
powers with nonsymbolic exponents to exp.
@asmeurer
Copy link
Member

Fix for the recursion error is at #24549. Let's open new issues for the other things identified above, like getting the actual solutions to work.

@asmeurer
Copy link
Member

It's ok: Range, like range, does not include the end point:

Oh my bad. I was misled by your comment that said the endpoint should be 2 but
I forgot to double check it myself #24368 (comment).

asmeurer added a commit to asmeurer/sympy that referenced this issue Jan 18, 2023
Setting it to False (it is True by default) disables rewriting cases that
create unevaluated Pow/exp objects.

This is used in solve to avoid a recursion error (fixes sympy#24368) which is
caused by the call to rewrite(exp) in solve, which then gets rewritten back to
a Pow by Poly. There doesn't appear to be any benefit to solving to rewrite
powers with nonsymbolic exponents to exp.
@oscarbenjamin
Copy link
Contributor

The problem is that this code is flakey:

sympy/sympy/solvers/solvers.py

Lines 2750 to 2752 in f4176c7

rewrite = lhs.rewrite(exp)
if rewrite != lhs:
return _vsolve(rewrite - rhs, sym, **flags)

I have seen this cause recursion error in other examples as well, not necessarily involving evaluate=False. This runs the risk of infinite recursion because it presumes that just because rewrite changed the expression in some way we can recurse again. The new expression should be equivalent to the one we started with though so there is a high chance that we will get back to the same place with the same expression because of any code that tries to canonicalise the form of the expression.

If possible we should just remove these lines:

diff --git a/sympy/solvers/solvers.py b/sympy/solvers/solvers.py
index 5c0f3ea..2ebbaa7 100644
--- a/sympy/solvers/solvers.py
+++ b/sympy/solvers/solvers.py
@@ -2747,9 +2747,6 @@ def equal(expr1, expr2):
             elif lhs.func == LambertW:
                 return _vsolve(lhs.args[0] - rhs*exp(rhs), sym, **flags)
 
-        rewrite = lhs.rewrite(exp)
-        if rewrite != lhs:
-            return _vsolve(rewrite - rhs, sym, **flags)
     except NotImplementedError:
         pass
 

That change passes all except two of the not slow tests in solvers. The failing examples are:

assert solve(5**(x/2) - 2**(3/x)) == [-b, b]
assert solve(sin(x) + tan(x)) == [0, -pi, pi, 2*pi]

Both give NotImplementedError. Each of these should be solved in a different way though.

For one thing it is possible to be specific about what types you want to rewrite:

In [11]: (1/x + cos(x)).rewrite(exp)
Out[11]: 
 x    -x           
              -log(x)
──── + ───── +        
 2       2             

In [12]: (1/x + cos(x)).rewrite([cos, sin], exp)
Out[12]: 
 x    -x    
             1
──── + ───── +2       2     x

Alternatively other rewrites can be used e.g.:

In [17]: (sin(x) + tan(x)).rewrite('sincos')
Out[17]: 
         sin(x)
sin(x) + ──────
         cos(x)

For the case of two exponentials:

In [19]: (5**(x/2) - 2**(3/x)).rewrite(exp)
Out[19]: 
   3log(2)    xlog(5)
   ────────    ────────
      x           2    
-          +   

Here we can just check if we have a**f(x) = b**g(x) and conclude that if a and b are both real and positive then this is equivalent to saying that log(a)*f(x) = log(b)*g(x):

In [30]: solve(log(5)*(x/2) - log(2)*3/x)
Out[30]: 
⎡      ________        ________⎤
⎢-6⋅╲╱ log(2)    √6⋅╲╱ log(2) ⎥
⎢───────────────, ─────────────⎥
⎢     ________        ________ ⎥
⎣   ╲╱ log(5)       ╲╱ log(5)  ⎦

Of course the full solution set here is infinite and corresponds to solutions of log(5)*(x/2)/log(2) - 3/x + 2*n*pi for integer n but in that situation solve usually returns a single solution for the "period".

@asmeurer
Copy link
Member

My thinking is that we do want to keep the exponential rewriting for trig solving, since implementing a trig solver that works without rewriting will likely just be a lot of work that would ultimately be duplicated. We could limit the rewrite to only rewrite trig functions, although that could be a regression if we miss a function, or if there is some custom user function that relied on being able to work with solve with an exp rewrite (but maybe that's not something that should be consider "breaking").

For the non-base e exponentials, I think we just need to fix the code that does the inversion so that it works the same way. It looks like there is a branch here that produces a different result for 5**(x/2) - 2**(3/x) vs. (5**(x/2) - 2**(3/x)).rewrite(exp)

sympy/sympy/solvers/solvers.py

Lines 3162 to 3171 in f4176c7

if a_base == b_base:
# a = -b
lhs = powsimp(powdenest(ad/bd))
rhs = -bi/ai
else:
rat = ad/bd
_lhs = powsimp(ad/bd)
if _lhs != rat:
lhs = _lhs
rhs = -bi/ai

asmeurer added a commit to asmeurer/sympy that referenced this issue Jan 19, 2023
This can lead to recursion errors (e.g., sympy#24368) because it is too easy for
the expression to get rewritten back to the way it was again in a recursive
call.

This still has some test failures. We need to handle trigonometric equations
in some other way (e.g., by doing a more targeted rewrite of them).
@asmeurer
Copy link
Member

OK, I have started a more targeted fix at removing the rewrite here #24553. As for #24549, I think we can remove the evaluate flag if #24553 works out, since it won't actually be used anywhere. Some of the other cleanups in that PR are still useful, though, so if that happens I'll salvage them.

@oscarbenjamin
Copy link
Contributor

My thinking is that we do want to keep the exponential rewriting for trig solving

I think that a better way to do this is by rewriting to a system of polynomial equations:

In [24]: eq = sin(x) + cos(x)

In [25]: solve(eq, x)
Out[25]: 
⎡-π ⎤
⎢───⎥
⎣ 4In [26]: s, c = symbols('s, c')

In [27]: sols = solve([s + c, s**2 + c**2 - 1], [s, c], dict=True)

In [28]: sols
Out[28]: 
⎡⎧   -22⎫  ⎧   √2     -2 ⎫⎤
⎢⎨c: ────, s: ──⎬, ⎨c: ──, s: ────⎬⎥
⎣⎩    2       2 ⎭  ⎩   2       2  ⎭⎦

In [29]: [{x: atan2(sol[s], sol[c])} for sol in sols]
Out[29]: 
⎡⎧   3π⎫  ⎧   -π ⎫⎤
⎢⎨x: ───⎬, ⎨x: ───⎬⎥
⎣⎩    4 ⎭  ⎩    4 ⎭⎦

In [30]: [checksol(eq, sol) for sol in _]
Out[30]: [True, True]

@asmeurer
Copy link
Member

I implemented the rewrite in terms of just trig functions for now. I'll need to look into how hard it is to implement your idea with the existing solve helpers.

asmeurer added a commit to asmeurer/sympy that referenced this issue Jan 31, 2023
asmeurer added a commit to asmeurer/sympy that referenced this issue Feb 1, 2023
This makes some tests still work similar to how they did before, but still
leaves sympy#24368 fixed.

However, it does change the behavior of some of the slow Lambert W tests, so
this approach may still require some thought.
asmeurer added a commit to asmeurer/sympy that referenced this issue Apr 1, 2023
This can lead to infinite recursion.

Fixes sympy#24368
@asmeurer
Copy link
Member

asmeurer commented Apr 1, 2023

OK, third attempt here. This is the simplest possible fix for this issue #24999, namely, just rebuild the expression after rewriting to filter out any evaluate=False rewrites (I've only applied it in this one place in the solve code; a more general way to do this is to add evaluate=True to rewrite as in #24549, but that approach was more controversial).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment