diff --git a/graphkit/network.py b/graphkit/network.py index 8c440be3..0df3ddf8 100644 --- a/graphkit/network.py +++ b/graphkit/network.py @@ -49,6 +49,11 @@ def __init__(self, **kwargs): # a compiled list of steps to evaluate layers *in order* and free mem. self.steps = [] + # This holds a cache of results for the _find_necessary_steps + # function, this helps speed up the compute call as well avoid + # a multithreading issue that is occuring when accessing the + # graph in networkx + self._necessary_steps_cache = {} def add_op(self, operation): @@ -157,6 +162,14 @@ def _find_necessary_steps(self, outputs, inputs): provided inputs and requested outputs. """ + # return steps if it has already been computed before for this set of inputs and outputs + outputs = tuple(sorted(outputs)) if isinstance(outputs, (list, set)) else outputs + inputs_keys = tuple(sorted(inputs.keys())) + cache_key = (inputs_keys, outputs) + if cache_key in self._necessary_steps_cache: + return self._necessary_steps_cache[cache_key] + + graph = self.graph if not outputs: # If caller requested all outputs, the necessary nodes are all @@ -164,8 +177,8 @@ def _find_necessary_steps(self, outputs, inputs): # names that aren't in the graph. necessary_nodes = set() for input_name in iter(inputs): - if self.graph.has_node(input_name): - necessary_nodes |= nx.descendants(self.graph, input_name) + if graph.has_node(input_name): + necessary_nodes |= nx.descendants(graph, input_name) else: @@ -175,24 +188,30 @@ def _find_necessary_steps(self, outputs, inputs): # in the graph. unnecessary_nodes = set() for input_name in iter(inputs): - if self.graph.has_node(input_name): - unnecessary_nodes |= nx.ancestors(self.graph, input_name) + if graph.has_node(input_name): + unnecessary_nodes |= nx.ancestors(graph, input_name) # Find the nodes we need to be able to compute the requested # outputs. Raise an exception if a requested output doesn't # exist in the graph. necessary_nodes = set() for output_name in outputs: - if not self.graph.has_node(output_name): + if not graph.has_node(output_name): raise ValueError("graphkit graph does not have an output " "node named %s" % output_name) - necessary_nodes |= nx.ancestors(self.graph, output_name) + necessary_nodes |= nx.ancestors(graph, output_name) # Get rid of the unnecessary nodes from the set of necessary ones. necessary_nodes -= unnecessary_nodes + + necessary_steps = [step for step in self.steps if step in necessary_nodes] + + # save this result in a precomputed cache for future lookup + self._necessary_steps_cache[cache_key] = necessary_steps + # Return an ordered list of the needed steps. - return [step for step in self.steps if step in necessary_nodes] + return necessary_steps def compute(self, outputs, named_inputs, method=None): diff --git a/test/test_graphkit.py b/test/test_graphkit.py index fb8d6330..bd97b317 100644 --- a/test/test_graphkit.py +++ b/test/test_graphkit.py @@ -280,6 +280,43 @@ def fn3(z, k=1): # make sure results are the same using either method assert result_sequential == result_threaded +def test_multi_threading(): + import time + import random + from multiprocessing.dummy import Pool + + def op_a(a, b): + time.sleep(random.random()*.02) + return a+b + + def op_b(c, b): + time.sleep(random.random()*.02) + return c+b + + def op_c(a, b): + time.sleep(random.random()*.02) + return a*b + + pipeline = compose(name="pipeline", merge=True)( + operation(name="op_a", needs=['a', 'b'], provides='c')(op_a), + operation(name="op_b", needs=['c', 'b'], provides='d')(op_b), + operation(name="op_c", needs=['a', 'b'], provides='e')(op_c), + ) + + def infer(i): + # data = open("616039-bradpitt.jpg").read() + outputs = ["c", "d", "e"] + results = pipeline({"a": 1, "b":2}, outputs) + assert tuple(sorted(results.keys())) == tuple(sorted(outputs)), (outputs, results) + return results + + N = 100 + for i in range(20, 200): + pool = Pool(i) + pool.map(infer, range(N)) + pool.close() + + #################################### # Backwards compatibility ####################################