Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New version: Task Streams via redis xgroups #168

Merged
merged 1 commit into from
Jul 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions examples/test_pubsub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import asyncio
import time
from qw.client import QClient, TaskWrapper
from qw.utils import cPrint

qw = QClient()

print('SERVER : ', qw.get_servers())


async def very_long_task(seconds: int):
if seconds == 17:
raise ValueError('BAD BOYS')
print(f'This Function Sleep for {seconds} sec.')
await asyncio.sleep(seconds)


async def queue_task():
await qw.publish(very_long_task, 10)
task = TaskWrapper(
program='troc',
task='organizations',
debug=True,
ignore_results=True
)
res = await asyncio.gather(
*[
qw.publish(task)
]
)
print(f'Task Queued: {res!s}')
await qw.publish(very_long_task, 15)

if __name__ == '__main__':
start_time = time.time()
loop = asyncio.get_event_loop()
top = loop.run_until_complete(
queue_task()
)
end_time = time.time() - start_time
cPrint(f'Task took {end_time} seconds to run', level='DEBUG')
65 changes: 63 additions & 2 deletions qw/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import asyncio
import itertools
import random
import uuid
import warnings
import inspect
import socket
import base64
from typing import Any, Union
from collections.abc import Callable, Awaitable
from collections import defaultdict
from functools import partial
import aioredis
from redis import asyncio as aioredis
import pickle
import cloudpickle
import jsonpickle
Expand All @@ -26,6 +28,7 @@
WORKER_DEFAULT_HOST,
WORKER_DEFAULT_PORT,
WORKER_REDIS,
REDIS_WORKER_STREAM,
USE_DISCOVERY,
WORKER_SECRET_KEY,
expected_message
Expand Down Expand Up @@ -478,3 +481,61 @@ async def queue(self, fn: Any, *args, use_wrapper: bool = True, **kwargs):
self.logger.exception(
f'Error Serializing Task: {err!s}'
)

async def publish(self, fn: Any, *args, use_wrapper: bool = True, **kwargs):
"""Publish a function into a Pub/Sub Channel.

Send & Forget functionality to send a task to Queue Worker using Pub/Sub.

Args:
fn: Any Function, object or callable to be send to Worker.
args: any non-keyword arguments
kwargs: keyword arguments.

Returns:
None.

Raises:
ConfigError: bad instructions to Worker Client.
ConnectionError: unable to connect to Worker.
Exception: Any Unhandled error.
"""
self.logger.debug(
f'Sending function {fn!s} to Pub/Sub Channel {REDIS_WORKER_STREAM}'
)
host = socket.gethostbyname(socket.gethostname())
# serializing
func = self.get_wrapped_function(
fn,
host,
*args,
use_wrapper=use_wrapper,
queued=True,
**kwargs
)
if use_wrapper is True:
uid = func.id
else:
uid = uuid.uuid1(
node=random.getrandbits(48) | 0x010000000000
)
serialized_task = cloudpickle.dumps(func)
encoded_task = base64.b64encode(serialized_task).decode('utf-8')
conn = aioredis.from_url(
WORKER_REDIS,
decode_responses=True,
encoding='utf-8'
)
message = {
"uid": str(uid),
"task": encoded_task
}
# check if published
# Add the data to the stream
result = await conn.xadd(REDIS_WORKER_STREAM, message)
serialized_result = {
"status": "Queued",
"task": f"{func!r}",
"message": result
}
return serialized_result
3 changes: 2 additions & 1 deletion qw/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def get_worker_list(workers: list):
REDIS_HOST = config.get('REDIS_HOST', fallback='localhost')
REDIS_PORT = config.getint('REDIS_PORT', fallback=6379)
REDIS_WORKER_DB = config.getint('REDIS_WORKER_DB', fallback=4)
REDIS_WORKER_CHANNEL = config.get('REDIS_WORKER_CHANNEL', fallback='WorkerChannel')
REDIS_WORKER_GROUP = config.get('REDIS_WORKER_CHANNEL', fallback='QWorkerGroup')
REDIS_WORKER_STREAM = config.get('REDIS_WORKER_STREAM', fallback='QWorkerStream')

WORKER_REDIS = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_WORKER_DB}"

Expand Down
90 changes: 76 additions & 14 deletions qw/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import socket
import uuid
import base64
import asyncio
import inspect
import random
Expand All @@ -18,12 +19,14 @@
)
from qw.utils import make_signature
from redis import asyncio as aioredis
from redis.exceptions import ResponseError
from .conf import (
WORKER_DEFAULT_HOST,
WORKER_DEFAULT_PORT,
expected_message,
WORKER_SECRET_KEY,
REDIS_WORKER_CHANNEL,
REDIS_WORKER_STREAM,
REDIS_WORKER_GROUP,
WORKER_REDIS
)
from .utils.json import json_encoder
Expand Down Expand Up @@ -91,17 +94,75 @@ def start_redis(self):
)
self.redis = aioredis.Redis(connection_pool=self.pool)

