Skip to content

Commit

Permalink
Merge pull request #15414 from oscarbenjamin/dsolve_algebraic
Browse files Browse the repository at this point in the history
Fix redundant solution checking in nth_algebraic
  • Loading branch information
asmeurer committed Oct 31, 2018
2 parents 60347e0 + 1b7e8aa commit cd98ba0
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 38 deletions.
94 changes: 56 additions & 38 deletions sympy/solvers/ode.py
Expand Up @@ -249,7 +249,8 @@
from sympy.core.symbol import Symbol, Wild, Dummy, symbols
from sympy.core.sympify import sympify

from sympy.logic.boolalg import BooleanAtom, And, Or, Not
from sympy.logic.boolalg import (BooleanAtom, And, Or, Not, BooleanTrue,
BooleanFalse)
from sympy.functions import cos, exp, im, log, re, sin, tan, sqrt, \
atan2, conjugate, Piecewise
from sympy.functions.combinatorial.factorials import factorial
Expand Down Expand Up @@ -4102,61 +4103,78 @@ def _nth_algebraic_remove_redundant_solutions(eq, solns, order, var):
f(x) = -x + C1 and in this case the two solutions are not equivalent wrt
initial conditions so both should be returned.
"""
# I believe that any algebraic solutions can only emerge before *any*
# integrations occur (although I haven't proved this and it depends on the
# particular way that diffx is defined at the time of writing). This means
# that an algebraic solution for f(x) will not have any integration
# constants and any integral solution will have a number of constants that
# matches the order of the ODE.
solns_algebraic = []
solns_integral = {} # {soln1: constants1, ...}
for soln in solns:
constants = soln.free_symbols - eq.free_symbols
if len(constants) == 0:
solns_algebraic.append(soln)
elif len(constants) == order:
solns_integral[soln] = constants
else:
assert False, "Solution should have 0 or order constants..."

# Compare each algebraic solution with each integral solution to remove
# redundant algebraic solutions.
solns = solns[:]
for soln in solns_algebraic:
for soln_integral, constants in solns_integral.items():
if _nth_algebraic_is_special_case_of(soln, soln_integral, constants, var):
solns.remove(soln)
def is_special_case_of(soln1, soln2):
return _nth_algebraic_is_special_case_of(soln1, soln2, eq, order, var)

unique_solns = []
for soln1 in solns:
for soln2 in unique_solns[:]:
if is_special_case_of(soln1, soln2):
break
elif is_special_case_of(soln2, soln1):
unique_solns.remove(soln2)
else:
unique_solns.append(soln1)

return solns
return unique_solns

def _nth_algebraic_is_special_case_of(soln1, soln2, constants2, var):
def _nth_algebraic_is_special_case_of(soln1, soln2, eq, order, var):
r"""
True if soln1 is found to be a special case of soln2 wrt some value of the
constants that appear in soln2. False otherwise.
"""
# The solutions returned by nth_algebraic should be given explicitly as in
# Eq(f(x), expr). We will equate the RHSs and try to solve for the
# integration constants. If we get any solutions for the constants that
# don't depend on x then that shows that those values of the constants
# make soln1 a special case of soln2 impying that soln1 is redundant.
# Eq(f(x), expr). We will equate the RHSs of the two solutions giving an
# equation f1(x) = f2(x).
#
# Since this is supposed to hold for all x it also holds for derivatives
# f1'(x) and f2'(x). For an order n ode we should be able to differentiate
# each solution n times to get n+1 equations.
#
# We then try to solve those n+1 equations for the integrations constants
# in f2(x). If we can find a solution that doesn't depend on x then it
# means that some value of the constants in f1(x) is a special case of
# f2(x) corresponding to a paritcular choice of the integration constants.

constants1 = soln1.free_symbols.difference(eq.free_symbols)
constants2 = soln2.free_symbols.difference(eq.free_symbols)

constants1_new = get_numbered_constants(soln1.rhs - soln2.rhs, len(constants1))
if len(constants1) == 1:
constants1_new = {constants1_new}
for c_old, c_new in zip(constants1, constants1_new):
soln1 = soln1.subs(c_old, c_new)

# n equations for f1(x)=f2(x), f1'(x)=f2'(x), ...
lhs = soln1.rhs.doit()
rhs = soln2.rhs.doit()
eqns = [Eq(lhs, rhs)]
for n in range(1, order):
lhs = lhs.diff(var)
rhs = rhs.diff(var)
eq = Eq(lhs, rhs)
eqns.append(eq)

# BooleanTrue/False awkwardly show up for trivial equations
if any(isinstance(eq, BooleanFalse) for eq in eqns):
return False
eqns = [eq for eq in eqns if not isinstance(eq, BooleanTrue)]

constant_solns = solve(Eq(soln1.rhs, soln2.rhs), constants2)
constant_solns = solve(eqns, constants2)

# Handling all the types potentially returned by solve is awkward...
# Sometimes returns a dict and sometimes a list of dicts
if isinstance(constant_solns, dict):
constant_solns = list(constant_solns.values())
elif not isinstance(constant_solns, list):
constant_solns = [constant_solns]
if len(constants2) == 1:
constant_solns = [[soln] for soln in constant_solns]

# If any solution gives all constants as expressions that don't depend on
# x then there exists constants for soln2 that give soln1
for constant_soln in constant_solns:
if not any(c.has(var) for c in constant_soln):
if not any(c.has(var) for c in constant_soln.values()):
return True
else:
return False


def _nth_linear_match(eq, func, order):
r"""
Matches a differential equation to the linear form:
Expand Down
30 changes: 30 additions & 0 deletions sympy/solvers/tests/test_ode.py
Expand Up @@ -2975,9 +2975,39 @@ def test_nth_algebraic():

