Skip to content

Commit

Permalink
Added Crank-Nicolson solver (#513)
Browse files Browse the repository at this point in the history
* Accelerated and improved the implicit Euler sovler
* Updated some documentation on the other solvers
  • Loading branch information
david-zwicker committed Jan 5, 2024
1 parent 01ea2c1 commit 622e8b2
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 41 deletions.
5 changes: 4 additions & 1 deletion pde/solvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@
~explicit.ExplicitSolver
~explicit_mpi.ExplicitMPISolver
~implicit.ImplicitSolver
~crank_nicolson.CrankNicolsonSolver
~scipy.ScipySolver
~registered_solvers
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from typing import List

from .controller import Controller
from .crank_nicolson import CrankNicolsonSolver
from .explicit import ExplicitSolver
from .implicit import ImplicitSolver
from .scipy import ScipySolver
Expand All @@ -43,6 +45,7 @@ def registered_solvers() -> List[str]:
"Controller",
"ExplicitSolver",
"ImplicitSolver",
"CrankNicolsonSolver",
"ScipySolver",
"registered_solvers",
]
6 changes: 5 additions & 1 deletion pde/solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@
from ..tools.typing import BackendType


class ConvergenceError(RuntimeError):
"""indicates that an implicit step did not converge"""


class SolverBase(metaclass=ABCMeta):
"""base class for solvers"""
"""base class for PDE solvers"""

dt_default: float = 1e-3
"""float: default time step used if no time step was specified"""
Expand Down
2 changes: 1 addition & 1 deletion pde/solvers/controller.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Defines a class controlling the simulations of PDEs.
Defines a class controlling the simulations of PDEs
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
Expand Down
124 changes: 124 additions & 0 deletions pde/solvers/crank_nicolson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
Defines a Crank-Nicolson solver
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
from __future__ import annotations

from typing import Callable

import numba as nb
import numpy as np

from ..fields.base import FieldBase
from ..pdes.base import PDEBase
from ..tools.typing import BackendType
from .base import ConvergenceError, SolverBase


class CrankNicolsonSolver(SolverBase):
"""solving partial differential equations using the Crank-Nicolson scheme"""

name = "crank-nicolson"

def __init__(
self,
pde: PDEBase,
*,
maxiter: int = 100,
maxerror: float = 1e-4,
explicit_fraction: float = 0,
backend: BackendType = "auto",
):
"""
Args:
pde (:class:`~pde.pdes.base.PDEBase`):
The instance describing the pde that needs to be solved
maxiter (int):
The maximal number of iterations per step
maxerror (float):
The maximal error that is permitted in each step
explicit_fraction (float):
Hyperparameter determinig the fraction of explicit time stepping in the
implicit step. `explicit_fraction == 0` is the simple Crank-Nicolson
scheme, while `explicit_fraction == 1` reduces to the explicit Euler
method. Intermediate values can improve convergence.
backend (str):
Determines how the function is created. Accepted values are 'numpy` and
'numba'. Alternatively, 'auto' lets the code decide for the most optimal
backend.
"""
super().__init__(pde, backend=backend)
self.maxiter = maxiter
self.maxerror = maxerror
self.explicit_fraction = explicit_fraction

def _make_single_step_fixed_dt(
self, state: FieldBase, dt: float
) -> Callable[[np.ndarray, float], None]:
"""return a function doing a single step with an implicit Euler scheme
Args:
state (:class:`~pde.fields.base.FieldBase`):
An example for the state from which the grid and other information can
be extracted
dt (float):
Time step of the implicit step
"""
if self.pde.is_sde:
raise RuntimeError("Cannot use implicit stepper with stochastic equation")

self.info["function_evaluations"] = 0
self.info["scheme"] = "implicit-euler"
self.info["stochastic"] = False
self.info["dt_adaptive"] = False

rhs = self._make_pde_rhs(state, backend=self.backend)
maxiter = int(self.maxiter)
maxerror2 = self.maxerror**2
α = self.explicit_fraction

# handle deterministic version of the pde
def crank_nicolson_step(state_data: np.ndarray, t: float) -> None:
"""compiled inner loop for speed"""
nfev = 0 # count function evaluations

# keep values at the current time t point used in iteration
state_t = state_data.copy()
rate_t = rhs(state_t, t)

# new state is weighted average of explicit and Crank-Nicolson steps
state_cn = state_t + dt / 2 * (rhs(state_data, t + dt) + rate_t)
state_data[:] = α * state_data + (1 - α) * state_cn
state_prev = np.empty_like(state_data)

