Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bernoulli differential equation #22294

Open
nickme68 opened this issue Oct 16, 2021 · 12 comments · May be fixed by #22312
Open

Bernoulli differential equation #22294

nickme68 opened this issue Oct 16, 2021 · 12 comments · May be fixed by #22312

Comments

@nickme68
Copy link

nickme68 commented Oct 16, 2021

Hi!

I have the new problem with solving Bernoulli equation:

from sympy import *
from IPython.display import display
x = symbols("x")
y = Function("y")
y_ = Derivative(y(x), x)
eq = Eq(y_-y(x), x*y(x)**2)
display(eq)
display(dsolve(eq))
eq2 = Eq(y_-y(x), -x*y(x)**2)
display(eq2)
display(dsolve(eq2))

Two equations are almost the same, but the first one is solved correctly, and the second one is solved only by series decomposition. Again online version of SymPy solves both equations correctly.

Regards, Nikolay

@oscarbenjamin
Copy link
Contributor

I have the new problem with solving Bernoulli equation:

When you say this is a new problem have you already opened an issue somewhere else? If so please link to the other issue. Maybe this new problem should just be added there rather than listed as a separate issue.

The current master branch gives non-series solutions for both equations but that's because the second is now solved by the 1st_rational_riccati solver. There is a bug in the matching code for the Bernoulli solver as it does not think it can solve the equation:

In [3]: classify_ode(eq2)
Out[3]: ('factorable', '1st_rational_riccati', '1st_power_series', 'lie_group')

skirpichev added a commit to skirpichev/diofant that referenced this issue Oct 16, 2021
@nickme68
Copy link
Author

When you say this is a new problem have you already opened an issue somewhere else?

Ok, I had similar problem with separable equations 2-3 weeks ago: changing sign from plus to minus makes equation unsolvable (22155). Possibly it is the same problem? By the way online (old) version of SymPy classifies equation correctly as Bernoulli equation:
('Bernoulli', '1st_power_series', 'lie_group', 'Bernoulli_Integral')

@oscarbenjamin
Copy link
Contributor

This example matched the Bernoulli hint until 059a1cf from #18403

@oscarbenjamin
Copy link
Contributor

CC @Mohitbalwani26

@oscarbenjamin
Copy link
Contributor

The problem is a bug in match:

from sympy import *

x = Symbol('x')
y = Function('y')
P = Wild('P', exclude=[y(x)])
Q = Wild('Q', exclude=[y(x)])
n = Wild('n', exclude=[x, y(x), y(x).diff(x)])

eq = x*y(x)**2 - y(x) + Derivative(y(x), x)
pattern = P*y(x) - Q*y(x)**n + Derivative(y(x), x)

print(eq.match(pattern)) # gives None

This should match with {P:-1, Q:x, n:2}.

@oscarbenjamin
Copy link
Contributor

oscarbenjamin commented Oct 17, 2021

The underlying issue is this:

from sympy import *

x, z = symbols('x, z')
Q = Wild('Q')
n = Wild('n')

eq = x*z**2
pattern = - Q*z**n

print(pattern.matches(eq)) # None

That should match with {Q:-1, n:2}.

@oscarbenjamin
Copy link
Contributor

oscarbenjamin commented Oct 17, 2021

In fact it's as simple as this:

from sympy import *

z = symbols('z')
n = Wild('n')

eq = z**2
pattern = z**n

print(pattern.matches(eq)) # None

@oscarbenjamin
Copy link
Contributor

This seems to fix all the issues mentioned above but leads to test failures in core:

diff --git a/sympy/core/operations.py b/sympy/core/operations.py
index efacede..c4360c9 100644
--- a/sympy/core/operations.py
+++ b/sympy/core/operations.py
@@ -255,9 +255,6 @@ def _matches_commutative(self, expr, repl_dict=None, old=False):
                 # the matching continue
                 return None
             newexpr = self._combine_inverse(expr, exact)
-            if not old and (expr.is_Add or expr.is_Mul):
-                if newexpr.count_ops() > expr.count_ops():
-                    return None
             newpattern = self._new_rawargs(*wild_part)
             return newpattern.matches(newexpr, repl_dict)
 
