Skip to content

Commit

Permalink
Added non seekable download queue
Browse files Browse the repository at this point in the history
  • Loading branch information
terricain committed May 27, 2024
1 parent 8e837e9 commit 72b52e1
Showing 1 changed file with 65 additions and 24 deletions.
89 changes: 65 additions & 24 deletions aioboto3/s3/inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import aiofiles
import inspect
import logging
from typing import Optional, Callable, BinaryIO, Dict, Any
from io import BytesIO
from typing import Optional, Callable, BinaryIO, Dict, Any, Union, Tuple

from botocore.exceptions import ClientError
from boto3 import utils
Expand Down Expand Up @@ -51,7 +52,7 @@ async def object_summary_load(self, *args, **kwargs):


async def download_file(
self, Bucket, Key, Filename, ExtraArgs=None, Callback=None, Config=None
self, Bucket: str, Key: str, Filename: str, ExtraArgs=None, Callback=None, Config=None
):
"""Download an S3 object to a file asynchronously.
Expand All @@ -64,33 +65,37 @@ async def download_file(
Similar behaviour as S3Transfer's download_file() method,
except that parameters are capitalised.
"""
async with aiofiles.open(Filename, 'wb') as open_file:
async with aiofiles.open(Filename, 'wb') as fileobj:
await download_fileobj(
self,
Bucket,
Key,
open_file,
fileobj,
ExtraArgs=ExtraArgs,
Callback=Callback,
Config=Config
)


async def _download_part(self, bucket, key, headers, start, file, semaphore, callback=None):
async def _download_part(self, bucket: str, key: str, headers, start, file: Union[BytesIO, Any, None], semaphore, callback=None, io_queue: Optional[asyncio.Queue] = None) -> Union[None, Tuple[int, bytes]]:
async with semaphore: # limit number of concurrent downloads
response = await self.get_object(
Bucket=bucket, Key=key, Range=headers['Range']
)
content = await response['Body'].read()

# Check if it's aiofiles file
if inspect.iscoroutinefunction(file.seek) and inspect.iscoroutinefunction(file.write):
await file.seek(start)
await file.write(content)
# If stream is not seekable, return it so it can be queued up to be written
if io_queue:
await io_queue.put((start, content))
else:
# Fallback to synchronous operations for file objects that are not async
file.seek(start)
file.write(content)
# Check if it's aiofiles file
if inspect.iscoroutinefunction(file.seek) and inspect.iscoroutinefunction(file.write):
await file.seek(start)
await file.write(content)
else:
# Fallback to synchronous operations for file objects that are not async
file.seek(start)
file.write(content)

# Call the wrapper callback with the number of bytes written, if provided
if callback:
Expand All @@ -107,7 +112,7 @@ async def download_fileobj(
The file-like object must be in binary mode.
This is a managed transfer which will perform a multipart download
This is a managed transfer which will perform a multipart download
with asyncio if necessary.
Usage::
Expand Down Expand Up @@ -144,13 +149,23 @@ async def download_fileobj(
try:
if ExtraArgs is None:
ExtraArgs = {}
resp = await self.get_object(Bucket=Bucket, Key=Key, **ExtraArgs)
# Get object metadata to determine the total size
head_response = await self.head_object(Bucket=Bucket, Key=Key, **ExtraArgs)
except ClientError as err:
if err.response['Error']['Code'] == 'NoSuchKey':
# Convert to 404 so it looks the same when boto3.download_file fails
raise ClientError({'Error': {'Code': '404', 'Message': 'Not Found'}}, 'HeadObject')
raise

# Semaphore to limit the number of concurrent downloads
semaphore = asyncio.Semaphore(10)

# Size of each part (8MB)
part_size = 8 * 1024 * 1024

total_size = head_response['ContentLength']
total_parts = (total_size + part_size - 1) // part_size

# Keep track of total downloaded bytes
total_downloaded = 0

Expand All @@ -163,33 +178,59 @@ def wrapper_callback(bytes_transferred):
except: # noqa: E722
pass

# Size of each part (8MB)
part_size = 8 * 1024 * 1024
is_seekable = inspect.isfunction(Fileobj.seek) or inspect.iscoroutinefunction(Fileobj.seek)

try:
# Get object metadata to determine the total size
response = await self.head_object(Bucket=Bucket, Key=Key, **ExtraArgs)
total_size = response['ContentLength']
total_parts = (total_size + part_size - 1) // part_size
# This'll have around `semaphore` length items, somewhat more if writing is slow
# TODO add limits so we dont fill up this list n blow out ram
io_list = []
io_queue = asyncio.Queue()

async def queue_reader():
is_async = inspect.iscoroutinefunction(Fileobj.write)

written_pos = 0
while written_pos < total_size:
io_list.append(await io_queue.get())

# Semaphore to limit the number of concurrent downloads
semaphore = asyncio.Semaphore(10)
# Stuff might be out of order in io_list
# so spin until there's nothing to queue off
done_nothing = False
while not done_nothing:
done_nothing = True

for chunk_start, data in io_list:
if chunk_start == written_pos:
if is_async:
await Fileobj.write(data)
else:
Fileobj.write(data)
written_pos += len(data)
done_nothing = False

queue_reader_future = None
if not is_seekable:
queue_reader_future = asyncio.ensure_future(queue_reader())

try:
tasks = []
for i in range(total_parts):
start = i * part_size
end = min(
start + part_size, total_size
) # Ensure we don't go beyond the total size
# Range headers, start at 0 so end which would be total_size, minus 1 = 0 indexed.
headers = {'Range': f'bytes={start}-{end - 1}'}
# Create a task for each part download
tasks.append(
_download_part(self, Bucket, Key, headers, start, Fileobj, semaphore, wrapper_callback)
_download_part(self, Bucket, Key, headers, start, Fileobj, semaphore, wrapper_callback, io_queue if not is_seekable else None)
)

# Run all the download tasks concurrently
await asyncio.gather(*tasks)

if queue_reader_future:
await queue_reader_future

logger.info(f'Downloaded file from {Bucket}/{Key}')

except ClientError as e:
Expand Down

0 comments on commit 72b52e1

Please sign in to comment.