Skip to content

Commit

Permalink
Enabled Azure Multipart (#863)
Browse files Browse the repository at this point in the history
Edited the AzureBlobInterface class:
* Wrote the logic for staging/uploading a block for a multipart upload
in the method upload_object
* Created two functions to 1) initiate the multipart upload and 2)
complete the multipart upload
* Since Azure works differently than s3 and gcs in that it doesn't
provide a global upload id for a destination object, I used the
destination object name instead as an upload id to stay consistent with
the other object stores. This pseudo-upload id is to keep track of which
blocks and their blockIDs belong to in the CopyJob/SyncJob.
* Upon completion of uploading/staging all blocks, all blocks for a
destination object are committed together.

More things to consider about this implementation:

Upload ID handling: Azure doesn't really have a concept equivalent to
AWS's upload IDs. Instead, blobs are created immediately and blocks are
associated with a blob via block IDs. My workaround of using the blob
name as the upload ID should work since I only use upload_id to
distinguish between requests in the finalize() method

Block IDs: It's worth noting that Azure requires block IDs to be of the
same length. I've appropriately handled this by formatting the IDs to be
of length len("{number of digits in max blocks supported by Azure
(50000) = 5}{destination_object_key}").

---------

Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
  • Loading branch information
xTRam1 and sarahwooders committed Jun 20, 2023
1 parent ac4d589 commit 99db23c
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 60 deletions.
20 changes: 16 additions & 4 deletions skyplane/api/transfer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from skyplane import exceptions
from skyplane.api.config import TransferConfig
from skyplane.chunk import Chunk, ChunkRequest
from skyplane.obj_store.azure_blob_interface import AzureBlobObject
from skyplane.obj_store.azure_blob_interface import AzureBlobInterface, AzureBlobObject
from skyplane.obj_store.gcs_interface import GCSObject
from skyplane.obj_store.r2_interface import R2Object
from skyplane.obj_store.storage_interface import StorageInterface
Expand Down Expand Up @@ -158,8 +158,15 @@ def _run_multipart_chunk_thread(
region = dest_iface.region_tag()
dest_object = dest_objects[region]
_, upload_id = upload_id_mapping[region]

metadata = None
# Convert parts to base64 and store mime_type if destination interface is AzureBlobInterface
if isinstance(dest_iface, AzureBlobInterface):
block_ids = list(map(lambda part_num: AzureBlobInterface.id_to_base64_encoding(part_num, dest_object.key), parts))
metadata = (block_ids, mime_type)

self.multipart_upload_requests.append(
dict(upload_id=upload_id, key=dest_object.key, parts=parts, region=region, bucket=bucket)
dict(upload_id=upload_id, key=dest_object.key, parts=parts, region=region, bucket=bucket, metadata=metadata)
)
else:
mime_type = None
Expand Down Expand Up @@ -646,6 +653,7 @@ def dispatch(
assert Chunk.from_dict(chunk_batch[0].as_dict()) == chunk_batch[0], f"Invalid chunk request: {chunk_batch[0].as_dict}"

# TODO: make async
st = time.time()
reply = self.http_pool.request(
"POST",
f"{server.gateway_api_url}/api/v1/chunk_requests",
Expand All @@ -654,9 +662,10 @@ def dispatch(
)
if reply.status != 200:
raise Exception(f"Failed to dispatch chunk requests {server.instance_name()}: {reply.data.decode('utf-8')}")
et = time.time()
reply_json = json.loads(reply.data.decode("utf-8"))
logger.fs.debug(f"Added {n_added} chunks to server {server}: {reply_json}")
n_added += reply_json["n_added"]
logger.fs.debug(f"Added {n_added} chunks to server {server} in {et-st}: {reply_json}")
queue_size[min_idx] = reply_json["qsize"] # update queue size
# dont try again with some gateway
min_idx = (min_idx + 1) % len(src_gateways)
Expand Down Expand Up @@ -685,7 +694,10 @@ def finalize(self):
def complete_fn(batch):
for req in batch:
logger.fs.debug(f"Finalize upload id {req['upload_id']} for key {req['key']}")
retry_backoff(partial(obj_store_interface.complete_multipart_upload, req["key"], req["upload_id"]), initial_backoff=0.5)
retry_backoff(
partial(obj_store_interface.complete_multipart_upload, req["key"], req["upload_id"], req["metadata"]),
initial_backoff=0.5,
)

do_parallel(complete_fn, batches, n=8)

Expand Down
121 changes: 100 additions & 21 deletions skyplane/obj_store/azure_blob_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
import os
from functools import lru_cache

from typing import Iterator, List, Optional, Tuple
from typing import Any, Iterator, List, Optional, Tuple

from skyplane import exceptions, compute
from skyplane.exceptions import NoSuchObjectException
from skyplane.obj_store.azure_storage_account_interface import AzureStorageAccountInterface
from skyplane.obj_store.object_store_interface import ObjectStoreInterface, ObjectStoreObject
from skyplane.utils import logger, imports
from azure.storage.blob import ContentSettings


MAX_BLOCK_DIGITS = 5


class AzureBlobObject(ObjectStoreObject):
Expand Down Expand Up @@ -149,25 +153,100 @@ def download_object(

@imports.inject("azure.storage.blob", pip_extra="azure")
def upload_object(azure_blob, self, src_file_path, dst_object_name, part_number=None, upload_id=None, check_md5=None, mime_type=None):
if part_number is not None or upload_id is not None:
# todo implement multipart upload
raise NotImplementedError("Multipart upload is not implemented for Azure")
"""Uses the BlobClient instead of ContainerClient since BlobClient allows for
block/part level manipulation for multi-part uploads
"""
src_file_path, dst_object_name = str(src_file_path), str(dst_object_name)
with open(src_file_path, "rb") as f:
print(f"Uploading {src_file_path} to {dst_object_name}")
blob_client = self.container_client.upload_blob(
name=dst_object_name,
data=f,
length=os.path.getsize(src_file_path),
max_concurrency=self.max_concurrency,
overwrite=True,
content_settings=azure_blob.ContentSettings(content_type=mime_type),
)
if check_md5:
b64_md5sum = base64.b64encode(check_md5).decode("utf-8") if check_md5 else None
blob_md5 = blob_client.get_blob_properties().properties.content_settings.content_md5
if b64_md5sum != blob_md5:
raise exceptions.ChecksumMismatchException(
f"Checksum mismatch for object {dst_object_name} in Azure container {self.container_name}, "
+ f"expected {b64_md5sum}, got {blob_md5}"
print(f"Uploading {src_file_path} to {dst_object_name}")

try:
blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=dst_object_name)

# multipart upload
if part_number is not None and upload_id is not None:
with open(src_file_path, "rb") as f:
block_id = AzureBlobInterface.id_to_base64_encoding(part_number=part_number, dest_key=dst_object_name)
blob_client.stage_block(block_id=block_id, data=f, length=os.path.getsize(src_file_path)) # stage the block
return

# single upload
with open(src_file_path, "rb") as f:
blob_client.upload_blob(
data=f,
length=os.path.getsize(src_file_path),
max_concurrency=self.max_concurrency,
overwrite=True,
content_settings=azure_blob.ContentSettings(content_type=mime_type),
)

# check MD5 if required
if check_md5:
b64_md5sum = base64.b64encode(check_md5).decode("utf-8") if check_md5 else None
blob_md5 = blob_client.get_blob_properties().properties.content_settings.content_md5
if b64_md5sum != blob_md5:
raise exceptions.ChecksumMismatchException(
f"Checksum mismatch for object {dst_object_name} in Azure container {self.container_name}, "
+ f"expected {b64_md5sum}, got {blob_md5}"
)
except Exception as e:
raise ValueError(f"Failed to upload {dst_object_name} to bucket {self.container_name} upload id {upload_id}: {e}")

def initiate_multipart_upload(self, dst_object_name: str, mime_type: Optional[str] = None) -> str:
"""Azure does not have an equivalent function to return an upload ID like s3 and gcs do.
Blocks in Azure are uploaded and associated with an ID, and can then be committed in a single operation to create the blob.
We will just return the dst_object_name (blob name) as the "upload_id" to keep the return type consistent for the multipart thread.
:param dst_object_name: name of the destination object, also our psuedo-uploadID
:type dst_object_name: str
:param mime_type: unused in this function but is kept for consistency with the other interfaces (default: None)
:type mime_type: str
"""

assert len(dst_object_name) > 0, f"Destination object name must be non-empty: '{dst_object_name}'"

return dst_object_name

@imports.inject("azure.storage.blob", pip_extra="azure")
def complete_multipart_upload(azure_blob, self, dst_object_name: str, upload_id: str, metadata: Optional[Any] = None) -> None:
"""After all blocks of a blob are uploaded/staged with their unique block_id,
in order to complete the multipart upload, we commit them together.
:param dst_object_name: name of the destination object, also is used to index into our block mappings
:type dst_object_name: str
:param upload_id: upload_id to index into our block id mappings, should be the same as the dst_object_name in Azure
:type upload_id: str
:param metadata: In Azure, this custom data is the blockID list (parts) and the object mime_type from the TransferJob instance (default: None)
:type metadata: Optional[Any]
"""

assert upload_id == dst_object_name, "In Azure, upload_id should be the same as the blob name."
assert metadata is not None, "In Azure, the custom data should exist for multipart"

# Decouple the custom data
block_list, mime_type = metadata
assert block_list != [], "The blockID list shouldn't be empty for Azure multipart"
block_list = list(map(lambda block_id: azure_blob.BlobBlock(block_id=block_id), block_list))

blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=dst_object_name)
try:
# The below operation will create the blob from the uploaded blocks.
blob_client.commit_block_list(block_list=block_list, content_settings=azure_blob.ContentSettings(content_type=mime_type))
except Exception as e:
raise exceptions.SkyplaneException(f"Failed to complete multipart upload for {dst_object_name}: {str(e)}")

@staticmethod
def id_to_base64_encoding(part_number: int, dest_key: str) -> str:
"""Azure expects all blockIDs to be Base64 strings. This function serves to convert the part numbers to
base64-encoded strings of the same length. The maximum number of blocks one blob supports in Azure is
50,000 blocks, so the maximum length to pad zeroes to will be (#digits in 50,000 = len("50000") = 5) + len(dest_key)
:param part_number: part number of the block, determined while splitting the date into chunks before the transfer
:type part_number: int
:param dest_key: destination object key, used to distinguish between different objects during concurrent uploads to the same container
"""
max_length = MAX_BLOCK_DIGITS + len(dest_key)
block_id = f"{part_number}{dest_key}"
block_id = block_id.ljust(max_length, "0") # pad with zeroes to get consistent length
block_id = block_id.encode("utf-8")
block_id = base64.b64encode(block_id).decode("utf-8")
return block_id
4 changes: 2 additions & 2 deletions skyplane/obj_store/cos_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import hashlib
import os
from functools import lru_cache
from typing import Iterator, List, Optional, Tuple
from typing import Any, Iterator, List, Optional, Tuple


from skyplane import exceptions
Expand Down Expand Up @@ -217,7 +217,7 @@ def initiate_multipart_upload(self, dst_object_name: str, mime_type: Optional[st
else:
raise exceptions.SkyplaneException(f"Failed to initiate multipart upload for {dst_object_name}: {response}")

def complete_multipart_upload(self, dst_object_name, upload_id):
def complete_multipart_upload(self, dst_object_name, upload_id, metadata: Optional[Any] = None):
print("complete multipart upload")
cos_client = self._cos_client()
all_parts = []
Expand Down
4 changes: 2 additions & 2 deletions skyplane/obj_store/file_system_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Iterator, List, Optional
from typing import Any, Iterator, List, Optional
from skyplane.obj_store.storage_interface import StorageInterface
import os

Expand Down Expand Up @@ -53,7 +53,7 @@ def delete_files(self, paths: List[str]):
def initiate_multipart_upload(self, dst_object_name: str) -> str:
raise ValueError("Multipart uploads not supported")

def complete_multipart_upload(self, dst_object_name: str, upload_id: str) -> None:
def complete_multipart_upload(self, dst_object_name: str, upload_id: str, metadata: Optional[Any] = None) -> None:
raise ValueError("Multipart uploads not supported")

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions skyplane/obj_store/gcs_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from xml.etree import ElementTree

import requests
from typing import Iterator, List, Optional, Tuple
from typing import Any, Iterator, List, Optional, Tuple

from skyplane import exceptions, compute
from skyplane.config_paths import cloud_config
Expand Down Expand Up @@ -254,7 +254,7 @@ def initiate_multipart_upload(self, dst_object_name: str, mime_type: Optional[st
response = self.send_xml_request(dst_object_name, {"uploads": None}, "POST", content_type=mime_type)
return ElementTree.fromstring(response.content)[2].text

def complete_multipart_upload(self, dst_object_name, upload_id):
def complete_multipart_upload(self, dst_object_name, upload_id, metadata: Optional[Any] = None):
# get parts
xml_data = ElementTree.Element("CompleteMultipartUpload")
next_part_number_marker = None
Expand Down
4 changes: 2 additions & 2 deletions skyplane/obj_store/hdfs_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from pyarrow import fs
from dataclasses import dataclass
from typing import Iterator, List, Optional
from typing import Any, Iterator, List, Optional
from skyplane.exceptions import NoSuchObjectException
from skyplane.obj_store.object_store_interface import ObjectStoreInterface, ObjectStoreObject
from skyplane.utils import logger
Expand Down Expand Up @@ -150,7 +150,7 @@ def write_file(self, file_name, data, offset=0):
def initiate_multipart_upload(self, dst_object_name: str, mime_type: Optional[str] = None) -> str:
raise NotImplementedError(f"Multipart upload is not supported for the POSIX file system.")

def complete_multipart_upload(self, dst_object_name: str, upload_id: str) -> None:
def complete_multipart_upload(self, dst_object_name: str, upload_id: str, metadata: Optional[Any] = None) -> None:
raise NotImplementedError(f"Multipart upload is not supported for the POSIX file system.")

@lru_cache(maxsize=1024)
Expand Down
4 changes: 2 additions & 2 deletions skyplane/obj_store/object_store_interface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass

from typing import Iterator, List, Optional, Tuple
from typing import Any, Iterator, List, Optional, Tuple

from skyplane.obj_store.storage_interface import StorageInterface
from skyplane.utils import logger
Expand Down Expand Up @@ -82,5 +82,5 @@ def delete_objects(self, keys: List[str]):
def initiate_multipart_upload(self, dst_object_name: str, mime_type: Optional[str] = None) -> str:
raise ValueError("Multipart uploads not supported")

def complete_multipart_upload(self, dst_object_name: str, upload_id: str) -> None:
def complete_multipart_upload(self, dst_object_name: str, upload_id: str, metadata: Optional[Any] = None) -> None:
raise ValueError("Multipart uploads not supported")
4 changes: 2 additions & 2 deletions skyplane/obj_store/posix_file_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import sys
from dataclasses import dataclass
from typing import Iterator, List, Optional
from typing import Any, Iterator, List, Optional
from skyplane.exceptions import NoSuchObjectException
from skyplane.obj_store.object_store_interface import ObjectStoreInterface, ObjectStoreObject
import mimetypes
Expand Down Expand Up @@ -139,7 +139,7 @@ def write_file(self, file_name, data, offset=0):
def initiate_multipart_upload(self, dst_object_name: str, mime_type: Optional[str] = None) -> str:
raise NotImplementedError(f"Multipart upload is not supported for the POSIX file system.")

def complete_multipart_upload(self, dst_object_name: str, upload_id: str) -> None:
def complete_multipart_upload(self, dst_object_name: str, upload_id: str, metadata: Optional[Any] = None) -> None:
raise NotImplementedError(f"Multipart upload is not supported for the POSIX file system.")

@lru_cache(maxsize=1024)
Expand Down
5 changes: 3 additions & 2 deletions skyplane/obj_store/s3_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from functools import lru_cache

from typing import Iterator, List, Optional, Tuple
from typing import Any, Iterator, List, Optional, Tuple

from skyplane import exceptions, compute
from skyplane.exceptions import NoSuchObjectException
Expand Down Expand Up @@ -43,6 +43,7 @@ def aws_region(self):
logger.warning(f"Bucket location {self.bucket_name} is not public. Assuming region is {default_region}")
return default_region
logger.warning(f"Specified bucket {self.bucket_name} does not exist, got AWS error: {e}")
print("Error getting AWS region", e)
raise exceptions.MissingBucketException(f"S3 bucket {self.bucket_name} does not exist") from e

def region_tag(self):
Expand Down Expand Up @@ -227,7 +228,7 @@ def initiate_multipart_upload(self, dst_object_name: str, mime_type: Optional[st
else:
raise exceptions.SkyplaneException(f"Failed to initiate multipart upload for {dst_object_name}: {response}")

def complete_multipart_upload(self, dst_object_name, upload_id):
def complete_multipart_upload(self, dst_object_name, upload_id, metadata: Optional[Any] = None):
s3_client = self._s3_client()
all_parts = []
while True:
Expand Down
Loading

0 comments on commit 99db23c

Please sign in to comment.