Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 858 ddt #888

Merged
merged 31 commits into from
Mar 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5127652
#858 update state vector to allow derivative of y or ydot
martinjrobins Mar 4, 2020
4c971a6
#858 add stateVectorDot
martinjrobins Mar 5, 2020
96c6a41
#858 add y_dot arg to all evaluate functions
martinjrobins Mar 6, 2020
5489aa3
#858 add VariableDot
martinjrobins Mar 6, 2020
e2b4ba7
#858 add d_dt unary operator, fixes to jac to support this
martinjrobins Mar 6, 2020
130bbb9
#858 discretisation of VariableDot results in StateVectorDot
martinjrobins Mar 6, 2020
1f03394
#858 time derivative of state vector gives state vector dot, raise er…
martinjrobins Mar 6, 2020
4ce9dbd
#858 add check for time derivs in model.rhs, fix some evaluate arg bugs
martinjrobins Mar 6, 2020
d330e67
#858 added make_semi_explicit to convert implicit dae equations to al…
martinjrobins Mar 6, 2020
eb1b381
#858 make_semi_explicit updates initial conditions
martinjrobins Mar 6, 2020
5a0da61
#858 some fixes, move time deriv model checks to base_model
martinjrobins Mar 7, 2020
9ba63f6
#858 misc fixes
martinjrobins Mar 7, 2020
c8eeb81
#858 style fixes, remove d_dt helper function, fix bugs
martinjrobins Mar 7, 2020
06ccfe8
#858 fixes for diff
martinjrobins Mar 7, 2020
b341616
#858 add vectordot and state_vectordot to api docs
martinjrobins Mar 14, 2020
6ca684d
#858 remove make_semi_explicit
martinjrobins Mar 14, 2020
f3747fa
#858 fix static analysis errors
martinjrobins Mar 14, 2020
6da78cb
#858 fix test failures
martinjrobins Mar 15, 2020
40feb08
#858 not testing solver anymore
martinjrobins Mar 15, 2020
3cdb6dc
#858 fix simplify bug, fix expression tree notebook
martinjrobins Mar 15, 2020
00990f9
#858 improve coverage
martinjrobins Mar 15, 2020
ff5e810
#858 merge develop
martinjrobins Mar 15, 2020
445ae54
#858 fix external variables test
martinjrobins Mar 15, 2020
7ac7848
#858 fix coverage and flake8
martinjrobins Mar 15, 2020
345be6a
#858 fix test errors
martinjrobins Mar 15, 2020
0c5270e
#858 improve coverage
martinjrobins Mar 16, 2020
fa5373e
#858 add coverage
martinjrobins Mar 18, 2020
d21f0b6
#858 improve coverage, document exception for InputParameter in evalu…
martinjrobins Mar 18, 2020
d8808a6
#858 improve coverage
martinjrobins Mar 18, 2020
bc32883
#858 improve coverage
martinjrobins Mar 18, 2020
088e625
#858 flake8
martinjrobins Mar 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/expression_tree/state_vector.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ State Vector

.. autoclass:: pybamm.StateVector
:members:

.. autoclass:: pybamm.StateVectorDot
:members:
4 changes: 4 additions & 0 deletions docs/source/expression_tree/variable.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,9 @@ Variable
.. autoclass:: pybamm.Variable
:members:

.. autoclass:: pybamm.VariableDot
:members:

.. autoclass:: pybamm.ExternalVariable
:members:

