Skip to content

Commit

Permalink
Reimplement channel_listen as context manager (#2856)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonathan Ehwald <github@ehwald.info>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Jul 4, 2023
1 parent 29baae0 commit c18be1f
Show file tree
Hide file tree
Showing 5 changed files with 479 additions and 15 deletions.
43 changes: 43 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
Release type: minor

This release updates the API to listen to Django Channels to avoid race conditions
when confirming GraphQL subscriptions.

**Deprecations:**

This release contains a deprecation for the Channels integration. The `channel_listen`
method will be replaced with an async context manager that returns an awaitable
AsyncGenerator. This method is called `listen_to_channel`.

An example of migrating existing code is given below:

```py
# Existing code
@strawberry.type
class MyDataType:
name: str

@strawberry.type
class Subscription:
@strawberry.subscription
async def my_data_subscription(
self, info: Info, groups: list[str]
) -> AsyncGenerator[MyDataType | None, None]:
yield None
async for message in info.context["ws"].channel_listen("my_data", groups=groups):
yield MyDataType(name=message["payload"])
```

```py
# New code
@strawberry.type
class Subscription:
@strawberry.subscription
async def my_data_subscription(
self, info: Info, groups: list[str]
) -> AsyncGenerator[MyDataType | None, None]:
async with info.context["ws"].listen_to_channel("my_data", groups=groups) as cm:
yield None
async for message in cm:
yield MyDataType(name=message["payload"])
```
66 changes: 53 additions & 13 deletions docs/integrations/channels.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,14 @@ class Subscription:
},
)

async for message in ws.channel_listen("chat.message", groups=room_ids):
if message["room_id"] in room_ids:
yield ChatRoomMessage(
room_name=message["room_id"],
message=message["message"],
current_user=user,
)
async with ws.listen_to_channel("chat.message", groups=room_ids) as cm:
async for message in cm:
if message["room_id"] in room_ids:
yield ChatRoomMessage(
room_name=message["room_id"],
message=message["message"],
current_user=user,
)
```

Explanation: `Info.context["ws"]` or `Info.context["request"]` is a pointer to the
Expand All @@ -153,11 +154,12 @@ message to all the channel_layer groups (specified in the subscription argument
<Note>

The `ChannelsConsumer` instance is shared between all subscriptions created in
a single websocket connection. The `ws.channel_listen` function will yield all
a single websocket connection. The `ws.listen_to_channel` context manager will return
a function to yield all
messages sent using the given message `type` (`chat.message` in the above example)
but does not ensure that the message was sent to the same group or groups that
it was called with - if another subscription using the same `ChannelsConsumer`
also uses `ws.channel_listen` with some other group names, those will be returned
also uses `ws.listen_to_channel` with some other group names, those will be returned
as well.

In the example we ensure `message["room_id"] in room_ids` before passing messages
Expand All @@ -168,9 +170,9 @@ the chat rooms requested in that subscription.

<Note>

We do not need to call `await channel_layer.group_add(room, ws.channel_name)` If
We do not need to call `await channel_layer.group_add(room, ws.channel_name)` if
we don't want to send an initial message while instantiating the subscription.
It is handled by `ws.channel_listen`.
It is handled by `ws.listen_to_channel`.

</Note>

Expand Down Expand Up @@ -355,6 +357,43 @@ Look here for some more complete examples:

---

### Confirming GraphQL Subscriptions

By default no confirmation message is sent to the GraphQL client once the
subscription has started. However, this is useful to be able to synchronize
actions and detect communication errors. The code below shows how the above
example can be adapted to send a null from the server to the client to confirm
that the subscription has successfully started. This includes confirming that
the Channels layer subscription has started.

```python
# mysite/gqlchat/subscription.py


