Skip to content

Commit

Permalink
Merge pull request #19998 from mijo2/simplification
Browse files Browse the repository at this point in the history
[GSoC] Simplification for the solutions of systems of ODEs
  • Loading branch information
oscarbenjamin committed Oct 2, 2020
2 parents 33cdd3b + 5cac44d commit 8064e60
Show file tree
Hide file tree
Showing 3 changed files with 946 additions and 865 deletions.
2 changes: 1 addition & 1 deletion sympy/solvers/ode/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def dsolve(eq, func=None, hint="default", simplify=True,
# changed to show the systems that haven't
# been solved.
try:
sol = dsolve_system(eq, funcs=func, ics=ics)
sol = dsolve_system(eq, funcs=func, ics=ics, doit=True)
return sol[0] if len(sol) == 1 else sol
except NotImplementedError:
pass
Expand Down
170 changes: 160 additions & 10 deletions sympy/solvers/ode/systems.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sympy.core import Add, Mul
from sympy.core import Add, Mul, S
from sympy.core.containers import Tuple
from sympy.core.compatibility import iterable
from sympy.core.exprtools import factor_terms
Expand All @@ -11,8 +11,8 @@
piecewise_fold, sqrt, log)
from sympy.functions.combinatorial.factorials import factorial
from sympy.matrices import zeros, Matrix, NonSquareMatrixError, MatrixBase, eye
from sympy.polys import Poly
from sympy.simplify import collect
from sympy.polys import Poly, together
from sympy.simplify import collect, radsimp, signsimp
from sympy.simplify.powsimp import powdenest, powsimp
from sympy.simplify.ratsimp import ratsimp
from sympy.simplify.simplify import simplify
Expand Down Expand Up @@ -70,6 +70,137 @@ def _solsimp(e, t):
return no_t + has_t


def simpsol(sol, wrt1, wrt2, doit=True):
"""Simplify solutions from dsolve_system."""

# The parameter sol is the solution as returned by dsolve (list of Eq).
#
# The parameters wrt1 and wrt2 are lists of symbols to be collected for
# with those in wrt1 being collected for first. This allows for collecting
# on any factors involving the independent variable before collecting on
# the integration constants or vice versa using e.g.:
#
# sol = simpsol(sol, [t], [C1, C2]) # t first, constants after
# sol = simpsol(sol, [C1, C2], [t]) # constants first, t after
#
# If doit=True (default) then simpsol will begin by evaluating any
# unevaluated integrals. Since many integrals will appear multiple times
# in the solutions this is done intelligently by computing each integral
# only once.
#
# The strategy is to first perform simple cancellation with factor_terms
# and then multiply out all brackets with expand_mul. This gives an Add
# with many terms.
#
# We split each term into two multiplicative factors dep and coeff where
# all factors that involve wrt1 are in dep and any constant factors are in
# coeff e.g.
# sqrt(2)*C1*exp(t) -> ( exp(t) , sqrt(2)*C1 )
#
# The dep factors are simplified using powsimp to combine expanded
# exponential factors e.g.
# exp(a*t)*exp(b*t) -> exp(t*(a+b))
#
# We then collect coefficients for all terms having the same (simplified)
# dep. The coefficients are then simplified using together and ratsimp and
# lastly by recursively applying the same transformation to the
# coefficients to collect on wrt2.
#
# Finally the result is recombined into an Add and signsimp is used to
# normalise any minus signs.

def simprhs(rhs, rep, wrt1, wrt2):
"""Simplify the rhs of an ODE solution"""
if rep:
rhs = rhs.subs(rep)
rhs = factor_terms(rhs)
rhs = simp_coeff_dep(rhs, wrt1, wrt2)
rhs = signsimp(rhs)
return rhs

def simp_coeff_dep(expr, wrt1, wrt2=None):
"""Split rhs into terms, split terms into dep and coeff and collect on dep"""
add_dep_terms = lambda e: e.is_Add and e.has(*wrt1)
expandable = lambda e: e.is_Mul and any(map(add_dep_terms, e.args))
expand_func = lambda e: expand_mul(e, deep=False)
expand_mul_mod = lambda e: e.replace(expandable, expand_func)
terms = Add.make_args(expand_mul_mod(expr))
dc = {}
for term in terms:
coeff, dep = term.as_independent(*wrt1, as_Add=False)
# Collect together the coefficients for terms that have the same
# dependence on wrt1 (after dep is normalised using simpdep).
dep = simpdep(dep, wrt1)

# See if the dependence on t cancels out...
if dep is not S.One:
dep2 = factor_terms(dep)
if not dep2.has(*wrt1):
coeff *= dep2
dep = S.One

if dep not in dc:
dc[dep] = coeff
else:
dc[dep] += coeff
# Apply the method recursively to the coefficients but this time
# collecting on wrt2 rather than wrt2.
termpairs = ((simpcoeff(c, wrt2), d) for d, c in dc.items())
if wrt2 is not None:
termpairs = ((simp_coeff_dep(c, wrt2), d) for c, d in termpairs)
return Add(*(c * d for c, d in termpairs))

def simpdep(term, wrt1):
"""Normalise factors involving t with powsimp and recombine exp"""
def canonicalise(a):
# Using factor_terms here isn't quite right because it leads to things
# like exp(t*(1+t)) that we don't want. We do want to cancel factors
# and pull out a common denominator but ideally the numerator would be
# expressed as a standard form polynomial in t so we expand_mul
# and collect afterwards.
a = factor_terms(a)
num, den = a.as_numer_denom()
num = expand_mul(num)
num = collect(num, wrt1)
return num / den

