Skip to content

Commit

Permalink
[Results] Allow unserializable return values (#1888)
Browse files Browse the repository at this point in the history
* fix: allow unserializable return values

* fix: review comments
  • Loading branch information
tchapi committed May 1, 2023
1 parent 36f5c88 commit 08cb311
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 3 deletions.
6 changes: 6 additions & 0 deletions rq/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,10 @@
DEFAULT_DEATH_PENALTY_CLASS = 'rq.timeouts.UnixSignalDeathPenalty'
""" The path for the default Death Penalty class to use.
Defaults to the `UnixSignalDeathPenalty` class within the `rq.timeouts` module
"""


UNSERIALIZABLE_RETURN_VALUE_PAYLOAD = 'Unserializable return value'
""" The value that we store in the job's _result property or in the Result's return_value
in case the return value of the actual job is not serializable
"""
4 changes: 2 additions & 2 deletions rq/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, Type
from uuid import uuid4

from .defaults import CALLBACK_TIMEOUT
from .defaults import CALLBACK_TIMEOUT, UNSERIALIZABLE_RETURN_VALUE_PAYLOAD
from .timeouts import JobTimeoutException, BaseDeathPenalty

if TYPE_CHECKING:
Expand Down Expand Up @@ -887,7 +887,7 @@ def restore(self, raw_data) -> Any:
try:
self._result = self.serializer.loads(result)
except Exception:
self._result = "Unserializable return value"
self._result = UNSERIALIZABLE_RETURN_VALUE_PAYLOAD
self.timeout = parse_timeout(obj.get('timeout')) if obj.get('timeout') else None
self.result_ttl = int(obj.get('result_ttl')) if obj.get('result_ttl') else None
self.failure_ttl = int(obj.get('failure_ttl')) if obj.get('failure_ttl') else None
Expand Down
7 changes: 6 additions & 1 deletion rq/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from enum import Enum
from redis import Redis

from .defaults import UNSERIALIZABLE_RETURN_VALUE_PAYLOAD
from .utils import decode_redis_hash
from .job import Job
from .serializers import resolve_serializer
Expand Down Expand Up @@ -181,7 +182,11 @@ def serialize(self):
if self.exc_string is not None:
data['exc_string'] = b64encode(zlib.compress(self.exc_string.encode())).decode()

serialized = self.serializer.dumps(self.return_value)
try:
serialized = self.serializer.dumps(self.return_value)
except: # noqa
serialized = self.serializer.dumps(UNSERIALIZABLE_RETURN_VALUE_PAYLOAD)

if self.return_value is not None:
data['return_value'] = b64encode(serialized).decode()

Expand Down
18 changes: 18 additions & 0 deletions tests/test_results.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
import tempfile

from datetime import timedelta
from unittest.mock import patch, PropertyMock
Expand All @@ -7,6 +8,7 @@

from tests import RQTestCase

from rq.defaults import UNSERIALIZABLE_RETURN_VALUE_PAYLOAD
from rq.job import Job
from rq.queue import Queue
from rq.registry import StartedJobRegistry
Expand Down Expand Up @@ -236,3 +238,19 @@ def test_job_return_value_result_ttl_zero(self):

Result.create(job, Result.Type.SUCCESSFUL, ttl=0, return_value=1)
self.assertIsNone(job.return_value())

def test_job_return_value_unserializable(self):
"""Test job.return_value when it is not serializable"""
queue = Queue(connection=self.connection, result_ttl=0)
job = queue.enqueue(say_hello)

# Returns None when there's no result
self.assertIsNone(job.return_value())

# tempfile.NamedTemporaryFile() is not picklable
Result.create(job, Result.Type.SUCCESSFUL, ttl=10, return_value=tempfile.NamedTemporaryFile())
self.assertEqual(job.return_value(), UNSERIALIZABLE_RETURN_VALUE_PAYLOAD)
self.assertEqual(Result.count(job), 1)

Result.create(job, Result.Type.SUCCESSFUL, ttl=10, return_value=1)
self.assertEqual(Result.count(job), 2)

0 comments on commit 08cb311

Please sign in to comment.