Skip to content

Commit

Permalink
Merge pull request #2907 from pybamm-team/issue-2670-interp2d
Browse files Browse the repository at this point in the history
Issue 2670 interp2d
  • Loading branch information
valentinsulzer committed Apr 30, 2023
2 parents 9b4490b + 9ebbb11 commit dc4855c
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 269 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Expand Up @@ -7,6 +7,10 @@
- PyBaMM is now supported on Python `3.10` and `3.11` ([#2435](https://github.com/pybamm-team/PyBaMM/pull/2435))
- Updated to casadi 3.6, which required some changes to the casadi integrator. ([#2859](https://github.com/pybamm-team/PyBaMM/pull/2859))

# Optimizations

- Fixed deprecated `interp2d` method by switching to `xarray.DataArray` as the backend for `ProcessedVariable` ([#2907](https://github.com/pybamm-team/PyBaMM/pull/2907))

## Bug fixes

- Parameter sets can now contain the key "chemistry", and will ignore its value (this previously would give errors in some cases) ([#2901](https://github.com/pybamm-team/PyBaMM/pull/2901))
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Expand Up @@ -10,6 +10,7 @@ imageio>=2.9.0
jupyter # For example notebooks
pybtex
sympy >= 1.8
xarray
# Note: Matplotlib is loaded for debug plots but to ensure pybamm runs
# on systems without an attached display it should never be imported
# outside of plot() methods.
Expand Down
205 changes: 23 additions & 182 deletions pybamm/solvers/processed_variable.py
Expand Up @@ -5,8 +5,8 @@
import numbers
import numpy as np
import pybamm
import scipy.interpolate as interp
from scipy.integrate import cumulative_trapezoid
import xarray as xr


class ProcessedVariable(object):
Expand Down Expand Up @@ -131,18 +131,7 @@ def initialise_0D(self):
)

# set up interpolation
if len(self.t_pts) == 1:
# Variable is just a scalar value, but we need to create a callable
# function to be consistent with other processed variables
self._interpolation_function = Interpolant0D(entries)
else:
self._interpolation_function = interp.interp1d(
self.t_pts,
entries,
kind="linear",
fill_value=np.nan,
bounds_error=False,
)
self._xr_data_array = xr.DataArray(entries, coords=[("t", self.t_pts)])

self.entries = entries
self.dimensions = 0
Expand Down Expand Up @@ -211,22 +200,10 @@ def initialise_1D(self, fixed_t=False):
self.first_dim_pts = edges

# set up interpolation
if len(self.t_pts) == 1:
# function of space only
self._interpolation_function = Interpolant1D(
pts_for_interp, entries_for_interp
)
else:
# function of space and time. Note that the order of 't' and 'space'
# is the reverse of what you'd expect
self._interpolation_function = interp.interp2d(
self.t_pts,
pts_for_interp,
entries_for_interp,
kind="linear",
fill_value=np.nan,
bounds_error=False,
)
self._xr_data_array = xr.DataArray(
entries_for_interp,
coords=[(self.first_dimension, pts_for_interp), ("t", self.t_pts)],
)

def initialise_2D(self):
"""
Expand Down Expand Up @@ -362,21 +339,14 @@ def initialise_2D(self):
self.second_dim_pts = second_dim_edges

# set up interpolation
if len(self.t_pts) == 1:
# function of space only. Note the order of the points is the reverse
# of what you'd expect
self._interpolation_function = Interpolant2D(
first_dim_pts_for_interp, second_dim_pts_for_interp, entries_for_interp
)
else:
# function of space and time.
self._interpolation_function = interp.RegularGridInterpolator(
(first_dim_pts_for_interp, second_dim_pts_for_interp, self.t_pts),
entries_for_interp,
method="linear",
fill_value=np.nan,
bounds_error=False,
)
self._xr_data_array = xr.DataArray(
entries_for_interp,
coords={
self.first_dimension: first_dim_pts_for_interp,
self.second_dimension: second_dim_pts_for_interp,
"t": self.t_pts,
},
)

def initialise_2D_scikit_fem(self):
y_sol = self.mesh.edges["y"]
Expand Down Expand Up @@ -411,74 +381,21 @@ def initialise_2D_scikit_fem(self):
self.second_dim_pts = z_sol

# set up interpolation
if len(self.t_pts) == 1:
# function of space only. Note the order of the points is the reverse
# of what you'd expect
self._interpolation_function = Interpolant2D(
self.first_dim_pts, self.second_dim_pts, entries
)
else:
# function of space and time.
self._interpolation_function = interp.RegularGridInterpolator(
(self.first_dim_pts, self.second_dim_pts, self.t_pts),
entries,
method="linear",
fill_value=np.nan,
bounds_error=False,
)
self._xr_data_array = xr.DataArray(
entries,
coords={"y": y_sol, "z": z_sol, "t": self.t_pts},
)

def __call__(self, t=None, x=None, r=None, y=None, z=None, R=None, warn=True):
"""
Evaluate the variable at arbitrary *dimensional* t (and x, r, y, z and/or R),
using interpolation
"""
# If t is None and there is only one value of time in the soluton (i.e.
# the solution is independent of time) then we set t equal to the value
# stored in the solution. If the variable is constant (doesn't depend on
# time) evaluate arbitrarily at the first value of t. Otherwise, raise
# an error
if t is None:
if len(self.t_pts) == 1:
t = self.t_pts
elif len(self.base_variables) == 1 and self.base_variables[0].is_constant():
t = self.t_pts[0]
else:
raise ValueError(
"t cannot be None for variable {}".format(self.base_variables)
)

# Call interpolant of correct spatial dimension
if self.dimensions == 0:
out = self._interpolation_function(t)
elif self.dimensions == 1:
out = self.call_1D(t, x, r, z, R)
elif self.dimensions == 2:
out = self.call_2D(t, x, r, y, z, R)
if warn is True and np.isnan(out).any():
pybamm.logger.warning(
"Calling variable outside interpolation range (returns 'nan')"
)
return out

def call_1D(self, t, x, r, z, R):
"""Evaluate a 1D variable"""
spatial_var = eval_dimension_name(self.first_dimension, x, r, None, z, R)
return self._interpolation_function(t, spatial_var)

def call_2D(self, t, x, r, y, z, R):
"""Evaluate a 2D variable"""
first_dim = eval_dimension_name(self.first_dimension, x, r, y, z, R)
second_dim = eval_dimension_name(self.second_dimension, x, r, y, z, R)
if isinstance(first_dim, np.ndarray):
if isinstance(second_dim, np.ndarray) and isinstance(t, np.ndarray):
first_dim = first_dim[:, np.newaxis, np.newaxis]
second_dim = second_dim[:, np.newaxis]
elif isinstance(second_dim, np.ndarray) or isinstance(t, np.ndarray):
first_dim = first_dim[:, np.newaxis]
else:
if isinstance(second_dim, np.ndarray) and isinstance(t, np.ndarray):
second_dim = second_dim[:, np.newaxis]
return self._interpolation_function((first_dim, second_dim, t))
kwargs = {"t": t, "x": x, "r": r, "y": y, "z": z, "R": R}
# Remove any None arguments
kwargs = {key: value for key, value in kwargs.items() if value is not None}
# Use xarray interpolation, return numpy array
return self._xr_data_array.interp(**kwargs).values

@property
def data(self):
Expand Down Expand Up @@ -564,79 +481,3 @@ def initialise_sensitivity_explicit_forward(self):

# Save attribute
self._sensitivities = sensitivities


class Interpolant0D:
def __init__(self, entries):
self.entries = entries

def __call__(self, t):
return self.entries


class Interpolant1D:
def __init__(self, pts_for_interp, entries_for_interp):
self.interpolant = interp.interp1d(
pts_for_interp,
entries_for_interp[:, 0],
kind="linear",
fill_value=np.nan,
bounds_error=False,
)

def __call__(self, t, z):
if isinstance(z, np.ndarray):
return self.interpolant(z)[:, np.newaxis]
else:
return self.interpolant(z)


class Interpolant2D:
def __init__(
self, first_dim_pts_for_interp, second_dim_pts_for_interp, entries_for_interp
):
self.interpolant = interp.interp2d(
second_dim_pts_for_interp,
first_dim_pts_for_interp,
entries_for_interp[:, :, 0],
kind="linear",
fill_value=np.nan,
bounds_error=False,
)

def __call__(self, input):
"""
Calls and returns a 2D interpolant of the correct shape depending on the
shape of the input
"""
first_dim, second_dim, _ = input
if isinstance(first_dim, np.ndarray) and isinstance(second_dim, np.ndarray):
first_dim = first_dim[:, 0, 0]
second_dim = second_dim[:, 0]
return self.interpolant(second_dim, first_dim)
elif isinstance(first_dim, np.ndarray):
first_dim = first_dim[:, 0]
return self.interpolant(second_dim, first_dim)[:, 0]
elif isinstance(second_dim, np.ndarray):
second_dim = second_dim[:, 0]
return self.interpolant(second_dim, first_dim)
else:
return self.interpolant(second_dim, first_dim)[0]


def eval_dimension_name(name, x, r, y, z, R):
if name == "x":
out = x
elif name == "r":
out = r
elif name == "y":
out = y
elif name == "z":
out = z
elif name == "R":
out = R

if out is None:
raise ValueError("inputs {} cannot be None".format(name))
else:
return out
1 change: 1 addition & 0 deletions requirements.txt
Expand Up @@ -8,6 +8,7 @@ casadi >= 3.6.0
imageio>=2.9.0
pybtex>=0.24.0
sympy >= 1.8
xarray
bpx
tqdm
# Note: Matplotlib is loaded for debug plots but to ensure pybamm runs
Expand Down
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -214,6 +214,7 @@ def compile_KLU():
"importlib-metadata",
"pybtex>=0.24.0",
"sympy>=1.8",
"xarray",
"bpx",
"tqdm",
# Note: Matplotlib is loaded for debug plots, but to ensure pybamm runs
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_models/standard_output_comparison.py
Expand Up @@ -146,8 +146,8 @@ def __init__(self, time, solutions):
def test_all(self):
self.compare("Negative particle concentration [mol.m-3]")
self.compare("Positive particle concentration [mol.m-3]")
self.compare("Negative particle flux [mol.m-2.s-1]", rtol=0.05)
self.compare("Positive particle flux [mol.m-2.s-1]", rtol=0.05)
self.compare("Negative particle flux [mol.m-2.s-1]", atol=1e-7, rtol=0.05)
self.compare("Positive particle flux [mol.m-2.s-1]", atol=1e-7, rtol=0.05)


class PorosityComparison(BaseOutputComparison):
Expand Down
Expand Up @@ -75,15 +75,18 @@ def test_get_processed_variables(self):
# for each current collector model
for model in models[1:]:
solution = model.default_solver.solve(model)
vars = model.post_process(solution, param, V, I)
variables = model.post_process(solution, param, V, I)
pts = np.array([0.1, 0.5, 0.9]) * min(
param.evaluate(model.param.L_y), param.evaluate(model.param.L_z)
)
for var, processed_var in vars.items():
for var, processed_var in variables.items():
if "Voltage [V]" in var:
processed_var(t=solution_1D.t[5])
else:
processed_var(t=solution_1D.t[5], y=pts, z=pts)
if model.options["dimensionality"] == 1:
processed_var(t=solution_1D.t[5], z=pts)
else:
processed_var(t=solution_1D.t[5], y=pts, z=pts)


if __name__ == "__main__":
Expand Down
18 changes: 7 additions & 11 deletions tests/unit/test_plotting/test_quick_plot.py
Expand Up @@ -329,19 +329,15 @@ def test_loqs_spme(self):
)
quick_plot.plot(0)

qp_data = (
quick_plot.plots[("Electrolyte concentration [mol.m-3]",)][0][
0
].get_ydata(),
)[0]
qp_data = quick_plot.plots[("Electrolyte concentration [mol.m-3]",)][0][
0
].get_ydata()
np.testing.assert_array_almost_equal(qp_data, c_e[:, 0])
quick_plot.slider_update(t_eval[-1] / scale)

qp_data = (
quick_plot.plots[("Electrolyte concentration [mol.m-3]",)][0][
0
].get_ydata(),
)[0][:, 0]
quick_plot.slider_update(t_eval[-1] / scale)
qp_data = quick_plot.plots[("Electrolyte concentration [mol.m-3]",)][0][
0
].get_ydata()
np.testing.assert_array_almost_equal(qp_data, c_e[:, 1])

# test quick plot of particle for spme
Expand Down

0 comments on commit dc4855c

Please sign in to comment.