diff --git a/sympy/core/power.py b/sympy/core/power.py
index 455481a..8185d9a 100644
--- a/sympy/core/power.py
+++ b/sympy/core/power.py
@@ -1558,19 +1558,12 @@ def matches(self, expr, repl_dict=None, old=False):
 
         b, e = expr.as_base_exp()
 
-        # special case number
-        sb, se = self.as_base_exp()
-        if sb.is_Symbol and se.is_Integer and expr:
-            if e.is_rational:
-                return sb.matches(b**(e/se), repl_dict)
-            return sb.matches(expr**(1/se), repl_dict)
-
         d = repl_dict.copy()
         d = self.base.matches(b, d)
         if d is None:
             return None
 
-        d = self.exp.xreplace(d).matches(e, d)
+        d = e.matches(self.exp.xreplace(d), d)
         if d is None:
             return Expr.matches(self, expr, repl_dict)
         return d

Failures:

============================================ FAILURES ==============================================
___________________________________________ test_replace ____________________________________________

    def test_replace():
        f = log(sin(x)) + tan(sin(x**2))
    
        assert f.replace(sin, cos) == log(cos(x)) + tan(cos(x**2))
        assert f.replace(
            sin, lambda a: sin(2*a)) == log(sin(2*x)) + tan(sin(2*x**2))
    
        a = Wild('a')
        b = Wild('b')
    
        assert f.replace(sin(a), cos(a)) == log(cos(x)) + tan(cos(x**2))
        assert f.replace(
            sin(a), lambda a: sin(2*a)) == log(sin(2*x)) + tan(sin(2*x**2))
        # test exact
        assert (2*x).replace(a*x + b, b - a, exact=True) == 2*x
        assert (2*x).replace(a*x + b, b - a) == 2*x
>       assert (2*x).replace(a*x + b, b - a, exact=False) == 2/x
E       assert -2/x**2 == (2 / x)
E        +  where -2/x**2 = <bound method Basic.replace of 2*x>(((a_ * x) + b_), (b_ - a_), exact=False)
E        +    where <bound method Basic.replace of 2*x> = (2 * x).replace

sympy/core/tests/test_expr.py:894: AssertionError
__________________________________________ test_issue_3773 __________________________________________

    def test_issue_3773():
        x = symbols('x')
        z, phi, r = symbols('z phi r')
        c, A, B, N = symbols('c A B N', cls=Wild)
        l = Wild('l', exclude=(0,))
    
        eq = z * sin(2*phi) * r**7
        matcher = c * sin(phi*N)**l * r**A * log(r)**B
    
>       assert eq.match(matcher) == {c: z, l: 1, N: 2, A: 7, B: 0}
E       assert None == {c_: z, l_: 1, N_: 2, A_: 7, ...}
E        +  where None = <bound method Basic.match of r**7*z*sin(2*phi)>(r**A_*c_*log(r)**B_*sin(phi*N_)**l_)
E        +    where <bound method Basic.match of r**7*z*sin(2*phi)> = r**7*z*sin(2*phi).match

sympy/core/tests/test_match.py:493: AssertionError
__________________________________________ test_issue_4559 __________________________________________

    def test_issue_4559():
        x = Symbol('x')
        e = Symbol('e')
        w = Wild('w', exclude=[x])
        y = Wild('y')
    
        # this is as it should be
    
        assert (3/x).match(w/y) == {w: 3, y: x}
        assert (3*x).match(w*y) == {w: 3, y: x}
>       assert (x/3).match(y/w) == {w: 3, y: x}
E       assert None == {w_: 3, y_: x}
E        +  where None = <bound method Basic.match of x/3>((y_ / w_))
E        +    where <bound method Basic.match of x/3> = (x / 3).match

sympy/core/tests/test_match.py:588: AssertionError
__________________________________________ test_issue_3539 __________________________________________

    def test_issue_3539():
        a = Wild('a')
        x = Symbol('x')
>       assert (x - 2).match(a - x) is None
E       assert {a_: 2*x - 2} is None
E        +  where {a_: 2*x - 2} = <bound method Basic.match of x - 2>((a_ - x))
E        +    where <bound method Basic.match of x - 2> = (x - 2).match

