Skip to content

Added custom consume many in AsyncConsumer.__call__ #2120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions channels/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_handler_name(message):
raise ValueError("Malformed type in message (leading underscore)")
return handler_name


DEFAULT_AWAIT_MANY_DISPATCH = await_many_dispatch
class AsyncConsumer:
"""
Base consumer class. Implements the ASGI application spec, and adds on
Expand All @@ -33,7 +33,8 @@ class AsyncConsumer:

_sync = False
channel_layer_alias = DEFAULT_CHANNEL_LAYER

await_many_and_dispatch = DEFAULT_AWAIT_MANY_DISPATCH
priority_message_types = ["websocket.disconnect"]
async def __call__(self, scope, receive, send):
"""
Dispatches incoming messages to type-based handlers asynchronously.
Expand All @@ -55,11 +56,11 @@ async def __call__(self, scope, receive, send):
# Pass messages in from channel layer or client to dispatch method
try:
if self.channel_layer is not None:
await await_many_dispatch(
await self.await_many_and_dispatch(
[receive, self.channel_receive], self.dispatch
)
else:
await await_many_dispatch([receive], self.dispatch)
await self.await_many_and_dispatch([receive], self.dispatch)
except StopConsumer:
# Exit cleanly
pass
Expand Down
88 changes: 87 additions & 1 deletion channels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def name_that_thing(thing):
return repr(thing)


async def await_many_dispatch(consumer_callables, dispatch):
async def await_many_dispatch(_, consumer_callables, dispatch):
"""
Given a set of consumer callables, awaits on them all and passes results
from them to the dispatch awaitable as they come in.
Expand Down Expand Up @@ -57,3 +57,89 @@ async def await_many_dispatch(consumer_callables, dispatch):
await task
except asyncio.CancelledError:
pass


class PriorityTaskManager:
def __init__(self, priority_message_types=None):
self.current_task = None
self.close_message_received = False
self.priority_message_types = priority_message_types

async def handle_message(self, message, dispatch):
if self.close_message_received:
return
if message["type"] in self.priority_message_types:
self.close_message_received = True

if self.current_task and not self.current_task.done():
self.current_task.cancel()
try:
await self.current_task
except asyncio.CancelledError:
pass
await dispatch(message)
else:
if self.current_task is None or self.current_task.done():
self.current_task = asyncio.create_task(dispatch(message))
await self.current_task


async def await_many_dispatch_with_priority(self, consumer_callables, dispatch):
"""
Given a set of consumer callables, awaits on them all and passes results
from them to the dispatch awaitable as they come in.
Separate the messages with type "websocket.disconnect" to a priority queue.
As they should be handled prior to shutdown by ASGI server(example: by daphne, if "websocket.disconnect" sent,
daphne waits for application_close_timeout (10 s default)) and kills the Application,
which may cause websocket.disconnect message not handled in consumer properly.
"""
# Call all callables, and ensure all return types are Futures
task_manager = PriorityTaskManager(priority_message_types=self.priority_message_types)
queue = asyncio.Queue()
priority_queue = asyncio.Queue()

async def receive_messages(consumer_callable):
try:
while True:
message = await consumer_callable()
if message.get("type") in self.priority_message_types:
priority_queue.put_nowait(message)
else:
queue.put_nowait(message)
except asyncio.CancelledError:
pass

async def process_messages(queue_to_process: asyncio.Queue):
try:
while True:
message = await queue_to_process.get()
await task_manager.handle_message(message, dispatch)
except asyncio.CancelledError:
pass

producer_tasks = [
asyncio.create_task(receive_messages(consumer_callable))
for consumer_callable in consumer_callables
]

processing_tasks = [
asyncio.create_task(process_messages(queue)),
asyncio.create_task(process_messages(priority_queue)),
]
try:
completed_tasks, pending = await asyncio.wait(
processing_tasks, return_when=asyncio.FIRST_COMPLETED
)
finally:
exception = None
for task in producer_tasks + processing_tasks:
if task.done() and task.exception():
exception = task.exception()
if not task.done():
task.cancel()


# Wait for cancellation to complete
await asyncio.gather(*producer_tasks, *processing_tasks, return_exceptions=True)
if exception:
raise exception