@strawberry.type
class Subscription:
@strawberry.subscription
async def join_chat_rooms(
self,
info: Info,
rooms: List[ChatRoom],
user: str,
) -> AsyncGenerator[ChatRoomMessage | None, None]:
...
async with ws.listen_to_channel("chat.message", groups=room_ids) as cm:
yield None
async for message in cm:
if message["room_id"] in room_ids:
yield ChatRoomMessage(
room_name=message["room_id"],
message=message["message"],
current_user=user,
)
```

Note the change in return signature for `join_chat_rooms` and the `yield None`
after entering the `listen_to_channel` context manger.

## Testing

We provide a minimal application communicator (`GraphQLWebsocketCommunicator`) for subscribing.
Expand Down Expand Up @@ -620,13 +659,14 @@ Every graphql session will have an instance of this class inside `info.context["
#### properties

```python
async def channel_listen(
@contextlib.asynccontextmanager
async def listen_to_channel(
self,
type: str,
*,
timeout: float | None = None,
groups: Sequence[str] | None = None
): # AsyncGenerator
) -> AsyncGenerator[Any, None]:
...
```

Expand Down
77 changes: 76 additions & 1 deletion strawberry/channels/handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import contextlib
import warnings
from collections import defaultdict
from typing import (
Any,
Expand Down Expand Up @@ -109,8 +110,9 @@ async def channel_listen(
using `self.channel_layer.group_add` at the beggining of the
execution and then discarded using `self.channel_layer.group_discard`
at the end of the execution.
"""

warnings.warn("Use listen_to_channel instead", DeprecationWarning, stacklevel=2)
if self.channel_layer is None:
raise RuntimeError(
"Layers integration is required listening for channels.\n"
Expand Down Expand Up @@ -144,6 +146,79 @@ async def channel_listen(
with contextlib.suppress(Exception):
await self.channel_layer.group_discard(group, self.channel_name)

@contextlib.asynccontextmanager
async def listen_to_channel(
self,
type: str,
*,
timeout: Optional[float] = None,
groups: Sequence[str] = (),
) -> AsyncGenerator[Any, None]:
"""Listen for messages sent to this consumer.
Utility to listen for channels messages for this consumer inside
a resolver (usually inside a subscription).
Parameters:
type:
The type of the message to wait for.
timeout:
An optional timeout to wait for each subsequent message
groups:
An optional sequence of groups to receive messages from.
When passing this parameter, the groups will be registered
using `self.channel_layer.group_add` at the beggining of the
execution and then discarded using `self.channel_layer.group_discard`
at the end of the execution.
"""

# Code to acquire resource (Channels subscriptions)
if self.channel_layer is None:
raise RuntimeError(
"Layers integration is required listening for channels.\n"
"Check https://channels.readthedocs.io/en/stable/topics/channel_layers.html " # noqa:E501
"for more information"
)

added_groups = []
# This queue will receive incoming messages for this generator instance
queue: asyncio.Queue = asyncio.Queue()
# Create a weak reference to the queue. Once we leave the current scope, it
# will be garbage collected
self.listen_queues[type].add(queue)

# Subscribe to all groups but return generator object to allow user
# code to run before blocking on incoming messages
for group in groups:
await self.channel_layer.group_add(group, self.channel_name)
added_groups.append(group)
try:
yield self._listen_to_channel_generator(queue, timeout)
finally:
# Code to release resource (Channels subscriptions)
for group in added_groups:
with contextlib.suppress(Exception):
await self.channel_layer.group_discard(group, self.channel_name)

async def _listen_to_channel_generator(
self, queue: asyncio.Queue, timeout: Optional[float]
) -> AsyncGenerator[Any, None]:
"""Generator for listen_to_channel method.
Seperated to allow user code to be run after subscribing to channels
and before blocking to wait for incoming channel messages.
"""

while True:
awaitable = queue.get()
if timeout is not None:
awaitable = asyncio.wait_for(awaitable, timeout)
try:
yield await awaitable
except asyncio.TimeoutError:
# TODO: shall we add log here and maybe in the suppress below?
return


class ChannelsWSConsumer(ChannelsConsumer, AsyncJsonWebsocketConsumer):
"""Base channels websocket async consumer."""
Loading

0 comments on commit c18be1f

Please sign in to comment.