sympy/core/tests/test_match.py:678: AssertionError
________________________________________ test_gh_issue_2711 _________________________________________

    def test_gh_issue_2711():
        x = Symbol('x')
        f = meijerg(((), ()), ((0,), ()), x)
        a = Wild('a')
        b = Wild('b')
    
        assert f.find(a) == {(S.Zero,), ((), ()), ((S.Zero,), ()), x, S.Zero,
                                 (), meijerg(((), ()), ((S.Zero,), ()), x)}
        assert f.find(a + b) == \
            {meijerg(((), ()), ((S.Zero,), ()), x), x, S.Zero}
>       assert f.find(a**2) == {meijerg(((), ()), ((S.Zero,), ()), x), x}
E       assert set() == {meijerg(((),...), ()), x), x}
E         Extra items in the right set:
E         meijerg(((), ()), ((0,), ()), x)
E         x
E         Use -v to get the full diff

sympy/core/tests/test_match.py:692: AssertionError
______________________________________ test_match_issue_21942 _______________________________________

    def test_match_issue_21942():
        a, r, w = symbols('a, r, w', nonnegative=True)
        p = symbols('p', positive=True)
        g_ = Wild('g')
        pattern = g_ ** (1 / (1 - p))
        eq = (a * r ** (1 - p) + w ** (1 - p) * (1 - a)) ** (1 / (1 - p))
        m = {g_: a * r ** (1 - p) + w ** (1 - p) * (1 - a)}
        assert pattern.matches(eq) == m
        assert (-pattern).matches(-eq) == m
>       assert pattern.matches(signsimp(eq)) is None
E       assert {g_: a*r**(1 - p) - w**(1 - p)*(a - 1)} is None
E        +  where {g_: a*r**(1 - p) - w**(1 - p)*(a - 1)} = <bound method Pow.matches of g_**(1/(1 - p))>((a*r**(1 - p) - w**(1 - p)*(a - 1))**(-1/(p - 1)))
E        +    where <bound method Pow.matches of g_**(1/(1 - p))> = g_**(1/(1 - p)).matches
E        +    and   (a*r**(1 - p) - w**(1 - p)*(a - 1))**(-1/(p - 1)) = signsimp((a*r**(1 - p) + w**(1 - p)*(1 - a))**(1/(1 - p)))

sympy/core/tests/test_match.py:729: AssertionError
========================================= warnings summary ==========================================
sympy/core/tests/test_args.py::test_all_classes_are_tested
  /home/oscar/current/sympy/37venv/lib/python3.7/site-packages/theano/scalar/basic.py:2412: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
  Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
    self.ctor = getattr(np, o_type.dtype)

-- Docs: https://docs.pytest.org/en/stable/warnings.html
                                          DO *NOT* COMMIT!                                          
