diff --git a/requirements.txt b/requirements.txt index 1e998da..ddef771 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,6 @@ orjson numpy>=1.13.1 pathos pbr -tenacity +tenacity>=8.0.1 tqdm requests>2,<3 \ No newline at end of file diff --git a/taskqueue/aws_queue_api.py b/taskqueue/aws_queue_api.py index cea291e..04221aa 100644 --- a/taskqueue/aws_queue_api.py +++ b/taskqueue/aws_queue_api.py @@ -3,6 +3,7 @@ import types import boto3 +import botocore.exceptions import botocore.errorfactory from .lib import toiter, sip, jsonify @@ -11,8 +12,20 @@ AWS_DEFAULT_REGION ) +import tenacity + AWS_BATCH_SIZE = 10 # send_message_batch's max batch size is 10 +class ClientSideError(Exception): + pass + +retry = tenacity.retry( + reraise=True, + stop=tenacity.stop_after_attempt(4), + wait=tenacity.wait_random_exponential(0.5, 60.0), + retry=tenacity.retry_if_not_exception_type(ClientSideError), +) + class AWSTaskQueueAPI(object): def __init__(self, qurl, region_name=AWS_DEFAULT_REGION, **kwargs): """ @@ -42,14 +55,18 @@ def __init__(self, qurl, region_name=AWS_DEFAULT_REGION, **kwargs): ) if self.qurl is None: - try: - self.qurl = self.sqs.get_queue_url(QueueName=qurl)["QueueUrl"] - except Exception: - print(qurl) - raise + self.qurl = self._get_qurl(qurl) self.batch_size = AWS_BATCH_SIZE + @retry + def _get_qurl(self, qurl): + try: + return self.sqs.get_queue_url(QueueName=qurl)["QueueUrl"] + except Exception as err: + print(f"Failed to fetch queue URL for: {qurl}") + raise + @property def enqueued(self): status = self.status() @@ -71,6 +88,7 @@ def leased(self): def is_empty(): return self.enqueued == 0 + @retry def status(self): resp = self.sqs.get_queue_attributes( QueueUrl=self.qurl, @@ -78,6 +96,7 @@ def status(self): ) return resp['Attributes'] + @retry def insert(self, tasks, delay_seconds=0): tasks = toiter(tasks) @@ -94,10 +113,18 @@ def insert(self, tasks, delay_seconds=0): } for j, task in enumerate(batch) ] - resp = self.sqs.send_message_batch( - QueueUrl=self.qurl, - Entries=entries, - ) + try: + resp = self.sqs.send_message_batch( + QueueUrl=self.qurl, + Entries=entries, + ) + except botocore.exceptions.ClientError as error: + http_code = error.response['ResponseMetadata']['HTTPStatusCode'] + if 400 <= int(http_code) < 500: + raise ClientSideError(error) + else: + raise error + total += len(entries) return total