17 changes: 12 additions & 5 deletions examples/notebooks/expression_tree/expression-tree.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also calculate the expression tree representing the gradient of the equation with respect to $t$ (which is of course simply the scalar value 1),"
"We can also calculate the expression tree representing the gradient of the equation with respect to $t$,"
]
},
{
Expand All @@ -84,7 +84,7 @@
"![](expression_tree2.png)\n",
"\n",
"\n",
"...and evaluate this expression, which will again give 1."
"...and evaluate this expression,"
]
},
{
Expand All @@ -95,7 +95,7 @@
{
"data": {
"text/plain": [
"1.0"
"array([[-11.]])"
]
},
"execution_count": 4,
Expand All @@ -104,7 +104,7 @@
}
],
"source": [
"diff_wrt_equation.evaluate(1, np.array([2]))"
"diff_wrt_equation.evaluate(t=1, y=np.array([2]), y_dot=np.array([2]))"
]
},
{
Expand Down Expand Up @@ -202,6 +202,13 @@
"\n",
"After the third stage, our expression tree is now able to be evaluated by one of the solver classes. Note that we have used a single equation above to illustrate the different types of expression trees in PyBaMM, but any given models will consist of many RHS or algebraic equations, along with boundary conditions. See [here](https://github.com/pybamm-team/PyBaMM/blob/master/examples/notebooks/add-model.ipynb) for more details of PyBaMM models."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -220,7 +227,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.6.7"
}
},
"nbformat": 4,
Expand Down
Binary file modified examples/notebooks/expression_tree/expression_tree2.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 3 additions & 2 deletions pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,12 @@ def version(formatted=False):
from .expression_tree.parameter import Parameter, FunctionParameter
from .expression_tree.broadcasts import *
from .expression_tree.scalar import Scalar
from .expression_tree.variable import Variable, ExternalVariable
from .expression_tree.variable import Variable, ExternalVariable, VariableDot
from .expression_tree.variable import VariableBase
from .expression_tree.independent_variable import *
from .expression_tree.independent_variable import t
from .expression_tree.vector import Vector
from .expression_tree.state_vector import StateVector
from .expression_tree.state_vector import StateVectorBase, StateVector, StateVectorDot

from .expression_tree.exceptions import *

Expand Down
8 changes: 8 additions & 0 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ def process_rhs_and_algebraic(self, model):
equations) and processed_concatenated_algebraic

