From d4c0523a4a7c9e886cbb59adac94dccf7fd38e17 Mon Sep 17 00:00:00 2001 From: kratman Date: Tue, 5 Mar 2024 09:42:20 -0500 Subject: [PATCH 1/5] Rename have_optional_dependency --- CONTRIBUTING.md | 8 ++++---- pybamm/__init__.py | 2 +- pybamm/citations.py | 12 +++++------ pybamm/expression_tree/array.py | 4 ++-- pybamm/expression_tree/binary_operators.py | 10 +++++----- pybamm/expression_tree/concatenations.py | 4 ++-- pybamm/expression_tree/functions.py | 12 +++++------ .../expression_tree/independent_variable.py | 6 +++--- pybamm/expression_tree/operations/latexify.py | 8 ++++---- pybamm/expression_tree/parameter.py | 6 +++--- pybamm/expression_tree/scalar.py | 4 ++-- pybamm/expression_tree/symbol.py | 12 +++++------ pybamm/expression_tree/unary_operators.py | 14 +++++++------ pybamm/expression_tree/variable.py | 4 ++-- pybamm/meshes/scikit_fem_submeshes.py | 4 ++-- pybamm/models/base_model.py | 4 ++-- pybamm/plotting/plot.py | 4 ++-- pybamm/plotting/plot2D.py | 4 ++-- pybamm/plotting/plot_summary_variables.py | 4 ++-- pybamm/plotting/plot_voltage_components.py | 4 ++-- pybamm/plotting/quick_plot.py | 20 +++++++++---------- pybamm/simulation.py | 4 ++-- .../spatial_methods/scikit_finite_element.py | 16 +++++++-------- pybamm/util.py | 12 +++-------- .../test_binary_operators.py | 4 ++-- .../test_concatenations.py | 4 ++-- .../test_expression_tree/test_functions.py | 4 ++-- .../test_independent_variable.py | 4 ++-- .../test_expression_tree/test_parameter.py | 6 +++--- .../test_printing/test_sympy_overrides.py | 4 ++-- .../unit/test_expression_tree/test_symbol.py | 4 ++-- .../test_unary_operators.py | 10 ++++++---- .../test_expression_tree/test_variable.py | 4 ++-- tests/unit/test_util.py | 4 ++-- 34 files changed, 114 insertions(+), 116 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b9800dcd61..fc8e848bb5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -104,13 +104,13 @@ Only 'core pybamm' is installed by default. The others have to be specified expl PyBaMM utilizes optional dependencies to allow users to choose which additional libraries they want to use. Managing these optional dependencies and their imports is essential to provide flexibility to PyBaMM users. -PyBaMM provides a utility function `have_optional_dependency`, to check for the availability of optional dependencies within methods. This function can be used to conditionally import optional dependencies only if they are available. Here's how to use it: +PyBaMM provides a utility function `import_optional_dependency`, to check for the availability of optional dependencies within methods. This function can be used to conditionally import optional dependencies only if they are available. Here's how to use it: Optional dependencies should never be imported at the module level, but always inside methods. For example: ``` def use_pybtex(x,y,z): - pybtex = have_optional_dependency("pybtex") + pybtex = import_optional_dependency("pybtex") ... ``` @@ -118,7 +118,7 @@ While importing a specific module instead of an entire package/library: ```python def use_parse_file(x, y, z): - parse_file = have_optional_dependency("pybtex.database", "parse_file") + parse_file = import_optional_dependency("pybtex.database", "parse_file") ... ``` @@ -143,7 +143,7 @@ class TestUtil(TestCase): pybamm.function_using_pybtex(x, y, z) # Test that the function works when pybtex is available - sys.modules["pybtex"] = pybamm.util.have_optional_dependency("pybtex") + sys.modules["pybtex"] = pybamm.util.import_optional_dependency("pybtex") pybamm.function_using_pybtex(x, y, z) ``` diff --git a/pybamm/__init__.py b/pybamm/__init__.py index ab2e72ed28..c2654ea9cf 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -47,7 +47,7 @@ get_parameters_filepath, have_jax, install_jax, - have_optional_dependency, + import_optional_dependency, is_jax_compatible, get_git_commit_info, ) diff --git a/pybamm/citations.py b/pybamm/citations.py index da371bbd84..74f477c117 100644 --- a/pybamm/citations.py +++ b/pybamm/citations.py @@ -7,7 +7,7 @@ import os import warnings from sys import _getframe -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class Citations: @@ -74,7 +74,7 @@ def read_citations(self): """Reads the citations in `pybamm.CITATIONS.bib`. Other works can be cited by passing a BibTeX citation to :meth:`register`. """ - parse_file = have_optional_dependency("pybtex.database", "parse_file") + parse_file = import_optional_dependency("pybtex.database", "parse_file") citations_file = os.path.join(pybamm.root_dir(), "pybamm", "CITATIONS.bib") bib_data = parse_file(citations_file, bib_format="bibtex") for key, entry in bib_data.entries.items(): @@ -85,7 +85,7 @@ def _add_citation(self, key, entry): previous entry is overwritten """ - Entry = have_optional_dependency("pybtex.database", "Entry") + Entry = import_optional_dependency("pybtex.database", "Entry") # Check input types are correct if not isinstance(key, str) or not isinstance(entry, Entry): raise TypeError() @@ -151,8 +151,8 @@ def _parse_citation(self, key): key: str A BibTeX formatted citation """ - PybtexError = have_optional_dependency("pybtex.scanner", "PybtexError") - parse_string = have_optional_dependency("pybtex.database", "parse_string") + PybtexError = import_optional_dependency("pybtex.scanner", "PybtexError") + parse_string = import_optional_dependency("pybtex.database", "parse_string") try: # Parse string as a bibtex citation, and check that a citation was found bib_data = parse_string(key, bib_format="bibtex") @@ -219,7 +219,7 @@ def print(self, filename=None, output_format="text", verbose=False): """ # Parse citations that were not known keys at registration, but do not # fail if they cannot be parsed - pybtex = have_optional_dependency("pybtex") + pybtex = import_optional_dependency("pybtex") try: for key in self._unknown_citations: self._parse_citation(key) diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index 0bb6168a7c..e9f1535e50 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType if TYPE_CHECKING: # pragma: no cover @@ -157,7 +157,7 @@ def is_constant(self): def to_equation(self) -> sympy.Array: """Returns the value returned by the node when evaluated.""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") entries_list = self.entries.tolist() return sympy.Array(entries_list) diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index d10844798b..4ee99cbaca 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -9,7 +9,7 @@ import functools import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency from typing import Callable, cast @@ -180,7 +180,7 @@ def _sympy_operator(self, left, right): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: @@ -388,7 +388,7 @@ def _binary_evaluate(self, left, right): def _sympy_operator(self, left, right): """Override :meth:`pybamm.BinaryOperator._sympy_operator`""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") left = sympy.Matrix(left) right = sympy.Matrix(right) return left * right @@ -737,7 +737,7 @@ def _binary_new_copy( def _sympy_operator(self, left, right): """Override :meth:`pybamm.BinaryOperator._sympy_operator`""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") return sympy.Min(left, right) @@ -782,7 +782,7 @@ def _binary_new_copy( def _sympy_operator(self, left, right): """Override :meth:`pybamm.BinaryOperator._sympy_operator`""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") return sympy.Max(left, right) diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 06718c311d..c452daadf1 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -10,7 +10,7 @@ from typing import Sequence import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class Concatenation(pybamm.Symbol): @@ -159,7 +159,7 @@ def is_constant(self): def _sympy_operator(self, *children): """Apply appropriate SymPy operators.""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") self.concat_latex = tuple(map(sympy.latex, children)) if self.print_name is not None: diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 72c9d4074a..6094e79dfb 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -9,7 +9,7 @@ from typing_extensions import TypeVar import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class Function(pybamm.Symbol): @@ -97,7 +97,7 @@ def _function_diff(self, children: Sequence[pybamm.Symbol], idx: float): Derivative with respect to child number 'idx'. See :meth:`pybamm.Symbol._diff()`. """ - autograd = have_optional_dependency("autograd") + autograd = import_optional_dependency("autograd") # Store differentiated function, needed in case we want to convert to CasADi if self.derivative == "autograd": return Function( @@ -210,7 +210,7 @@ def _sympy_operator(self, child): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: @@ -275,7 +275,7 @@ def _function_new_copy(self, children): def _sympy_operator(self, child): """Apply appropriate SymPy operators.""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") class_name = self.__class__.__name__.lower() sympy_function = getattr(sympy, class_name) return sympy_function(child) @@ -332,7 +332,7 @@ def _function_diff(self, children, idx): def _sympy_operator(self, child): """Override :meth:`pybamm.Function._sympy_operator`""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") return sympy.asinh(child) @@ -360,7 +360,7 @@ def _function_diff(self, children, idx): def _sympy_operator(self, child): """Override :meth:`pybamm.Function._sympy_operator`""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") return sympy.atan(child) diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index 0dca6dba46..dddf5c73f4 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -6,7 +6,7 @@ import numpy as np import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType KNOWN_COORD_SYS = ["cartesian", "cylindrical polar", "spherical polar"] @@ -58,7 +58,7 @@ def _jac(self, variable) -> pybamm.Scalar: def to_equation(self) -> sympy.Symbol: """Convert the node and its subtree into a SymPy equation.""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: @@ -102,7 +102,7 @@ def _evaluate_for_shape(self): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") return sympy.Symbol("t") diff --git a/pybamm/expression_tree/operations/latexify.py b/pybamm/expression_tree/operations/latexify.py index c16ab4b83d..aec7ed9402 100644 --- a/pybamm/expression_tree/operations/latexify.py +++ b/pybamm/expression_tree/operations/latexify.py @@ -9,7 +9,7 @@ import pybamm from pybamm.expression_tree.printing.sympy_overrides import custom_print_func -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency def get_rng_min_max_name(rng, min_or_max): @@ -89,7 +89,7 @@ def _get_bcs_displays(self, var): Returns a list of boundary condition equations with ranges in front of the equations. """ - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") bcs_eqn_list = [] bcs = self.model.boundary_conditions.get(var, None) @@ -120,7 +120,7 @@ def _get_bcs_displays(self, var): def _get_param_var(self, node): """Returns a list of parameters and a list of variables.""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") param_list = [] var_list = [] dfs_nodes = [node] @@ -163,7 +163,7 @@ def _get_param_var(self, node): return param_list, var_list def latexify(self, output_variables=None): - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") # Voltage is the default output variable if it exists if output_variables is None: if "Voltage [V]" in self.model.variables: diff --git a/pybamm/expression_tree/parameter.py b/pybamm/expression_tree/parameter.py index e646ff234d..4b2da9819b 100644 --- a/pybamm/expression_tree/parameter.py +++ b/pybamm/expression_tree/parameter.py @@ -11,7 +11,7 @@ import sympy import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class Parameter(pybamm.Symbol): @@ -48,7 +48,7 @@ def is_constant(self) -> Literal[False]: def to_equation(self) -> sympy.Symbol: """Convert the node and its subtree into a SymPy equation.""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: @@ -244,7 +244,7 @@ def _evaluate_for_shape(self): def to_equation(self) -> sympy.Symbol: """Convert the node and its subtree into a SymPy equation.""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: diff --git a/pybamm/expression_tree/scalar.py b/pybamm/expression_tree/scalar.py index 26bbabfcf0..f41fd233df 100644 --- a/pybamm/expression_tree/scalar.py +++ b/pybamm/expression_tree/scalar.py @@ -6,7 +6,7 @@ from typing import Literal import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency from pybamm.type_definitions import Numeric @@ -87,7 +87,7 @@ def is_constant(self) -> Literal[True]: def to_equation(self): """Returns the value returned by the node when evaluated.""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 70f4e82db6..62e3620a4c 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Sequence, cast import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency from pybamm.expression_tree.printing.print_name import prettify_print_name if TYPE_CHECKING: # pragma: no cover @@ -478,7 +478,7 @@ def render(self): # pragma: no cover """ Print out a visual representation of the tree (this node and its children) """ - anytree = have_optional_dependency("anytree") + anytree = import_optional_dependency("anytree") for pre, _, node in anytree.RenderTree(self): if isinstance(node, pybamm.Scalar) and node.name != str(node.value): print(f"{pre}{node.name} = {node.value}") @@ -497,7 +497,7 @@ def visualise(self, filename: str): filename to output, must end in ".png" """ - DotExporter = have_optional_dependency("anytree.exporter", "DotExporter") + DotExporter = import_optional_dependency("anytree.exporter", "DotExporter") # check that filename ends in .png. if filename[-4:] != ".png": raise ValueError("filename should end in .png") @@ -517,7 +517,7 @@ def relabel_tree(self, symbol: Symbol, counter: int): Finds all children of a symbol and assigns them a new id so that they can be visualised properly using the graphviz output """ - anytree = have_optional_dependency("anytree") + anytree = import_optional_dependency("anytree") name = symbol.name if name == "div": name = "∇⋅" @@ -560,7 +560,7 @@ def pre_order(self): a b """ - anytree = have_optional_dependency("anytree") + anytree = import_optional_dependency("anytree") return anytree.PreOrderIter(self) def __str__(self): @@ -1056,7 +1056,7 @@ def print_name(self, name): self._print_name = prettify_print_name(name) def to_equation(self): - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") return sympy.Symbol(str(self.name)) def to_json(self): diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 319499a9fc..1818ce630b 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -6,7 +6,7 @@ import numpy as np from scipy.sparse import csr_matrix, issparse import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency from pybamm.type_definitions import DomainsType @@ -108,7 +108,7 @@ def _sympy_operator(self, child): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: @@ -450,7 +450,9 @@ def _unary_new_copy(self, child): def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" - sympy_Gradient = have_optional_dependency("sympy.vector.operators", "Gradient") + sympy_Gradient = import_optional_dependency( + "sympy.vector.operators", "Gradient" + ) return sympy_Gradient(child) @@ -484,7 +486,7 @@ def _unary_new_copy(self, child): def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" - sympy_Divergence = have_optional_dependency( + sympy_Divergence = import_optional_dependency( "sympy.vector.operators", "Divergence" ) return sympy_Divergence(child) @@ -672,7 +674,7 @@ def _evaluates_on_edges(self, dimension: str) -> bool: def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") return sympy.Integral(child, sympy.Symbol("xn")) @@ -996,7 +998,7 @@ def _unary_new_copy(self, child): def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") if ( self.child.domain[0] in ["negative particle", "positive particle"] and self.side == "right" diff --git a/pybamm/expression_tree/variable.py b/pybamm/expression_tree/variable.py index eb0d90cdb6..63c4573a5b 100644 --- a/pybamm/expression_tree/variable.py +++ b/pybamm/expression_tree/variable.py @@ -5,7 +5,7 @@ import numpy as np import numbers import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency from pybamm.type_definitions import ( DomainType, AuxiliaryDomainType, @@ -135,7 +135,7 @@ def _evaluate_for_shape(self): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: diff --git a/pybamm/meshes/scikit_fem_submeshes.py b/pybamm/meshes/scikit_fem_submeshes.py index 82a7bd72f1..e52f58f069 100644 --- a/pybamm/meshes/scikit_fem_submeshes.py +++ b/pybamm/meshes/scikit_fem_submeshes.py @@ -5,7 +5,7 @@ from .meshes import SubMesh import numpy as np -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class ScikitSubMesh2D(SubMesh): @@ -27,7 +27,7 @@ class ScikitSubMesh2D(SubMesh): """ def __init__(self, edges, coord_sys, tabs): - skfem = have_optional_dependency("skfem") + skfem = import_optional_dependency("skfem") self.edges = edges self.nodes = dict.fromkeys(["y", "z"]) for var in self.nodes.keys(): diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index b6b5a9b2da..0266300dea 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -13,7 +13,7 @@ import pybamm from pybamm.expression_tree.operations.serialise import Serialise -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class BaseModel: @@ -1185,7 +1185,7 @@ def latexify(self, filename=None, newline=True, output_variables=None): This will return first five model equations >>> model.latexify(newline=False)[1:5] """ - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") if sympy: from pybamm.expression_tree.operations.latexify import Latexify diff --git a/pybamm/plotting/plot.py b/pybamm/plotting/plot.py index cf5c972a87..4037ab8fbf 100644 --- a/pybamm/plotting/plot.py +++ b/pybamm/plotting/plot.py @@ -3,7 +3,7 @@ # import pybamm from .quick_plot import ax_min, ax_max -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency def plot(x, y, ax=None, show_plot=True, **kwargs): @@ -27,7 +27,7 @@ def plot(x, y, ax=None, show_plot=True, **kwargs): Keyword arguments, passed to plt.plot """ - plt = have_optional_dependency("matplotlib.pyplot") + plt = import_optional_dependency("matplotlib.pyplot") if not isinstance(x, pybamm.Array): raise TypeError("x must be 'pybamm.Array'") diff --git a/pybamm/plotting/plot2D.py b/pybamm/plotting/plot2D.py index a37cd1e2ed..7d1f3c6bae 100644 --- a/pybamm/plotting/plot2D.py +++ b/pybamm/plotting/plot2D.py @@ -3,7 +3,7 @@ # import pybamm from .quick_plot import ax_min, ax_max -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency def plot2D(x, y, z, ax=None, show_plot=True, **kwargs): @@ -27,7 +27,7 @@ def plot2D(x, y, z, ax=None, show_plot=True, **kwargs): only display the plot after plt.show() has been called. """ - plt = have_optional_dependency("matplotlib.pyplot") + plt = import_optional_dependency("matplotlib.pyplot") if not isinstance(x, pybamm.Array): raise TypeError("x must be 'pybamm.Array'") diff --git a/pybamm/plotting/plot_summary_variables.py b/pybamm/plotting/plot_summary_variables.py index 33642c4d5a..bd4db0ee6c 100644 --- a/pybamm/plotting/plot_summary_variables.py +++ b/pybamm/plotting/plot_summary_variables.py @@ -3,7 +3,7 @@ # import numpy as np import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency def plot_summary_variables( @@ -27,7 +27,7 @@ def plot_summary_variables( Keyword arguments, passed to plt.subplots. """ - plt = have_optional_dependency("matplotlib.pyplot") + plt = import_optional_dependency("matplotlib.pyplot") if isinstance(solutions, pybamm.Solution): solutions = [solutions] diff --git a/pybamm/plotting/plot_voltage_components.py b/pybamm/plotting/plot_voltage_components.py index 3b155b71de..0d1bb7b573 100644 --- a/pybamm/plotting/plot_voltage_components.py +++ b/pybamm/plotting/plot_voltage_components.py @@ -3,7 +3,7 @@ # import numpy as np -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency from pybamm.simulation import Simulation from pybamm.solvers.solution import Solution @@ -42,7 +42,7 @@ def plot_voltage_components( solution = input_data.solution elif isinstance(input_data, Solution): solution = input_data - plt = have_optional_dependency("matplotlib.pyplot") + plt = import_optional_dependency("matplotlib.pyplot") # Set a default value for alpha, the opacity kwargs_fill = {"alpha": 0.6, **kwargs_fill} diff --git a/pybamm/plotting/quick_plot.py b/pybamm/plotting/quick_plot.py index c6d3054abb..dbb8f37555 100644 --- a/pybamm/plotting/quick_plot.py +++ b/pybamm/plotting/quick_plot.py @@ -5,7 +5,7 @@ import numpy as np import pybamm from collections import defaultdict -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class LoopList(list): @@ -46,7 +46,7 @@ def split_long_string(title, max_words=None): def close_plots(): """Close all open figures""" - plt = have_optional_dependency("matplotlib.pyplot") + plt = import_optional_dependency("matplotlib.pyplot") plt.close("all") @@ -473,10 +473,10 @@ def plot(self, t, dynamic=False): Dimensional time (in 'time_units') at which to plot. """ - plt = have_optional_dependency("matplotlib.pyplot") - gridspec = have_optional_dependency("matplotlib.gridspec") - cm = have_optional_dependency("matplotlib", "cm") - colors = have_optional_dependency("matplotlib", "colors") + plt = import_optional_dependency("matplotlib.pyplot") + gridspec = import_optional_dependency("matplotlib.gridspec") + cm = import_optional_dependency("matplotlib", "cm") + colors = import_optional_dependency("matplotlib", "colors") t_in_seconds = t * self.time_scaling_factor self.fig = plt.figure(figsize=self.figsize) @@ -674,8 +674,8 @@ def dynamic_plot(self, show_plot=True, step=None): continuous_update=False, ) else: - plt = have_optional_dependency("matplotlib.pyplot") - Slider = have_optional_dependency("matplotlib.widgets", "Slider") + plt = import_optional_dependency("matplotlib.pyplot") + Slider = import_optional_dependency("matplotlib.widgets", "Slider") # create an initial plot at time self.min_t self.plot(self.min_t, dynamic=True) @@ -779,8 +779,8 @@ def create_gif(self, number_of_images=80, duration=0.1, output_filename="plot.gi Name of the generated GIF file. """ - imageio = have_optional_dependency("imageio.v2") - plt = have_optional_dependency("matplotlib.pyplot") + imageio = import_optional_dependency("imageio.v2") + plt = import_optional_dependency("matplotlib.pyplot") # time stamps at which the images/plots will be created time_array = np.linspace(self.min_t, self.max_t, num=number_of_images) diff --git a/pybamm/simulation.py b/pybamm/simulation.py index c1d62a47e9..aa58022b6d 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -11,7 +11,7 @@ import sys from functools import lru_cache from datetime import timedelta -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency from pybamm.expression_tree.operations.serialise import Serialise @@ -699,7 +699,7 @@ def solve( # check if a user has tqdm installed if showprogress: - tqdm = have_optional_dependency("tqdm") + tqdm = import_optional_dependency("tqdm") cycle_lengths = tqdm.tqdm( self.experiment.cycle_lengths, desc="Cycling", diff --git a/pybamm/spatial_methods/scikit_finite_element.py b/pybamm/spatial_methods/scikit_finite_element.py index 07a3c0e1be..e65e29f7f8 100644 --- a/pybamm/spatial_methods/scikit_finite_element.py +++ b/pybamm/spatial_methods/scikit_finite_element.py @@ -7,7 +7,7 @@ from scipy.sparse.linalg import inv import numpy as np -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class ScikitFiniteElement(pybamm.SpatialMethod): @@ -88,7 +88,7 @@ def gradient(self, symbol, discretised_symbol, boundary_conditions): to the y-component of the gradient and the second column corresponds to the z component of the gradient. """ - skfem = have_optional_dependency("skfem") + skfem = import_optional_dependency("skfem") domain = symbol.domain[0] mesh = self.mesh[domain] @@ -144,7 +144,7 @@ def gradient_matrix(self, symbol, boundary_conditions): :class:`pybamm.Matrix` The (sparse) finite element gradient matrix for the domain """ - skfem = have_optional_dependency("skfem") + skfem = import_optional_dependency("skfem") # get primary domain mesh domain = symbol.domain[0] mesh = self.mesh[domain] @@ -190,7 +190,7 @@ def laplacian(self, symbol, discretised_symbol, boundary_conditions): Contains the result of acting the discretised gradient on the child discretised_symbol """ - skfem = have_optional_dependency("skfem") + skfem = import_optional_dependency("skfem") domain = symbol.domain[0] mesh = self.mesh[domain] @@ -258,7 +258,7 @@ def stiffness_matrix(self, symbol, boundary_conditions): :class:`pybamm.Matrix` The (sparse) finite element stiffness matrix for the domain """ - skfem = have_optional_dependency("skfem") + skfem = import_optional_dependency("skfem") # get primary domain mesh domain = symbol.domain[0] mesh = self.mesh[domain] @@ -321,7 +321,7 @@ def definite_integral_matrix(self, child, vector_type="row"): :class:`pybamm.Matrix` The finite element integral vector for the domain """ - skfem = have_optional_dependency("skfem") + skfem = import_optional_dependency("skfem") # get primary domain mesh domain = child.domain[0] mesh = self.mesh[domain] @@ -383,7 +383,7 @@ def boundary_integral_vector(self, domain, region): :class:`pybamm.Matrix` The finite element integral vector for the domain """ - skfem = have_optional_dependency("skfem") + skfem = import_optional_dependency("skfem") # get primary domain mesh mesh = self.mesh[domain[0]] @@ -501,7 +501,7 @@ def assemble_mass_form(self, symbol, boundary_conditions, region="interior"): :class:`pybamm.Matrix` The (sparse) mass matrix for the spatial method. """ - skfem = have_optional_dependency("skfem") + skfem = import_optional_dependency("skfem") # get primary domain mesh domain = symbol.domain[0] mesh = self.mesh[domain] diff --git a/pybamm/util.py b/pybamm/util.py index 1149327cf7..275dec7c79 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -352,25 +352,19 @@ def install_jax(arguments=None): # pragma: no cover ) -# https://docs.pybamm.org/en/latest/source/user_guide/contributing.html#managing-optional-dependencies-and-their-imports -def have_optional_dependency(module_name, attribute=None): +def import_optional_dependency(module_name, attribute=None): err_msg = f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details." try: - # Attempt to import the specified module module = importlib.import_module(module_name) - if attribute: - # If an attribute is specified, check if it's available if hasattr(module, attribute): imported_attribute = getattr(module, attribute) - return imported_attribute # Return the imported attribute + # Return the imported attribute + return imported_attribute else: - # Raise an ModuleNotFoundError if the attribute is not available raise ModuleNotFoundError(err_msg) # pragma: no cover else: # Return the entire module if no attribute is specified return module - except ModuleNotFoundError: - # Raise an ModuleNotFoundError if the module or attribute is not available raise ModuleNotFoundError(err_msg) diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index b6cbe093eb..eab0fd8d45 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -9,7 +9,7 @@ from scipy.sparse import coo_matrix import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency EMPTY_DOMAINS = { "primary": [], @@ -790,7 +790,7 @@ def test_inner_simplifications(self): self.assertEqual(pybamm.inner(a3, a3).evaluate(), 9) def test_to_equation(self): - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") # Test print_name pybamm.Addition.print_name = "test" self.assertEqual(pybamm.Addition(1, 2).to_equation(), sympy.Symbol("test")) diff --git a/tests/unit/test_expression_tree/test_concatenations.py b/tests/unit/test_expression_tree/test_concatenations.py index 691b6a7ee2..28b89b6e28 100644 --- a/tests/unit/test_expression_tree/test_concatenations.py +++ b/tests/unit/test_expression_tree/test_concatenations.py @@ -8,7 +8,7 @@ import numpy as np import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency from tests import get_discretisation_for_testing, get_mesh_for_testing @@ -371,7 +371,7 @@ def test_numpy_concatenation(self): ) def test_to_equation(self): - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") a = pybamm.Symbol("a", domain="test a") b = pybamm.Symbol("b", domain="test b") func_symbol = sympy.Symbol(r"\begin{cases}a\\b\end{cases}") diff --git a/tests/unit/test_expression_tree/test_functions.py b/tests/unit/test_expression_tree/test_functions.py index 33e11459ab..2c8f969984 100644 --- a/tests/unit/test_expression_tree/test_functions.py +++ b/tests/unit/test_expression_tree/test_functions.py @@ -9,7 +9,7 @@ from scipy import special import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency def test_function(arg): @@ -121,7 +121,7 @@ def test_function_unnamed(self): self.assertEqual(fun.name, "function (cos)") def test_to_equation(self): - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") a = pybamm.Symbol("a", domain="test") # Test print_name diff --git a/tests/unit/test_expression_tree/test_independent_variable.py b/tests/unit/test_expression_tree/test_independent_variable.py index b748a6fbe9..377bb48c13 100644 --- a/tests/unit/test_expression_tree/test_independent_variable.py +++ b/tests/unit/test_expression_tree/test_independent_variable.py @@ -6,7 +6,7 @@ import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class TestIndependentVariable(TestCase): @@ -64,7 +64,7 @@ def test_spatial_variable_edge(self): self.assertTrue(x.evaluates_on_edges("primary")) def test_to_equation(self): - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") # Test print_name func = pybamm.IndependentVariable("a") func.print_name = "test" diff --git a/tests/unit/test_expression_tree/test_parameter.py b/tests/unit/test_expression_tree/test_parameter.py index 6940ac38fe..27f80210c1 100644 --- a/tests/unit/test_expression_tree/test_parameter.py +++ b/tests/unit/test_expression_tree/test_parameter.py @@ -6,7 +6,7 @@ import unittest import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class TestParameter(TestCase): @@ -20,7 +20,7 @@ def test_evaluate_for_shape(self): self.assertIsInstance(a.evaluate_for_shape(), numbers.Number) def test_to_equation(self): - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") func = pybamm.Parameter("test_string") func1 = pybamm.Parameter("test_name") @@ -107,7 +107,7 @@ def _myfun(x): self.assertEqual(_myfun(x).print_name, None) def test_function_parameter_to_equation(self): - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") func = pybamm.FunctionParameter("test", {"x": pybamm.Scalar(1)}) func1 = pybamm.FunctionParameter("func", {"var": pybamm.Variable("var")}) diff --git a/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py b/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py index de3ff08c43..eef5016d1e 100644 --- a/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py +++ b/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py @@ -6,12 +6,12 @@ import pybamm from pybamm.expression_tree.printing.sympy_overrides import custom_print_func -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class TestCustomPrint(TestCase): def test_print_Derivative(self): - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") # Test force_partial der1 = sympy.Derivative("y", "x") der1.force_partial = True diff --git a/tests/unit/test_expression_tree/test_symbol.py b/tests/unit/test_expression_tree/test_symbol.py index 9a7939c66d..71e5b205bf 100644 --- a/tests/unit/test_expression_tree/test_symbol.py +++ b/tests/unit/test_expression_tree/test_symbol.py @@ -12,7 +12,7 @@ import pybamm from pybamm.expression_tree.binary_operators import _Heaviside -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class TestSymbol(TestCase): @@ -485,7 +485,7 @@ def test_test_shape(self): (y1 + y2).test_shape() def test_to_equation(self): - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") self.assertEqual(pybamm.Symbol("test").to_equation(), sympy.Symbol("test")) def test_numpy_array_ufunc(self): diff --git a/tests/unit/test_expression_tree/test_unary_operators.py b/tests/unit/test_expression_tree/test_unary_operators.py index 6ae6b62d05..d03a32ab32 100644 --- a/tests/unit/test_expression_tree/test_unary_operators.py +++ b/tests/unit/test_expression_tree/test_unary_operators.py @@ -9,7 +9,7 @@ from scipy.sparse import diags import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class TestUnaryOperators(TestCase): @@ -678,11 +678,13 @@ def test_not_constant(self): self.assertFalse((2 * a).is_constant()) def test_to_equation(self): - sympy = have_optional_dependency("sympy") - sympy_Divergence = have_optional_dependency( + sympy = import_optional_dependency("sympy") + sympy_Divergence = import_optional_dependency( "sympy.vector.operators", "Divergence" ) - sympy_Gradient = have_optional_dependency("sympy.vector.operators", "Gradient") + sympy_Gradient = import_optional_dependency( + "sympy.vector.operators", "Gradient" + ) a = pybamm.Symbol("a", domain="negative particle") b = pybamm.Symbol("b", domain="current collector") diff --git a/tests/unit/test_expression_tree/test_variable.py b/tests/unit/test_expression_tree/test_variable.py index 0d5aa251d2..cb5178a377 100644 --- a/tests/unit/test_expression_tree/test_variable.py +++ b/tests/unit/test_expression_tree/test_variable.py @@ -7,7 +7,7 @@ import numpy as np import pybamm -from pybamm.util import have_optional_dependency +from pybamm.util import import_optional_dependency class TestVariable(TestCase): @@ -55,7 +55,7 @@ def test_variable_bounds(self): pybamm.Variable("var", bounds=(1, 1)) def test_to_equation(self): - sympy = have_optional_dependency("sympy") + sympy = import_optional_dependency("sympy") # Test print_name func = pybamm.Variable("test_string") func.print_name = "test" diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index 24f204b6df..6fe43c096d 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -99,7 +99,7 @@ def test_git_commit_info(self): self.assertIsInstance(git_commit_info, str) self.assertEqual(git_commit_info[:2], "v2") - def test_have_optional_dependency(self): + def test_import_optional_dependency(self): with self.assertRaisesRegex( ModuleNotFoundError, "Optional dependency pybtex is not available." ): @@ -119,7 +119,7 @@ def test_have_optional_dependency(self): sym.visualise(test_name) sys.modules["pybtex"] = pybtex - pybamm.util.have_optional_dependency("pybtex") + pybamm.util.import_optional_dependency("pybtex") pybamm.print_citations() From f4bb55cf11edb5ab395210b62612ad7d7c3d17cc Mon Sep 17 00:00:00 2001 From: kratman Date: Tue, 5 Mar 2024 09:46:25 -0500 Subject: [PATCH 2/5] Change log --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6cade77684..8bf54b563e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ ## Breaking changes +- Renamed "have_optional_dependency" to "import_optional_dependency" ([#3866](https://github.com/pybamm-team/PyBaMM/pull/3866)) - Renamed "testing" argument for plots to "show_plot" and flipped its meaning (show_plot=True is now the default and shows the plot) ([#3842](https://github.com/pybamm-team/PyBaMM/pull/3842)) - Dropped support for BPX version 0.3.0 and below ([#3414](https://github.com/pybamm-team/PyBaMM/pull/3414)) From 4d78a0841551d990e8350f78b5496db1d0f9d670 Mon Sep 17 00:00:00 2001 From: kratman Date: Tue, 5 Mar 2024 10:21:38 -0500 Subject: [PATCH 3/5] Fix import --- pybamm/models/base_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index 194b8938af..28530df90b 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -1185,8 +1185,7 @@ def latexify(self, filename=None, newline=True, output_variables=None): This will return first five model equations >>> model.latexify(newline=False)[1:5] """ - if sympy: - from pybamm.expression_tree.operations.latexify import Latexify + from pybamm.expression_tree.operations.latexify import Latexify return Latexify(self, filename, newline).latexify( output_variables=output_variables From 8ab78ed885b3c9804c1f0db9beaf28d0418e762c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Mar 2024 15:26:09 +0000 Subject: [PATCH 4/5] style: pre-commit fixes --- pybamm/models/base_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index 28530df90b..007897370f 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -13,7 +13,6 @@ import pybamm from pybamm.expression_tree.operations.serialise import Serialise -import sympy class BaseModel: From cb32f632a713c4a99cb46aae1b76e3c5e23fe475 Mon Sep 17 00:00:00 2001 From: Arjun Verma Date: Tue, 5 Mar 2024 21:28:39 +0530 Subject: [PATCH 5/5] Update pybamm/util.py --- pybamm/util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pybamm/util.py b/pybamm/util.py index 275dec7c79..0722b8174b 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -352,6 +352,7 @@ def install_jax(arguments=None): # pragma: no cover ) +# https://docs.pybamm.org/en/latest/source/user_guide/contributing.html#managing-optional-dependencies-and-their-imports def import_optional_dependency(module_name, attribute=None): err_msg = f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details." try: