Skip to content

Commit

Permalink
build_fn is always provided by "deps" argument
Browse files Browse the repository at this point in the history
  • Loading branch information
spirali committed Aug 14, 2019
1 parent e9da582 commit 2777351
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 41 deletions.
2 changes: 1 addition & 1 deletion examples/simple/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
runtime = Runtime("mydb.db")


def do_preprocessing(config):
def do_preprocessing(config, deps):
time.sleep(0.3) # Simulate computation
return random.randint(0, 10)

Expand Down
2 changes: 1 addition & 1 deletion examples/tournament/tournament.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
runtime = Runtime("mydb.db")


def train_player(config):
def train_player(config, deps):
time.sleep(random.randint(5, 15) / 10)
return {"strength": random.randint(0, 10)}

Expand Down
39 changes: 15 additions & 24 deletions orco/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import cloudpickle
import tqdm

from orco.ref import resolve_ref_keys, ref_to_refkey, RefKey
from orco.ref import resolve_ref_keys, ref_to_refkey, RefKey, collect_ref_keys
from .db import DB
from .task import Task

Expand Down Expand Up @@ -104,15 +104,14 @@ def _init(self, tasks):

for task in tasks:
count = 0
if task.inputs is not None:
for inp in task.inputs:
if isinstance(inp, Task):
count += 1
c = consumers.get(inp)
if c is None:
c = []
consumers[inp] = c
c.append(task)
for inp in task.inputs:
if isinstance(inp, Task):
count += 1
c = consumers.get(inp)
if c is None:
c = []
consumers[inp] = c
c.append(task)
if count == 0:
ready.append(task)
waiting_deps[task] = count
Expand Down Expand Up @@ -140,19 +139,13 @@ def submit(task):
collection = self.runtime._get_collection(task.ref)
pickled_fns = cloudpickle.dumps((collection.build_fn, collection.make_raw_entry))
pickle_cache[ref.collection_name] = pickled_fns
if task.inputs is not None:
inputs = [t.ref.ref_key()
if isinstance(t, Task) else t.ref_key() for t in task.inputs]
else:
inputs = None
return pool.submit(_run_task,
self.id,
db.path,
pickled_fns,
task.ref.ref_key(),
task.ref.config,
ref_to_refkey(task.dep_value),
inputs)
ref_to_refkey(task.dep_value))
self.stats = {
"n_tasks": len(all_tasks),
"n_completed": 0
Expand Down Expand Up @@ -201,19 +194,17 @@ def submit(task):
_per_process_db = None


def _run_task(executor_id, db_path, fns, ref_key, config, dep_value, deps):
def _run_task(executor_id, db_path, fns, ref_key, config, dep_value):
global _per_process_db
if _per_process_db is None:
_per_process_db = DB(db_path, threading=False)
build_fn, finalize_fn = cloudpickle.loads(fns)

start_time = time.time()

if deps is not None:
ref_map = {ref: _per_process_db.get_entry(ref.collection_name, ref.key) for ref in deps}
dep_value = resolve_ref_keys(dep_value, ref_map)
value = build_fn(config, dep_value)
else:
value = build_fn(config)
deps = collect_ref_keys(dep_value)
ref_map = {ref: _per_process_db.get_entry(ref.collection_name, ref.key) for ref in deps}
dep_value = resolve_ref_keys(dep_value, ref_map)
value = build_fn(config, dep_value)
end_time = time.time()
return finalize_fn(ref_key.collection_name, ref_key.key, None, value, end_time - start_time)
22 changes: 20 additions & 2 deletions orco/ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,31 @@ def __hash__(self):
return hash((self.collection, self.key))
"""


def collect_refs(obj):
result = set()
_collect_refs_helper(obj, result)
_collect_refs_helper(obj, result, Ref)
return result


def _collect_refs_helper(dep_value, ref_set, r_class):
if isinstance(dep_value, r_class):
ref_set.add(dep_value)
elif isinstance(dep_value, dict):
for val in dep_value.values():
_collect_refs_helper(val, ref_set, r_class)
elif isinstance(dep_value, Iterable):
for val in dep_value:
_collect_refs_helper(val, ref_set, r_class)


def collect_ref_keys(obj):
result = set()
_collect_refs_helper(obj, result, RefKey)
return result


def _collect_refs_helper(dep_value, ref_set):
def _collect_ref_keys_helper(dep_value, ref_set):
if isinstance(dep_value, Ref):
ref_set.add(dep_value)
elif isinstance(dep_value, dict):
Expand Down
2 changes: 1 addition & 1 deletion orco/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def make_task(ref):
for r in dep_refs:
global_deps.add((r, ref))
else:
inputs = None
inputs = ()
dep_value = None
if state is None and collection.build_fn is None:
raise Exception("Computation depends on a missing configuration '{}' in a fixed collection".format(ref))
Expand Down
2 changes: 1 addition & 1 deletion tests/interactive/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
rt.register_executor(executor3)
executor3.stop()

c_sleepers = rt.register_collection("sleepers", lambda c: time.sleep(c))
c_sleepers = rt.register_collection("sleepers", lambda c, d: time.sleep(c))
c_bedrooms = rt.register_collection("bedrooms", lambda c, d: None, lambda c: [c_sleepers.ref(x) for x in c["sleepers"]])

rt.compute(c_bedrooms.ref({"sleepers": [0.1]}))
Expand Down
19 changes: 10 additions & 9 deletions tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def test_collection_compute(env):

counter = env.file_storage("counter", 0)

def adder(config):
def adder(config, deps):
assert not deps
counter.write(counter.read() + 1)
return config["a"] + config["b"]

Expand Down Expand Up @@ -68,7 +69,7 @@ def test_collection_deps(env):

counter_file = env.file_storage("counter", [0, 0])

def builder1(config):
def builder1(config, input):
counter = counter_file.read()
counter[0] += 1
counter_file.write(counter)
Expand Down Expand Up @@ -125,7 +126,7 @@ def test_collection_deps_complex(env):
runtime = env.test_runtime()
runtime.register_executor(LocalExecutor(n_processes=1))

def builder1(config):
def builder1(config, input):
return config * 10

def builder2(config, deps):
Expand All @@ -150,7 +151,7 @@ def test_collection_double_ref(env):
runtime = env.test_runtime()
runtime.register_executor(LocalExecutor())

col1 = runtime.register_collection("col1", lambda c: c * 10)
col1 = runtime.register_collection("col1", lambda c, d: c * 10)
col2 = runtime.register_collection("col2",
(lambda c, d: sum(x.value for x in d)),
(lambda c: [col1.ref(10), col1.ref(10), col1.ref(10)]))
Expand All @@ -161,7 +162,7 @@ def test_collection_stored_deps(env):
runtime = env.test_runtime()
runtime.register_executor(LocalExecutor())

col1 = runtime.register_collection("col1", lambda c: c * 10)
col1 = runtime.register_collection("col1", lambda c, d: c * 10)
col2 = runtime.register_collection("col2",
(lambda c, d: sum(x.value for x in d)),
lambda c: [col1.ref(i) for i in range(c["start"], c["end"], c["step"])])
Expand Down Expand Up @@ -241,7 +242,7 @@ def test_collection_clean(env):
runtime = env.test_runtime()
runtime.register_executor(LocalExecutor())

col1 = runtime.register_collection("col1", lambda c: c)
col1 = runtime.register_collection("col1", lambda c, d: c)
col2 = runtime.register_collection("col2", lambda c, d: c, lambda c: [col1.ref(c)])

runtime.compute(col2.ref(1))
Expand All @@ -255,7 +256,7 @@ def test_collection_to_pandas(env):
runtime = env.test_runtime()
runtime.register_executor(LocalExecutor())

col1 = runtime.register_collection("col1", lambda c: c * 2)
col1 = runtime.register_collection("col1", lambda c, d: c * 2)
runtime.compute(col1.refs([1, 2, 3, 4]))
frame = runtime.to_pandas(col1)
assert len(frame) == 4
Expand All @@ -269,7 +270,7 @@ def test_collection_invalidate(env):
runtime = env.test_runtime()
runtime.register_executor(LocalExecutor())

col1 = runtime.register_collection("col1", lambda c: c)
col1 = runtime.register_collection("col1", lambda c, d: c)
col2 = runtime.register_collection("col2", lambda c, d: c, lambda c: [col1.ref(c)])
col3 = runtime.register_collection("col3", lambda c, d: c, lambda c: [col1.ref(c)])
col4 = runtime.register_collection("col4", lambda c, d: c, lambda c: [col2.ref(c)])
Expand All @@ -287,7 +288,7 @@ def test_collection_computed(env):
runtime = env.test_runtime()
runtime.register_executor(LocalExecutor(n_processes=1))

def build_fn(x):
def build_fn(x, deps):
return x * 10

collection = runtime.register_collection("col1", build_fn)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_executor_error(env):
executor = LocalExecutor(heartbeat_interval=1, n_processes=2)
runtime.register_executor(executor)

col0 = runtime.register_collection("col0", lambda c: c)
col0 = runtime.register_collection("col0", lambda c, d: c)
col1 = runtime.register_collection("col1", lambda c, d: 100 // d[0].value, lambda c: [col0.ref(c)])
col2 = runtime.register_collection("col2", lambda c, ds: sum(d.value for d in ds), lambda c: [col1.ref(x) for x in c])

Expand All @@ -72,7 +72,7 @@ def test_executor_error(env):

def test_executor_conflict(env, tmpdir):

def compute_0(c):
def compute_0(c, d):
path = tmpdir.join("test-{}".format(c))
assert not path.check()
path.write("Done")
Expand Down

0 comments on commit 2777351

Please sign in to comment.