Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 749 cache symbol shape #780

Merged
merged 14 commits into from
Jan 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

## Optimizations

- Added caching for shape evaluation, used during discretisation ([#780](https://github.com/pybamm-team/PyBaMM/pull/780))
- Added an option to skip model checks during discretisation, which could be slow for large models ([#739](https://github.com/pybamm-team/PyBaMM/pull/739))
- Use CasADi's automatic differentation algorithms by default when solving a model ([#714](https://github.com/pybamm-team/PyBaMM/pull/714))
- Avoid re-checking size when making a copy of an `Index` object ([#656](https://github.com/pybamm-team/PyBaMM/pull/656))
Expand All @@ -53,6 +54,7 @@

## Breaking changes

- Removed `Outer` and `Kron` nodes as no longer used ([#777](https://github.com/pybamm-team/PyBaMM/pull/777))
- Moved `results` to separate repositories ([#761](https://github.com/pybamm-team/PyBaMM/pull/761))
- The parameters "Bruggeman coefficient" must now be specified separately as "Bruggeman coefficient (electrolyte)" and "Bruggeman coefficient (electrode)"
- The current classes (`GetConstantCurrent`, `GetUserCurrent` and `GetUserData`) have now been removed. Please refer to the [`change-input-current` notebook](https://github.com/pybamm-team/PyBaMM/blob/master/examples/notebooks/change-input-current.ipynb) for information on how to specify an input current
Expand Down
8 changes: 0 additions & 8 deletions docs/source/expression_tree/binary_operator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,7 @@ Binary Operators
.. autoclass:: pybamm.Inner
:members:

.. autoclass:: pybamm.Outer
:members:

.. autoclass:: pybamm.Kron
:members:

.. autoclass:: pybamm.Heaviside
:members:

.. autofunction:: pybamm.outer

.. autofunction:: pybamm.source
3 changes: 0 additions & 3 deletions pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,7 @@ def version(formatted=False):
Division,
Inner,
inner,
Outer,
Kron,
Heaviside,
outer,
source,
)
from .expression_tree.concatenations import (
Expand Down
4 changes: 1 addition & 3 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ def check_variables(self, model):
"""
Check variables in variable list against rhs
Be lenient with size check if the variable in model.variables is broadcasted, or
a concatenation, or an outer product
a concatenation
(if broadcasted, variable is a multiplication with a vector of ones)
"""
for rhs_var in model.rhs.keys():
Expand All @@ -1001,7 +1001,6 @@ def check_variables(self, model):
)

not_concatenation = not isinstance(var, pybamm.Concatenation)
not_outer = not isinstance(var, pybamm.Outer)

not_mult_by_one_vec = not (
isinstance(var, pybamm.Multiplication)
Expand All @@ -1012,7 +1011,6 @@ def check_variables(self, model):
if (
different_shapes
and not_concatenation
and not_outer
and not_mult_by_one_vec
):
raise pybamm.ModelError(
Expand Down
115 changes: 5 additions & 110 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
import numbers
from scipy.sparse import issparse, csr_matrix, kron
from scipy.sparse import issparse, csr_matrix


def is_scalar_zero(expr):
Expand Down Expand Up @@ -79,17 +79,8 @@ class BinaryOperator(pybamm.Symbol):
def __init__(self, name, left, right):
left, right = self.format(left, right)

# Check and process domains, except for Outer symbol which takes the outer
# product of two smbols in different domains, and gives it the domain of the
# right child.
if isinstance(self, (pybamm.Outer, pybamm.Kron)):
domain = right.domain
auxiliary_domains = {}
if domain != []:
auxiliary_domains["secondary"] = left.domain
else:
domain = self.get_children_domains(left.domain, right.domain)
auxiliary_domains = self.get_children_auxiliary_domains([left, right])
domain = self.get_children_domains(left.domain, right.domain)
auxiliary_domains = self.get_children_auxiliary_domains([left, right])
super().__init__(
name,
children=[left, right],
Expand All @@ -116,11 +107,7 @@ def format(self, left, right):
)

# Do some broadcasting in special cases, to avoid having to do this manually
if (
not isinstance(self, (Outer, Kron))
and left.domain != []
and right.domain != []
):
if left.domain != [] and right.domain != []:
if (
left.domain != right.domain
and "secondary" in right.auxiliary_domains
Expand Down Expand Up @@ -192,7 +179,7 @@ def evaluate(self, t=None, y=None, u=None, known_evals=None):
right = self.right.evaluate(t, y, u)
return self._binary_evaluate(left, right)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape()`. """
left = self.children[0].evaluate_for_shape()
right = self.children[1].evaluate_for_shape()
Expand Down Expand Up @@ -654,86 +641,6 @@ def inner(left, right):
return pybamm.Inner(left, right)


class Outer(BinaryOperator):
"""A node in the expression tree representing an outer product.
This takes a 1D vector in the current collector domain of size (n,1) and a 1D
variable of size (m,1), takes their outer product, and reshapes this into a vector
of size (nm,1). It can also take in a vector in a single particle and a vector
of the electrolyte domain to repeat that particle.
Note: this class might be a bit dangerous, so at the moment it is very restrictive
in what symbols can be passed to it

**Extends:** :class:`BinaryOperator`
"""

def __init__(self, left, right):
""" See :meth:`pybamm.BinaryOperator.__init__()`. """
# cannot have certain types of objects in the right symbol, as these
# can already be 2D objects (so we can't take an outer product with them)
if right.has_symbol_of_classes(
(pybamm.Variable, pybamm.StateVector, pybamm.Matrix, pybamm.SpatialVariable)
):
raise TypeError("right child must only contain Vectors and Scalars" "")

super().__init__("outer product", left, right)

def __str__(self):
""" See :meth:`pybamm.Symbol.__str__()`. """
return "outer({!s}, {!s})".format(self.left, self.right)

def diff(self, variable):
""" See :meth:`pybamm.Symbol.diff()`. """
raise NotImplementedError("diff not implemented for symbol of type 'Outer'")

def _outer_jac(self, left_jac, right_jac, variable):
"""
Calculate jacobian of outer product.
See :meth:`pybamm.Jacobian._jac()`.
"""
# right cannot be a StateVector, so no need for product rule
left, right = self.orphans
if left.evaluates_to_number():
# Return zeros of correct size
return pybamm.Matrix(
csr_matrix((self.size, variable.evaluation_array.count(True)))
)
else:
return pybamm.Kron(left_jac, right)

def _binary_evaluate(self, left, right):
""" See :meth:`pybamm.BinaryOperator._binary_evaluate()`. """

return np.outer(left, right).reshape(-1, 1)


class Kron(BinaryOperator):
"""A node in the expression tree representing a (sparse) kronecker product operator

**Extends:** :class:`BinaryOperator`
"""

def __init__(self, left, right):
""" See :meth:`pybamm.BinaryOperator.__init__()`. """

super().__init__("kronecker product", left, right)

def __str__(self):
""" See :meth:`pybamm.Symbol.__str__()`. """
return "kron({!s}, {!s})".format(self.left, self.right)

def diff(self, variable):
""" See :meth:`pybamm.Symbol.diff()`. """
raise NotImplementedError("diff not implemented for symbol of type 'Kron'")

def _binary_jac(self, left_jac, right_jac):
""" See :meth:`pybamm.BinaryOperator._binary_jac()`. """
raise NotImplementedError("jac not implemented for symbol of type 'Kron'")

def _binary_evaluate(self, left, right):
""" See :meth:`pybamm.BinaryOperator._binary_evaluate()`. """
return csr_matrix(kron(left, right))


class Heaviside(BinaryOperator):
"""A node in the expression tree representing a heaviside step function.

Expand Down Expand Up @@ -783,18 +690,6 @@ def _binary_new_copy(self, left, right):
return Heaviside(left, right, self.equal)


def outer(left, right):
"""
Return outer product of two symbols. If the symbols have the same domain, the outer
product is just a multiplication. If they have different domains, make a copy of the
left child with same domain as right child, and then take outer product.
"""
try:
return left * right
except pybamm.DomainError:
return pybamm.Outer(left, right)


def source(left, right, boundary=False):
"""A convinience function for creating (part of) an expression tree representing
a source term. This is necessary for spatial methods where the mass matrix
Expand Down
6 changes: 3 additions & 3 deletions pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _unary_new_copy(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
return PrimaryBroadcast(child, self.broadcast_domain)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns a vector of NaNs to represent the shape of a Broadcast.
See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`
Expand Down Expand Up @@ -210,7 +210,7 @@ def _unary_new_copy(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
return SecondaryBroadcast(child, self.broadcast_domain)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns a vector of NaNs to represent the shape of a Broadcast.
See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`
Expand Down Expand Up @@ -253,7 +253,7 @@ def _unary_new_copy(self, child):
""" See :meth:`pybamm.UnaryOperator.simplify()`. """
return FullBroadcast(child, self.broadcast_domain, self.auxiliary_domains)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns a vector of NaNs to represent the shape of a Broadcast.
See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _concatenation_simplify(self, children):
new_symbol.clear_domains()
return new_symbol

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape` """
if len(self.children) == 0:
return np.array([])
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def evaluate(self, t=None, y=None, u=None, known_evals=None):
evaluated_children = [child.evaluate(t, y, u) for child in self.children]
return self._function_evaluate(evaluated_children)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Default behaviour: has same shape as all child
See :meth:`pybamm.Symbol.evaluate_for_shape()`
Expand Down
4 changes: 2 additions & 2 deletions pybamm/expression_tree/independent_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class IndependentVariable(pybamm.Symbol):
def __init__(self, name, domain=None, auxiliary_domains=None):
super().__init__(name, domain=domain, auxiliary_domains=auxiliary_domains)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
""" See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()` """
return pybamm.evaluate_for_shape_using_domain(
self.domain, self.auxiliary_domains
Expand Down Expand Up @@ -57,7 +57,7 @@ def _base_evaluate(self, t, y=None):
raise ValueError("t must be provided")
return t

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Return the scalar '0' to represent the shape of the independent variable `Time`.
See :meth:`pybamm.Symbol.evaluate_for_shape()`
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/input_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
return InputParameter(self.name)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns the scalar 'NaN' to represent the shape of a parameter.
See :meth:`pybamm.Symbol.evaluate_for_shape()`
Expand Down
7 changes: 2 additions & 5 deletions pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,8 @@ def _convert(self, symbol, t=None, y=None, u=None):
# process children
converted_left = self.convert(left, t, y, u)
converted_right = self.convert(right, t, y, u)
if isinstance(symbol, pybamm.Outer):
return casadi.kron(converted_left, converted_right)
else:
# _binary_evaluate defined in derived classes for specific rules
return symbol._binary_evaluate(converted_left, converted_right)
# _binary_evaluate defined in derived classes for specific rules
return symbol._binary_evaluate(converted_left, converted_right)

elif isinstance(symbol, pybamm.UnaryOperator):
converted_child = self.convert(symbol.child, t, y, u)
Expand Down
8 changes: 0 additions & 8 deletions pybamm/expression_tree/operations/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,6 @@ def find_symbols(symbol, constant_symbols, variable_symbols):
"if scipy.sparse.issparse({1}) else "
"{0} * {1}".format(children_vars[0], children_vars[1])
)
elif isinstance(symbol, pybamm.Outer):
symbol_str = "np.outer({}, {}).reshape(-1, 1)".format(
children_vars[0], children_vars[1]
)
elif isinstance(symbol, pybamm.Kron):
symbol_str = "scipy.sparse.csr_matrix(scipy.sparse.kron({}, {}))".format(
children_vars[0], children_vars[1]
)
else:
symbol_str = children_vars[0] + " " + symbol.name + " " + children_vars[1]

Expand Down
11 changes: 2 additions & 9 deletions pybamm/expression_tree/operations/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,8 @@ def _jac(self, symbol, variable):
# process children
left_jac = self.jac(left, variable)
right_jac = self.jac(right, variable)
# Need to treat outer differently. If the left child of an Outer
# evaluates to number then we need to return a matrix of zeros
# of the correct size, which requires variable.evaluation_array
if isinstance(symbol, pybamm.Outer):
# _outer_jac defined in pybamm.Outer
jac = symbol._outer_jac(left_jac, right_jac, variable)
else:
# _binary_jac defined in derived classes for specific rules
jac = symbol._binary_jac(left_jac, right_jac)
# _binary_jac defined in derived classes for specific rules
jac = symbol._binary_jac(left_jac, right_jac)

elif isinstance(symbol, pybamm.UnaryOperator):
child_jac = self.jac(symbol.child, variable)
Expand Down
4 changes: 2 additions & 2 deletions pybamm/expression_tree/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
return Parameter(self.name, self.domain)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns the scalar 'NaN' to represent the shape of a parameter.
See :meth:`pybamm.Symbol.evaluate_for_shape()`
Expand Down Expand Up @@ -118,7 +118,7 @@ def _function_parameter_new_copy(self, children):
"""
return FunctionParameter(self.name, *children, diff_variable=self.diff_variable)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns the sum of the evaluated children
See :meth:`pybamm.Symbol.evaluate_for_shape()`
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def new_copy(self):
evaluation_array=self.evaluation_array,
)

def evaluate_for_shape(self):
def _evaluate_for_shape(self):
"""
Returns a vector of NaNs to represent the shape of a StateVector.
The size of a StateVector is the number of True elements in its evaluation_array
Expand Down
8 changes: 8 additions & 0 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,14 @@ def evaluate_for_shape(self):
shape is returned instead, using the symbol's domain.
See :meth:`pybamm.Symbol.evaluate()`
"""
try:
return self._saved_evaluate_for_shape
except AttributeError:
self._saved_evaluate_for_shape = self._evaluate_for_shape()
return self._saved_evaluate_for_shape

def _evaluate_for_shape(self):
"See :meth:`Symbol.evaluate_for_shape`"
return self.evaluate()

def is_constant(self):
Expand Down
Loading