Skip to content

Commit

Permalink
fix: add queue tests
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-friday committed Aug 7, 2023
1 parent db51c8d commit e9f7178
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 8 deletions.
24 changes: 17 additions & 7 deletions src/async_wrapper/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ContextManager,
Generator,
Generic,
Literal,
TypeVar,
)

Expand Down Expand Up @@ -70,7 +71,7 @@ class Queue(Generic[ValueT]):
>>> async with anyio.create_task_group() as task_group:
>>> async with queue.aputter:
>>> for i in range(10):
>>> task_group.start_soon(aput, queue.clone(putter=True), i)
>>> task_group.start_soon(aput, queue.cloning.putter, i)
>>>
>>> async with queue.agetter:
>>> result = {x async for x in queue}
Expand Down Expand Up @@ -414,8 +415,8 @@ class _RestrictedQueue(Queue[ValueT], Generic[ValueT]):

def __init__(self, queue: Queue[ValueT], *, putter: bool, getter: bool) -> None:
self._queue = queue
if not getter and not putter:
raise QueueRestrictedError("putter and getter are all False")
if getter is putter:
raise QueueRestrictedError("putter and getter are the same")
self._do_putter = putter
self._do_getter = getter

Expand Down Expand Up @@ -443,7 +444,7 @@ def _close_stream(self) -> bool:
return self._close_getter
if self._do_putter:
return self._close_putter
raise RuntimeError("never")
raise RuntimeError("never") # pragma: no cover

