Skip to content

Commit

Permalink
fix finializing upload?
Browse files Browse the repository at this point in the history
  • Loading branch information
lynnliu030 committed Oct 20, 2023
1 parent 5c9dde9 commit c03c184
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 26 deletions.
29 changes: 23 additions & 6 deletions skyplane/api/transfer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,15 @@ def _run_multipart_chunk_thread(
metadata = (block_ids, mime_type)

self.multipart_upload_requests.append(
dict(upload_id=upload_id, key=dest_object.key, parts=parts, region=region, bucket=bucket, metadata=metadata)
dict(
upload_id=upload_id,
key=dest_object.key,
parts=parts,
region=region,
bucket=bucket,
metadata=metadata,
vm=True if dest_iface.provider == "vm" else False,
)
)
else:
mime_type = None
Expand Down Expand Up @@ -306,6 +314,7 @@ def transfer_pair_generator(
if isinstance(dst_iface, VMInterface):
# VM destination
from skyplane.obj_store.vm_interface import VMFile

host_ip = dst_iface.host_ip()

dest_obj = VMFile(provider=dest_provider, bucket=host_ip, key=dest_key)
Expand Down Expand Up @@ -353,7 +362,7 @@ def chunk(self, transfer_pair_generator: Generator[TransferPair, None, None]) ->
multipart_chunk_threads = []

# start chunking threads
if self.transfer_config.multipart_enabled: # and not isinstance(self.dst_ifaces[0], VMInterface):
if self.transfer_config.multipart_enabled: # and not isinstance(self.dst_ifaces[0], VMInterface):
for _ in range(self.concurrent_multipart_chunk_threads):
t = threading.Thread(
target=self._run_multipart_chunk_thread,
Expand Down Expand Up @@ -387,12 +396,12 @@ def chunk(self, transfer_pair_generator: Generator[TransferPair, None, None]) ->
)
)

if self.transfer_config.multipart_enabled: # and not isinstance(self.dst_ifaces[0], VMInterface):
if self.transfer_config.multipart_enabled: # and not isinstance(self.dst_ifaces[0], VMInterface):
# drain multipart chunk queue and yield with updated chunk IDs
while not multipart_chunk_queue.empty():
yield multipart_chunk_queue.get()

if self.transfer_config.multipart_enabled: # and not isinstance(self.dst_ifaces[0], VMInterface):
if self.transfer_config.multipart_enabled: # and not isinstance(self.dst_ifaces[0], VMInterface):
# wait for processing multipart requests to finish
logger.fs.debug("Waiting for multipart threads to finish")
# while not multipart_send_queue.empty():
Expand Down Expand Up @@ -722,11 +731,14 @@ def finalize(self):
for req in self.multipart_transfer_list:
if "region" not in req or "bucket" not in req:
raise Exception(f"Invalid multipart upload request: {req}")
groups[(req["region"], req["bucket"])].append(req)
groups[(req["region"], req["bucket"], req["vm"])].append(req)
for key, group in groups.items():
region, bucket = key
region, bucket, vm = key
batch_len = max(1, len(group) // 128)
batches = [group[i : i + batch_len] for i in range(0, len(group), batch_len)]
print(f"region: {region}, bucket: {bucket}")
if vm:
region = "vm:" + region
obj_store_interface = StorageInterface.create(region, bucket)

def complete_fn(batch):
Expand All @@ -748,14 +760,19 @@ def verify(self):
def verify_region(i):
dst_iface = self.dst_ifaces[i]
dst_prefix = self.dst_prefixes[i]
print("Dst prefix: ", dst_prefix)

# gather destination key mapping for this region
dst_keys = {pair.dst_objs[dst_iface.region_tag()].key: pair.src_obj for pair in self.transfer_list}
print(f"Destination key mappings: {dst_keys}")

# list and check destination prefix
for obj in dst_iface.list_objects(dst_prefix):
print(f"Object listed: {obj.key}")
# check metadata (src.size == dst.size) && (src.modified <= dst.modified)
src_obj = dst_keys.get(obj.key)
print(f"src_obj: {src_obj}")
print(f"Object: {obj}")
if src_obj and src_obj.size == obj.size and src_obj.last_modified <= obj.last_modified:
del dst_keys[obj.key]

Expand Down
4 changes: 2 additions & 2 deletions skyplane/cli/cli_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ def run_transfer(
# fallback option: transfer is too small
if cli.args["cmd"] == "cp":
job = CopyJob(src, [dst], recursive=recursive) # TODO: rever to using pipeline
if cli.estimate_small_transfer(job, 0.01 * GB): # Test small transfer
# if cli.estimate_small_transfer(job, cloud_config.get_flag("native_cmd_threshold_gb") * GB):
if cli.estimate_small_transfer(job, 0.01 * GB): # Test small transfer
# if cli.estimate_small_transfer(job, cloud_config.get_flag("native_cmd_threshold_gb") * GB):
small_transfer_status = cli.transfer_cp_small(src, dst, recursive)
return 0 if small_transfer_status else 1
else:
Expand Down
26 changes: 13 additions & 13 deletions skyplane/obj_store/vm_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import mimetypes
import os
from typing import Iterator, List, Optional
from typing import Any, Iterator, List, Optional
import uuid
from dateutil.parser import parse
import paramiko
Expand Down Expand Up @@ -39,8 +39,8 @@ def __init__(
self.region = region
self.private_key_path = private_key_path
self.local_path = local_path
self.temp_dir = "/tmp/multipart_uploads/" # directory on the VMs
self.temp_dir = "/tmp/multipart_uploads/" # directory on the VMs

# Set up SSH
ssh_client = paramiko.SSHClient()
ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
Expand Down Expand Up @@ -76,8 +76,8 @@ def __init__(

@property
def provider(self) -> str:
return "VM"
return "vm"

def region_tag(self) -> str:
return self.region

Expand All @@ -91,7 +91,7 @@ def key_path(self) -> str:
return str(self.private_key_path)

def bucket(self) -> str:
return self.local_path
return f"{self.region}@{self.username}@{self.host}:{self.local_path}?private_key_path={self.private_key_path}"

def host_ip(self) -> str:
return self.host
Expand Down Expand Up @@ -140,7 +140,7 @@ def download_object(self, src_object_name, dst_file_path):
sftp = self.client.open_sftp()
sftp.get(f"{self.path}/{src_object_name}", dst_file_path)

def upload_object(self, src_file_path, dst_object_name, part_number=None, upload_id=None):
def upload_object(self, src_file_path, dst_object_name, part_number=None, upload_id=None):
sftp = self.client.open_sftp()
if part_number and upload_id:
remote_part_path = f"{self.temp_dir}/{upload_id}/{part_number}"
Expand Down Expand Up @@ -177,17 +177,17 @@ def initiate_multipart_upload(self, dst_object_name: str, mime_type: Optional[st
if error_message:
raise exceptions.BadConfigException(f"Failed to create directory on VM: {error_message}")
return upload_id
def complete_multipart_upload(self, dst_object_name, upload_id):

def complete_multipart_upload(self, dst_object_name, upload_id, metadata: Optional[Any] = None):
_, stdout, _ = self.client.exec_command(f"ls {self.temp_dir}/{upload_id}")
parts = [f"{self.temp_dir}/{upload_id}/{part}" for part in sorted(stdout.read().decode().split(), key=int)]

# Concatenate all parts together
concatenated_parts = " ".join(parts)
_, stderr, _ = self.client.exec_command(f"cat {concatenated_parts} > {self.path}/{dst_object_name}")
_, stderr, _ = self.client.exec_command(f"cat {concatenated_parts} > {dst_object_name}")
error_message = stderr.read().decode().strip()
if error_message:
raise exceptions.BadConfigException(f"Failed to complete multipart upload on VM: {error_message}")

# Cleanup
_, _, _ = self.client.exec_command(f"rm -r {self.temp_dir}/{upload_id}")
# Cleanup
# _, _, _ = self.client.exec_command(f"rm -r {self.temp_dir}/{upload_id}")
2 changes: 1 addition & 1 deletion skyplane/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def plan(self, jobs: List[TransferJob]) -> TopologyPlan:
vm_types[dst_region_tag] if vm_types else None,
instance_id=dst_vm_instance_id,
instance_path=dst_vm_instance_path,
instance_key_path=dst_vm_key_path
instance_key_path=dst_vm_key_path,
)

# initialize gateway programs per region
Expand Down
4 changes: 2 additions & 2 deletions skyplane/planner/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
gateway_vm: Optional[str],
gateway_instance_id: Optional[str] = None,
gateway_instance_path: Optional[str] = None,
gateway_key_path: Optional[str] = None
gateway_key_path: Optional[str] = None,
):
self.region_tag = region_tag
self.gateway_id = gateway_id
Expand Down Expand Up @@ -99,7 +99,7 @@ def add_gateway(
vm_type: Optional[str] = None,
instance_id: Optional[str] = None,
instance_path: Optional[str] = None,
instance_key_path: Optional[str] = None
instance_key_path: Optional[str] = None,
):
"""Create gateway in specified region"""
gateway_id = region_tag + str(len([gateway for gateway in self.gateways.values() if gateway.region_tag == region_tag]))
Expand Down
1 change: 1 addition & 0 deletions tests/interface_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def interface_test_framework(region, bucket, multipart: bool, test_delete_bucket
interface = ObjectStoreInterface.create(region, bucket)
return interface_test_from_iface(interface, multipart=multipart, test_delete_bucket=test_delete_bucket, file_size_mb=file_size_mb)


def interface_test_from_iface(interface, multipart: bool, test_delete_bucket: bool = False, file_size_mb: int = 1):
interface.create_bucket(region.split(":")[1])
time.sleep(5)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_vm/test_vm_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tests.interface_util import interface_test_framework
from skyplane.utils import logger


def provision_vm():
vm = GCPCloudProvider().provision_instance("us-east1", "n2-standard-2")
return vm
Expand All @@ -17,12 +18,11 @@ def test_vm_simple():
region = vm.region_tag
vm_host = "skyplane"
vm_private_key_path = vm.ssh_private_key

vm_iface = VMInterface(vm_host, vm.gcp_instance_name, vm_region, vm_private_key_path)

# test a provisioned vm exists
assert vm_iface.exists()

# test basic transfer
assert interface_test_from_iface(vm_iface)

0 comments on commit c03c184

Please sign in to comment.