async def start_subscription(self):
"""Starts PUB/SUB system based on Redis."""
async def ensure_group_exists(self):
try:
# Try to create the group. This will fail if the group already exists.
await self.redis.xgroup_create(
REDIS_WORKER_STREAM, REDIS_WORKER_GROUP, id='$', mkstream=True
)
except ResponseError as e:
if "BUSYGROUP Consumer Group name already exists" not in str(e):
raise
except Exception as exc:
self.logger.exception(exc, stack_info=True)
raise
try:
self.pubsub = self.redis.pubsub()
await self.pubsub.subscribe(REDIS_WORKER_CHANNEL)
# create the consumer:
await self.redis.xgroup_createconsumer(
REDIS_WORKER_STREAM, REDIS_WORKER_GROUP, self._name
)
self.logger.debug(
f":: Creating Consumer {self._name} on Stream {REDIS_WORKER_STREAM}"
)
except Exception as exc:
print(exc)
self.logger.exception(exc, stack_info=True)
raise

async def start_subscription(self):
"""Starts stream consumer group based on Redis."""
try:
await self.ensure_group_exists()
info = await self.redis.xinfo_groups(REDIS_WORKER_STREAM)
self.logger.debug(f'Groups Info: {info}')
while self._running:
try:
msg = await self.pubsub.get_message()
if msg and msg['type'] == 'message':
self.logger.debug(f'Received Task: {msg}')
message_groups = await self.redis.xreadgroup(
REDIS_WORKER_GROUP,
self._name,
streams={REDIS_WORKER_STREAM: '>'},
block=100,
count=1
)
for _, messages in message_groups:
for _id, fn in messages:
try:
encoded_task = fn['task']
task_id = fn['uid']
# Process the task
serialized_task = base64.b64decode(encoded_task)
task = cloudpickle.loads(serialized_task)
self.logger.info(
f'TASK RECEIVED: {task} with id {task_id} at {int(time.time())}'
)
try:
executor = TaskExecutor(task)
await executor.run()
self.logger.info(
f":: TASK {task}.{task_id} Executed at {int(time.time())}"
)
except Exception as e:
self.logger.error(
f"Task {task}:{task_id} failed with error {e}"
)
# If processing raises an exception, the next line won't be executed
await self.redis.xack(
REDIS_WORKER_STREAM,
REDIS_WORKER_GROUP,
_id
)
except Exception as e:
self.logger.error(f"Error processing message: {e}")
await asyncio.sleep(0.001) # sleep a bit to prevent high CPU usage
except ConnectionResetError:
self.logger.error(
Expand All @@ -110,26 +171,26 @@ async def start_subscription(self):
await asyncio.sleep(1) # Wait for a bit before trying to reconnect
await self.start_subscription() # Try to restart the subscription
except asyncio.CancelledError:
await self.pubsub.unsubscribe(REDIS_WORKER_CHANNEL)
break
except KeyboardInterrupt:
break
except Exception as exc:
# Handle other exceptions as necessary
self.logger.error(
f"Error in start_subscription: {exc}"
f"Error Getting Message: {exc}"
)
break
except Exception as exc:
self.logger.error(
f"Could not establish initial connection: {exc}"
)

async def close_redis(self):
try:
# Get a new pubsub object and unsubscribe from 'channel'
try:
await self.pubsub.unsubscribe(REDIS_WORKER_CHANNEL)
# create the consumer:
await self.redis.xgroup_delconsumer(
REDIS_WORKER_STREAM, REDIS_WORKER_GROUP, self._name
)
await asyncio.wait_for(self.redis.close(), timeout=2.0)
except asyncio.TimeoutError:
self.logger.error(
Expand Down Expand Up @@ -483,7 +544,6 @@ async def connection_handler(
message=f'Task {task!s} was discarded, queue full',
writer=writer
)
print('RESULT > ', result)
if result is None:
# Not always a Task returns Value, sometimes returns None.
result = [
Expand Down Expand Up @@ -573,6 +633,8 @@ def start_server(num_worker, host, port, debug: bool):
loop.run_until_complete(
worker.shutdown()
)
except Exception:
pass
finally:
if loop:
loop.close() # Close the event loop
2 changes: 1 addition & 1 deletion qw/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
__description__ = ('QueueWorker is asynchronous Task Queue implementation '
'built on top of Asyncio.'
'Can you spawn distributed workers to run functions inside workers.')
__version__ = '1.8.8'
__version__ = '1.9.0'
__author__ = 'Jesus Lara'
__author_email__ = 'jesuslarag@gmail.com'
__license__ = 'MIT'
Expand Down