Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove anytree #1912

Merged
merged 6 commits into from
Jan 23, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
- Added an option to force install compatible versions of jax and jaxlib if already installed using CLI ([#1881](https://github.com/pybamm-team/PyBaMM/pull/1881))
- Allow pybamm.Solution.save_data() to return a string if filename is None, and added json to_format option ([#1909](https://github.com/pybamm-team/PyBaMM/pull/1909)

## Optimizations

- The `Symbol` nodes no longer subclasses `anytree.NodeMixIn`. This removes some checks that were not really needed ([#1912](https://github.com/pybamm-team/PyBaMM/pull/1912))

## Bug fixes

- Parameters can now be imported from any given path in `Windows` ([#1900](https://github.com/pybamm-team/PyBaMM/pull/1900))
Expand Down
14 changes: 5 additions & 9 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,18 +1043,14 @@ def _process_symbol(self, symbol):
# Return a new copy of the input parameter, but set the expected size
# according to the domain of the input parameter
expected_size = self._get_variable_size(symbol)
new_input_parameter = symbol.new_copy()
new_input_parameter.set_expected_size(expected_size)
new_input_parameter = pybamm.InputParameter(
symbol.name, symbol.domain, expected_size
)
return new_input_parameter

else:
# Backup option: return new copy of the object
try:
return symbol.new_copy()
except NotImplementedError:
raise NotImplementedError(
"Cannot discretise symbol of type '{}'".format(type(symbol))
)
# Backup option: return the object
return symbol

def concatenate(self, *symbols, sparse=False):
if sparse:
Expand Down
20 changes: 5 additions & 15 deletions pybamm/expression_tree/averages.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ def x_average(symbol):
)
for domain in symbol.domains.values()
):
new_symbol = symbol.new_copy()
new_symbol.parent = None
return new_symbol
return symbol
# If symbol is a broadcast, reduce by one dimension
if isinstance(
symbol,
Expand Down Expand Up @@ -217,9 +215,7 @@ def z_average(symbol):
)
# If symbol doesn't have a domain, its average value is itself
if symbol.domain == []:
new_symbol = symbol.new_copy()
new_symbol.parent = None
return new_symbol
return symbol
# If symbol is a Broadcast, its average value is its child
elif isinstance(symbol, pybamm.Broadcast):
return symbol.orphans[0]
Expand Down Expand Up @@ -252,9 +248,7 @@ def yz_average(symbol):
)
# If symbol doesn't have a domain, its average value is itself
if symbol.domain == []:
new_symbol = symbol.new_copy()
new_symbol.parent = None
return new_symbol
return symbol
# If symbol is a Broadcast, its average value is its child
elif isinstance(symbol, pybamm.Broadcast):
return symbol.orphans[0]
Expand Down Expand Up @@ -287,9 +281,7 @@ def r_average(symbol):
["negative particle"],
["working particle"],
]:
new_symbol = symbol.new_copy()
new_symbol.parent = None
return new_symbol
return symbol
# If symbol is a secondary broadcast onto "negative electrode" or
# "positive electrode", take the r-average of the child then broadcast back
elif isinstance(symbol, pybamm.SecondaryBroadcast) and symbol.domains[
Expand Down Expand Up @@ -334,9 +326,7 @@ def size_average(symbol, f_a_dist=None):
domain in [["negative particle size"], ["positive particle size"]]
for domain in list(symbol.domains.values())
):
new_symbol = symbol.new_copy()
new_symbol.parent = None
return new_symbol
return symbol

# If symbol is a primary broadcast to "particle size", take the orphan
elif isinstance(symbol, pybamm.PrimaryBroadcast) and symbol.domain in [
Expand Down
8 changes: 4 additions & 4 deletions pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __str__(self):

def _diff(self, variable):
"""See :meth:`pybamm.Symbol._diff()`."""
children_diffs = [child.diff(variable) for child in self.cached_children]
children_diffs = [child.diff(variable) for child in self.children]
if len(children_diffs) == 1:
diff = children_diffs[0]
else:
Expand Down Expand Up @@ -92,7 +92,7 @@ def _concatenation_evaluate(self, children_eval):

def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
"""See :meth:`pybamm.Symbol.evaluate()`."""
children = self.cached_children
children = self.children
if known_evals is not None:
if self.id not in known_evals:
children_eval = [None] * len(children)
Expand Down Expand Up @@ -189,7 +189,7 @@ def __init__(self, *children):

def _concatenation_jac(self, children_jacs):
"""See :meth:`pybamm.Concatenation.concatenation_jac()`."""
children = self.cached_children
children = self.children
if len(children) == 0:
return pybamm.Scalar(0)
else:
Expand Down Expand Up @@ -252,7 +252,7 @@ def __init__(self, children, full_mesh, copy_this=None):

# create disc of domain => slice for each child
self._children_slices = [
self.create_slices(child) for child in self.cached_children
self.create_slices(child) for child in self.children
]
else:
self._full_mesh = copy.copy(copy_this._full_mesh)
Expand Down
21 changes: 7 additions & 14 deletions pybamm/expression_tree/input_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,21 @@ class InputParameter(pybamm.Symbol):
domain : iterable of str, or str
list of domains over which the node is valid (empty list indicates the symbol
is valid over all domains)
expected_size : int
The size of the input parameter expected, defaults to 1 (scalar input)
"""

def __init__(self, name, domain=None):
# Expected shape defaults to 1
self._expected_size = 1
def __init__(self, name, domain=None, expected_size=1):
self._expected_size = expected_size
super().__init__(name, domain=domain)

def create_copy(self):
"""See :meth:`pybamm.Symbol.new_copy()`."""
new_input_parameter = InputParameter(self.name, self.domain)
new_input_parameter._expected_size = self._expected_size
new_input_parameter = InputParameter(
self.name, self.domain, expected_size=self._expected_size
)
return new_input_parameter

def set_expected_size(self, size):
"""Specify the size that the input parameter should be."""
self._expected_size = size

# We also need to update the saved size and shape
self._saved_size = size
self._saved_shape = (size, 1)
self._saved_evaluate_for_shape = self._evaluate_for_shape()

def _evaluate_for_shape(self):
"""
Returns the scalar 'NaN' to represent the shape of a parameter.
Expand Down
6 changes: 2 additions & 4 deletions pybamm/expression_tree/operations/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def jac(self, symbol, variable):
return jac

def _jac(self, symbol, variable):
""" See :meth:`Jacobian.jac()`. """
"""See :meth:`Jacobian.jac()`."""

if isinstance(symbol, pybamm.BinaryOperator):
left, right = symbol.children
Expand All @@ -76,9 +76,7 @@ def _jac(self, symbol, variable):
jac = symbol._function_jac(children_jacs)

elif isinstance(symbol, pybamm.Concatenation):
children_jacs = [
self.jac(child, variable) for child in symbol.cached_children
]
children_jacs = [self.jac(child, variable) for child in symbol.children]
if len(children_jacs) == 1:
jac = children_jacs[0]
else:
Expand Down
5 changes: 2 additions & 3 deletions pybamm/expression_tree/operations/replace_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,5 @@ def _process_symbol(self, symbol):
else:
# Only other option is that the symbol is a leaf (doesn't have children)
# In this case, since we have already ruled out that the symbol is one of
# the symbols that needs to be replaced, we can just return a new copy of
# the symbol
return symbol.new_copy()
# the symbols that needs to be replaced, we can just return the symbol
return symbol
18 changes: 4 additions & 14 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#
# Base Symbol Class for the expression tree
#
import copy
import numbers

import anytree
Expand Down Expand Up @@ -179,7 +178,7 @@ def simplify_if_constant(symbol):
return symbol


class Symbol(anytree.NodeMixin):
class Symbol:
"""
Base node class for the expression tree.

Expand Down Expand Up @@ -210,19 +209,10 @@ def __init__(self, name, children=None, domain=None, auxiliary_domains=None):
if children is None:
children = []

# Store "orphans", which are separate from children as they do not have a
# parent node, so they do not cause tree corruption errors when used again
# in a different part of the tree
self._children = children
# Keep a separate "oprhans" attribute for backwards compatibility
self._orphans = children

for child in children:
# copy child before adding
# this also adds copy.copy(child) to self.children
copy.copy(child).parent = self

# cache children
self.cached_children = super(Symbol, self).children

# Set auxiliary domains
self._domains = {"primary": None}
self.auxiliary_domains = auxiliary_domains
Expand Down Expand Up @@ -250,7 +240,7 @@ def children(self):
Note: it is assumed that children of a node are not modified after initial
creation
"""
return self.cached_children
return self._children

@property
def name(self):
Expand Down
10 changes: 3 additions & 7 deletions pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,11 @@ def __init__(self, child):

def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
child = self.child.new_copy()
return sign(child) * child.diff(variable)
return sign(self.child) * self.child.diff(variable)

def _unary_jac(self, child_jac):
"""See :meth:`pybamm.UnaryOperator._unary_jac()`."""
child = self.child.new_copy()
return sign(child) * child_jac
return sign(self.child) * child_jac

def _unary_evaluate(self, child):
"""See :meth:`UnaryOperator._unary_evaluate()`."""
Expand Down Expand Up @@ -1273,9 +1271,7 @@ def boundary_value(symbol, side):

# If symbol doesn't have a domain, its boundary value is itself
if symbol.domain == []:
new_symbol = symbol.new_copy()
new_symbol.parent = None
return new_symbol
return symbol
# If symbol is a primary or full broadcast, reduce by one dimension
if isinstance(symbol, (pybamm.PrimaryBroadcast, pybamm.FullBroadcast)):
return symbol.reduce_one_dimension()
Expand Down
13 changes: 3 additions & 10 deletions pybamm/parameters/parameter_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def _process_symbol(self, symbol):
):
# Wrap with NotConstant to avoid simplification,
# which would stop symbolic diff from working properly
new_child = pybamm.NotConstant(child.new_copy())
new_child = pybamm.NotConstant(child)
new_children.append(self.process_symbol(new_child))
else:
new_children.append(self.process_symbol(child))
Expand Down Expand Up @@ -766,15 +766,8 @@ def _process_symbol(self, symbol):
return symbol._concatenation_new_copy(new_children)

else:
# Backup option: return new copy of the object
try:
return symbol.new_copy()
except NotImplementedError:
raise NotImplementedError(
"Cannot process parameters for symbol of type '{}'".format(
type(symbol)
)
)
# Backup option: return the object
return symbol

def evaluate(self, symbol):
"""
Expand Down
18 changes: 5 additions & 13 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,23 +480,17 @@ def jacp(*args, **kwargs):
found_t = True
# Dimensional
elif symbol.right.id == (pybamm.t * model.timescale_eval).id:
expr = (
symbol.left.new_copy() / symbol.right.right.new_copy()
)
expr = symbol.left / symbol.right.right
found_t = True
elif symbol.left.id == (pybamm.t * model.timescale_eval).id:
expr = (
symbol.right.new_copy() / symbol.left.right.new_copy()
)
expr = symbol.right / symbol.left.right
found_t = True

# Update the events if the heaviside function depended on t
if found_t:
model.events.append(
pybamm.Event(
str(symbol),
expr.new_copy(),
pybamm.EventType.DISCONTINUITY,
str(symbol), expr, pybamm.EventType.DISCONTINUITY
)
)
elif isinstance(symbol, pybamm.Modulo):
Expand All @@ -507,9 +501,7 @@ def jacp(*args, **kwargs):
found_t = True
# Dimensional
elif symbol.left.id == (pybamm.t * model.timescale_eval).id:
expr = (
symbol.right.new_copy() / symbol.left.right.new_copy()
)
expr = symbol.right / symbol.left.right
found_t = True

# Update the events if the modulo function depended on t
Expand All @@ -523,7 +515,7 @@ def jacp(*args, **kwargs):
model.events.append(
pybamm.Event(
str(symbol),
expr.new_copy() * pybamm.Scalar(i + 1),
expr * pybamm.Scalar(i + 1),
pybamm.EventType.DISCONTINUITY,
)
)
Expand Down
7 changes: 3 additions & 4 deletions tests/unit/test_expression_tree/test_input_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@ def test_input_parameter_init(self):
self.assertEqual(a.evaluate(inputs={"a": 1}), 1)
self.assertEqual(a.evaluate(inputs={"a": 5}), 5)

def test_set_expected_size(self):
a = pybamm.InputParameter("a")
a.set_expected_size(10)
a = pybamm.InputParameter("a", expected_size=10)
self.assertEqual(a._expected_size, 10)
np.testing.assert_array_equal(
a.evaluate(inputs="shape test"), np.nan * np.ones((10, 1))
)
y = np.linspace(0, 1, 10)
np.testing.assert_array_equal(a.evaluate(inputs={"a": y}), y[:, np.newaxis])

with self.assertRaisesRegex(
ValueError,
"Input parameter 'a' was given an object of size '1' but was expecting an "
Expand All @@ -34,7 +33,7 @@ def test_evaluate_for_shape(self):
self.assertTrue(np.isnan(a.evaluate_for_shape()))
self.assertEqual(a.shape, ())

a.set_expected_size(10)
a = pybamm.InputParameter("a", expected_size=10)
self.assertEqual(a.shape, (10, 1))
np.testing.assert_equal(a.evaluate_for_shape(), np.nan * np.ones((10, 1)))
self.assertEqual(a.evaluate_for_shape().shape, (10, 1))
Expand Down
Loading