"""

# Discretise right-hand sides, passing domain from variable
processed_rhs = self.process_dict(model.rhs)

Expand Down Expand Up @@ -856,6 +857,13 @@ def _process_symbol(self, symbol):
disc_children = [self.process_symbol(child) for child in symbol.children]
return symbol._function_new_copy(disc_children)

elif isinstance(symbol, pybamm.VariableDot):
return pybamm.StateVectorDot(
*self.y_slices[symbol.get_variable().id],
domain=symbol.domain,
auxiliary_domains=symbol.auxiliary_domains
)

elif isinstance(symbol, pybamm.Variable):
# Check if variable is a standard variable or an external variable
if any(symbol.id == var.id for var in self.external_variables.values()):
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,6 @@ def new_copy(self):
self.entries_string,
)

def _base_evaluate(self, t=None, y=None, u=None):
def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
""" See :meth:`pybamm.Symbol._base_evaluate()`. """
return self._entries
34 changes: 19 additions & 15 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def is_scalar_zero(expr):
Utility function to test if an expression evaluates to a constant scalar zero
"""
if expr.is_constant():
result = expr.evaluate_ignoring_errors()
result = expr.evaluate_ignoring_errors(t=None)
return isinstance(result, numbers.Number) and result == 0
else:
return False
Expand All @@ -24,20 +24,20 @@ def is_matrix_zero(expr):
Utility function to test if an expression evaluates to a constant matrix zero
"""
if expr.is_constant():
result = expr.evaluate_ignoring_errors()
result = expr.evaluate_ignoring_errors(t=None)
return (issparse(result) and result.count_nonzero() == 0) or (
isinstance(result, np.ndarray) and np.all(result == 0)
)
else:
return False


def is_one(expr):
def is_scalar_one(expr):
"""
Utility function to test if an expression evaluates to a constant scalar one
"""
if expr.is_constant():
result = expr.evaluate_ignoring_errors()
result = expr.evaluate_ignoring_errors(t=None)
return isinstance(result, numbers.Number) and result == 1
else:
return False
Expand Down Expand Up @@ -162,21 +162,21 @@ def _binary_new_copy(self, left, right):
"Default behaviour for new_copy"
return self.__class__(left, right)

def evaluate(self, t=None, y=None, u=None, known_evals=None):
def evaluate(self, t=None, y=None, y_dot=None, u=None, known_evals=None):
""" See :meth:`pybamm.Symbol.evaluate()`. """
if known_evals is not None:
id = self.id
try:
return known_evals[id], known_evals
except KeyError:
left, known_evals = self.left.evaluate(t, y, u, known_evals)
right, known_evals = self.right.evaluate(t, y, u, known_evals)
left, known_evals = self.left.evaluate(t, y, y_dot, u, known_evals)
right, known_evals = self.right.evaluate(t, y, y_dot, u, known_evals)
value = self._binary_evaluate(left, right)
known_evals[id] = value
return value, known_evals
else:
left = self.left.evaluate(t, y, u)
right = self.right.evaluate(t, y, u)
left = self.left.evaluate(t, y, y_dot, u)
right = self.right.evaluate(t, y, y_dot, u)
return self._binary_evaluate(left, right)

def _evaluate_for_shape(self):
Expand Down Expand Up @@ -252,8 +252,12 @@ def _binary_simplify(self, left, right):
if is_scalar_zero(right):
return pybamm.Scalar(1)

# anything to the power of one is itself
# zero to the power of anything is zero
if is_scalar_zero(left):
return pybamm.Scalar(0)

# anything to the power of one is itself
if is_scalar_one(right):
return left

return self.__class__(left, right)
Expand Down Expand Up @@ -425,9 +429,9 @@ def _binary_simplify(self, left, right):
return zeros_of_shape(shape)

# anything multiplied by a scalar one returns itself
if is_one(left):
if is_scalar_one(left):
return right
if is_one(right):
if is_scalar_one(right):
return left

return pybamm.simplify_multiplication_division(self.__class__, left, right)
Expand Down Expand Up @@ -549,7 +553,7 @@ def _binary_simplify(self, left, right):
return pybamm.Array(np.inf * np.ones(left.shape_for_testing))

# anything divided by one is itself
if is_one(right):
if is_scalar_one(right):
return left

return pybamm.simplify_multiplication_division(self.__class__, left, right)
Expand Down Expand Up @@ -622,9 +626,9 @@ def _binary_simplify(self, left, right):
return zeros_of_shape(shape)

# anything multiplied by a scalar one returns itself
if is_one(left):
if is_scalar_one(left):
return right
if is_one(right):
if is_scalar_one(right):
return left

return pybamm.simplify_multiplication_division(self.__class__, left, right)
Expand Down
6 changes: 3 additions & 3 deletions pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,22 @@ def _concatenation_evaluate(self, children_eval):
else:
return self.concatenation_function(children_eval)

def evaluate(self, t=None, y=None, u=None, known_evals=None):
def evaluate(self, t=None, y=None, y_dot=None, u=None, known_evals=None):
""" See :meth:`pybamm.Symbol.evaluate()`. """
children = self.cached_children
if known_evals is not None:
if self.id not in known_evals:
children_eval = [None] * len(children)
for idx, child in enumerate(children):
children_eval[idx], known_evals = child.evaluate(
t, y, u, known_evals
t, y, y_dot, u, known_evals
)
known_evals[self.id] = self._concatenation_evaluate(children_eval)
return known_evals[self.id], known_evals
else:
children_eval = [None] * len(children)
for idx, child in enumerate(children):
children_eval[idx] = child.evaluate(t, y, u)
children_eval[idx] = child.evaluate(t, y, y_dot, u)
return self._concatenation_evaluate(children_eval)

def new_copy(self):
Expand Down
7 changes: 4 additions & 3 deletions pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,19 +152,20 @@ def _function_jac(self, children_jacs):

return jacobian

def evaluate(self, t=None, y=None, u=None, known_evals=None):
def evaluate(self, t=None, y=None, y_dot=None, u=None, known_evals=None):
""" See :meth:`pybamm.Symbol.evaluate()`. """
if known_evals is not None:
if self.id not in known_evals:
evaluated_children = [None] * len(self.children)
for i, child in enumerate(self.children):
evaluated_children[i], known_evals = child.evaluate(
t, y, u, known_evals=known_evals
t, y, y_dot, u, known_evals=known_evals
)
known_evals[self.id] = self._function_evaluate(evaluated_children)
return known_evals[self.id], known_evals
else:
evaluated_children = [child.evaluate(t, y, u) for child in self.children]
evaluated_children = [child.evaluate(t, y, y_dot, u)
for child in self.children]
return self._function_evaluate(evaluated_children)

def _evaluate_for_shape(self):
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/independent_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
return Time()

def _base_evaluate(self, t, y=None, u=None):
def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
""" See :meth:`pybamm.Symbol._base_evaluate()`. """
if t is None:
raise ValueError("t must be provided")
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/input_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _jac(self, variable):
""" See :meth:`pybamm.Symbol._jac()`. """
return pybamm.Scalar(0)

def _base_evaluate(self, t=None, y=None, u=None):
def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
# u should be a dictionary
# convert 'None' to empty dictionary for more informative error
if u is None:
Expand Down
25 changes: 16 additions & 9 deletions pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, casadi_symbols=None):

pybamm.citations.register("Andersson2019")

def convert(self, symbol, t=None, y=None, u=None):
def convert(self, symbol, t, y, y_dot, u):
"""
This function recurses down the tree, converting the PyBaMM expression tree to
a CasADi expression tree
Expand All @@ -39,12 +39,12 @@ def convert(self, symbol, t=None, y=None, u=None):
except KeyError:
# Change u to empty dictionary if it's None
u = u or {}
casadi_symbol = self._convert(symbol, t, y, u)
casadi_symbol = self._convert(symbol, t, y, y_dot, u)
self._casadi_symbols[symbol.id] = casadi_symbol

