diff --git a/.gitignore b/.gitignore index db4561ea..ce3d241b 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,8 @@ docs/_build/ # PyBuilder target/ + +# Plots genersated when running sample code +/*.png +/*.svg +/*.pdf \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index d8657a8f..025017a7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,19 +2,28 @@ language: python python: - "2.7" - - "3.4" - "3.5" + - "3.6" + - "3.7" + +addons: + apt: + packages: + - graphviz + install: - pip install Sphinx sphinx_rtd_theme codecov packaging - "python -c $'import os, packaging.version as version\\nv = version.parse(os.environ.get(\"TRAVIS_TAG\", \"1.0\")).public\\nwith open(\"VERSION\", \"w\") as f: f.write(v)'" - - python setup.py install + - pip install -e .[test] - cd docs - make clean html - cd .. script: - - python setup.py nosetests --with-coverage --cover-package=graphkit + - pytest -v --cov=graphkit + # In case you adopt -m 'not slow' in setup.cfg. + #- pytest -vm slow --cov=graphkit deploy: provider: pypi diff --git a/README.md b/README.md index 0e1e95a4..81b1657d 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,8 @@ > It's a DAG all the way down +![Sample graph](docs/source/images/test_pruning_not_overrides_given_intermediate-asked.png "Sample graph") + ## Lightweight computation graphs for Python GraphKit is a lightweight Python module for creating and running ordered graphs of computations, where the nodes of the graph correspond to computational operations, and the edges correspond to output --> input dependencies between those operations. Such graphs are useful in computer vision, machine learning, and many other domains. @@ -14,9 +16,12 @@ GraphKit is a lightweight Python module for creating and running ordered graphs Here's how to install: -``` -pip install graphkit -``` + pip install graphkit + +OR with dependencies for plotting support (and you need to install [`Graphviz`](https://graphviz.org) +program separately with your OS tools):: + + pip install graphkit[plot] Here's a Python script with an example GraphKit computation graph that produces multiple outputs (`a * b`, `a - a * b`, and `abs(a - a * b) ** 3`): @@ -30,20 +35,20 @@ def abspow(a, p): return c # Compose the mul, sub, and abspow operations into a computation graph. -graph = compose(name="graph")( +graphop = compose(name="graphop")( operation(name="mul1", needs=["a", "b"], provides=["ab"])(mul), operation(name="sub1", needs=["a", "ab"], provides=["a_minus_ab"])(sub), operation(name="abspow1", needs=["a_minus_ab"], provides=["abs_a_minus_ab_cubed"], params={"p": 3})(abspow) ) # Run the graph and request all of the outputs. -out = graph({'a': 2, 'b': 5}) +out = graphop({'a': 2, 'b': 5}) # Prints "{'a': 2, 'a_minus_ab': -8, 'b': 5, 'ab': 10, 'abs_a_minus_ab_cubed': 512}". print(out) # Run the graph and request a subset of the outputs. -out = graph({'a': 2, 'b': 5}, outputs=["a_minus_ab"]) +out = graphop({'a': 2, 'b': 5}, outputs=["a_minus_ab"]) # Prints "{'a_minus_ab': -8}". print(out) @@ -51,6 +56,21 @@ print(out) As you can see, any function can be used as an operation in GraphKit, even ones imported from system modules! + +## Plotting + +For debugging the above graph-operation you may plot it using these methods: + +```python + graphop.plot(show=True, solution=out) # open a matplotlib window with solution values in nodes + graphop.plot("intro.svg") # other supported formats: png, jpg, pdf, ... +``` + +![Intro graph](docs/source/images/intro.png "Intro graph") +![Graphkit Legend](docs/source/images/GraphkitLegend.svg "Graphkit Legend") + +> **TIP:** The `pydot.Dot` instances returned by `plot()` are rendered as SVG in *Jupyter/IPython*. + # License Code licensed under the Apache License, Version 2.0 license. See LICENSE file for terms. diff --git a/docs/source/graph_composition.rst b/docs/source/graph_composition.rst index ba428a14..1d8e9f6d 100644 --- a/docs/source/graph_composition.rst +++ b/docs/source/graph_composition.rst @@ -30,15 +30,15 @@ The simplest use case for ``compose`` is assembling a collection of individual o return c # Compose the mul, sub, and abspow operations into a computation graph. - graph = compose(name="graph")( + graphop = compose(name="graphop")( operation(name="mul1", needs=["a", "b"], provides=["ab"])(mul), operation(name="sub1", needs=["a", "ab"], provides=["a_minus_ab"])(sub), operation(name="abspow1", needs=["a_minus_ab"], provides=["abs_a_minus_ab_cubed"], params={"p": 3})(abspow) ) -The call here to ``compose()()`` yields a runnable computation graph that looks like this (where the circles are operations, squares are data, and octagons are parameters): +The call here to ``compose()`` yields a runnable computation graph that looks like this (where the circles are operations, squares are data, and octagons are parameters): -.. image:: images/example_graph.svg +.. image:: images/intro.svg .. _graph-computations: @@ -49,7 +49,7 @@ Running a computation graph The graph composed in the example above in :ref:`simple-graph-composition` can be run by simply calling it with a dictionary argument whose keys correspond to the names of inputs to the graph and whose values are the corresponding input values. For example, if ``graph`` is as defined above, we can run it like this:: # Run the graph and request all of the outputs. - out = graph({'a': 2, 'b': 5}) + out = graphop({'a': 2, 'b': 5}) # Prints "{'a': 2, 'a_minus_ab': -8, 'b': 5, 'ab': 10, 'abs_a_minus_ab_cubed': 512}". print(out) @@ -57,10 +57,10 @@ The graph composed in the example above in :ref:`simple-graph-composition` can b Producing a subset of outputs ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -By default, calling a graph on a set of inputs will yield all of that graph's outputs. You can use the ``outputs`` parameter to request only a subset. For example, if ``graph`` is as above:: +By default, calling a graph-operation on a set of inputs will yield all of that graph's outputs. You can use the ``outputs`` parameter to request only a subset. For example, if ``graphop`` is as above:: - # Run the graph and request a subset of the outputs. - out = graph({'a': 2, 'b': 5}, outputs=["a_minus_ab"]) + # Run the graph-operation and request a subset of the outputs. + out = graphop({'a': 2, 'b': 5}, outputs=["a_minus_ab"]) # Prints "{'a_minus_ab': -8}". print(out) @@ -70,17 +70,17 @@ When using ``outputs`` to request only a subset of a graph's outputs, GraphKit e Short-circuiting a graph computation ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -You can short-circuit a graph computation, making certain inputs unnecessary, by providing a value in the graph that is further downstream in the graph than those inputs. For example, in the graph we've been working with, you could provide the value of ``a_minus_ab`` to make the inputs ``a`` and ``b`` unnecessary:: +You can short-circuit a graph computation, making certain inputs unnecessary, by providing a value in the graph that is further downstream in the graph than those inputs. For example, in the graph-operation we've been working with, you could provide the value of ``a_minus_ab`` to make the inputs ``a`` and ``b`` unnecessary:: - # Run the graph and request a subset of the outputs. - out = graph({'a_minus_ab': -8}) + # Run the graph-operation and request a subset of the outputs. + out = graphop({'a_minus_ab': -8}) # Prints "{'a_minus_ab': -8, 'abs_a_minus_ab_cubed': 512}". print(out) When you do this, any ``operation`` nodes that are not on a path from the downstream input to the requested outputs (i.e. predecessors of the downstream input) are not computed. For example, the ``mul1`` and ``sub1`` operations are not executed here. -This can be useful if you have a graph that accepts alternative forms of the same input. For example, if your graph requires a ``PIL.Image`` as input, you could allow your graph to be run in an API server by adding an earlier ``operation`` that accepts as input a string of raw image data and converts that data into the needed ``PIL.Image``. Then, you can either provide the raw image data string as input, or you can provide the ``PIL.Image`` if you have it and skip providing the image data string. +This can be useful if you have a graph-operation that accepts alternative forms of the same input. For example, if your graph-operation requires a ``PIL.Image`` as input, you could allow your graph to be run in an API server by adding an earlier ``operation`` that accepts as input a string of raw image data and converts that data into the needed ``PIL.Image``. Then, you can either provide the raw image data string as input, or you can provide the ``PIL.Image`` if you have it and skip providing the image data string. Adding on to an existing computation graph ------------------------------------------ @@ -109,7 +109,7 @@ Sometimes you will have two computation graphs---perhaps ones that share operati combined_graph = compose(name="combined_graph")(graph1, graph2) -However, if you want to combine graphs that share operations and don't want to pay the price of running redundant computations, you can set the ``merge`` parameter of ``compose()`` to ``True``. This will consolidate redundant ``operation`` nodes (based on ``name``) into a single node. For example, let's say we have ``graph``, as in the examples above, along with this graph:: +However, if you want to combine graphs that share operations and don't want to pay the price of running redundant computations, you can set the ``merge`` parameter of ``compose()`` to ``True``. This will consolidate redundant ``operation`` nodes (based on ``name``) into a single node. For example, let's say we have ``graphop``, as in the examples above, along with this graph:: # This graph shares the "mul1" operation with graph. another_graph = compose(name="another_graph")( @@ -117,9 +117,9 @@ However, if you want to combine graphs that share operations and don't want to p operation(name="mul2", needs=["c", "ab"], provides=["cab"])(mul) ) -We can merge ``graph`` and ``another_graph`` like so, avoiding a redundant ``mul1`` operation:: +We can merge ``graphop`` and ``another_graph`` like so, avoiding a redundant ``mul1`` operation:: - merged_graph = compose(name="merged_graph", merge=True)(graph, another_graph) + merged_graph = compose(name="merged_graph", merge=True)(graphop, another_graph) This ``merged_graph`` will look like this: diff --git a/docs/source/images/GraphkitLegend.svg b/docs/source/images/GraphkitLegend.svg new file mode 100644 index 00000000..798ba3a5 --- /dev/null +++ b/docs/source/images/GraphkitLegend.svg @@ -0,0 +1,150 @@ + + + + + + +G + + +cluster_legend + +Graphkit Legend + + + +operation + +operation + + + +graphop + +graph operation + + + + +insteps + +execution step + + + + +executed + +executed + + + + +data + +data + + + +input + +input + + + + +output + +output + + + + +inp_out + +inp+out + + + + +evicted + +evicted + + + + +pinned + +pinned + + + + +evpin + +evict+pin + + + + +sol + +in solution + + + + + +e2 + +dependency + + + +e1->e2 + + + + + +e3 + +optional + + + +e2->e3 + + + + + +e4 + +pruned dependency + + + +e3->e4 + + + + + +e5 + +execution sequence + + + +e4->e5 + + +1 + + + diff --git a/docs/source/images/intro.svg b/docs/source/images/intro.svg new file mode 100644 index 00000000..4469543f --- /dev/null +++ b/docs/source/images/intro.svg @@ -0,0 +1,143 @@ + + + + + + +G + +graphop + +cluster_after prunning + +after prunning + + + +abspow1 + +abspow1 + + + +abs_a_minus_ab_cubed + +abs_a_minus_ab_cubed + + + +abspow1->abs_a_minus_ab_cubed + + + + + +a + +a + + + +mul1 + +mul1 + + + +a->mul1 + + + + + +ab + +ab + + + +a->ab + + +4 + + + +sub1 + +sub1 + + + +a->sub1 + + + + + +b + +b + + + +mul1->b + + +1 + + + +mul1->ab + + + + + +b->mul1 + + + + + +b->sub1 + + +2 + + + +ab->sub1 + + + + + +sub1->a + + +3 + + + +a_minus_ab + +a_minus_ab + + + +sub1->a_minus_ab + + + + + +a_minus_ab->abspow1 + + + + + diff --git a/docs/source/images/test_pruning_not_overrides_given_intermediate-asked.png b/docs/source/images/test_pruning_not_overrides_given_intermediate-asked.png new file mode 100644 index 00000000..c8e6cdb4 Binary files /dev/null and b/docs/source/images/test_pruning_not_overrides_given_intermediate-asked.png differ diff --git a/docs/source/index.rst b/docs/source/index.rst index 5c5e505c..f542da58 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -38,6 +38,12 @@ Here's how to install:: pip install graphkit +OR with dependencies for plotting support (and you need to install `Graphviz +`_ program separately with your OS tools):: + + pip install graphkit[plot] + + Here's a Python script with an example GraphKit computation graph that produces multiple outputs (``a * b``, ``a - a * b``, and ``abs(a - a * b) ** 3``):: from operator import mul, sub @@ -49,26 +55,54 @@ Here's a Python script with an example GraphKit computation graph that produces return c # Compose the mul, sub, and abspow operations into a computation graph. - graph = compose(name="graph")( + graphop = compose(name="graphop")( operation(name="mul1", needs=["a", "b"], provides=["ab"])(mul), operation(name="sub1", needs=["a", "ab"], provides=["a_minus_ab"])(sub), operation(name="abspow1", needs=["a_minus_ab"], provides=["abs_a_minus_ab_cubed"], params={"p": 3})(abspow) ) - # Run the graph and request all of the outputs. - out = graph({'a': 2, 'b': 5}) + # Run the graph-operation and request all of the outputs. + out = graphop({'a': 2, 'b': 5}) # Prints "{'a': 2, 'a_minus_ab': -8, 'b': 5, 'ab': 10, 'abs_a_minus_ab_cubed': 512}". print(out) - # Run the graph and request a subset of the outputs. - out = graph({'a': 2, 'b': 5}, outputs=["a_minus_ab"]) + # Run the graph-operation and request a subset of the outputs. + out = graphop({'a': 2, 'b': 5}, outputs=["a_minus_ab"]) # Prints "{'a_minus_ab': -8}". print(out) As you can see, any function can be used as an operation in GraphKit, even ones imported from system modules! + +Plotting +-------- + +For debugging the above graph-operation you may plot it using these methods:: + + graphop.plot(show=True, solution=out) # open a matplotlib window with solution values in nodes + graphop.plot("intro.svg") # other supported formats: png, jpg, pdf, ... + +.. image:: images/intro.svg + :alt: Intro graph + +.. figure:: images/GraphkitLegend.svg + :alt: Graphkit Legend + + The legend for all graphkit diagrams, generated by :func:`graphkit.plot.legend()`. + +.. Tip:: + The ``pydot.Dot`` instances returned by ``plot()`` are rendered + directly in *Jupyter/IPython* notebooks as SVG images. + +.. NOTE:: + For plots, `Graphviz `_ program must be in your PATH, + and ``pydot`` & ``matplotlib`` python packages installed. + You may install both when installing ``graphkit`` with its ``plot`` extras:: + + pip install graphkit[plot] + License ------- diff --git a/docs/source/operations.rst b/docs/source/operations.rst index b7b4dbad..fbd6dea2 100644 --- a/docs/source/operations.rst +++ b/docs/source/operations.rst @@ -51,15 +51,15 @@ Let's look again at the operations from the script in :ref:`quick-start`, for ex return c # Compose the mul, sub, and abspow operations into a computation graph. - graph = compose(name="graph")( + graphop = compose(name="graphop")( operation(name="mul1", needs=["a", "b"], provides=["ab"])(mul), operation(name="sub1", needs=["a", "ab"], provides=["a_minus_ab"])(sub), operation(name="abspow1", needs=["a_minus_ab"], provides=["abs_a_minus_ab_cubed"], params={"p": 3})(abspow) ) -The ``needs`` and ``provides`` arguments to the operations in this script define a computation graph that looks like this (where the circles are operations, squares are data, and octagons are parameters): +The ``needs`` and ``provides`` arguments to the operations in this script define a computation graph that looks like this (where the oval are operations, squares/houses are data): -.. image:: images/example_graph.svg +.. image:: images/intro.svg Constant operation parameters: ``params`` @@ -86,7 +86,7 @@ If you are defining your computation graph and the functions that comprise it al def foo(a, b, c): return c * (a + b) - graph = compose(name='foo_graph')(foo) + graphop = compose(name='foo_graph')(foo) Functional specification ^^^^^^^^^^^^^^^^^^^^^^^^ @@ -99,7 +99,7 @@ If the functions underlying your computation graph operations are defined elsewh add_op = operation(name='add_op', needs=['a', 'b'], provides='sum')(add) mul_op = operation(name='mul_op', needs=['c', 'sum'], provides='product')(mul) - graph = compose(name='add_mul_graph')(add_op, mul_op) + graphop = compose(name='add_mul_graph')(add_op, mul_op) The functional specification is also useful if you want to create multiple ``operation`` instances from the same function, perhaps with different parameter values, e.g.:: @@ -111,7 +111,7 @@ The functional specification is also useful if you want to create multiple ``ope pow_op1 = operation(name='pow_op1', needs=['a'], provides='a_squared')(mypow) pow_op2 = operation(name='pow_op2', needs=['a'], params={'p': 3}, provides='a_cubed')(mypow) - graph = compose(name='two_pows_graph')(pow_op1, pow_op2) + graphop = compose(name='two_pows_graph')(pow_op1, pow_op2) A slightly different approach can be used here to accomplish the same effect by creating an operation "factory":: @@ -125,7 +125,7 @@ A slightly different approach can be used here to accomplish the same effect by pow_op1 = pow_op_factory(name='pow_op1', needs=['a'], provides='a_squared') pow_op2 = pow_op_factory(name='pow_op2', needs=['a'], params={'p': 3}, provides='a_cubed') - graph = compose(name='two_pows_graph')(pow_op1, pow_op2) + graphop = compose(name='two_pows_graph')(pow_op1, pow_op2) Modifiers on ``operation`` inputs and outputs diff --git a/graphkit/base.py b/graphkit/base.py index 1c04e8d5..5bf35d7e 100644 --- a/graphkit/base.py +++ b/graphkit/base.py @@ -1,5 +1,12 @@ # Copyright 2016, Yahoo Inc. # Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms. +try: + from collections import abc +except ImportError: + import collections as abc + +from . import plot + class Data(object): """ @@ -144,35 +151,64 @@ def __repr__(self): self.provides) -class NetworkOperation(Operation): +class NetworkOperation(Operation, plot.Plotter): def __init__(self, **kwargs): self.net = kwargs.pop('net') Operation.__init__(self, **kwargs) # set execution mode to single-threaded sequential by default self._execution_method = "sequential" + self._overwrites_collector = None + + def _build_pydot(self, **kws): + """delegate to network""" + kws.setdefault("title", self.name) + plotter = self.net.last_plan or self.net + return plotter._build_pydot(**kws) def _compute(self, named_inputs, outputs=None): - return self.net.compute(outputs, named_inputs, method=self._execution_method) + return self.net.compute( + named_inputs, outputs, method=self._execution_method, + overwrites_collector=self._overwrites_collector) def __call__(self, *args, **kwargs): return self._compute(*args, **kwargs) + def compile(self, *args, **kwargs): + return self.net.compile(*args, **kwargs) + def set_execution_method(self, method): """ Determine how the network will be executed. - Args: - method: str - If "parallel", execute graph operations concurrently - using a threadpool. + :param str method: + If "parallel", execute graph operations concurrently + using a threadpool. """ - options = ['parallel', 'sequential'] - assert method in options + choices = ['parallel', 'sequential'] + if method not in choices: + raise ValueError( + "Invalid computation method %r! Must be one of %s" + (method, choices)) self._execution_method = method - def plot(self, filename=None, show=False): - self.net.plot(filename=filename, show=show) + def set_overwrites_collector(self, collector): + """ + Asks to put all *overwrites* into the `collector` after computing + + An "overwrites" is intermediate value calculated but NOT stored + into the results, becaues it has been given also as an intemediate + input value, and the operation that would overwrite it MUST run for + its other results. + + :param collector: + a mutable dict to be fillwed with named values + """ + if collector is not None and not isinstance(collector, abc.MutableMapping): + raise ValueError( + "Overwrites collector was not a MutableMapping, but: %r" + % collector) + self._overwrites_collector = collector def __getstate__(self): state = Operation.__getstate__(self) diff --git a/graphkit/functional.py b/graphkit/functional.py index 65388973..5b3735fe 100644 --- a/graphkit/functional.py +++ b/graphkit/functional.py @@ -1,7 +1,7 @@ # Copyright 2016, Yahoo Inc. # Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms. - -from itertools import chain +from boltons.setutils import IndexedSet as iset +import networkx as nx from .base import Operation, NetworkOperation from .network import Network @@ -28,7 +28,7 @@ def _compute(self, named_inputs, outputs=None): result = zip(self.provides, result) if outputs: - outputs = set(outputs) + outputs = sorted(set(outputs)) result = filter(lambda x: x[0] in outputs, result) return dict(result) @@ -185,27 +185,23 @@ def __call__(self, *operations): # If merge is desired, deduplicate operations before building network if self.merge: - merge_set = set() + merge_set = iset() # Preseve given node order. for op in operations: if isinstance(op, NetworkOperation): - net_ops = filter(lambda x: isinstance(x, Operation), op.net.steps) - merge_set.update(net_ops) + netop_nodes = nx.topological_sort(op.net.graph) + merge_set.update(s for s in netop_nodes if isinstance(s, Operation)) else: merge_set.add(op) - operations = list(merge_set) - - def order_preserving_uniquifier(seq, seen=None): - seen = seen if seen else set() - seen_add = seen.add - return [x for x in seq if not (x in seen or seen_add(x))] + operations = merge_set - provides = order_preserving_uniquifier(chain(*[op.provides for op in operations])) - needs = order_preserving_uniquifier(chain(*[op.needs for op in operations]), set(provides)) + provides = iset(p for op in operations for p in op.provides) + # Mark them all as optional, now that #18 calmly ignores + # non-fully satisfied operations. + needs = iset(optional(n) for op in operations for n in op.needs) - provides - # compile network + # Build network net = Network() for op in operations: net.add_op(op) - net.compile() return NetworkOperation(name=self.name, needs=needs, provides=provides, params={}, net=net) diff --git a/graphkit/network.py b/graphkit/network.py index 0df3ddf8..a0dc7663 100644 --- a/graphkit/network.py +++ b/graphkit/network.py @@ -1,19 +1,104 @@ # Copyright 2016, Yahoo Inc. # Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms. - -import time +"""" +The main implementation of the network of operations & data to compute. + +The execution of network *operations* (aka computation) is splitted +in 2 phases: + +- COMPILE: prune unsatisfied nodes, sort dag topologically & solve it, and + derive the *execution steps* (see below) based on the given *inputs* + and asked *outputs*. + +- EXECUTE: sequential or parallel invocation of the underlying functions + of the operations with arguments from the ``solution``. + +Computations are based on 5 data-structures: + +:attr:`Network.graph` + A ``networkx`` graph (yet a DAG) containing interchanging layers of + :class:`Operation` and :class:`DataPlaceholderNode` nodes. + They are layed out and connected by repeated calls of + :meth:`~Network.add_OP`. + + The computation starts with :meth:`~Network._prune_graph()` extracting + a *DAG subgraph* by *pruning* its nodes based on given inputs and + requested outputs in :meth:`~Network.compute()`. + +:attr:`ExecutionPlan.dag` + An directed-acyclic-graph containing the *pruned* nodes as build by + :meth:`~Network._prune_graph()`. This pruned subgraph is used to decide + the :attr:`ExecutionPlan.steps` (below). + The containing :class:`ExecutionPlan.steps` instance is cached + in :attr:`_cached_plans` across runs with inputs/outputs as key. + +:attr:`ExecutionPlan.steps` + It is the list of the operation-nodes only + from the dag (above), topologically sorted, and interspersed with + *instructions steps* needed to complete the run. + It is built by :meth:`~Network._build_execution_steps()` based on + the subgraph dag extracted above. + The containing :class:`ExecutionPlan.steps` instance is cached + in :attr:`_cached_plans` across runs with inputs/outputs as key. + + The *instructions* items achieve the following: + + - :class:`DeleteInstruction`: delete items from `solution` as soon as + they are not needed further down the dag, to reduce memory footprint + while computing. + + - :class:`PinInstruction`: avoid overwritting any given intermediate + inputs, and still allow their providing operations to run + (because they are needed for their other outputs). + +:var solution: + a local-var in :meth:`~Network.compute()`, initialized on each run + to hold the values of the given inputs, generated (intermediate) data, + and output values. + It is returned as is if no specific outputs requested; no data-eviction + happens then. + +:arg overwrites: + The optional argument given to :meth:`~Network.compute()` to colect the + intermediate *calculated* values that are overwritten by intermediate + (aka "pinned") input-values. +""" +import logging import os -import networkx as nx - +import sys +import time +from collections import defaultdict, namedtuple from io import StringIO +from itertools import chain + +import networkx as nx +from boltons.setutils import IndexedSet as iset +from . import plot from .base import Operation +from .modifiers import optional + +log = logging.getLogger(__name__) + + +from networkx import DiGraph +if sys.version_info < (3, 6): + """ + Consistently ordered variant of :class:`~networkx.DiGraph`. + + PY3.6 has inmsertion-order dicts, but PY3.5 has not. + And behvavior *and TCs) in these environments may fail spuriously! + Still *subgraphs* may not patch! + + Fix from: + https://networkx.github.io/documentation/latest/reference/classes/ordered.html#module-networkx.classes.ordered + """ + from networkx import OrderedDiGraph as DiGraph class DataPlaceholderNode(str): """ - A node for the Network graph that describes the name of a Data instance - produced or required by a layer. + Dag node naming a data-value produced or required by an operation. """ def __repr__(self): return 'DataPlaceholderNode("%s")' % self @@ -21,46 +106,65 @@ def __repr__(self): class DeleteInstruction(str): """ - An instruction for the compiled list of evaluation steps to free or delete - a Data instance from the Network's cache after it is no longer needed. + Execution step to delete a computed value from the `solution`. + + It's a step in :attr:`ExecutionPlan.steps` for the data-node `str` that + frees its data-value from `solution` after it is no longer needed, + to reduce memory footprint while computing the graph. """ def __repr__(self): return 'DeleteInstruction("%s")' % self -class Network(object): +class PinInstruction(str): + """ + Execution step to replace a computed value in the `solution` from the inputs, + + and to store the computed one in the ``overwrites`` instead + (both `solution` & ``overwrites`` are local-vars in :meth:`~Network.compute()`). + + It's a step in :attr:`ExecutionPlan.steps` for the data-node `str` that + ensures the corresponding intermediate input-value is not overwritten when + its providing function(s) could not be pruned, because their other outputs + are needed elesewhere. """ - This is the main network implementation. The class contains all of the - code necessary to weave together operations into a directed-acyclic-graph (DAG) - and pass data through. + def __repr__(self): + return 'PinInstruction("%s")' % self + + +class Network(plot.Plotter): """ + Assemble operations & data into a directed-acyclic-graph (DAG) to run them. - def __init__(self, **kwargs): - """ - """ + """ + def __init__(self, **kwargs): # directed graph of layer instances and data-names defining the net. - self.graph = nx.DiGraph() - self._debug = kwargs.get("debug", False) + self.graph = DiGraph() - # this holds the timing information for eache layer + # this holds the timing information for each layer self.times = {} - # a compiled list of steps to evaluate layers *in order* and free mem. - self.steps = [] + #: Speed up :meth:`compile()` call and avoid a multithreading issue(?) + #: that is occuring when accessing the dag in networkx. + self._cached_plans = {} + + #: the execution_plan of the last call to :meth:`compute()` + #: (not ``compile()``!), for debugging purposes. + self.last_plan = None - # This holds a cache of results for the _find_necessary_steps - # function, this helps speed up the compute call as well avoid - # a multithreading issue that is occuring when accessing the - # graph in networkx - self._necessary_steps_cache = {} + def _build_pydot(self, **kws): + from .plot import build_pydot + kws.setdefault('graph', self.graph) + + return build_pydot(**kws) def add_op(self, operation): """ Adds the given operation and its data requirements to the network graph - based on the name of the operation, the names of the operation's needs, and - the names of the data it provides. + based on the name of the operation, the names of the operation's needs, + and the names of the data it provides. :param Operation operation: Operation object to add. """ @@ -71,7 +175,9 @@ def add_op(self, operation): assert operation.provides is not None, "Operation's 'provides' must be named" # assert layer is only added once to graph - assert operation not in self.graph.nodes(), "Operation may only be added once" + assert operation not in self.graph.nodes, "Operation may only be added once" + + self._cached_plans = {} # add nodes and edges to graph describing the data needs for this layer for n in operation.needs: @@ -81,174 +187,396 @@ def add_op(self, operation): for p in operation.provides: self.graph.add_edge(operation, DataPlaceholderNode(p)) - # clear compiled steps (must recompile after adding new layers) - self.steps = [] + def _build_execution_steps(self, dag, inputs, outputs): + """ + Create the list of operation-nodes & *instructions* evaluating all - def list_layers(self): - assert self.steps, "network must be compiled before listing layers." - return [(s.name, s) for s in self.steps if isinstance(s, Operation)] - + operations & instructions needed a) to free memory and b) avoid + overwritting given intermediate inputs. - def show_layers(self): - """Shows info (name, needs, and provides) about all layers in this network.""" - for name, step in self.list_layers(): - print("layer_name: ", name) - print("\t", "needs: ", step.needs) - print("\t", "provides: ", step.provides) - print("") + :param dag: + The original dag, pruned; not broken. + :param outputs: + outp-names to decide whether to add (and which) del-instructions + In the list :class:`DeleteInstructions` steps (DA) are inserted between + operation nodes to reduce the memory footprint of solution. + A DA is inserted whenever a *need* is not used by any other *operation* + further down the DAG. + Note that since the `solutions` are not shared across `compute()` calls, + any memory-reductions are for as long as a single computation runs. - def compile(self): - """Create a set of steps for evaluating layers - and freeing memory as necessary""" + """ - # clear compiled steps - self.steps = [] + steps = [] # create an execution order such that each layer's needs are provided. - ordered_nodes = list(nx.dag.topological_sort(self.graph)) + ordered_nodes = iset(nx.topological_sort(dag)) - # add Operations evaluation steps, and instructions to free data. + # Add Operations evaluation steps, and instructions to free and "pin" + # data. for i, node in enumerate(ordered_nodes): if isinstance(node, DataPlaceholderNode): - continue + if node in inputs and dag.pred[node]: + # Command pinning only when there is another operation + # generating this data as output. + steps.append(PinInstruction(node)) elif isinstance(node, Operation): + steps.append(node) - # add layer to list of steps - self.steps.append(node) + # Keep all values in solution if not specific outputs asked. + if not outputs: + continue # Add instructions to delete predecessors as possible. A # predecessor may be deleted if it is a data placeholder that # is no longer needed by future Operations. - for predecessor in self.graph.predecessors(node): - if self._debug: - print("checking if node %s can be deleted" % predecessor) - predecessor_still_needed = False + for need in self.graph.pred[node]: + log.debug("checking if node %s can be deleted", need) for future_node in ordered_nodes[i+1:]: - if isinstance(future_node, Operation): - if predecessor in future_node.needs: - predecessor_still_needed = True - break - if not predecessor_still_needed: - if self._debug: - print(" adding delete instruction for %s" % predecessor) - self.steps.append(DeleteInstruction(predecessor)) + if ( + isinstance(future_node, Operation) + and need in future_node.needs + ): + break + else: + if need not in outputs: + log.debug(" adding delete instruction for %s", need) + steps.append(DeleteInstruction(need)) else: - raise TypeError("Unrecognized network graph node") + raise AssertionError("Unrecognized network graph node %r" % node) + + return steps + + def _collect_unsatisfied_operations(self, dag, inputs): + """ + Traverse topologically sorted dag to collect un-satisfied operations. + Unsatisfied operations are those suffering from ANY of the following: - def _find_necessary_steps(self, outputs, inputs): + - They are missing at least one compulsory need-input. + Since the dag is ordered, as soon as we're on an operation, + all its needs have been accounted, so we can get its satisfaction. + + - Their provided outputs are not linked to any data in the dag. + An operation might not have any output link when :meth:`_prune_graph()` + has broken them, due to given intermediate inputs. + + :param dag: + a graph with broken edges those arriving to existing inputs + :param inputs: + an iterable of the names of the input values + return: + a list of unsatisfied operations to prune """ - Determines what graph steps need to pe run to get to the requested - outputs from the provided inputs. Eliminates steps that come before - (in topological order) any inputs that have been provided. Also - eliminates steps that are not on a path from the provided inputs to - the requested outputs. + # To collect data that will be produced. + ok_data = set(inputs) + # To colect the map of operations --> satisfied-needs. + op_satisfaction = defaultdict(set) + # To collect the operations to drop. + unsatisfied = [] + for node in nx.topological_sort(dag): + if isinstance(node, Operation): + if not dag.adj[node]: + # Prune operations that ended up providing no output. + unsatisfied.append(node) + else: + real_needs = set(n for n in node.needs + if not isinstance(n, optional)) + if real_needs.issubset(op_satisfaction[node]): + # We have a satisfied operation; mark its output-data + # as ok. + ok_data.update(dag.adj[node]) + else: + # Prune operations with partial inputs. + unsatisfied.append(node) + elif isinstance(node, (DataPlaceholderNode, str)): # `str` are givens + if node in ok_data: + # mark satisfied-needs on all future operations + for future_op in dag.adj[node]: + op_satisfaction[future_op].add(node) + else: + raise AssertionError("Unrecognized network graph node %r" % node) - :param list outputs: + return unsatisfied + + + def _prune_graph(self, outputs, inputs): + """ + Determines what graph steps need to run to get to the requested + outputs from the provided inputs. : + - Eliminate steps that are not on a path arriving to requested outputs. + - Eliminate unsatisfied operations: partial inputs or no outputs needed. + + :param iterable outputs: A list of desired output names. This can also be ``None``, in which case the necessary steps are all graph nodes that are reachable from one of the provided inputs. - :param dict inputs: - A dictionary mapping names to values for all provided inputs. + :param iterable inputs: + The inputs names of all given inputs. - :returns: - Returns a list of all the steps that need to be run for the - provided inputs and requested outputs. + :return: + the *pruned_dag* """ + dag = self.graph + + # Ignore input names that aren't in the graph. + graph_inputs = iset(dag.nodes) & inputs # preserve order + + # Scream if some requested outputs aren't in the graph. + unknown_outputs = iset(outputs) - dag.nodes + if unknown_outputs: + raise ValueError( + "Unknown output node(s) requested: %s" + % ", ".join(unknown_outputs)) + + broken_dag = dag.copy() # preserve net's graph + + # Break the incoming edges to all given inputs. + # + # Nodes producing any given intermediate inputs are unecessary + # (unless they are also used elsewhere). + # To discover which ones to prune, we break their incoming edges + # and they will drop out while collecting ancestors from the outputs. + broken_edges = set() # unordered, not iterated + for given in graph_inputs: + broken_edges.update(broken_dag.in_edges(given)) + broken_dag.remove_edges_from(broken_edges) + + # Drop stray input values and operations (if any). + broken_dag.remove_nodes_from(nx.isolates(broken_dag)) + + if outputs: + # If caller requested specific outputs, we can prune any + # unrelated nodes further up the dag. + ending_in_outputs = set() + for input_name in outputs: + ending_in_outputs.update(nx.ancestors(dag, input_name)) + broken_dag = broken_dag.subgraph(ending_in_outputs | set(outputs)) + + + # Prune unsatisfied operations (those with partial inputs or no outputs). + unsatisfied = self._collect_unsatisfied_operations(broken_dag, inputs) + # Clone it so that it is picklable. + pruned_dag = dag.subgraph(broken_dag.nodes - unsatisfied).copy() + + return pruned_dag, tuple(broken_edges) + + def compile(self, inputs=(), outputs=()): + """ + Create or get from cache an execution-plan for the given inputs/outputs. - # return steps if it has already been computed before for this set of inputs and outputs - outputs = tuple(sorted(outputs)) if isinstance(outputs, (list, set)) else outputs - inputs_keys = tuple(sorted(inputs.keys())) - cache_key = (inputs_keys, outputs) - if cache_key in self._necessary_steps_cache: - return self._necessary_steps_cache[cache_key] + See :meth:`_prune_graph()` and :meth:`_build_execution_steps()` + for detailed description. - graph = self.graph - if not outputs: + :param inputs: + An iterable with the names of all the given inputs. - # If caller requested all outputs, the necessary nodes are all - # nodes that are reachable from one of the inputs. Ignore input - # names that aren't in the graph. - necessary_nodes = set() - for input_name in iter(inputs): - if graph.has_node(input_name): - necessary_nodes |= nx.descendants(graph, input_name) + :param outputs: + (optional) An iterable or the name of the output name(s). + If missing, requested outputs assumed all graph reachable nodes + from one of the given inputs. + :return: + the cached or fresh new execution-plan + """ + # outputs must be iterable + if not outputs: + outputs = () + elif isinstance(outputs, str): + outputs = (outputs, ) + + # Make a stable cache-key + cache_key = (tuple(sorted(inputs)), tuple(sorted(outputs))) + if cache_key in self._cached_plans: + # An execution plan has been compiled before + # for the same inputs & outputs. + plan = self._cached_plans[cache_key] else: + # Build a new execution plan for the given inputs & outputs. + # + pruned_dag, broken_edges = self._prune_graph(outputs, inputs) + steps = self._build_execution_steps(pruned_dag, inputs, outputs) + plan = ExecutionPlan( + self, + tuple(inputs), + outputs, + pruned_dag, + tuple(broken_edges), + tuple(steps), + executed=iset(), + ) + + # Cache compilation results to speed up future runs + # with different values (but same number of inputs/outputs). + self._cached_plans[cache_key] = plan + + return plan + + def compute( + self, named_inputs, outputs, method=None, overwrites_collector=None): + """ + Solve & execute the graph, sequentially or parallel. - # If the caller requested a subset of outputs, find any nodes that - # are made unecessary because we were provided with an input that's - # deeper into the network graph. Ignore input names that aren't - # in the graph. - unnecessary_nodes = set() - for input_name in iter(inputs): - if graph.has_node(input_name): - unnecessary_nodes |= nx.ancestors(graph, input_name) + :param dict named_inputs: + A dict of key/value pairs where the keys represent the data nodes + you want to populate, and the values are the concrete values you + want to set for the data node. - # Find the nodes we need to be able to compute the requested - # outputs. Raise an exception if a requested output doesn't - # exist in the graph. - necessary_nodes = set() - for output_name in outputs: - if not graph.has_node(output_name): - raise ValueError("graphkit graph does not have an output " - "node named %s" % output_name) - necessary_nodes |= nx.ancestors(graph, output_name) + :param list output: + once all necessary computations are complete. + If you set this variable to ``None``, all data nodes will be kept + and returned at runtime. - # Get rid of the unnecessary nodes from the set of necessary ones. - necessary_nodes -= unnecessary_nodes + :param method: + if ``"parallel"``, launches multi-threading. + Set when invoking a composed graph or by + :meth:`~NetworkOperation.set_execution_method()`. + :param overwrites_collector: + (optional) a mutable dict to be fillwed with named values. + If missing, values are simply discarded. - necessary_steps = [step for step in self.steps if step in necessary_nodes] + :returns: a dictionary of output data objects, keyed by name. + """ - # save this result in a precomputed cache for future lookup - self._necessary_steps_cache[cache_key] = necessary_steps + assert isinstance(outputs, (list, tuple)) or outputs is None,\ + "The outputs argument must be a list" - # Return an ordered list of the needed steps. - return necessary_steps + # Build the execution plan. + self.last_plan = plan = self.compile(named_inputs.keys(), outputs) + # start with fresh data solution. + solution = dict(named_inputs) - def compute(self, outputs, named_inputs, method=None): - """ - Run the graph. Any inputs to the network must be passed in by name. + plan.execute(solution, overwrites_collector, method) - :param list output: The names of the data node you'd like to have returned - once all necessary computations are complete. - If you set this variable to ``None``, all - data nodes will be kept and returned at runtime. + if outputs: + # Filter outputs to just return what's requested. + # Otherwise, eturn the whole solution as output, + # including input and intermediate data nodes. + # TODO: assert no other outputs exists due to DelInstructs. + solution = dict(i for i in solution.items() if i[0] in outputs) - :param dict named_inputs: A dict of key/value pairs where the keys - represent the data nodes you want to populate, - and the values are the concrete values you - want to set for the data node. + return solution - :returns: a dictionary of output data objects, keyed by name. +class ExecutionPlan( + namedtuple("_ExePlan", "net inputs outputs dag broken_edges steps executed"), + plot.Plotter +): + """ + The result of the network's compilation phase. + + Note the execution plan's attributes are on purpose immutable tuples. + + :ivar net: + The parent :class:`Network` + + :ivar inputs: + A tuple with the names of the given inputs used to construct the plan. + + :ivar outputs: + A (possibly empy) tuple with the names of the requested outputs + used to construct the plan. + + :ivar dag: + The regular (not broken) *pruned* subgraph of net-graph. + + :ivar broken_edges: + Tuple of broken incoming edges to given data. + + :ivar steps: + The tuple of operation-nodes & *instructions* needed to evaluate + the given inputs & asked outputs, free memory and avoid overwritting + any given intermediate inputs. + :ivar executed: + An empty set to collect all operations that have been executed so far. + """ + + @property + def broken_dag(self): + return nx.restricted_view(self.dag, nodes=(), edges=self.broken_edges) + + def _build_pydot(self, **kws): + from .plot import build_pydot + + clusters = None + if self.dag.nodes != self.net.graph.nodes: + clusters = {n: "after prunning" for n in self.dag.nodes} + mykws = { + "graph": self.net.graph, + "steps": self.steps, + "inputs": self.inputs, + "outputs": self.outputs, + "executed": self.executed, + "edge_props": {e: {"color": "wheat", "penwidth": 2} for e in self.broken_edges}, + "clusters": clusters, + } + mykws.update(kws) + + return build_pydot(**mykws) + + def __repr__(self): + return ( + "ExecutionPlan:\n +--inputs:%s, \n +--outputs=%s\n +--steps=%s)" + % (self.inputs, self.outputs, self.steps)) + + def get_data_node(self, name): + """ + Retuen the data node from a graph using its name, or None. """ + node = self.dag.nodes[name] + if isinstance(node, DataPlaceholderNode): + return node - # assert that network has been compiled - assert self.steps, "network must be compiled before calling compute." - assert isinstance(outputs, (list, tuple)) or outputs is None,\ - "The outputs argument must be a list" + def _can_schedule_operation(self, op): + """ + Determines if a Operation is ready to be scheduled for execution + based on what has already been executed. - # choose a method of execution - if method == "parallel": - return self._compute_thread_pool_barrier_method(named_inputs, - outputs) - else: - return self._compute_sequential_method(named_inputs, - outputs) + :param op: + The Operation object to check + :return: + A boolean indicating whether the operation may be scheduled for + execution based on what has already been executed. + """ + # Use `broken_dag` to allow executing operations after given inputs + # regardless of whether their producers have yet to run. + dependencies = set(n for n in nx.ancestors(self.broken_dag, op) + if isinstance(n, Operation)) + return dependencies.issubset(self.executed) + def _can_evict_value(self, name): + """ + Determines if a DataPlaceholderNode is ready to be deleted from solution. - def _compute_thread_pool_barrier_method(self, named_inputs, outputs, - thread_pool_size=10): + :param name: + The name of the data node to check + :return: + A boolean indicating whether the data node can be deleted or not. + """ + data_node = self.get_data_node(name) + # Use `broken_dag` not to block a successor waiting for this data, + # since in any case will use a given input, not some pipe of this data. + return data_node and set( + self.broken_dag.successors(data_node)).issubset(self.executed) + + def _pin_data_in_solution(self, value_name, solution, inputs, overwrites): + value_name = str(value_name) + if overwrites is not None: + overwrites[value_name] = solution[value_name] + solution[value_name] = inputs[value_name] + + def _execute_thread_pool_barrier_method(self, inputs, solution, overwrites, + thread_pool_size=10 + ): """ This method runs the graph using a parallel pool of thread executors. You may achieve lower total latency if your graph is sufficiently @@ -257,41 +585,42 @@ def _compute_thread_pool_barrier_method(self, named_inputs, outputs, from multiprocessing.dummy import Pool # if we have not already created a thread_pool, create one - if not hasattr(self, "_thread_pool"): - self._thread_pool = Pool(thread_pool_size) - pool = self._thread_pool - - cache = {} - cache.update(named_inputs) - necessary_nodes = self._find_necessary_steps(outputs, named_inputs) - - # this keeps track of all nodes that have already executed - has_executed = set() + if not hasattr(self.net, "_thread_pool"): + self.net._thread_pool = Pool(thread_pool_size) + pool = self.net._thread_pool # with each loop iteration, we determine a set of operations that can be # scheduled, then schedule them onto a thread pool, then collect their - # results onto a memory cache for use upon the next iteration. + # results onto a memory solution for use upon the next iteration. while True: # the upnext list contains a list of operations for scheduling # in the current round of scheduling upnext = [] - for node in necessary_nodes: - # only delete if all successors for the data node have been executed - if isinstance(node, DeleteInstruction): - if ready_to_delete_data_node(node, - has_executed, - self.graph): - if node in cache: - cache.pop(node) - - # continue if this node is anything but an operation node - if not isinstance(node, Operation): - continue - - if ready_to_schedule_operation(node, has_executed, self.graph) \ - and node not in has_executed: + for node in self.steps: + if ( + isinstance(node, Operation) + and self._can_schedule_operation(node) + and node not in self.executed + ): upnext.append(node) + elif isinstance(node, DeleteInstruction): + # Only delete if all successors for the data node + # have been executed. + # An optional need may not have a value in the solution. + if ( + node in solution + and self._can_evict_value(node) + ): + log.debug("removing data '%s' from solution.", node) + del solution[node] + elif isinstance(node, PinInstruction): + # Always and repeatedely pin the value, even if not all + # providers of the data have executed. + # An optional need may not have a value in the solution. + if node in solution: + self._pin_data_in_solution( + node, solution, inputs, overwrites) # stop if no nodes left to schedule, exit out of the loop @@ -299,198 +628,76 @@ def _compute_thread_pool_barrier_method(self, named_inputs, outputs, break done_iterator = pool.imap_unordered( - lambda op: (op,op._compute(cache)), + lambda op: (op,op._compute(solution)), upnext) for op, result in done_iterator: - cache.update(result) - has_executed.add(op) + solution.update(result) + self.executed.add(op) - if not outputs: - return cache - else: - return {k: cache[k] for k in iter(cache) if k in outputs} - def _compute_sequential_method(self, named_inputs, outputs): + def _execute_sequential_method(self, inputs, solution, overwrites): """ This method runs the graph one operation at a time in a single thread """ - # start with fresh data cache - cache = {} - - # add inputs to data cache - cache.update(named_inputs) - - # Find the subset of steps we need to run to get to the requested - # outputs from the provided inputs. - all_steps = self._find_necessary_steps(outputs, named_inputs) - self.times = {} - for step in all_steps: + for step in self.steps: if isinstance(step, Operation): - if self._debug: - print("-"*32) - print("executing step: %s" % step.name) + log.debug("%sexecuting step: %s", "-"*32, step.name) # time execution... t0 = time.time() # compute layer outputs - layer_outputs = step._compute(cache) + layer_outputs = step._compute(solution) - # add outputs to cache - cache.update(layer_outputs) + # add outputs to solution + solution.update(layer_outputs) + self.executed.add(step) # record execution time t_complete = round(time.time() - t0, 5) self.times[step.name] = t_complete - if self._debug: - print("step completion time: %s" % t_complete) + log.debug("step completion time: %s", t_complete) - # Process DeleteInstructions by deleting the corresponding data - # if possible. elif isinstance(step, DeleteInstruction): + # Cache value may be missing if it is optional. + if step in solution: + log.debug("removing data '%s' from solution.", step) + del solution[step] - if outputs and step not in outputs: - # Some DeleteInstruction steps may not exist in the cache - # if they come from optional() needs that are not privoded - # as inputs. Make sure the step exists before deleting. - if step in cache: - if self._debug: - print("removing data '%s' from cache." % step) - cache.pop(step) - + elif isinstance(step, PinInstruction): + self._pin_data_in_solution(step, solution, inputs, overwrites) else: - raise TypeError("Unrecognized instruction.") + raise AssertionError("Unrecognized instruction.%r" % step) - if not outputs: - # Return the whole cache as output, including input and - # intermediate data nodes. - return cache - - else: - # Filter outputs to just return what's needed. - # Note: list comprehensions exist in python 2.7+ - return {k: cache[k] for k in iter(cache) if k in outputs} - - - def plot(self, filename=None, show=False): + def execute(self, solution, overwrites=None, method=None): """ - Plot the graph. - - params: - :param str filename: - Write the output to a png, pdf, or graphviz dot file. The extension - controls the output format. - - :param boolean show: - If this is set to True, use matplotlib to show the graph diagram - (Default: False) - - :returns: - An instance of the pydot graph - + :param solution: + a mutable maping to collect the results and that must contain also + the given input values for at least the compulsory inputs that + were specified when the plan was built (but cannot enforce that!). + + :param overwrites: + (optional) a mutable dict to collect calculated-but-discarded values + because they were "pinned" by input vaules. + If missing, the overwrites values are simply discarded. """ - import pydot - import matplotlib.pyplot as plt - import matplotlib.image as mpimg - - assert self.graph is not None - - def get_node_name(a): - if isinstance(a, DataPlaceholderNode): - return a - return a.name - - g = pydot.Dot(graph_type="digraph") - - # draw nodes - for nx_node in self.graph.nodes(): - if isinstance(nx_node, DataPlaceholderNode): - node = pydot.Node(name=nx_node, shape="rect") - else: - node = pydot.Node(name=nx_node.name, shape="circle") - g.add_node(node) - - # draw edges - for src, dst in self.graph.edges(): - src_name = get_node_name(src) - dst_name = get_node_name(dst) - edge = pydot.Edge(src=src_name, dst=dst_name) - g.add_edge(edge) - - # save plot - if filename: - basename, ext = os.path.splitext(filename) - with open(filename, "w") as fh: - if ext.lower() == ".png": - fh.write(g.create_png()) - elif ext.lower() == ".dot": - fh.write(g.to_string()) - elif ext.lower() in [".jpg", ".jpeg"]: - fh.write(g.create_jpeg()) - elif ext.lower() == ".pdf": - fh.write(g.create_pdf()) - elif ext.lower() == ".svg": - fh.write(g.create_svg()) - else: - raise Exception("Unknown file format for saving graph: %s" % ext) + # Clean executed operation from any previous execution. + self.executed.clear() - # display graph via matplotlib - if show: - png = g.create_png() - sio = StringIO(png) - img = mpimg.imread(sio) - plt.imshow(img, aspect="equal") - plt.show() - - return g - - -def ready_to_schedule_operation(op, has_executed, graph): - """ - Determines if a Operation is ready to be scheduled for execution based on - what has already been executed. - - Args: - op: - The Operation object to check - has_executed: set - A set containing all operations that have been executed so far - graph: - The networkx graph containing the operations and data nodes - Returns: - A boolean indicating whether the operation may be scheduled for - execution based on what has already been executed. - """ - dependencies = set(filter(lambda v: isinstance(v, Operation), - nx.ancestors(graph, op))) - return dependencies.issubset(has_executed) + # choose a method of execution + executor = (self._execute_thread_pool_barrier_method + if method == "parallel" else + self._execute_sequential_method) -def ready_to_delete_data_node(name, has_executed, graph): - """ - Determines if a DataPlaceholderNode is ready to be deleted from the - cache. + # clone and keep orignal inputs in solution intact + executor(dict(solution), solution, overwrites) - Args: - name: - The name of the data node to check - has_executed: set - A set containing all operations that have been executed so far - graph: - The networkx graph containing the operations and data nodes - Returns: - A boolean indicating whether the data node can be deleted or not. - """ - data_node = get_data_node(name, graph) - return set(graph.successors(data_node)).issubset(has_executed) + # return it, but caller can also see the results in `solution` dict. + return solution -def get_data_node(name, graph): - """ - Gets a data node from a graph using its name - """ - for node in graph.nodes(): - if node == name and isinstance(node, DataPlaceholderNode): - return node - return None +# TODO: maybe class Solution(object): +# values = {} +# overwrites = None diff --git a/graphkit/plot.py b/graphkit/plot.py new file mode 100644 index 00000000..69138d06 --- /dev/null +++ b/graphkit/plot.py @@ -0,0 +1,411 @@ +# Copyright 2016, Yahoo Inc. +# Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms. + +import io +import logging +import os + + +log = logging.getLogger(__name__) + + +class Plotter(object): + """ + Classes wishing to plot their graphs should inherit this and ... + + implement property ``_plot`` to return a "partial" callable that somehow + ends up calling :func:`plot.plot_graph()` with the `graph` or any other + args binded appropriately. + The purpose is to avoid copying this function & documentation here around. + """ + + def plot(self, filename=None, show=False, **kws): + """ + :param str filename: + Write diagram into a file. + Common extensions are ``.png .dot .jpg .jpeg .pdf .svg`` + call :func:`plot.supported_plot_formats()` for more. + :param show: + If it evaluates to true, opens the diagram in a matplotlib window. + If it equals `-1`, it plots but does not open the Window. + :param inputs: + an optional name list, any nodes in there are plotted + as a "house" + :param outputs: + an optional name list, any nodes in there are plotted + as an "inverted-house" + :param solution: + an optional dict with values to annotate nodes, drawn "filled" + (currently content not shown, but node drawn as "filled") + :param executed: + an optional container with operations executed, drawn "filled" + :param title: + an optional string to display at the bottom of the graph + :param node_props: + an optional nested dict of Grapvhiz attributes for certain nodes + :param edge_props: + an optional nested dict of Grapvhiz attributes for certain edges + :param clusters: + an optional mapping of nodes --> cluster-names, to group them + + :return: + A ``pydot.Dot`` instance. + NOTE that the returned instance is monkeypatched to support + direct rendering in *jupyter cells* as SVG. + + + Note that the `graph` argument is absent - Each Plotter provides + its own graph internally; use directly :func:`plot_graph()` to provide + a different graph. + + **Legend:** + + . figure:: ../images/Graphkitlegend.svg + :alt: Graphkit Legend + + see :func:`legend()` + + *NODES:* + + oval + function + egg + subgraph operation + house + given input + inversed-house + asked output + polygon + given both as input & asked as output (what?) + square + intermediate data, neither given nor asked. + red frame + delete-instruction, to free up memory. + blue frame + pinned-instruction, not to overwrite intermediate inputs. + filled + data node has a value in `solution` OR function has been executed. + thick frame + function/data node in execution `steps`. + + *ARROWS* + + solid black arrows + dependencies (source-data are``need``-ed by target-operations, + sources-operations ``provide`` target-data) + dashed black arrows + optional needs + wheat arrows + broken dependency (``provide``) during pruning + green-dotted arrows + execution steps labeled in succession + + **Sample code:** + + >>> from graphkit import compose, operation + >>> from graphkit.modifiers import optional + + >>> graphop = compose(name="graphop")( + ... operation(name="add", needs=["a", "b1"], provides=["ab1"])(add), + ... operation(name="sub", needs=["a", optional("b2")], provides=["ab2"])(lambda a, b=1: a-b), + ... operation(name="abb", needs=["ab1", "ab2"], provides=["asked"])(add), + ... ) + + >>> graphop.plot(show=True); # plot just the graph in a matplotlib window + >>> inputs = {'a': 1, 'b1': 2} + >>> solution = graphop(inputs) # now plots will include the execution-plan + + >>> graphop.plot('plot1.svg', inputs=inputs, outputs=['asked', 'b1'], solution=solution); + >>> graphop.plot(solution=solution) # just get the `pydoit.Dot` object, renderable in Jupyter + """ + dot = self._build_pydot(**kws) + return render_pydot(dot, filename=filename, show=show) + + def _build_pydot(self, **kws): + raise AssertionError("Must implement that!") + + +def _is_class_value_in_list(lst, cls, value): + return any(isinstance(i, cls) and i == value for i in lst) + + +def _merge_conditions(*conds): + """combines conditions as a choice in binary range, eg, 2 conds --> [0, 3]""" + return sum(int(bool(c)) << i for i, c in enumerate(conds)) + + +def _apply_user_props(dotobj, user_props, key): + if user_props and key in user_props: + dotobj.get_attributes().update(user_props[key]) + # Delete it, to report unmatched ones, AND not to annotate `steps`. + del user_props[key] + + +def _report_unmatched_user_props(user_props, kind): + if user_props and log.isEnabledFor(logging.WARNING): + unmatched = "\n ".join(str(i) for i in user_props.items()) + log.warning("Unmatched `%s_props`:\n +--%s", kind, unmatched) + + +def _monkey_patch_for_jupyter(pydot): + # Ensure Dot nstance render in Jupyter + # (see pydot/pydot#220) + if not hasattr(pydot.Dot, "_repr_svg_"): + + def make_svg(self): + return self.create_svg().decode() + + # monkey patch class + pydot.Dot._repr_svg_ = make_svg + + +def build_pydot( + graph, + steps=None, + inputs=None, + outputs=None, + solution=None, + executed=None, + title=None, + node_props=None, + edge_props=None, + clusters=None, +): + """ + Build a *Graphviz* out of a Network graph/steps/inputs/outputs and return it. + + See :meth:`Plotter.plot()` for the arguments, sample code, and + the legend of the plots. + """ + import pydot + from .base import NetworkOperation, Operation + from .modifiers import optional + from .network import DeleteInstruction, PinInstruction + + _monkey_patch_for_jupyter(pydot) + + assert graph is not None + + steps_thickness = 3 + fill_color = "wheat" + steps_color = "#009999" + new_clusters = {} + + def append_or_cluster_node(dot, nx_node, node): + if not clusters or not nx_node in clusters: + dot.add_node(node) + else: + cluster_name = clusters[nx_node] + node_cluster = new_clusters.get(cluster_name) + if not node_cluster: + node_cluster = new_clusters[cluster_name] = pydot.Cluster( + cluster_name, label=cluster_name + ) + node_cluster.add_node(node) + + def append_any_clusters(dot): + for cluster in new_clusters.values(): + dot.add_subgraph(cluster) + + def get_node_name(a): + if isinstance(a, Operation): + return a.name + return a + + dot = pydot.Dot(graph_type="digraph", label=title, fontname="italic") + + # draw nodes + for nx_node in graph.nodes: + if isinstance(nx_node, str): + kw = {} + + # FrameColor change by step type + if steps and nx_node in steps: + choice = _merge_conditions( + _is_class_value_in_list(steps, DeleteInstruction, nx_node), + _is_class_value_in_list(steps, PinInstruction, nx_node), + ) + # 0 is singled out because `nx_node` exists in `steps`. + color = "NOPE #990000 blue purple".split()[choice] + kw = {"color": color, "penwidth": steps_thickness} + + # SHAPE change if with inputs/outputs. + # tip: https://graphviz.gitlab.io/_pages/doc/info/shapes.html + choice = _merge_conditions( + inputs and nx_node in inputs, outputs and nx_node in outputs + ) + shape = "rect invhouse house hexagon".split()[choice] + + # LABEL change with solution. + if solution and nx_node in solution: + kw["style"] = "filled" + kw["fillcolor"] = fill_color + # kw["tooltip"] = str(solution.get(nx_node)) # not working :-() + node = pydot.Node(name=nx_node, shape=shape, **kw) + else: # Operation + kw = {} + + if steps and nx_node in steps: + kw["penwdth"] = steps_thickness + shape = "egg" if isinstance(nx_node, NetworkOperation) else "oval" + if executed and nx_node in executed: + kw["style"] = "filled" + kw["fillcolor"] = fill_color + node = pydot.Node(name=nx_node.name, shape=shape, **kw) + + _apply_user_props(node, node_props, key=node.get_name()) + + append_or_cluster_node(dot, nx_node, node) + + _report_unmatched_user_props(node_props, "node") + + append_any_clusters(dot) + + # draw edges + for src, dst in graph.edges: + src_name = get_node_name(src) + dst_name = get_node_name(dst) + kw = {} + if isinstance(dst, Operation) and _is_class_value_in_list( + dst.needs, optional, src + ): + kw["style"] = "dashed" + edge = pydot.Edge(src=src_name, dst=dst_name, **kw) + + _apply_user_props(edge, edge_props, key=(src, dst)) + + dot.add_edge(edge) + + _report_unmatched_user_props(edge_props, "edge") + + # draw steps sequence + if steps and len(steps) > 1: + it1 = iter(steps) + it2 = iter(steps) + next(it2) + for i, (src, dst) in enumerate(zip(it1, it2), 1): + src_name = get_node_name(src) + dst_name = get_node_name(dst) + edge = pydot.Edge( + src=src_name, + dst=dst_name, + label=str(i), + style="dotted", + color=steps_color, + fontcolor=steps_color, + fontname="bold", + fontsize=18, + penwidth=steps_thickness, + arrowhead="vee", + ) + dot.add_edge(edge) + + return dot + + +def supported_plot_formats(): + """return automatically all `pydot` extensions""" + import pydot + + return [".%s" % f for f in pydot.Dot().formats] + + +def render_pydot(dot, filename=None, show=False): + """ + Plot a *Graphviz* dot in a matplotlib, in file or return it for Jupyter. + + :param dot: + the pre-built *Graphviz* dot instance + :param str filename: + Write diagram into a file. + Common extensions are ``.png .dot .jpg .jpeg .pdf .svg`` + call :func:`plot.supported_plot_formats()` for more. + :param show: + If it evaluates to true, opens the diagram in a matplotlib window. + If it equals `-1`, it returns the image but does not open the Window. + + :return: + the matplotlib image if ``show=-1``, or the `dot`. + + See :meth:`Plotter.plot()` for sample code. + """ + # Save plot + # + if filename: + formats = supported_plot_formats() + _basename, ext = os.path.splitext(filename) + if not ext.lower() in formats: + raise ValueError( + "Unknown file format for saving graph: %s" + " File extensions must be one of: %s" % (ext, " ".join(formats)) + ) + + dot.write(filename, format=ext.lower()[1:]) + + ## Display graph via matplotlib + # + if show: + import matplotlib.pyplot as plt + import matplotlib.image as mpimg + + png = dot.create_png() + sio = io.BytesIO(png) + img = mpimg.imread(sio) + if show != -1: + plt.imshow(img, aspect="equal") + plt.show() + + return img + + return dot + + +def legend(filename=None, show=None): + """Generate a legend for all plots (see Plotter.plot() for args)""" + import pydot + + _monkey_patch_for_jupyter(pydot) + + ## From https://stackoverflow.com/questions/3499056/making-a-legend-key-in-graphviz + dot_text = """ + digraph { + rankdir=LR; + subgraph cluster_legend { + label="Graphkit Legend"; + + operation [shape=oval]; + graphop [shape=egg label="graph operation"]; + insteps [penwidth=3 label="execution step"]; + executed [style=filled fillcolor=wheat]; + operation -> graphop -> insteps -> executed [style=invis]; + + data [shape=rect]; + input [shape=invhouse]; + output [shape=house]; + inp_out [shape=hexagon label="inp+out"]; + evicted [shape=rect penwidth=3 color="#990000"]; + pinned [shape=rect penwidth=3 color="blue"]; + evpin [shape=rect penwidth=3 color=purple label="evict+pin"]; + sol [shape=rect style=filled fillcolor=wheat label="in solution"]; + data -> input -> output -> inp_out -> evicted -> pinned -> evpin -> sol [style=invis]; + + e1 [style=invis] e2 [color=invis label="dependency"]; + e1 -> e2; + e3 [color=invis label="optional"]; + e2 -> e3 [style=dashed]; + e4 [color=invis penwidth=3 label="pruned dependency"]; + e3 -> e4 [color=wheat penwidth=2]; + e5 [color=invis penwidth=4 label="execution sequence"]; + e4 -> e5 [color="#009999" penwidth=4 style=dotted arrowhead=vee label=1 fontcolor="#009999"]; + } + } + """ + + dot = pydot.graph_from_dot_data(dot_text)[0] + # clus = pydot.Cluster("Graphkit legend", label="Graphkit legend") + # dot.add_subgraph(clus) + + # nodes = dot.Node() + # clus.add_node("operation") + + return render_pydot(dot, filename=filename, show=show) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..2e5ce9fc --- /dev/null +++ b/setup.cfg @@ -0,0 +1,10 @@ +## Python's setup.cfg for tool defaults: +# +[bdist_wheel] +universal = 1 + + +[tool:pytest] +# See http://doc.pytest.org/en/latest/mark.html#mark +markers = + slow: marks tests as slow, select them with `-m slow` or `-m 'not slow'` diff --git a/setup.py b/setup.py index bd7883f4..bf91ab44 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,15 @@ with io.open('graphkit/__init__.py', 'rt', encoding='utf8') as f: version = re.search(r'__version__ = \'(.*?)\'', f.read()).group(1) +plot_reqs = [ + "matplotlib", # to test plot + "pydot", # to test plot +] +test_reqs = plot_reqs + [ + "pytest", + "pytest-cov", +] + setup( name='graphkit', version=version, @@ -28,11 +37,16 @@ author_email='huyng@yahoo-inc.com', url='http://github.com/yahoo/graphkit', packages=['graphkit'], - install_requires=['networkx'], + install_requires=[ + "networkx; python_version >= '3.5'", + "networkx == 2.2; python_version < '3.5'", + "boltons" # for IndexSet + ], extras_require={ - 'plot': ['pydot', 'matplotlib'] + 'plot': plot_reqs, + 'test': test_reqs, }, - tests_require=['numpy'], + tests_require=test_reqs, license='Apache-2.0', keywords=['graph', 'computation graph', 'DAG', 'directed acyclical graph'], classifiers=[ diff --git a/test/test_graphkit.py b/test/test_graphkit.py index bd97b317..dc9e0a6d 100644 --- a/test/test_graphkit.py +++ b/test/test_graphkit.py @@ -3,22 +3,43 @@ import math import pickle - +from operator import add, floordiv, mul, sub from pprint import pprint -from operator import add -from numpy.testing import assert_raises -import graphkit.network as network +import pytest + import graphkit.modifiers as modifiers -from graphkit import operation, compose, Operation +import graphkit.network as network +from graphkit import Operation, compose, operation +from graphkit.network import DeleteInstruction + + +def scream(*args, **kwargs): + raise AssertionError( + "Must not have run!\n args: %s\n kwargs: %s", (args, kwargs)) + + +def identity(x): + return x + + +def filtdict(d, *keys): + """ + Keep dict items with the given keys + + >>> filtdict({"a": 1, "b": 2}, "b") + {"b": 2} + """ + return type(d)(i for i in d.items() if i[0] in keys) + -def test_network(): +def test_network_smoke(): # Sum operation, late-bind compute function sum_op1 = operation(name='sum_op1', needs=['a', 'b'], provides='sum_ab')(add) # sum_op1 is callable - print(sum_op1(1, 2)) + assert sum_op1(1, 2) == 3 # Multiply operation, decorate in-place @operation(name='mul_op1', needs=['sum_ab', 'b'], provides='sum_ab_times_b') @@ -26,14 +47,14 @@ def mul_op1(a, b): return a * b # mul_op1 is callable - print(mul_op1(1, 2)) + assert mul_op1(1, 2) == 2 # Pow operation @operation(name='pow_op1', needs='sum_ab', provides=['sum_ab_p1', 'sum_ab_p2', 'sum_ab_p3'], params={'exponent': 3}) def pow_op1(a, exponent=2): return [math.pow(a, y) for y in range(1, exponent+1)] - print(pow_op1._compute({'sum_ab':2}, ['sum_ab_p2'])) + assert pow_op1._compute({'sum_ab':2}, ['sum_ab_p2']) == {'sum_ab_p2': 4.0} # Partial operation that is bound at a later time partial_op = operation(name='sum_op2', needs=['sum_ab_p1', 'sum_ab_p2'], provides='p1_plus_p2') @@ -47,7 +68,7 @@ def pow_op1(a, exponent=2): sum_op3 = sum_op_factory(name='sum_op3', needs=['a', 'b'], provides='sum_ab2') # sum_op3 is callable - print(sum_op3(5, 6)) + assert sum_op3(5, 6) == 11 # compose network net = compose(name='my network')(sum_op1, mul_op1, pow_op1, sum_op2, sum_op3) @@ -57,13 +78,24 @@ def pow_op1(a, exponent=2): # # get all outputs - pprint(net({'a': 1, 'b': 2})) + exp = {'a': 1, + 'b': 2, + 'p1_plus_p2': 12.0, + 'sum_ab': 3, + 'sum_ab2': 3, + 'sum_ab_p1': 3.0, + 'sum_ab_p2': 9.0, + 'sum_ab_p3': 27.0, + 'sum_ab_times_b': 6} + assert net({'a': 1, 'b': 2}) == exp # get specific outputs - pprint(net({'a': 1, 'b': 2}, outputs=["sum_ab_times_b"])) + exp = {'sum_ab_times_b': 6} + assert net({'a': 1, 'b': 2}, outputs=["sum_ab_times_b"]) == exp # start with inputs already computed - pprint(net({"sum_ab": 1, "b": 2}, outputs=["sum_ab_times_b"])) + exp = {'sum_ab_times_b': 2} + assert net({"sum_ab": 1, "b": 2}, outputs=["sum_ab_times_b"]) == exp # visualize network graph # net.plot(show=True) @@ -75,15 +107,23 @@ def test_network_simple_merge(): sum_op2 = operation(name='sum_op2', needs=['a', 'b'], provides='sum2')(add) sum_op3 = operation(name='sum_op3', needs=['sum1', 'c'], provides='sum3')(add) net1 = compose(name='my network 1')(sum_op1, sum_op2, sum_op3) - pprint(net1({'a': 1, 'b': 2, 'c': 4})) + + exp = {'a': 1, 'b': 2, 'c': 4, 'sum1': 3, 'sum2': 3, 'sum3': 7} + sol = net1({'a': 1, 'b': 2, 'c': 4}) + assert sol == exp sum_op4 = operation(name='sum_op1', needs=['d', 'e'], provides='a')(add) sum_op5 = operation(name='sum_op2', needs=['a', 'f'], provides='b')(add) + net2 = compose(name='my network 2')(sum_op4, sum_op5) - pprint(net2({'d': 1, 'e': 2, 'f': 4})) + exp = {'a': 3, 'b': 7, 'd': 1, 'e': 2, 'f': 4} + sol = net2({'d': 1, 'e': 2, 'f': 4}) + assert sol == exp net3 = compose(name='merged')(net1, net2) - pprint(net3({'c': 5, 'd': 1, 'e': 2, 'f': 4})) + exp = {'a': 3, 'b': 7, 'c': 5, 'd': 1, 'e': 2, 'f': 4, 'sum1': 10, 'sum2': 10, 'sum3': 15} + sol = net3({'c': 5, 'd': 1, 'e': 2, 'f': 4}) + assert sol == exp def test_network_deep_merge(): @@ -92,15 +132,40 @@ def test_network_deep_merge(): sum_op2 = operation(name='sum_op2', needs=['a', 'b'], provides='sum2')(add) sum_op3 = operation(name='sum_op3', needs=['sum1', 'c'], provides='sum3')(add) net1 = compose(name='my network 1')(sum_op1, sum_op2, sum_op3) - pprint(net1({'a': 1, 'b': 2, 'c': 4})) + + exp = {'a': 1, 'b': 2, 'c': 4, 'sum1': 3, 'sum2': 3, 'sum3': 7} + assert net1({'a': 1, 'b': 2, 'c': 4}) == exp sum_op4 = operation(name='sum_op1', needs=['a', 'b'], provides='sum1')(add) sum_op5 = operation(name='sum_op4', needs=['sum1', 'b'], provides='sum2')(add) net2 = compose(name='my network 2')(sum_op4, sum_op5) - pprint(net2({'a': 1, 'b': 2})) + exp = {'a': 1, 'b': 2, 'sum1': 3, 'sum2': 5} + assert net2({'a': 1, 'b': 2}) == exp net3 = compose(name='merged', merge=True)(net1, net2) - pprint(net3({'a': 1, 'b': 2, 'c': 4})) + exp = {'a': 1, 'b': 2, 'c': 4, 'sum1': 3, 'sum2': 3, 'sum3': 7} + assert net3({'a': 1, 'b': 2, 'c': 4}) == exp + + +def test_network_merge_in_doctests(): + def abspow(a, p): + c = abs(a) ** p + return c + + graphop = compose(name="graphop")( + operation(name="mul1", needs=["a", "b"], provides=["ab"])(mul), + operation(name="sub1", needs=["a", "ab"], provides=["a_minus_ab"])(sub), + operation(name="abspow1", needs=["a_minus_ab"], provides=["abs_a_minus_ab_cubed"], params={"p": 3}) + (abspow) + ) + + another_graph = compose(name="another_graph")( + operation(name="mul1", needs=["a", "b"], provides=["ab"])(mul), + operation(name="mul2", needs=["c", "ab"], provides=["cab"])(mul) + ) + merged_graph = compose(name="merged_graph", merge=True)(graphop, another_graph) + assert merged_graph.needs + assert merged_graph.provides def test_input_based_pruning(): @@ -138,7 +203,7 @@ def test_output_based_pruning(): sum_op3 = operation(name='sum_op3', needs=['c', 'sum2'], provides='sum3')(add) net = compose(name='test_net')(sum_op1, sum_op2, sum_op3) - results = net({'c': c, 'd': d}, outputs=['sum3']) + results = net({'a': 0, 'b': 0, 'c': c, 'd': d}, outputs=['sum3']) # Make sure we got expected result without having to pass a or b. assert 'sum3' in results @@ -180,8 +245,230 @@ def test_pruning_raises_for_bad_output(): # Request two outputs we can compute and one we can't compute. Assert # that this raises a ValueError. - assert_raises(ValueError, net, {'a': 1, 'b': 2, 'c': 3, 'd': 4}, - outputs=['sum1', 'sum3', 'sum4']) + with pytest.raises(ValueError) as exinfo: + net({'a': 1, 'b': 2, 'c': 3, 'd': 4}, + outputs=['sum1', 'sum3', 'sum4']) + assert exinfo.match('sum4') + +def test_pruning_not_overrides_given_intermediate(): + # Test #25: v1.2.4 overwrites intermediate data when no output asked + pipeline = compose(name="pipeline")( + operation(name="not run", needs=["a"], provides=["overriden"])(scream), + operation(name="op", needs=["overriden", "c"], provides=["asked"])(add), + ) + + inputs = {"a": 5, "overriden": 1, "c": 2} + exp = {"a": 5, "overriden": 1, "c": 2, "asked": 3} + # v1.2.4.ok + assert pipeline(inputs, ["asked"]) == filtdict(exp, "asked") + # FAILs + # - on v1.2.4 with (overriden, asked): = (5, 7) instead of (1, 3) + # - on #18(unsatisfied) + #23(ordered-sets) with (overriden, asked) = (5, 7) instead of (1, 3) + # FIXED on #26 + assert pipeline(inputs) == exp + + ## Test OVERWITES + # + overwrites = {} + pipeline.set_overwrites_collector(overwrites) + assert pipeline(inputs, ["asked"]) == filtdict(exp, "asked") + assert overwrites == {} # unjust must have been pruned + + overwrites = {} + pipeline.set_overwrites_collector(overwrites) + assert pipeline(inputs) == exp + assert overwrites == {} # unjust must have been pruned + + ## Test Parallel + # + pipeline.set_execution_method("parallel") + overwrites = {} + pipeline.set_overwrites_collector(overwrites) + #assert pipeline(inputs, ["asked"]) == filtdict(exp, "asked") + assert overwrites == {} # unjust must have been pruned + + overwrites = {} + pipeline.set_overwrites_collector(overwrites) + assert pipeline(inputs) == exp + assert overwrites == {} # unjust must have been pruned + + +def test_pruning_multiouts_not_override_intermediates1(): + # Test #25: v.1.2.4 overwrites intermediate data when a previous operation + # must run for its other outputs (outputs asked or not) + pipeline = compose(name="pipeline")( + operation(name="must run", needs=["a"], provides=["overriden", "calced"]) + (lambda x: (x, 2 * x)), + operation(name="add", needs=["overriden", "calced"], provides=["asked"])(add), + ) + + inputs = {"a": 5, "overriden": 1, "c": 2} + exp = {"a": 5, "overriden": 1, "calced": 10, "asked": 11} + # FAILs + # - on v1.2.4 with (overriden, asked) = (5, 15) instead of (1, 11) + # - on #18(unsatisfied) + #23(ordered-sets) like v1.2.4. + # FIXED on #26 + assert pipeline({"a": 5, "overriden": 1}) == exp + # FAILs + # - on v1.2.4 with KeyError: 'e', + # - on #18(unsatisfied) + #23(ordered-sets) with empty result. + # FIXED on #26 + assert pipeline(inputs, ["asked"]) == filtdict(exp, "asked") + + ## Test OVERWITES + # + overwrites = {} + pipeline.set_overwrites_collector(overwrites) + assert pipeline({"a": 5, "overriden": 1}) == exp + assert overwrites == {'overriden': 5} + + overwrites = {} + pipeline.set_overwrites_collector(overwrites) + assert pipeline(inputs, ["asked"]) == filtdict(exp, "asked") + assert overwrites == {'overriden': 5} + + ## Test parallel + # + pipeline.set_execution_method("parallel") + assert pipeline({"a": 5, "overriden": 1}) == exp + assert pipeline(inputs, ["asked"]) == filtdict(exp, "asked") + + +def test_pruning_multiouts_not_override_intermediates2(): + # Test #25: v.1.2.4 overrides intermediate data when a previous operation + # must run for its other outputs (outputs asked or not) + # SPURIOUS FAILS in < PY3.6 due to unordered dicts, + # eg https://travis-ci.org/ankostis/graphkit/jobs/594813119 + pipeline = compose(name="pipeline")( + operation(name="must run", needs=["a"], provides=["overriden", "e"]) + (lambda x: (x, 2 * x)), + operation(name="op1", needs=["overriden", "c"], provides=["d"])(add), + operation(name="op2", needs=["d", "e"], provides=["asked"])(mul), + ) + + inputs = {"a": 5, "overriden": 1, "c": 2} + exp = {"a": 5, "overriden": 1, "c": 2, "d": 3, "e": 10, "asked": 30} + # FAILs + # - on v1.2.4 with (overriden, asked) = (5, 70) instead of (1, 13) + # - on #18(unsatisfied) + #23(ordered-sets) like v1.2.4. + # FIXED on #26 + assert pipeline(inputs) == exp + # FAILs + # - on v1.2.4 with KeyError: 'e', + # - on #18(unsatisfied) + #23(ordered-sets) with empty result. + assert pipeline(inputs, ["asked"]) == filtdict(exp, "asked") + # FIXED on #26 + + ## Test OVERWITES + # + overwrites = {} + pipeline.set_overwrites_collector(overwrites) + assert pipeline(inputs) == exp + assert overwrites == {'overriden': 5} + + overwrites = {} + pipeline.set_overwrites_collector(overwrites) + assert pipeline(inputs, ["asked"]) == filtdict(exp, "asked") + assert overwrites == {'overriden': 5} + + ## Test parallel + # + pipeline.set_execution_method("parallel") + assert pipeline(inputs) == exp + assert pipeline(inputs, ["asked"]) == filtdict(exp, "asked") + + +def test_pruning_with_given_intermediate_and_asked_out(): + # Test #24: v1.2.4 does not prune before given intermediate data when + # outputs not asked, but does so when output asked. + pipeline = compose(name="pipeline")( + operation(name="unjustly pruned", needs=["given-1"], provides=["a"])(identity), + operation(name="shortcuted", needs=["a", "b"], provides=["given-2"])(add), + operation(name="good_op", needs=["a", "given-2"], provides=["asked"])(add), + ) + + exp = {"given-1": 5, "b": 2, "given-2": 2, "a": 5, "asked": 7} + # v1.2.4 is ok + assert pipeline({"given-1": 5, "b": 2, "given-2": 2}) == exp + # FAILS + # - on v1.2.4 with KeyError: 'a', + # - on #18 (unsatisfied) with no result. + # FIXED on #18+#26 (new dag solver). + assert pipeline({"given-1": 5, "b": 2, "given-2": 2}, ["asked"]) == filtdict(exp, "asked") + + ## Test OVERWITES + # + overwrites = {} + pipeline.set_overwrites_collector(overwrites) + assert pipeline({"given-1": 5, "b": 2, "given-2": 2}) == exp + assert overwrites == {} + + overwrites = {} + pipeline.set_overwrites_collector(overwrites) + assert pipeline({"given-1": 5, "b": 2, "given-2": 2}, ["asked"]) == filtdict(exp, "asked") + assert overwrites == {} + + ## Test parallel + # FAIL! in #26! + # + pipeline.set_execution_method("parallel") + assert pipeline({"given-1": 5, "b": 2, "given-2": 2}) == exp + assert pipeline({"given-1": 5, "b": 2, "given-2": 2}, ["asked"]) == filtdict(exp, "asked") + +def test_unsatisfied_operations(): + # Test that operations with partial inputs are culled and not failing. + pipeline = compose(name="pipeline")( + operation(name="add", needs=["a", "b1"], provides=["a+b1"])(add), + operation(name="sub", needs=["a", "b2"], provides=["a-b2"])(sub), + ) + + exp = {"a": 10, "b1": 2, "a+b1": 12} + assert pipeline({"a": 10, "b1": 2}) == exp + assert pipeline({"a": 10, "b1": 2}, outputs=["a+b1"]) == filtdict(exp, "a+b1") + + exp = {"a": 10, "b2": 2, "a-b2": 8} + assert pipeline({"a": 10, "b2": 2}) == exp + assert pipeline({"a": 10, "b2": 2}, outputs=["a-b2"]) == filtdict(exp, "a-b2") + + ## Test parallel + # + pipeline.set_execution_method("parallel") + exp = {"a": 10, "b1": 2, "a+b1": 12} + assert pipeline({"a": 10, "b1": 2}) == exp + assert pipeline({"a": 10, "b1": 2}, outputs=["a+b1"]) == filtdict(exp, "a+b1") + + exp = {"a": 10, "b2": 2, "a-b2": 8} + assert pipeline({"a": 10, "b2": 2}) == exp + assert pipeline({"a": 10, "b2": 2}, outputs=["a-b2"]) == filtdict(exp, "a-b2") + +def test_unsatisfied_operations_same_out(): + # Test unsatisfied pairs of operations providing the same output. + pipeline = compose(name="pipeline")( + operation(name="mul", needs=["a", "b1"], provides=["ab"])(mul), + operation(name="div", needs=["a", "b2"], provides=["ab"])(floordiv), + operation(name="add", needs=["ab", "c"], provides=["ab_plus_c"])(add), + ) + + exp = {"a": 10, "b1": 2, "c": 1, "ab": 20, "ab_plus_c": 21} + assert pipeline({"a": 10, "b1": 2, "c": 1}) == exp + assert pipeline({"a": 10, "b1": 2, "c": 1}, outputs=["ab_plus_c"]) == filtdict(exp, "ab_plus_c") + + exp = {"a": 10, "b2": 2, "c": 1, "ab": 5, "ab_plus_c": 6} + assert pipeline({"a": 10, "b2": 2, "c": 1}) == exp + assert pipeline({"a": 10, "b2": 2, "c": 1}, outputs=["ab_plus_c"]) == filtdict(exp, "ab_plus_c") + + ## Test parallel + # + # FAIL! in #26 + pipeline.set_execution_method("parallel") + exp = {"a": 10, "b1": 2, "c": 1, "ab": 20, "ab_plus_c": 21} + assert pipeline({"a": 10, "b1": 2, "c": 1}) == exp + assert pipeline({"a": 10, "b1": 2, "c": 1}, outputs=["ab_plus_c"]) == filtdict(exp, "ab_plus_c") + # + # FAIL! in #26 + exp = {"a": 10, "b2": 2, "c": 1, "ab": 5, "ab_plus_c": 6} + assert pipeline({"a": 10, "b2": 2, "c": 1}) == exp + assert pipeline({"a": 10, "b2": 2, "c": 1}, outputs=["ab_plus_c"]) == filtdict(exp, "ab_plus_c") def test_optional(): @@ -208,6 +495,68 @@ def addplusplus(a, b, c=0): assert results['sum'] == sum(named_inputs.values()) +def test_optional_per_function_with_same_output(): + # Test that the same need can be both optional and not on different operations. + # + ## ATTENTION, the selected function is NOT the one with more inputs + # but the 1st satisfiable function added in the network. + + add_op = operation(name='add', needs=['a', 'b'], provides='a+-b')(add) + sub_op_optional = operation( + name='sub_opt', needs=['a', modifiers.optional('b')], provides='a+-b' + )(lambda a, b=10: a - b) + + # Normal order + # + pipeline = compose(name='partial_optionals')(add_op, sub_op_optional) + # + named_inputs = {'a': 1, 'b': 2} + assert pipeline(named_inputs) == {'a': 1, 'a+-b': 3, 'b': 2} + assert pipeline(named_inputs, ['a+-b']) == {'a+-b': 3} + # + named_inputs = {'a': 1} + assert pipeline(named_inputs) == {'a': 1, 'a+-b': -9} + assert pipeline(named_inputs, ['a+-b']) == {'a+-b': -9} + + # Inverse op order + # + pipeline = compose(name='partial_optionals')(sub_op_optional, add_op) + # + named_inputs = {'a': 1, 'b': 2} + assert pipeline(named_inputs) == {'a': 1, 'a+-b': -1, 'b': 2} + assert pipeline(named_inputs, ['a+-b']) == {'a+-b': -1} + # + named_inputs = {'a': 1} + assert pipeline(named_inputs) == {'a': 1, 'a+-b': -9} + assert pipeline(named_inputs, ['a+-b']) == {'a+-b': -9} + + # PARALLEL + Normal order + # + pipeline = compose(name='partial_optionals')(add_op, sub_op_optional) + pipeline.set_execution_method("parallel") + # + named_inputs = {'a': 1, 'b': 2} + assert pipeline(named_inputs) == {'a': 1, 'a+-b': 3, 'b': 2} + assert pipeline(named_inputs, ['a+-b']) == {'a+-b': 3} + # + named_inputs = {'a': 1} + assert pipeline(named_inputs) == {'a': 1, 'a+-b': -9} + assert pipeline(named_inputs, ['a+-b']) == {'a+-b': -9} + + # PARALLEL + Inverse op order + # + pipeline = compose(name='partial_optionals')(sub_op_optional, add_op) + pipeline.set_execution_method("parallel") + # + named_inputs = {'a': 1, 'b': 2} + assert pipeline(named_inputs) == {'a': 1, 'a+-b': -1, 'b': 2} + assert pipeline(named_inputs, ['a+-b']) == {'a+-b': -1} + # + named_inputs = {'a': 1} + assert pipeline(named_inputs) == {'a': 1, 'a+-b': -9} + assert pipeline(named_inputs, ['a+-b']) == {'a+-b': -9} + + def test_deleted_optional(): # Test that DeleteInstructions included for optionals do not raise # exceptions when the corresponding input is not prodided. @@ -226,22 +575,71 @@ def addplusplus(a, b, c=0): assert 'sum2' in results +def test_deleteinstructs_vary_with_inputs(): + # Check #21: DeleteInstructions positions vary when inputs change. + def count_deletions(steps): + return sum(isinstance(n, DeleteInstruction) for n in steps) + + pipeline = compose(name="pipeline")( + operation(name="a free without b", needs=["a"], provides=["aa"])(identity), + operation(name="satisfiable", needs=["a", "b"], provides=["ab"])(add), + operation(name="optional ab", needs=["aa", modifiers.optional("ab")], provides=["asked"]) + (lambda a, ab=10: a + ab), + ) + + inp = {"a": 2, "b": 3} + exp = inp.copy(); exp.update({"aa": 2, "ab": 5, "asked": 7}) + res = pipeline(inp) + assert res == exp # ok + steps11 = pipeline.compile(inp).steps + res = pipeline(inp, outputs=["asked"]) + assert res == filtdict(exp, "asked") # ok + steps12 = pipeline.compile(inp, ["asked"]).steps + + inp = {"a": 2} + exp = inp.copy(); exp.update({"aa": 2, "asked": 12}) + res = pipeline(inp) + assert res == exp # ok + steps21 = pipeline.compile(inp).steps + res = pipeline(inp, outputs=["asked"]) + assert res == filtdict(exp, "asked") # ok + steps22 = pipeline.compile(inp, ["asked"]).steps + + # When no outs, no del-instructs. + assert steps11 != steps12 + assert count_deletions(steps11) == 0 + assert steps21 != steps22 + assert count_deletions(steps21) == 0 + + # Check steps vary with inputs + # + # FAILs in v1.2.4 + #18, PASS in #26 + assert steps11 != steps21 + # Check deletes vary with inputs + # + # FAILs in v1.2.4 + #18, PASS in #26 + assert count_deletions(steps12) != count_deletions(steps22) + + +@pytest.mark.slow def test_parallel_execution(): import time + delay = 0.5 + def fn(x): - time.sleep(1) + time.sleep(delay) print("fn %s" % (time.time() - t0)) return 1 + x def fn2(a,b): - time.sleep(1) + time.sleep(delay) print("fn2 %s" % (time.time() - t0)) return a+b def fn3(z, k=1): - time.sleep(1) + time.sleep(delay) print("fn3 %s" % (time.time() - t0)) return z + k @@ -280,6 +678,7 @@ def fn3(z, k=1): # make sure results are the same using either method assert result_sequential == result_threaded +@pytest.mark.slow def test_multi_threading(): import time import random @@ -310,8 +709,8 @@ def infer(i): assert tuple(sorted(results.keys())) == tuple(sorted(outputs)), (outputs, results) return results - N = 100 - for i in range(20, 200): + N = 33 + for i in range(13, 61): pool = Pool(i) pool.map(infer, range(N)) pool.close() @@ -353,6 +752,7 @@ def compute(self, inputs): outputs.append(p) return outputs + def test_backwards_compatibility(): sum_op1 = Sum( diff --git a/test/test_plot.py b/test/test_plot.py new file mode 100644 index 00000000..d17201cb --- /dev/null +++ b/test/test_plot.py @@ -0,0 +1,190 @@ +# Copyright 2016, Yahoo Inc. +# Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms. + +import sys +from operator import add + +import pytest + +from graphkit import base, compose, network, operation, plot +from graphkit.modifiers import optional + + +@pytest.fixture +def pipeline(): + return compose(name="netop")( + operation(name="add", needs=["a", "b1"], provides=["ab1"])(add), + operation(name="sub", needs=["a", optional("b2")], provides=["ab2"])( + lambda a, b=1: a - b + ), + operation(name="abb", needs=["ab1", "ab2"], provides=["asked"])(add), + ) + + +@pytest.fixture(params=[{"a": 1}, {"a": 1, "b1": 2}]) +def inputs(request): + return {"a": 1, "b1": 2} + + +@pytest.fixture(params=[None, ("a", "b1")]) +def input_names(request): + return request.param + + +@pytest.fixture(params=[None, ["asked", "b1"]]) +def outputs(request): + return request.param + + +@pytest.fixture(params=[None, 1]) +def solution(pipeline, inputs, outputs, request): + return request.param and pipeline(inputs, outputs) + + +###### TEST CASES ####### +## + + +def test_plotting_docstring(): + common_formats = ".png .dot .jpg .jpeg .pdf .svg".split() + for ext in common_formats: + assert ext in base.NetworkOperation.plot.__doc__ + assert ext in network.Network.plot.__doc__ + + +@pytest.mark.slow +def test_plot_formats(pipeline, tmp_path): + ## Generate all formats (not needing to save files) + + # run it here (and not in ficture) to ansure `last_plan` exists. + inputs = {"a": 1, "b1": 2} + outputs = ["asked", "b1"] + solution = pipeline(inputs, outputs) + + # The 1st list does not working on my PC, or travis. + # NOTE: maintain the other lists manually from the Exception message. + failing_formats = ".dia .hpgl .mif .mp .pcl .pic .vtx .xlib".split() + # The subsequent format names producing the same dot-file. + dupe_formats = [ + ".cmapx_np", # .cmapx + ".imap_np", # .imap + ".jpeg", # .jpe + ".jpg", # .jpe + ".plain-ext", # .plain + ] + null_formats = ".cmap .ismap".split() + forbidden_formats = set(failing_formats + dupe_formats + null_formats) + formats_to_check = sorted(set(plot.supported_plot_formats()) - forbidden_formats) + + # Collect old dots to detect dupes. + prev_renders = {} + dupe_errs = [] + for ext in formats_to_check: + # Check Network. + # + render = pipeline.plot(solution=solution).create(format=ext[1:]) + if not render: + dupe_errs.append("\n null: %s" % ext) + + elif render in prev_renders.values(): + dupe_errs.append( + "\n dupe: %s <--> %s" + % (ext, [pext for pext, pdot in prev_renders.items() if pdot == render]) + ) + else: + prev_renders[ext] = render + + if dupe_errs: + raise AssertionError("Failed pydot formats: %s" % "".join(sorted(dupe_errs))) + + +def test_plotters_hierarchy(pipeline, inputs, outputs): + # Plotting original network, no plan. + base_dot = str(pipeline.plot(inputs=inputs, outputs=outputs)) + assert base_dot + assert pipeline.name in str(base_dot) + + solution = pipeline(inputs, outputs) + + # Plotting delegates to netwrok plan. + plan_dot = str(pipeline.plot(inputs=inputs, outputs=outputs)) + assert plan_dot + assert plan_dot != base_dot + assert pipeline.name in str(plan_dot) + + # Plot a plan + solution, which must be different from all before. + sol_plan_dot = str(pipeline.plot(inputs=inputs, outputs=outputs, solution=solution)) + assert sol_plan_dot != base_dot + assert sol_plan_dot != plan_dot + assert pipeline.name in str(plan_dot) + + plan = pipeline.net.last_plan + pipeline.net.last_plan = None + + # We resetted last_plan to check if it reproduces original. + base_dot2 = str(pipeline.plot(inputs=inputs, outputs=outputs)) + assert str(base_dot2) == str(base_dot) + + # Calling plot directly on plan misses netop.name + raw_plan_dot = str(plan.plot(inputs=inputs, outputs=outputs)) + assert pipeline.name not in str(raw_plan_dot) + + # Chek plan does not contain solution, unless given. + raw_sol_plan_dot = str(plan.plot(inputs=inputs, outputs=outputs, solution=solution)) + assert raw_sol_plan_dot != raw_plan_dot + + +def test_plot_bad_format(pipeline, tmp_path): + with pytest.raises(ValueError, match="Unknown file format") as exinfo: + pipeline.plot(filename="bad.format") + + ## Check help msg lists all siupported formats + for ext in plot.supported_plot_formats(): + assert exinfo.match(ext) + + +def test_plot_write_file(pipeline, tmp_path): + # Try saving a file from one format. + + fpath = tmp_path / "network.png" + dot1 = pipeline.plot(str(fpath)) + assert fpath.exists() + assert dot1 + + +def _check_plt_img(img): + assert img is not None + assert len(img) > 0 + + +def test_plot_matpotlib(pipeline, tmp_path): + ## Try matplotlib Window, but # without opening a Window. + + if sys.version_info < (3, 5): + # On PY< 3.5 it fails with: + # nose.proxy.TclError: no display name and no $DISPLAY environment variable + # eg https://travis-ci.org/ankostis/graphkit/jobs/593957996 + import matplotlib + + matplotlib.use("Agg") + # do not open window in headless travis + img = pipeline.plot(show=-1) + _check_plt_img(img) + + +def test_plot_jupyter(pipeline, tmp_path): + ## Try returned Jupyter SVG. + + dot = pipeline.plot() + s = dot._repr_svg_() + assert "SVG" in s + + +def test_plot_legend(pipeline, tmp_path): + ## Try returned Jupyter SVG. + + dot = plot.legend() + assert dot + + img = plot.legend(show=-1) + _check_plt_img(img)