Skip to content

Commit

Permalink
Make 'async def' tests work (#2583)
Browse files Browse the repository at this point in the history
- Add pytest-asyncio dev dependency
- Detect undecorated async def methods (which would otherwise vacuously pass) via conftest.py/pytest_pyfunc_call
- Refactor existing async tests into async def test methods
- Remove cirq.testing.assert_asyncio_will_have_result
- Remove cirq.testing.assert_asyncio_will_raise
- Remove cirq.testing.assert_asyncio_still_running
- Add cirq.testing.asyncio_not_finishing with forgot-to-await detection
  • Loading branch information
Strilanc authored and CirqBot committed Nov 22, 2019
1 parent d963a02 commit 55c1590
Show file tree
Hide file tree
Showing 12 changed files with 150 additions and 189 deletions.
9 changes: 9 additions & 0 deletions cirq/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import inspect

import matplotlib.pyplot as plt


def pytest_configure(config):
# Use matplotlib agg backend which does not require a display.
plt.switch_backend('agg')


def pytest_pyfunc_call(pyfuncitem):
if inspect.iscoroutinefunction(pyfuncitem._obj):
# coverage: ignore
raise ValueError(f'{pyfuncitem._obj.__name__} is async but not '
f'decorated with "@pytest.mark.asyncio".')
5 changes: 3 additions & 2 deletions cirq/sim/simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ def text(self, to_print):
assert p.text_pretty == 'SimulationTrialResult(...)'


def test_async_sample():
@pytest.mark.asyncio
async def test_async_sample():
m = {'mock': np.array([[0], [1]])}

class MockSimulator(cirq.SimulatesSamples):
Expand All @@ -299,7 +300,7 @@ def _run(self, circuit, param_resolver, repetitions):

q = cirq.LineQubit(0)
f = MockSimulator().run_async(cirq.Circuit(cirq.measure(q)), repetitions=10)
result = cirq.testing.assert_asyncio_will_have_result(f)
result = await f
np.testing.assert_equal(result.measurements, m)


Expand Down
5 changes: 1 addition & 4 deletions cirq/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
"""Utilities for testing code."""

from cirq.testing.asynchronous import (
assert_asyncio_still_running,
assert_asyncio_will_have_result,
assert_asyncio_will_raise,
)
asyncio_pending,)