====================================== short test summary info ======================================
FAILED sympy/core/tests/test_expr.py::test_replace - assert -2/x**2 == (2 / x)
FAILED sympy/core/tests/test_match.py::test_issue_3773 - assert None == {c_: z, l_: 1, N_: 2, A_: ...
FAILED sympy/core/tests/test_match.py::test_issue_4559 - assert None == {w_: 3, y_: x}
FAILED sympy/core/tests/test_match.py::test_issue_3539 - assert {a_: 2*x - 2} is None
FAILED sympy/core/tests/test_match.py::test_gh_issue_2711 - assert set() == {meijerg(((),...), ())...
FAILED sympy/core/tests/test_match.py::test_match_issue_21942 - assert {g_: a*r**(1 - p) - w**(1 -...
====== 6 failed, 1864 passed, 69 skipped, 39 xfailed, 1 xpassed, 1 warning in 97.02s (0:01:37) ======

@oscarbenjamin
Copy link
Contributor

Actually there was an unintended change in that diff. Here's a better one:

diff --git a/sympy/core/operations.py b/sympy/core/operations.py
index efacede..c4360c9 100644
--- a/sympy/core/operations.py
+++ b/sympy/core/operations.py
@@ -255,9 +255,6 @@ def _matches_commutative(self, expr, repl_dict=None, old=False):
                 # the matching continue
                 return None
             newexpr = self._combine_inverse(expr, exact)
-            if not old and (expr.is_Add or expr.is_Mul):
-                if newexpr.count_ops() > expr.count_ops():
-                    return None
             newpattern = self._new_rawargs(*wild_part)
             return newpattern.matches(newexpr, repl_dict)
 
diff --git a/sympy/core/power.py b/sympy/core/power.py
index 455481a..f0fb1ea 100644
--- a/sympy/core/power.py
+++ b/sympy/core/power.py
@@ -1558,13 +1558,6 @@ def matches(self, expr, repl_dict=None, old=False):
 
         b, e = expr.as_base_exp()
 
-        # special case number
-        sb, se = self.as_base_exp()
-        if sb.is_Symbol and se.is_Integer and expr:
-            if e.is_rational:
-                return sb.matches(b**(e/se), repl_dict)
-            return sb.matches(expr**(1/se), repl_dict)
-
         d = repl_dict.copy()
         d = self.base.matches(b, d)
         if d is None:

Still has some failures

@mohajain
Copy link
Contributor

I would like to help

@oscarbenjamin
Copy link
Contributor

I would like to help

Go ahead, but it might not be easy. The pattern-matching code is hard to debug.

@oscarbenjamin
Copy link
Contributor

This is as far as I got:

diff --git a/sympy/core/operations.py b/sympy/core/operations.py
index efacede..c4360c9 100644
--- a/sympy/core/operations.py
+++ b/sympy/core/operations.py
@@ -255,9 +255,6 @@ def _matches_commutative(self, expr, repl_dict=None, old=False):
                 # the matching continue
                 return None
             newexpr = self._combine_inverse(expr, exact)
-            if not old and (expr.is_Add or expr.is_Mul):
-                if newexpr.count_ops() > expr.count_ops():
-                    return None
             newpattern = self._new_rawargs(*wild_part)
             return newpattern.matches(newexpr, repl_dict)
 
diff --git a/sympy/core/power.py b/sympy/core/power.py
index 455481a..840a7ba 100644
--- a/sympy/core/power.py
+++ b/sympy/core/power.py
@@ -1556,24 +1556,28 @@ def matches(self, expr, repl_dict=None, old=False):
         if not isinstance(expr, Expr):
             return None
 
-        b, e = expr.as_base_exp()
-
-        # special case number
         sb, se = self.as_base_exp()
-        if sb.is_Symbol and se.is_Integer and expr:
-            if e.is_rational:
-                return sb.matches(b**(e/se), repl_dict)
-            return sb.matches(expr**(1/se), repl_dict)
-
-        d = repl_dict.copy()
-        d = self.base.matches(b, d)
-        if d is None:
-            return None
-
-        d = self.exp.xreplace(d).matches(e, d)
-        if d is None:
-            return Expr.matches(self, expr, repl_dict)
-        return d
+        if expr.is_Rational and expr.p == 1 and expr.q != 1:
+            # handle e.g. Rational(1, 2) as Pow(2, -1)
+            _, denom = expr.as_numer_denom()
+            b, e = denom, S.NegativeOne
+        else:
+            b, e = expr.as_base_exp()
+
+        # First try matching se with e and sb with b
+        # Then try se with -e and sb with 1/b
+        for sign in [1, -1]:
+            d = repl_dict.copy()
+            d = se.matches(sign*e, d)
+            if d is None:
+                continue
+            d = sb.xreplace(d).matches(b**sign, d)
+            if d is None:
+                continue
+            return d
+
+        # Fall back to the generic handler
+        return Expr.matches(self, expr, repl_dict)
 
     def _eval_nseries(self, x, n, logx, cdir=0):
         # NOTE! This function is an important part of the gruntz algorithm

Some of the tests that fail should probably just be changed e.g.:

    def test_issue_3539():
        a = Wild('a')
        x = Symbol('x')
>       assert (x - 2).match(a - x) is None
E       assert {a_: 2*x - 2} is None

I think this test is incorrect. It would be correct if a was declared as Wild('a', exclude=[x])

@mohajain mohajain linked a pull request Oct 18, 2021 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants