Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ orjson
numpy>=1.13.1
pathos
pbr
tenacity
tenacity>=8.0.1
tqdm
requests>2,<3
45 changes: 36 additions & 9 deletions taskqueue/aws_queue_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import types

import boto3
import botocore.exceptions
import botocore.errorfactory

from .lib import toiter, sip, jsonify
Expand All @@ -11,8 +12,20 @@
AWS_DEFAULT_REGION
)

import tenacity

AWS_BATCH_SIZE = 10 # send_message_batch's max batch size is 10

class ClientSideError(Exception):
pass

retry = tenacity.retry(
reraise=True,
stop=tenacity.stop_after_attempt(4),
wait=tenacity.wait_random_exponential(0.5, 60.0),
retry=tenacity.retry_if_not_exception_type(ClientSideError),
)

class AWSTaskQueueAPI(object):
def __init__(self, qurl, region_name=AWS_DEFAULT_REGION, **kwargs):
"""
Expand Down Expand Up @@ -42,14 +55,18 @@ def __init__(self, qurl, region_name=AWS_DEFAULT_REGION, **kwargs):
)

if self.qurl is None:
try:
self.qurl = self.sqs.get_queue_url(QueueName=qurl)["QueueUrl"]
except Exception:
print(qurl)
raise
self.qurl = self._get_qurl(qurl)

self.batch_size = AWS_BATCH_SIZE

@retry
def _get_qurl(self, qurl):
try:
return self.sqs.get_queue_url(QueueName=qurl)["QueueUrl"]
except Exception as err:
print(f"Failed to fetch queue URL for: {qurl}")
raise

@property
def enqueued(self):
status = self.status()
Expand All @@ -71,13 +88,15 @@ def leased(self):
def is_empty():
return self.enqueued == 0

@retry
def status(self):
resp = self.sqs.get_queue_attributes(
QueueUrl=self.qurl,
AttributeNames=['ApproximateNumberOfMessages', 'ApproximateNumberOfMessagesNotVisible']
)
return resp['Attributes']

@retry
def insert(self, tasks, delay_seconds=0):
tasks = toiter(tasks)

Expand All @@ -94,10 +113,18 @@ def insert(self, tasks, delay_seconds=0):
} for j, task in enumerate(batch)
]

resp = self.sqs.send_message_batch(
QueueUrl=self.qurl,
Entries=entries,
)
try:
resp = self.sqs.send_message_batch(
QueueUrl=self.qurl,
Entries=entries,
)
except botocore.exceptions.ClientError as error:
http_code = error.response['ResponseMetadata']['HTTPStatusCode']
if 400 <= int(http_code) < 500:
raise ClientSideError(error)
else:
raise error

total += len(entries)

return total
Expand Down