Skip to content

Commit

Permalink
Custom serializers (#143)
Browse files Browse the repository at this point in the history
* feat: add ability to specify custom serializer and deserializer functions

* feat: add _serialize, _deserialize arguments to create_pool

* fix: add back PickleError for backward compatibility

* chore: format code

* fix: apply code review suggestions

* chore: add tests for custom serializers

* chore: add docs paragraph about serializers and an example

* chore: make serializer type hints stricter and define constants in arq.jobs

* chore: make serialization functions public and uniform the names to *lizer

* fix: pass _job_*lizer arguments to recursive call of create_pool

* chore: make job_*lizer arguments public in create_pool

* chore: remove unnecessary defer in serializer example

* chore: add precise error value when testing incompatible serializers

* chore: format code

* chore: add history entry

* chore: make the serializer arguments in ArqRedis public

* chore: adjust indentation in serialization tests

* fix: serializationerror tests

* chore: fix ArqRedis' docstring
  • Loading branch information
rubik authored and samuelcolvin committed Aug 11, 2019
1 parent b2d397a commit f0fa671
Show file tree
Hide file tree
Showing 12 changed files with 283 additions and 71 deletions.
4 changes: 4 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
History
-------

v0.17 (unreleased)
..................
* custom serializers, eg. to use msgpack rather than pickle, #143 by @rubik

v0.16.1 (2019-08-02)
....................
* prevent duplicate ``job_id`` when job result exists, fix #137
Expand Down
47 changes: 37 additions & 10 deletions arq/connections.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import functools
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta
Expand All @@ -10,7 +11,7 @@
from aioredis import MultiExecError, Redis

from .constants import default_queue_name, job_key_prefix, result_key_prefix
from .jobs import Job, JobResult, pickle_job
from .jobs import Deserializer, Job, JobResult, Serializer, serialize_job
from .utils import timestamp_ms, to_ms, to_unix_ms

logger = logging.getLogger('arq.connections')
Expand Down Expand Up @@ -43,8 +44,24 @@ def __repr__(self):
class ArqRedis(Redis):
"""
Thin subclass of ``aioredis.Redis`` which adds :func:`arq.connections.enqueue_job`.
:param redis_settings: an instance of ``arq.connections.RedisSettings``.
:param job_serializer: a function that serializes Python objects to bytes, defaults to pickle.dumps
:param job_deserializer: a function that deserializes bytes into Python objects, defaults to pickle.loads
:param kwargs: keyword arguments directly passed to ``aioredis.Redis``.
"""

def __init__(
self,
pool_or_conn,
job_serializer: Optional[Serializer] = None,
job_deserializer: Optional[Deserializer] = None,
**kwargs,
) -> None:
self.job_serializer = job_serializer
self.job_deserializer = job_deserializer
super().__init__(pool_or_conn, **kwargs)

async def enqueue_job(
self,
function: str,
Expand Down Expand Up @@ -98,7 +115,7 @@ async def enqueue_job(

expires_ms = expires_ms or score - enqueue_time_ms + expires_extra_ms

job = pickle_job(function, args, kwargs, _job_try, enqueue_time_ms)
job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer)
tr = conn.multi_exec()
tr.psetex(job_key, expires_ms, job)
tr.zadd(_queue_name, score, job_id)
Expand All @@ -107,11 +124,11 @@ async def enqueue_job(
except MultiExecError:
# job got enqueued since we checked 'job_exists'
return
return Job(job_id, self)
return Job(job_id, redis=self, _deserializer=self.job_deserializer)

async def _get_job_result(self, key):
job_id = key[len(result_key_prefix) :]
job = Job(job_id, self)
job = Job(job_id, self, _deserializer=self.job_deserializer)
r = await job.result_info()
r.job_id = job_id
return r
Expand All @@ -125,7 +142,13 @@ async def all_job_results(self) -> List[JobResult]:
return sorted(results, key=attrgetter('enqueue_time'))


async def create_pool(settings: RedisSettings = None, *, _retry: int = 0) -> ArqRedis:
async def create_pool(
settings: RedisSettings = None,
*,
retry: int = 0,
job_serializer: Optional[Serializer] = None,
job_deserializer: Optional[Deserializer] = None,
) -> ArqRedis:
"""
Create a new redis pool, retrying up to ``conn_retries`` times if the connection fails.
Expand All @@ -141,29 +164,33 @@ async def create_pool(settings: RedisSettings = None, *, _retry: int = 0) -> Arq
password=settings.password,
timeout=settings.conn_timeout,
encoding='utf8',
commands_factory=ArqRedis,
commands_factory=functools.partial(
ArqRedis, job_serializer=job_serializer, job_deserializer=job_deserializer
),
)
except (ConnectionError, OSError, aioredis.RedisError, asyncio.TimeoutError) as e:
if _retry < settings.conn_retries:
if retry < settings.conn_retries:
logger.warning(
'redis connection error %s:%s %s %s, %d retries remaining...',
settings.host,
settings.port,
e.__class__.__name__,
e,
settings.conn_retries - _retry,
settings.conn_retries - retry,
)
await asyncio.sleep(settings.conn_retry_delay)
else:
raise
else:
if _retry > 0:
if retry > 0:
logger.info('redis connection successful')
return pool

# recursively attempt to create the pool outside the except block to avoid
# "During handling of the above exception..." madness
return await create_pool(settings, _retry=_retry + 1)
return await create_pool(
settings, retry=retry + 1, job_serializer=job_serializer, job_deserializer=job_deserializer
)


async def log_redis_info(redis, log_func):
Expand Down
126 changes: 83 additions & 43 deletions arq/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Optional
from typing import Any, Callable, Dict, Optional

from .constants import default_queue_name, in_progress_key_prefix, job_key_prefix, result_key_prefix
from .utils import ms_to_datetime, poll, timestamp_ms

logger = logging.getLogger('arq.jobs')

Serializer = Callable[[Dict[str, Any]], bytes]
Deserializer = Callable[[bytes], Dict[str, Any]]


class JobStatus(str, Enum):
"""
Expand Down Expand Up @@ -53,12 +56,15 @@ class Job:
Holds data a reference to a job.
"""

__slots__ = 'job_id', '_redis', '_queue_name'
__slots__ = 'job_id', '_redis', '_queue_name', '_deserializer'

def __init__(self, job_id: str, redis, _queue_name: str = default_queue_name):
def __init__(
self, job_id: str, redis, _queue_name: str = default_queue_name, _deserializer: Optional[Deserializer] = None
):
self.job_id = job_id
self._redis = redis
self._queue_name = _queue_name
self._deserializer = _deserializer

async def result(self, timeout: Optional[float] = None, *, pole_delay: float = 0.5) -> Any:
"""
Expand Down Expand Up @@ -87,7 +93,7 @@ async def info(self) -> Optional[JobDef]:
if not info:
v = await self._redis.get(job_key_prefix + self.job_id, encoding=None)
if v:
info = unpickle_job(v)
info = deserialize_job(v, deserializer=self._deserializer)
if info:
info.score = await self._redis.zscore(self._queue_name, self.job_id)
return info
Expand All @@ -99,7 +105,7 @@ async def result_info(self) -> Optional[JobResult]:
"""
v = await self._redis.get(result_key_prefix + self.job_id, encoding=None)
if v:
return unpickle_result(v)
return deserialize_result(v, deserializer=self._deserializer)

async def status(self) -> JobStatus:
"""
Expand All @@ -119,19 +125,29 @@ def __repr__(self):
return f'<arq job {self.job_id}>'


class PickleError(RuntimeError):
class SerializationError(RuntimeError):
pass


def pickle_job(function_name: str, args: tuple, kwargs: dict, job_try: int, enqueue_time_ms: int):
def serialize_job(
function_name: str,
args: tuple,
kwargs: dict,
job_try: int,
enqueue_time_ms: int,
*,
serializer: Optional[Serializer] = None,
) -> Optional[bytes]:
data = {'t': job_try, 'f': function_name, 'a': args, 'k': kwargs, 'et': enqueue_time_ms}
if serializer is None:
serializer = pickle.dumps
try:
return pickle.dumps(data)
return serializer(data)
except Exception as e:
raise PickleError(f'unable to pickle job "{function_name}"') from e
raise SerializationError(f'unable to serialize job "{function_name}"') from e


def pickle_result(
def serialize_result(
function: str,
args: tuple,
kwargs: dict,
Expand All @@ -142,6 +158,8 @@ def pickle_result(
start_ms: int,
finished_ms: int,
ref: str,
*,
serializer: Optional[Serializer] = None,
) -> Optional[bytes]:
data = {
't': job_try,
Expand All @@ -154,41 +172,63 @@ def pickle_result(
'st': start_ms,
'ft': finished_ms,
}
if serializer is None:
serializer = pickle.dumps
try:
return pickle.dumps(data)
return serializer(data)
except Exception:
logger.warning('error pickling result of %s', ref, exc_info=True)
logger.warning('error serializing result of %s', ref, exc_info=True)

data.update(r=PickleError('unable to pickle result'), s=False)
data.update(r=SerializationError('unable to serialize result'), s=False)
try:
return pickle.dumps(data)
return serializer(data)
except Exception:
logger.critical('error pickling result of %s even after replacing result', ref, exc_info=True)


def unpickle_job(r: bytes) -> JobDef:
d = pickle.loads(r)
return JobDef(
function=d['f'], args=d['a'], kwargs=d['k'], job_try=d['t'], enqueue_time=ms_to_datetime(d['et']), score=None
)


def unpickle_job_raw(r: bytes) -> tuple:
d = pickle.loads(r)
return d['f'], d['a'], d['k'], d['t'], d['et']


def unpickle_result(r: bytes) -> JobResult:
d = pickle.loads(r)
return JobResult(
job_try=d['t'],
function=d['f'],
args=d['a'],
kwargs=d['k'],
enqueue_time=ms_to_datetime(d['et']),
score=None,
success=d['s'],
result=d['r'],
start_time=ms_to_datetime(d['st']),
finish_time=ms_to_datetime(d['ft']),
)
logger.critical('error serializing result of %s even after replacing result', ref, exc_info=True)


def deserialize_job(r: bytes, *, deserializer: Optional[Deserializer] = None) -> JobDef:
if deserializer is None:
deserializer = pickle.loads
try:
d = deserializer(r)
return JobDef(
function=d['f'],
args=d['a'],
kwargs=d['k'],
job_try=d['t'],
enqueue_time=ms_to_datetime(d['et']),
score=None,
)
except Exception as e:
raise SerializationError(f'unable to deserialize job: {r!r}') from e


def deserialize_job_raw(r: bytes, *, deserializer: Optional[Deserializer] = None) -> tuple:
if deserializer is None:
deserializer = pickle.loads
try:
d = deserializer(r)
return d['f'], d['a'], d['k'], d['t'], d['et']
except Exception as e:
raise SerializationError(f'unable to deserialize job: {r!r}') from e


def deserialize_result(r: bytes, *, deserializer: Optional[Deserializer] = None) -> JobResult:
if deserializer is None:
deserializer = pickle.loads
try:
d = deserializer(r)
return JobResult(
job_try=d['t'],
function=d['f'],
args=d['a'],
kwargs=d['k'],
enqueue_time=ms_to_datetime(d['et']),
score=None,
success=d['s'],
result=d['r'],
start_time=ms_to_datetime(d['st']),
finish_time=ms_to_datetime(d['ft']),
)
except Exception as e:
raise SerializationError(f'unable to deserialize job result: {r!r}') from e
Loading

0 comments on commit f0fa671

Please sign in to comment.