# fixed point iteration for improving state after dt
for n in range(maxiter):
state_prev[:] = state_data # keep previous state to judge convergence
# new state is weighted average of explicit and Crank-Nicolson steps
state_cn = state_t + dt / 2 * (rhs(state_data, t + dt) + rate_t)
state_data[:] = α * state_data + (1 - α) * state_cn

# calculate mean squared error
err = 0.0
for j in range(state_data.size):
diff = state_data.flat[j] - state_prev.flat[j]
err += (diff.conjugate() * diff).real
err /= state_data.size

if err < maxerror2:
# fix point iteration converged
break
else:
with nb.objmode:
self._logger.warning(
"Crank-Nicolson step did not converge after %d iterations "
"at t=%g (error=%g)",
maxiter,
t,
err,
)
raise ConvergenceError("Crank-Nicolson step did not converge.")
nfev += n + 2

return crank_nicolson_step
2 changes: 1 addition & 1 deletion pde/solvers/explicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


class ExplicitSolver(AdaptiveSolverBase):
"""class for solving partial differential equations explicitly"""
"""various explicit PDE solvers"""

name = "explicit"

Expand Down
2 changes: 1 addition & 1 deletion pde/solvers/explicit_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class ExplicitMPISolver(ExplicitSolver):
"""class for solving partial differential equations explicitly using MPI
"""various explicit PDE solve using MPI
Warning:
This solver can only be used if MPI is properly installed. In particular, python
Expand Down
34 changes: 14 additions & 20 deletions pde/solvers/implicit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Defines an implicit solver
Defines an implicit Euler solver
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""
Expand All @@ -13,15 +13,11 @@
from ..fields.base import FieldBase
from ..pdes.base import PDEBase
from ..tools.typing import BackendType
from .base import SolverBase


class ConvergenceError(RuntimeError):
"""indicates that the implicit step did not converge"""
from .base import ConvergenceError, SolverBase


class ImplicitSolver(SolverBase):
"""class for solving partial differential equations implicitly"""
"""implicit (backward) Euler PDE solver"""

name = "implicit"

Expand Down Expand Up @@ -78,30 +74,29 @@ def implicit_step(state_data: np.ndarray, t: float) -> None:
"""compiled inner loop for speed"""
nfev = 0 # count function evaluations

# save state at current time point t for stepping
state_t = state_data.copy()

# estimate state at next time point
evolution_last = dt * rhs(state_data, t)
state_data[:] = state_t + dt * rhs(state_data, t)
state_prev = np.empty_like(state_data)

# fixed point iteration for improving state after dt
for n in range(maxiter):
# fixed point iteration for improving state after dt
state_guess = state_data + evolution_last
evolution_this = dt * rhs(state_guess, t + dt)
state_prev[:] = state_data # keep previous state to judge convergence
# another interation to improve estimate
state_data[:] = state_t + dt * rhs(state_data, t + dt)

# calculate mean squared error
# calculate mean squared error to judge convergence
err = 0.0
for j in range(state_data.size):
diff = (
state_guess.flat[j]
- state_data.flat[j]
- evolution_this.flat[j]
)
diff = state_data.flat[j] - state_prev.flat[j]
err += (diff.conjugate() * diff).real
err /= state_data.size

if err < maxerror2:
# fix point iteration converged
break

evolution_last = evolution_this
else:
with nb.objmode:
self._logger.warning(
Expand All @@ -113,6 +108,5 @@ def implicit_step(state_data: np.ndarray, t: float) -> None:
)
raise ConvergenceError("Implicit Euler step did not converge.")
nfev += n + 1
state_data += evolution_this

return implicit_step
5 changes: 3 additions & 2 deletions pde/solvers/scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@


