Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve dsolve #21743

Merged
merged 11 commits into from
Jul 22, 2021
4 changes: 2 additions & 2 deletions doc/src/modules/solvers/ode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ nth_algebraic

nth_order_reducible
^^^^^^^^^^^^^^^^^^^
.. autoclass:: sympy.solvers.ode.ode::NthOrderReducible
.. autoclass:: sympy.solvers.ode.single::NthOrderReducible
:members:

separable
Expand All @@ -187,7 +187,7 @@ separable_reduced

lie_group
^^^^^^^^^
.. autoclass:: sympy.solvers.ode.ode::LieGroup
.. autoclass:: sympy.solvers.ode.single::LieGroup
:members:

2nd_hypergeometric
Expand Down
2 changes: 1 addition & 1 deletion sympy/solvers/deutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _desolve(eq, func=None, hint="default", ics=None, simplify=True, *, prep=Tru
# recursive calls.
if kwargs.get('classify', True):
hints = classifier(eq, func, dict=True, ics=ics, xi=xi, eta=eta,
n=terms, x0=x0, prep=prep)
n=terms, x0=x0, hint=hint, prep=prep)

else:
# Here is what all this means:
Expand Down
70 changes: 23 additions & 47 deletions sympy/solvers/ode/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def recur_len(l):
if all_:
retdict = {}
failed_hints = {}
gethints = classify_ode(eq, dict=True)
gethints = classify_ode(eq, dict=True, hint='all')
orderedhints = gethints['ordered_hints']
for hint in hints:
try:
Expand Down Expand Up @@ -1014,45 +1014,29 @@ class in it. Note that a hint may do this anyway if
else:
raise ValueError("Enter boundary conditions of the form ics={f(point): value, f(x).diff(x, order).subs(x, point): value}")

# Any ODE that can be solved with a combination of algebra and
# integrals e.g.:
# d^3/dx^3(x y) = F(x)
ode = SingleODEProblem(eq_orig, func, x, prep=prep, xi=xi, eta=eta)
solvers = {
NthAlgebraic: ('nth_algebraic',),
FirstExact:('1st_exact',),
FirstLinear: ('1st_linear',),
AlmostLinear: ('almost_linear',),
Bernoulli: ('Bernoulli',),
Factorable: ('factorable',),
RiccatiSpecial: ('Riccati_special_minus2',),
SecondNonlinearAutonomousConserved: ('2nd_nonlinear_autonomous_conserved',),
Liouville: ('Liouville',),
Separable: ('separable',),
SeparableReduced: ('separable_reduced',),
HomogeneousCoeffSubsDepDivIndep: ('1st_homogeneous_coeff_subs_dep_div_indep',),
HomogeneousCoeffSubsIndepDivDep: ('1st_homogeneous_coeff_subs_indep_div_dep',),
HomogeneousCoeffBest: ('1st_homogeneous_coeff_best',),
LinearCoefficients: ('linear_coefficients',),
NthOrderReducible: ('nth_order_reducible',),
SecondHypergeometric: ('2nd_hypergeometric',),
NthLinearConstantCoeffHomogeneous: ('nth_linear_constant_coeff_homogeneous',),
NthLinearConstantCoeffVariationOfParameters: ('nth_linear_constant_coeff_variation_of_parameters',),
NthLinearConstantCoeffUndeterminedCoefficients: ('nth_linear_constant_coeff_undetermined_coefficients',),
NthLinearEulerEqHomogeneous: ('nth_linear_euler_eq_homogeneous',),
NthLinearEulerEqNonhomogeneousVariationOfParameters: ('nth_linear_euler_eq_nonhomogeneous_variation_of_parameters',),
NthLinearEulerEqNonhomogeneousUndeterminedCoefficients: ('nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients',),
SecondLinearBessel: ('2nd_linear_bessel',),
SecondLinearAiry: ('2nd_linear_airy',),
LieGroup: ('lie_group',),
}
for solvercls in solvers:
solver = solvercls(ode)
user_hint = kwargs.get('hint','default')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be a space after a comma.

