Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions graphkit/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def __init__(self, **kwargs):
# a compiled list of steps to evaluate layers *in order* and free mem.
self.steps = []

# This holds a cache of results for the _find_necessary_steps
# function, this helps speed up the compute call as well avoid
# a multithreading issue that is occuring when accessing the
# graph in networkx
self._necessary_steps_cache = {}


def add_op(self, operation):
Expand Down Expand Up @@ -157,15 +162,23 @@ def _find_necessary_steps(self, outputs, inputs):
provided inputs and requested outputs.
"""

# return steps if it has already been computed before for this set of inputs and outputs
outputs = tuple(sorted(outputs)) if isinstance(outputs, (list, set)) else outputs
inputs_keys = tuple(sorted(inputs.keys()))
cache_key = (inputs_keys, outputs)
if cache_key in self._necessary_steps_cache:
return self._necessary_steps_cache[cache_key]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i had no idea you could have a tuple of tuples as a key. seems to work though.


graph = self.graph
if not outputs:

# If caller requested all outputs, the necessary nodes are all
# nodes that are reachable from one of the inputs. Ignore input
# names that aren't in the graph.
necessary_nodes = set()
for input_name in iter(inputs):
if self.graph.has_node(input_name):
necessary_nodes |= nx.descendants(self.graph, input_name)
if graph.has_node(input_name):
necessary_nodes |= nx.descendants(graph, input_name)

else:

Expand All @@ -175,24 +188,30 @@ def _find_necessary_steps(self, outputs, inputs):
# in the graph.
unnecessary_nodes = set()
for input_name in iter(inputs):
if self.graph.has_node(input_name):
unnecessary_nodes |= nx.ancestors(self.graph, input_name)
if graph.has_node(input_name):
unnecessary_nodes |= nx.ancestors(graph, input_name)

# Find the nodes we need to be able to compute the requested
# outputs. Raise an exception if a requested output doesn't
# exist in the graph.
necessary_nodes = set()
for output_name in outputs:
if not self.graph.has_node(output_name):
if not graph.has_node(output_name):
raise ValueError("graphkit graph does not have an output "
"node named %s" % output_name)
necessary_nodes |= nx.ancestors(self.graph, output_name)
necessary_nodes |= nx.ancestors(graph, output_name)

# Get rid of the unnecessary nodes from the set of necessary ones.
necessary_nodes -= unnecessary_nodes


necessary_steps = [step for step in self.steps if step in necessary_nodes]

# save this result in a precomputed cache for future lookup
self._necessary_steps_cache[cache_key] = necessary_steps

# Return an ordered list of the needed steps.
return [step for step in self.steps if step in necessary_nodes]
return necessary_steps


def compute(self, outputs, named_inputs, method=None):
Expand Down
37 changes: 37 additions & 0 deletions test/test_graphkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,43 @@ def fn3(z, k=1):
# make sure results are the same using either method
assert result_sequential == result_threaded

def test_multi_threading():
import time
import random
from multiprocessing.dummy import Pool

def op_a(a, b):
time.sleep(random.random()*.02)
return a+b

def op_b(c, b):
time.sleep(random.random()*.02)
return c+b

def op_c(a, b):
time.sleep(random.random()*.02)
return a*b

pipeline = compose(name="pipeline", merge=True)(
operation(name="op_a", needs=['a', 'b'], provides='c')(op_a),
operation(name="op_b", needs=['c', 'b'], provides='d')(op_b),
operation(name="op_c", needs=['a', 'b'], provides='e')(op_c),
)

def infer(i):
# data = open("616039-bradpitt.jpg").read()
outputs = ["c", "d", "e"]
results = pipeline({"a": 1, "b":2}, outputs)
assert tuple(sorted(results.keys())) == tuple(sorted(outputs)), (outputs, results)
return results

N = 100
for i in range(20, 200):
pool = Pool(i)
pool.map(infer, range(N))
pool.close()


####################################
# Backwards compatibility
####################################
Expand Down