Skip to content

Commit

Permalink
AWS Auth Propagation for API (#923)
Browse files Browse the repository at this point in the history
  • Loading branch information
abiswal2001 committed Oct 20, 2023
1 parent e734ae5 commit 9ac0764
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 24 deletions.
64 changes: 61 additions & 3 deletions skyplane/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@
import typer
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional

from skyplane.api.config import TransferConfig
from skyplane.api.provisioner import Provisioner
from skyplane.api.obj_store import ObjectStore
from skyplane.config import SkyplaneConfig
from skyplane.config_paths import config_path
from skyplane.obj_store.object_store_interface import ObjectStoreInterface
from skyplane.obj_store.storage_interface import StorageInterface
from skyplane.api.usage import get_clientid
from skyplane.utils import logger
from skyplane.utils.definitions import tmp_log_dir
from skyplane.utils.path import parse_path

from skyplane.api.pipeline import Pipeline

Expand All @@ -28,6 +33,10 @@ def __init__(
ibmcloud_config: Optional["IBMCloudConfig"] = None,
transfer_config: Optional[TransferConfig] = None,
log_dir: Optional[str] = None,
disable_aws: bool = False,
disable_azure: bool = False,
disable_gcp: bool = False,
disable_ibm: bool = False
):
"""
:param aws_config: aws cloud configurations
Expand Down Expand Up @@ -66,16 +75,53 @@ def __init__(
azure_auth=self.azure_auth,
gcp_auth=self.gcp_auth,
ibmcloud_auth=self.ibmcloud_auth,
disable_aws=disable_aws,
disable_azure=disable_azure,
disable_gcp=disable_gcp,
disable_ibm=disable_ibm
)

def pipeline(self, planning_algorithm: Optional[str] = "direct", max_instances: Optional[int] = 1, debug=False):
self.config = SkyplaneConfig.default_config()
if not disable_aws:
self.config.aws_enabled = True
if aws_config:
self.config.aws_access_key = aws_config.aws_access_key
self.config.aws_secret_key = aws_config.aws_secret_key
if not disable_azure:
self.config.azure_enabled = True
if azure_config:
self.config.azure_subscription_id=azure_config.azure_subscription_id
self.config.azure_resource_group=azure_config.azure_resource_group
self.config.azure_principal_id=azure_config.azure_umi_id
self.config.azure_umi_name=azure_config.azure_umi_name
self.config.azure_client_id=azure_config.azure_umi_client_id
if not disable_gcp:
self.config.gcp_enabled = True
if gcp_config:
self.config.gcp_project_id=gcp_config.gcp_project_id
if not disable_ibm:
self.config.ibm_enabled = True
if ibm_config:
self.config.ibmcloud_access_id=ibm_config.ibmcloud_access_id
self.config.ibmcloud_secret_key=ibm_config.ibmcloud_secret_key
self.config.ibmcloud_iam_key=ibm_config.ibmcloud_iam_key
self.config.ibmcloud_iam_endpoint=ibm_config.ibmcloud_iam_endpoint
self.config.ibmcloud_useragent=ibm_config.ibmcloud_useragent
self.config.ibmcloud_resource_group_id=ibm_config.ibmcloud_resource_group_id

self.config.to_config_file(config_path)
typer.secho(f"\nConfig file saved to {config_path}", fg="green")

def pipeline(self, planning_algorithm: Optional[str] = "direct", max_instances: Optional[int] = 1, src_iface: Optional[ObjectStoreInterface] = None, dst_ifaces: Optional[List[ObjectStoreInterface]] = None, debug=False):
"""Create a pipeline object to queue jobs"""
return Pipeline(
planning_algorithm=planning_algorithm,
max_instances=max_instances,
clientid=self.clientid,
provisioner=self.provisioner,
transfer_config=self.transfer_config,
src_iface=src_iface,
dst_ifaces=dst_ifaces,
debug=debug,
)

Expand All @@ -93,8 +139,20 @@ def copy(self, src: str, dst: str, recursive: bool = False, max_instances: Optio
:param max_instances: The maximum number of instances to use per region (default: 1)
:type max_instances: int
"""
provider_src, bucket_src, _ = parse_path(src)

src_iface = ObjectStoreInterface.create(f"{provider_src}:infer", bucket_src, aws_auth=self.aws_auth, azure_auth=self.azure_auth, gcp_auth=self.gcp_auth)

if isinstance(dst, str):
provider_dst, bucket_dst, _ = parse_path(dst)
dst_ifaces = [StorageInterface.create(f"{provider_dst}:infer", bucket_dst, aws_auth=self.aws_auth, azure_auth=self.azure_auth, gcp_auth=self.gcp_auth)]
else:
dst_ifaces = []
for dst_path in dst:
provider_dst, bucket_dst, _ = parse_path(dst_path)
dst_ifaces.append(StorageInterface.create(f"{provider_dst}:infer", bucket_dst, aws_auth=self.aws_auth, azure_auth=self.azure_auth, gcp_auth=self.gcp_auth))

pipeline = self.pipeline(max_instances=max_instances, debug=debug)
pipeline = self.pipeline(max_instances=max_instances, debug=debug, src_iface = src_iface, dst_ifaces=dst_ifaces)
pipeline.queue_copy(src, dst, recursive=recursive)
pipeline.start(progress=True)

Expand Down
11 changes: 9 additions & 2 deletions skyplane/api/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from skyplane.api.transfer_job import CopyJob, SyncJob, TransferJob
from skyplane.api.config import TransferConfig

from skyplane.obj_store.object_store_interface import ObjectStoreInterface
from skyplane.planner.planner import MulticastDirectPlanner, DirectPlannerSourceOneSided, DirectPlannerDestOneSided
from skyplane.planner.topology import TopologyPlanGateway
from skyplane.utils import logger
Expand All @@ -33,6 +34,8 @@ def __init__(
max_instances: Optional[int] = 1,
n_connections: Optional[int] = 64,
planning_algorithm: Optional[str] = "direct",
src_iface: Optional[ObjectStoreInterface] = None,
dst_ifaces: Optional[List[ObjectStoreInterface]] = None,
debug: Optional[bool] = False,
):
"""
Expand Down Expand Up @@ -70,6 +73,10 @@ def __init__(
else:
raise ValueError(f"No such planning algorithm {planning_algorithm}")

# obj store interfaces
self.src_iface = src_iface
self.dst_ifaces = dst_ifaces

# transfer logs
self.transfer_dir = tmp_log_dir / "transfer_logs" / datetime.now().strftime("%Y%m%d_%H%M%S")
self.transfer_dir.mkdir(exist_ok=True, parents=True)
Expand Down Expand Up @@ -146,7 +153,7 @@ def queue_copy(
"""
if isinstance(dst, str):
dst = [dst]
job = CopyJob(src, dst, recursive, requester_pays=self.transfer_config.requester_pays)
job = CopyJob(src, dst, recursive, requester_pays=self.transfer_config.requester_pays, src_iface=self.src_iface, dst_ifaces=self.dst_ifaces)
logger.fs.debug(f"[SkyplaneClient] Queued copy job {job}")
self.jobs_to_dispatch.append(job)
return job.uuid
Expand All @@ -169,7 +176,7 @@ def queue_sync(
"""
if isinstance(dst, str):
dst = [dst]
job = SyncJob(src, dst, requester_pays=self.transfer_config.requester_pays)
job = SyncJob(src, dst, requester_pays=self.transfer_config.requester_pays, src_iface=self.src_iface, dst_ifaces=self.dst_ifaces)
logger.fs.debug(f"[SkyplaneClient] Queued sync job {job}")
self.jobs_to_dispatch.append(job)
return job.uuid
Expand Down
24 changes: 16 additions & 8 deletions skyplane/api/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def __init__(
gcp_auth: Optional[compute.GCPAuthentication] = None,
host_uuid: Optional[str] = None,
ibmcloud_auth: Optional[compute.IBMCloudAuthentication] = None,
disable_aws: bool = False,
disable_azure: bool = False,
disable_gcp: bool = False,
disable_ibm: bool = False
):
"""
:param aws_auth: authentication information for aws
Expand All @@ -70,21 +74,25 @@ def __init__(
self.gcp_auth = gcp_auth
self.host_uuid = host_uuid
self.ibmcloud_auth = ibmcloud_auth
self._make_cloud_providers()
self._make_cloud_providers(disable_aws, disable_azure, disable_gcp, disable_ibm)
self.temp_nodes: Set[compute.Server] = set() # temporary area to store nodes that should be terminated upon exit
self.pending_provisioner_tasks: List[ProvisionerTask] = []
self.provisioned_vms: Dict[str, compute.Server] = {}

# store GCP firewall rules to be deleted upon exit
self.gcp_firewall_rules: Set[str] = set()

def _make_cloud_providers(self):
self.aws = compute.AWSCloudProvider(
key_prefix=f"skyplane{'-'+self.host_uuid.replace('-', '') if self.host_uuid else ''}", auth=self.aws_auth
)
self.azure = compute.AzureCloudProvider(auth=self.azure_auth)
self.gcp = compute.GCPCloudProvider(auth=self.gcp_auth)
self.ibmcloud = compute.IBMCloudProvider(auth=self.ibmcloud_auth)
def _make_cloud_providers(self, disable_aws, disable_azure, disable_gcp, disable_ibm):
if not disable_aws:
self.aws = compute.AWSCloudProvider(
key_prefix=f"skyplane{'-'+self.host_uuid.replace('-', '') if self.host_uuid else ''}", auth=self.aws_auth
)
if not disable_azure:
self.azure = compute.AzureCloudProvider(auth=self.azure_auth)
if not disable_gcp:
self.gcp = compute.GCPCloudProvider(auth=self.gcp_auth)
if not disable_ibm:
self.ibmcloud = compute.IBMCloudProvider(auth=self.ibmcloud_auth)

def init_global(self, aws: bool = True, azure: bool = True, gcp: bool = True, ibmcloud: bool = True):
"""
Expand Down
16 changes: 11 additions & 5 deletions skyplane/api/transfer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ def __init__(
recursive: bool = False,
requester_pays: bool = False,
job_id: Optional[str] = None,
src_iface: Optional[ObjectStoreInterface] = None,
dst_ifaces: Optional[List[ObjectStoreInterface]] = None
):
self.src_path = src_path
self.dst_paths = dst_paths
Expand All @@ -477,6 +479,8 @@ def __init__(
self.uuid = str(uuid.uuid4())
else:
self.uuid = job_id
self._src_iface = src_iface
self._dst_ifaces = dst_ifaces

@property
def transfer_type(self) -> str:
Expand All @@ -495,7 +499,7 @@ def src_prefix(self) -> Optional[str]:
@property
def src_iface(self) -> StorageInterface:
"""Return the source object store interface"""
if not hasattr(self, "_src_iface"):
if not self._src_iface:
provider_src, bucket_src, _ = parse_path(self.src_path)
self._src_iface = ObjectStoreInterface.create(f"{provider_src}:infer", bucket_src)
if self.requester_pays:
Expand All @@ -515,7 +519,7 @@ def dst_prefixes(self) -> List[str]:
@property
def dst_ifaces(self) -> List[StorageInterface]:
"""Return the destination object store interface"""
if not hasattr(self, "_dst_iface"):
if not self._dst_ifaces:
if self.transfer_type == "unicast":
provider_dst, bucket_dst, _ = parse_path(self.dst_paths[0])
self._dst_ifaces = [StorageInterface.create(f"{provider_dst}:infer", bucket_dst)]
Expand Down Expand Up @@ -570,8 +574,10 @@ def __init__(
recursive: bool = False,
requester_pays: bool = False,
job_id: Optional[str] = None,
src_iface: Optional[ObjectStoreInterface] = None,
dst_ifaces: Optional[List[ObjectStoreInterface]] = None
):
super().__init__(src_path, dst_paths, recursive, requester_pays, job_id)
super().__init__(src_path, dst_paths, recursive, requester_pays, job_id, src_iface, dst_ifaces)
self.transfer_list = []
self.multipart_transfer_list = []

Expand Down Expand Up @@ -772,8 +778,8 @@ def size_gb(self):
class SyncJob(CopyJob):
"""sync job that copies the source objects that does not exist in the destination bucket to the destination"""

def __init__(self, src_path: str, dst_paths: List[str] or str, requester_pays: bool = False, job_id: Optional[str] = None):
super().__init__(src_path, dst_paths, True, requester_pays, job_id)
def __init__(self, src_path: str, dst_paths: List[str] or str, requester_pays: bool = False, job_id: Optional[str] = None, src_iface: Optional[ObjectStoreInterface] = None, dst_ifaces: Optional[List[ObjectStoreInterface]] = None):
super().__init__(src_path, dst_paths, True, requester_pays, job_id, src_iface, dst_ifaces)
self.transfer_list = []
self.multipart_transfer_list = []

Expand Down
4 changes: 4 additions & 0 deletions skyplane/compute/aws/aws_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def __init__(self, config: Optional[SkyplaneConfig] = None, access_key: Optional
self.config_mode = "manual"
self._access_key = access_key
self._secret_key = secret_key
elif self.config.aws_access_key and self.config.aws_secret_key:
self.config_mode = "manual"
self._access_key = self.config.aws_access_key
self._secret_key = self.config.aws_secret_key
else:
self.config_mode = "iam_inferred"
self._access_key = None
Expand Down
4 changes: 2 additions & 2 deletions skyplane/compute/aws/aws_cloud_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_instance_list(exceptions, self, region: str) -> List[AWSServer]:
except exceptions.ClientError as e:
logger.error(f"error provisioning in {region}: {e}")
return []
return [AWSServer(f"aws:{region}", i) for i in instance_ids]
return [AWSServer(f"aws:{region}", i, auth=self.auth) for i in instance_ids]

def setup_global(self, iam_name: str = "skyplane_gateway", attach_policy_arn: Optional[str] = None):
# Create IAM role if it doesn't exist and grant managed role if given.
Expand Down Expand Up @@ -246,4 +246,4 @@ def start_instance(subnet_id: str):
logger.fs.warning(f"Terminating instance {instance[0].id} due to keyboard interrupt")
instance[0].terminate()
raise
return AWSServer(f"aws:{region}", instance[0].id)
return AWSServer(f"aws:{region}", instance[0].id, auth=self.auth)
4 changes: 2 additions & 2 deletions skyplane/compute/aws/aws_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
class AWSServer(Server):
"""AWS Server class to support basic SSH operations"""

def __init__(self, region_tag, instance_id, log_dir=None):
def __init__(self, region_tag, instance_id, log_dir=None, auth: Optional[AWSAuthentication] = None):
super().__init__(region_tag, log_dir=log_dir)
assert self.region_tag.split(":")[0] == "aws"
self.auth = AWSAuthentication()
self.auth = AWSAuthentication() if auth is None else auth
self.key_manager = AWSKeyManager(self.auth)
self.aws_region = self.region_tag.split(":")[1]
self.instance_id = instance_id
Expand Down
14 changes: 14 additions & 0 deletions skyplane/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class SkyplaneConfig:
cloudflare_enabled: bool
ibmcloud_enabled: bool
anon_clientid: str
aws_access_key: Optional[str] = None
aws_secret_key: Optional[str] = None
azure_principal_id: Optional[str] = None
azure_subscription_id: Optional[str] = None
azure_resource_group: Optional[str] = None
Expand Down Expand Up @@ -148,9 +150,15 @@ def load_config(cls, path) -> "SkyplaneConfig":
anon_clientid = cls.generate_machine_id()

aws_enabled = False
aws_access_key = None
aws_secret_access_key = None
if "aws" in config:
if "aws_enabled" in config["aws"]:
aws_enabled = config.getboolean("aws", "aws_enabled")
if "aws_access_key" in config["aws"]:
aws_access_key = config.get("aws", "aws_access_key")
if "aws_secret_key" in config["aws"]:
aws_secret_key = config.get("aws", "aws_secret_key")

azure_enabled = False
azure_subscription_id = None
Expand Down Expand Up @@ -215,6 +223,8 @@ def load_config(cls, path) -> "SkyplaneConfig":
gcp_enabled=gcp_enabled,
ibmcloud_enabled=ibmcloud_enabled,
cloudflare_enabled=cloudflare_enabled,
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
anon_clientid=anon_clientid,
azure_principal_id=azure_principal_id,
azure_subscription_id=azure_subscription_id,
Expand Down Expand Up @@ -248,6 +258,10 @@ def to_config_file(self, path):
if "aws" not in config:
config.add_section("aws")
config.set("aws", "aws_enabled", str(self.aws_enabled))
if self.aws_access_key:
config.set("aws", "aws_access_key", self.aws_access_key)
if self.aws_secret_key:
config.set("aws", "aws_secret_key", self.aws_secret_key)

if "ibmcloud" not in config:
config.add_section("ibmcloud")
Expand Down
7 changes: 5 additions & 2 deletions skyplane/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@


class Planner:
def __init__(self, transfer_config: TransferConfig, quota_limits_file: Optional[str] = None):
def __init__(self, transfer_config: TransferConfig, config: Optional[SkyplaneConfig] = None, quota_limits_file: Optional[str] = None):
self.transfer_config = transfer_config
self.config = SkyplaneConfig.load_config(config_path)
if config_path.exists():
self.config = SkyplaneConfig.load_config(config_path)
else:
self.config = SkyplaneConfig.default_config()
self.n_instances = self.config.get_flag("max_instances")

# Loading the quota information, add ibm cloud when it is supported
Expand Down

0 comments on commit 9ac0764

Please sign in to comment.