diff --git a/README.md b/README.md index e71aef0d8..01d3cefef 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,13 @@ Copy https://github.com/skyplane-project/skyplane/blob/main/skyplane/compute/ibm into `~/.bluemix/ibm_credentials` and fill your IBM IAM key and credentials to your IBM Cloud object storage +---> For SCP: +$ # Create directory if required +$ mkdir -p ~/.scp +$ # Add the lines for "access_key", "secret_key", and "project_id" to scp_credential file +$ echo "access_key = " >> ~/.scp/scp_credential +$ echo "secret_key = " >> ~/.scp/scp_credential +$ echo "project_id = " >> ~/.scp/scp_credential ``` After authenticating with each cloud provider, you can run `skyplane init` to create a configuration file for Skyplane. @@ -149,6 +156,11 @@ $ skyplane init Enter the GCP project ID [XXXXXXX]: GCP region config file saved to /home/ubuntu/.skyplane/gcp_config +(4) Configuring SCP: + Loaded SCP credentials from the scp_credntial file [access key: ...XXXXXX] + SCP region config file saved to /home/ubuntu/.skyplane/scp_config + + Config file saved to /home/ubuntu/.skyplane/config ``` diff --git a/pyproject.toml b/pyproject.toml index 00ee8be99..d59fe74a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ aws = ["boto3"] azure = ["azure-identity", "azure-mgmt-authorization", "azure-mgmt-compute", "azure-mgmt-network", "azure-mgmt-resource", "azure-mgmt-storage", "azure-mgmt-quota", "azure-mgmt-subscription", "azure-storage-blob"] gcp = ["google-api-python-client", "google-auth", "google-cloud-compute", "google-cloud-storage"] ibm = ["ibm-cloud-sdk-core", "ibm-cos-sdk", "ibm-vpc"] +scp = ["boto3"] all = ["boto3", "azure-identity", "azure-mgmt-authorization", "azure-mgmt-compute", "azure-mgmt-network", "azure-mgmt-resource", "azure-mgmt-storage", "azure-mgmt-subscription", "azure-storage-blob", "google-api-python-client", "google-auth", "google-cloud-compute", "google-cloud-storage", "ibm-cloud-sdk-core", "ibm-cos-sdk", "ibm-vpc"] gateway = ["flask", "lz4", "pynacl", "pyopenssl", "werkzeug"] solver = ["cvxpy", "graphviz", "matplotlib", "numpy"] diff --git a/skyplane/__init__.py b/skyplane/__init__.py index d2f4ed239..473b9ab26 100644 --- a/skyplane/__init__.py +++ b/skyplane/__init__.py @@ -23,5 +23,6 @@ "AWSConfig", "AzureConfig", "GCPConfig", + "SCPConfig", "TransferHook", ] diff --git a/skyplane/api/client.py b/skyplane/api/client.py index 7f0bb1437..eb87ea427 100644 --- a/skyplane/api/client.py +++ b/skyplane/api/client.py @@ -14,7 +14,7 @@ from skyplane.api.pipeline import Pipeline if TYPE_CHECKING: - from skyplane.api.config import AWSConfig, AzureConfig, GCPConfig, TransferConfig, IBMCloudConfig + from skyplane.api.config import AWSConfig, AzureConfig, GCPConfig, TransferConfig, IBMCloudConfig, SCPConfig class SkyplaneClient: @@ -26,6 +26,7 @@ def __init__( azure_config: Optional["AzureConfig"] = None, gcp_config: Optional["GCPConfig"] = None, ibmcloud_config: Optional["IBMCloudConfig"] = None, + scp_config: Optional["SCPConfig"] = None, transfer_config: Optional[TransferConfig] = None, log_dir: Optional[str] = None, ): @@ -48,6 +49,7 @@ def __init__( self.azure_auth = azure_config.make_auth_provider() if azure_config else None self.gcp_auth = gcp_config.make_auth_provider() if gcp_config else None self.ibmcloud_auth = ibmcloud_config.make_auth_provider() if ibmcloud_config else None + self.scp_auth = scp_config.make_auth_provider() if scp_config else None self.transfer_config = transfer_config if transfer_config else TransferConfig() self.log_dir = ( tmp_log_dir / "transfer_logs" / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}-{uuid.uuid4().hex[:8]}" @@ -66,6 +68,7 @@ def __init__( azure_auth=self.azure_auth, gcp_auth=self.gcp_auth, ibmcloud_auth=self.ibmcloud_auth, + scp_auth=self.scp_auth, ) def pipeline(self, planning_algorithm: Optional[str] = "direct", max_instances: Optional[int] = 1, debug=False): diff --git a/skyplane/api/config.py b/skyplane/api/config.py index b65e820ac..3195ddc9d 100644 --- a/skyplane/api/config.py +++ b/skyplane/api/config.py @@ -4,7 +4,7 @@ from skyplane import compute -from skyplane.config_paths import aws_quota_path, gcp_quota_path, azure_standardDv5_quota_path +from skyplane.config_paths import aws_quota_path, gcp_quota_path, azure_standardDv5_quota_path, scp_quota_path from pathlib import Path @@ -61,6 +61,17 @@ def make_auth_provider(self) -> compute.IBMCloudAuthentication: # pytype: enable=attribute-error +@dataclass +class SCPConfig(AuthenticationConfig): + scp_access_key: Optional[str] = None + scp_secret_key: Optional[str] = None + scp_project_id: Optional[str] = None + scp_enabled: bool = False + + def make_auth_provider(self) -> compute.SCPAuthentication: + return compute.SCPAuthentication(config=self) # type: ignore + + @dataclass(frozen=True) class TransferConfig: autoterminate_minutes: int = 15 @@ -82,16 +93,22 @@ class TransferConfig: azure_use_spot_instances: bool = False gcp_use_spot_instances: bool = False ibmcloud_use_spot_instances: bool = False + # Add SCP Support + scp_use_spot_instances: bool = False aws_instance_class: str = "m5.8xlarge" azure_instance_class: str = "Standard_D2_v5" gcp_instance_class: str = "n2-standard-16" ibmcloud_instance_class: str = "bx2-2x8" gcp_use_premium_network: bool = True + # Add SCP Support + scp_instance_class: str = "h1v32m128" aws_vcpu_file: Path = aws_quota_path gcp_vcpu_file: Path = gcp_quota_path azure_vcpu_file: Path = azure_standardDv5_quota_path + # Add SCP Support + scp_vcpu_file: Path = scp_quota_path # TODO: add ibmcloud when the quota info is available # multipart config diff --git a/skyplane/api/dataplane.py b/skyplane/api/dataplane.py index c8253a922..69916093d 100644 --- a/skyplane/api/dataplane.py +++ b/skyplane/api/dataplane.py @@ -21,6 +21,7 @@ from skyplane.utils import logger from skyplane.utils.definitions import gateway_docker_image, tmp_log_dir from skyplane.utils.fn import PathLike, do_parallel +from skyplane.utils.retry import retry_backoff if TYPE_CHECKING: from skyplane.api.provisioner import Provisioner @@ -156,6 +157,7 @@ def provision( is_azure_used = any(n.region_tag.startswith("azure:") for n in self.topology.get_gateways()) is_gcp_used = any(n.region_tag.startswith("gcp:") for n in self.topology.get_gateways()) is_ibmcloud_used = any(n.region_tag.startswith("ibmcloud:") for n in self.topology.get_gateways()) + is_scp_used = any(n.region_tag.startswith("scp:") for n in self.topology.get_gateways()) # create VMs from the topology for node in self.topology.get_gateways(): @@ -172,7 +174,7 @@ def provision( ) # initialize clouds - self.provisioner.init_global(aws=is_aws_used, azure=is_azure_used, gcp=is_gcp_used, ibmcloud=is_ibmcloud_used) + self.provisioner.init_global(aws=is_aws_used, azure=is_azure_used, gcp=is_gcp_used, ibmcloud=is_ibmcloud_used, scp=is_scp_used) # provision VMs uuids = self.provisioner.provision( @@ -273,9 +275,13 @@ def deprovision(self, max_jobs: int = 64, spinner: bool = False): def check_error_logs(self) -> Dict[str, List[str]]: """Get the error log from remote gateways if there is any error.""" + def http_pool_request(instance): + return self.http_pool.request("GET", f"{instance.gateway_api_url}/api/v1/errors") + def get_error_logs(args): _, instance = args - reply = self.http_pool.request("GET", f"{instance.gateway_api_url}/api/v1/errors") + # reply = self.http_pool.request("GET", f"{instance.gateway_api_url}/api/v1/errors") + reply = retry_backoff(partial(http_pool_request, instance)) if reply.status != 200: raise Exception(f"Failed to get error logs from gateway instance {instance.instance_name()}: {reply.data.decode('utf-8')}") return json.loads(reply.data.decode("utf-8"))["errors"] diff --git a/skyplane/api/obj_store.py b/skyplane/api/obj_store.py index cbf57a05b..6436f709b 100644 --- a/skyplane/api/obj_store.py +++ b/skyplane/api/obj_store.py @@ -39,6 +39,8 @@ def create_bucket(self, region: str, bucket_name: str): return f"s3://{bucket_name}" elif provider == "gcp": return f"gs://{bucket_name}" + elif provider == "scp": + return f"scp://{bucket_name}" else: raise NotImplementedError(f"Provider {provider} not implemented") diff --git a/skyplane/api/provisioner.py b/skyplane/api/provisioner.py index 09e89cf52..7e0981742 100644 --- a/skyplane/api/provisioner.py +++ b/skyplane/api/provisioner.py @@ -52,6 +52,7 @@ def __init__( gcp_auth: Optional[compute.GCPAuthentication] = None, host_uuid: Optional[str] = None, ibmcloud_auth: Optional[compute.IBMCloudAuthentication] = None, + scp_auth: Optional[compute.SCPAuthentication] = None, ): """ :param aws_auth: authentication information for aws @@ -64,12 +65,15 @@ def __init__( :type host_uuid: string :param ibmcloud_auth: authentication information for aws :type ibmcloud_auth: compute.IBMCloudAuthentication + :param scp_auth: authentication information for scp + :type scp_auth: compute.SCPAuthentication """ self.aws_auth = aws_auth self.azure_auth = azure_auth self.gcp_auth = gcp_auth self.host_uuid = host_uuid self.ibmcloud_auth = ibmcloud_auth + self.scp_auth = scp_auth self._make_cloud_providers() self.temp_nodes: Set[compute.Server] = set() # temporary area to store nodes that should be terminated upon exit self.pending_provisioner_tasks: List[ProvisionerTask] = [] @@ -85,8 +89,9 @@ def _make_cloud_providers(self): self.azure = compute.AzureCloudProvider(auth=self.azure_auth) self.gcp = compute.GCPCloudProvider(auth=self.gcp_auth) self.ibmcloud = compute.IBMCloudProvider(auth=self.ibmcloud_auth) + self.scp = compute.SCPCloudProvider(auth=self.scp_auth) - def init_global(self, aws: bool = True, azure: bool = True, gcp: bool = True, ibmcloud: bool = True): + def init_global(self, aws: bool = True, azure: bool = True, gcp: bool = True, ibmcloud: bool = True, scp: bool = True): """ Initialize the global cloud providers by configuring with credentials @@ -110,6 +115,9 @@ def init_global(self, aws: bool = True, azure: bool = True, gcp: bool = True, ib jobs.append(self.gcp.setup_global) if ibmcloud: jobs.append(self.ibmcloud.setup_global) + if scp: + jobs.append(self.scp.create_ssh_key) + jobs.append(self.scp.setup_global) do_parallel(lambda fn: fn(), jobs, spinner=False) @@ -174,6 +182,10 @@ def _provision_task(self, task: ProvisionerTask): elif task.cloud_provider == "ibmcloud": assert self.ibmcloud.auth.enabled(), "IBM Cloud credentials not configured" server = self.ibmcloud.provision_instance(task.region, task.vm_type, tags=task.tags) + elif task.cloud_provider == "scp": + assert self.scp.auth.enabled(), "SCP credentials not configured" + # print('def _provision_task : ', task.region, task.vm_type, task.tags) + server = self.scp.provision_instance(task.region, task.vm_type, tags=task.tags) else: raise NotImplementedError(f"Unknown provider {task.cloud_provider}") logger.fs.debug(f"[Provisioner._provision_task] Provisioned {server} in {t.elapsed:.2f}s") @@ -206,6 +218,8 @@ def provision(self, authorize_firewall: bool = True, max_jobs: int = 16, spinner azure_provisioned = any([task.cloud_provider == "azure" for task in provision_tasks]) gcp_provisioned = any([task.cloud_provider == "gcp" for task in provision_tasks]) ibmcloud_provisioned = any([task.cloud_provider == "ibmcloud" for task in provision_tasks]) + scp_regions = set([task.region for task in provision_tasks if task.cloud_provider == "scp"]) + scp_provisioned = any([task.cloud_provider == "scp" for task in provision_tasks]) # configure regions if aws_provisioned: @@ -224,6 +238,25 @@ def provision(self, authorize_firewall: bool = True, max_jobs: int = 16, spinner ) logger.fs.info(f"[Provisioner.provision] Configured IBM Cloud regions {ibmcloud_regions}") + if scp_provisioned: + logger.fs.info("SCP provisioning may sometimes take several minutes. Please be patient.") + do_parallel( + self.scp.setup_region, + list(set(scp_regions)), + spinner=spinner, + spinner_persist=False, + desc="Configuring SCP regions", + ) + # server group create, add provision_tasks on tags(region) + for r in set(scp_regions): + servergroup = self.scp.network.create_server_group(r) + for task in provision_tasks: + if task.cloud_provider == "scp" and task.region == r: + task.tags["servergroup"] = servergroup + # print('provisioner.py - task.tags : ', task.tags) + + logger.fs.info(f"[Provisioner.provision] Configured SCP regions {scp_regions}") + # provision VMs logger.fs.info(f"[Provisioner.provision] Provisioning {len(provision_tasks)} VMs") results: List[Tuple[ProvisionerTask, compute.Server]] = do_parallel( @@ -253,6 +286,18 @@ def authorize_gcp_gateways(): self.gcp_firewall_rules.add(self.gcp.authorize_gateways(public_ips + private_ips)) authorize_ip_jobs.append(authorize_gcp_gateways) + if scp_provisioned: + # configure firewall for each scp region + for r in set(scp_regions): + scp_ips = [s.private_ip() for t, s in results if t.cloud_provider == "scp" and t.region == r] + # vpcids = [s.vpc_id for t, s in results if t.cloud_provider == "scp" and t.region == r] # pytype: disable=bad-return-type + vpcids = [ + s.vpc_id if isinstance(s, compute.SCPServer) else None + for t, s in results + if t.cloud_provider == "scp" and t.region == r + ] + # print('provisioner.py - scp_ips : ', scp_ips, ', vpcids : ', vpcids) + authorize_ip_jobs.extend([partial(self.scp.add_firewall_rule_all, r, scp_ips, vpcids)]) do_parallel( lambda fn: fn(), @@ -303,6 +348,7 @@ def deprovision_gateway_instance(server: compute.Server): azure_deprovisioned = any([s.provider == "azure" for s in servers]) gcp_deprovisioned = any([s.provider == "gcp" for s in servers]) ibmcloud_deprovisioned = any([s.provider == "ibmcloud" for s in servers]) + scp_deprovisioned = any([s.provider == "scp" for s in servers]) if azure_deprovisioned: logger.warning("Azure deprovisioning is very slow. Please be patient.") logger.fs.info(f"[Provisioner.deprovision] Deprovisioning {len(servers)} VMs") @@ -327,6 +373,14 @@ def deprovision_gateway_instance(server: compute.Server): if gcp_deprovisioned: jobs.extend([partial(self.gcp.remove_gateway_rule, rule) for rule in self.gcp_firewall_rules]) logger.fs.info(f"[Provisioner.deprovision] Deauthorizing GCP gateways with firewalls: {self.gcp_firewall_rules}") + if scp_deprovisioned: + scp_regions = set([s.region() for s in servers if s.provider == "scp"]) + for r in set(scp_regions): + scp_servers = [s for s in servers if s.provider == "scp" and s.region() == r] + scp_ips = [s.private_ip() for s in scp_servers] + vpcids = [s.vpc_id for s in scp_servers] + jobs.extend([partial(self.scp.remove_gateway_rule_region, r, scp_ips, vpcids)]) + logger.fs.info(f"[Provisioner.deprovision] Deauthorizing SCP gateways with firewalls: {scp_ips}") do_parallel( lambda fn: fn(), jobs, n=max_jobs, spinner=spinner, spinner_persist=False, desc="Deauthorizing gateways from firewalls" ) diff --git a/skyplane/api/tracker.py b/skyplane/api/tracker.py index 0439dc5a0..dbd2e56e5 100644 --- a/skyplane/api/tracker.py +++ b/skyplane/api/tracker.py @@ -3,6 +3,7 @@ from abc import ABC from datetime import datetime from threading import Thread +from functools import partial import urllib3 from typing import TYPE_CHECKING, Dict, List, Optional, Set @@ -16,6 +17,7 @@ from skyplane.utils.fn import do_parallel from skyplane.api.usage import UsageClient from skyplane.utils.definitions import GB +from skyplane.utils.retry import retry_backoff from skyplane.cli.impl.common import print_stats_completed @@ -335,10 +337,14 @@ def monitor_transfer(pd, self, region_tag): def _chunk_to_job_map(self): return {chunk_id: job_uuid for job_uuid, cr_dict in self.job_chunk_requests.items() for chunk_id in cr_dict.keys()} + def http_pool_request(self, instance): + return self.http_pool.request("GET", f"{instance.gateway_api_url}/api/v1/chunk_status_log") + def _query_chunk_status(self): def get_chunk_status(args): node, instance = args - reply = self.http_pool.request("GET", f"{instance.gateway_api_url}/api/v1/chunk_status_log") + # reply = self.http_pool.request("GET", f"{instance.gateway_api_url}/api/v1/chunk_status_log") + reply = retry_backoff(partial(self.http_pool_request, instance)) if reply.status != 200: raise Exception( f"Failed to get chunk status from gateway instance {instance.instance_name()}: {reply.data.decode('utf-8')}" diff --git a/skyplane/api/transfer_job.py b/skyplane/api/transfer_job.py index 0155b7c17..7bf0f2509 100644 --- a/skyplane/api/transfer_job.py +++ b/skyplane/api/transfer_job.py @@ -307,6 +307,10 @@ def transfer_pair_generator( from skyplane.obj_store.r2_interface import R2Object dest_obj = R2Object(provider=dest_provider, bucket=dst_iface.bucket(), key=dest_key) + elif dest_provider == "scp": + from skyplane.obj_store.scp_interface import SCPObject + + dest_obj = SCPObject(provider=dest_provider, bucket=dst_iface.bucket(), key=dest_key) else: raise ValueError(f"Invalid dest_region {dest_region}, unknown provider") dest_objs[dst_iface.region_tag()] = dest_obj @@ -626,6 +630,25 @@ def dispatch( n_multiparts = 0 time.time() + # rare case where we need to retry + def mapping_request(dst_gateway, mappings): + reply = self.http_pool.request( + "POST", + f"{dst_gateway.gateway_api_url}/api/v1/upload_id_maps", + body=json.dumps(mappings).encode("utf-8"), + headers={"Content-Type": "application/json"}, + ) + return reply + + def chunk_request(server, chunk_batch, n_added): + reply = self.http_pool.request( + "POST", + f"{server.gateway_api_url}/api/v1/chunk_requests", + body=json.dumps([chunk.as_dict() for chunk in chunk_batch[n_added:]]).encode("utf-8"), + headers={"Content-Type": "application/json"}, + ) + return reply + for batch in batches: # send upload_id mappings to sink gateways upload_id_batch = [cr for cr in batch if cr.upload_id_mapping is not None] @@ -641,12 +664,14 @@ def dispatch( mappings[region_tag][key] = id # send mapping to gateway - reply = self.http_pool.request( - "POST", - f"{dst_gateway.gateway_api_url}/api/v1/upload_id_maps", - body=json.dumps(mappings).encode("utf-8"), - headers={"Content-Type": "application/json"}, - ) + # reply = self.http_pool.request( + # "POST", + # f"{dst_gateway.gateway_api_url}/api/v1/upload_id_maps", + # body=json.dumps(mappings).encode("utf-8"), + # headers={"Content-Type": "application/json"}, + # ) + reply = retry_backoff(partial(mapping_request, dst_gateway, mappings), initial_backoff=0.5) + # TODO: assume that only destination nodes would write to the obj store if reply.status != 200: raise Exception( @@ -667,12 +692,13 @@ def dispatch( # TODO: make async st = time.time() - reply = self.http_pool.request( - "POST", - f"{server.gateway_api_url}/api/v1/chunk_requests", - body=json.dumps([chunk.as_dict() for chunk in chunk_batch[n_added:]]).encode("utf-8"), - headers={"Content-Type": "application/json"}, - ) + # reply = self.http_pool.request( + # "POST", + # f"{server.gateway_api_url}/api/v1/chunk_requests", + # body=json.dumps([chunk.as_dict() for chunk in chunk_batch[n_added:]]).encode("utf-8"), + # headers={"Content-Type": "application/json"}, + # ) + reply = retry_backoff(partial(chunk_request, server, chunk_batch, n_added), initial_backoff=0.5) if reply.status != 200: raise Exception(f"Failed to dispatch chunk requests {server.instance_name()}: {reply.data.decode('utf-8')}") et = time.time() diff --git a/skyplane/api/usage.py b/skyplane/api/usage.py index f2e17482a..f939dcae3 100644 --- a/skyplane/api/usage.py +++ b/skyplane/api/usage.py @@ -300,6 +300,8 @@ def make_error( if dest_region_tags is None: dest_provider, dest_region = None, None else: + if isinstance(dest_region_tags, str): + dest_region_tags = [dest_region_tags] dest_providers = [tag.split(":")[0] for tag in dest_region_tags] dest_regions = [tag.split(":")[1] for tag in dest_region_tags] diff --git a/skyplane/cli/cli.py b/skyplane/cli/cli.py index eb1d38f5d..ae1ef369d 100644 --- a/skyplane/cli/cli.py +++ b/skyplane/cli/cli.py @@ -48,6 +48,12 @@ def deprovision( if instances: typer.secho(f"Deprovisioning {len(instances)} instances", fg="yellow", bold=True) do_parallel(lambda instance: instance.terminate_instance(), instances, desc="Deprovisioning", spinner=True, spinner_persist=True) + # if compute.SCPAuthentication().enabled(): + if any([instance.region_tag.split(":")[0] == "scp" for instance in instances]): + typer.secho(f"Removing SCP Gateway Rules", fg="yellow", bold=True) + scp_instances = [instance for instance in instances if instance.region_tag.split(":")[0] == "scp"] + scp = compute.SCPCloudProvider() + scp.remove_gateway_rule_all(scp_instances) else: typer.secho("No instances to deprovision", fg="yellow", bold=True) @@ -61,6 +67,9 @@ def deprovision( if compute.AzureAuthentication().enabled(): azure = compute.AzureCloudProvider() azure.teardown_global() + if compute.SCPAuthentication().enabled(): + scp = compute.SCPCloudProvider() + scp.teardown_global() @app.command() diff --git a/skyplane/cli/cli_init.py b/skyplane/cli/cli_init.py index 5bd6e3d2f..95881310c 100644 --- a/skyplane/cli/cli_init.py +++ b/skyplane/cli/cli_init.py @@ -16,7 +16,7 @@ from skyplane.cli.impl.common import print_header from skyplane.api.usage import UsageClient, UsageStatsStatus from skyplane.config import SkyplaneConfig -from skyplane.config_paths import aws_config_path, gcp_config_path, config_path, ibmcloud_config_path +from skyplane.config_paths import aws_config_path, gcp_config_path, config_path, ibmcloud_config_path, scp_config_path from skyplane.utils import logger @@ -471,17 +471,80 @@ def load_config(): return config +def load_scp_config(config: SkyplaneConfig, force_init: bool = False, non_interactive: bool = False) -> SkyplaneConfig: + def disable_scp_support(): + typer.secho(" Disabling SCP support", fg="blue") + config.scp_enabled = False + config.scp_access_key = None + config.scp_secret_key = None + compute.SCPAuthentication.clear_region_config() + return config + + if non_interactive or typer.confirm(" Do you want to configure SCP support in Skyplane?", default=True): + # SCP credentials check + auth = compute.SCPAuthentication(config=config) + credentials_path = auth.scp_credential_path + + if not os.path.exists(os.path.expanduser(credentials_path)): + config.scp_enabled = False + else: + with open(os.path.expanduser(credentials_path), "r") as f: + lines = [line.strip() for line in f.readlines() if " = " in line] + credentials_scp = {} + for line in lines: + key, value = line.split(" = ") + credentials_scp[key] = value.strip() + + if ( + "scp_access_key" not in credentials_scp + or "scp_secret_key" not in credentials_scp + or credentials_scp["scp_access_key"] is None + or credentials_scp["scp_secret_key"] is None + ): + config.scp_enabled = False + else: + config.scp_enabled = True + + # auth = compute.SCPAuthentication(config=config) + + if force_init: + typer.secho(" SCP configurations will be re-initialized", fg="red", err=True) + compute.SCPAuthentication.clear_region_config() + + if config.scp_enabled: + typer.secho( + f" Loaded SCP credentials from the scp_credential [scp_access key: ...{credentials_scp['scp_access_key'][-6:]}]", + fg="blue", + ) + auth.save_region_config(config) + typer.secho(f" SCP region config file saved to {scp_config_path}", fg="blue") + config.scp_enabled = True + return config + else: + typer.secho( + " SCP credentials not found in scp_credential, please check the scp_credential via scp credential guide", + fg="red", + err=True, + ) + return disable_scp_support() + + else: + return disable_scp_support() + + def init( non_interactive: bool = typer.Option(False, "--non-interactive", "-y", help="Run non-interactively"), reinit_azure: bool = False, reinit_gcp: bool = False, reinit_ibm: bool = False, reinit_cloudflare: bool = False, + reinit_scp: bool = False, disable_config_aws: bool = False, disable_config_azure: bool = False, disable_config_gcp: bool = False, disable_config_ibm: bool = True, # TODO: eventuall enable IBM disable_config_cloudflare: bool = False, + disable_config_scp: bool = False, ): """ It loads the configuration file, and if it doesn't exist, it creates a default one. Then it creates @@ -493,6 +556,8 @@ def init( :type reinit_gcp: bool :param reinit_ibm: If true, will reinitialize the IBM Cloud region list and credentials :type reinit_ibm: bool + :param reinit_scp: If true, will reinitialize the SCP region list and credentials + :type reinit_scp: bool :param disable_config_aws: If true, will disable AWS configuration (may still be enabled if environment variables are set) :type disable_config_aws: bool :param disable_config_azure: If true, will disable Azure configuration (may still be enabled if environment variables are set) @@ -503,6 +568,8 @@ def init( :type disable_config_ibm: bool :param disable_config_cloudflare: If true, will disable Cloudflare configuration (may still be enabled if environment variables are set) :type disable_config_cloudflare: bool + :param disable_config_scp: If true, will disable SCP configuration (may still be enabled if environment variables are set) + :type disable_config_scp: bool """ print_header() @@ -516,13 +583,13 @@ def init( cloud_config = SkyplaneConfig.default_config() # load AWS config - if not (reinit_azure or reinit_gcp or reinit_ibm): + if not (reinit_azure or reinit_gcp or reinit_ibm or reinit_cloudflare or reinit_scp): typer.secho("\n(1) Configuring AWS:", fg="yellow", bold=True) if not disable_config_aws: cloud_config = load_aws_config(cloud_config, non_interactive=non_interactive) # load Azure config - if not (reinit_gcp or reinit_ibm): + if not (reinit_gcp or reinit_ibm or reinit_cloudflare or reinit_scp): if reinit_azure: typer.secho("\nConfiguring Azure:", fg="yellow", bold=True) else: @@ -531,7 +598,7 @@ def init( cloud_config = load_azure_config(cloud_config, force_init=reinit_azure, non_interactive=non_interactive) # load GCP config - if not reinit_azure: + if not (reinit_azure or reinit_ibm or reinit_cloudflare or reinit_scp): if reinit_gcp: typer.secho("\nConfiguring GCP:", fg="yellow", bold=True) else: @@ -540,18 +607,27 @@ def init( cloud_config = load_gcp_config(cloud_config, force_init=reinit_gcp, non_interactive=non_interactive) # load cloudflare config - if not reinit_cloudflare and not disable_config_cloudflare: # TODO: fix reinit logic + if not (reinit_gcp or reinit_cloudflare or reinit_scp) and not disable_config_cloudflare: # TODO: fix reinit logic typer.secho("\n(4) Configuring Cloudflare R2:", fg="yellow", bold=True) if not disable_config_cloudflare: cloud_config = load_cloudflare_config(cloud_config, non_interactive=non_interactive) # load IBMCloud config - if not disable_config_ibm and not reinit_ibm: + if not disable_config_ibm and not (reinit_ibm or reinit_scp): # TODO: fix IBM configuration to not fail on file finding typer.secho("\n(4) Configuring IBM Cloud:", fg="yellow", bold=True) if not disable_config_ibm: cloud_config = load_ibmcloud_config(cloud_config, force_init=reinit_ibm, non_interactive=non_interactive) + # load SCP config + if not (reinit_azure or reinit_gcp or reinit_ibm or reinit_cloudflare): + if reinit_scp: + typer.secho("\nConfiguring SCP:", fg="yellow", bold=True) + else: + typer.secho("\n(5) Configuring SCP:", fg="yellow", bold=True) + if not disable_config_scp: + cloud_config = load_scp_config(cloud_config, force_init=reinit_scp, non_interactive=non_interactive) + cloud_config.to_config_file(config_path) typer.secho(f"\nConfig file saved to {config_path}", fg="green") diff --git a/skyplane/cli/cli_transfer.py b/skyplane/cli/cli_transfer.py index c1d6620fd..ebb38592f 100644 --- a/skyplane/cli/cli_transfer.py +++ b/skyplane/cli/cli_transfer.py @@ -9,7 +9,7 @@ from rich.progress import Progress, TextColumn, SpinnerColumn import skyplane -from skyplane.api.config import TransferConfig, AWSConfig, GCPConfig, AzureConfig, IBMCloudConfig +from skyplane.api.config import TransferConfig, AWSConfig, GCPConfig, AzureConfig, IBMCloudConfig, SCPConfig from skyplane.api.transfer_job import CopyJob, SyncJob, TransferJob from skyplane.cli.impl.cp_replicate_fallback import ( replicate_onprem_cp_cmd, @@ -52,7 +52,9 @@ class SkyplaneCLI: def __init__(self, src_region_tag: str, dst_region_tag: str, args: Dict[str, Any], skyplane_config: Optional[SkyplaneConfig] = None): self.src_region_tag, self.dst_region_tag = src_region_tag, dst_region_tag self.args = args - self.aws_config, self.azure_config, self.gcp_config, self.ibmcloud_config = self.to_api_config(skyplane_config or cloud_config) + self.aws_config, self.azure_config, self.gcp_config, self.ibmcloud_config, self.scp_config = self.to_api_config( + skyplane_config or cloud_config + ) # update config # TODO: set remaining config params @@ -68,6 +70,7 @@ def __init__(self, src_region_tag: str, dst_region_tag: str, args: Dict[str, Any gcp_config=self.gcp_config, transfer_config=self.transfer_config, ibmcloud_config=self.ibmcloud_config, + scp_config=self.scp_config, ) typer.secho(f"Using Skyplane version {skyplane.__version__}", fg="bright_black") @@ -84,13 +87,19 @@ def to_api_config(self, config: SkyplaneConfig): ibmcloud_resource_group_id=config.ibmcloud_resource_group_id, ibmcloud_enabled=config.ibmcloud_enabled, ) + scp_config = SCPConfig( + scp_access_key=config.scp_access_key, + scp_secret_key=config.scp_secret_key, + scp_project_id=config.scp_project_id, + scp_enabled=config.scp_enabled, + ) if not config.azure_resource_group or not config.azure_umi_name: typer.secho( "Azure resource group and UMI name not configured correctly. Please reinit Azure with `skyplane init --reinit-azure`.", fg=typer.colors.RED, err=True, ) - return aws_config, None, gcp_config, ibmcloud_config + return aws_config, None, gcp_config, ibmcloud_config, scp_config azure_config = AzureConfig( config.azure_subscription_id, config.azure_resource_group, @@ -99,7 +108,7 @@ def to_api_config(self, config: SkyplaneConfig): config.azure_client_id, config.azure_enabled, ) - return aws_config, azure_config, gcp_config, ibmcloud_config + return aws_config, azure_config, gcp_config, ibmcloud_config, scp_config def make_transfer_config(self, config: SkyplaneConfig) -> TransferConfig: intraregion = self.src_region_tag == self.dst_region_tag @@ -118,6 +127,7 @@ def make_transfer_config(self, config: SkyplaneConfig) -> TransferConfig: gcp_instance_class=config.get_flag("gcp_instance_class"), ibmcloud_instance_class=config.get_flag("ibmcloud_instance_class"), gcp_use_premium_network=config.get_flag("gcp_use_premium_network"), + scp_instance_class=config.get_flag("scp_instance_class"), multipart_enabled=config.get_flag("multipart_enabled"), multipart_threshold_mb=config.get_flag("multipart_min_threshold_mb"), multipart_chunk_size_mb=config.get_flag("multipart_chunk_size_mb"), diff --git a/skyplane/cli/impl/common.py b/skyplane/cli/impl/common.py index 6280a4fb8..366241da8 100644 --- a/skyplane/cli/impl/common.py +++ b/skyplane/cli/impl/common.py @@ -56,6 +56,8 @@ def run(): query_jobs.append(catch_error(lambda: compute.AzureCloudProvider().get_matching_instances())) if compute.GCPAuthentication().enabled(): query_jobs.append(catch_error(lambda: compute.GCPCloudProvider().get_matching_instances())) + if compute.SCPAuthentication().enabled(): + query_jobs.append(catch_error(lambda: compute.SCPCloudProvider().get_matching_instances())) # query in parallel for instance_list in do_parallel( lambda f: f(), query_jobs, n=-1, return_args=False, spinner=True, desc="Querying clouds for instances" diff --git a/skyplane/cli/impl/cp_replicate_fallback.py b/skyplane/cli/impl/cp_replicate_fallback.py index eb7025b5b..00c648f12 100644 --- a/skyplane/cli/impl/cp_replicate_fallback.py +++ b/skyplane/cli/impl/cp_replicate_fallback.py @@ -37,6 +37,14 @@ def fallback_cmd_azure_sync(src_path, dest_path): return f"azcopy sync {src_path} {dest_path}" +def fallback_cmd_scp_cp(src_path: str, dest_path: str, recursive: bool) -> str: + return f"cp {src_path} {dest_path}" if not recursive else f"cp -r {src_path} {dest_path}" + + +def fallback_cmd_scp_sync(src_path, dest_path): + return f"rsync -av {src_path} {dest_path}" + + def replicate_onprem_cp_cmd(src, dst, recursive=True) -> Optional[str]: provider_src, _, _ = parse_path(src) provider_dst, _, _ = parse_path(dst) @@ -92,6 +100,9 @@ def replicate_small_cp_cmd(src, dst, recursive=True) -> Optional[str]: # azure -> azure elif provider_src == "azure" and provider_dst == "azure": return fallback_cmd_azure_cp(src, dst, recursive) + # scp -> scp + elif provider_src == "scp" and provider_dst == "scp": + return fallback_cmd_scp_cp(src, dst, recursive) # unsupported fallback else: return None @@ -110,6 +121,9 @@ def replicate_small_sync_cmd(src, dst) -> Optional[str]: # azure -> azure elif provider_src == "azure" and provider_dst == "azure": return fallback_cmd_azure_sync(src, dst) + # scp -> scp + elif provider_src == "scp" and provider_dst == "scp": + return fallback_cmd_scp_sync(src, dst) # unsupported fallback else: return None diff --git a/skyplane/compute/__init__.py b/skyplane/compute/__init__.py index 729207a65..bb65e9bf8 100644 --- a/skyplane/compute/__init__.py +++ b/skyplane/compute/__init__.py @@ -12,6 +12,9 @@ from skyplane.compute.ibmcloud.ibmcloud_auth import IBMCloudAuthentication from skyplane.compute.ibmcloud.ibmcloud_provider import IBMCloudProvider from skyplane.compute.server import Server, ServerState +from skyplane.compute.scp.scp_auth import SCPAuthentication +from skyplane.compute.scp.scp_cloud_provider import SCPCloudProvider +from skyplane.compute.scp.scp_server import SCPServer __all__ = [ "CloudProvider", @@ -29,4 +32,7 @@ "IBMCloudAuthentication", "IBMCloudProvider", "GCPServer", + "SCPAuthentication", + "SCPCloudProvider", + "SCPServer", ] diff --git a/skyplane/compute/cloud_provider.py b/skyplane/compute/cloud_provider.py index 56497e6ca..96ebe2f3d 100644 --- a/skyplane/compute/cloud_provider.py +++ b/skyplane/compute/cloud_provider.py @@ -46,6 +46,12 @@ def get_transfer_cost(src_key, dst_key, premium_tier=True): from skyplane.compute.gcp.gcp_cloud_provider import GCPCloudProvider return GCPCloudProvider.get_transfer_cost(f"gcp:{_}", dst_key, premium_tier) + elif src_provider == "scp": + from skyplane.compute.scp.scp_cloud_provider import SCPCloudProvider + + return SCPCloudProvider.get_transfer_cost(src_key, dst_key, premium_tier) + elif src_provider == "test": + return 0 else: raise ValueError(f"Unknown provider {src_provider}") diff --git a/skyplane/compute/gcp/gcp_pricing.py b/skyplane/compute/gcp/gcp_pricing.py index 120eef4b5..031e41c3d 100644 --- a/skyplane/compute/gcp/gcp_pricing.py +++ b/skyplane/compute/gcp/gcp_pricing.py @@ -30,7 +30,7 @@ def get_transfer_cost(src_key, dst_key, premium_tier=True): return 0.15 else: return 0.08 - elif dst_provider in ["aws", "azure", "cloudflare"] and premium_tier: + elif dst_provider in ["aws", "azure", "cloudflare", "scp"] and premium_tier: is_dst_australia = (dst == "ap-southeast-2") if dst_provider == "aws" else (dst.startswith("australia")) # singapore or tokyo or osaka if src_continent == "asia" and (src_region == "southeast2" or src_region == "northeast1" or src_region == "northeast2"): @@ -43,7 +43,7 @@ def get_transfer_cost(src_key, dst_key, premium_tier=True): return 0.19 if is_dst_australia else 0.147 else: return 0.19 if is_dst_australia else 0.12 - elif dst_provider in ["aws", "azure", "cloudflare"] and not premium_tier: + elif dst_provider in ["aws", "azure", "cloudflare", "scp"] and not premium_tier: if src_continent == "us" or src_continent == "europe" or src_continent == "northamerica": return 0.085 elif src_continent == "southamerica" or src_continent == "australia": diff --git a/skyplane/compute/scp/scp_auth.py b/skyplane/compute/scp/scp_auth.py new file mode 100644 index 000000000..c6c96d244 --- /dev/null +++ b/skyplane/compute/scp/scp_auth.py @@ -0,0 +1,130 @@ +# (C) Copyright Samsung SDS. 2023 + +# + +# Licensed under the Apache License, Version 2.0 (the "License"); + +# you may not use this file except in compliance with the License. + +# You may obtain a copy of the License at + +# + +# http://www.apache.org/licenses/LICENSE-2.0 + +# + +# Unless required by applicable law or agreed to in writing, software + +# distributed under the License is distributed on an "AS IS" BASIS, + +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +# See the License for the specific language governing permissions and + +# limitations under the License. +import json +from typing import List, Optional +from skyplane.config_paths import config_path, scp_config_path, scp_quota_path + +from skyplane.compute.server import key_root +from skyplane.config import SkyplaneConfig +from skyplane.compute.scp.scp_utils import SCPClient +from skyplane.compute.scp.scp_utils import CREDENTIALS_PATH + + +class SCPAuthentication: + def __init__(self, config: Optional[SkyplaneConfig] = None): + if not config is None: + self.config = config + else: + self.config = SkyplaneConfig.load_config(config_path) + + self._scp_access_key = self.config.scp_access_key + self._scp_secret_key = self.config.scp_secret_key + self._scp_project_id = self.config.scp_project_id + + def _get_scp_vm_quota(self, region): + """scp vm quota - no restirction on vcpu, but if no enough quota, scp will return error""" + result_list = [] + for region_name in region: + region_dict = {"on_demand_standard_vcpus": 500, "service_zone_name": region_name} + result_list.append(region_dict) + + return result_list + + def save_region_config(self, config: SkyplaneConfig): + if self.config.scp_enabled == False: + self.clear_region_config() + return + region_list = self.get_region_config() + + with open(scp_config_path, "w") as f: + for region in region_list: + f.write(region + "\n") + + quota_infos = self._get_scp_vm_quota(region_list) + with scp_quota_path.open("w") as f: + f.write(json.dumps(quota_infos, indent=2)) + + @staticmethod + def get_zones() -> dict: + scp_client = SCPClient() + url = f"/project/v3/projects/{scp_client.project_id}/zones" + + response = scp_client._get(url) + return response + + @staticmethod + def get_region_config() -> List[str]: + zones = SCPAuthentication.get_zones() + # print(zones) + service_zone_locations = [zone["serviceZoneName"] for zone in zones] + return service_zone_locations + + def clear_region_config(): + with scp_config_path.open("w") as f: + f.write("") + with scp_quota_path.open("w") as f: + f.write("") + + def enabled(self): + return self.config.scp_enabled + + def get_zone_location(self, zone_name): + zones = self.get_zones() + for zone in zones: + if zone["serviceZoneName"] == zone_name: + return zone["serviceZoneLocation"] + return None + + @property + def credential_path(self): + credential_path = key_root / "scp" / "scp_credential" + return credential_path + + @property + def scp_credential_path(self): + return CREDENTIALS_PATH + + @property + def scp_access_key(self): + if self._scp_access_key is None: + self._scp_access_key, self._scp_secret_key, self._scp_project_id = self.infer_credentials() + return self._scp_access_key + + @property + def scp_secret_key(self): + if self._scp_secret_key is None: + self._scp_access_key, self._scp_secret_key, self._scp_project_id = self.infer_credentials() + return self._scp_secret_key + + @property + def scp_project_id(self): + if self._scp_project_id is None: + self._scp_access_key, self._scp_secret_key, self._scp_project_id = self.infer_credentials() + return self._scp_project_id + + def infer_credentials(self): + scp_client = SCPClient() + return scp_client.access_key, scp_client.secret_key, scp_client.project_id diff --git a/skyplane/compute/scp/scp_cloud_provider.py b/skyplane/compute/scp/scp_cloud_provider.py new file mode 100644 index 000000000..ddeba4157 --- /dev/null +++ b/skyplane/compute/scp/scp_cloud_provider.py @@ -0,0 +1,285 @@ +# (C) Copyright Samsung SDS. 2023 + +# + +# Licensed under the Apache License, Version 2.0 (the "License"); + +# you may not use this file except in compliance with the License. + +# You may obtain a copy of the License at + +# + +# http://www.apache.org/licenses/LICENSE-2.0 + +# + +# Unless required by applicable law or agreed to in writing, software + +# distributed under the License is distributed on an "AS IS" BASIS, + +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +# See the License for the specific language governing permissions and + +# limitations under the License. +import os +from typing import List, Optional +import uuid +from pathlib import Path + +from skyplane.compute.key_utils import generate_keypair + +from skyplane import exceptions +from skyplane.compute.scp.scp_auth import SCPAuthentication +from skyplane.compute.scp.scp_server import SCPServer +from skyplane.compute.scp.scp_network import SCPNetwork +from skyplane.compute.cloud_provider import CloudProvider +from skyplane.compute.server import key_root +from skyplane.utils import logger +from skyplane.utils.fn import do_parallel + + +class SCPCloudProvider(CloudProvider): + def __init__( + self, + key_prefix: str = "skyplane", + key_root=key_root / "scp", + auth: Optional[SCPAuthentication] = None, + network: Optional[SCPNetwork] = None, + ): + super().__init__() + self.key_name = key_prefix + self.auth = auth if auth else SCPAuthentication() + self.network = network if network else SCPNetwork(self.auth) + key_root.mkdir(parents=True, exist_ok=True) + self.private_key_path = key_root / "scp_key" + self.public_key_path = key_root / "scp_key.pub" + + @property + def name(self): + return "scp" + + @staticmethod + def region_list() -> List[str]: + return SCPAuthentication.get_region_config() + + @classmethod + def get_transfer_cost(cls, src_key, dst_key, premium_tier=True): + assert src_key.startswith("scp:") + dst_provider, dst_region = dst_key.split(":") + + return 0.077 + + def create_ssh_key(self): + public_key_path = Path(self.public_key_path) + private_key_path = Path(self.private_key_path) + if not private_key_path.exists(): + private_key_path.parent.mkdir(parents=True, exist_ok=True) + generate_keypair(public_key_path, private_key_path) + + def get_init_script(self): + cmd_st = "mkdir -p ~/.ssh/; touch ~/.ssh/authorized_keys;" + with open(os.path.expanduser(self.public_key_path)) as f: + pub_key = f.read() + cmd = "echo '{}' &>>~/.ssh/authorized_keys;".format(pub_key) + cmd_ed = "chmod 644 ~/.ssh/authorized_keys; chmod 700 ~/.ssh/; " + cmd_ed += "mkdir -p ~/.scp; " + + return cmd_st + cmd + cmd_ed + + def get_instance_list(self, region) -> List[SCPServer]: + try: + service_zone_id = self.network.get_service_zone_id(region) + servergroups = self.network.list_server_group() + servergroups = [sg for sg in servergroups if sg["serviceZoneId"] == service_zone_id] + + instances = [] + for sg in servergroups: + instances += self.network.list_server_group_instance(sg["serverGroupId"]) + + instancelist = [] + validstate = ["RUNNING", "STOPPED", "STOPPING", "STARTING"] + for instance in instances: + if instance["virtualServerState"] in validstate: + instancelist.append( + {"virtualServerId": instance["virtualServerId"], "virtualServerName": instance["virtualServerName"]} + ) + + return [ + SCPServer(f"scp:{region}", virtualServerName=i["virtualServerName"], virtualServerId=i["virtualServerId"]) + for i in instancelist + ] + except Exception as e: + raise exceptions.SkyplaneException(f"Failed to get instance list: {e}") + + def setup_global(self, **kwargs): + pass + + def setup_region(self, region: str): + self.network.make_vpc(region) + + def provision_instance( + self, + region: str, + instance_class: str, + disk_size: int = 100, + use_spot_instances: bool = False, + name: Optional[str] = None, + tags={"skyplane": "true"}, + instance_os: str = "ubuntu", + ) -> SCPServer: + assert not region.startswith("scp:"), "Region should be SCP region" + if name is None: + name = f"skyplane-scp-{uuid.uuid4().hex[:8]}" + url = f"/virtual-server/v3/virtual-servers" + + service_zone_id = self.network.get_service_zone_id(region) + + vpc_name = self.network.define_vpc_name(region) + vpc_id = self.network.list_vpcs(service_zone_id, vpc_name)[0]["vpcId"] + init_script = self.get_init_script() + subnet_id = self.network.list_subnets(vpc_id, None)[0]["subnetId"] + security_group_id = self.network.list_security_groups(vpc_id, None)[0]["securityGroupId"] + servergroup = tags.get("servergroup", "default") + image_id = "IMAGE-mX_5UKOJqriGnWZ2GXEbEg" # Ubuntu 22.04 + + req_body = { + "availabilityZoneName": "AZ1" if region == "KOREA-WEST-MAZ-SCP-B001" else None, + "blockStorage": { + "blockStorageName": "Skyplane-Block-Storage", + "diskSize": disk_size, + }, + "imageId": image_id, + "initialScript": { + "encodingType": "plain", + "initialScriptShell": "bash", + "initialScriptType": "text", + "initialScriptContent": init_script, + }, + "nic": { + "natEnabled": "true", + # "publicIpId" : "PUBLIC_IP-xxxxxx", + "subnetId": subnet_id, + }, + "osAdmin": {"osUserId": "root", "osUserPassword": "test123$"}, + "securityGroupIds": [security_group_id], + "serverGroupId": servergroup, + "serverType": instance_class, + "serviceZoneId": service_zone_id, + "tags": [ + {"tagKey": "skyplane", "tagValue": "true"}, + {"tagKey": "skyplaneclientid", "tagValue": tags.get("skyplaneclientid", "unknown")}, + ], + "virtualServerName": name, + } + try: + # self.network.scp_client.random_wait() + response = self.network.scp_client._post(url, req_body) + # print(response) + virtualServerId = response["resourceId"] + + # wait for server editing + create_completion_condition = lambda: self.network.get_vs_details(virtualServerId)["virtualServerState"] == "EDITING" + self.network.scp_client.wait_for_completion( + f"[{region}:{name}:{instance_class}] Creating SCP Skyplane Virtual Server", create_completion_condition + ) + + # wait for virtual server ip + create_completion_condition = lambda: self.network.get_vs_details(virtualServerId).get("ip") is not None + self.network.scp_client.wait_for_completion( + f"[{region}:{name}:{instance_class}] Creating SCP Skyplane Virtual Server ip", create_completion_condition + ) + + # Add firewall rule + self.network.add_firewall_22_rule(region, servergroup, virtualServerId) + + # wait for server running + create_completion_condition = lambda: self.network.get_vs_details(virtualServerId)["virtualServerState"] == "RUNNING" + self.network.scp_client.wait_for_completion( + f"[{region}:{name}:{instance_class}] Running SCP Skyplane Virtual Server", create_completion_condition + ) + + server = SCPServer(f"scp:{region}", virtualServerName=name, virtualServerId=virtualServerId, vpc_id=vpc_id) + + except KeyboardInterrupt: + logger.warning(f"keyboard interrupt. You may need to use the deprovision command to terminate instance {name}.") + # wait for server running + create_completion_condition = lambda: self.network.get_vs_details(virtualServerId)["virtualServerState"] == "RUNNING" + self.network.scp_client.wait_for_completion( + f"[{region}:{name}:{instance_class}] Running SCP Skyplane Virtual Server", create_completion_condition + ) + server = SCPServer(f"scp:{region}", virtualServerName=name, virtualServerId=virtualServerId, vpc_id=vpc_id) + server.terminate_instance_impl() + raise + except Exception as e: + # Not found any cluster by requested Scale + if "404" in str(e): + raise Exception( + f"Failed to provision the instance : due to insufficient cluster capacity compared to the requested scale, {e}. Check the SCP console for more details." + ) + else: + raise Exception(f"Exception occurred during provision instance: {e}. Check the SCP console for more details.") + return server + + def add_firewall_rule_all(self, region, private_ips, vpcids): + self.network.add_firewall_rule_all(region, private_ips, vpcids) + + def remove_gateway_rule(self, server: SCPServer): + ip = server.private_ip() + igw = self.network.get_igw(server.vpc_id) + firewall = self.network.get_firewallId(igw) + + ruleIds = self.network.list_firewall_rules(firewall) + + ruleDetails = [self.network.get_firewall_rule_details(firewall, id) for id in ruleIds] + # for d in ruleDetails: print(d) + ruleDetails = [detail for detail in ruleDetails if "ruleId" in detail] # HACK, some times wrong responses + + rule_source = [detail["ruleId"] for detail in ruleDetails if ip in detail["sourceIpAddresses"]] + rule_dest = [detail["ruleId"] for detail in ruleDetails if ip in detail["destinationIpAddresses"]] + rule_ids = list(set(rule_source + rule_dest)) + print(f"remove_gateway_rule : rule_ids - {rule_ids}") + self.network.delete_firewall_rules(server.region(), firewall, rule_ids) + + def remove_gateway_rule_region(self, region, private_ips, vpcids): + igw = self.network.get_igw(vpcids[0]) + firewallId = self.network.get_firewallId(igw) + + ruleIds = self.network.list_firewall_rules(firewallId) + + # For each private IP, check if it appears in the source or destination IPs of ruleIds' details. + ruleDetails = [self.network.get_firewall_rule_details(firewallId, id) for id in ruleIds] + + rule_ids = list( + set( + detail["ruleId"] + for ip in private_ips + for detail in ruleDetails + if ip in detail["sourceIpAddresses"] or ip in detail["destinationIpAddresses"] + ) + ) + + self.network.delete_firewall_rules(region, firewallId, rule_ids) + + def remove_gateway_rule_all(self, instances): + # Get all regions + regions = list(set([instance.region() for instance in instances])) + + # Get all private IPs and VPC IDs for each region, then remove gateway rules for each region + for region in regions: + private_ips = [instance.private_ip() for instance in instances if instance.region() == region] + vpcids = [instance.vpc_id for instance in instances if instance.region() == region] + self.remove_gateway_rule_region(region, private_ips, vpcids) + + def teardown_global(self): + args = [] + for region in self.region_list(): + service_zone_id = self.network.get_service_zone_id(region) + vpc_name = self.network.define_vpc_name(region) + skyplane_vpc = self.network.find_vpc_id(service_zone_id, vpc_name) + if skyplane_vpc: + args.append((region, skyplane_vpc)) + + if len(args) > 0: + do_parallel(lambda x: self.network.delete_vpc(*x), args, desc="Removing VPCs", spinner=True, spinner_persist=True) diff --git a/skyplane/compute/scp/scp_network.py b/skyplane/compute/scp/scp_network.py new file mode 100644 index 000000000..b1a1f49ef --- /dev/null +++ b/skyplane/compute/scp/scp_network.py @@ -0,0 +1,655 @@ +# (C) Copyright Samsung SDS. 2023 + +# + +# Licensed under the Apache License, Version 2.0 (the "License"); + +# you may not use this file except in compliance with the License. + +# You may obtain a copy of the License at + +# + +# http://www.apache.org/licenses/LICENSE-2.0 + +# + +# Unless required by applicable law or agreed to in writing, software + +# distributed under the License is distributed on an "AS IS" BASIS, + +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +# See the License for the specific language governing permissions and + +# limitations under the License. +import time +from typing import List, Optional +import uuid + +from skyplane.compute.scp.scp_auth import SCPAuthentication +from skyplane.compute.scp.scp_utils import SCPClient +from skyplane.utils import logger + + +class SCPNetwork: + def __init__(self, auth: SCPAuthentication, vpc_name="skyplane", sg_name="skyplane"): + self.auth = auth + self.scp_client = SCPClient() + + def get_service_zone_id(self, region: str) -> str: + zones = SCPAuthentication.get_zones() + for zone in zones: + if zone["serviceZoneName"] == region: + return zone["serviceZoneId"] + raise ValueError(f"Region {region} not found in SCP") + + def get_service_zoneName(self, zoneId: str) -> str: + zones = SCPAuthentication.get_zones() + for zone in zones: + if zone["serviceZoneId"] == zoneId: + return zone["serviceZoneName"] + raise ValueError(f"ZoneId {zoneId} not found in SCP") + + def get_vpc_info(self, vpcId): + url = f"/vpc/v2/vpcs?vpcId={vpcId}" + response = self.scp_client._get(url) + return response[0]["serviceZoneId"], response[0]["vpcName"] + + def list_vpcs(self, service_zone_id, vpc_name): + url = f"/vpc/v2/vpcs?serviceZoneId={service_zone_id}" + + response = self.scp_client._get(url) + vpcs = [result for result in response if result["vpcName"] == vpc_name] + + if len(vpcs) == 0: + return None + else: + return vpcs + + def get_vpc_detail(self, vpcId): + url = f"/vpc/v2/vpcs/{vpcId}" + response = self.scp_client._getDetail(url) + return response + + def delete_vpc(self, region, vpcId): + service_zone_id, vpc_name = self.get_vpc_info(vpcId) + # delete subnets + subnets = self.list_subnets(vpcId, subnetId=None) + for subnet in subnets: + self.delete_subnet(region, subnet["subnetId"], vpcId) + + # delete igw - firewall first + igws = self.list_igws(vpcId) + for igw in igws: + firewallId = self.get_firewallId(igw["internetGatewayId"]) + self.delete_firewall_rules(region, firewallId, None) + self.delete_igw(region, igw["internetGatewayId"], vpcId) + + # delete security groups + sgs = self.list_security_groups(vpcId, securityGroupName=None) + for sg in sgs: + self.delete_security_group(region, sg["securityGroupId"], sg["securityGroupName"]) + + # delete vpc + url = f"/vpc/v2/vpcs/{vpcId}" + try: + self.scp_client._delete(url) + delete_completion_condition = lambda: not self.list_vpcs(service_zone_id, vpc_name) + self.scp_client.wait_for_completion(f"[{region}] Deleting SCP Skyplane VPC", delete_completion_condition) + + except Exception as e: + logger.error(e) + raise e + # return response + + def create_vpc(self, region, service_zone_id, vpc_name): + url = "/vpc/v3/vpcs" + req_body = {"serviceZoneId": service_zone_id, "vpcName": vpc_name, "vpcDescription": "Skyplane VPC"} + vpcId = None + try: + response = self.scp_client._post(url, req_body) + vpcId = response["resourceId"] + create_completion_condition = lambda: self.get_vpc_detail(vpcId)["vpcState"] == "ACTIVE" + self.scp_client.wait_for_completion(f"[{region}] Creating SCP Skyplane VPC", create_completion_condition) + except Exception as e: + logger.error(e) + raise e + return vpcId + + def list_igws(self, vpcId: str): + url = f"/internet-gateway/v2/internet-gateways" + response = self.scp_client._get(url) + igws = [result for result in response if result["vpcId"] == vpcId] + return igws + + def delete_igw(self, region, igwId: str, vpcId: str): + url = f"/internet-gateway/v2/internet-gateways/{igwId}" + try: + self.scp_client._delete(url) + delete_completion_condition = lambda: not self.list_igws(vpcId) + self.scp_client.wait_for_completion(f"[{region}] Deleting SCP Skyplane IGW", delete_completion_condition) + except Exception as e: + logger.error(e) + raise e + + def create_igw(self, region, service_zone_id, vpcId: str): + url = "/internet-gateway/v2/internet-gateways" + req_body = { + "firewallEnabled": True, + "serviceZoneId": service_zone_id, + "vpcId": vpcId, + "InternetGatewayDescription": "Default Internet Gateway for Skyplane VPC", + } + try: + response = self.scp_client._post(url, req_body) + igwId = response["resourceId"] + return igwId + except Exception as e: + logger.error(e) + raise e + + def check_created_igw(self, region, igwId: str, vpcId: str): + create_completion_condition = lambda: self.list_igws(vpcId) and self.list_igws(vpcId)[0].get("internetGatewayState") == "ATTACHED" + self.scp_client.wait_for_completion(f"[{region}] Creating SCP Skyplane Internet Gateway", create_completion_condition) + + def list_subnets(self, vpcId: str, subnetId: Optional[str] = None): + if subnetId is None: + url = f"/subnet/v2/subnets?vpcId={vpcId}" + else: + url = f"/subnet/v2/subnets?subnetId={subnetId}" + + response = self.scp_client._get(url) + # subnet = [result for result in response if result['vpcId'] == vpcId] + return response + + def create_subnet(self, region, vpc_name, vpcId: str): + url = "/subnet/v2/subnets" + req_body = { + "subnetCidrBlock": "192.168.0.0/24", + "subnetName": vpc_name + "Sub", + "subnetType": "PUBLIC", + "vpcId": vpcId, + "subnetDescription": "Default Subnet for Skyplane VPC", + } + try: + response = self.scp_client._post(url, req_body) + subnetId = response["resourceId"] + return subnetId + + except Exception as e: + logger.error(e) + raise e + + def check_created_subnet(self, region, subnetId: str, vpcId: str): + create_completion_condition = lambda: self.list_subnets(vpcId, subnetId)[0]["subnetState"] == "ACTIVE" + self.scp_client.wait_for_completion(f"[{region}] Creating SCP Skyplane subnet", create_completion_condition) + + def delete_subnet(self, region, subnetId: str, vpcId: str): + url = f"/subnet/v2/subnets/{subnetId}" + try: + response = self.scp_client._delete(url) + delete_completion_condition = lambda: not self.list_subnets(vpcId, subnetId) + self.scp_client.wait_for_completion(f"[{region}] Deleting SCP Skyplane subnet", delete_completion_condition) + except Exception as e: + logger.error(e) + raise e + return response + + def delete_security_group(self, region, securityGroupId: str, securityGroupName: str): + url = f"/security-group/v2/security-groups/{securityGroupId}" + try: + self.scp_client._delete(url) + delete_completion_condition = lambda: not self.list_security_groups(None, securityGroupName) + self.scp_client.wait_for_completion(f"[{region}] Deleting SCP Skyplane Security Group", delete_completion_condition) + except Exception as e: + logger.error(e) + raise e + + def list_security_groups(self, vpcId, securityGroupName: Optional[str] = None): + if securityGroupName is None: + url = f"/security-group/v2/security-groups?vpcId={vpcId}" + else: + url = f"/security-group/v2/security-groups?securityGroupName={securityGroupName}" + response = self.scp_client._get(url) + return response + + def create_security_group(self, region, service_zone_id, vpcId: str): + url = "/security-group/v3/security-groups" + req_body = { + "loggable": False, + "securityGroupName": "SkyplaneSecuGroup", + "serviceZoneId": service_zone_id, + "vpcId": vpcId, + "securityGroupDescription": "Default Security Group for Skyplane VPC", + } + try: + response = self.scp_client._post(url, req_body) + securityGroupId = response["resourceId"] + create_completion_condition = ( + lambda: self.list_security_groups(vpcId, req_body["securityGroupName"])[0]["securityGroupState"] == "ACTIVE" + ) + self.scp_client.wait_for_completion(f"[{region}] Creating SCP Skyplane Security Group", create_completion_condition) + return securityGroupId + + except Exception as e: + logger.error(e) + raise e + + def find_vpc_id(self, service_zone_id, vpc_name): + vpc_content = self.list_vpcs(service_zone_id, vpc_name) + if vpc_content is None: + return None + + vpc_list = [item["vpcId"] for item in vpc_content] + return vpc_list[0] + + def find_valid_vpc(self, service_zone_id, vpc_name): + vpc_content = self.list_vpcs(service_zone_id, vpc_name) + if vpc_content is None: + return None + + vpc_list = [item["vpcId"] for item in vpc_content if item["vpcState"] == "ACTIVE"] + for vpc_id in vpc_list: + igw_list = self.list_igws(vpc_id) + igw_list = [igw for igw in igw_list if igw["internetGatewayState"] == "ATTACHED"] + subnet_list = self.list_subnets(vpc_id) + subnet_list = [subnet for subnet in subnet_list if subnet["subnetState"] == "ACTIVE" and subnet["subnetType"] == "PUBLIC"] + if len(igw_list) > 0 and len(subnet_list) > 0: + return vpc_id + + return None + + def define_vpc_name(self, region: str): + parts = region.split("-") + return "Skyplane" + parts[0][0:2] + parts[1] + parts[2] + + def make_vpc(self, region: str): + service_zone_id = self.get_service_zone_id(region) + + vpc_name = self.define_vpc_name(region) + + # find matching valid VPC + skyplane_vpc = self.find_valid_vpc(service_zone_id, vpc_name) + + try: + if skyplane_vpc is None: + # create VPC + skyplane_vpc = self.create_vpc(region, service_zone_id, vpc_name) + + # create Subnet + subnet = self.create_subnet(region, vpc_name, skyplane_vpc) + + # create Internet Gateway - firewall, routing table creation is done by SCP + igw = self.create_igw(region, service_zone_id, skyplane_vpc) + + # create security group + sg = self.create_security_group(region, service_zone_id, skyplane_vpc) + self.create_security_group_in_rule(region, sg) + self.create_security_group_out_rule(region, sg) + + # create firewall rules + self.check_created_igw(region, igw, skyplane_vpc) + self.check_created_subnet(region, subnet, skyplane_vpc) + + return skyplane_vpc + except Exception as e: + logger.error(e) + raise e + + def delete_security_group_rules(self, region, securityGroupId: str, ruleIds): + url = f"/security-group/v2/security-groups/{securityGroupId}/rules" + req_body = {"ruleDeletionType": "PARTIAL", "ruleIds": ruleIds} + try: + self.scp_client._delete(url, req_body) + delete_completion_condition = lambda: all(not self.list_security_group_rules(securityGroupId, rule) for rule in ruleIds) + self.scp_client.wait_for_completion(f"[{region}] Deleting SCP Skyplane Security Group Rule", delete_completion_condition) + except Exception as e: + logger.error(e) + raise e + + def list_security_group_rules(self, securityGroupId: str, ruleId: str): + url = f"/security-group/v2/security-groups/{securityGroupId}/rules" + response = self.scp_client._get(url) + if ruleId is not None: + result = [result for result in response if result["ruleId"] == ruleId] + else: # list all ruleIds + result = [rule["ruleId"] for rule in response] + return result + + def create_security_group_in_rule(self, region, securityGroupId: str): + url = f"/security-group/v2/security-groups/{securityGroupId}/rules" + req_body = { + "ruleDirection": "IN", + "services": [ + { + "serviceType": "TCP_ALL", + } + ], + "sourceIpAddresses": ["0.0.0.0/0"], + "ruleDescription": "Skyplane Security Group In rule(ssh)", + } + try: + response = self.scp_client._post(url, req_body) + ruleId = response["resourceId"] + create_completion_condition = lambda: self.list_security_group_rules(securityGroupId, ruleId)[0]["ruleState"] == "ACTIVE" + self.scp_client.wait_for_completion(f"[{region}] Creating SCP Skyplane Security Group In Rule", create_completion_condition) + return ruleId + + except Exception as e: + logger.error(e) + raise e + + def create_security_group_out_rule(self, region, securityGroupId: str): + url = f"/security-group/v2/security-groups/{securityGroupId}/rules" + req_body = { + "ruleDirection": "OUT", + "services": [ + { + "serviceType": "TCP_ALL", + } + ], + "destinationIpAddresses": ["0.0.0.0/0"], + "ruleDescription": "Skyplane Security Group Out rule", + } + try: + response = self.scp_client._post(url, req_body) + ruleId = response["resourceId"] + create_completion_condition = lambda: self.list_security_group_rules(securityGroupId, ruleId)[0]["ruleState"] == "ACTIVE" + self.scp_client.wait_for_completion(f"[{region}] Creating SCP Skyplane Security Group Out Rule", create_completion_condition) + return ruleId + except Exception as e: + logger.error(e) + raise e + + def delete_firewall_rules(self, region, firewallId: str, ruleIds, retry_limit=10): + url = f"/firewall/v2/firewalls/{firewallId}/rules" + + if ruleIds is None: + ruleIds = self.list_firewall_rules(firewallId, None) + + req_body = {"ruleDeletionType": "PARTIAL", "ruleIds": ruleIds} + retries = 0 + while retries < retry_limit: + try: + # print(f"try delete_firewall_rules: req_body - [{region}] retries : {retries} \n {req_body}") + response = self.scp_client._delete(url, req_body) + # delete_completion_condition = lambda: all(not self.list_firewall_rules(firewallId, rule) for rule in ruleIds) + # self.scp_client.wait_for_completion(f"[{region}] Deleting SCP Skyplane Firewall Rules", delete_completion_condition) + return response + except Exception as e: + # error code 500 retry + if "500" in e.args[0]: + retries += 1 + # print(f"Error deleting firewall rules(): {e}") + logger.fs.debug(f"Error deleting firewall rules(): {e}") + if retries == retry_limit: + logger.error(e) + raise e + time.sleep(5) + else: + logger.error(e) + raise e + + def get_firewallId(self, igwId: str): + url = f"/firewall/v2/firewalls?objectId={igwId}" + result = self.scp_client._get(url)[0]["firewallId"] + return result + + def get_firewall_rule_details(self, firewallId: str, ruleId: str): + url = f"/firewall/v2/firewalls/{firewallId}/rules/{ruleId}" + response = self.scp_client._getDetail(url) + return response + + def list_firewall_rules(self, firewallId: str, ruleId=None): + url = f"/firewall/v2/firewalls/{firewallId}/rules?size=1000" + response = self.scp_client._get(url) + if ruleId is None: + result = [rule["ruleId"] for rule in response] + else: + result = [result for result in response if result["ruleId"] == ruleId] + return result + + def add_firewall_22_rule(self, region, servergroup, virtualServerId, retry_limit=5): + # check rule exist + def check_rule_exist(server_id, firewall_id): + ip = self.get_vs_details(server_id)["ip"] + rules = self.list_firewall_rules(firewall_id, None) + for rule in rules: + rule_details = self.get_firewall_rule_details(firewall_id, rule) + if rule_details["ruleDirection"] == "IN" and rule_details["ruleDescription"] == "Skyplane Firewall ssh rule": + if ip in rule_details["destinationIpAddresses"]: + # print(f"Pass - {region}-{servergroup} : Already exists firewall rule") + return True + + servers = self.list_server_group_instance(servergroup) + try: + private_ips = [self.get_vs_details(server["virtualServerId"])["ip"] for server in servers] + except Exception as e: + # print(f"Pass - {region}-{servergroup} : Waiting for all servers to be Editing status") + return e + + vpcId = self.get_vs_details(virtualServerId)["vpcId"] + igwId = self.list_igws(vpcId)[0]["internetGatewayId"] + firewallId = self.get_firewallId(igwId) + + url = f"/firewall/v2/firewalls/{firewallId}/rules" + req_body = { + "sourceIpAddresses": ["0.0.0.0/0"], + "destinationIpAddresses": private_ips, + "services": [{"serviceType": "TCP", "serviceValue": "22"}], + "ruleDirection": "IN", + "ruleAction": "ALLOW", + "isRuleEnabled": True, + "ruleLocationType": "LAST", + "ruleDescription": "Skyplane Firewall ssh rule", + } + retries = 0 + while retries < retry_limit: + try: + if not check_rule_exist(virtualServerId, firewallId): + response = self.scp_client.post(url, req_body) + ruleId = response["resourceId"] + create_completion_condition = lambda: self.get_firewall_rule_details(firewallId, ruleId)["ruleState"] == "ACTIVE" + self.scp_client.wait_for_completion( + f"[{region}:{private_ips}] Creating SCP Skyplane Firewall 22 Rule", create_completion_condition + ) + return ruleId + else: + return None + except Exception as e: + if "500" in e.args[0]: + # print(f"Pass - Error creating firewall in rule(): {e}") + logger.fs.debug(f"Pass - Error creating firewall in rule(): {e}") + return None + else: + retries += 1 + # print(f"Error creating firewall in rule(): {e}") + logger.fs.debug(f"Error creating firewall in rule(): {e}") + if retries == retry_limit: + logger.fs.error(e) + raise e + time.sleep(1) + + def add_firewall_rule(self, region: str, virtualServerIds: Optional[List[str]] = None): + for virtualServerId in virtualServerIds: + try: + vpcId = self.get_vs_details(virtualServerId)["vpcId"] + ip = self.get_vs_details(virtualServerId)["ip"] + + igwId = self.list_igws(vpcId)[0]["internetGatewayId"] + firewallId = self.get_firewallId(igwId) + + self.create_firewall_in_rule(region, firewallId, ip) + self.create_firewall_out_rule(region, firewallId, ip) + except Exception as e: + logger.error(e) + raise e + + def add_firewall_rule_all(self, region, private_ips, vpcids): + vpcId = vpcids[0] + try: + # igwId = self.list_igws(vpcId)[0]['internetGatewayId'] + igwId = self.get_igw(vpcId) + firewallId = self.get_firewallId(igwId) + + self.create_firewall_in_rule(region, firewallId, private_ips) + self.create_firewall_out_rule(region, firewallId, private_ips) + except Exception as e: + logger.error(e) + raise e + + def get_igw(self, vpc_Id: str): + return self.list_igws(vpc_Id)[0]["internetGatewayId"] + + def create_firewall_in_rule(self, region, firewallId, internalIp, retry_limit=60): + url = f"/firewall/v2/firewalls/{firewallId}/rules" + req_body = { + "sourceIpAddresses": ["0.0.0.0/0"], + "destinationIpAddresses": internalIp, + "services": [{"serviceType": "TCP_ALL"}], + "ruleDirection": "IN", + "ruleAction": "ALLOW", + "isRuleEnabled": True, + "ruleLocationType": "LAST", + "ruleDescription": "Skyplane Firewall In rule", + } + retries = 0 + while retries < retry_limit: + try: + response = self.scp_client._post(url, req_body) + # print(response) + # time.sleep(3) + ruleId = response["resourceId"] + create_completion_condition = lambda: self.get_firewall_rule_details(firewallId, ruleId)["ruleState"] == "ACTIVE" + self.scp_client.wait_for_completion(f"[{region}] Creating SCP Skyplane Firewall In Rule", create_completion_condition) + return ruleId + except Exception as e: + retries += 1 + print(f"Error creating firewall in rule(): {e}") + if retries == retry_limit: + logger.error(e) + raise e + time.sleep(2) + + def create_firewall_out_rule(self, region, firewallId, internalIp, retry_limit=60): + url = f"/firewall/v2/firewalls/{firewallId}/rules" + req_body = { + "sourceIpAddresses": internalIp, + "destinationIpAddresses": ["0.0.0.0/0"], + "services": [{"serviceType": "TCP_ALL"}], + "ruleDirection": "OUT", + "ruleAction": "ALLOW", + "isRuleEnabled": True, + "ruleLocationType": "FIRST", + "ruleDescription": "Skyplane Firewall Out rule", + } + retries = 0 + while retries < retry_limit: + try: + response = self.scp_client._post(url, req_body) + ruleId = response["resourceId"] + create_completion_condition = lambda: self.get_firewall_rule_details(firewallId, ruleId)["ruleState"] == "ACTIVE" + self.scp_client.wait_for_completion(f"[{region}] Creating SCP Skyplane Firewall Out Rule", create_completion_condition) + return ruleId + except Exception as e: + retries += 1 + print(f"Error creating firewall out rule(): {e}") + if retries == retry_limit: + logger.error(e) + raise e + time.sleep(2) + + def list_instances(self, instanceId): + if instanceId is None: + url = f"/virtual-server/v2/virtual-servers" + else: + try: + instanceName = self.get_vs_details(instanceId)["virtualServerName"] + except Exception as e: + if "404" in e.args[0]: + return None + url = f"/virtual-server/v2/virtual-servers?virtualServerName={instanceName}" + + result = self.scp_client._get(url) + if len(result) == 0: + return None + else: + return result + + def start_instance(self, instanceId: str): + url = f"/virtual-server/v2/virtual-servers/{instanceId}/start" + try: + response = self.scp_client._post(url, req_body={}) + instanceId = response["resourceId"] + completion_condition = lambda: self.list_instances(instanceId)[0]["virtualServerState"] == "RUNNING" + self.scp_client.wait_for_completion(f" Starting SCP Skyplane Virtual Server", completion_condition) + except Exception as e: + logger.error(e) + raise e + + def stop_instance(self, instanceId: str): + url = f"/virtual-server/v2/virtual-servers/{instanceId}/stop" + try: + response = self.scp_client._post(url, req_body={}) + instanceId = response["resourceId"] + completion_condition = lambda: self.list_instances(instanceId)[0]["virtualServerState"] == "STOPPED" + self.scp_client.wait_for_completion(f" Stopping SCP Skyplane Virtual Server", completion_condition) + except Exception as e: + logger.error(e) + raise e + + def terminate_instance(self, instanceId: str): + url = f"/virtual-server/v2/virtual-servers/{instanceId}" + try: + response = self.scp_client._delete(url, req_body={}) + instanceId = response["resourceId"] + # completion_condition = lambda: self.list_instances(instanceId) is None + # self.scp_client.wait_for_completion(f" Terminating SCP Skyplane Virtual Server", completion_condition) + except Exception as e: + logger.error(e) + raise e + + def get_nic_details(self, virtualServerId): + url = f"/virtual-server/v2/virtual-servers/{virtualServerId}/nics" + response = self.scp_client._get(url) + return response + + def get_vs_details(self, virtualServerId): + url = f"/virtual-server/v3/virtual-servers/{virtualServerId}" + response = self.scp_client._getDetail(url) + return response + + def create_server_group(self, region): + url = f"/server-group/v2/server-groups" + serviceZoneId = self.get_service_zone_id(region) + req_body = { + "serverGroupName": f"skyplane-{uuid.uuid4().hex[:8]}", + "serviceZoneId": serviceZoneId, + "serviceFor": "VIRTUAL_SERVER", + "servicedGroupFor": "COMPUTE", + } + response = self.scp_client._post(url, req_body) + serverGroupId = response["serverGroupId"] + return serverGroupId + + def list_server_group(self): + url = f"/server-group/v2/server-groups" + response = self.scp_client._get(url) + return response + + def delete_server_group(self, serverGroupId): + url = f"/server-group/v2/server-groups/{serverGroupId}" + if len(self.list_server_group_instance(serverGroupId)) == 0: + response = self.scp_client._delete(url) + return response + + def detail_server_group(self, serverGroupId): + url = f"/server-group/v2/server-groups/{serverGroupId}" + response = self.scp_client._getDetail(url) + return response + + def list_server_group_instance(self, serverGroupId): + url = f"/virtual-server/v2/virtual-servers?serverGroupId={serverGroupId}" + response = self.scp_client._get(url) + return response diff --git a/skyplane/compute/scp/scp_server.py b/skyplane/compute/scp/scp_server.py new file mode 100644 index 000000000..98d8bfc0b --- /dev/null +++ b/skyplane/compute/scp/scp_server.py @@ -0,0 +1,210 @@ +# (C) Copyright Samsung SDS. 2023 + +# + +# Licensed under the Apache License, Version 2.0 (the "License"); + +# you may not use this file except in compliance with the License. + +# You may obtain a copy of the License at + +# + +# http://www.apache.org/licenses/LICENSE-2.0 + +# + +# Unless required by applicable law or agreed to in writing, software + +# distributed under the License is distributed on an "AS IS" BASIS, + +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +# See the License for the specific language governing permissions and + +# limitations under the License. +import time +import warnings +from typing import Optional +from pathlib import Path +from skyplane.utils import logger + +from cryptography.utils import CryptographyDeprecationWarning + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=CryptographyDeprecationWarning) + import paramiko + +from skyplane import exceptions +from skyplane.compute.server import Server, ServerState, key_root +from skyplane.compute.scp.scp_network import SCPNetwork +from skyplane.utils.fn import PathLike + +from skyplane.compute.scp.scp_auth import SCPAuthentication +from skyplane.compute.scp.scp_utils import SCPClient + + +class SCPServer(Server): + def __init__( + self, + region_tag: str, + virtualServerName: str, + # virtualServerId: Optional[str] = None, + key_root: PathLike = key_root / "scp", + virtualServerId: Optional[str] = None, + log_dir: Optional[PathLike] = None, + ssh_private_key=None, + vpc_id=None, + ): + super().__init__(region_tag, log_dir=log_dir) + self.auth = SCPAuthentication() + self.network = SCPNetwork(self.auth) + self.scp_client = SCPClient() + + assert self.region_tag.split(":")[0] == "scp", f"Region name doesn't match pattern scp: {self.region_tag}" + self.scp_region = self.region_tag.split(":")[1] + self.virtualserver_name = virtualServerName + # self.virtualserver_name = self.region_tag.split(":")[2] if virtualServerName is None else virtualServerName + self.virtualserver_id = self.virtualserver_id() if virtualServerId is None else virtualServerId + + self.vpc_id = self.vpc_id() if vpc_id is None else vpc_id + + key_root = Path(key_root) + key_root.mkdir(parents=True, exist_ok=True) + if ssh_private_key is None: + self.ssh_private_key = key_root / "scp_key" + else: + self.ssh_private_key = ssh_private_key + + self.internal_ip, self.external_ip = self._init_ips() + + def uuid(self): + return f"{self.region_tag}:{self.virtualserver_name}" + + def _init_ips(self): + internal_ip = self.network.get_vs_details(self.virtualserver_id)["ip"] + external_ip = None + nics = self.network.get_nic_details(self.virtualserver_id) + for nic in nics: + if nic["ip"] == internal_ip and nic["subnetType"] == "PUBLIC": + # print(nic['natIp']) + external_ip = nic["natIp"] + break + return internal_ip, external_ip + + def public_ip(self) -> str: + return self.external_ip + + def private_ip(self): + return self.internal_ip + + def instance_class(self): + return self.network.get_vs_details(self.virtualserver_id)["serverType"] + + def instance_state(self): + return ServerState.from_scp_state(self.network.get_vs_details(self.virtualserver_id)["virtualServerState"]) + + def region(self): + return self.scp_region + + def instance_name(self): + return self.virtualserver_name + + def virtualserver_name(self): + return self.virtualserver_name + + def virtualserver_id(self): + url = f"/virtual-server/v2/virtual-servers?vitualServerName={self.virtualserver_name}" + response = self.scp_client._get(url) + return response[0]["virtualServerId"] + + def tags(self): + url = f"/tag/v2/resources/{self.virtualserver_id}/tags" + tags = self.scp_client._get(url) + return {tag["tagKey"]: tag["tagValue"] for tag in tags} if tags else {} + + def network_tier(self): + return "PREMIUM" + + def vpc_id(self): + url = f"/virtual-server/v3/virtual-servers/{self.virtualserver_id}" + response = self.scp_client._getDetail(url) + return response["vpcId"] + + def terminate_instance_impl(self): + self.network.terminate_instance(self.virtualserver_id) + # pass + + def get_sftp_client(self): + t = paramiko.Transport((self.public_ip(), 22)) + pkey = paramiko.RSAKey.from_private_key_file(str(self.ssh_private_key), password="test123$") + t.connect(username="root", pkey=pkey) + return paramiko.SFTPClient.from_transport(t) + + def get_ssh_client_impl(self): + """Return paramiko client that connects to this instance.""" + ssh_client = paramiko.SSHClient() + ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + try: + ssh_client.connect( + hostname=self.public_ip(), + username="root", + pkey=paramiko.RSAKey.from_private_key_file(str(self.ssh_private_key), password="test123$"), + look_for_keys=False, + banner_timeout=200, + ) + return ssh_client + except paramiko.AuthenticationException as e: + raise exceptions.BadConfigException( + f"Failed to connect to SCP server {self.uuid()}. Delete local SCP keys and retry: `rm -rf {key_root / 'scp'}`" + ) from e + + def open_ssh_tunnel_impl(self, remote_port): + import sshtunnel + + try: + tunnel = sshtunnel.SSHTunnelForwarder( + (self.public_ip(), 22), + ssh_username="root", + ssh_pkey=str(self.ssh_private_key), + ssh_private_key_password="test123$", + host_pkey_directories=[], + local_bind_address=("127.0.0.1", 0), + remote_bind_address=("127.0.0.1", remote_port), + ) + tunnel.start() + return tunnel + except Exception as e: + logger.error(f"Error opening SSH tunnel: {str(e)}") + return None + + def get_ssh_cmd(self) -> str: + return f"ssh -i {self.ssh_private_key} {'root'}@{self.public_ip()}" + + def install_docker(self): + # print("install_docker in scp_server.py") + try: + return super().install_docker() + except Exception as e: + import re + + pid_pattern = re.compile(r"process (\d+) ") + match_result = pid_pattern.search(str(e)) + if match_result: + pid = match_result.group(1) + self.run_command(f"sudo kill -9 {pid}; sudo rm /var/lib/dpkg/lock-frontend") + # print(f"Killed process with PID {pid} on {self.region_tag}:{self.instance_name()}, {self.public_ip()}") + logger.fs.debug(f"Killed process with PID {pid} on {self.region_tag}:{self.instance_name()}, {self.public_ip()}") + time.sleep(1) + else: + if "sudo dpkg --configure -a" in str(e): + self.run_command("sudo dpkg --configure -a") + # print(f"Ran 'sudo dpkg --configure -a' on {self.region_tag}:{self.instance_name()}, {self.public_ip()}") + logger.fs.debug(f"Ran 'sudo dpkg --configure -a' on {self.region_tag}:{self.instance_name()}, {self.public_ip()}") + time.sleep(1) + elif "metricbeat" in str(e): + # print(f"metricbeat error on {self.region_tag}:{self.instance_name()}, {self.public_ip()}") + logger.fs.debug(f"metricbeat on {self.region_tag}:{self.instance_name()}, {self.public_ip()}") + time.sleep(5) + raise RuntimeError(f"Failed to install Docker on {self.region_tag}, {self.public_ip()}: error {e}") + raise RuntimeError(f"Failed to install Docker on {self.region_tag}, {self.public_ip()}: error {e}") diff --git a/skyplane/compute/scp/scp_utils.py b/skyplane/compute/scp/scp_utils.py new file mode 100644 index 000000000..8702f6a8f --- /dev/null +++ b/skyplane/compute/scp/scp_utils.py @@ -0,0 +1,248 @@ +# (C) Copyright Samsung SDS. 2023 + +# + +# Licensed under the Apache License, Version 2.0 (the "License"); + +# you may not use this file except in compliance with the License. + +# You may obtain a copy of the License at + +# + +# http://www.apache.org/licenses/LICENSE-2.0 + +# + +# Unless required by applicable law or agreed to in writing, software + +# distributed under the License is distributed on an "AS IS" BASIS, + +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +# See the License for the specific language governing permissions and + +# limitations under the License. +""" SCP Open-API utility functions """ + +import base64 +import datetime +from functools import wraps +import hashlib +import hmac +import os +import json +import random +import time +from urllib import parse +import requests +from skyplane.utils import logger + +CREDENTIALS_PATH = "~/.scp/scp_credential" +API_ENDPOINT = "https://openapi.samsungsdscloud.com" +TEMP_VM_JSON_PATH = "/tmp/json/tmp_vm_body.json" + + +class SCPClientError(Exception): + pass + + +class SCPOngoingRequestError(Exception): + pass + + +class SCPCreationFailError(Exception): + pass + + +def raise_scp_error(response: requests.Response) -> None: + """Raise SCPCloudError if appropriate.""" + status_code = response.status_code + if status_code == 200 or status_code == 202: + return + try: + resp_json = response.json() + message = resp_json["message"] + except (KeyError, json.decoder.JSONDecodeError): + # print(f"response: {response.content}") + raise SCPClientError(f"Unexpected error. Status code: {status_code}") + + raise SCPClientError(f"{status_code}: {message}") + + +def _retry(method, max_tries=60, backoff_s=1): + @wraps(method) + def method_with_retries(self, *args, **kwargs): + try_count = 0 + while try_count < max_tries: + try: + return method(self, *args, **kwargs) + except Exception as e: + # print(e.args[0]) + retry_codes = ["500", "412", "403"] # add 403 for object storage + # if any(code in e.args[0] for code in retry_codes): + if any(code in str(e) for code in retry_codes): + try_count += 1 + # console.print(f"[yellow] retries: {method.__name__} - {e}, try_count : {try_count}[/yellow]") + logger.fs.debug(f"retries: {method.__name__} - {e}, try_count : {try_count}") + if try_count < max_tries: + time.sleep(backoff_s) + else: + raise e + elif "Connection aborted" in str(e): + try_count += 1 + # with open(f"/skyplane/retry_nocode_error.txt", "w") as f: + # f.write(str(e)) + logger.fs.debug(f"retries: {method.__name__} - {e}, try_count : {try_count}") + if try_count < max_tries: + time.sleep(backoff_s) + else: + raise e + else: + # with open(f"/skyplane/retry_nocode_error.txt", "w") as f: + # f.write(str(e)) + raise e + + return method_with_retries + + +class SCPClient: + """SCP Open-API client""" + + def __init__(self) -> None: + # print('SCPClient init') + self.credentials = os.path.expanduser(CREDENTIALS_PATH) + if not os.path.exists(self.credentials): + self.credentials = os.path.expanduser("/pkg/data/scp_credential") + assert os.path.exists(self.credentials), "SCP Credentials not found" + with open(self.credentials, "r") as f: + lines = [line.strip() for line in f.readlines() if " = " in line] + self._credentials = {line.split(" = ")[0]: line.split(" = ")[1] for line in lines} + + self.access_key = self._credentials["scp_access_key"] + self.secret_key = self._credentials["scp_secret_key"] + self.project_id = self._credentials["scp_project_id"] + self.client_type = "OpenApi" + self.timestamp = "" + self.signature = "" + + self.headers = { + "X-Cmp-AccessKey": f"{self.access_key}", + "X-Cmp-ClientType": f"{self.client_type}", + "X-Cmp-ProjectId": f"{self.project_id}", + "X-Cmp-Timestamp": f"{self.timestamp}", + "X-Cmp-Signature": f"{self.signature}", + } + + @_retry + def _get(self, url, contents_key="contents"): + method = "GET" + url = f"{API_ENDPOINT}{url}" + self.set_timestamp() + self.set_signature(url=url, method=method) + + response = requests.get(url, headers=self.headers) + raise_scp_error(response) + + if contents_key is not None: + return response.json().get(contents_key, []) + else: + return response.json() + + @_retry + def _getDetail(self, url): + method = "GET" + url = f"{API_ENDPOINT}{url}" + self.set_timestamp() + self.set_signature(url=url, method=method) + + response = requests.get(url, headers=self.headers) + response_detail = response.content.decode("utf-8") + raise_scp_error(response) + return json.loads(response_detail) + + def post(self, url, req_body): + method = "POST" + url = f"{API_ENDPOINT}{url}" + self.set_timestamp() + self.set_signature(url=url, method=method) + + response = requests.post(url, json=req_body, headers=self.headers) + raise_scp_error(response) + return response.json() + + @_retry + def _post(self, url, req_body): + return self.post(url, req_body) + + @_retry + def _delete(self, url, req_body=None): + method = "DELETE" + url = f"{API_ENDPOINT}{url}" + self.set_timestamp() + self.set_signature(url=url, method=method) + + # try: + if req_body: + response = requests.delete(url, json=req_body, headers=self.headers) + else: + response = requests.delete(url, headers=self.headers) + raise_scp_error(response) + try: + return response.json() + except json.JSONDecodeError as e: + # return response.content.decode('utf-8') + return response + + def wait_for_completion(self, task_name, completion_condition, retry_limit=5, timeout=180): + start = time.time() + # console.print(f"[bright_black] Waiting for... {task_name} [/bright_black]") + logger.fs.debug(f"Waiting for... {task_name}") + + retries = 0 + while time.time() - start < timeout and retries < retry_limit: + try: + if completion_condition(): + break + except Exception as e: + retries += 1 + # print(f"Exception occurred during completion_condition(): {task_name} - {e}, retries : {retries}") + logger.fs.debug(f"Exception occurred during completion_condition(): {task_name} - {e}, retries : {retries}") + time.sleep(1) + if retries == retry_limit: + # console.print(f"[red] {task_name}... Retry limit reached [/red]") + logger.fs.error(f"{task_name}... Retry limit reached") + raise SCPCreationFailError(f"Failed to create {task_name}") + else: + # console.print(f"[green] {task_name}... Completed [/green]") + logger.fs.debug(f"{task_name}... Completed") + + def set_timestamp(self) -> None: + # self.timestamp = str(int(round((datetime.datetime.now() - datetime.timedelta(minutes=1)).timestamp() * 1000))) + current_time = datetime.datetime.now() + self.timestamp = str(int(round(current_time.timestamp() * 1000))) + self.headers["X-Cmp-Timestamp"] = self.timestamp + + def set_signature(self, url: str, method: str) -> None: + self.signature = self.get_signature(url=url, method=method) + self.headers["X-Cmp-Signature"] = f"{self.signature}" + + def get_signature(self, url: str, method: str) -> str: + url_info = parse.urlsplit(url) + url = f"{url_info.scheme}://{url_info.netloc}{parse.quote(url_info.path)}" + + if url_info.query: + enc_params = [(item[0], parse.quote(item[1][0])) for item in parse.parse_qs(url_info.query).items()] + url = f"{url}?{parse.urlencode(enc_params)}" + + message = method + url + self.timestamp + self.access_key + self.project_id + self.client_type + message = message.encode("utf-8") + secret = self.secret_key.encode("utf-8") + + signature = base64.b64encode(hmac.new(secret, message, digestmod=hashlib.sha256).digest()).decode("utf-8") + + return signature + + def random_wait(self): + random_int = random.randint(1, 100) + time.sleep(1 * random_int / 100) diff --git a/skyplane/compute/server.py b/skyplane/compute/server.py index 585773105..c741ccfb3 100644 --- a/skyplane/compute/server.py +++ b/skyplane/compute/server.py @@ -2,6 +2,7 @@ import logging import os import socket +import time from contextlib import closing from enum import Enum, auto from functools import partial @@ -81,6 +82,19 @@ def from_ibmcloud_state(ibmcloud_state): } return mapping.get(ibmcloud_state, ServerState.UNKNOWN) + @staticmethod + def from_scp_state(scp_state): + mapping = { + "CTREATING": ServerState.PENDING, + "EDITING": ServerState.PENDING, + "RUNNING": ServerState.RUNNING, + "STARTING": ServerState.SUSPENDED, + "STOPPING": ServerState.SUSPENDED, + "STOPPED": ServerState.SUSPENDED, + "TERMINATING": ServerState.TERMINATED, + } + return mapping.get(scp_state, ServerState.UNKNOWN) + class Server: """Abstract server class to support basic SSH operations""" @@ -127,7 +141,8 @@ def get_ssh_cmd(self) -> str: def ssh_client(self): """Create SSH client and cache.""" if not hasattr(self, "_ssh_client"): - self._ssh_client = self.get_ssh_client_impl() + # retry for aws & gcp ubuntu instances + self._ssh_client = retry_backoff(partial(self.get_ssh_client_impl)) return self._ssh_client def tunnel_port(self, remote_port: int) -> int: @@ -338,6 +353,13 @@ def check_stderr(tup): if self.provider == "aws": docker_envs["AWS_DEFAULT_REGION"] = self.region_tag.split(":")[1] + if self.provider == "scp": + credentail_path = Path(self.auth.scp_credential_path).expanduser() + credentail_file = os.path.basename(credentail_path) + self.upload_file(credentail_path, f"/tmp/{credentail_file}") + docker_envs["SCP_CREDENTIAL_FILE"] = f"/pkg/data/{credentail_file}" + docker_run_flags += f" -v /tmp/{credentail_file}:/pkg/data/{credentail_file}" + # copy E2EE keys if e2ee_key_bytes is not None: e2ee_key_file = "e2ee_key" @@ -386,10 +408,13 @@ def is_api_ready(): status_val = json.loads(http_pool.request("GET", api_url).data.decode("utf-8")) is_up = status_val.get("status") == "ok" return is_up - except Exception: + except Exception as e: + logger.fs.error(f"Gateway {self.instance_name()} : {e}") return False try: + # avoid console ssh error log + time.sleep(0.5) logging.disable(logging.CRITICAL) wait_for(is_api_ready, timeout=30, interval=0.1, desc=f"Waiting for gateway {self.uuid()} to start") except TimeoutError as e: diff --git a/skyplane/config.py b/skyplane/config.py index f1210d8a9..322838659 100644 --- a/skyplane/config.py +++ b/skyplane/config.py @@ -40,6 +40,9 @@ "requester_pays": bool, "native_cmd_enabled": bool, "native_cmd_threshold_gb": int, + "scp_use_spot_instances": bool, + "scp_instance_class": str, + "scp_default_region": str, } _DEFAULT_FLAGS = { @@ -75,6 +78,9 @@ "requester_pays": False, "native_cmd_enabled": True, "native_cmd_threshold_gb": 2, + "scp_use_spot_instances": False, + "scp_instance_class": "h1v32m128", + "scp_default_region": "KR-WEST-1", } @@ -103,6 +109,7 @@ class SkyplaneConfig: gcp_enabled: bool cloudflare_enabled: bool ibmcloud_enabled: bool + scp_enabled: bool anon_clientid: str azure_principal_id: Optional[str] = None azure_subscription_id: Optional[str] = None @@ -118,6 +125,9 @@ class SkyplaneConfig: ibmcloud_iam_endpoint: Optional[str] = None ibmcloud_useragent: Optional[str] = None ibmcloud_resource_group_id: Optional[str] = None + scp_access_key: Optional[str] = None + scp_secret_key: Optional[str] = None + scp_project_id: Optional[str] = None @staticmethod def generate_machine_id() -> str: @@ -131,6 +141,7 @@ def default_config(cls) -> "SkyplaneConfig": gcp_enabled=False, ibmcloud_enabled=False, cloudflare_enabled=False, + scp_enabled=False, anon_clientid=cls.generate_machine_id(), ) @@ -183,6 +194,20 @@ def load_config(cls, path) -> "SkyplaneConfig": if "cloudflare_secret_access_key" in config["cloudflare"]: cloudflare_secret_access_key = config.get("cloudflare", "cloudflare_secret_access_key") + scp_enabled = False + scp_access_key = None + scp_secret_key = None + scp_project_id = None + if "scp" in config: + if "scp_enabled" in config["scp"]: + scp_enabled = config.getboolean("scp", "scp_enabled") + if "scp_access_key" in config["scp"]: + scp_access_key = config.get("scp", "scp_access_key") + if "scp_secret_key" in config["scp"]: + scp_secret_key = config.get("scp", "scp_secret_key") + if "scp_project_id" in config["scp"]: + scp_project_id = config.get("scp", "scp_project_id") + gcp_enabled = False gcp_project_id = None if "gcp" in config: @@ -230,6 +255,10 @@ def load_config(cls, path) -> "SkyplaneConfig": ibmcloud_iam_endpoint=ibmcloud_iam_endpoint, ibmcloud_useragent=ibmcloud_useragent, ibmcloud_resource_group_id=ibmcloud_resource_group_id, + scp_enabled=scp_enabled, + scp_access_key=scp_access_key, + scp_secret_key=scp_secret_key, + scp_project_id=scp_project_id, ) if "flags" in config: @@ -274,6 +303,16 @@ def to_config_file(self, path): if self.cloudflare_secret_access_key: config.set("cloudflare", "cloudflare_secret_access_key", self.cloudflare_secret_access_key) + if "scp" not in config: + config.add_section("scp") + config.set("scp", "scp_enabled", str(self.scp_enabled)) + if self.scp_access_key: + config.set("scp", "scp_access_key", self.scp_access_key) + if self.scp_secret_key: + config.set("scp", "scp_secret_key", self.scp_secret_key) + if self.scp_project_id: + config.set("scp", "scp_project_id", self.scp_project_id) + if "azure" not in config: config.add_section("azure") config.set("azure", "azure_enabled", str(self.azure_enabled)) diff --git a/skyplane/config_paths.py b/skyplane/config_paths.py index c252869b3..a70b83565 100644 --- a/skyplane/config_paths.py +++ b/skyplane/config_paths.py @@ -13,6 +13,8 @@ gcp_config_path = __config_root__ / "gcp_config" ibmcloud_config_path = __config_root__ / "ibmcloud_config" gcp_quota_path = __config_root__ / "gcp_quota" +scp_config_path = __config_root__ / "scp_config" +scp_quota_path = __config_root__ / "scp_quota" @functools.lru_cache(maxsize=None) diff --git a/skyplane/data/vcpu_info.csv b/skyplane/data/vcpu_info.csv index cf48a3d3a..3b206ee6a 100644 --- a/skyplane/data/vcpu_info.csv +++ b/skyplane/data/vcpu_info.csv @@ -25,3 +25,11 @@ Standard_D16_v5,azure,16 Standard_D8_v5,azure,8 Standard_D4_v5,azure,4 Standard_D2_v5,azure,2 +h1v96m384,scp,96 +h1v64m256,scp,64 +h1v48m192,scp,48 +h1v32m128,scp,32 +s1v16m64,scp,16 +s1v8m32,scp,8 +s1v4m16,scp,4 +s1v2m8,scp,2 diff --git a/skyplane/gateway/operators/gateway_operator.py b/skyplane/gateway/operators/gateway_operator.py index c574caebc..3ff0d1ea5 100644 --- a/skyplane/gateway/operators/gateway_operator.py +++ b/skyplane/gateway/operators/gateway_operator.py @@ -272,6 +272,16 @@ def process(self, chunk_req: ChunkRequest, dst_host: str): chunk_ids = [chunk_req.chunk.chunk_id] chunk_reqs = [chunk_req] + + # retry http_pool.request + def http_pool_request(register_body): + return self.http_pool.request( + "POST", + f"https://{dst_host}:8080/api/v1/chunk_requests", + body=register_body, + headers={"Content-Type": "application/json"}, + ) + try: with Timer(f"pre-register chunks {chunk_ids} to {dst_host}"): # TODO: remove chunk request wrapper @@ -288,12 +298,13 @@ def process(self, chunk_req: ChunkRequest, dst_host: str): while n_added < len(chunk_reqs): register_body = json.dumps([c.chunk.as_dict() for c in chunk_reqs[n_added:]]).encode("utf-8") # print(f"[sender-{self.worker_id}]:{chunk_ids} register body {register_body}") - response = self.http_pool.request( - "POST", - f"https://{dst_host}:8080/api/v1/chunk_requests", - body=register_body, - headers={"Content-Type": "application/json"}, - ) + # response = self.http_pool.request( + # "POST", + # f"https://{dst_host}:8080/api/v1/chunk_requests", + # body=register_body, + # headers={"Content-Type": "application/json"}, + # ) + response = retry_backoff(partial(http_pool_request, register_body)) reply_json = json.loads(response.data.decode("utf-8")) print(f"[sender-{self.worker_id}]", n_added, reply_json, dst_host) n_added += reply_json["n_added"] @@ -308,11 +319,23 @@ def process(self, chunk_req: ChunkRequest, dst_host: str): raise e # contact server to set up socket connection - if self.destination_ports.get(dst_host) is None: - print(f"[sender-{self.worker_id}]:{chunk_ids} creating new socket") + def socket_connection(): self.destination_sockets[dst_host] = retry_backoff( - partial(self.make_socket, dst_host), max_retries=3, exception_class=socket.timeout + # all exceptions are retried + # partial(self.make_socket, dst_host), max_retries=3, exception_class=socket.timeout + partial(self.make_socket, dst_host), + max_retries=3, + exception_class=Exception, ) + + if self.destination_ports.get(dst_host) is None: + print(f"[sender-{self.worker_id}]:{chunk_ids} creating new socket") + # self.destination_sockets[dst_host] = retry_backoff( + # # all exceptions are retried + # # partial(self.make_socket, dst_host), max_retries=3, exception_class=socket.timeout + # partial(self.make_socket, dst_host), max_retries=3, exception_class=Exception + # ) + socket_connection() print(f"[sender-{self.worker_id}]:{chunk_ids} created new socket") sock = self.destination_sockets[dst_host] @@ -348,18 +371,40 @@ def process(self, chunk_req: ChunkRequest, dst_host: str): is_compressed=(compressed_length is not None), ) # print(f"[sender-{self.worker_id}]:{chunk_id} sending chunk header {header}") - header.to_socket(sock) - # print(f"[sender-{self.worker_id}]:{chunk_id} sent chunk header") - - # send chunk data - assert chunk_file_path.exists(), f"chunk file {chunk_file_path} does not exist" - # file_size = os.path.getsize(chunk_file_path) - - with Timer() as t: - sock.sendall(data) + # fix socket timeouterror case - we need to create a new socket, so retry_backoff is not easy to use. + retry_count = 0 + max_retries = 3 + while retry_count < max_retries: + try: + header.to_socket(sock) + # print(f"[sender-{self.worker_id}]:{chunk_id} sent chunk header") + + # send chunk data + assert chunk_file_path.exists(), f"chunk file {chunk_file_path} does not exist" + # file_size = os.path.getsize(chunk_file_path) + + with Timer() as t: + sock.sendall(data) + except Exception as e: + print(f"[sender-{self.worker_id}]:{chunk_id} error sending chunk {e}") + retry_count += 1 + if retry_count < max_retries: + # print(f"SCP DEV - Retrying (attempt {retry_count} of {max_retries})...") + # with open(f"/skyplane/header.to_socket_{retry_count}_{dst_host}_error.txt", "w") as f: + # f.write(str(e)) + time.sleep(0.5) + socket_connection() + sock = self.destination_sockets[dst_host] + else: + # print(f"SCP DEV - Max retries ({max_retries}) reached. Giving up.") + raise e + else: + break # logger.debug(f"[sender:{self.worker_id}]:{chunk_id} sent at {chunk.chunk_length_bytes * 8 / t.elapsed / MB:.2f}Mbps") - print(f"[sender:{self.worker_id}]:{chunk_id} sent at {wire_length * 8 / t.elapsed / MB:.2f}Mbps") + print( + f"[sender:{self.worker_id}]:{chunk_id} sent at {wire_length * 8 / t.elapsed / MB:.2f}Mbps - {dst_host} : {self.destination_ports[dst_host]}" + ) if dst_host not in self.sent_chunk_ids: self.sent_chunk_ids[dst_host] = [] diff --git a/skyplane/obj_store/scp_interface.py b/skyplane/obj_store/scp_interface.py new file mode 100644 index 000000000..3058fdc82 --- /dev/null +++ b/skyplane/obj_store/scp_interface.py @@ -0,0 +1,483 @@ +# (C) Copyright Samsung SDS. 2023 + +# + +# Licensed under the Apache License, Version 2.0 (the "License"); + +# you may not use this file except in compliance with the License. + +# You may obtain a copy of the License at + +# + +# http://www.apache.org/licenses/LICENSE-2.0 + +# + +# Unless required by applicable law or agreed to in writing, software + +# distributed under the License is distributed on an "AS IS" BASIS, + +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +# See the License for the specific language governing permissions and + +# limitations under the License. +import base64 +import hashlib +import os +import time +import boto3 + +from functools import lru_cache, wraps +from typing import Iterator, Any, List, Optional, Tuple + +from skyplane import exceptions, compute +from skyplane.exceptions import NoSuchObjectException +from skyplane.obj_store.object_store_interface import ObjectStoreInterface, ObjectStoreObject +from skyplane.config_paths import cloud_config +from skyplane.utils import logger, imports +from skyplane.compute.scp.scp_network import SCPNetwork +import skyplane.compute.scp.scp_utils as sc + + +def _retry(method, max_tries=60, backoff_s=1): + @wraps(method) + def method_with_retries(self, *args, **kwargs): + try_count = 0 + while try_count < max_tries: + try: + return method(self, *args, **kwargs) + except Exception as e: + try_count += 1 + # print(e) + # with open(f"/skyplane/scp_interface_{method.__name__}_{try_count}_error.txt", "w") as f: + # f.write(str(e)) + logger.fs.debug(f"retries: {method.__name__} - {e}, try_count : {try_count}") + if try_count < max_tries: + time.sleep(backoff_s) + else: + raise e + + return method_with_retries + + +class SCPObject(ObjectStoreObject): + def full_path(self): + return f"scp://{self.bucket}/{self.key}" + + +class SCPInterface(ObjectStoreInterface): + def __init__(self, bucket_name: str): + # OpenAPI Auth + self.auth = compute.SCPAuthentication() + self.scp_client = sc.SCPClient() + + self.bucket_name = bucket_name + self.obsBucketId = self._get_bucket_id() + self.obs_access_key = "" + self.obs_secret_key = "" + self.obs_ednpoint = "" + self.obs_region = "" + + self.requester_pays = False + self._cached_s3_clients = {} + if self.obsBucketId is not None: + self._set_s3_credentials() + self._s3_client() + + @_retry + def _set_s3_credentials(self): + # if self.obsBucketId is None: + max_retries = 60 + retries = 0 + while self.obsBucketId is None and retries < max_retries: + self.obsBucketId = self._get_bucket_id() + retries += 1 + time.sleep(0.25) + # scp_client = self.scp_client + if self.obs_access_key == "": + # uri_path = f"/object-storage/v3/buckets/{self.obsBucketId}/api-info" + uri_path = f"/object-storage/v4/buckets/{self.obsBucketId}/access-info" + response = self.scp_client._get(uri_path, None) + # print(response) + try: + self.obs_access_key = response["objectStorageBucketAccessKey"] + self.obs_secret_key = response["objectStorageBucketSecretKey"] + self.obs_ednpoint = response["objectStorageBucketPublicEndpointUrl"] + # self.obs_region = response["serviceZoneId"] + network = SCPNetwork(self.auth) + self.obs_region = network.get_service_zoneName(response["serviceZoneId"]) + except Exception as e: + if "An error occurred (AccessDenied) when calling the GetBucketLocation operation" in str(e): + logger.warning(f"Bucket location {self.bucket_name} is not public.") + logger.warning(f"Specified bucket {self.bucket_name} does not exist, got SCP error: {e}") + # print("Error getting SCP region", e) + raise exceptions.MissingBucketException(f"SCP bucket {self.bucket_name} does not exist") from e + + @_retry + def _s3_client(self, region=None): + region = region if region is not None else self.obs_region + if region not in self._cached_s3_clients: + obsSession = boto3.Session( + aws_access_key_id=self.obs_access_key, + aws_secret_access_key=self.obs_secret_key, + region_name=self.obs_region + # config=Config(connect_timeout=timeout_seconds, read_timeout=timeout_seconds) + ) + # s3_client = obsSession.client('s3', endpoint_url=self.obs_ednpoint) + timeout_seconds = 120 + self._cached_s3_clients[region] = obsSession.client( + "s3", + endpoint_url=self.obs_ednpoint, + config=boto3.session.Config(connect_timeout=timeout_seconds, read_timeout=timeout_seconds), + ) + return self._cached_s3_clients[region] + + @property + def provider(self): + return "scp" + + def path(self): + return f"scp://{self.bucket_name}" + + def region_tag(self): + return "scp:" + self.scp_region + + def bucket(self) -> str: + return self.bucket_name + + @property + @lru_cache(maxsize=1) + def scp_region(self): + # scp_client = self.scp_client + # obsBucketId = self._get_bucket_id() + if self.obsBucketId is None: + raise exceptions.MissingBucketException(f"SCP bucket {self.bucket_name} does not exist") + + default_region = cloud_config.get_flag("scp_default_region") + # uri_path = f"/object-storage/v3/buckets/{self.obsBucketId}/api-info" + uri_path = f"/object-storage/v4/buckets/{self.obsBucketId}/access-info" + try: + response = self.scp_client._get(uri_path, None) # No value + network = SCPNetwork(self.auth) + region = network.get_service_zoneName(response["serviceZoneId"]) + return region if region is not None else default_region + except Exception as e: + if "An error occurred (AccessDenied) when calling the GetBucketLocation operation" in str(e): + logger.warning(f"Bucket location {self.bucket_name} is not public. Assuming region is {default_region}") + return default_region + logger.warning(f"Specified bucket {self.bucket_name} does not exist, got SCP error: {e}") + # print("Error getting SCP region", e) + raise exceptions.MissingBucketException(f"SCP bucket {self.bucket_name} does not exist") from e + + def bucket_exists(self) -> bool: + try: + obsBuckets = self.bucket_lists() + # pytype: disable=unsupported-operands + for bucket in obsBuckets: + if bucket["objectStorageBucketName"] == self.bucket_name: + return True + return False + # pytype: enable=unsupported-operands + except Exception as e: + logger.warning(f"Specified bucket {self.bucket_name} does not exist, got error: {e}") + # print("Error getting bucket: ", e) + return False + + def bucket_lists(self) -> List[str]: + # scp_client = self.scp_client + # uri_path = "/object-storage/v3/buckets?size=999" + uri_path = "/object-storage/v4/buckets?size=999" + response = self.scp_client._get(uri_path) + return response + + def create_object_repr(self, key: str) -> SCPObject: + return SCPObject(provider=self.provider, bucket=self.bucket(), key=key) + + def _get_bucket_id(self) -> Optional[str]: + # url = f"/object-storage/v3/buckets?obsBucketName={self.bucket_name}" + url = f"/object-storage/v4/buckets?objectStorageBucketName={self.bucket_name}" + try: + response = self.scp_client._get(url) + # print(response) + if len(response) == 0: + return None + else: + return response[0]["objectStorageBucketId"] + except Exception as e: + logger.warning(f"Specified bucket {self.bucket_name} does not exist, got error: {e}") + # print("Error getting bucket id : ", e) + return None + + def get_objectstorage_id(self, zone_id=None): + # scp_client = self.scp_client + zone_id = zone_id if zone_id is not None else None + # uri_path = f"/object-storage/v3/object-storages?zoneId={zone_id}" + uri_path = f"/object-storage/v4/object-storages?serviceZoneId={zone_id}" + + response = self.scp_client._get(uri_path) + return response[0]["objectStorageId"] + + def create_bucket(self, scp_region): + if not self.bucket_exists(): + network = SCPNetwork(self.auth) + zone_id = network.get_service_zone_id(scp_region) + obs_id = self.get_objectstorage_id(zone_id) + + # uri_path = "/object-storage/v3/buckets" + uri_path = "/object-storage/v4/buckets" + req_body = { + "objectStorageBucketAccessControlEnabled": "false", + "objectStorageBucketFileEncryptionEnabled": "false", + "objectStorageBucketName": self.bucket_name, + "objectStorageBucketVersionEnabled": "false", + "objectStorageId": obs_id, + "productNames": ["Object Storage"], + "serviceZoneId": zone_id, + "tags": [{"tagKey": "skycomputing", "tagValue": "skyplane"}], + } + self.scp_client._post(uri_path, req_body) + time.sleep(3) + else: + logger.warning(f"Bucket {self.bucket} in region {scp_region} already exists") + + def delete_bucket(self): + if self.bucket_exists(): + # obsBucketId = self._get_bucket_id() + # scp_client = self.scp_client + # uri_path = f"/object-storage/v3/buckets/{self.obsBucketId}" + uri_path = f"/object-storage/v4/buckets/{self.obsBucketId}" + + self.scp_client._delete(uri_path) + else: + logger.warning(f"Bucket {self.bucket} in region {self.scp_region} is not exists") + + # def list_objects(self, prefix: str = None, recursive: bool = False) -> Iterator[SCPObject]: + @_retry + def list_objects(self, prefix="", obs_region=None) -> Iterator[SCPObject]: + self._set_s3_credentials() + paginator = self._s3_client(self.obs_region).get_paginator("list_objects_v2") + requester_pays = {"RequestPayer": "requester"} if self.requester_pays else {} + page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix, **requester_pays) + for page in page_iterator: + objs = [] + for obj in page.get("Contents", []): + objs.append( + SCPObject( + obj["Key"], + provider=self.provider, + bucket=self.bucket(), + size=obj["Size"], + last_modified=obj["LastModified"], + mime_type=obj.get("ContentType"), + ) + ) + yield from objs + # return objs # For Test + + def delete_objects(self, keys: List[str]): + # self._set_s3_credentials() + s3_client = self._s3_client() + while keys: + batch, keys = keys[:1000], keys[1000:] # take up to 1000 keys at a time + s3_client.delete_objects(Bucket=self.bucket_name, Delete={"Objects": [{"Key": k} for k in batch]}) + + @lru_cache(maxsize=1024) + @imports.inject("botocore.exceptions", pip_extra="aws") + def get_obj_metadata(botocore_exceptions, self, obj_name): + self._set_s3_credentials() + s3_client = self._s3_client() + try: + return s3_client.head_object(Bucket=self.bucket_name, Key=str(obj_name)) + except botocore_exceptions.ClientError as e: + raise NoSuchObjectException(f"Object {obj_name} does not exist, or you do not have permission to access it") from e + + def get_obj_size(self, obj_name): + return self.get_obj_metadata(obj_name)["ContentLength"] + + def get_obj_last_modified(self, obj_name): + return self.get_obj_metadata(obj_name)["LastModified"] + + def get_obj_mime_type(self, obj_name): + return self.get_obj_metadata(obj_name)["ContentType"] + + def exists(self, obj_name): + try: + self.get_obj_metadata(obj_name) + return True + except NoSuchObjectException: + return False + + def download_object( + self, + src_object_name, + dst_file_path, + offset_bytes=None, + size_bytes=None, + write_at_offset=False, + generate_md5=False, + write_block_size=2**16, + ) -> Tuple[Optional[str], Optional[bytes]]: + src_object_name, dst_file_path = str(src_object_name), str(dst_file_path) + # self._set_s3_credentials() + max_retries = 10 + retries = 0 + md5 = None + mime_type = None + while retries < max_retries: + s3_client = self._s3_client() + try: + assert len(src_object_name) > 0, f"Source object name must be non-empty: '{src_object_name}'" + args = {"Bucket": self.bucket_name, "Key": src_object_name} + assert not (offset_bytes and not size_bytes), f"Cannot specify {offset_bytes} without {size_bytes}" + if offset_bytes is not None and size_bytes is not None: + args["Range"] = f"bytes={offset_bytes}-{offset_bytes + size_bytes - 1}" + if self.requester_pays: + args["RequestPayer"] = "requester" + response = s3_client.get_object(**args) + + # write response data + if os.path.exists(dst_file_path): + os.remove(dst_file_path) + if not os.path.exists(dst_file_path): + open(dst_file_path, "a").close() + if generate_md5: + m = hashlib.md5() + with open(dst_file_path, "wb+" if write_at_offset else "wb") as f: + f.seek(offset_bytes if write_at_offset else 0) + b = response["Body"].read(write_block_size) + while b: + if generate_md5: + m.update(b) + f.write(b) + b = response["Body"].read(write_block_size) + response["Body"].close() + md5 = m.digest() if generate_md5 else None + mime_type = response["ContentType"] + except Exception as e: + retries += 1 + if retries == max_retries: + raise + else: + chunk_name = os.path.basename(dst_file_path) + logger.warning(f"Download failed_{chunk_name}, retrying ({retries}/{max_retries})") + # 이 위치에 에러 파일 만들기 + with open(f"/skyplane/download_object_{chunk_name}_{retries}_{max_retries}_error.txt", "w") as f: + f.write(str(e)) + time.sleep(1) + else: + break + return mime_type, md5 + + @imports.inject("botocore.exceptions", pip_extra="aws") + def upload_object( + botocore_exceptions, self, src_file_path, dst_object_name, part_number=None, upload_id=None, check_md5=None, mime_type=None + ): + dst_object_name, src_file_path = str(dst_object_name), str(src_file_path) + self._set_s3_credentials() + s3_client = self._s3_client() + assert len(dst_object_name) > 0, f"Destination object name must be non-empty: '{dst_object_name}'" + b64_md5sum = base64.b64encode(check_md5).decode("utf-8") if check_md5 else None + checksum_args = dict(ContentMD5=b64_md5sum) if b64_md5sum else dict() + + max_retries = 10 + retries = 0 + while retries < max_retries: + try: + with open(src_file_path, "rb") as f: + if upload_id: + # if part_number is None: # Upload ALL parts + # part_size = 10 * 1024 * 1024 #100MB + # pn = 1 + # while True: + # data = f.read(part_size) + # if not len(data): + # break + # s3_client.upload_part( + # Body=data, + # Key=dst_object_name, + # Bucket=self.bucket_name, + # PartNumber=pn, + # UploadId=upload_id.strip(), # TODO: figure out why whitespace gets added, + # **checksum_args, + # ) + # pn += 1 + # else: + s3_client.upload_part( + Body=f, + Key=dst_object_name, + Bucket=self.bucket_name, + PartNumber=part_number, + UploadId=upload_id.strip(), # TODO: figure out why whitespace gets added, + **checksum_args, + ) + else: + mime_args = dict(ContentType=mime_type) if mime_type else dict() + s3_client.put_object(Body=f, Key=dst_object_name, Bucket=self.bucket_name, **checksum_args, **mime_args) + except botocore_exceptions.ClientError as e: + retries += 1 + if retries == max_retries: + # catch MD5 mismatch error and raise appropriate exception + if "Error" in e.response and "Code" in e.response["Error"] and e.response["Error"]["Code"] == "InvalidDigest": + raise exceptions.ChecksumMismatchException(f"Checksum mismatch for object {dst_object_name}") from e + raise + else: + logger.warning(f"Upload failed, retrying ({retries}/{max_retries})") + # 이 위치에 에러 파일 만들기 + with open(f"/skyplane/upload_object_{dst_object_name}_{retries}_error.txt", "w") as f: + f.write(str(e)) + time.sleep(1) + else: + break + + def initiate_multipart_upload(self, dst_object_name: str, mime_type: Optional[str] = None) -> str: + self._set_s3_credentials() + client = self._s3_client() + assert len(dst_object_name) > 0, f"Destination object name must be non-empty: '{dst_object_name}'" + response = client.create_multipart_upload( + Bucket=self.bucket_name, Key=dst_object_name, **(dict(ContentType=mime_type) if mime_type else dict()) + ) + if "UploadId" in response: + return response["UploadId"] + else: + raise exceptions.SkyplaneException(f"Failed to initiate multipart upload for {dst_object_name}: {response}") + + def complete_multipart_upload(self, dst_object_name, upload_id, metadata: Optional[Any] = None): + self._set_s3_credentials() + s3_client = self._s3_client() + all_parts = [] + while True: + response = s3_client.list_parts( + Bucket=self.bucket_name, Key=dst_object_name, MaxParts=100, UploadId=upload_id, PartNumberMarker=len(all_parts) + ) + if "Parts" not in response: + break + else: + if len(response["Parts"]) == 0: + break + all_parts += response["Parts"] + all_parts = sorted(all_parts, key=lambda d: d["PartNumber"]) + response = s3_client.complete_multipart_upload( + UploadId=upload_id, + Bucket=self.bucket_name, + Key=dst_object_name, + MultipartUpload={"Parts": [{"PartNumber": p["PartNumber"], "ETag": p["ETag"]} for p in all_parts]}, + ) + assert "ETag" in response, f"Failed to complete multipart upload for {dst_object_name}: {response}" + + # Lists in-progress multipart uploads. + @imports.inject("botocore.exceptions", pip_extra="aws") + def list_multipart_uploads(self): + self._set_s3_credentials() + s3_client = self._s3_client() + response = s3_client.list_multipart_uploads(Bucket=self.bucket()) + return response["Uploads"] if response is not None else None + + # Aborts a multipart upload. After a multipart upload is aborted, no additional parts can be uploaded using that upload ID + @imports.inject("botocore.exceptions", pip_extra="aws") + def abort_multipart_upload(self, key, upload_id): + self._set_s3_credentials() + s3_client = self._s3_client() + response = s3_client.abort_multipart_upload(Bucket=self.bucket(), Key=key, UploadId=upload_id) + return response diff --git a/skyplane/obj_store/storage_interface.py b/skyplane/obj_store/storage_interface.py index a430573f8..7694b67e4 100644 --- a/skyplane/obj_store/storage_interface.py +++ b/skyplane/obj_store/storage_interface.py @@ -60,6 +60,10 @@ def create(region_tag: str, bucket: str): logger.fs.debug(f"attempting to create hdfs bucket {bucket}") return HDFSInterface(host=bucket) + elif region_tag.startswith("scp"): + from skyplane.obj_store.scp_interface import SCPInterface + + return SCPInterface(bucket) elif region_tag.startswith("local"): # from skyplane.obj_store.file_system_interface import FileSystemInterface from skyplane.obj_store.posix_file_interface import POSIXInterface diff --git a/skyplane/planner/planner.py b/skyplane/planner/planner.py index 2bb21e0ac..dde2846ba 100644 --- a/skyplane/planner/planner.py +++ b/skyplane/planner/planner.py @@ -23,7 +23,7 @@ import json from skyplane.utils.fn import do_parallel -from skyplane.config_paths import config_path, azure_standardDv5_quota_path, aws_quota_path, gcp_quota_path +from skyplane.config_paths import config_path, azure_standardDv5_quota_path, aws_quota_path, gcp_quota_path, scp_quota_path from skyplane.config import SkyplaneConfig @@ -48,6 +48,9 @@ def __init__(self, transfer_config: TransferConfig, quota_limits_file: Optional[ if os.path.exists(gcp_quota_path): with gcp_quota_path.open("r") as f: quota_limits["gcp"] = json.load(f) + if os.path.exists(scp_quota_path): + with scp_quota_path.open("r") as f: + quota_limits["scp"] = json.load(f) self.quota_limits = quota_limits # Loading the vcpu information - a dictionary of dictionaries @@ -102,6 +105,10 @@ def _get_quota_limits_for(self, cloud_provider: str, region: str, spot: bool = F for quota in quota_limits: if quota["region_name"] == region: return quota["spot_standard_vcpus"] if spot else quota["on_demand_standard_vcpus"] + elif cloud_provider == "scp": + for quota in quota_limits: + if quota["service_zone_name"] == region: + return quota["on_demand_standard_vcpus"] return None def _calculate_vm_types(self, region_tag: str) -> Optional[Tuple[str, int]]: diff --git a/skyplane/utils/path.py b/skyplane/utils/path.py index 9670934f8..6a826fa56 100644 --- a/skyplane/utils/path.py +++ b/skyplane/utils/path.py @@ -63,6 +63,15 @@ def is_plausible_local_path(path_test: str): raise ValueError(f"Invalid HDFS path: {path}") host, path = match.groups() return "hdfs", host, path + elif path.startswith("scp://"): + provider, parsed = path[:3], path[6:] + if len(parsed) == 0: + logger.error(f"Invalid SCP path: '{path}'", fg="red", err=True) + raise ValueError(f"Invalid SCP path: '{path}'") + bucket, *keys = parsed.split("/", 1) + key = keys[0] if len(keys) > 0 else "" + provider = "scp" + return provider, bucket, key else: if not is_plausible_local_path(path): logger.warning(f"Local path '{path}' does not exist") diff --git a/tests/integration/cp.py b/tests/integration/cp.py index b96128465..b39af2560 100644 --- a/tests/integration/cp.py +++ b/tests/integration/cp.py @@ -12,11 +12,15 @@ def setup_buckets(src_region, dest_region, n_files=1, file_size_mb=1): src_provider, src_zone = src_region.split(":") dest_provider, dest_zone = dest_region.split(":") if src_provider == "azure": - src_bucket_name = f"skyplanetest{src_zone}/{str(uuid.uuid4()).replace('-', '')}" + src_bucket_name = f"integration{src_zone}/{str(uuid.uuid4()).replace('-', '')}" + elif src_provider == "scp": # object storage name must be lowercase + src_bucket_name = f"integration-{src_zone.lower()}-{str(uuid.uuid4())[:8]}" else: - src_bucket_name = f"skyplane-integration-{src_zone}-{str(uuid.uuid4())[:8]}" + src_bucket_name = f"integration{src_zone}-{str(uuid.uuid4())[:8]}" if dest_provider == "azure": - dest_bucket_name = f"skyplanetest{dest_zone}/{str(uuid.uuid4()).replace('-', '')}" + dest_bucket_name = f"integration{dest_zone}/{str(uuid.uuid4()).replace('-', '')}" + elif dest_provider == "scp": # object storage name must be lowercase + dest_bucket_name = f"skyplane-integration-{dest_zone.lower()}-{str(uuid.uuid4())[:8]}" else: dest_bucket_name = f"skyplane-integration-{dest_zone}-{str(uuid.uuid4())[:8]}" logger.debug(f"creating buckets {src_bucket_name} and {dest_bucket_name}") @@ -59,6 +63,8 @@ def map_path(region, bucket, prefix): return f"https://{storage_account}.blob.core.windows.net/{container}/{prefix}" elif provider == "gcp": return f"gs://{bucket}/{prefix}" + elif provider == "scp": + return f"scp://{bucket}/{prefix}" else: raise Exception(f"Unknown provider {provider}") diff --git a/tests/integration/test_api_client.py b/tests/integration/test_api_client.py index 1652f9ada..95508703c 100644 --- a/tests/integration/test_api_client.py +++ b/tests/integration/test_api_client.py @@ -57,3 +57,8 @@ def test_aws_interface(): def test_gcp_interface(): test_region("gcp:us-central1-a") return True + + +def test_scp_interface(): + test_region("scp:KOREA-WEST-1-SCP-B001") + return True diff --git a/tests/unit_scp/test_scp_obj_interface.py b/tests/unit_scp/test_scp_obj_interface.py new file mode 100644 index 000000000..25008f501 --- /dev/null +++ b/tests/unit_scp/test_scp_obj_interface.py @@ -0,0 +1,44 @@ +import uuid +from skyplane.obj_store.object_store_interface import ObjectStoreInterface +from tests.interface_util import interface_test_framework +from skyplane.utils import logger + + +def test_scp_singlepart(): + assert interface_test_framework("scp:KOREA-WEST-2-SCP-B001", f"test-skyplane-{uuid.uuid4()}", False, test_delete_bucket=True) + + +test_scp_singlepart() + + +def test_scp_singlepart_zero_bytes(): + assert interface_test_framework( + "scp:KOREA-WEST-2-SCP-B001", f"test-skyplane-{uuid.uuid4()}", False, test_delete_bucket=True, file_size_mb=0 + ) + + +test_scp_singlepart_zero_bytes() + + +def test_scp_multipart(): + assert interface_test_framework("scp:KOREA-WEST-2-SCP-B001", f"test-skyplane-{uuid.uuid4()}", True, test_delete_bucket=True) + + +test_scp_multipart() + + +def test_scp_bucket_exists(): + # test a public bucket with objects + iface = ObjectStoreInterface.create("scp:infer", "skyplane") + assert iface.bucket_exists() + + # test a random bucket that doesn't exist + iface = ObjectStoreInterface.create("scp:infer", f"skyplane-does-not-exist-{uuid.uuid4()}") + assert not iface.bucket_exists() + + # test public but empty bucket + # iface = ObjectStoreInterface.create("scp:infer", "skyplane-test-empty-public-bucket") + # assert iface.bucket_exists() + + +test_scp_bucket_exists()