Skip to content

Commit

Permalink
Merge pull request #9 from yahoo/thread_pool_barrier_executor
Browse files Browse the repository at this point in the history
Thread pool barrier executor
  • Loading branch information
huyng committed Dec 16, 2017
2 parents 708eb1f + 7d63219 commit d3f8a8a
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 5 deletions.
18 changes: 17 additions & 1 deletion graphkit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,28 @@ def __init__(self, **kwargs):
self.net = kwargs.pop('net')
Operation.__init__(self, **kwargs)

# set execution mode to single-threaded sequential by default
self._execution_method = "sequential"

def _compute(self, named_inputs, outputs=None):
return self.net.compute(outputs, named_inputs)
return self.net.compute(outputs, named_inputs, method=self._execution_method)

def __call__(self, *args, **kwargs):
return self._compute(*args, **kwargs)

def set_execution_method(self, method):
"""
Determine how the network will be executed.
Args:
method: str
If "parallel", execute graph operations concurrently
using a threadpool.
"""
options = ['parallel', 'sequential']
assert method in options
self._execution_method = method

def plot(self, filename=None, show=False):
self.net.plot(filename=filename, show=show)

Expand Down
134 changes: 130 additions & 4 deletions graphkit/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, **kwargs):
self.steps = []



def add_op(self, operation):
"""
Adds the given operation and its data requirements to the network graph
Expand Down Expand Up @@ -184,7 +185,7 @@ def _find_necessary_steps(self, outputs, inputs):
for output_name in outputs:
if not self.graph.has_node(output_name):
raise ValueError("graphkit graph does not have an output "
"node named %s" % output_name)
"node named %s" % output_name)
necessary_nodes |= nx.ancestors(self.graph, output_name)

# Get rid of the unnecessary nodes from the set of necessary ones.
Expand All @@ -194,7 +195,7 @@ def _find_necessary_steps(self, outputs, inputs):
return [step for step in self.steps if step in necessary_nodes]


def compute(self, outputs, named_inputs):
def compute(self, outputs, named_inputs, method=None):
"""
Run the graph. Any inputs to the network must be passed in by name.
Expand All @@ -214,9 +215,86 @@ def compute(self, outputs, named_inputs):

# assert that network has been compiled
assert self.steps, "network must be compiled before calling compute."
assert isinstance(outputs, (list, tuple)) or outputs == None,\
assert isinstance(outputs, (list, tuple)) or outputs is None,\
"The outputs argument must be a list"


# choose a method of execution
if method == "parallel":
return self._compute_thread_pool_barrier_method(named_inputs,
outputs)
else:
return self._compute_sequential_method(named_inputs,
outputs)


def _compute_thread_pool_barrier_method(self, named_inputs, outputs,
thread_pool_size=10):
"""
This method runs the graph using a parallel pool of thread executors.
You may achieve lower total latency if your graph is sufficiently
sub divided into operations using this method.
"""
from multiprocessing.dummy import Pool

# if we have not already created a thread_pool, create one
if not hasattr(self, "_thread_pool"):
self._thread_pool = Pool(thread_pool_size)
pool = self._thread_pool

cache = {}
cache.update(named_inputs)
necessary_nodes = self._find_necessary_steps(outputs, named_inputs)

# this keeps track of all nodes that have already executed
has_executed = set()

# with each loop iteration, we determine a set of operations that can be
# scheduled, then schedule them onto a thread pool, then collect their
# results onto a memory cache for use upon the next iteration.
while True:

# the upnext list contains a list of operations for scheduling
# in the current round of scheduling
upnext = []
for node in necessary_nodes:
# only delete if all successors for the data node have been executed
if isinstance(node, DeleteInstruction):
if ready_to_delete_data_node(node,
has_executed,
self.graph):
if node in cache:
cache.pop(node)

# continue if this node is anything but an operation node
if not isinstance(node, Operation):
continue

