Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

AWS Auth Propagation for API #923

Merged
merged 26 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
64 changes: 61 additions & 3 deletions skyplane/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@
import typer
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional

from skyplane.api.config import TransferConfig
from skyplane.api.provisioner import Provisioner
from skyplane.api.obj_store import ObjectStore
from skyplane.config import SkyplaneConfig
from skyplane.config_paths import config_path
from skyplane.obj_store.object_store_interface import ObjectStoreInterface
from skyplane.obj_store.storage_interface import StorageInterface
from skyplane.api.usage import get_clientid
from skyplane.utils import logger
from skyplane.utils.definitions import tmp_log_dir
from skyplane.utils.path import parse_path

from skyplane.api.pipeline import Pipeline

Expand All @@ -28,6 +33,10 @@ def __init__(
ibmcloud_config: Optional["IBMCloudConfig"] = None,
transfer_config: Optional[TransferConfig] = None,
log_dir: Optional[str] = None,
disable_aws: bool = False,
disable_azure: bool = False,
disable_gcp: bool = False,
disable_ibm: bool = False
):
"""
:param aws_config: aws cloud configurations
Expand Down Expand Up @@ -66,16 +75,53 @@ def __init__(
azure_auth=self.azure_auth,
gcp_auth=self.gcp_auth,
ibmcloud_auth=self.ibmcloud_auth,
disable_aws=disable_aws,
disable_azure=disable_azure,
disable_gcp=disable_gcp,
disable_ibm=disable_ibm
)

def pipeline(self, planning_algorithm: Optional[str] = "direct", max_instances: Optional[int] = 1, debug=False):
self.config = SkyplaneConfig.default_config()
if not disable_aws:
self.config.aws_enabled = True
if aws_config:
self.config.aws_access_key = aws_config.aws_access_key
self.config.aws_secret_key = aws_config.aws_secret_key
if not disable_azure:
self.config.azure_enabled = True
if azure_config:
self.config.azure_subscription_id=azure_config.azure_subscription_id
self.config.azure_resource_group=azure_config.azure_resource_group
self.config.azure_principal_id=azure_config.azure_umi_id
self.config.azure_umi_name=azure_config.azure_umi_name
self.config.azure_client_id=azure_config.azure_umi_client_id
if not disable_gcp:
self.config.gcp_enabled = True
if gcp_config:
self.config.gcp_project_id=gcp_config.gcp_project_id
if not disable_ibm:
self.config.ibm_enabled = True
if ibm_config:
self.config.ibmcloud_access_id=ibm_config.ibmcloud_access_id
self.config.ibmcloud_secret_key=ibm_config.ibmcloud_secret_key
self.config.ibmcloud_iam_key=ibm_config.ibmcloud_iam_key
self.config.ibmcloud_iam_endpoint=ibm_config.ibmcloud_iam_endpoint
self.config.ibmcloud_useragent=ibm_config.ibmcloud_useragent
self.config.ibmcloud_resource_group_id=ibm_config.ibmcloud_resource_group_id

self.config.to_config_file(config_path)
typer.secho(f"\nConfig file saved to {config_path}", fg="green")

def pipeline(self, planning_algorithm: Optional[str] = "direct", max_instances: Optional[int] = 1, src_iface: Optional[ObjectStoreInterface] = None, dst_ifaces: Optional[List[ObjectStoreInterface]] = None, debug=False):
"""Create a pipeline object to queue jobs"""
return Pipeline(
planning_algorithm=planning_algorithm,
max_instances=max_instances,
clientid=self.clientid,
provisioner=self.provisioner,
transfer_config=self.transfer_config,
src_iface=src_iface,
dst_ifaces=dst_ifaces,
debug=debug,
)

Expand All @@ -93,8 +139,20 @@ def copy(self, src: str, dst: str, recursive: bool = False, max_instances: Optio
:param max_instances: The maximum number of instances to use per region (default: 1)
:type max_instances: int
"""
provider_src, bucket_src, _ = parse_path(src)

src_iface = ObjectStoreInterface.create(f"{provider_src}:infer", bucket_src, aws_auth=self.aws_auth, azure_auth=self.azure_auth, gcp_auth=self.gcp_auth)

