Skip to content

Commit

Permalink
[plugins/aws][fix] Set default region on all sessions (#1623)
Browse files Browse the repository at this point in the history
* [plugins/aws][fix] Set default region on all sessions

* Probe partition for manually specified accounts

* Syntax
  • Loading branch information
lloesche committed May 31, 2023
1 parent 79cec0f commit 0f21094
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 16 deletions.
63 changes: 51 additions & 12 deletions plugins/aws/resoto_plugin_aws/__init__.py
Expand Up @@ -45,6 +45,8 @@
["account"],
)

GLOBAL_REGIONS = ("us-east-1", "us-gov-west-1", "cn-north-1")


class AWSCollectorPlugin(BaseCollectorPlugin):
cloud = "aws"
Expand All @@ -60,7 +62,7 @@ def add_config(cfg: Config) -> None:

@staticmethod
def auto_enableable() -> bool:
for region in ("us-east-1", "us-gov-west-1", "cn-north-1"):
for region in GLOBAL_REGIONS:
try:
account_id = (
boto3.session.Session(region_name=region).client("sts").get_caller_identity().get("Account")
Expand Down Expand Up @@ -139,7 +141,12 @@ def collect_aws(self) -> None:
def regions(self, profile: Optional[str] = None, partition: str = "aws") -> List[str]:
if len(self.__regions) == 0:
if not Config.aws.region or (isinstance(Config.aws.region, list) and len(Config.aws.region) == 0):
log.debug("AWS region not specified, assuming all regions")
add_log_str = ""
if profile:
add_log_str += f" profile {profile}"
if partition:
add_log_str += f" partition {partition}"
log.debug(f"AWS region not specified, assuming all regions{add_log_str}")
self.__regions = all_regions(profile=profile, partition=partition)
else:
self.__regions = list(Config.aws.region)
Expand Down Expand Up @@ -410,11 +417,22 @@ def cleanup(config: Config, resource: BaseResource, graph: Graph) -> bool:

def authenticated(account: AwsAccount, core_feedback: CoreFeedback) -> bool:
try:
log.debug(f"AWS testing credentials for {account.rtdname}")
add_log_str = ""
if account.role:
add_log_str += f" role {account.role}"
if account.profile:
add_log_str += f" profile {account.profile}"
if account.partition:
add_log_str += f" partition {account.partition}"
log.debug(f"AWS testing credentials for {account.rtdname}{add_log_str}")
session = aws_session(
account=account.id, role=account.role, profile=account.profile, partition=account.partition
)
_ = session.client("sts").get_caller_identity().get("Account")
_ = (
session.client("sts", region_name=global_region_by_partition(account.partition))
.get_caller_identity()
.get("Account")
)
except botocore.exceptions.NoCredentialsError:
core_feedback.error(f"No AWS credentials found for {account.rtdname}", log)
except botocore.exceptions.ClientError as e:
Expand All @@ -439,9 +457,28 @@ def current_account_id(profile: Optional[str] = None) -> str:
return account_id


def probe_partition(account: Optional[str] = None, role: Optional[str] = None, profile: Optional[str] = None) -> str:
for region in GLOBAL_REGIONS:
partition = arn_partition_by_region(region)
try:
session = aws_session(account=account, role=role, profile=profile, partition=partition)
_ = session.client("sts", region_name=region).get_caller_identity().get("Account")
except Exception:
pass
else:
return partition
return "aws"


def current_account_id_and_partition(profile: Optional[str] = None) -> Tuple[str, str]:
interesting_exception = None
for region in ("us-east-1", "us-gov-west-1", "cn-north-1"):
add_log_str = ""
if profile:
add_log_str = f" with profile {profile}"
log.debug(f"Trying to determine current account id and partition{add_log_str}")
for region in GLOBAL_REGIONS:
partition = arn_partition_by_region(region)
log.debug(f"Probing region {region}")
try:
if profile:
account_id = (
Expand All @@ -461,11 +498,14 @@ def current_account_id_and_partition(profile: Optional[str] = None) -> Tuple[str
.get_caller_identity()
.get("Account")
)
return account_id, arn_partition_by_region(region)
log.debug(f"Determined partition for account {account_id} to be {partition}")
return account_id, partition
except botocore.exceptions.ClientError as e:
log.debug(f"Got an exception when probing partition {partition}: {e}")
if e.response["Error"]["Code"] != "InvalidClientTokenId":
interesting_exception = e
except Exception as e:
log.debug(f"Got an exception when probing partition {partition}: {e}")
interesting_exception = e
if interesting_exception:
raise interesting_exception
Expand Down Expand Up @@ -543,12 +583,11 @@ def get_accounts(core_feedback: CoreFeedback) -> List[AwsAccount]:
accounts.append(AwsAccount(id=account_id, partition=partition))
elif Config.aws.role and Config.aws.account:
log.debug("Both, role and list of accounts specified")
accounts.extend(
[
AwsAccount(id=aws_account_id, role=Config.aws.role, profile=profile)
for aws_account_id in Config.aws.account
]
)
for aws_account_id in Config.aws.account:
partition = probe_partition(aws_account_id, profile=profile)
accounts.append(
AwsAccount(id=aws_account_id, role=Config.aws.role, profile=profile, partition=partition)
)
else:
account_id, partition = current_account_id_and_partition(profile=profile)
accounts.extend([AwsAccount(id=account_id, profile=profile, partition=partition)])
Expand Down
12 changes: 8 additions & 4 deletions plugins/aws/resoto_plugin_aws/configuration.py
Expand Up @@ -34,12 +34,15 @@ class AwsSessionHolder:

# noinspection PyUnusedLocal
@lru_cache(maxsize=128)
def __direct_session(self, profile: Optional[str]) -> BotoSession:
def __direct_session(self, profile: Optional[str], partition: str) -> BotoSession:
global_region = global_region_by_partition(partition)
if profile:
return self.session_class_factory(profile_name=profile)
return self.session_class_factory(profile_name=profile, region_name=global_region)
else:
return self.session_class_factory(
aws_access_key_id=self.access_key_id, aws_secret_access_key=self.secret_access_key
aws_access_key_id=self.access_key_id,
aws_secret_access_key=self.secret_access_key,
region_name=global_region,
)

# noinspection PyUnusedLocal
Expand Down Expand Up @@ -68,6 +71,7 @@ def __sts_session(
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
region_name=global_region,
)

def _session(
Expand All @@ -82,7 +86,7 @@ def _session(
Consider using the client() and resource() methods instead.
"""
if aws_role is None:
return self.__direct_session(aws_profile)
return self.__direct_session(aws_profile, aws_partition)
else:
# Use sts to create a temporary token for the given account and role
# Sts session is at least valid for 900 seconds (default 1 hour)
Expand Down
3 changes: 3 additions & 0 deletions plugins/aws/resoto_plugin_aws/utils.py
Expand Up @@ -63,16 +63,19 @@ def aws_session(
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
region_name=global_region,
)
else:
if profile:
return BotoSession(
profile_name=profile,
region_name=global_region,
)
else:
return BotoSession(
aws_access_key_id=Config.aws.access_key_id,
aws_secret_access_key=Config.aws.secret_access_key,
region_name=global_region,
)


Expand Down

0 comments on commit 0f21094

Please sign in to comment.