diff --git a/aioboto3/s3/inject.py b/aioboto3/s3/inject.py index 37e3ca0..33a3f47 100644 --- a/aioboto3/s3/inject.py +++ b/aioboto3/s3/inject.py @@ -3,16 +3,34 @@ import inspect import logging from io import BytesIO -from typing import Optional, Callable, BinaryIO, Dict, Any, Union, Tuple +from typing import Optional, Callable, BinaryIO, Dict, Any, Union +from abc import abstractmethod from botocore.exceptions import ClientError from boto3 import utils -from boto3.s3.transfer import S3TransferConfig +from boto3.s3.transfer import S3TransferConfig, S3Transfer from boto3.s3.inject import bucket_upload_file, bucket_download_file, bucket_copy, bucket_upload_fileobj, bucket_download_fileobj +from s3transfer.upload import UploadSubmissionTask logger = logging.getLogger(__name__) +TransferCallback = Callable[[int], None] + + +class _AsyncBinaryIO: + @abstractmethod + async def seek(self, offset: int, whence: int = 0) -> int: + pass + + @abstractmethod + async def write(self, s: Union[bytes, bytearray]) -> int: + pass + + +AnyFileObject = Union[_AsyncBinaryIO, BinaryIO] + + def inject_s3_transfer_methods(class_attributes, **kwargs): utils.inject_attribute(class_attributes, 'upload_file', upload_file) utils.inject_attribute(class_attributes, 'download_file', download_file) @@ -52,20 +70,27 @@ async def object_summary_load(self, *args, **kwargs): async def download_file( - self, Bucket: str, Key: str, Filename: str, ExtraArgs=None, Callback=None, Config=None + self, + Bucket: str, + Key: str, + Filename: str, + ExtraArgs: Optional[Dict[str, Any]] = None, + Callback: Optional[TransferCallback] = None, + Config: Optional[S3TransferConfig] = None ): """Download an S3 object to a file asynchronously. Usage:: import aioboto3 - s3 = aioboto3.resource('s3') - await s3.meta.client.download_file('mybucket', 'hello.txt', '/tmp/hello.txt') + + async with aioboto3.resource('s3') as s3: + await s3.meta.client.download_file('mybucket', 'hello.txt', '/tmp/hello.txt') Similar behaviour as S3Transfer's download_file() method, except that parameters are capitalised. """ - async with aiofiles.open(Filename, 'wb') as fileobj: + async with aiofiles.open(Filename, 'wb') as fileobj: # type: _AsyncBinaryIO await download_fileobj( self, Bucket, @@ -77,21 +102,24 @@ async def download_file( ) -async def _download_part(self, bucket: str, key: str, headers, start, file: Union[BytesIO, Any, None], semaphore, callback=None, io_queue: Optional[asyncio.Queue] = None) -> Union[None, Tuple[int, bytes]]: +async def _download_part(self, bucket: str, key: str, headers: Dict[str, str], start: int, file: AnyFileObject, semaphore: asyncio.Semaphore, write_lock: asyncio.Lock, + callback=None, io_queue: Optional[asyncio.Queue] = None) -> None: async with semaphore: # limit number of concurrent downloads response = await self.get_object( Bucket=bucket, Key=key, Range=headers['Range'] ) content = await response['Body'].read() - # If stream is not seekable, return it so it can be queued up to be written + # If stream is not seekable, return the offset and data so it can be queued up to be written if io_queue: await io_queue.put((start, content)) else: # Check if it's aiofiles file if inspect.iscoroutinefunction(file.seek) and inspect.iscoroutinefunction(file.write): - await file.seek(start) - await file.write(content) + # These operations need to happen sequentially, which is non-deterministic when dealing with event loops + async with write_lock: + await file.seek(start) + await file.write(content) else: # Fallback to synchronous operations for file objects that are not async file.seek(start) @@ -106,7 +134,13 @@ async def _download_part(self, bucket: str, key: str, headers, start, file: Unio async def download_fileobj( - self, Bucket, Key, Fileobj, ExtraArgs=None, Callback=None, Config=None + self, + Bucket: str, + Key: str, + Fileobj: AnyFileObject, + ExtraArgs: Optional[Dict[str, Any]] = None, + Callback: Optional[TransferCallback] = None, + Config: Optional[S3TransferConfig] = None ): """Download an object from S3 to a file-like object. @@ -146,9 +180,10 @@ async def download_fileobj( download. """ + Config = Config or S3TransferConfig() + ExtraArgs = ExtraArgs or {} + try: - if ExtraArgs is None: - ExtraArgs = {} # Get object metadata to determine the total size head_response = await self.head_object(Bucket=Bucket, Key=Key, **ExtraArgs) except ClientError as err: @@ -158,13 +193,11 @@ async def download_fileobj( raise # Semaphore to limit the number of concurrent downloads - semaphore = asyncio.Semaphore(10) - - # Size of each part (8MB) - part_size = 8 * 1024 * 1024 + semaphore = asyncio.Semaphore(Config.max_request_concurrency) + write_mutex = asyncio.Lock() total_size = head_response['ContentLength'] - total_parts = (total_size + part_size - 1) // part_size + total_parts = (total_size + Config.multipart_chunksize - 1) // Config.multipart_chunksize # Keep track of total downloaded bytes total_downloaded = 0 @@ -178,34 +211,51 @@ def wrapper_callback(bytes_transferred): except: # noqa: E722 pass - is_seekable = inspect.isfunction(Fileobj.seek) or inspect.iscoroutinefunction(Fileobj.seek) + is_seekable = hasattr(Fileobj, "seek") # This'll have around `semaphore` length items, somewhat more if writing is slow # TODO add limits so we dont fill up this list n blow out ram io_list = [] + + # This should be Config.io_concurrency but as we're gathering all coro's we cant guarantee + # that the co-routines will start in relative order so we could fill up the queue with the + # x chunks and if we're not writing to a seekable stream then it'll deadlock. io_queue = asyncio.Queue() async def queue_reader(): + """ + Pretty much, get things off queue, add them to list + Go through list, write things to file object in order + """ is_async = inspect.iscoroutinefunction(Fileobj.write) - written_pos = 0 - while written_pos < total_size: - io_list.append(await io_queue.get()) - - # Stuff might be out of order in io_list - # so spin until there's nothing to queue off - done_nothing = False - while not done_nothing: - done_nothing = True - - for chunk_start, data in io_list: - if chunk_start == written_pos: - if is_async: - await Fileobj.write(data) - else: - Fileobj.write(data) - written_pos += len(data) - done_nothing = False + try: + written_pos = 0 + while written_pos < total_size: + io_list.append(await io_queue.get()) + + # Stuff might be out of order in io_list + # so spin until there's nothing to queue off + done_nothing = False + while not done_nothing: + done_nothing = True + + indexes_to_remove = [] + for index, (chunk_start, data) in enumerate(io_list): + if chunk_start == written_pos: + if is_async: + await Fileobj.write(data) + else: + Fileobj.write(data) + + indexes_to_remove.append(index) + written_pos += len(data) + done_nothing = False + + for index in reversed(indexes_to_remove): + io_list.pop(index) + except asyncio.CancelledError: + pass queue_reader_future = None if not is_seekable: @@ -214,19 +264,19 @@ async def queue_reader(): try: tasks = [] for i in range(total_parts): - start = i * part_size + start = i * Config.multipart_chunksize end = min( - start + part_size, total_size + start + Config.multipart_chunksize, total_size ) # Ensure we don't go beyond the total size # Range headers, start at 0 so end which would be total_size, minus 1 = 0 indexed. headers = {'Range': f'bytes={start}-{end - 1}'} # Create a task for each part download tasks.append( - _download_part(self, Bucket, Key, headers, start, Fileobj, semaphore, wrapper_callback, io_queue if not is_seekable else None) + _download_part(self, Bucket, Key, headers, start, Fileobj, semaphore, write_mutex, wrapper_callback, io_queue if not is_seekable else None) ) # Run all the download tasks concurrently - await asyncio.gather(*tasks) + await asyncio.gather(*tasks) # TODO might not be worth spamming the eventloop with 1000's of tasks, but deal with it when its a problem. if queue_reader_future: await queue_reader_future @@ -241,7 +291,7 @@ async def queue_reader(): async def upload_fileobj( self, - Fileobj: BinaryIO, + Fileobj: AnyFileObject, Bucket: str, Key: str, ExtraArgs: Optional[Dict[str, Any]] = None, @@ -291,20 +341,19 @@ async def upload_fileobj( by custom logic. """ kwargs = ExtraArgs or {} + upload_part_args = {k: v for k, v in kwargs.items() if k in UploadSubmissionTask.UPLOAD_PART_ARGS} + complete_upload_args = {k: v for k, v in kwargs.items() if k in UploadSubmissionTask.COMPLETE_MULTIPART_ARGS} + Config = Config or S3TransferConfig() # I was debating setting up a queue etc... # If its too slow I'll then be bothered - multipart_chunksize = 8388608 if Config is None else Config.multipart_chunksize - io_chunksize = 262144 if Config is None else Config.io_chunksize - max_concurrency = 10 if Config is None else Config.max_request_concurrency - max_io_queue = 100 if Config is None else Config.max_io_queue_size # Start multipart upload resp = await self.create_multipart_upload(Bucket=Bucket, Key=Key, **kwargs) upload_id = resp['UploadId'] finished_parts = [] expected_parts = 0 - io_queue = asyncio.Queue(maxsize=max_io_queue) + io_queue = asyncio.Queue(maxsize=Config.max_io_queue_size) exception_event = asyncio.Event() exception = None sent_bytes = 0 @@ -360,10 +409,10 @@ async def file_reader() -> None: part += 1 multipart_payload = bytearray() loop_counter = 0 - while len(multipart_payload) < multipart_chunksize: + while len(multipart_payload) < Config.multipart_chunksize: try: # Handles if .read() returns anything that can be awaited - data_chunk = Fileobj.read(io_chunksize) + data_chunk = Fileobj.read(Config.io_chunksize) if inspect.isawaitable(data_chunk): # noinspection PyUnresolvedReferences data = await data_chunk @@ -397,12 +446,12 @@ async def file_reader() -> None: multipart_payload = Processing(multipart_payload) await io_queue.put({'Body': multipart_payload, 'Bucket': Bucket, 'Key': Key, - 'PartNumber': part, 'UploadId': upload_id}) + 'PartNumber': part, 'UploadId': upload_id, **upload_part_args}) logger.debug('Added part to io_queue') expected_parts += 1 file_reader_future = asyncio.ensure_future(file_reader()) - futures = [asyncio.ensure_future(uploader()) for _ in range(0, max_concurrency)] + futures = [asyncio.ensure_future(uploader()) for _ in range(0, Config.max_request_concurrency)] # Wait for file reader to finish await file_reader_future @@ -428,7 +477,8 @@ async def file_reader() -> None: Bucket=Bucket, Key=Key, UploadId=upload_id, - MultipartUpload={'Parts': finished_parts} + MultipartUpload={'Parts': finished_parts}, + **complete_upload_args ) except Exception as err: # We failed to complete the upload, try and abort, then return the orginal error @@ -467,20 +517,26 @@ async def file_reader() -> None: async def upload_file( - self, Filename, Bucket, Key, ExtraArgs=None, Callback=None, Config=None + self, + Filename: str, + Bucket: str, + Key: str, + ExtraArgs: Optional[Dict[str, Any]] = None, + Callback: Optional[TransferCallback] = None, + Config: Optional[S3TransferConfig] = None ): """Upload a file to an S3 object. Usage:: import boto3 - s3 = boto3.resource('s3') - s3.meta.client.upload_file('/tmp/hello.txt', 'mybucket', 'hello.txt') + async with aioboto3.resource('s3') as s3: + await s3.meta.client.upload_file('/tmp/hello.txt', 'mybucket', 'hello.txt') Similar behavior as S3Transfer's upload_file() method, except that parameters are capitalized. """ - with open(Filename, 'rb') as open_file: + async with aiofiles.open(Filename, 'rb') as open_file: await upload_fileobj( self, open_file, @@ -493,7 +549,14 @@ async def upload_file( async def copy( - self, CopySource, Bucket, Key, ExtraArgs=None, Callback=None, SourceClient=None, Config=None + self, + CopySource: Dict[str, Any], + Bucket: str, + Key: str, + ExtraArgs: Optional[Dict[str, Any]] = None, + Callback: Optional[TransferCallback] = None, + SourceClient = None, # Should be aioboto3/aiobotocore client + Config: Optional[S3TransferConfig] = None ): assert 'Bucket' in CopySource assert 'Key' in CopySource @@ -504,9 +567,12 @@ async def copy( if ExtraArgs is None: ExtraArgs = {} + download_args = {k: v for k, v in ExtraArgs.items() if k in S3Transfer.ALLOWED_DOWNLOAD_ARGS} + upload_args = {k: v for k, v in ExtraArgs.items() if k in S3Transfer.ALLOWED_UPLOAD_ARGS} + try: resp = await SourceClient.get_object( - Bucket=CopySource['Bucket'], Key=CopySource['Key'], **ExtraArgs + Bucket=CopySource['Bucket'], Key=CopySource['Key'], **download_args ) except ClientError as err: if err.response['Error']['Code'] == 'NoSuchKey': @@ -520,7 +586,7 @@ async def copy( file_obj, Bucket, Key, - ExtraArgs=ExtraArgs, + ExtraArgs=upload_args, Callback=Callback, Config=Config )