# Used when dsolve is called without an explicit hint.
# We exit early to return the first valid match
early_exit = (user_hint=='default')
if user_hint.endswith('_Integral'):
user_hint = user_hint[:-len('_Integral')]
user_map = solver_map
# An explicit hint has been given to dsolve
# Skip matching code for other hints
if user_hint not in ['default', 'all', 'all_Integral', 'best'] and user_hint in solver_map:
user_map = {user_hint: solver_map[user_hint]}

for hint in user_map:
solver = user_map[hint](ode)
if solver.matches():
for hints in solvers[solvercls]:
matching_hints[hints] = solver
if solvercls.has_integral:
matching_hints[hints + "_Integral"] = solver
matching_hints[hint] = solver
if user_map[hint].has_integral:
matching_hints[hint + "_Integral"] = solver
if dict and early_exit:
matching_hints["default"] = hint
return matching_hints

eq = expand(eq)
# Precondition to try remove f(x) from highest order derivative
reduced_eq = None
Expand Down Expand Up @@ -3581,12 +3565,4 @@ def _nonlinear_3eq_order1_type5(x, y, z, t, eq):


#This import is written at the bottom to avoid circular imports.
from .single import (NthAlgebraic, Factorable, FirstLinear, AlmostLinear,
Bernoulli, SingleODEProblem, SingleODESolver, RiccatiSpecial,
SecondNonlinearAutonomousConserved, FirstExact, Liouville, Separable,
SeparableReduced, HomogeneousCoeffSubsDepDivIndep, HomogeneousCoeffSubsIndepDivDep,
HomogeneousCoeffBest, LinearCoefficients, NthOrderReducible, SecondHypergeometric,
NthLinearConstantCoeffHomogeneous, NthLinearConstantCoeffVariationOfParameters,
NthLinearConstantCoeffUndeterminedCoefficients, NthLinearEulerEqHomogeneous,
NthLinearEulerEqNonhomogeneousVariationOfParameters, LieGroup,
NthLinearEulerEqNonhomogeneousUndeterminedCoefficients, SecondLinearBessel, SecondLinearAiry)
from .single import SingleODEProblem, SingleODESolver, solver_map
29 changes: 29 additions & 0 deletions sympy/solvers/ode/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -2899,5 +2899,34 @@ def _get_general_solution(self, *, simplify_flag: bool = True):
return desols


solver_map = {
'factorable': Factorable,
'nth_algebraic': NthAlgebraic,
'separable': Separable,
'1st_exact': FirstExact,
'1st_linear': FirstLinear,
'Bernoulli': Bernoulli,
'Riccati_special_minus2': RiccatiSpecial,
'1st_homogeneous_coeff_best': HomogeneousCoeffBest,
'1st_homogeneous_coeff_subs_indep_div_dep': HomogeneousCoeffSubsIndepDivDep,
'1st_homogeneous_coeff_subs_dep_div_indep': HomogeneousCoeffSubsDepDivIndep,
'almost_linear': AlmostLinear,
'linear_coefficients': LinearCoefficients,
'separable_reduced': SeparableReduced,
'lie_group': LieGroup,
'nth_linear_constant_coeff_homogeneous': NthLinearConstantCoeffHomogeneous,
'nth_linear_euler_eq_homogeneous': NthLinearEulerEqHomogeneous,
'nth_linear_constant_coeff_undetermined_coefficients': NthLinearConstantCoeffUndeterminedCoefficients,
'nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients': NthLinearEulerEqNonhomogeneousUndeterminedCoefficients,
'nth_linear_constant_coeff_variation_of_parameters': NthLinearConstantCoeffVariationOfParameters,
'nth_linear_euler_eq_nonhomogeneous_variation_of_parameters': NthLinearEulerEqNonhomogeneousVariationOfParameters,
'Liouville': Liouville,
'2nd_linear_airy': SecondLinearAiry,
'2nd_linear_bessel': SecondLinearBessel,
'2nd_hypergeometric': SecondHypergeometric,
'nth_order_reducible': NthOrderReducible,
'2nd_nonlinear_autonomous_conserved': SecondNonlinearAutonomousConserved,
}

# Avoid circular import:
from .ode import dsolve, ode_sol_simplicity, odesimp, homogeneous_order