if ready_to_schedule_operation(node, has_executed, self.graph) \
and node not in has_executed:
upnext.append(node)


# stop if no nodes left to schedule, exit out of the loop
if len(upnext) == 0:
break

done_iterator = pool.imap_unordered(
lambda op: (op,op._compute(cache)),
upnext)
for op, result in done_iterator:
cache.update(result)
has_executed.add(op)

if not outputs:
return cache
else:
return {k: cache[k] for k in iter(cache) if k in outputs}

def _compute_sequential_method(self, named_inputs, outputs):
"""
This method runs the graph one operation at a time in a single thread
"""
# start with fresh data cache
cache = {}

Expand All @@ -227,7 +305,7 @@ def compute(self, outputs, named_inputs):
# outputs from the provided inputs.
all_steps = self._find_necessary_steps(outputs, named_inputs)

self.times={}
self.times = {}
for step in all_steps:

if isinstance(step, Operation):
Expand Down Expand Up @@ -349,3 +427,51 @@ def get_node_name(a):
plt.show()

return g


def ready_to_schedule_operation(op, has_executed, graph):
"""
Determines if a Operation is ready to be scheduled for execution based on
what has already been executed.
Args:
op:
The Operation object to check
has_executed: set
A set containing all operations that have been executed so far
graph:
The networkx graph containing the operations and data nodes
Returns:
A boolean indicating whether the operation may be scheduled for
execution based on what has already been executed.
"""
dependencies = set(filter(lambda v: isinstance(v, Operation),
nx.ancestors(graph, op)))
return dependencies.issubset(has_executed)

def ready_to_delete_data_node(name, has_executed, graph):
"""
Determines if a DataPlaceholderNode is ready to be deleted from the
cache.
Args:
name:
The name of the data node to check
has_executed: set
A set containing all operations that have been executed so far
graph:
The networkx graph containing the operations and data nodes
Returns:
A boolean indicating whether the data node can be deleted or not.
"""
data_node = get_data_node(name, graph)
return set(graph.successors(data_node)).issubset(has_executed)

def get_data_node(name, graph):
"""
Gets a data node from a graph using its name
"""
for node in graph.nodes():
if node == name and isinstance(node, DataPlaceholderNode):
return node
return None
53 changes: 53 additions & 0 deletions test/test_graphkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,59 @@ def addplusplus(a, b, c=0):



def test_parallel_execution():
import time

def fn(x):
time.sleep(1)
print("fn %s" % (time.time() - t0))
return 1 + x

def fn2(a,b):
time.sleep(1)
print("fn2 %s" % (time.time() - t0))
return a+b

def fn3(z, k=1):
time.sleep(1)
print("fn3 %s" % (time.time() - t0))
return z + k

pipeline = compose(name="l", merge=True)(

# the following should execute in parallel under threaded execution mode
operation(name="a", needs="x", provides="ao")(fn),
operation(name="b", needs="x", provides="bo")(fn),

# this should execute after a and b have finished
operation(name="c", needs=["ao", "bo"], provides="co")(fn2),

operation(name="d",
needs=["ao", modifiers.optional("k")],
provides="do")(fn3),

operation(name="e", needs=["ao", "bo"], provides="eo")(fn2),
operation(name="f", needs="eo", provides="fo")(fn),
operation(name="g", needs="fo", provides="go")(fn)


)

t0 = time.time()
pipeline.set_execution_method("parallel")
result_threaded = pipeline({"x": 10}, ["co", "go", "do"])
print("threaded result")
print(result_threaded)

t0 = time.time()
pipeline.set_execution_method("sequential")
result_sequential = pipeline({"x": 10}, ["co", "go", "do"])
print("sequential result")
print(result_sequential)

# make sure results are the same using either method
assert result_sequential == result_threaded

####################################
# Backwards compatibility
####################################
Expand Down

0 comments on commit d3f8a8a

Please sign in to comment.