class ScipySolver(SolverBase):
"""class for solving partial differential equations using scipy
"""PDE solver using :func:`scipy.integrate.solve_ivp`.
This class is a thin wrapper around :func:`scipy.integrate.solve_ivp`. In
particular, it supports all the methods implemented by this function.
particular, it supports all the methods implemented by this function and exposes its
arguments, so details can be controlled.
"""

name = "scipy"
Expand Down
23 changes: 16 additions & 7 deletions scripts/performance_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,31 @@

import sys
from pathlib import Path
from typing import Literal

PACKAGE_PATH = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(PACKAGE_PATH))

import numpy as np

from pde import CahnHilliardPDE, Controller, DiffusionPDE, ScalarField, UnitGrid
from pde.solvers import ExplicitSolver, ImplicitSolver, ScipySolver
from pde.solvers import CrankNicolsonSolver, ExplicitSolver, ImplicitSolver, ScipySolver


def main(equation: str = "cahn-hilliard", t_range: float = 100, size: int = 32):
def main(
equation: Literal["diffusion", "cahn-hilliard"] = "cahn-hilliard",
t_range: float = 100,
size: int = 32,
):
"""main routine testing the performance
Args:
equation (str): Chooses the equation to consider
t_range (float): Sets the total duration that should be solved for
size (int): The number of grid points along each axis
equation (str):
Chooses the equation to consider
t_range (float):
Sets the total duration that should be solved for
size (int):
The number of grid points along each axis
"""
print("Reports duration in seconds (smaller is better)\n")

Expand All @@ -48,8 +56,9 @@ def main(equation: str = "cahn-hilliard", t_range: float = 100, size: int = 32):
"Euler, adaptive": (1e-3, ExplicitSolver(eq, scheme="euler", adaptive=True)),
"Runge-Kutta, fixed": (1e-2, ExplicitSolver(eq, scheme="rk", adaptive=False)),
"Runge-Kutta, adaptive": (1e-2, ExplicitSolver(eq, scheme="rk", adaptive=True)),
"implicit": (1e-2, ImplicitSolver(eq)),
"scipy": (None, ScipySolver(eq)),
"Implicit": (1e-2, ImplicitSolver(eq)),
"Crank-Nicolson": (1e-2, CrankNicolsonSolver(eq)),
"Scipy": (None, ScipySolver(eq)),
}

for name, (dt, solver) in solvers.items():
Expand Down
18 changes: 11 additions & 7 deletions tests/solvers/test_generic_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,23 @@
from pde import PDE, DiffusionPDE, FieldCollection, MemoryStorage, ScalarField, UnitGrid
from pde.solvers import (
Controller,
CrankNicolsonSolver,
ExplicitSolver,
ImplicitSolver,
ScipySolver,
registered_solvers,
)
from pde.solvers.base import AdaptiveSolverBase

SOLVER_CLASSES = [ExplicitSolver, ImplicitSolver, CrankNicolsonSolver, ScipySolver]


def test_solver_registration():
"""test solver registration"""
solvers = registered_solvers()
assert "explicit" in solvers
assert "implicit" in solvers
assert "crank-nicolson" in solvers
assert "scipy" in solvers


Expand All @@ -31,7 +35,7 @@ def test_solver_in_pde_class(rng):
eq.solve(field, t_range=1, solver=ScipySolver, tracker=None)


@pytest.mark.parametrize("solver_class", [ExplicitSolver, ImplicitSolver, ScipySolver])
@pytest.mark.parametrize("solver_class", SOLVER_CLASSES)
def test_compare_solvers(solver_class, rng):
"""compare several solvers"""
field = ScalarField.random_uniform(UnitGrid([8, 8]), -1, 1, rng=rng)
Expand All @@ -48,8 +52,9 @@ def test_compare_solvers(solver_class, rng):
np.testing.assert_allclose(s1.data, s2.data, rtol=1e-2, atol=1e-2)


@pytest.mark.parametrize("solver_class", SOLVER_CLASSES)
@pytest.mark.parametrize("backend", ["numpy", "numba"])
def test_solvers_complex(backend):
def test_solvers_complex(solver_class, backend):
"""test solvers with a complex PDE"""
r = FieldCollection.scalar_random_uniform(2, UnitGrid([3]), labels=["a", "b"])
c = r["a"] + 1j * r["b"]
Expand All @@ -61,11 +66,10 @@ def test_solvers_complex(backend):
res_r = eq_r.solve(r, t_range=1e-2, dt=1e-3, backend="numpy", tracker=None)
exp_c = res_r[0].data + 1j * res_r[1].data

for solver_class in [ExplicitSolver, ImplicitSolver, ScipySolver]:
solver = solver_class(eq_c, backend=backend)
controller = Controller(solver, t_range=1e-2, tracker=None)
res_c = controller.run(c, dt=1e-3)
np.testing.assert_allclose(res_c.data, exp_c, rtol=1e-3, atol=1e-3)
solver = solver_class(eq_c, backend=backend)
controller = Controller(solver, t_range=1e-2, tracker=None)
res_c = controller.run(c, dt=1e-3)
np.testing.assert_allclose(res_c.data, exp_c, rtol=1e-3, atol=1e-3)


def test_basic_adaptive_solver():
Expand Down

0 comments on commit 622e8b2

Please sign in to comment.