if isinstance(dst, str):
provider_dst, bucket_dst, _ = parse_path(dst)
dst_ifaces = [StorageInterface.create(f"{provider_dst}:infer", bucket_dst, aws_auth=self.aws_auth, azure_auth=self.azure_auth, gcp_auth=self.gcp_auth)]
else:
dst_ifaces = []
for dst_path in dst:
provider_dst, bucket_dst, _ = parse_path(dst_path)
dst_ifaces.append(StorageInterface.create(f"{provider_dst}:infer", bucket_dst, aws_auth=self.aws_auth, azure_auth=self.azure_auth, gcp_auth=self.gcp_auth))

pipeline = self.pipeline(max_instances=max_instances, debug=debug)
pipeline = self.pipeline(max_instances=max_instances, debug=debug, src_iface = src_iface, dst_ifaces=dst_ifaces)
pipeline.queue_copy(src, dst, recursive=recursive)
pipeline.start(progress=True)

Expand Down
11 changes: 9 additions & 2 deletions skyplane/api/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from skyplane.api.transfer_job import CopyJob, SyncJob, TransferJob
from skyplane.api.config import TransferConfig

from skyplane.obj_store.object_store_interface import ObjectStoreInterface
from skyplane.planner.planner import MulticastDirectPlanner, DirectPlannerSourceOneSided, DirectPlannerDestOneSided
from skyplane.planner.topology import TopologyPlanGateway
from skyplane.utils import logger
Expand All @@ -33,6 +34,8 @@ def __init__(
max_instances: Optional[int] = 1,
n_connections: Optional[int] = 64,
planning_algorithm: Optional[str] = "direct",
src_iface: Optional[ObjectStoreInterface] = None,
dst_ifaces: Optional[List[ObjectStoreInterface]] = None,
debug: Optional[bool] = False,
):
"""
Expand Down Expand Up @@ -70,6 +73,10 @@ def __init__(
else:
raise ValueError(f"No such planning algorithm {planning_algorithm}")

# obj store interfaces
self.src_iface = src_iface
self.dst_ifaces = dst_ifaces

# transfer logs
self.transfer_dir = tmp_log_dir / "transfer_logs" / datetime.now().strftime("%Y%m%d_%H%M%S")
self.transfer_dir.mkdir(exist_ok=True, parents=True)
Expand Down Expand Up @@ -146,7 +153,7 @@ def queue_copy(
"""
if isinstance(dst, str):
dst = [dst]
job = CopyJob(src, dst, recursive, requester_pays=self.transfer_config.requester_pays)
job = CopyJob(src, dst, recursive, requester_pays=self.transfer_config.requester_pays, src_iface=self.src_iface, dst_ifaces=self.dst_ifaces)
logger.fs.debug(f"[SkyplaneClient] Queued copy job {job}")
self.jobs_to_dispatch.append(job)
return job.uuid
Expand All @@ -169,7 +176,7 @@ def queue_sync(
"""
if isinstance(dst, str):
dst = [dst]
job = SyncJob(src, dst, requester_pays=self.transfer_config.requester_pays)
job = SyncJob(src, dst, requester_pays=self.transfer_config.requester_pays, src_iface=self.src_iface, dst_ifaces=self.dst_ifaces)
logger.fs.debug(f"[SkyplaneClient] Queued sync job {job}")
self.jobs_to_dispatch.append(job)
return job.uuid
Expand Down
24 changes: 16 additions & 8 deletions skyplane/api/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def __init__(
gcp_auth: Optional[compute.GCPAuthentication] = None,
host_uuid: Optional[str] = None,
ibmcloud_auth: Optional[compute.IBMCloudAuthentication] = None,
disable_aws: bool = False,
disable_azure: bool = False,
disable_gcp: bool = False,
disable_ibm: bool = False
):
"""
:param aws_auth: authentication information for aws
Expand All @@ -70,21 +74,25 @@ def __init__(
self.gcp_auth = gcp_auth
self.host_uuid = host_uuid
self.ibmcloud_auth = ibmcloud_auth
self._make_cloud_providers()
self._make_cloud_providers(disable_aws, disable_azure, disable_gcp, disable_ibm)
self.temp_nodes: Set[compute.Server] = set() # temporary area to store nodes that should be terminated upon exit
self.pending_provisioner_tasks: List[ProvisionerTask] = []
self.provisioned_vms: Dict[str, compute.Server] = {}

# store GCP firewall rules to be deleted upon exit
self.gcp_firewall_rules: Set[str] = set()

def _make_cloud_providers(self):
self.aws = compute.AWSCloudProvider(
key_prefix=f"skyplane{'-'+self.host_uuid.replace('-', '') if self.host_uuid else ''}", auth=self.aws_auth
)
self.azure = compute.AzureCloudProvider(auth=self.azure_auth)
self.gcp = compute.GCPCloudProvider(auth=self.gcp_auth)
self.ibmcloud = compute.IBMCloudProvider(auth=self.ibmcloud_auth)
def _make_cloud_providers(self, disable_aws, disable_azure, disable_gcp, disable_ibm):
if not disable_aws:
self.aws = compute.AWSCloudProvider(
key_prefix=f"skyplane{'-'+self.host_uuid.replace('-', '') if self.host_uuid else ''}", auth=self.aws_auth
)
if not disable_azure:
self.azure = compute.AzureCloudProvider(auth=self.azure_auth)
if not disable_gcp:
self.gcp = compute.GCPCloudProvider(auth=self.gcp_auth)
if not disable_ibm:
self.ibmcloud = compute.IBMCloudProvider(auth=self.ibmcloud_auth)

def init_global(self, aws: bool = True, azure: bool = True, gcp: bool = True, ibmcloud: bool = True):
"""
Expand Down
16 changes: 11 additions & 5 deletions skyplane/api/transfer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ def __init__(
recursive: bool = False,
requester_pays: bool = False,
job_id: Optional[str] = None,
src_iface: Optional[ObjectStoreInterface] = None,
dst_ifaces: Optional[List[ObjectStoreInterface]] = None
):
self.src_path = src_path
self.dst_paths = dst_paths
Expand All @@ -477,6 +479,8 @@ def __init__(
self.uuid = str(uuid.uuid4())
else:
self.uuid = job_id
self._src_iface = src_iface
self._dst_ifaces = dst_ifaces

@property
def transfer_type(self) -> str:
Expand All @@ -495,7 +499,7 @@ def src_prefix(self) -> Optional[str]:
@property
def src_iface(self) -> StorageInterface:
"""Return the source object store interface"""
if not hasattr(self, "_src_iface"):
if not self._src_iface:
provider_src, bucket_src, _ = parse_path(self.src_path)
self._src_iface = ObjectStoreInterface.create(f"{provider_src}:infer", bucket_src)
if self.requester_pays:
Expand All @@ -515,7 +519,7 @@ def dst_prefixes(self) -> List[str]:
@property
def dst_ifaces(self) -> List[StorageInterface]:
"""Return the destination object store interface"""
if not hasattr(self, "_dst_iface"):
if not self._dst_ifaces:
if self.transfer_type == "unicast":
provider_dst, bucket_dst, _ = parse_path(self.dst_paths[0])
self._dst_ifaces = [StorageInterface.create(f"{provider_dst}:infer", bucket_dst)]
Expand Down Expand Up @@ -570,8 +574,10 @@ def __init__(
recursive: bool = False,
requester_pays: bool = False,
job_id: Optional[str] = None,
src_iface: Optional[ObjectStoreInterface] = None,
dst_ifaces: Optional[List[ObjectStoreInterface]] = None
):
super().__init__(src_path, dst_paths, recursive, requester_pays, job_id)
super().__init__(src_path, dst_paths, recursive, requester_pays, job_id, src_iface, dst_ifaces)
self.transfer_list = []
self.multipart_transfer_list = []

Expand Down Expand Up @@ -772,8 +778,8 @@ def size_gb(self):
class SyncJob(CopyJob):
"""sync job that copies the source objects that does not exist in the destination bucket to the destination"""

def __init__(self, src_path: str, dst_paths: List[str] or str, requester_pays: bool = False, job_id: Optional[str] = None):
super().__init__(src_path, dst_paths, True, requester_pays, job_id)
def __init__(self, src_path: str, dst_paths: List[str] or str, requester_pays: bool = False, job_id: Optional[str] = None, src_iface: Optional[ObjectStoreInterface] = None, dst_ifaces: Optional[List[ObjectStoreInterface]] = None):
super().__init__(src_path, dst_paths, True, requester_pays, job_id, src_iface, dst_ifaces)
self.transfer_list = []
self.multipart_transfer_list = []

Expand Down
4 changes: 4 additions & 0 deletions skyplane/compute/aws/aws_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def __init__(self, config: Optional[SkyplaneConfig] = None, access_key: Optional
self.config_mode = "manual"
self._access_key = access_key
self._secret_key = secret_key
elif self.config.aws_access_key and self.config.aws_secret_key:
self.config_mode = "manual"
self._access_key = self.config.aws_access_key
self._secret_key = self.config.aws_secret_key
else:
self.config_mode = "iam_inferred"
self._access_key = None
Expand Down
4 changes: 2 additions & 2 deletions skyplane/compute/aws/aws_cloud_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_instance_list(exceptions, self, region: str) -> List[AWSServer]:
except exceptions.ClientError as e:
logger.error(f"error provisioning in {region}: {e}")
return []
return [AWSServer(f"aws:{region}", i) for i in instance_ids]
return [AWSServer(f"aws:{region}", i, auth=self.auth) for i in instance_ids]

def setup_global(self, iam_name: str = "skyplane_gateway", attach_policy_arn: Optional[str] = None):
# Create IAM role if it doesn't exist and grant managed role if given.
Expand Down Expand Up @@ -246,4 +246,4 @@ def start_instance(subnet_id: str):
logger.fs.warning(f"Terminating instance {instance[0].id} due to keyboard interrupt")
instance[0].terminate()
raise
return AWSServer(f"aws:{region}", instance[0].id)
return AWSServer(f"aws:{region}", instance[0].id, auth=self.auth)
4 changes: 2 additions & 2 deletions skyplane/compute/aws/aws_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
class AWSServer(Server):
"""AWS Server class to support basic SSH operations"""

def __init__(self, region_tag, instance_id, log_dir=None):
def __init__(self, region_tag, instance_id, log_dir=None, auth: Optional[AWSAuthentication] = None):
super().__init__(region_tag, log_dir=log_dir)
assert self.region_tag.split(":")[0] == "aws"
self.auth = AWSAuthentication()
self.auth = AWSAuthentication() if auth is None else auth
self.key_manager = AWSKeyManager(self.auth)
self.aws_region = self.region_tag.split(":")[1]
self.instance_id = instance_id
Expand Down
14 changes: 14 additions & 0 deletions skyplane/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class SkyplaneConfig:
cloudflare_enabled: bool
ibmcloud_enabled: bool
anon_clientid: str
aws_access_key: Optional[str] = None
aws_secret_key: Optional[str] = None
azure_principal_id: Optional[str] = None
azure_subscription_id: Optional[str] = None
azure_resource_group: Optional[str] = None
Expand Down Expand Up @@ -148,9 +150,15 @@ def load_config(cls, path) -> "SkyplaneConfig":
anon_clientid = cls.generate_machine_id()

aws_enabled = False
aws_access_key = None
aws_secret_access_key = None
if "aws" in config:
if "aws_enabled" in config["aws"]:
aws_enabled = config.getboolean("aws", "aws_enabled")
if "aws_access_key" in config["aws"]:
aws_access_key = config.get("aws", "aws_access_key")
if "aws_secret_key" in config["aws"]:
aws_secret_key = config.get("aws", "aws_secret_key")

azure_enabled = False
azure_subscription_id = None
Expand Down Expand Up @@ -215,6 +223,8 @@ def load_config(cls, path) -> "SkyplaneConfig":
gcp_enabled=gcp_enabled,
ibmcloud_enabled=ibmcloud_enabled,
cloudflare_enabled=cloudflare_enabled,
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
anon_clientid=anon_clientid,
azure_principal_id=azure_principal_id,
azure_subscription_id=azure_subscription_id,
Expand Down Expand Up @@ -248,6 +258,10 @@ def to_config_file(self, path):
if "aws" not in config:
config.add_section("aws")
config.set("aws", "aws_enabled", str(self.aws_enabled))
if self.aws_access_key:
config.set("aws", "aws_access_key", self.aws_access_key)
if self.aws_secret_key:
config.set("aws", "aws_secret_key", self.aws_secret_key)

if "ibmcloud" not in config:
config.add_section("ibmcloud")
Expand Down
7 changes: 5 additions & 2 deletions skyplane/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@


class Planner:
def __init__(self, transfer_config: TransferConfig, quota_limits_file: Optional[str] = None):
def __init__(self, transfer_config: TransferConfig, config: Optional[SkyplaneConfig] = None, quota_limits_file: Optional[str] = None):
self.transfer_config = transfer_config
self.config = SkyplaneConfig.load_config(config_path)
if config_path.exists():
self.config = SkyplaneConfig.load_config(config_path)
else:
self.config = SkyplaneConfig.default_config()
self.n_instances = self.config.get_flag("max_instances")

# Loading the quota information, add ibm cloud when it is supported
Expand Down