term = powsimp(term)
rep = {e: exp(canonicalise(e.args[0])) for e in term.atoms(exp)}
term = term.subs(rep)
return term

def simpcoeff(coeff, wrt2):
"""Bring to a common fraction and cancel with ratsimp"""
coeff = together(coeff)
if coeff.is_polynomial():
# Calling ratsimp can be expensive. The main reason is to simplify
# sums of terms with irrational denominators so we limit ourselves
# to the case where the expression is polynomial in any symbols.
# Maybe there's a better approach...
coeff = ratsimp(radsimp(coeff))
# collect on secondary variables first and any remaining symbols after
if wrt2 is not None:
syms = list(wrt2) + list(ordered(coeff.free_symbols - set(wrt2)))
else:
syms = list(ordered(coeff.free_symbols))
coeff = collect(coeff, syms)
coeff = together(coeff)
return coeff

# There are often repeated integrals. Collect unique integrals and
# evaluate each once and then substitute into the final result to replace
# all occurrences in each of the solution equations.
if doit:
integrals = set().union(*(s.atoms(Integral) for s in sol))
rep = {i: factor_terms(i).doit() for i in integrals}
else:
rep = {}

sol = [Eq(s.lhs, simprhs(s.rhs, rep, wrt1, wrt2)) for s in sol]

return sol


def linodesolve_type(A, t, b=None):
r"""
Helper function that determines the type of the system of ODEs for solving with :obj:`sympy.solvers.ode.systems.linodesolve()`
Expand Down Expand Up @@ -1518,7 +1649,8 @@ def _higher_order_ode_solver(match):
else:
new_eqs, new_funcs = _higher_order_to_first_order(eqs, sysorder, t, funcs=funcs,
type=type, J=match.get('J', None),
f_t=match.get('f(t)', None))
f_t=match.get('f(t)', None),
P=match.get('P', None), b=match.get('rhs', None))

if is_transformed:
t = match.get('t_', t)
Expand Down Expand Up @@ -1728,7 +1860,7 @@ def _second_order_to_first_order(eqs, funcs, t, type="auto", A1=None,
return _higher_order_to_first_order(eqs, sys_order, t, funcs=funcs)


def _higher_order_type2_to_sub_systems(J, f_t, funcs, t, max_order):
def _higher_order_type2_to_sub_systems(J, f_t, funcs, t, max_order, b=None, P=None):

# Note: To add a test for this ValueError
if J is None or f_t is None or not _matrix_is_constant(J, t):
Expand All @@ -1737,8 +1869,17 @@ def _higher_order_type2_to_sub_systems(J, f_t, funcs, t, max_order):
Type 2
'''))

if P is None and b is not None and not b.is_zero_matrix:
raise ValueError(filldedent('''
Provide the keyword 'P' for matrix P in A = P * J * P-1.
'''))

new_funcs = Matrix([Function(Dummy('{}__0'.format(f.func.__name__)))(t) for f in funcs])
new_eqs = new_funcs.diff(t, max_order) - f_t * J * new_funcs

if b is not None and not b.is_zero_matrix:
new_eqs -= P.inv() * b

new_eqs = canonical_odes(new_eqs, new_funcs, t)[0]

return new_eqs, new_funcs
Expand Down Expand Up @@ -1804,9 +1945,11 @@ def _get_coeffs_from_subs_expression(expr):
if type == "type2":
J = kwargs.get('J', None)
f_t = kwargs.get('f_t', None)
b = kwargs.get('b', None)
P = kwargs.get('P', None)
max_order = max(sys_order[func] for func in funcs)

return _higher_order_type2_to_sub_systems(J, f_t, funcs, t, max_order)
return _higher_order_type2_to_sub_systems(J, f_t, funcs, t, max_order, P=P, b=b)

# Note: To be changed to this after doit option is disabled for default cases
# new_sysorder = _get_func_order(new_eqs, new_funcs)
Expand Down Expand Up @@ -1836,7 +1979,7 @@ def _get_coeffs_from_subs_expression(expr):
return eqs, new_funcs


def dsolve_system(eqs, funcs=None, t=None, ics=None, doit=False):
def dsolve_system(eqs, funcs=None, t=None, ics=None, doit=False, simplify=True):
r"""
Solves any(supported) system of Ordinary Differential Equations
Expand Down Expand Up @@ -1877,7 +2020,13 @@ def dsolve_system(eqs, funcs=None, t=None, ics=None, doit=False):
ics : Dict or None
Set of initial boundary/conditions for the system of ODEs
doit : Boolean
Evaluate the solutions if True. Default value is False
Evaluate the solutions if True. Default value is True. Can be
set to false if the integral evaluation takes too much time and/or
isn't required.
simplify: Boolean
Simplify the solutions for the systems. Default value is True.
Can be set to false if simplification takes too much time and/or
isn't required.
Examples
========
Expand Down Expand Up @@ -1984,8 +2133,9 @@ def dsolve_system(eqs, funcs=None, t=None, ics=None, doit=False):
solved_constants = solve_ics(sol, funcs, constants, ics)
sol = [s.subs(solved_constants) for s in sol]

if doit:
sol = [s.doit() for s in sol]
if simplify:
constants = Tuple(*sol).free_symbols - variables
sol = simpsol(sol, [t], constants, doit=doit)

final_sols.append(sol)

Expand Down

0 comments on commit 8064e60

Please sign in to comment.