return casadi_symbol

def _convert(self, symbol, t=None, y=None, u=None):
def _convert(self, symbol, t, y, y_dot, u):
""" See :meth:`CasadiConverter.convert()`. """
if isinstance(
symbol,
Expand All @@ -56,34 +56,41 @@ def _convert(self, symbol, t=None, y=None, u=None):
pybamm.ExternalVariable,
),
):
return casadi.MX(symbol.evaluate(t, y, u))
return casadi.MX(symbol.evaluate(t, y, y_dot, u))

elif isinstance(symbol, pybamm.StateVector):
if y is None:
raise ValueError("Must provide a 'y' for converting state vectors")
return casadi.vertcat(*[y[y_slice] for y_slice in symbol.y_slices])

elif isinstance(symbol, pybamm.StateVectorDot):
if y_dot is None:
raise ValueError("Must provide a 'y_dot' for converting state vectors")
return casadi.vertcat(*[y_dot[y_slice] for y_slice in symbol.y_slices])

elif isinstance(symbol, pybamm.BinaryOperator):
left, right = symbol.children
# process children
converted_left = self.convert(left, t, y, u)
converted_right = self.convert(right, t, y, u)
converted_left = self.convert(left, t, y, y_dot, u)
converted_right = self.convert(right, t, y, y_dot, u)

if isinstance(symbol, pybamm.Minimum):
return casadi.fmin(converted_left, converted_right)
if isinstance(symbol, pybamm.Maximum):
return casadi.fmax(converted_left, converted_right)

# _binary_evaluate defined in derived classes for specific rules
return symbol._binary_evaluate(converted_left, converted_right)

elif isinstance(symbol, pybamm.UnaryOperator):
converted_child = self.convert(symbol.child, t, y, u)
converted_child = self.convert(symbol.child, t, y, y_dot, u)
if isinstance(symbol, pybamm.AbsoluteValue):
return casadi.fabs(converted_child)
return symbol._unary_evaluate(converted_child)

elif isinstance(symbol, pybamm.Function):
converted_children = [
self.convert(child, t, y, u) for child in symbol.children
self.convert(child, t, y, y_dot, u) for child in symbol.children
]
# Special functions
if symbol.function == np.min:
Expand Down Expand Up @@ -114,7 +121,7 @@ def _convert(self, symbol, t=None, y=None, u=None):
return symbol._function_evaluate(converted_children)
elif isinstance(symbol, pybamm.Concatenation):
converted_children = [
self.convert(child, t, y, u) for child in symbol.children
self.convert(child, t, y, y_dot, u) for child in symbol.children
]
if isinstance(symbol, (pybamm.NumpyConcatenation, pybamm.SparseStack)):
return casadi.vertcat(*converted_children)
Expand Down