Skip to content

Commit

Permalink
rudimentary API for parts
Browse files Browse the repository at this point in the history
  • Loading branch information
antonzabreyko committed Apr 23, 2023
1 parent 08391be commit faa53ae
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 21 deletions.
1 change: 1 addition & 0 deletions skyplane/api/transfer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _run_multipart_chunk_thread(
parts.append(part_num)
part_num += 1
out_queue.put(chunk)
self.dst_iface.receive_parts(dest_object.key, parts)
self.multipart_upload_requests.append(dict(upload_id=upload_id, key=dest_object.key, parts=parts, region=region, bucket=bucket))

def to_chunk_requests(self, gen_in: Generator[Chunk, None, None]) -> Generator[ChunkRequest, None, None]:
Expand Down
55 changes: 34 additions & 21 deletions skyplane/obj_store/azure_blob_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,32 +153,45 @@ def download_object(
def upload_object(azure_blob, self, src_file_path, dst_object_name, part_number=None, upload_id=None, check_md5=None, mime_type=None):
if part_number is not None or upload_id is not None:
with open(src_file_path, "rb") as f:
client, part_list = self.multipart_map[dst_object_name]
logger.info(self.multipart_map)
if dst_object_name in self.multipart_map:
client, part_list = self.multipart_map[dst_object_name]
else:
client = self.blob_service_client.get_blob_client(container=self.container_name, blob=dst_object_name)
part_list = list()
self.multipart_map[dst_object_name] = (client, part_list)
client.stage_block(block_id=part_number, data=f)
part_list.append(BlobBlock(str(part_number)))
src_file_path, dst_object_name = str(src_file_path), str(dst_object_name)
with open(src_file_path, "rb") as f:
print(f"Uploading {src_file_path} to {dst_object_name}")
blob_client = self.container_client.upload_blob(
name=dst_object_name,
data=f,
length=os.path.getsize(src_file_path),
max_concurrency=self.max_concurrency,
overwrite=True,
content_settings=azure_blob.ContentSettings(content_type=mime_type),
)
if check_md5:
b64_md5sum = base64.b64encode(check_md5).decode("utf-8") if check_md5 else None
blob_md5 = blob_client.get_blob_properties().properties.content_settings.content_md5
if b64_md5sum != blob_md5:
raise exceptions.ChecksumMismatchException(
f"Checksum mismatch for object {dst_object_name} in Azure container {self.container_name}, "
+ f"expected {b64_md5sum}, got {blob_md5}"
else:
src_file_path, dst_object_name = str(src_file_path), str(dst_object_name)
with open(src_file_path, "rb") as f:
print(f"Uploading {src_file_path} to {dst_object_name}")
blob_client = self.container_client.upload_blob(
name=dst_object_name,
data=f,
length=os.path.getsize(src_file_path),
max_concurrency=self.max_concurrency,
overwrite=True,
content_settings=azure_blob.ContentSettings(content_type=mime_type),
)
if check_md5:
b64_md5sum = base64.b64encode(check_md5).decode("utf-8") if check_md5 else None
blob_md5 = blob_client.get_blob_properties().properties.content_settings.content_md5
if b64_md5sum != blob_md5:
raise exceptions.ChecksumMismatchException(
f"Checksum mismatch for object {dst_object_name} in Azure container {self.container_name}, "
+ f"expected {b64_md5sum}, got {blob_md5}"
)

def initiate_multipart_upload(self, dst_object_name: str, mime_type: Optional[str] = None) -> str:
self.multipart_map[dst_object_name] = (self.container_client, list())
return
#logger.info(dst_object_name + " cat")
#self.multipart_map[dst_object_name] = (self.container_client, list())

def receive_parts(self, dst_object_name: str, parts):
self.multipart_map[dst_object_name] = parts

def complete_multipart_upload(self, dst_object_name: str, upload_id: str) -> None:
client, part_list = self.multipart_map[dst_object_name]
#client, part_list = self.multipart_map[dst_object_name]
part_list = self.multipart_map[dst_object_name]
client.commit_block_list(part_list)
3 changes: 3 additions & 0 deletions skyplane/obj_store/object_store_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def delete_objects(self, keys: List[str]):
def initiate_multipart_upload(self, dst_object_name: str, mime_type: Optional[str] = None) -> str:
raise ValueError("Multipart uploads not supported")

def receive_parts(self, dst_object_name: str, parts: List[int]):
pass

def complete_multipart_upload(self, dst_object_name: str, upload_id: str) -> None:
raise ValueError("Multipart uploads not supported")

Expand Down

0 comments on commit faa53ae

Please sign in to comment.