eqn = (1 - sin(f(x))) * f(x).diff(x)
sol = Eq(f(x), C1)
assert checkodesol(eqn, sol, order=1, solve_for_func=False)[0]
assert sol == dsolve(eqn, f(x), hint='nth_algebraic')
assert sol == dsolve(eqn, f(x))

M, m, r, t = symbols('M m r t')
phi = Function('phi')
eqn = Eq(-M * phi(t).diff(t),
Rational(3, 2) * m * r**2 * phi(t).diff(t) * phi(t).diff(t,t))
solns = [Eq(phi(t), C1), Eq(phi(t), C1 + C2*t - M*t**2/(3*m*r**2))]
assert checkodesol(eqn, solns[0], order=2, solve_for_func=False)[0]
assert checkodesol(eqn, solns[1], order=2, solve_for_func=False)[0]
assert set(solns) == set(dsolve(eqn, phi(t), hint='nth_algebraic'))
assert set(solns) == set(dsolve(eqn, phi(t)))

eqn = f(x) * f(x).diff(x) * f(x).diff(x, x)
sol = Eq(f(x), C1 + C2*x)
assert checkodesol(eqn, sol, order=1, solve_for_func=False)[0]
assert sol == dsolve(eqn, f(x), hint='nth_algebraic')
assert sol == dsolve(eqn, f(x))

eqn = f(x) * f(x).diff(x) * f(x).diff(x, x) * (f(x) - 1)
sol = Eq(f(x), C1 + C2*x)
assert checkodesol(eqn, sol, order=1, solve_for_func=False)[0]
assert sol == dsolve(eqn, f(x), hint='nth_algebraic')
assert sol == dsolve(eqn, f(x))

eqn = f(x) * f(x).diff(x) * f(x).diff(x, x) * (f(x) - 1) * (f(x).diff(x) - x)
solns = [Eq(f(x), C1 + x**2/2), Eq(f(x), C1 + C2*x)]
assert checkodesol(eqn, solns[0], order=2, solve_for_func=False)[0]
assert checkodesol(eqn, solns[1], order=2, solve_for_func=False)[0]
assert set(solns) == set(dsolve(eqn, f(x), hint='nth_algebraic'))
assert set(solns) == set(dsolve(eqn, f(x)))


def test_nth_algebraic_redundant_solutions():
# This one has a redundant solution that should be removed
Expand Down

0 comments on commit cd98ba0

Please sign in to comment.