Skip to content

Commit

Permalink
Refactored dataloader tests
Browse files Browse the repository at this point in the history
  • Loading branch information
syrusakbary committed Mar 3, 2017
1 parent efc1264 commit b21d626
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 37 deletions.
5 changes: 4 additions & 1 deletion promise/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,11 +428,14 @@ def on_error(error):
on_error,
)

peek = Context.peek_context()
if peek:
peek.drain_queue()

if self._trace:
# If we wait, we drain the queue of the
# callbacks waiting on the context exit
# so we avoid a blocking state
Context.peek_context().drain_queue()
self._trace.drain_queue()

self._is_waiting = True
Expand Down
37 changes: 1 addition & 36 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,43 +43,8 @@ def call_fn(keys):
assert values == []


def test_batches_multiple_requests():
identity_loader, load_calls = id_loader()

@Promise.safe
def safe():
promise1 = identity_loader.load(1)
promise2 = identity_loader.load(2)
return promise1, promise2

promise1, promise2 = safe()
value1, value2 = Promise.all([promise1, promise2]).get()
assert value1 == 1
assert value2 == 2

assert load_calls == [[1, 2]]


def test_batches_multiple_requests_two():
identity_loader, load_calls = id_loader()

@Promise.safe
def safe():
promise1 = identity_loader.load(1)
promise2 = identity_loader.load(2)
return Promise.all([promise1, promise2])

p = safe()
value1, value2 = p.get()

assert value1 == 1
assert value2 == 2

assert load_calls == [[1, 2]]


@Promise.safe
def test_batches_multiple_requests_safe():
def test_batches_multiple_requests():
identity_loader, load_calls = id_loader()

promise1 = identity_loader.load(1)
Expand Down
65 changes: 65 additions & 0 deletions tests/test_dataloader_extra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from promise import Promise
from promise.dataloader import DataLoader


def id_loader(**options):
load_calls = []

def fn(keys):
load_calls.append(keys)
return Promise.resolve(keys)

identity_loader = DataLoader(fn, **options)
return identity_loader, load_calls


def test_batches_multiple_requests():
identity_loader, load_calls = id_loader()

@Promise.safe
def safe():
promise1 = identity_loader.load(1)
promise2 = identity_loader.load(2)
return promise1, promise2

promise1, promise2 = safe()
value1, value2 = Promise.all([promise1, promise2]).get()
assert value1 == 1
assert value2 == 2

assert load_calls == [[1, 2]]


def test_batches_multiple_requests_two():
identity_loader, load_calls = id_loader()

@Promise.safe
def safe():
promise1 = identity_loader.load(1)
promise2 = identity_loader.load(2)
return Promise.all([promise1, promise2])

p = safe()
value1, value2 = p.get()

assert value1 == 1
assert value2 == 2

assert load_calls == [[1, 2]]


@Promise.safe
def test_batches_multiple_requests_safe():
identity_loader, load_calls = id_loader()

promise1 = identity_loader.load(1)
promise2 = identity_loader.load(2)

p = Promise.all([promise1, promise2])

value1, value2 = p.get()

assert value1 == 1
assert value2 == 2

assert load_calls == [[1, 2]]

0 comments on commit b21d626

Please sign in to comment.