from cirq.testing.circuit_compare import (
assert_circuits_with_terminal_measurements_are_equivalent,
Expand Down
130 changes: 44 additions & 86 deletions cirq/testing/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,103 +13,61 @@
# limitations under the License.

import asyncio
import re
from collections.abc import Awaitable, Coroutine
from typing import Any, Type, Union
from typing import Union, Awaitable, Coroutine

# A placeholder default value used to detect that callers did not specify an
# 'expected' value argument in `assert_asyncio_will_have_result`, and so the
# result should be returned without checking it.
JUST_RETURN_RESULT = object() # type: Any

def asyncio_pending(future: Union[Awaitable, asyncio.Future, Coroutine],
timeout: float = 0.001) -> Awaitable[bool]:
"""Gives the given future a chance to complete, and determines if it didn't.
def _run_loop_waiting_for(future: Union[Awaitable, asyncio.Future, Coroutine],
timeout: float):
return asyncio.get_event_loop().run_until_complete(
asyncio.wait_for(asyncio.shield(future), timeout=timeout))


def assert_asyncio_still_running(
future: Union[Awaitable, asyncio.Future, Coroutine],
timeout: float = 0.001):
"""Checks that the given asyncio future has not completed.
Works by running the asyncio event loop for a short amount of time.
This method is used in tests checking that a future actually depends on some
given event having happened. The test can assert, before the event, that the
future is still pending and then assert, after the event, that the future
has a result.
Args:
future: The future that should not yet be resolved.
timeout: The number of seconds to wait for the future. Make sure this is
a small value, because it holds up the passing test!
future: The future that may or may not be able to resolve when given
a bit of time.
timeout: The number of seconds to wait for the future. This should
generally be a small value (milliseconds) when expecting the future
to not resolve, and a large value (seconds) when expecting the
future to resolve.
Raises:
AssertError: The future completed or failed within the timeout.
Returns:
True if the future is still pending after the timeout elapses. False if
the future did complete (or fail) or was already completed (or already
failed).
Examples:
>>> import asyncio
>>> import pytest
>>> @pytest.mark.asyncio
... async def test_completion_only_when_expected():
... f = asyncio.Future()
... assert await cirq.testing.asyncio_pending(f)
... f.set_result(5)
... assert await f == 5
"""
try:
_run_loop_waiting_for(future, timeout)
assert False, "Not running: {!r}".format(future)
except asyncio.TimeoutError:
pass

async def body():
f = asyncio.shield(future)
t = asyncio.ensure_future(asyncio.sleep(timeout))
done, _ = await asyncio.wait([f, t],
return_when=asyncio.FIRST_COMPLETED)
t.cancel()
return f not in done

def assert_asyncio_will_have_result(
future: Union[Awaitable, asyncio.Future, Coroutine],
expected: Any = JUST_RETURN_RESULT,
timeout: float = 1.0) -> Any:
"""Checks that the given asyncio future completes with the given value.
return _AwaitBeforeAssert(body())

Works by running the asyncio event loop for up to the given timeout.

Args:
future: The asyncio awaitable that should complete.
expected: The result that the future should have after it completes.
If not specified, nothing is asserted about the result.
timeout: The maximum number of seconds to run the event loop until the
future resolves.
class _AwaitBeforeAssert:

Returns:
The future's result.
def __init__(self, awaitable: Awaitable):
self.awaitable = awaitable

Raises:
AssertError: The future did not complete in time, or did not contain
the expected result.
"""
try:
actual = _run_loop_waiting_for(future, timeout)
if expected is not JUST_RETURN_RESULT:
assert actual == expected, "{!r} != {!r} from {!r}".format(
actual, expected, future)
return actual
except asyncio.TimeoutError:
assert False, "Not done: {!r}".format(future)
def __bool__(self):
raise RuntimeError('You forgot the "await" in '
'"assert await cirq.testing.asyncio_pending(...)".')


def assert_asyncio_will_raise(
future: Union[Awaitable, asyncio.Future, Coroutine],
expected: Type,
*,
match: str,
timeout: float = 1.0):
"""Checks that the given asyncio future fails with a matching error.
Works by running the asyncio event loop for up to the given timeout.
Args:
future: The asyncio awaitable that should error.
expected: The exception type that the future should end up containing.
match: A regex that must match the exception's message.
timeout: The maximum number of seconds to run the event loop until the
future resolves.
Raises:
AssertError: The future did not resolve in time, or did not contain
a matching exception.
"""
try:
_run_loop_waiting_for(future, timeout)
except expected as exc:
if match and not re.search(match, str(exc)):
assert False, "Pattern '{}' not found in '{}'".format(match, exc)
except asyncio.TimeoutError:
assert False, "Not done: {!r}".format(future)
else:
assert False, "DID NOT RAISE {}".format(expected)
def __await__(self):
return self.awaitable.__await__()
49 changes: 14 additions & 35 deletions cirq/testing/asynchronous_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,48 +19,27 @@
import cirq


def test_assert_still_running():
@pytest.mark.asyncio
async def test_asyncio_pending():
f = asyncio.Future()
cirq.testing.assert_asyncio_still_running(f)

assert await cirq.testing.asyncio_pending(f)
f.set_result(5)
with pytest.raises(AssertionError, match="Not running"):
cirq.testing.assert_asyncio_still_running(f)
assert not await cirq.testing.asyncio_pending(f)
assert not await cirq.testing.asyncio_pending(f, timeout=100)

e = asyncio.Future()
e.set_exception(ValueError('test fail'))
with pytest.raises(ValueError, match="test fail"):
cirq.testing.assert_asyncio_still_running(e)


def test_assert_will_have_result():
f = asyncio.Future()
with pytest.raises(AssertionError, match="Not done"):
cirq.testing.assert_asyncio_will_have_result(f, 5, timeout=0.01)

f.set_result(5)
assert cirq.testing.assert_asyncio_will_have_result(f, 5) == 5
with pytest.raises(AssertionError, match="!="):
cirq.testing.assert_asyncio_will_have_result(f, 6)

e = asyncio.Future()
assert await cirq.testing.asyncio_pending(e)
e.set_exception(ValueError('test fail'))
with pytest.raises(ValueError, match="test fail"):
cirq.testing.assert_asyncio_will_have_result(e, 5)
assert not await cirq.testing.asyncio_pending(e)
assert not await cirq.testing.asyncio_pending(e, timeout=100)


def test_assert_will_raise():
@pytest.mark.asyncio
async def test_asyncio_pending_common_mistake_caught():
f = asyncio.Future()
with pytest.raises(AssertionError, match="Not done"):
cirq.testing.assert_asyncio_will_raise(f,
ValueError,
match='',
timeout=0.01)

f.set_result(5)
with pytest.raises(BaseException, match="DID NOT RAISE"):
cirq.testing.assert_asyncio_will_raise(f, ValueError, match='')

e = asyncio.Future()
e.set_exception(ValueError('test fail'))
cirq.testing.assert_asyncio_will_raise(e, ValueError, match="test fail")
pending = cirq.testing.asyncio_pending(f)
with pytest.raises(RuntimeError, match='forgot the "await"'):
assert pending
assert await pending
6 changes: 4 additions & 2 deletions cirq/work/collector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

import cirq

Expand All @@ -34,7 +35,8 @@ def test_circuit_sample_job_repr():
tag='guess'))


def test_async_collect():
@pytest.mark.asyncio
async def test_async_collect():
received = []

class TestCollector(cirq.Collector):
Expand All @@ -52,7 +54,7 @@ def on_job_result(self, job, result):
completion = TestCollector().collect_async(sampler=cirq.Simulator(),
max_total_samples=100,
concurrency=5)
cirq.testing.assert_asyncio_will_have_result(completion, None)
assert await completion is None
assert received == ['test'] * 10


Expand Down
12 changes: 8 additions & 4 deletions cirq/work/pauli_sum_collector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

import cirq


def test_pauli_string_sample_collector():
@pytest.mark.asyncio
async def test_pauli_string_sample_collector():
a, b = cirq.LineQubit.range(2)
p = cirq.PauliSumCollector(circuit=cirq.Circuit(cirq.H(a), cirq.CNOT(a, b),
cirq.X(a), cirq.Z(b)),
Expand All @@ -24,18 +27,19 @@ def test_pauli_string_sample_collector():
4 * cirq.Z(a) * cirq.Z(b),
samples_per_term=100)
completion = p.collect_async(sampler=cirq.Simulator())
cirq.testing.assert_asyncio_will_have_result(completion, None)
assert await completion is None
assert p.estimated_energy() == 11


def test_pauli_string_sample_single():
@pytest.mark.asyncio
async def test_pauli_string_sample_single():
a, b = cirq.LineQubit.range(2)
p = cirq.PauliSumCollector(circuit=cirq.Circuit(cirq.H(a), cirq.CNOT(a, b),
cirq.X(a), cirq.Z(b)),
observable=cirq.X(a) * cirq.X(b),
samples_per_term=100)
completion = p.collect_async(sampler=cirq.Simulator())
cirq.testing.assert_asyncio_will_have_result(completion, None)
assert await completion is None
assert p.estimated_energy() == -1


Expand Down
22 changes: 11 additions & 11 deletions cirq/work/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,21 @@
import cirq


def test_sampler_async_fail():
@pytest.mark.asyncio
async def test_sampler_async_fail():

class FailingSampler(cirq.Sampler):

def run_sweep(self, program, params, repetitions: int = 1):
raise ValueError('test')

cirq.testing.assert_asyncio_will_raise(FailingSampler().run_async(
cirq.Circuit(), repetitions=1),
ValueError,
match='test')
with pytest.raises(ValueError, match='test'):
await FailingSampler().run_async(cirq.Circuit(), repetitions=1)

cirq.testing.assert_asyncio_will_raise(FailingSampler().run_sweep_async(
cirq.Circuit(), repetitions=1, params=None),
ValueError,
match='test')
with pytest.raises(ValueError, match='test'):
await FailingSampler().run_sweep_async(cirq.Circuit(),
repetitions=1,
params=None)


def test_sampler_sample_multiple_params():
Expand Down Expand Up @@ -141,7 +140,8 @@ def test_sampler_sample_inconsistent_keys():
])


def test_sampler_async_not_run_inline():
@pytest.mark.asyncio
async def test_sampler_async_not_run_inline():
ran = False

class S(cirq.Sampler):
Expand All @@ -153,5 +153,5 @@ def run_sweep(self, *args, **kwargs):

a = S().run_sweep_async(cirq.Circuit(), params=None)
assert not ran
cirq.testing.assert_asyncio_will_have_result(a, [])
assert await a == []
assert ran

0 comments on commit 55c1590

Please sign in to comment.