diff --git a/.gitignore b/.gitignore
index db4561ea..ce3d241b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -52,3 +52,8 @@ docs/_build/
# PyBuilder
target/
+
+# Plots genersated when running sample code
+/*.png
+/*.svg
+/*.pdf
\ No newline at end of file
diff --git a/.travis.yml b/.travis.yml
index d8657a8f..025017a7 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -2,19 +2,28 @@ language: python
python:
- "2.7"
- - "3.4"
- "3.5"
+ - "3.6"
+ - "3.7"
+
+addons:
+ apt:
+ packages:
+ - graphviz
+
install:
- pip install Sphinx sphinx_rtd_theme codecov packaging
- "python -c $'import os, packaging.version as version\\nv = version.parse(os.environ.get(\"TRAVIS_TAG\", \"1.0\")).public\\nwith open(\"VERSION\", \"w\") as f: f.write(v)'"
- - python setup.py install
+ - pip install -e .[test]
- cd docs
- make clean html
- cd ..
script:
- - python setup.py nosetests --with-coverage --cover-package=graphkit
+ - pytest -v --cov=graphkit
+ # In case you adopt -m 'not slow' in setup.cfg.
+ #- pytest -vm slow --cov=graphkit
deploy:
provider: pypi
diff --git a/README.md b/README.md
index 0e1e95a4..81b1657d 100644
--- a/README.md
+++ b/README.md
@@ -6,6 +6,8 @@
> It's a DAG all the way down
+
+
## Lightweight computation graphs for Python
GraphKit is a lightweight Python module for creating and running ordered graphs of computations, where the nodes of the graph correspond to computational operations, and the edges correspond to output --> input dependencies between those operations. Such graphs are useful in computer vision, machine learning, and many other domains.
@@ -14,9 +16,12 @@ GraphKit is a lightweight Python module for creating and running ordered graphs
Here's how to install:
-```
-pip install graphkit
-```
+ pip install graphkit
+
+OR with dependencies for plotting support (and you need to install [`Graphviz`](https://graphviz.org)
+program separately with your OS tools)::
+
+ pip install graphkit[plot]
Here's a Python script with an example GraphKit computation graph that produces multiple outputs (`a * b`, `a - a * b`, and `abs(a - a * b) ** 3`):
@@ -30,20 +35,20 @@ def abspow(a, p):
return c
# Compose the mul, sub, and abspow operations into a computation graph.
-graph = compose(name="graph")(
+graphop = compose(name="graphop")(
operation(name="mul1", needs=["a", "b"], provides=["ab"])(mul),
operation(name="sub1", needs=["a", "ab"], provides=["a_minus_ab"])(sub),
operation(name="abspow1", needs=["a_minus_ab"], provides=["abs_a_minus_ab_cubed"], params={"p": 3})(abspow)
)
# Run the graph and request all of the outputs.
-out = graph({'a': 2, 'b': 5})
+out = graphop({'a': 2, 'b': 5})
# Prints "{'a': 2, 'a_minus_ab': -8, 'b': 5, 'ab': 10, 'abs_a_minus_ab_cubed': 512}".
print(out)
# Run the graph and request a subset of the outputs.
-out = graph({'a': 2, 'b': 5}, outputs=["a_minus_ab"])
+out = graphop({'a': 2, 'b': 5}, outputs=["a_minus_ab"])
# Prints "{'a_minus_ab': -8}".
print(out)
@@ -51,6 +56,21 @@ print(out)
As you can see, any function can be used as an operation in GraphKit, even ones imported from system modules!
+
+## Plotting
+
+For debugging the above graph-operation you may plot it using these methods:
+
+```python
+ graphop.plot(show=True, solution=out) # open a matplotlib window with solution values in nodes
+ graphop.plot("intro.svg") # other supported formats: png, jpg, pdf, ...
+```
+
+
+
+
+> **TIP:** The `pydot.Dot` instances returned by `plot()` are rendered as SVG in *Jupyter/IPython*.
+
# License
Code licensed under the Apache License, Version 2.0 license. See LICENSE file for terms.
diff --git a/docs/source/graph_composition.rst b/docs/source/graph_composition.rst
index ba428a14..1d8e9f6d 100644
--- a/docs/source/graph_composition.rst
+++ b/docs/source/graph_composition.rst
@@ -30,15 +30,15 @@ The simplest use case for ``compose`` is assembling a collection of individual o
return c
# Compose the mul, sub, and abspow operations into a computation graph.
- graph = compose(name="graph")(
+ graphop = compose(name="graphop")(
operation(name="mul1", needs=["a", "b"], provides=["ab"])(mul),
operation(name="sub1", needs=["a", "ab"], provides=["a_minus_ab"])(sub),
operation(name="abspow1", needs=["a_minus_ab"], provides=["abs_a_minus_ab_cubed"], params={"p": 3})(abspow)
)
-The call here to ``compose()()`` yields a runnable computation graph that looks like this (where the circles are operations, squares are data, and octagons are parameters):
+The call here to ``compose()`` yields a runnable computation graph that looks like this (where the circles are operations, squares are data, and octagons are parameters):
-.. image:: images/example_graph.svg
+.. image:: images/intro.svg
.. _graph-computations:
@@ -49,7 +49,7 @@ Running a computation graph
The graph composed in the example above in :ref:`simple-graph-composition` can be run by simply calling it with a dictionary argument whose keys correspond to the names of inputs to the graph and whose values are the corresponding input values. For example, if ``graph`` is as defined above, we can run it like this::
# Run the graph and request all of the outputs.
- out = graph({'a': 2, 'b': 5})
+ out = graphop({'a': 2, 'b': 5})
# Prints "{'a': 2, 'a_minus_ab': -8, 'b': 5, 'ab': 10, 'abs_a_minus_ab_cubed': 512}".
print(out)
@@ -57,10 +57,10 @@ The graph composed in the example above in :ref:`simple-graph-composition` can b
Producing a subset of outputs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-By default, calling a graph on a set of inputs will yield all of that graph's outputs. You can use the ``outputs`` parameter to request only a subset. For example, if ``graph`` is as above::
+By default, calling a graph-operation on a set of inputs will yield all of that graph's outputs. You can use the ``outputs`` parameter to request only a subset. For example, if ``graphop`` is as above::
- # Run the graph and request a subset of the outputs.
- out = graph({'a': 2, 'b': 5}, outputs=["a_minus_ab"])
+ # Run the graph-operation and request a subset of the outputs.
+ out = graphop({'a': 2, 'b': 5}, outputs=["a_minus_ab"])
# Prints "{'a_minus_ab': -8}".
print(out)
@@ -70,17 +70,17 @@ When using ``outputs`` to request only a subset of a graph's outputs, GraphKit e
Short-circuiting a graph computation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-You can short-circuit a graph computation, making certain inputs unnecessary, by providing a value in the graph that is further downstream in the graph than those inputs. For example, in the graph we've been working with, you could provide the value of ``a_minus_ab`` to make the inputs ``a`` and ``b`` unnecessary::
+You can short-circuit a graph computation, making certain inputs unnecessary, by providing a value in the graph that is further downstream in the graph than those inputs. For example, in the graph-operation we've been working with, you could provide the value of ``a_minus_ab`` to make the inputs ``a`` and ``b`` unnecessary::
- # Run the graph and request a subset of the outputs.
- out = graph({'a_minus_ab': -8})
+ # Run the graph-operation and request a subset of the outputs.
+ out = graphop({'a_minus_ab': -8})
# Prints "{'a_minus_ab': -8, 'abs_a_minus_ab_cubed': 512}".
print(out)
When you do this, any ``operation`` nodes that are not on a path from the downstream input to the requested outputs (i.e. predecessors of the downstream input) are not computed. For example, the ``mul1`` and ``sub1`` operations are not executed here.
-This can be useful if you have a graph that accepts alternative forms of the same input. For example, if your graph requires a ``PIL.Image`` as input, you could allow your graph to be run in an API server by adding an earlier ``operation`` that accepts as input a string of raw image data and converts that data into the needed ``PIL.Image``. Then, you can either provide the raw image data string as input, or you can provide the ``PIL.Image`` if you have it and skip providing the image data string.
+This can be useful if you have a graph-operation that accepts alternative forms of the same input. For example, if your graph-operation requires a ``PIL.Image`` as input, you could allow your graph to be run in an API server by adding an earlier ``operation`` that accepts as input a string of raw image data and converts that data into the needed ``PIL.Image``. Then, you can either provide the raw image data string as input, or you can provide the ``PIL.Image`` if you have it and skip providing the image data string.
Adding on to an existing computation graph
------------------------------------------
@@ -109,7 +109,7 @@ Sometimes you will have two computation graphs---perhaps ones that share operati
combined_graph = compose(name="combined_graph")(graph1, graph2)
-However, if you want to combine graphs that share operations and don't want to pay the price of running redundant computations, you can set the ``merge`` parameter of ``compose()`` to ``True``. This will consolidate redundant ``operation`` nodes (based on ``name``) into a single node. For example, let's say we have ``graph``, as in the examples above, along with this graph::
+However, if you want to combine graphs that share operations and don't want to pay the price of running redundant computations, you can set the ``merge`` parameter of ``compose()`` to ``True``. This will consolidate redundant ``operation`` nodes (based on ``name``) into a single node. For example, let's say we have ``graphop``, as in the examples above, along with this graph::
# This graph shares the "mul1" operation with graph.
another_graph = compose(name="another_graph")(
@@ -117,9 +117,9 @@ However, if you want to combine graphs that share operations and don't want to p
operation(name="mul2", needs=["c", "ab"], provides=["cab"])(mul)
)
-We can merge ``graph`` and ``another_graph`` like so, avoiding a redundant ``mul1`` operation::
+We can merge ``graphop`` and ``another_graph`` like so, avoiding a redundant ``mul1`` operation::
- merged_graph = compose(name="merged_graph", merge=True)(graph, another_graph)
+ merged_graph = compose(name="merged_graph", merge=True)(graphop, another_graph)
This ``merged_graph`` will look like this:
diff --git a/docs/source/images/GraphkitLegend.svg b/docs/source/images/GraphkitLegend.svg
new file mode 100644
index 00000000..798ba3a5
--- /dev/null
+++ b/docs/source/images/GraphkitLegend.svg
@@ -0,0 +1,150 @@
+
+
+
+
+
diff --git a/docs/source/images/intro.svg b/docs/source/images/intro.svg
new file mode 100644
index 00000000..4469543f
--- /dev/null
+++ b/docs/source/images/intro.svg
@@ -0,0 +1,143 @@
+
+
+
+
+
diff --git a/docs/source/images/test_pruning_not_overrides_given_intermediate-asked.png b/docs/source/images/test_pruning_not_overrides_given_intermediate-asked.png
new file mode 100644
index 00000000..c8e6cdb4
Binary files /dev/null and b/docs/source/images/test_pruning_not_overrides_given_intermediate-asked.png differ
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 5c5e505c..f542da58 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -38,6 +38,12 @@ Here's how to install::
pip install graphkit
+OR with dependencies for plotting support (and you need to install `Graphviz
+`_ program separately with your OS tools)::
+
+ pip install graphkit[plot]
+
+
Here's a Python script with an example GraphKit computation graph that produces multiple outputs (``a * b``, ``a - a * b``, and ``abs(a - a * b) ** 3``)::
from operator import mul, sub
@@ -49,26 +55,54 @@ Here's a Python script with an example GraphKit computation graph that produces
return c
# Compose the mul, sub, and abspow operations into a computation graph.
- graph = compose(name="graph")(
+ graphop = compose(name="graphop")(
operation(name="mul1", needs=["a", "b"], provides=["ab"])(mul),
operation(name="sub1", needs=["a", "ab"], provides=["a_minus_ab"])(sub),
operation(name="abspow1", needs=["a_minus_ab"], provides=["abs_a_minus_ab_cubed"], params={"p": 3})(abspow)
)
- # Run the graph and request all of the outputs.
- out = graph({'a': 2, 'b': 5})
+ # Run the graph-operation and request all of the outputs.
+ out = graphop({'a': 2, 'b': 5})
# Prints "{'a': 2, 'a_minus_ab': -8, 'b': 5, 'ab': 10, 'abs_a_minus_ab_cubed': 512}".
print(out)
- # Run the graph and request a subset of the outputs.
- out = graph({'a': 2, 'b': 5}, outputs=["a_minus_ab"])
+ # Run the graph-operation and request a subset of the outputs.
+ out = graphop({'a': 2, 'b': 5}, outputs=["a_minus_ab"])
# Prints "{'a_minus_ab': -8}".
print(out)
As you can see, any function can be used as an operation in GraphKit, even ones imported from system modules!
+
+Plotting
+--------
+
+For debugging the above graph-operation you may plot it using these methods::
+
+ graphop.plot(show=True, solution=out) # open a matplotlib window with solution values in nodes
+ graphop.plot("intro.svg") # other supported formats: png, jpg, pdf, ...
+
+.. image:: images/intro.svg
+ :alt: Intro graph
+
+.. figure:: images/GraphkitLegend.svg
+ :alt: Graphkit Legend
+
+ The legend for all graphkit diagrams, generated by :func:`graphkit.plot.legend()`.
+
+.. Tip::
+ The ``pydot.Dot`` instances returned by ``plot()`` are rendered
+ directly in *Jupyter/IPython* notebooks as SVG images.
+
+.. NOTE::
+ For plots, `Graphviz `_ program must be in your PATH,
+ and ``pydot`` & ``matplotlib`` python packages installed.
+ You may install both when installing ``graphkit`` with its ``plot`` extras::
+
+ pip install graphkit[plot]
+
License
-------
diff --git a/docs/source/operations.rst b/docs/source/operations.rst
index b7b4dbad..fbd6dea2 100644
--- a/docs/source/operations.rst
+++ b/docs/source/operations.rst
@@ -51,15 +51,15 @@ Let's look again at the operations from the script in :ref:`quick-start`, for ex
return c
# Compose the mul, sub, and abspow operations into a computation graph.
- graph = compose(name="graph")(
+ graphop = compose(name="graphop")(
operation(name="mul1", needs=["a", "b"], provides=["ab"])(mul),
operation(name="sub1", needs=["a", "ab"], provides=["a_minus_ab"])(sub),
operation(name="abspow1", needs=["a_minus_ab"], provides=["abs_a_minus_ab_cubed"], params={"p": 3})(abspow)
)
-The ``needs`` and ``provides`` arguments to the operations in this script define a computation graph that looks like this (where the circles are operations, squares are data, and octagons are parameters):
+The ``needs`` and ``provides`` arguments to the operations in this script define a computation graph that looks like this (where the oval are operations, squares/houses are data):
-.. image:: images/example_graph.svg
+.. image:: images/intro.svg
Constant operation parameters: ``params``
@@ -86,7 +86,7 @@ If you are defining your computation graph and the functions that comprise it al
def foo(a, b, c):
return c * (a + b)
- graph = compose(name='foo_graph')(foo)
+ graphop = compose(name='foo_graph')(foo)
Functional specification
^^^^^^^^^^^^^^^^^^^^^^^^
@@ -99,7 +99,7 @@ If the functions underlying your computation graph operations are defined elsewh
add_op = operation(name='add_op', needs=['a', 'b'], provides='sum')(add)
mul_op = operation(name='mul_op', needs=['c', 'sum'], provides='product')(mul)
- graph = compose(name='add_mul_graph')(add_op, mul_op)
+ graphop = compose(name='add_mul_graph')(add_op, mul_op)
The functional specification is also useful if you want to create multiple ``operation`` instances from the same function, perhaps with different parameter values, e.g.::
@@ -111,7 +111,7 @@ The functional specification is also useful if you want to create multiple ``ope
pow_op1 = operation(name='pow_op1', needs=['a'], provides='a_squared')(mypow)
pow_op2 = operation(name='pow_op2', needs=['a'], params={'p': 3}, provides='a_cubed')(mypow)
- graph = compose(name='two_pows_graph')(pow_op1, pow_op2)
+ graphop = compose(name='two_pows_graph')(pow_op1, pow_op2)
A slightly different approach can be used here to accomplish the same effect by creating an operation "factory"::
@@ -125,7 +125,7 @@ A slightly different approach can be used here to accomplish the same effect by
pow_op1 = pow_op_factory(name='pow_op1', needs=['a'], provides='a_squared')
pow_op2 = pow_op_factory(name='pow_op2', needs=['a'], params={'p': 3}, provides='a_cubed')
- graph = compose(name='two_pows_graph')(pow_op1, pow_op2)
+ graphop = compose(name='two_pows_graph')(pow_op1, pow_op2)
Modifiers on ``operation`` inputs and outputs
diff --git a/graphkit/base.py b/graphkit/base.py
index 1c04e8d5..5bf35d7e 100644
--- a/graphkit/base.py
+++ b/graphkit/base.py
@@ -1,5 +1,12 @@
# Copyright 2016, Yahoo Inc.
# Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms.
+try:
+ from collections import abc
+except ImportError:
+ import collections as abc
+
+from . import plot
+
class Data(object):
"""
@@ -144,35 +151,64 @@ def __repr__(self):
self.provides)
-class NetworkOperation(Operation):
+class NetworkOperation(Operation, plot.Plotter):
def __init__(self, **kwargs):
self.net = kwargs.pop('net')
Operation.__init__(self, **kwargs)
# set execution mode to single-threaded sequential by default
self._execution_method = "sequential"
+ self._overwrites_collector = None
+
+ def _build_pydot(self, **kws):
+ """delegate to network"""
+ kws.setdefault("title", self.name)
+ plotter = self.net.last_plan or self.net
+ return plotter._build_pydot(**kws)
def _compute(self, named_inputs, outputs=None):
- return self.net.compute(outputs, named_inputs, method=self._execution_method)
+ return self.net.compute(
+ named_inputs, outputs, method=self._execution_method,
+ overwrites_collector=self._overwrites_collector)
def __call__(self, *args, **kwargs):
return self._compute(*args, **kwargs)
+ def compile(self, *args, **kwargs):
+ return self.net.compile(*args, **kwargs)
+
def set_execution_method(self, method):
"""
Determine how the network will be executed.
- Args:
- method: str
- If "parallel", execute graph operations concurrently
- using a threadpool.
+ :param str method:
+ If "parallel", execute graph operations concurrently
+ using a threadpool.
"""
- options = ['parallel', 'sequential']
- assert method in options
+ choices = ['parallel', 'sequential']
+ if method not in choices:
+ raise ValueError(
+ "Invalid computation method %r! Must be one of %s"
+ (method, choices))
self._execution_method = method
- def plot(self, filename=None, show=False):
- self.net.plot(filename=filename, show=show)
+ def set_overwrites_collector(self, collector):
+ """
+ Asks to put all *overwrites* into the `collector` after computing
+
+ An "overwrites" is intermediate value calculated but NOT stored
+ into the results, becaues it has been given also as an intemediate
+ input value, and the operation that would overwrite it MUST run for
+ its other results.
+
+ :param collector:
+ a mutable dict to be fillwed with named values
+ """
+ if collector is not None and not isinstance(collector, abc.MutableMapping):
+ raise ValueError(
+ "Overwrites collector was not a MutableMapping, but: %r"
+ % collector)
+ self._overwrites_collector = collector
def __getstate__(self):
state = Operation.__getstate__(self)
diff --git a/graphkit/functional.py b/graphkit/functional.py
index 65388973..5b3735fe 100644
--- a/graphkit/functional.py
+++ b/graphkit/functional.py
@@ -1,7 +1,7 @@
# Copyright 2016, Yahoo Inc.
# Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms.
-
-from itertools import chain
+from boltons.setutils import IndexedSet as iset
+import networkx as nx
from .base import Operation, NetworkOperation
from .network import Network
@@ -28,7 +28,7 @@ def _compute(self, named_inputs, outputs=None):
result = zip(self.provides, result)
if outputs:
- outputs = set(outputs)
+ outputs = sorted(set(outputs))
result = filter(lambda x: x[0] in outputs, result)
return dict(result)
@@ -185,27 +185,23 @@ def __call__(self, *operations):
# If merge is desired, deduplicate operations before building network
if self.merge:
- merge_set = set()
+ merge_set = iset() # Preseve given node order.
for op in operations:
if isinstance(op, NetworkOperation):
- net_ops = filter(lambda x: isinstance(x, Operation), op.net.steps)
- merge_set.update(net_ops)
+ netop_nodes = nx.topological_sort(op.net.graph)
+ merge_set.update(s for s in netop_nodes if isinstance(s, Operation))
else:
merge_set.add(op)
- operations = list(merge_set)
-
- def order_preserving_uniquifier(seq, seen=None):
- seen = seen if seen else set()
- seen_add = seen.add
- return [x for x in seq if not (x in seen or seen_add(x))]
+ operations = merge_set
- provides = order_preserving_uniquifier(chain(*[op.provides for op in operations]))
- needs = order_preserving_uniquifier(chain(*[op.needs for op in operations]), set(provides))
+ provides = iset(p for op in operations for p in op.provides)
+ # Mark them all as optional, now that #18 calmly ignores
+ # non-fully satisfied operations.
+ needs = iset(optional(n) for op in operations for n in op.needs) - provides
- # compile network
+ # Build network
net = Network()
for op in operations:
net.add_op(op)
- net.compile()
return NetworkOperation(name=self.name, needs=needs, provides=provides, params={}, net=net)
diff --git a/graphkit/network.py b/graphkit/network.py
index 0df3ddf8..a0dc7663 100644
--- a/graphkit/network.py
+++ b/graphkit/network.py
@@ -1,19 +1,104 @@
# Copyright 2016, Yahoo Inc.
# Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms.
-
-import time
+""""
+The main implementation of the network of operations & data to compute.
+
+The execution of network *operations* (aka computation) is splitted
+in 2 phases:
+
+- COMPILE: prune unsatisfied nodes, sort dag topologically & solve it, and
+ derive the *execution steps* (see below) based on the given *inputs*
+ and asked *outputs*.
+
+- EXECUTE: sequential or parallel invocation of the underlying functions
+ of the operations with arguments from the ``solution``.
+
+Computations are based on 5 data-structures:
+
+:attr:`Network.graph`
+ A ``networkx`` graph (yet a DAG) containing interchanging layers of
+ :class:`Operation` and :class:`DataPlaceholderNode` nodes.
+ They are layed out and connected by repeated calls of
+ :meth:`~Network.add_OP`.
+
+ The computation starts with :meth:`~Network._prune_graph()` extracting
+ a *DAG subgraph* by *pruning* its nodes based on given inputs and
+ requested outputs in :meth:`~Network.compute()`.
+
+:attr:`ExecutionPlan.dag`
+ An directed-acyclic-graph containing the *pruned* nodes as build by
+ :meth:`~Network._prune_graph()`. This pruned subgraph is used to decide
+ the :attr:`ExecutionPlan.steps` (below).
+ The containing :class:`ExecutionPlan.steps` instance is cached
+ in :attr:`_cached_plans` across runs with inputs/outputs as key.
+
+:attr:`ExecutionPlan.steps`
+ It is the list of the operation-nodes only
+ from the dag (above), topologically sorted, and interspersed with
+ *instructions steps* needed to complete the run.
+ It is built by :meth:`~Network._build_execution_steps()` based on
+ the subgraph dag extracted above.
+ The containing :class:`ExecutionPlan.steps` instance is cached
+ in :attr:`_cached_plans` across runs with inputs/outputs as key.
+
+ The *instructions* items achieve the following:
+
+ - :class:`DeleteInstruction`: delete items from `solution` as soon as
+ they are not needed further down the dag, to reduce memory footprint
+ while computing.
+
+ - :class:`PinInstruction`: avoid overwritting any given intermediate
+ inputs, and still allow their providing operations to run
+ (because they are needed for their other outputs).
+
+:var solution:
+ a local-var in :meth:`~Network.compute()`, initialized on each run
+ to hold the values of the given inputs, generated (intermediate) data,
+ and output values.
+ It is returned as is if no specific outputs requested; no data-eviction
+ happens then.
+
+:arg overwrites:
+ The optional argument given to :meth:`~Network.compute()` to colect the
+ intermediate *calculated* values that are overwritten by intermediate
+ (aka "pinned") input-values.
+"""
+import logging
import os
-import networkx as nx
-
+import sys
+import time
+from collections import defaultdict, namedtuple
from io import StringIO
+from itertools import chain
+
+import networkx as nx
+from boltons.setutils import IndexedSet as iset
+from . import plot
from .base import Operation
+from .modifiers import optional
+
+log = logging.getLogger(__name__)
+
+
+from networkx import DiGraph
+if sys.version_info < (3, 6):
+ """
+ Consistently ordered variant of :class:`~networkx.DiGraph`.
+
+ PY3.6 has inmsertion-order dicts, but PY3.5 has not.
+ And behvavior *and TCs) in these environments may fail spuriously!
+ Still *subgraphs* may not patch!
+
+ Fix from:
+ https://networkx.github.io/documentation/latest/reference/classes/ordered.html#module-networkx.classes.ordered
+ """
+ from networkx import OrderedDiGraph as DiGraph
class DataPlaceholderNode(str):
"""
- A node for the Network graph that describes the name of a Data instance
- produced or required by a layer.
+ Dag node naming a data-value produced or required by an operation.
"""
def __repr__(self):
return 'DataPlaceholderNode("%s")' % self
@@ -21,46 +106,65 @@ def __repr__(self):
class DeleteInstruction(str):
"""
- An instruction for the compiled list of evaluation steps to free or delete
- a Data instance from the Network's cache after it is no longer needed.
+ Execution step to delete a computed value from the `solution`.
+
+ It's a step in :attr:`ExecutionPlan.steps` for the data-node `str` that
+ frees its data-value from `solution` after it is no longer needed,
+ to reduce memory footprint while computing the graph.
"""
def __repr__(self):
return 'DeleteInstruction("%s")' % self
-class Network(object):
+class PinInstruction(str):
+ """
+ Execution step to replace a computed value in the `solution` from the inputs,
+
+ and to store the computed one in the ``overwrites`` instead
+ (both `solution` & ``overwrites`` are local-vars in :meth:`~Network.compute()`).
+
+ It's a step in :attr:`ExecutionPlan.steps` for the data-node `str` that
+ ensures the corresponding intermediate input-value is not overwritten when
+ its providing function(s) could not be pruned, because their other outputs
+ are needed elesewhere.
"""
- This is the main network implementation. The class contains all of the
- code necessary to weave together operations into a directed-acyclic-graph (DAG)
- and pass data through.
+ def __repr__(self):
+ return 'PinInstruction("%s")' % self
+
+
+class Network(plot.Plotter):
"""
+ Assemble operations & data into a directed-acyclic-graph (DAG) to run them.
- def __init__(self, **kwargs):
- """
- """
+ """
+ def __init__(self, **kwargs):
# directed graph of layer instances and data-names defining the net.
- self.graph = nx.DiGraph()
- self._debug = kwargs.get("debug", False)
+ self.graph = DiGraph()
- # this holds the timing information for eache layer
+ # this holds the timing information for each layer
self.times = {}
- # a compiled list of steps to evaluate layers *in order* and free mem.
- self.steps = []
+ #: Speed up :meth:`compile()` call and avoid a multithreading issue(?)
+ #: that is occuring when accessing the dag in networkx.
+ self._cached_plans = {}
+
+ #: the execution_plan of the last call to :meth:`compute()`
+ #: (not ``compile()``!), for debugging purposes.
+ self.last_plan = None
- # This holds a cache of results for the _find_necessary_steps
- # function, this helps speed up the compute call as well avoid
- # a multithreading issue that is occuring when accessing the
- # graph in networkx
- self._necessary_steps_cache = {}
+ def _build_pydot(self, **kws):
+ from .plot import build_pydot
+ kws.setdefault('graph', self.graph)
+
+ return build_pydot(**kws)
def add_op(self, operation):
"""
Adds the given operation and its data requirements to the network graph
- based on the name of the operation, the names of the operation's needs, and
- the names of the data it provides.
+ based on the name of the operation, the names of the operation's needs,
+ and the names of the data it provides.
:param Operation operation: Operation object to add.
"""
@@ -71,7 +175,9 @@ def add_op(self, operation):
assert operation.provides is not None, "Operation's 'provides' must be named"
# assert layer is only added once to graph
- assert operation not in self.graph.nodes(), "Operation may only be added once"
+ assert operation not in self.graph.nodes, "Operation may only be added once"
+
+ self._cached_plans = {}
# add nodes and edges to graph describing the data needs for this layer
for n in operation.needs:
@@ -81,174 +187,396 @@ def add_op(self, operation):
for p in operation.provides:
self.graph.add_edge(operation, DataPlaceholderNode(p))
- # clear compiled steps (must recompile after adding new layers)
- self.steps = []
+ def _build_execution_steps(self, dag, inputs, outputs):
+ """
+ Create the list of operation-nodes & *instructions* evaluating all
- def list_layers(self):
- assert self.steps, "network must be compiled before listing layers."
- return [(s.name, s) for s in self.steps if isinstance(s, Operation)]
-
+ operations & instructions needed a) to free memory and b) avoid
+ overwritting given intermediate inputs.
- def show_layers(self):
- """Shows info (name, needs, and provides) about all layers in this network."""
- for name, step in self.list_layers():
- print("layer_name: ", name)
- print("\t", "needs: ", step.needs)
- print("\t", "provides: ", step.provides)
- print("")
+ :param dag:
+ The original dag, pruned; not broken.
+ :param outputs:
+ outp-names to decide whether to add (and which) del-instructions
+ In the list :class:`DeleteInstructions` steps (DA) are inserted between
+ operation nodes to reduce the memory footprint of solution.
+ A DA is inserted whenever a *need* is not used by any other *operation*
+ further down the DAG.
+ Note that since the `solutions` are not shared across `compute()` calls,
+ any memory-reductions are for as long as a single computation runs.
- def compile(self):
- """Create a set of steps for evaluating layers
- and freeing memory as necessary"""
+ """
- # clear compiled steps
- self.steps = []
+ steps = []
# create an execution order such that each layer's needs are provided.
- ordered_nodes = list(nx.dag.topological_sort(self.graph))
+ ordered_nodes = iset(nx.topological_sort(dag))
- # add Operations evaluation steps, and instructions to free data.
+ # Add Operations evaluation steps, and instructions to free and "pin"
+ # data.
for i, node in enumerate(ordered_nodes):
if isinstance(node, DataPlaceholderNode):
- continue
+ if node in inputs and dag.pred[node]:
+ # Command pinning only when there is another operation
+ # generating this data as output.
+ steps.append(PinInstruction(node))
elif isinstance(node, Operation):
+ steps.append(node)
- # add layer to list of steps
- self.steps.append(node)
+ # Keep all values in solution if not specific outputs asked.
+ if not outputs:
+ continue
# Add instructions to delete predecessors as possible. A
# predecessor may be deleted if it is a data placeholder that
# is no longer needed by future Operations.
- for predecessor in self.graph.predecessors(node):
- if self._debug:
- print("checking if node %s can be deleted" % predecessor)
- predecessor_still_needed = False
+ for need in self.graph.pred[node]:
+ log.debug("checking if node %s can be deleted", need)
for future_node in ordered_nodes[i+1:]:
- if isinstance(future_node, Operation):
- if predecessor in future_node.needs:
- predecessor_still_needed = True
- break
- if not predecessor_still_needed:
- if self._debug:
- print(" adding delete instruction for %s" % predecessor)
- self.steps.append(DeleteInstruction(predecessor))
+ if (
+ isinstance(future_node, Operation)
+ and need in future_node.needs
+ ):
+ break
+ else:
+ if need not in outputs:
+ log.debug(" adding delete instruction for %s", need)
+ steps.append(DeleteInstruction(need))
else:
- raise TypeError("Unrecognized network graph node")
+ raise AssertionError("Unrecognized network graph node %r" % node)
+
+ return steps
+
+ def _collect_unsatisfied_operations(self, dag, inputs):
+ """
+ Traverse topologically sorted dag to collect un-satisfied operations.
+ Unsatisfied operations are those suffering from ANY of the following:
- def _find_necessary_steps(self, outputs, inputs):
+ - They are missing at least one compulsory need-input.
+ Since the dag is ordered, as soon as we're on an operation,
+ all its needs have been accounted, so we can get its satisfaction.
+
+ - Their provided outputs are not linked to any data in the dag.
+ An operation might not have any output link when :meth:`_prune_graph()`
+ has broken them, due to given intermediate inputs.
+
+ :param dag:
+ a graph with broken edges those arriving to existing inputs
+ :param inputs:
+ an iterable of the names of the input values
+ return:
+ a list of unsatisfied operations to prune
"""
- Determines what graph steps need to pe run to get to the requested
- outputs from the provided inputs. Eliminates steps that come before
- (in topological order) any inputs that have been provided. Also
- eliminates steps that are not on a path from the provided inputs to
- the requested outputs.
+ # To collect data that will be produced.
+ ok_data = set(inputs)
+ # To colect the map of operations --> satisfied-needs.
+ op_satisfaction = defaultdict(set)
+ # To collect the operations to drop.
+ unsatisfied = []
+ for node in nx.topological_sort(dag):
+ if isinstance(node, Operation):
+ if not dag.adj[node]:
+ # Prune operations that ended up providing no output.
+ unsatisfied.append(node)
+ else:
+ real_needs = set(n for n in node.needs
+ if not isinstance(n, optional))
+ if real_needs.issubset(op_satisfaction[node]):
+ # We have a satisfied operation; mark its output-data
+ # as ok.
+ ok_data.update(dag.adj[node])
+ else:
+ # Prune operations with partial inputs.
+ unsatisfied.append(node)
+ elif isinstance(node, (DataPlaceholderNode, str)): # `str` are givens
+ if node in ok_data:
+ # mark satisfied-needs on all future operations
+ for future_op in dag.adj[node]:
+ op_satisfaction[future_op].add(node)
+ else:
+ raise AssertionError("Unrecognized network graph node %r" % node)
- :param list outputs:
+ return unsatisfied
+
+
+ def _prune_graph(self, outputs, inputs):
+ """
+ Determines what graph steps need to run to get to the requested
+ outputs from the provided inputs. :
+ - Eliminate steps that are not on a path arriving to requested outputs.
+ - Eliminate unsatisfied operations: partial inputs or no outputs needed.
+
+ :param iterable outputs:
A list of desired output names. This can also be ``None``, in which
case the necessary steps are all graph nodes that are reachable
from one of the provided inputs.
- :param dict inputs:
- A dictionary mapping names to values for all provided inputs.
+ :param iterable inputs:
+ The inputs names of all given inputs.
- :returns:
- Returns a list of all the steps that need to be run for the
- provided inputs and requested outputs.
+ :return:
+ the *pruned_dag*
"""
+ dag = self.graph
+
+ # Ignore input names that aren't in the graph.
+ graph_inputs = iset(dag.nodes) & inputs # preserve order
+
+ # Scream if some requested outputs aren't in the graph.
+ unknown_outputs = iset(outputs) - dag.nodes
+ if unknown_outputs:
+ raise ValueError(
+ "Unknown output node(s) requested: %s"
+ % ", ".join(unknown_outputs))
+
+ broken_dag = dag.copy() # preserve net's graph
+
+ # Break the incoming edges to all given inputs.
+ #
+ # Nodes producing any given intermediate inputs are unecessary
+ # (unless they are also used elsewhere).
+ # To discover which ones to prune, we break their incoming edges
+ # and they will drop out while collecting ancestors from the outputs.
+ broken_edges = set() # unordered, not iterated
+ for given in graph_inputs:
+ broken_edges.update(broken_dag.in_edges(given))
+ broken_dag.remove_edges_from(broken_edges)
+
+ # Drop stray input values and operations (if any).
+ broken_dag.remove_nodes_from(nx.isolates(broken_dag))
+
+ if outputs:
+ # If caller requested specific outputs, we can prune any
+ # unrelated nodes further up the dag.
+ ending_in_outputs = set()
+ for input_name in outputs:
+ ending_in_outputs.update(nx.ancestors(dag, input_name))
+ broken_dag = broken_dag.subgraph(ending_in_outputs | set(outputs))
+
+
+ # Prune unsatisfied operations (those with partial inputs or no outputs).
+ unsatisfied = self._collect_unsatisfied_operations(broken_dag, inputs)
+ # Clone it so that it is picklable.
+ pruned_dag = dag.subgraph(broken_dag.nodes - unsatisfied).copy()
+
+ return pruned_dag, tuple(broken_edges)
+
+ def compile(self, inputs=(), outputs=()):
+ """
+ Create or get from cache an execution-plan for the given inputs/outputs.
- # return steps if it has already been computed before for this set of inputs and outputs
- outputs = tuple(sorted(outputs)) if isinstance(outputs, (list, set)) else outputs
- inputs_keys = tuple(sorted(inputs.keys()))
- cache_key = (inputs_keys, outputs)
- if cache_key in self._necessary_steps_cache:
- return self._necessary_steps_cache[cache_key]
+ See :meth:`_prune_graph()` and :meth:`_build_execution_steps()`
+ for detailed description.
- graph = self.graph
- if not outputs:
+ :param inputs:
+ An iterable with the names of all the given inputs.
- # If caller requested all outputs, the necessary nodes are all
- # nodes that are reachable from one of the inputs. Ignore input
- # names that aren't in the graph.
- necessary_nodes = set()
- for input_name in iter(inputs):
- if graph.has_node(input_name):
- necessary_nodes |= nx.descendants(graph, input_name)
+ :param outputs:
+ (optional) An iterable or the name of the output name(s).
+ If missing, requested outputs assumed all graph reachable nodes
+ from one of the given inputs.
+ :return:
+ the cached or fresh new execution-plan
+ """
+ # outputs must be iterable
+ if not outputs:
+ outputs = ()
+ elif isinstance(outputs, str):
+ outputs = (outputs, )
+
+ # Make a stable cache-key
+ cache_key = (tuple(sorted(inputs)), tuple(sorted(outputs)))
+ if cache_key in self._cached_plans:
+ # An execution plan has been compiled before
+ # for the same inputs & outputs.
+ plan = self._cached_plans[cache_key]
else:
+ # Build a new execution plan for the given inputs & outputs.
+ #
+ pruned_dag, broken_edges = self._prune_graph(outputs, inputs)
+ steps = self._build_execution_steps(pruned_dag, inputs, outputs)
+ plan = ExecutionPlan(
+ self,
+ tuple(inputs),
+ outputs,
+ pruned_dag,
+ tuple(broken_edges),
+ tuple(steps),
+ executed=iset(),
+ )
+
+ # Cache compilation results to speed up future runs
+ # with different values (but same number of inputs/outputs).
+ self._cached_plans[cache_key] = plan
+
+ return plan
+
+ def compute(
+ self, named_inputs, outputs, method=None, overwrites_collector=None):
+ """
+ Solve & execute the graph, sequentially or parallel.
- # If the caller requested a subset of outputs, find any nodes that
- # are made unecessary because we were provided with an input that's
- # deeper into the network graph. Ignore input names that aren't
- # in the graph.
- unnecessary_nodes = set()
- for input_name in iter(inputs):
- if graph.has_node(input_name):
- unnecessary_nodes |= nx.ancestors(graph, input_name)
+ :param dict named_inputs:
+ A dict of key/value pairs where the keys represent the data nodes
+ you want to populate, and the values are the concrete values you
+ want to set for the data node.
- # Find the nodes we need to be able to compute the requested
- # outputs. Raise an exception if a requested output doesn't
- # exist in the graph.
- necessary_nodes = set()
- for output_name in outputs:
- if not graph.has_node(output_name):
- raise ValueError("graphkit graph does not have an output "
- "node named %s" % output_name)
- necessary_nodes |= nx.ancestors(graph, output_name)
+ :param list output:
+ once all necessary computations are complete.
+ If you set this variable to ``None``, all data nodes will be kept
+ and returned at runtime.
- # Get rid of the unnecessary nodes from the set of necessary ones.
- necessary_nodes -= unnecessary_nodes
+ :param method:
+ if ``"parallel"``, launches multi-threading.
+ Set when invoking a composed graph or by
+ :meth:`~NetworkOperation.set_execution_method()`.
+ :param overwrites_collector:
+ (optional) a mutable dict to be fillwed with named values.
+ If missing, values are simply discarded.
- necessary_steps = [step for step in self.steps if step in necessary_nodes]
+ :returns: a dictionary of output data objects, keyed by name.
+ """
- # save this result in a precomputed cache for future lookup
- self._necessary_steps_cache[cache_key] = necessary_steps
+ assert isinstance(outputs, (list, tuple)) or outputs is None,\
+ "The outputs argument must be a list"
- # Return an ordered list of the needed steps.
- return necessary_steps
+ # Build the execution plan.
+ self.last_plan = plan = self.compile(named_inputs.keys(), outputs)
+ # start with fresh data solution.
+ solution = dict(named_inputs)
- def compute(self, outputs, named_inputs, method=None):
- """
- Run the graph. Any inputs to the network must be passed in by name.
+ plan.execute(solution, overwrites_collector, method)
- :param list output: The names of the data node you'd like to have returned
- once all necessary computations are complete.
- If you set this variable to ``None``, all
- data nodes will be kept and returned at runtime.
+ if outputs:
+ # Filter outputs to just return what's requested.
+ # Otherwise, eturn the whole solution as output,
+ # including input and intermediate data nodes.
+ # TODO: assert no other outputs exists due to DelInstructs.
+ solution = dict(i for i in solution.items() if i[0] in outputs)
- :param dict named_inputs: A dict of key/value pairs where the keys
- represent the data nodes you want to populate,
- and the values are the concrete values you
- want to set for the data node.
+ return solution
- :returns: a dictionary of output data objects, keyed by name.
+class ExecutionPlan(
+ namedtuple("_ExePlan", "net inputs outputs dag broken_edges steps executed"),
+ plot.Plotter
+):
+ """
+ The result of the network's compilation phase.
+
+ Note the execution plan's attributes are on purpose immutable tuples.
+
+ :ivar net:
+ The parent :class:`Network`
+
+ :ivar inputs:
+ A tuple with the names of the given inputs used to construct the plan.
+
+ :ivar outputs:
+ A (possibly empy) tuple with the names of the requested outputs
+ used to construct the plan.
+
+ :ivar dag:
+ The regular (not broken) *pruned* subgraph of net-graph.
+
+ :ivar broken_edges:
+ Tuple of broken incoming edges to given data.
+
+ :ivar steps:
+ The tuple of operation-nodes & *instructions* needed to evaluate
+ the given inputs & asked outputs, free memory and avoid overwritting
+ any given intermediate inputs.
+ :ivar executed:
+ An empty set to collect all operations that have been executed so far.
+ """
+
+ @property
+ def broken_dag(self):
+ return nx.restricted_view(self.dag, nodes=(), edges=self.broken_edges)
+
+ def _build_pydot(self, **kws):
+ from .plot import build_pydot
+
+ clusters = None
+ if self.dag.nodes != self.net.graph.nodes:
+ clusters = {n: "after prunning" for n in self.dag.nodes}
+ mykws = {
+ "graph": self.net.graph,
+ "steps": self.steps,
+ "inputs": self.inputs,
+ "outputs": self.outputs,
+ "executed": self.executed,
+ "edge_props": {e: {"color": "wheat", "penwidth": 2} for e in self.broken_edges},
+ "clusters": clusters,
+ }
+ mykws.update(kws)
+
+ return build_pydot(**mykws)
+
+ def __repr__(self):
+ return (
+ "ExecutionPlan:\n +--inputs:%s, \n +--outputs=%s\n +--steps=%s)"
+ % (self.inputs, self.outputs, self.steps))
+
+ def get_data_node(self, name):
+ """
+ Retuen the data node from a graph using its name, or None.
"""
+ node = self.dag.nodes[name]
+ if isinstance(node, DataPlaceholderNode):
+ return node
- # assert that network has been compiled
- assert self.steps, "network must be compiled before calling compute."
- assert isinstance(outputs, (list, tuple)) or outputs is None,\
- "The outputs argument must be a list"
+ def _can_schedule_operation(self, op):
+ """
+ Determines if a Operation is ready to be scheduled for execution
+ based on what has already been executed.
- # choose a method of execution
- if method == "parallel":
- return self._compute_thread_pool_barrier_method(named_inputs,
- outputs)
- else:
- return self._compute_sequential_method(named_inputs,
- outputs)
+ :param op:
+ The Operation object to check
+ :return:
+ A boolean indicating whether the operation may be scheduled for
+ execution based on what has already been executed.
+ """
+ # Use `broken_dag` to allow executing operations after given inputs
+ # regardless of whether their producers have yet to run.
+ dependencies = set(n for n in nx.ancestors(self.broken_dag, op)
+ if isinstance(n, Operation))
+ return dependencies.issubset(self.executed)
+ def _can_evict_value(self, name):
+ """
+ Determines if a DataPlaceholderNode is ready to be deleted from solution.
- def _compute_thread_pool_barrier_method(self, named_inputs, outputs,
- thread_pool_size=10):
+ :param name:
+ The name of the data node to check
+ :return:
+ A boolean indicating whether the data node can be deleted or not.
+ """
+ data_node = self.get_data_node(name)
+ # Use `broken_dag` not to block a successor waiting for this data,
+ # since in any case will use a given input, not some pipe of this data.
+ return data_node and set(
+ self.broken_dag.successors(data_node)).issubset(self.executed)
+
+ def _pin_data_in_solution(self, value_name, solution, inputs, overwrites):
+ value_name = str(value_name)
+ if overwrites is not None:
+ overwrites[value_name] = solution[value_name]
+ solution[value_name] = inputs[value_name]
+
+ def _execute_thread_pool_barrier_method(self, inputs, solution, overwrites,
+ thread_pool_size=10
+ ):
"""
This method runs the graph using a parallel pool of thread executors.
You may achieve lower total latency if your graph is sufficiently
@@ -257,41 +585,42 @@ def _compute_thread_pool_barrier_method(self, named_inputs, outputs,
from multiprocessing.dummy import Pool
# if we have not already created a thread_pool, create one
- if not hasattr(self, "_thread_pool"):
- self._thread_pool = Pool(thread_pool_size)
- pool = self._thread_pool
-
- cache = {}
- cache.update(named_inputs)
- necessary_nodes = self._find_necessary_steps(outputs, named_inputs)
-
- # this keeps track of all nodes that have already executed
- has_executed = set()
+ if not hasattr(self.net, "_thread_pool"):
+ self.net._thread_pool = Pool(thread_pool_size)
+ pool = self.net._thread_pool
# with each loop iteration, we determine a set of operations that can be
# scheduled, then schedule them onto a thread pool, then collect their
- # results onto a memory cache for use upon the next iteration.
+ # results onto a memory solution for use upon the next iteration.
while True:
# the upnext list contains a list of operations for scheduling
# in the current round of scheduling
upnext = []
- for node in necessary_nodes:
- # only delete if all successors for the data node have been executed
- if isinstance(node, DeleteInstruction):
- if ready_to_delete_data_node(node,
- has_executed,
- self.graph):
- if node in cache:
- cache.pop(node)
-
- # continue if this node is anything but an operation node
- if not isinstance(node, Operation):
- continue
-
- if ready_to_schedule_operation(node, has_executed, self.graph) \
- and node not in has_executed:
+ for node in self.steps:
+ if (
+ isinstance(node, Operation)
+ and self._can_schedule_operation(node)
+ and node not in self.executed
+ ):
upnext.append(node)
+ elif isinstance(node, DeleteInstruction):
+ # Only delete if all successors for the data node
+ # have been executed.
+ # An optional need may not have a value in the solution.
+ if (
+ node in solution
+ and self._can_evict_value(node)
+ ):
+ log.debug("removing data '%s' from solution.", node)
+ del solution[node]
+ elif isinstance(node, PinInstruction):
+ # Always and repeatedely pin the value, even if not all
+ # providers of the data have executed.
+ # An optional need may not have a value in the solution.
+ if node in solution:
+ self._pin_data_in_solution(
+ node, solution, inputs, overwrites)
# stop if no nodes left to schedule, exit out of the loop
@@ -299,198 +628,76 @@ def _compute_thread_pool_barrier_method(self, named_inputs, outputs,
break
done_iterator = pool.imap_unordered(
- lambda op: (op,op._compute(cache)),
+ lambda op: (op,op._compute(solution)),
upnext)
for op, result in done_iterator:
- cache.update(result)
- has_executed.add(op)
+ solution.update(result)
+ self.executed.add(op)
- if not outputs:
- return cache
- else:
- return {k: cache[k] for k in iter(cache) if k in outputs}
- def _compute_sequential_method(self, named_inputs, outputs):
+ def _execute_sequential_method(self, inputs, solution, overwrites):
"""
This method runs the graph one operation at a time in a single thread
"""
- # start with fresh data cache
- cache = {}
-
- # add inputs to data cache
- cache.update(named_inputs)
-
- # Find the subset of steps we need to run to get to the requested
- # outputs from the provided inputs.
- all_steps = self._find_necessary_steps(outputs, named_inputs)
-
self.times = {}
- for step in all_steps:
+ for step in self.steps:
if isinstance(step, Operation):
- if self._debug:
- print("-"*32)
- print("executing step: %s" % step.name)
+ log.debug("%sexecuting step: %s", "-"*32, step.name)
# time execution...
t0 = time.time()
# compute layer outputs
- layer_outputs = step._compute(cache)
+ layer_outputs = step._compute(solution)
- # add outputs to cache
- cache.update(layer_outputs)
+ # add outputs to solution
+ solution.update(layer_outputs)
+ self.executed.add(step)
# record execution time
t_complete = round(time.time() - t0, 5)
self.times[step.name] = t_complete
- if self._debug:
- print("step completion time: %s" % t_complete)
+ log.debug("step completion time: %s", t_complete)
- # Process DeleteInstructions by deleting the corresponding data
- # if possible.
elif isinstance(step, DeleteInstruction):
+ # Cache value may be missing if it is optional.
+ if step in solution:
+ log.debug("removing data '%s' from solution.", step)
+ del solution[step]
- if outputs and step not in outputs:
- # Some DeleteInstruction steps may not exist in the cache
- # if they come from optional() needs that are not privoded
- # as inputs. Make sure the step exists before deleting.
- if step in cache:
- if self._debug:
- print("removing data '%s' from cache." % step)
- cache.pop(step)
-
+ elif isinstance(step, PinInstruction):
+ self._pin_data_in_solution(step, solution, inputs, overwrites)
else:
- raise TypeError("Unrecognized instruction.")
+ raise AssertionError("Unrecognized instruction.%r" % step)
- if not outputs:
- # Return the whole cache as output, including input and
- # intermediate data nodes.
- return cache
-
- else:
- # Filter outputs to just return what's needed.
- # Note: list comprehensions exist in python 2.7+
- return {k: cache[k] for k in iter(cache) if k in outputs}
-
-
- def plot(self, filename=None, show=False):
+ def execute(self, solution, overwrites=None, method=None):
"""
- Plot the graph.
-
- params:
- :param str filename:
- Write the output to a png, pdf, or graphviz dot file. The extension
- controls the output format.
-
- :param boolean show:
- If this is set to True, use matplotlib to show the graph diagram
- (Default: False)
-
- :returns:
- An instance of the pydot graph
-
+ :param solution:
+ a mutable maping to collect the results and that must contain also
+ the given input values for at least the compulsory inputs that
+ were specified when the plan was built (but cannot enforce that!).
+
+ :param overwrites:
+ (optional) a mutable dict to collect calculated-but-discarded values
+ because they were "pinned" by input vaules.
+ If missing, the overwrites values are simply discarded.
"""
- import pydot
- import matplotlib.pyplot as plt
- import matplotlib.image as mpimg
-
- assert self.graph is not None
-
- def get_node_name(a):
- if isinstance(a, DataPlaceholderNode):
- return a
- return a.name
-
- g = pydot.Dot(graph_type="digraph")
-
- # draw nodes
- for nx_node in self.graph.nodes():
- if isinstance(nx_node, DataPlaceholderNode):
- node = pydot.Node(name=nx_node, shape="rect")
- else:
- node = pydot.Node(name=nx_node.name, shape="circle")
- g.add_node(node)
-
- # draw edges
- for src, dst in self.graph.edges():
- src_name = get_node_name(src)
- dst_name = get_node_name(dst)
- edge = pydot.Edge(src=src_name, dst=dst_name)
- g.add_edge(edge)
-
- # save plot
- if filename:
- basename, ext = os.path.splitext(filename)
- with open(filename, "w") as fh:
- if ext.lower() == ".png":
- fh.write(g.create_png())
- elif ext.lower() == ".dot":
- fh.write(g.to_string())
- elif ext.lower() in [".jpg", ".jpeg"]:
- fh.write(g.create_jpeg())
- elif ext.lower() == ".pdf":
- fh.write(g.create_pdf())
- elif ext.lower() == ".svg":
- fh.write(g.create_svg())
- else:
- raise Exception("Unknown file format for saving graph: %s" % ext)
+ # Clean executed operation from any previous execution.
+ self.executed.clear()
- # display graph via matplotlib
- if show:
- png = g.create_png()
- sio = StringIO(png)
- img = mpimg.imread(sio)
- plt.imshow(img, aspect="equal")
- plt.show()
-
- return g
-
-
-def ready_to_schedule_operation(op, has_executed, graph):
- """
- Determines if a Operation is ready to be scheduled for execution based on
- what has already been executed.
-
- Args:
- op:
- The Operation object to check
- has_executed: set
- A set containing all operations that have been executed so far
- graph:
- The networkx graph containing the operations and data nodes
- Returns:
- A boolean indicating whether the operation may be scheduled for
- execution based on what has already been executed.
- """
- dependencies = set(filter(lambda v: isinstance(v, Operation),
- nx.ancestors(graph, op)))
- return dependencies.issubset(has_executed)
+ # choose a method of execution
+ executor = (self._execute_thread_pool_barrier_method
+ if method == "parallel" else
+ self._execute_sequential_method)
-def ready_to_delete_data_node(name, has_executed, graph):
- """
- Determines if a DataPlaceholderNode is ready to be deleted from the
- cache.
+ # clone and keep orignal inputs in solution intact
+ executor(dict(solution), solution, overwrites)
- Args:
- name:
- The name of the data node to check
- has_executed: set
- A set containing all operations that have been executed so far
- graph:
- The networkx graph containing the operations and data nodes
- Returns:
- A boolean indicating whether the data node can be deleted or not.
- """
- data_node = get_data_node(name, graph)
- return set(graph.successors(data_node)).issubset(has_executed)
+ # return it, but caller can also see the results in `solution` dict.
+ return solution
-def get_data_node(name, graph):
- """
- Gets a data node from a graph using its name
- """
- for node in graph.nodes():
- if node == name and isinstance(node, DataPlaceholderNode):
- return node
- return None
+# TODO: maybe class Solution(object):
+# values = {}
+# overwrites = None
diff --git a/graphkit/plot.py b/graphkit/plot.py
new file mode 100644
index 00000000..69138d06
--- /dev/null
+++ b/graphkit/plot.py
@@ -0,0 +1,411 @@
+# Copyright 2016, Yahoo Inc.
+# Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms.
+
+import io
+import logging
+import os
+
+
+log = logging.getLogger(__name__)
+
+
+class Plotter(object):
+ """
+ Classes wishing to plot their graphs should inherit this and ...
+
+ implement property ``_plot`` to return a "partial" callable that somehow
+ ends up calling :func:`plot.plot_graph()` with the `graph` or any other
+ args binded appropriately.
+ The purpose is to avoid copying this function & documentation here around.
+ """
+
+ def plot(self, filename=None, show=False, **kws):
+ """
+ :param str filename:
+ Write diagram into a file.
+ Common extensions are ``.png .dot .jpg .jpeg .pdf .svg``
+ call :func:`plot.supported_plot_formats()` for more.
+ :param show:
+ If it evaluates to true, opens the diagram in a matplotlib window.
+ If it equals `-1`, it plots but does not open the Window.
+ :param inputs:
+ an optional name list, any nodes in there are plotted
+ as a "house"
+ :param outputs:
+ an optional name list, any nodes in there are plotted
+ as an "inverted-house"
+ :param solution:
+ an optional dict with values to annotate nodes, drawn "filled"
+ (currently content not shown, but node drawn as "filled")
+ :param executed:
+ an optional container with operations executed, drawn "filled"
+ :param title:
+ an optional string to display at the bottom of the graph
+ :param node_props:
+ an optional nested dict of Grapvhiz attributes for certain nodes
+ :param edge_props:
+ an optional nested dict of Grapvhiz attributes for certain edges
+ :param clusters:
+ an optional mapping of nodes --> cluster-names, to group them
+
+ :return:
+ A ``pydot.Dot`` instance.
+ NOTE that the returned instance is monkeypatched to support
+ direct rendering in *jupyter cells* as SVG.
+
+
+ Note that the `graph` argument is absent - Each Plotter provides
+ its own graph internally; use directly :func:`plot_graph()` to provide
+ a different graph.
+
+ **Legend:**
+
+ . figure:: ../images/Graphkitlegend.svg
+ :alt: Graphkit Legend
+
+ see :func:`legend()`
+
+ *NODES:*
+
+ oval
+ function
+ egg
+ subgraph operation
+ house
+ given input
+ inversed-house
+ asked output
+ polygon
+ given both as input & asked as output (what?)
+ square
+ intermediate data, neither given nor asked.
+ red frame
+ delete-instruction, to free up memory.
+ blue frame
+ pinned-instruction, not to overwrite intermediate inputs.
+ filled
+ data node has a value in `solution` OR function has been executed.
+ thick frame
+ function/data node in execution `steps`.
+
+ *ARROWS*
+
+ solid black arrows
+ dependencies (source-data are``need``-ed by target-operations,
+ sources-operations ``provide`` target-data)
+ dashed black arrows
+ optional needs
+ wheat arrows
+ broken dependency (``provide``) during pruning
+ green-dotted arrows
+ execution steps labeled in succession
+
+ **Sample code:**
+
+ >>> from graphkit import compose, operation
+ >>> from graphkit.modifiers import optional
+
+ >>> graphop = compose(name="graphop")(
+ ... operation(name="add", needs=["a", "b1"], provides=["ab1"])(add),
+ ... operation(name="sub", needs=["a", optional("b2")], provides=["ab2"])(lambda a, b=1: a-b),
+ ... operation(name="abb", needs=["ab1", "ab2"], provides=["asked"])(add),
+ ... )
+
+ >>> graphop.plot(show=True); # plot just the graph in a matplotlib window
+ >>> inputs = {'a': 1, 'b1': 2}
+ >>> solution = graphop(inputs) # now plots will include the execution-plan
+
+ >>> graphop.plot('plot1.svg', inputs=inputs, outputs=['asked', 'b1'], solution=solution);
+ >>> graphop.plot(solution=solution) # just get the `pydoit.Dot` object, renderable in Jupyter
+ """
+ dot = self._build_pydot(**kws)
+ return render_pydot(dot, filename=filename, show=show)
+
+ def _build_pydot(self, **kws):
+ raise AssertionError("Must implement that!")
+
+
+def _is_class_value_in_list(lst, cls, value):
+ return any(isinstance(i, cls) and i == value for i in lst)
+
+
+def _merge_conditions(*conds):
+ """combines conditions as a choice in binary range, eg, 2 conds --> [0, 3]"""
+ return sum(int(bool(c)) << i for i, c in enumerate(conds))
+
+
+def _apply_user_props(dotobj, user_props, key):
+ if user_props and key in user_props:
+ dotobj.get_attributes().update(user_props[key])
+ # Delete it, to report unmatched ones, AND not to annotate `steps`.
+ del user_props[key]
+
+
+def _report_unmatched_user_props(user_props, kind):
+ if user_props and log.isEnabledFor(logging.WARNING):
+ unmatched = "\n ".join(str(i) for i in user_props.items())
+ log.warning("Unmatched `%s_props`:\n +--%s", kind, unmatched)
+
+
+def _monkey_patch_for_jupyter(pydot):
+ # Ensure Dot nstance render in Jupyter
+ # (see pydot/pydot#220)
+ if not hasattr(pydot.Dot, "_repr_svg_"):
+
+ def make_svg(self):
+ return self.create_svg().decode()
+
+ # monkey patch class
+ pydot.Dot._repr_svg_ = make_svg
+
+
+def build_pydot(
+ graph,
+ steps=None,
+ inputs=None,
+ outputs=None,
+ solution=None,
+ executed=None,
+ title=None,
+ node_props=None,
+ edge_props=None,
+ clusters=None,
+):
+ """
+ Build a *Graphviz* out of a Network graph/steps/inputs/outputs and return it.
+
+ See :meth:`Plotter.plot()` for the arguments, sample code, and
+ the legend of the plots.
+ """
+ import pydot
+ from .base import NetworkOperation, Operation
+ from .modifiers import optional
+ from .network import DeleteInstruction, PinInstruction
+
+ _monkey_patch_for_jupyter(pydot)
+
+ assert graph is not None
+
+ steps_thickness = 3
+ fill_color = "wheat"
+ steps_color = "#009999"
+ new_clusters = {}
+
+ def append_or_cluster_node(dot, nx_node, node):
+ if not clusters or not nx_node in clusters:
+ dot.add_node(node)
+ else:
+ cluster_name = clusters[nx_node]
+ node_cluster = new_clusters.get(cluster_name)
+ if not node_cluster:
+ node_cluster = new_clusters[cluster_name] = pydot.Cluster(
+ cluster_name, label=cluster_name
+ )
+ node_cluster.add_node(node)
+
+ def append_any_clusters(dot):
+ for cluster in new_clusters.values():
+ dot.add_subgraph(cluster)
+
+ def get_node_name(a):
+ if isinstance(a, Operation):
+ return a.name
+ return a
+
+ dot = pydot.Dot(graph_type="digraph", label=title, fontname="italic")
+
+ # draw nodes
+ for nx_node in graph.nodes:
+ if isinstance(nx_node, str):
+ kw = {}
+
+ # FrameColor change by step type
+ if steps and nx_node in steps:
+ choice = _merge_conditions(
+ _is_class_value_in_list(steps, DeleteInstruction, nx_node),
+ _is_class_value_in_list(steps, PinInstruction, nx_node),
+ )
+ # 0 is singled out because `nx_node` exists in `steps`.
+ color = "NOPE #990000 blue purple".split()[choice]
+ kw = {"color": color, "penwidth": steps_thickness}
+
+ # SHAPE change if with inputs/outputs.
+ # tip: https://graphviz.gitlab.io/_pages/doc/info/shapes.html
+ choice = _merge_conditions(
+ inputs and nx_node in inputs, outputs and nx_node in outputs
+ )
+ shape = "rect invhouse house hexagon".split()[choice]
+
+ # LABEL change with solution.
+ if solution and nx_node in solution:
+ kw["style"] = "filled"
+ kw["fillcolor"] = fill_color
+ # kw["tooltip"] = str(solution.get(nx_node)) # not working :-()
+ node = pydot.Node(name=nx_node, shape=shape, **kw)
+ else: # Operation
+ kw = {}
+
+ if steps and nx_node in steps:
+ kw["penwdth"] = steps_thickness
+ shape = "egg" if isinstance(nx_node, NetworkOperation) else "oval"
+ if executed and nx_node in executed:
+ kw["style"] = "filled"
+ kw["fillcolor"] = fill_color
+ node = pydot.Node(name=nx_node.name, shape=shape, **kw)
+
+ _apply_user_props(node, node_props, key=node.get_name())
+
+ append_or_cluster_node(dot, nx_node, node)
+
+ _report_unmatched_user_props(node_props, "node")
+
+ append_any_clusters(dot)
+
+ # draw edges
+ for src, dst in graph.edges:
+ src_name = get_node_name(src)
+ dst_name = get_node_name(dst)
+ kw = {}
+ if isinstance(dst, Operation) and _is_class_value_in_list(
+ dst.needs, optional, src
+ ):
+ kw["style"] = "dashed"
+ edge = pydot.Edge(src=src_name, dst=dst_name, **kw)
+
+ _apply_user_props(edge, edge_props, key=(src, dst))
+
+ dot.add_edge(edge)
+
+ _report_unmatched_user_props(edge_props, "edge")
+
+ # draw steps sequence
+ if steps and len(steps) > 1:
+ it1 = iter(steps)
+ it2 = iter(steps)
+ next(it2)
+ for i, (src, dst) in enumerate(zip(it1, it2), 1):
+ src_name = get_node_name(src)
+ dst_name = get_node_name(dst)
+ edge = pydot.Edge(
+ src=src_name,
+ dst=dst_name,
+ label=str(i),
+ style="dotted",
+ color=steps_color,
+ fontcolor=steps_color,
+ fontname="bold",
+ fontsize=18,
+ penwidth=steps_thickness,
+ arrowhead="vee",
+ )
+ dot.add_edge(edge)
+
+ return dot
+
+
+def supported_plot_formats():
+ """return automatically all `pydot` extensions"""
+ import pydot
+
+ return [".%s" % f for f in pydot.Dot().formats]
+
+
+def render_pydot(dot, filename=None, show=False):
+ """
+ Plot a *Graphviz* dot in a matplotlib, in file or return it for Jupyter.
+
+ :param dot:
+ the pre-built *Graphviz* dot instance
+ :param str filename:
+ Write diagram into a file.
+ Common extensions are ``.png .dot .jpg .jpeg .pdf .svg``
+ call :func:`plot.supported_plot_formats()` for more.
+ :param show:
+ If it evaluates to true, opens the diagram in a matplotlib window.
+ If it equals `-1`, it returns the image but does not open the Window.
+
+ :return:
+ the matplotlib image if ``show=-1``, or the `dot`.
+
+ See :meth:`Plotter.plot()` for sample code.
+ """
+ # Save plot
+ #
+ if filename:
+ formats = supported_plot_formats()
+ _basename, ext = os.path.splitext(filename)
+ if not ext.lower() in formats:
+ raise ValueError(
+ "Unknown file format for saving graph: %s"
+ " File extensions must be one of: %s" % (ext, " ".join(formats))
+ )
+
+ dot.write(filename, format=ext.lower()[1:])
+
+ ## Display graph via matplotlib
+ #
+ if show:
+ import matplotlib.pyplot as plt
+ import matplotlib.image as mpimg
+
+ png = dot.create_png()
+ sio = io.BytesIO(png)
+ img = mpimg.imread(sio)
+ if show != -1:
+ plt.imshow(img, aspect="equal")
+ plt.show()
+
+ return img
+
+ return dot
+
+
+def legend(filename=None, show=None):
+ """Generate a legend for all plots (see Plotter.plot() for args)"""
+ import pydot
+
+ _monkey_patch_for_jupyter(pydot)
+
+ ## From https://stackoverflow.com/questions/3499056/making-a-legend-key-in-graphviz
+ dot_text = """
+ digraph {
+ rankdir=LR;
+ subgraph cluster_legend {
+ label="Graphkit Legend";
+
+ operation [shape=oval];
+ graphop [shape=egg label="graph operation"];
+ insteps [penwidth=3 label="execution step"];
+ executed [style=filled fillcolor=wheat];
+ operation -> graphop -> insteps -> executed [style=invis];
+
+ data [shape=rect];
+ input [shape=invhouse];
+ output [shape=house];
+ inp_out [shape=hexagon label="inp+out"];
+ evicted [shape=rect penwidth=3 color="#990000"];
+ pinned [shape=rect penwidth=3 color="blue"];
+ evpin [shape=rect penwidth=3 color=purple label="evict+pin"];
+ sol [shape=rect style=filled fillcolor=wheat label="in solution"];
+ data -> input -> output -> inp_out -> evicted -> pinned -> evpin -> sol [style=invis];
+
+ e1 [style=invis] e2 [color=invis label="dependency"];
+ e1 -> e2;
+ e3 [color=invis label="optional"];
+ e2 -> e3 [style=dashed];
+ e4 [color=invis penwidth=3 label="pruned dependency"];
+ e3 -> e4 [color=wheat penwidth=2];
+ e5 [color=invis penwidth=4 label="execution sequence"];
+ e4 -> e5 [color="#009999" penwidth=4 style=dotted arrowhead=vee label=1 fontcolor="#009999"];
+ }
+ }
+ """
+
+ dot = pydot.graph_from_dot_data(dot_text)[0]
+ # clus = pydot.Cluster("Graphkit legend", label="Graphkit legend")
+ # dot.add_subgraph(clus)
+
+ # nodes = dot.Node()
+ # clus.add_node("operation")
+
+ return render_pydot(dot, filename=filename, show=show)
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 00000000..2e5ce9fc
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,10 @@
+## Python's setup.cfg for tool defaults:
+#
+[bdist_wheel]
+universal = 1
+
+
+[tool:pytest]
+# See http://doc.pytest.org/en/latest/mark.html#mark
+markers =
+ slow: marks tests as slow, select them with `-m slow` or `-m 'not slow'`
diff --git a/setup.py b/setup.py
index bd7883f4..bf91ab44 100644
--- a/setup.py
+++ b/setup.py
@@ -19,6 +19,15 @@
with io.open('graphkit/__init__.py', 'rt', encoding='utf8') as f:
version = re.search(r'__version__ = \'(.*?)\'', f.read()).group(1)
+plot_reqs = [
+ "matplotlib", # to test plot
+ "pydot", # to test plot
+]
+test_reqs = plot_reqs + [
+ "pytest",
+ "pytest-cov",
+]
+
setup(
name='graphkit',
version=version,
@@ -28,11 +37,16 @@
author_email='huyng@yahoo-inc.com',
url='http://github.com/yahoo/graphkit',
packages=['graphkit'],
- install_requires=['networkx'],
+ install_requires=[
+ "networkx; python_version >= '3.5'",
+ "networkx == 2.2; python_version < '3.5'",
+ "boltons" # for IndexSet
+ ],
extras_require={
- 'plot': ['pydot', 'matplotlib']
+ 'plot': plot_reqs,
+ 'test': test_reqs,
},
- tests_require=['numpy'],
+ tests_require=test_reqs,
license='Apache-2.0',
keywords=['graph', 'computation graph', 'DAG', 'directed acyclical graph'],
classifiers=[
diff --git a/test/test_graphkit.py b/test/test_graphkit.py
index bd97b317..dc9e0a6d 100644
--- a/test/test_graphkit.py
+++ b/test/test_graphkit.py
@@ -3,22 +3,43 @@
import math
import pickle
-
+from operator import add, floordiv, mul, sub
from pprint import pprint
-from operator import add
-from numpy.testing import assert_raises
-import graphkit.network as network
+import pytest
+
import graphkit.modifiers as modifiers
-from graphkit import operation, compose, Operation
+import graphkit.network as network
+from graphkit import Operation, compose, operation
+from graphkit.network import DeleteInstruction
+
+
+def scream(*args, **kwargs):
+ raise AssertionError(
+ "Must not have run!\n args: %s\n kwargs: %s", (args, kwargs))
+
+
+def identity(x):
+ return x
+
+
+def filtdict(d, *keys):
+ """
+ Keep dict items with the given keys
+
+ >>> filtdict({"a": 1, "b": 2}, "b")
+ {"b": 2}
+ """
+ return type(d)(i for i in d.items() if i[0] in keys)
+
-def test_network():
+def test_network_smoke():
# Sum operation, late-bind compute function
sum_op1 = operation(name='sum_op1', needs=['a', 'b'], provides='sum_ab')(add)
# sum_op1 is callable
- print(sum_op1(1, 2))
+ assert sum_op1(1, 2) == 3
# Multiply operation, decorate in-place
@operation(name='mul_op1', needs=['sum_ab', 'b'], provides='sum_ab_times_b')
@@ -26,14 +47,14 @@ def mul_op1(a, b):
return a * b
# mul_op1 is callable
- print(mul_op1(1, 2))
+ assert mul_op1(1, 2) == 2
# Pow operation
@operation(name='pow_op1', needs='sum_ab', provides=['sum_ab_p1', 'sum_ab_p2', 'sum_ab_p3'], params={'exponent': 3})
def pow_op1(a, exponent=2):
return [math.pow(a, y) for y in range(1, exponent+1)]
- print(pow_op1._compute({'sum_ab':2}, ['sum_ab_p2']))
+ assert pow_op1._compute({'sum_ab':2}, ['sum_ab_p2']) == {'sum_ab_p2': 4.0}
# Partial operation that is bound at a later time
partial_op = operation(name='sum_op2', needs=['sum_ab_p1', 'sum_ab_p2'], provides='p1_plus_p2')
@@ -47,7 +68,7 @@ def pow_op1(a, exponent=2):
sum_op3 = sum_op_factory(name='sum_op3', needs=['a', 'b'], provides='sum_ab2')
# sum_op3 is callable
- print(sum_op3(5, 6))
+ assert sum_op3(5, 6) == 11
# compose network
net = compose(name='my network')(sum_op1, mul_op1, pow_op1, sum_op2, sum_op3)
@@ -57,13 +78,24 @@ def pow_op1(a, exponent=2):
#
# get all outputs
- pprint(net({'a': 1, 'b': 2}))
+ exp = {'a': 1,
+ 'b': 2,
+ 'p1_plus_p2': 12.0,
+ 'sum_ab': 3,
+ 'sum_ab2': 3,
+ 'sum_ab_p1': 3.0,
+ 'sum_ab_p2': 9.0,
+ 'sum_ab_p3': 27.0,
+ 'sum_ab_times_b': 6}
+ assert net({'a': 1, 'b': 2}) == exp
# get specific outputs
- pprint(net({'a': 1, 'b': 2}, outputs=["sum_ab_times_b"]))
+ exp = {'sum_ab_times_b': 6}
+ assert net({'a': 1, 'b': 2}, outputs=["sum_ab_times_b"]) == exp
# start with inputs already computed
- pprint(net({"sum_ab": 1, "b": 2}, outputs=["sum_ab_times_b"]))
+ exp = {'sum_ab_times_b': 2}
+ assert net({"sum_ab": 1, "b": 2}, outputs=["sum_ab_times_b"]) == exp
# visualize network graph
# net.plot(show=True)
@@ -75,15 +107,23 @@ def test_network_simple_merge():
sum_op2 = operation(name='sum_op2', needs=['a', 'b'], provides='sum2')(add)
sum_op3 = operation(name='sum_op3', needs=['sum1', 'c'], provides='sum3')(add)
net1 = compose(name='my network 1')(sum_op1, sum_op2, sum_op3)
- pprint(net1({'a': 1, 'b': 2, 'c': 4}))
+
+ exp = {'a': 1, 'b': 2, 'c': 4, 'sum1': 3, 'sum2': 3, 'sum3': 7}
+ sol = net1({'a': 1, 'b': 2, 'c': 4})
+ assert sol == exp
sum_op4 = operation(name='sum_op1', needs=['d', 'e'], provides='a')(add)
sum_op5 = operation(name='sum_op2', needs=['a', 'f'], provides='b')(add)
+
net2 = compose(name='my network 2')(sum_op4, sum_op5)
- pprint(net2({'d': 1, 'e': 2, 'f': 4}))
+ exp = {'a': 3, 'b': 7, 'd': 1, 'e': 2, 'f': 4}
+ sol = net2({'d': 1, 'e': 2, 'f': 4})
+ assert sol == exp
net3 = compose(name='merged')(net1, net2)
- pprint(net3({'c': 5, 'd': 1, 'e': 2, 'f': 4}))
+ exp = {'a': 3, 'b': 7, 'c': 5, 'd': 1, 'e': 2, 'f': 4, 'sum1': 10, 'sum2': 10, 'sum3': 15}
+ sol = net3({'c': 5, 'd': 1, 'e': 2, 'f': 4})
+ assert sol == exp
def test_network_deep_merge():
@@ -92,15 +132,40 @@ def test_network_deep_merge():
sum_op2 = operation(name='sum_op2', needs=['a', 'b'], provides='sum2')(add)
sum_op3 = operation(name='sum_op3', needs=['sum1', 'c'], provides='sum3')(add)
net1 = compose(name='my network 1')(sum_op1, sum_op2, sum_op3)
- pprint(net1({'a': 1, 'b': 2, 'c': 4}))
+
+ exp = {'a': 1, 'b': 2, 'c': 4, 'sum1': 3, 'sum2': 3, 'sum3': 7}
+ assert net1({'a': 1, 'b': 2, 'c': 4}) == exp
sum_op4 = operation(name='sum_op1', needs=['a', 'b'], provides='sum1')(add)
sum_op5 = operation(name='sum_op4', needs=['sum1', 'b'], provides='sum2')(add)
net2 = compose(name='my network 2')(sum_op4, sum_op5)
- pprint(net2({'a': 1, 'b': 2}))
+ exp = {'a': 1, 'b': 2, 'sum1': 3, 'sum2': 5}
+ assert net2({'a': 1, 'b': 2}) == exp
net3 = compose(name='merged', merge=True)(net1, net2)
- pprint(net3({'a': 1, 'b': 2, 'c': 4}))
+ exp = {'a': 1, 'b': 2, 'c': 4, 'sum1': 3, 'sum2': 3, 'sum3': 7}
+ assert net3({'a': 1, 'b': 2, 'c': 4}) == exp
+
+
+def test_network_merge_in_doctests():
+ def abspow(a, p):
+ c = abs(a) ** p
+ return c
+
+ graphop = compose(name="graphop")(
+ operation(name="mul1", needs=["a", "b"], provides=["ab"])(mul),
+ operation(name="sub1", needs=["a", "ab"], provides=["a_minus_ab"])(sub),
+ operation(name="abspow1", needs=["a_minus_ab"], provides=["abs_a_minus_ab_cubed"], params={"p": 3})
+ (abspow)
+ )
+
+ another_graph = compose(name="another_graph")(
+ operation(name="mul1", needs=["a", "b"], provides=["ab"])(mul),
+ operation(name="mul2", needs=["c", "ab"], provides=["cab"])(mul)
+ )
+ merged_graph = compose(name="merged_graph", merge=True)(graphop, another_graph)
+ assert merged_graph.needs
+ assert merged_graph.provides
def test_input_based_pruning():
@@ -138,7 +203,7 @@ def test_output_based_pruning():
sum_op3 = operation(name='sum_op3', needs=['c', 'sum2'], provides='sum3')(add)
net = compose(name='test_net')(sum_op1, sum_op2, sum_op3)
- results = net({'c': c, 'd': d}, outputs=['sum3'])
+ results = net({'a': 0, 'b': 0, 'c': c, 'd': d}, outputs=['sum3'])
# Make sure we got expected result without having to pass a or b.
assert 'sum3' in results
@@ -180,8 +245,230 @@ def test_pruning_raises_for_bad_output():
# Request two outputs we can compute and one we can't compute. Assert
# that this raises a ValueError.
- assert_raises(ValueError, net, {'a': 1, 'b': 2, 'c': 3, 'd': 4},
- outputs=['sum1', 'sum3', 'sum4'])
+ with pytest.raises(ValueError) as exinfo:
+ net({'a': 1, 'b': 2, 'c': 3, 'd': 4},
+ outputs=['sum1', 'sum3', 'sum4'])
+ assert exinfo.match('sum4')
+
+def test_pruning_not_overrides_given_intermediate():
+ # Test #25: v1.2.4 overwrites intermediate data when no output asked
+ pipeline = compose(name="pipeline")(
+ operation(name="not run", needs=["a"], provides=["overriden"])(scream),
+ operation(name="op", needs=["overriden", "c"], provides=["asked"])(add),
+ )
+
+ inputs = {"a": 5, "overriden": 1, "c": 2}
+ exp = {"a": 5, "overriden": 1, "c": 2, "asked": 3}
+ # v1.2.4.ok
+ assert pipeline(inputs, ["asked"]) == filtdict(exp, "asked")
+ # FAILs
+ # - on v1.2.4 with (overriden, asked): = (5, 7) instead of (1, 3)
+ # - on #18(unsatisfied) + #23(ordered-sets) with (overriden, asked) = (5, 7) instead of (1, 3)
+ # FIXED on #26
+ assert pipeline(inputs) == exp
+
+ ## Test OVERWITES
+ #
+ overwrites = {}
+ pipeline.set_overwrites_collector(overwrites)
+ assert pipeline(inputs, ["asked"]) == filtdict(exp, "asked")
+ assert overwrites == {} # unjust must have been pruned
+
+ overwrites = {}
+ pipeline.set_overwrites_collector(overwrites)
+ assert pipeline(inputs) == exp
+ assert overwrites == {} # unjust must have been pruned
+
+ ## Test Parallel
+ #
+ pipeline.set_execution_method("parallel")
+ overwrites = {}
+ pipeline.set_overwrites_collector(overwrites)
+ #assert pipeline(inputs, ["asked"]) == filtdict(exp, "asked")
+ assert overwrites == {} # unjust must have been pruned
+
+ overwrites = {}
+ pipeline.set_overwrites_collector(overwrites)
+ assert pipeline(inputs) == exp
+ assert overwrites == {} # unjust must have been pruned
+
+
+def test_pruning_multiouts_not_override_intermediates1():
+ # Test #25: v.1.2.4 overwrites intermediate data when a previous operation
+ # must run for its other outputs (outputs asked or not)
+ pipeline = compose(name="pipeline")(
+ operation(name="must run", needs=["a"], provides=["overriden", "calced"])
+ (lambda x: (x, 2 * x)),
+ operation(name="add", needs=["overriden", "calced"], provides=["asked"])(add),
+ )
+
+ inputs = {"a": 5, "overriden": 1, "c": 2}
+ exp = {"a": 5, "overriden": 1, "calced": 10, "asked": 11}
+ # FAILs
+ # - on v1.2.4 with (overriden, asked) = (5, 15) instead of (1, 11)
+ # - on #18(unsatisfied) + #23(ordered-sets) like v1.2.4.
+ # FIXED on #26
+ assert pipeline({"a": 5, "overriden": 1}) == exp
+ # FAILs
+ # - on v1.2.4 with KeyError: 'e',
+ # - on #18(unsatisfied) + #23(ordered-sets) with empty result.
+ # FIXED on #26
+ assert pipeline(inputs, ["asked"]) == filtdict(exp, "asked")
+
+ ## Test OVERWITES
+ #
+ overwrites = {}
+ pipeline.set_overwrites_collector(overwrites)
+ assert pipeline({"a": 5, "overriden": 1}) == exp
+ assert overwrites == {'overriden': 5}
+
+ overwrites = {}
+ pipeline.set_overwrites_collector(overwrites)
+ assert pipeline(inputs, ["asked"]) == filtdict(exp, "asked")
+ assert overwrites == {'overriden': 5}
+
+ ## Test parallel
+ #
+ pipeline.set_execution_method("parallel")
+ assert pipeline({"a": 5, "overriden": 1}) == exp
+ assert pipeline(inputs, ["asked"]) == filtdict(exp, "asked")
+
+
+def test_pruning_multiouts_not_override_intermediates2():
+ # Test #25: v.1.2.4 overrides intermediate data when a previous operation
+ # must run for its other outputs (outputs asked or not)
+ # SPURIOUS FAILS in < PY3.6 due to unordered dicts,
+ # eg https://travis-ci.org/ankostis/graphkit/jobs/594813119
+ pipeline = compose(name="pipeline")(
+ operation(name="must run", needs=["a"], provides=["overriden", "e"])
+ (lambda x: (x, 2 * x)),
+ operation(name="op1", needs=["overriden", "c"], provides=["d"])(add),
+ operation(name="op2", needs=["d", "e"], provides=["asked"])(mul),
+ )
+
+ inputs = {"a": 5, "overriden": 1, "c": 2}
+ exp = {"a": 5, "overriden": 1, "c": 2, "d": 3, "e": 10, "asked": 30}
+ # FAILs
+ # - on v1.2.4 with (overriden, asked) = (5, 70) instead of (1, 13)
+ # - on #18(unsatisfied) + #23(ordered-sets) like v1.2.4.
+ # FIXED on #26
+ assert pipeline(inputs) == exp
+ # FAILs
+ # - on v1.2.4 with KeyError: 'e',
+ # - on #18(unsatisfied) + #23(ordered-sets) with empty result.
+ assert pipeline(inputs, ["asked"]) == filtdict(exp, "asked")
+ # FIXED on #26
+
+ ## Test OVERWITES
+ #
+ overwrites = {}
+ pipeline.set_overwrites_collector(overwrites)
+ assert pipeline(inputs) == exp
+ assert overwrites == {'overriden': 5}
+
+ overwrites = {}
+ pipeline.set_overwrites_collector(overwrites)
+ assert pipeline(inputs, ["asked"]) == filtdict(exp, "asked")
+ assert overwrites == {'overriden': 5}
+
+ ## Test parallel
+ #
+ pipeline.set_execution_method("parallel")
+ assert pipeline(inputs) == exp
+ assert pipeline(inputs, ["asked"]) == filtdict(exp, "asked")
+
+
+def test_pruning_with_given_intermediate_and_asked_out():
+ # Test #24: v1.2.4 does not prune before given intermediate data when
+ # outputs not asked, but does so when output asked.
+ pipeline = compose(name="pipeline")(
+ operation(name="unjustly pruned", needs=["given-1"], provides=["a"])(identity),
+ operation(name="shortcuted", needs=["a", "b"], provides=["given-2"])(add),
+ operation(name="good_op", needs=["a", "given-2"], provides=["asked"])(add),
+ )
+
+ exp = {"given-1": 5, "b": 2, "given-2": 2, "a": 5, "asked": 7}
+ # v1.2.4 is ok
+ assert pipeline({"given-1": 5, "b": 2, "given-2": 2}) == exp
+ # FAILS
+ # - on v1.2.4 with KeyError: 'a',
+ # - on #18 (unsatisfied) with no result.
+ # FIXED on #18+#26 (new dag solver).
+ assert pipeline({"given-1": 5, "b": 2, "given-2": 2}, ["asked"]) == filtdict(exp, "asked")
+
+ ## Test OVERWITES
+ #
+ overwrites = {}
+ pipeline.set_overwrites_collector(overwrites)
+ assert pipeline({"given-1": 5, "b": 2, "given-2": 2}) == exp
+ assert overwrites == {}
+
+ overwrites = {}
+ pipeline.set_overwrites_collector(overwrites)
+ assert pipeline({"given-1": 5, "b": 2, "given-2": 2}, ["asked"]) == filtdict(exp, "asked")
+ assert overwrites == {}
+
+ ## Test parallel
+ # FAIL! in #26!
+ #
+ pipeline.set_execution_method("parallel")
+ assert pipeline({"given-1": 5, "b": 2, "given-2": 2}) == exp
+ assert pipeline({"given-1": 5, "b": 2, "given-2": 2}, ["asked"]) == filtdict(exp, "asked")
+
+def test_unsatisfied_operations():
+ # Test that operations with partial inputs are culled and not failing.
+ pipeline = compose(name="pipeline")(
+ operation(name="add", needs=["a", "b1"], provides=["a+b1"])(add),
+ operation(name="sub", needs=["a", "b2"], provides=["a-b2"])(sub),
+ )
+
+ exp = {"a": 10, "b1": 2, "a+b1": 12}
+ assert pipeline({"a": 10, "b1": 2}) == exp
+ assert pipeline({"a": 10, "b1": 2}, outputs=["a+b1"]) == filtdict(exp, "a+b1")
+
+ exp = {"a": 10, "b2": 2, "a-b2": 8}
+ assert pipeline({"a": 10, "b2": 2}) == exp
+ assert pipeline({"a": 10, "b2": 2}, outputs=["a-b2"]) == filtdict(exp, "a-b2")
+
+ ## Test parallel
+ #
+ pipeline.set_execution_method("parallel")
+ exp = {"a": 10, "b1": 2, "a+b1": 12}
+ assert pipeline({"a": 10, "b1": 2}) == exp
+ assert pipeline({"a": 10, "b1": 2}, outputs=["a+b1"]) == filtdict(exp, "a+b1")
+
+ exp = {"a": 10, "b2": 2, "a-b2": 8}
+ assert pipeline({"a": 10, "b2": 2}) == exp
+ assert pipeline({"a": 10, "b2": 2}, outputs=["a-b2"]) == filtdict(exp, "a-b2")
+
+def test_unsatisfied_operations_same_out():
+ # Test unsatisfied pairs of operations providing the same output.
+ pipeline = compose(name="pipeline")(
+ operation(name="mul", needs=["a", "b1"], provides=["ab"])(mul),
+ operation(name="div", needs=["a", "b2"], provides=["ab"])(floordiv),
+ operation(name="add", needs=["ab", "c"], provides=["ab_plus_c"])(add),
+ )
+
+ exp = {"a": 10, "b1": 2, "c": 1, "ab": 20, "ab_plus_c": 21}
+ assert pipeline({"a": 10, "b1": 2, "c": 1}) == exp
+ assert pipeline({"a": 10, "b1": 2, "c": 1}, outputs=["ab_plus_c"]) == filtdict(exp, "ab_plus_c")
+
+ exp = {"a": 10, "b2": 2, "c": 1, "ab": 5, "ab_plus_c": 6}
+ assert pipeline({"a": 10, "b2": 2, "c": 1}) == exp
+ assert pipeline({"a": 10, "b2": 2, "c": 1}, outputs=["ab_plus_c"]) == filtdict(exp, "ab_plus_c")
+
+ ## Test parallel
+ #
+ # FAIL! in #26
+ pipeline.set_execution_method("parallel")
+ exp = {"a": 10, "b1": 2, "c": 1, "ab": 20, "ab_plus_c": 21}
+ assert pipeline({"a": 10, "b1": 2, "c": 1}) == exp
+ assert pipeline({"a": 10, "b1": 2, "c": 1}, outputs=["ab_plus_c"]) == filtdict(exp, "ab_plus_c")
+ #
+ # FAIL! in #26
+ exp = {"a": 10, "b2": 2, "c": 1, "ab": 5, "ab_plus_c": 6}
+ assert pipeline({"a": 10, "b2": 2, "c": 1}) == exp
+ assert pipeline({"a": 10, "b2": 2, "c": 1}, outputs=["ab_plus_c"]) == filtdict(exp, "ab_plus_c")
def test_optional():
@@ -208,6 +495,68 @@ def addplusplus(a, b, c=0):
assert results['sum'] == sum(named_inputs.values())
+def test_optional_per_function_with_same_output():
+ # Test that the same need can be both optional and not on different operations.
+ #
+ ## ATTENTION, the selected function is NOT the one with more inputs
+ # but the 1st satisfiable function added in the network.
+
+ add_op = operation(name='add', needs=['a', 'b'], provides='a+-b')(add)
+ sub_op_optional = operation(
+ name='sub_opt', needs=['a', modifiers.optional('b')], provides='a+-b'
+ )(lambda a, b=10: a - b)
+
+ # Normal order
+ #
+ pipeline = compose(name='partial_optionals')(add_op, sub_op_optional)
+ #
+ named_inputs = {'a': 1, 'b': 2}
+ assert pipeline(named_inputs) == {'a': 1, 'a+-b': 3, 'b': 2}
+ assert pipeline(named_inputs, ['a+-b']) == {'a+-b': 3}
+ #
+ named_inputs = {'a': 1}
+ assert pipeline(named_inputs) == {'a': 1, 'a+-b': -9}
+ assert pipeline(named_inputs, ['a+-b']) == {'a+-b': -9}
+
+ # Inverse op order
+ #
+ pipeline = compose(name='partial_optionals')(sub_op_optional, add_op)
+ #
+ named_inputs = {'a': 1, 'b': 2}
+ assert pipeline(named_inputs) == {'a': 1, 'a+-b': -1, 'b': 2}
+ assert pipeline(named_inputs, ['a+-b']) == {'a+-b': -1}
+ #
+ named_inputs = {'a': 1}
+ assert pipeline(named_inputs) == {'a': 1, 'a+-b': -9}
+ assert pipeline(named_inputs, ['a+-b']) == {'a+-b': -9}
+
+ # PARALLEL + Normal order
+ #
+ pipeline = compose(name='partial_optionals')(add_op, sub_op_optional)
+ pipeline.set_execution_method("parallel")
+ #
+ named_inputs = {'a': 1, 'b': 2}
+ assert pipeline(named_inputs) == {'a': 1, 'a+-b': 3, 'b': 2}
+ assert pipeline(named_inputs, ['a+-b']) == {'a+-b': 3}
+ #
+ named_inputs = {'a': 1}
+ assert pipeline(named_inputs) == {'a': 1, 'a+-b': -9}
+ assert pipeline(named_inputs, ['a+-b']) == {'a+-b': -9}
+
+ # PARALLEL + Inverse op order
+ #
+ pipeline = compose(name='partial_optionals')(sub_op_optional, add_op)
+ pipeline.set_execution_method("parallel")
+ #
+ named_inputs = {'a': 1, 'b': 2}
+ assert pipeline(named_inputs) == {'a': 1, 'a+-b': -1, 'b': 2}
+ assert pipeline(named_inputs, ['a+-b']) == {'a+-b': -1}
+ #
+ named_inputs = {'a': 1}
+ assert pipeline(named_inputs) == {'a': 1, 'a+-b': -9}
+ assert pipeline(named_inputs, ['a+-b']) == {'a+-b': -9}
+
+
def test_deleted_optional():
# Test that DeleteInstructions included for optionals do not raise
# exceptions when the corresponding input is not prodided.
@@ -226,22 +575,71 @@ def addplusplus(a, b, c=0):
assert 'sum2' in results
+def test_deleteinstructs_vary_with_inputs():
+ # Check #21: DeleteInstructions positions vary when inputs change.
+ def count_deletions(steps):
+ return sum(isinstance(n, DeleteInstruction) for n in steps)
+
+ pipeline = compose(name="pipeline")(
+ operation(name="a free without b", needs=["a"], provides=["aa"])(identity),
+ operation(name="satisfiable", needs=["a", "b"], provides=["ab"])(add),
+ operation(name="optional ab", needs=["aa", modifiers.optional("ab")], provides=["asked"])
+ (lambda a, ab=10: a + ab),
+ )
+
+ inp = {"a": 2, "b": 3}
+ exp = inp.copy(); exp.update({"aa": 2, "ab": 5, "asked": 7})
+ res = pipeline(inp)
+ assert res == exp # ok
+ steps11 = pipeline.compile(inp).steps
+ res = pipeline(inp, outputs=["asked"])
+ assert res == filtdict(exp, "asked") # ok
+ steps12 = pipeline.compile(inp, ["asked"]).steps
+
+ inp = {"a": 2}
+ exp = inp.copy(); exp.update({"aa": 2, "asked": 12})
+ res = pipeline(inp)
+ assert res == exp # ok
+ steps21 = pipeline.compile(inp).steps
+ res = pipeline(inp, outputs=["asked"])
+ assert res == filtdict(exp, "asked") # ok
+ steps22 = pipeline.compile(inp, ["asked"]).steps
+
+ # When no outs, no del-instructs.
+ assert steps11 != steps12
+ assert count_deletions(steps11) == 0
+ assert steps21 != steps22
+ assert count_deletions(steps21) == 0
+
+ # Check steps vary with inputs
+ #
+ # FAILs in v1.2.4 + #18, PASS in #26
+ assert steps11 != steps21
+ # Check deletes vary with inputs
+ #
+ # FAILs in v1.2.4 + #18, PASS in #26
+ assert count_deletions(steps12) != count_deletions(steps22)
+
+
+@pytest.mark.slow
def test_parallel_execution():
import time
+ delay = 0.5
+
def fn(x):
- time.sleep(1)
+ time.sleep(delay)
print("fn %s" % (time.time() - t0))
return 1 + x
def fn2(a,b):
- time.sleep(1)
+ time.sleep(delay)
print("fn2 %s" % (time.time() - t0))
return a+b
def fn3(z, k=1):
- time.sleep(1)
+ time.sleep(delay)
print("fn3 %s" % (time.time() - t0))
return z + k
@@ -280,6 +678,7 @@ def fn3(z, k=1):
# make sure results are the same using either method
assert result_sequential == result_threaded
+@pytest.mark.slow
def test_multi_threading():
import time
import random
@@ -310,8 +709,8 @@ def infer(i):
assert tuple(sorted(results.keys())) == tuple(sorted(outputs)), (outputs, results)
return results
- N = 100
- for i in range(20, 200):
+ N = 33
+ for i in range(13, 61):
pool = Pool(i)
pool.map(infer, range(N))
pool.close()
@@ -353,6 +752,7 @@ def compute(self, inputs):
outputs.append(p)
return outputs
+
def test_backwards_compatibility():
sum_op1 = Sum(
diff --git a/test/test_plot.py b/test/test_plot.py
new file mode 100644
index 00000000..d17201cb
--- /dev/null
+++ b/test/test_plot.py
@@ -0,0 +1,190 @@
+# Copyright 2016, Yahoo Inc.
+# Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms.
+
+import sys
+from operator import add
+
+import pytest
+
+from graphkit import base, compose, network, operation, plot
+from graphkit.modifiers import optional
+
+
+@pytest.fixture
+def pipeline():
+ return compose(name="netop")(
+ operation(name="add", needs=["a", "b1"], provides=["ab1"])(add),
+ operation(name="sub", needs=["a", optional("b2")], provides=["ab2"])(
+ lambda a, b=1: a - b
+ ),
+ operation(name="abb", needs=["ab1", "ab2"], provides=["asked"])(add),
+ )
+
+
+@pytest.fixture(params=[{"a": 1}, {"a": 1, "b1": 2}])
+def inputs(request):
+ return {"a": 1, "b1": 2}
+
+
+@pytest.fixture(params=[None, ("a", "b1")])
+def input_names(request):
+ return request.param
+
+
+@pytest.fixture(params=[None, ["asked", "b1"]])
+def outputs(request):
+ return request.param
+
+
+@pytest.fixture(params=[None, 1])
+def solution(pipeline, inputs, outputs, request):
+ return request.param and pipeline(inputs, outputs)
+
+
+###### TEST CASES #######
+##
+
+
+def test_plotting_docstring():
+ common_formats = ".png .dot .jpg .jpeg .pdf .svg".split()
+ for ext in common_formats:
+ assert ext in base.NetworkOperation.plot.__doc__
+ assert ext in network.Network.plot.__doc__
+
+
+@pytest.mark.slow
+def test_plot_formats(pipeline, tmp_path):
+ ## Generate all formats (not needing to save files)
+
+ # run it here (and not in ficture) to ansure `last_plan` exists.
+ inputs = {"a": 1, "b1": 2}
+ outputs = ["asked", "b1"]
+ solution = pipeline(inputs, outputs)
+
+ # The 1st list does not working on my PC, or travis.
+ # NOTE: maintain the other lists manually from the Exception message.
+ failing_formats = ".dia .hpgl .mif .mp .pcl .pic .vtx .xlib".split()
+ # The subsequent format names producing the same dot-file.
+ dupe_formats = [
+ ".cmapx_np", # .cmapx
+ ".imap_np", # .imap
+ ".jpeg", # .jpe
+ ".jpg", # .jpe
+ ".plain-ext", # .plain
+ ]
+ null_formats = ".cmap .ismap".split()
+ forbidden_formats = set(failing_formats + dupe_formats + null_formats)
+ formats_to_check = sorted(set(plot.supported_plot_formats()) - forbidden_formats)
+
+ # Collect old dots to detect dupes.
+ prev_renders = {}
+ dupe_errs = []
+ for ext in formats_to_check:
+ # Check Network.
+ #
+ render = pipeline.plot(solution=solution).create(format=ext[1:])
+ if not render:
+ dupe_errs.append("\n null: %s" % ext)
+
+ elif render in prev_renders.values():
+ dupe_errs.append(
+ "\n dupe: %s <--> %s"
+ % (ext, [pext for pext, pdot in prev_renders.items() if pdot == render])
+ )
+ else:
+ prev_renders[ext] = render
+
+ if dupe_errs:
+ raise AssertionError("Failed pydot formats: %s" % "".join(sorted(dupe_errs)))
+
+
+def test_plotters_hierarchy(pipeline, inputs, outputs):
+ # Plotting original network, no plan.
+ base_dot = str(pipeline.plot(inputs=inputs, outputs=outputs))
+ assert base_dot
+ assert pipeline.name in str(base_dot)
+
+ solution = pipeline(inputs, outputs)
+
+ # Plotting delegates to netwrok plan.
+ plan_dot = str(pipeline.plot(inputs=inputs, outputs=outputs))
+ assert plan_dot
+ assert plan_dot != base_dot
+ assert pipeline.name in str(plan_dot)
+
+ # Plot a plan + solution, which must be different from all before.
+ sol_plan_dot = str(pipeline.plot(inputs=inputs, outputs=outputs, solution=solution))
+ assert sol_plan_dot != base_dot
+ assert sol_plan_dot != plan_dot
+ assert pipeline.name in str(plan_dot)
+
+ plan = pipeline.net.last_plan
+ pipeline.net.last_plan = None
+
+ # We resetted last_plan to check if it reproduces original.
+ base_dot2 = str(pipeline.plot(inputs=inputs, outputs=outputs))
+ assert str(base_dot2) == str(base_dot)
+
+ # Calling plot directly on plan misses netop.name
+ raw_plan_dot = str(plan.plot(inputs=inputs, outputs=outputs))
+ assert pipeline.name not in str(raw_plan_dot)
+
+ # Chek plan does not contain solution, unless given.
+ raw_sol_plan_dot = str(plan.plot(inputs=inputs, outputs=outputs, solution=solution))
+ assert raw_sol_plan_dot != raw_plan_dot
+
+
+def test_plot_bad_format(pipeline, tmp_path):
+ with pytest.raises(ValueError, match="Unknown file format") as exinfo:
+ pipeline.plot(filename="bad.format")
+
+ ## Check help msg lists all siupported formats
+ for ext in plot.supported_plot_formats():
+ assert exinfo.match(ext)
+
+
+def test_plot_write_file(pipeline, tmp_path):
+ # Try saving a file from one format.
+
+ fpath = tmp_path / "network.png"
+ dot1 = pipeline.plot(str(fpath))
+ assert fpath.exists()
+ assert dot1
+
+
+def _check_plt_img(img):
+ assert img is not None
+ assert len(img) > 0
+
+
+def test_plot_matpotlib(pipeline, tmp_path):
+ ## Try matplotlib Window, but # without opening a Window.
+
+ if sys.version_info < (3, 5):
+ # On PY< 3.5 it fails with:
+ # nose.proxy.TclError: no display name and no $DISPLAY environment variable
+ # eg https://travis-ci.org/ankostis/graphkit/jobs/593957996
+ import matplotlib
+
+ matplotlib.use("Agg")
+ # do not open window in headless travis
+ img = pipeline.plot(show=-1)
+ _check_plt_img(img)
+
+
+def test_plot_jupyter(pipeline, tmp_path):
+ ## Try returned Jupyter SVG.
+
+ dot = pipeline.plot()
+ s = dot._repr_svg_()
+ assert "SVG" in s
+
+
+def test_plot_legend(pipeline, tmp_path):
+ ## Try returned Jupyter SVG.
+
+ dot = plot.legend()
+ assert dot
+
+ img = plot.legend(show=-1)
+ _check_plt_img(img)