@property
def _stream(
Expand All @@ -453,12 +454,12 @@ def _stream(
return self._getter
if self._do_putter:
return self._putter
raise RuntimeError("never")
raise RuntimeError("never") # pragma: no cover

@property
@override
def _closed(self) -> bool:
return not self._close_stream and self._stream._closed # noqa: SLF001
return not self._close_stream or self._stream._closed # noqa: SLF001

@override
def qsize(self) -> int:
Expand Down Expand Up @@ -529,16 +530,25 @@ def __repr__(self) -> str:
elif self._do_putter:
where = "putter"
else:
raise RuntimeError("never")
raise RuntimeError("never") # pragma: no cover
return _render("RestrictedQueue", max=max_size, size=size, where=where)


class _Cloning(Generic[ValueT]):
__slots__ = ("_queue",)

def __init__(self, queue: Queue[ValueT]) -> None:
if queue._closed: # noqa: SLF001
raise QueueClosedError("queue is already closed")
self._queue = queue

def create(self, where: Literal["putter", "getter"]) -> _RestrictedQueue[ValueT]:
if where == "putter":
return self.putter
if where == "getter":
return self.getter
raise RuntimeError("never") # pragma: no cover

@property
def putter(self) -> _RestrictedQueue[ValueT]:
self._raise_if_closed()
Expand Down
162 changes: 161 additions & 1 deletion tests/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from anyio import CancelScope, create_task_group, fail_after, wait_all_tasks_blocked

from async_wrapper import Queue, create_queue
from async_wrapper.exception import QueueBrokenError, QueueClosedError
from async_wrapper.exception import (
QueueBrokenError,
QueueClosedError,
QueueRestrictedError,
)
from async_wrapper.queue import _RestrictedQueue


def test_invalid_max_buffer() -> None:
Expand Down Expand Up @@ -114,6 +119,26 @@ async def getter() -> None:
assert result == ["hello", "anyio"]


@pytest.mark.anyio()
async def test_iterate_using_cloning() -> None:
queue: Queue[str] = create_queue()
result: list[str] = []
getter_queue = queue.cloning.getter

async def getter() -> None:
async with getter_queue:
async for item in getter_queue:
result.append(item) # noqa: PERF402

async with create_task_group() as task_group:
task_group.start_soon(getter)
await queue.aput("hello")
await queue.aput("anyio")
await queue.aclose()

assert result == ["hello", "anyio"]


@pytest.mark.anyio()
async def test_aget_aput_closed_queue() -> None:
queue: Queue[Any] = create_queue()
Expand Down Expand Up @@ -142,13 +167,41 @@ async def test_clone() -> None:
assert queue2.get() == "hello"


@pytest.mark.anyio()
async def test_clone_using_cloning() -> None:
queue: Queue[str] = create_queue(1)
putter = queue.cloning.putter
getter = queue.cloning.getter

await queue.aclose()
putter.put("hello")
assert getter.get() == "hello"


@pytest.mark.anyio()
async def test_clone_closed() -> None:
queue: Queue[str] = create_queue(1)
await queue.aclose()
pytest.raises(QueueClosedError, queue.clone)


@pytest.mark.anyio()
async def test_clone_closed_using_cloning() -> None:
queue: Queue[str] = create_queue(1)
await queue.aclose()
with pytest.raises(QueueClosedError, match="queue is already closed"):
_ = queue.cloning


@pytest.mark.anyio()
async def test_clone_closed_using_cloning_after_create() -> None:
queue: Queue[str] = create_queue(1)
clone = queue.cloning
await queue.aclose()
with pytest.raises(QueueClosedError, match="queue is already closed"):
_ = clone.getter


@pytest.mark.anyio()
async def test_aget_when_cancelled() -> None:
queue: Queue[str] = create_queue()
Expand Down Expand Up @@ -319,6 +372,32 @@ async def test_get(q: Queue[Any]) -> None:
assert status.open_send_streams == 1


@pytest.mark.anyio()
async def test_clone_each_using_cloning():
queue: Queue[Any] = create_queue(1)

async def test_put(q: Queue[Any]) -> None:
async with q:
await q.aput(1)

async def test_get(q: Queue[Any]) -> None:
async with q:
await q.aget()

async with create_task_group() as task_group:
task_group.start_soon(test_put, queue.cloning.putter)
task_group.start_soon(test_put, queue.cloning.putter)
task_group.start_soon(test_get, queue.cloning.getter)
task_group.start_soon(test_get, queue.cloning.getter)

assert not queue._closed # noqa: SLF001
assert queue.empty()

status = queue.statistics()
assert status.open_receive_streams == 1
assert status.open_send_streams == 1


@pytest.mark.anyio()
def test_queue_clone_uset():
queue: Queue[Any] = create_queue(1)
Expand Down Expand Up @@ -484,3 +563,84 @@ async def test_queue_repr(x: int | None):
expected_max = x or "inf"
expected_repr = f"<Queue: max={expected_max}, size={size}>"
assert repr(queue) == expected_repr


@pytest.mark.anyio()
@pytest.mark.parametrize("x", chain((None,), range(1, 4)))
async def test_queue_getter_repr_using_cloning(x: int | None):
queue: Queue[Any] = create_queue(x)
size = random.randint(1, x or 10) # noqa: S311

async with create_task_group() as task_group:
for i in range(size):
task_group.start_soon(queue.aput, i)

expected_max = x or "inf"
expected_repr = f"<RestrictedQueue: max={expected_max}, size={size}, where={{}}>"
for where in ("getter", "putter"):
clone = queue.cloning.create(where)
assert repr(clone) == expected_repr.format(where)


@pytest.mark.anyio()
async def test_restricted_queue_error():
queue = create_queue()
clone = queue.cloning
getter = clone.getter
putter = clone.putter

putter.put(1)
with pytest.raises(QueueRestrictedError, match="putter is restricted"):
getter.put(1)
with pytest.raises(QueueRestrictedError, match="putter is restricted"):
await getter.aput(1)
with pytest.raises(QueueRestrictedError, match="getter is restricted"):
putter.get()
with pytest.raises(QueueRestrictedError, match="getter is restricted"):
await putter.aget()

with pytest.raises(TypeError, match="do not clone restricted queue"):
_ = getter.cloning
with pytest.raises(TypeError, match="do not clone restricted queue"):
_ = putter.cloning


@pytest.mark.anyio()
async def test_restricted_queue_eixt():
queue = create_queue()
result = []

async def aget(queue: Queue[Any]) -> None:
with queue:
value = await queue.aget()
assert queue._closed # noqa: SLF001
result.append(value)

def put(queue: Queue[Any]) -> None:
with queue:
queue.put(1)
assert queue._closed # noqa: SLF001

async with create_task_group() as task_group:
task_group.start_soon(aget, queue.cloning.getter)
put(queue.cloning.putter)

assert result == [1]


def test_create_restricted_queue_error():
queue = create_queue()

with pytest.raises(QueueRestrictedError, match="putter and getter are the same"):
_ = _RestrictedQueue(queue, putter=True, getter=True)

with pytest.raises(QueueRestrictedError, match="putter and getter are the same"):
_ = _RestrictedQueue(queue, putter=False, getter=False)


def test_restricted_queue_stats():
queue = create_queue()
getter, putter = queue.cloning.getter, queue.cloning.putter

assert getter.statistics() == queue.statistics()
assert putter.statistics() == queue.statistics()

0 comments on commit e9f7178

Please sign in to comment.