diff --git a/skyplane/api/client.py b/skyplane/api/client.py index a1a43143a..d597fc432 100644 --- a/skyplane/api/client.py +++ b/skyplane/api/client.py @@ -2,14 +2,19 @@ import typer from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional from skyplane.api.config import TransferConfig from skyplane.api.provisioner import Provisioner from skyplane.api.obj_store import ObjectStore +from skyplane.config import SkyplaneConfig +from skyplane.config_paths import config_path +from skyplane.obj_store.object_store_interface import ObjectStoreInterface +from skyplane.obj_store.storage_interface import StorageInterface from skyplane.api.usage import get_clientid from skyplane.utils import logger from skyplane.utils.definitions import tmp_log_dir +from skyplane.utils.path import parse_path from skyplane.api.pipeline import Pipeline @@ -28,6 +33,10 @@ def __init__( ibmcloud_config: Optional["IBMCloudConfig"] = None, transfer_config: Optional[TransferConfig] = None, log_dir: Optional[str] = None, + disable_aws: bool = False, + disable_azure: bool = False, + disable_gcp: bool = False, + disable_ibm: bool = False ): """ :param aws_config: aws cloud configurations @@ -66,9 +75,44 @@ def __init__( azure_auth=self.azure_auth, gcp_auth=self.gcp_auth, ibmcloud_auth=self.ibmcloud_auth, + disable_aws=disable_aws, + disable_azure=disable_azure, + disable_gcp=disable_gcp, + disable_ibm=disable_ibm ) - def pipeline(self, planning_algorithm: Optional[str] = "direct", max_instances: Optional[int] = 1, debug=False): + self.config = SkyplaneConfig.default_config() + if not disable_aws: + self.config.aws_enabled = True + if aws_config: + self.config.aws_access_key = aws_config.aws_access_key + self.config.aws_secret_key = aws_config.aws_secret_key + if not disable_azure: + self.config.azure_enabled = True + if azure_config: + self.config.azure_subscription_id=azure_config.azure_subscription_id + self.config.azure_resource_group=azure_config.azure_resource_group + self.config.azure_principal_id=azure_config.azure_umi_id + self.config.azure_umi_name=azure_config.azure_umi_name + self.config.azure_client_id=azure_config.azure_umi_client_id + if not disable_gcp: + self.config.gcp_enabled = True + if gcp_config: + self.config.gcp_project_id=gcp_config.gcp_project_id + if not disable_ibm: + self.config.ibm_enabled = True + if ibm_config: + self.config.ibmcloud_access_id=ibm_config.ibmcloud_access_id + self.config.ibmcloud_secret_key=ibm_config.ibmcloud_secret_key + self.config.ibmcloud_iam_key=ibm_config.ibmcloud_iam_key + self.config.ibmcloud_iam_endpoint=ibm_config.ibmcloud_iam_endpoint + self.config.ibmcloud_useragent=ibm_config.ibmcloud_useragent + self.config.ibmcloud_resource_group_id=ibm_config.ibmcloud_resource_group_id + + self.config.to_config_file(config_path) + typer.secho(f"\nConfig file saved to {config_path}", fg="green") + + def pipeline(self, planning_algorithm: Optional[str] = "direct", max_instances: Optional[int] = 1, src_iface: Optional[ObjectStoreInterface] = None, dst_ifaces: Optional[List[ObjectStoreInterface]] = None, debug=False): """Create a pipeline object to queue jobs""" return Pipeline( planning_algorithm=planning_algorithm, @@ -76,6 +120,8 @@ def pipeline(self, planning_algorithm: Optional[str] = "direct", max_instances: clientid=self.clientid, provisioner=self.provisioner, transfer_config=self.transfer_config, + src_iface=src_iface, + dst_ifaces=dst_ifaces, debug=debug, ) @@ -93,8 +139,20 @@ def copy(self, src: str, dst: str, recursive: bool = False, max_instances: Optio :param max_instances: The maximum number of instances to use per region (default: 1) :type max_instances: int """ + provider_src, bucket_src, _ = parse_path(src) + + src_iface = ObjectStoreInterface.create(f"{provider_src}:infer", bucket_src, aws_auth=self.aws_auth, azure_auth=self.azure_auth, gcp_auth=self.gcp_auth) + + if isinstance(dst, str): + provider_dst, bucket_dst, _ = parse_path(dst) + dst_ifaces = [StorageInterface.create(f"{provider_dst}:infer", bucket_dst, aws_auth=self.aws_auth, azure_auth=self.azure_auth, gcp_auth=self.gcp_auth)] + else: + dst_ifaces = [] + for dst_path in dst: + provider_dst, bucket_dst, _ = parse_path(dst_path) + dst_ifaces.append(StorageInterface.create(f"{provider_dst}:infer", bucket_dst, aws_auth=self.aws_auth, azure_auth=self.azure_auth, gcp_auth=self.gcp_auth)) - pipeline = self.pipeline(max_instances=max_instances, debug=debug) + pipeline = self.pipeline(max_instances=max_instances, debug=debug, src_iface = src_iface, dst_ifaces=dst_ifaces) pipeline.queue_copy(src, dst, recursive=recursive) pipeline.start(progress=True) diff --git a/skyplane/api/pipeline.py b/skyplane/api/pipeline.py index 6fa3face4..19aab51fb 100644 --- a/skyplane/api/pipeline.py +++ b/skyplane/api/pipeline.py @@ -10,6 +10,7 @@ from skyplane.api.transfer_job import CopyJob, SyncJob, TransferJob from skyplane.api.config import TransferConfig +from skyplane.obj_store.object_store_interface import ObjectStoreInterface from skyplane.planner.planner import MulticastDirectPlanner, DirectPlannerSourceOneSided, DirectPlannerDestOneSided from skyplane.planner.topology import TopologyPlanGateway from skyplane.utils import logger @@ -33,6 +34,8 @@ def __init__( max_instances: Optional[int] = 1, n_connections: Optional[int] = 64, planning_algorithm: Optional[str] = "direct", + src_iface: Optional[ObjectStoreInterface] = None, + dst_ifaces: Optional[List[ObjectStoreInterface]] = None, debug: Optional[bool] = False, ): """ @@ -70,6 +73,10 @@ def __init__( else: raise ValueError(f"No such planning algorithm {planning_algorithm}") + # obj store interfaces + self.src_iface = src_iface + self.dst_ifaces = dst_ifaces + # transfer logs self.transfer_dir = tmp_log_dir / "transfer_logs" / datetime.now().strftime("%Y%m%d_%H%M%S") self.transfer_dir.mkdir(exist_ok=True, parents=True) @@ -146,7 +153,7 @@ def queue_copy( """ if isinstance(dst, str): dst = [dst] - job = CopyJob(src, dst, recursive, requester_pays=self.transfer_config.requester_pays) + job = CopyJob(src, dst, recursive, requester_pays=self.transfer_config.requester_pays, src_iface=self.src_iface, dst_ifaces=self.dst_ifaces) logger.fs.debug(f"[SkyplaneClient] Queued copy job {job}") self.jobs_to_dispatch.append(job) return job.uuid @@ -169,7 +176,7 @@ def queue_sync( """ if isinstance(dst, str): dst = [dst] - job = SyncJob(src, dst, requester_pays=self.transfer_config.requester_pays) + job = SyncJob(src, dst, requester_pays=self.transfer_config.requester_pays, src_iface=self.src_iface, dst_ifaces=self.dst_ifaces) logger.fs.debug(f"[SkyplaneClient] Queued sync job {job}") self.jobs_to_dispatch.append(job) return job.uuid diff --git a/skyplane/api/provisioner.py b/skyplane/api/provisioner.py index 09e89cf52..49735ef17 100644 --- a/skyplane/api/provisioner.py +++ b/skyplane/api/provisioner.py @@ -52,6 +52,10 @@ def __init__( gcp_auth: Optional[compute.GCPAuthentication] = None, host_uuid: Optional[str] = None, ibmcloud_auth: Optional[compute.IBMCloudAuthentication] = None, + disable_aws: bool = False, + disable_azure: bool = False, + disable_gcp: bool = False, + disable_ibm: bool = False ): """ :param aws_auth: authentication information for aws @@ -70,7 +74,7 @@ def __init__( self.gcp_auth = gcp_auth self.host_uuid = host_uuid self.ibmcloud_auth = ibmcloud_auth - self._make_cloud_providers() + self._make_cloud_providers(disable_aws, disable_azure, disable_gcp, disable_ibm) self.temp_nodes: Set[compute.Server] = set() # temporary area to store nodes that should be terminated upon exit self.pending_provisioner_tasks: List[ProvisionerTask] = [] self.provisioned_vms: Dict[str, compute.Server] = {} @@ -78,13 +82,17 @@ def __init__( # store GCP firewall rules to be deleted upon exit self.gcp_firewall_rules: Set[str] = set() - def _make_cloud_providers(self): - self.aws = compute.AWSCloudProvider( - key_prefix=f"skyplane{'-'+self.host_uuid.replace('-', '') if self.host_uuid else ''}", auth=self.aws_auth - ) - self.azure = compute.AzureCloudProvider(auth=self.azure_auth) - self.gcp = compute.GCPCloudProvider(auth=self.gcp_auth) - self.ibmcloud = compute.IBMCloudProvider(auth=self.ibmcloud_auth) + def _make_cloud_providers(self, disable_aws, disable_azure, disable_gcp, disable_ibm): + if not disable_aws: + self.aws = compute.AWSCloudProvider( + key_prefix=f"skyplane{'-'+self.host_uuid.replace('-', '') if self.host_uuid else ''}", auth=self.aws_auth + ) + if not disable_azure: + self.azure = compute.AzureCloudProvider(auth=self.azure_auth) + if not disable_gcp: + self.gcp = compute.GCPCloudProvider(auth=self.gcp_auth) + if not disable_ibm: + self.ibmcloud = compute.IBMCloudProvider(auth=self.ibmcloud_auth) def init_global(self, aws: bool = True, azure: bool = True, gcp: bool = True, ibmcloud: bool = True): """ diff --git a/skyplane/api/transfer_job.py b/skyplane/api/transfer_job.py index cf3322985..acd7f1cdb 100644 --- a/skyplane/api/transfer_job.py +++ b/skyplane/api/transfer_job.py @@ -468,6 +468,8 @@ def __init__( recursive: bool = False, requester_pays: bool = False, job_id: Optional[str] = None, + src_iface: Optional[ObjectStoreInterface] = None, + dst_ifaces: Optional[List[ObjectStoreInterface]] = None ): self.src_path = src_path self.dst_paths = dst_paths @@ -477,6 +479,8 @@ def __init__( self.uuid = str(uuid.uuid4()) else: self.uuid = job_id + self._src_iface = src_iface + self._dst_ifaces = dst_ifaces @property def transfer_type(self) -> str: @@ -495,7 +499,7 @@ def src_prefix(self) -> Optional[str]: @property def src_iface(self) -> StorageInterface: """Return the source object store interface""" - if not hasattr(self, "_src_iface"): + if not self._src_iface: provider_src, bucket_src, _ = parse_path(self.src_path) self._src_iface = ObjectStoreInterface.create(f"{provider_src}:infer", bucket_src) if self.requester_pays: @@ -515,7 +519,7 @@ def dst_prefixes(self) -> List[str]: @property def dst_ifaces(self) -> List[StorageInterface]: """Return the destination object store interface""" - if not hasattr(self, "_dst_iface"): + if not self._dst_ifaces: if self.transfer_type == "unicast": provider_dst, bucket_dst, _ = parse_path(self.dst_paths[0]) self._dst_ifaces = [StorageInterface.create(f"{provider_dst}:infer", bucket_dst)] @@ -570,8 +574,10 @@ def __init__( recursive: bool = False, requester_pays: bool = False, job_id: Optional[str] = None, + src_iface: Optional[ObjectStoreInterface] = None, + dst_ifaces: Optional[List[ObjectStoreInterface]] = None ): - super().__init__(src_path, dst_paths, recursive, requester_pays, job_id) + super().__init__(src_path, dst_paths, recursive, requester_pays, job_id, src_iface, dst_ifaces) self.transfer_list = [] self.multipart_transfer_list = [] @@ -772,8 +778,8 @@ def size_gb(self): class SyncJob(CopyJob): """sync job that copies the source objects that does not exist in the destination bucket to the destination""" - def __init__(self, src_path: str, dst_paths: List[str] or str, requester_pays: bool = False, job_id: Optional[str] = None): - super().__init__(src_path, dst_paths, True, requester_pays, job_id) + def __init__(self, src_path: str, dst_paths: List[str] or str, requester_pays: bool = False, job_id: Optional[str] = None, src_iface: Optional[ObjectStoreInterface] = None, dst_ifaces: Optional[List[ObjectStoreInterface]] = None): + super().__init__(src_path, dst_paths, True, requester_pays, job_id, src_iface, dst_ifaces) self.transfer_list = [] self.multipart_transfer_list = [] diff --git a/skyplane/compute/aws/aws_auth.py b/skyplane/compute/aws/aws_auth.py index 532c01e20..ba249b040 100644 --- a/skyplane/compute/aws/aws_auth.py +++ b/skyplane/compute/aws/aws_auth.py @@ -18,6 +18,10 @@ def __init__(self, config: Optional[SkyplaneConfig] = None, access_key: Optional self.config_mode = "manual" self._access_key = access_key self._secret_key = secret_key + elif self.config.aws_access_key and self.config.aws_secret_key: + self.config_mode = "manual" + self._access_key = self.config.aws_access_key + self._secret_key = self.config.aws_secret_key else: self.config_mode = "iam_inferred" self._access_key = None diff --git a/skyplane/compute/aws/aws_cloud_provider.py b/skyplane/compute/aws/aws_cloud_provider.py index d82733f63..0613dd967 100644 --- a/skyplane/compute/aws/aws_cloud_provider.py +++ b/skyplane/compute/aws/aws_cloud_provider.py @@ -56,7 +56,7 @@ def get_instance_list(exceptions, self, region: str) -> List[AWSServer]: except exceptions.ClientError as e: logger.error(f"error provisioning in {region}: {e}") return [] - return [AWSServer(f"aws:{region}", i) for i in instance_ids] + return [AWSServer(f"aws:{region}", i, auth=self.auth) for i in instance_ids] def setup_global(self, iam_name: str = "skyplane_gateway", attach_policy_arn: Optional[str] = None): # Create IAM role if it doesn't exist and grant managed role if given. @@ -246,4 +246,4 @@ def start_instance(subnet_id: str): logger.fs.warning(f"Terminating instance {instance[0].id} due to keyboard interrupt") instance[0].terminate() raise - return AWSServer(f"aws:{region}", instance[0].id) + return AWSServer(f"aws:{region}", instance[0].id, auth=self.auth) diff --git a/skyplane/compute/aws/aws_server.py b/skyplane/compute/aws/aws_server.py index a2c1d2a22..3226e94b7 100644 --- a/skyplane/compute/aws/aws_server.py +++ b/skyplane/compute/aws/aws_server.py @@ -21,10 +21,10 @@ class AWSServer(Server): """AWS Server class to support basic SSH operations""" - def __init__(self, region_tag, instance_id, log_dir=None): + def __init__(self, region_tag, instance_id, log_dir=None, auth: Optional[AWSAuthentication] = None): super().__init__(region_tag, log_dir=log_dir) assert self.region_tag.split(":")[0] == "aws" - self.auth = AWSAuthentication() + self.auth = AWSAuthentication() if auth is None else auth self.key_manager = AWSKeyManager(self.auth) self.aws_region = self.region_tag.split(":")[1] self.instance_id = instance_id diff --git a/skyplane/config.py b/skyplane/config.py index f1210d8a9..49993ee48 100644 --- a/skyplane/config.py +++ b/skyplane/config.py @@ -104,6 +104,8 @@ class SkyplaneConfig: cloudflare_enabled: bool ibmcloud_enabled: bool anon_clientid: str + aws_access_key: Optional[str] = None + aws_secret_key: Optional[str] = None azure_principal_id: Optional[str] = None azure_subscription_id: Optional[str] = None azure_resource_group: Optional[str] = None @@ -148,9 +150,15 @@ def load_config(cls, path) -> "SkyplaneConfig": anon_clientid = cls.generate_machine_id() aws_enabled = False + aws_access_key = None + aws_secret_access_key = None if "aws" in config: if "aws_enabled" in config["aws"]: aws_enabled = config.getboolean("aws", "aws_enabled") + if "aws_access_key" in config["aws"]: + aws_access_key = config.get("aws", "aws_access_key") + if "aws_secret_key" in config["aws"]: + aws_secret_key = config.get("aws", "aws_secret_key") azure_enabled = False azure_subscription_id = None @@ -215,6 +223,8 @@ def load_config(cls, path) -> "SkyplaneConfig": gcp_enabled=gcp_enabled, ibmcloud_enabled=ibmcloud_enabled, cloudflare_enabled=cloudflare_enabled, + aws_access_key=aws_access_key, + aws_secret_key=aws_secret_key, anon_clientid=anon_clientid, azure_principal_id=azure_principal_id, azure_subscription_id=azure_subscription_id, @@ -248,6 +258,10 @@ def to_config_file(self, path): if "aws" not in config: config.add_section("aws") config.set("aws", "aws_enabled", str(self.aws_enabled)) + if self.aws_access_key: + config.set("aws", "aws_access_key", self.aws_access_key) + if self.aws_secret_key: + config.set("aws", "aws_secret_key", self.aws_secret_key) if "ibmcloud" not in config: config.add_section("ibmcloud") diff --git a/skyplane/planner/planner.py b/skyplane/planner/planner.py index 2bb21e0ac..d716a4f9d 100644 --- a/skyplane/planner/planner.py +++ b/skyplane/planner/planner.py @@ -28,9 +28,12 @@ class Planner: - def __init__(self, transfer_config: TransferConfig, quota_limits_file: Optional[str] = None): + def __init__(self, transfer_config: TransferConfig, config: Optional[SkyplaneConfig] = None, quota_limits_file: Optional[str] = None): self.transfer_config = transfer_config - self.config = SkyplaneConfig.load_config(config_path) + if config_path.exists(): + self.config = SkyplaneConfig.load_config(config_path) + else: + self.config = SkyplaneConfig.default_config() self.n_instances = self.config.get_flag("max_instances") # Loading the quota information, add ibm cloud when it is supported