Skip to content

Commit

Permalink
Rewrote s3.copy to utilise copy_object or upload_part_copy
Browse files Browse the repository at this point in the history
  • Loading branch information
terricain committed May 27, 2024
1 parent 367859a commit 0858ec0
Showing 1 changed file with 78 additions and 21 deletions.
99 changes: 78 additions & 21 deletions aioboto3/s3/inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import aiofiles
import inspect
import logging
import math
from io import BytesIO
from typing import Optional, Callable, BinaryIO, Dict, Any, Union
from abc import abstractmethod
Expand All @@ -11,6 +12,7 @@
from boto3.s3.transfer import S3TransferConfig, S3Transfer
from boto3.s3.inject import bucket_upload_file, bucket_download_file, bucket_copy, bucket_upload_fileobj, bucket_download_fileobj
from s3transfer.upload import UploadSubmissionTask
from s3transfer.copies import CopySubmissionTask

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -570,41 +572,96 @@ async def copy(
Key: str,
ExtraArgs: Optional[Dict[str, Any]] = None,
Callback: Optional[TransferCallback] = None,
SourceClient = None, # Should be aioboto3/aiobotocore client
SourceClient=None, # Should be aioboto3/aiobotocore client
Config: Optional[S3TransferConfig] = None
):
assert 'Bucket' in CopySource
assert 'Key' in CopySource

if SourceClient is None:
SourceClient = self

if ExtraArgs is None:
ExtraArgs = {}

download_args = {k: v for k, v in ExtraArgs.items() if k in S3Transfer.ALLOWED_DOWNLOAD_ARGS}
upload_args = {k: v for k, v in ExtraArgs.items() if k in S3Transfer.ALLOWED_UPLOAD_ARGS}
SourceClient = SourceClient or self
Config = Config or S3TransferConfig()
ExtraArgs = ExtraArgs or {}

try:
resp = await SourceClient.get_object(
Bucket=CopySource['Bucket'], Key=CopySource['Key'], **download_args
)
head_object_kwargs = {}
for param, value in ExtraArgs.items():
if param in CopySubmissionTask.EXTRA_ARGS_TO_HEAD_ARGS_MAPPING:
head_object_kwargs[CopySubmissionTask.EXTRA_ARGS_TO_HEAD_ARGS_MAPPING[param]] = value

# Get object metadata to determine the total size
head_response = await SourceClient.head_object(Bucket=CopySource['Bucket'], Key=CopySource['Key'], **head_object_kwargs)
except ClientError as err:
if err.response['Error']['Code'] == 'NoSuchKey':
# Convert to 404 so it looks the same when boto3.download_file fails
raise ClientError({'Error': {'Code': '404', 'Message': 'Not Found'}}, 'HeadObject')
raise

file_obj = resp['Body']
# So CopyObject works up to 5GiB, but S3Transfer uses Config.MultipartThreshold which by default is 8MiB :unamused:
if head_response['ContentLength'] < Config.multipart_threshold:
await self.copy_object(CopySource=CopySource, Bucket=Bucket, Key=Key, **ExtraArgs)
return

await self.upload_fileobj(
file_obj,
Bucket,
Key,
ExtraArgs=upload_args,
Callback=Callback,
Config=Config
)
# File is larger than 5GiB, do multipart copy
create_multipart_kwargs = {k: v for k, v in ExtraArgs.items() if k not in CopySubmissionTask.CREATE_MULTIPART_ARGS_BLACKLIST}
create_multipart_upload_resp = await self.create_multipart_upload(Bucket=Bucket, Key=Key, **create_multipart_kwargs)

finished_parts = []
total_size = 0

async def uploader(size: int, part_args: Dict[str, Any]):
nonlocal total_size

upload_part_response = await self.upload_part_copy(**part_args)
finished_parts.append({'ETag': upload_part_response['CopyPartResult']['ETag'], 'PartNumber': part_args['PartNumber']})

# Call the callback, if it blocks then not good :/
if Callback:
try:
total_size += size
Callback(total_size)
except: # noqa: E722
pass

num_parts = int(math.ceil(head_response['ContentLength'] / float(Config.multipart_chunksize)))

tasks = []
upload_kwargs = {k: v for k, v in ExtraArgs.items() if k in CopySubmissionTask.UPLOAD_PART_COPY_ARGS}
upload_kwargs.update({'Bucket': Bucket, 'Key': Key, 'CopySource': CopySource, 'UploadId': create_multipart_upload_resp['UploadId']})
for part_number in range(1, num_parts + 1):
part_upload_kwargs = upload_kwargs.copy()
part_upload_kwargs['PartNumber'] = part_number

range_start = (part_number - 1) * Config.multipart_chunksize
range_end = range_start + Config.multipart_chunksize - 1
if part_number == num_parts:
range_end = head_response['ContentLength'] - 1

part_upload_kwargs['CopySourceRange'] = f'bytes={range_start}-{range_end}'

tasks.append(uploader(range_end-range_start, part_upload_kwargs))

try:
await asyncio.gather(*tasks)

assert len(finished_parts) == num_parts, "Number of finished upload parts does not match expected parts"

finished_parts.sort(key=lambda item: item['PartNumber'])

complete_upload_args = {k: v for k, v in ExtraArgs.items() if k in CopySubmissionTask.COMPLETE_MULTIPART_ARGS}
await self.complete_multipart_upload(
Bucket=Bucket,
Key=Key,
UploadId=create_multipart_upload_resp['UploadId'],
MultipartUpload={'Parts': finished_parts},
**complete_upload_args
)

except Exception as err:
try:
await self.abort_multipart_upload(Bucket=Bucket, Key=Key, UploadId=create_multipart_upload_resp['UploadId'])
except Exception as err2:
raise err2 from err
raise err


async def bucket_load(self, *args, **kwargs):
Expand Down

0 comments on commit 0858ec0

Please sign in to comment.