From 0b21719746a892bd1d62978a038b7aa7aa140b16 Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Tue, 11 Mar 2025 14:07:22 -0400 Subject: [PATCH 01/14] Add basic IMDS credentials resolver --- .../credentials_resolvers/__init__.py | 7 +- .../credentials_resolvers/imds.py | 230 ++++++++++++++++++ 2 files changed, 236 insertions(+), 1 deletion(-) create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/__init__.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/__init__.py index 3aead11b3..bf7d08c0a 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/__init__.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/__init__.py @@ -2,5 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 from .environment import EnvironmentCredentialsResolver from .static import StaticCredentialsResolver +from .imds import IMDSCredentialsResolver -__all__ = ("EnvironmentCredentialsResolver", "StaticCredentialsResolver") +__all__ = ( + "EnvironmentCredentialsResolver", + "StaticCredentialsResolver", + "IMDSCredentialsResolver", +) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py new file mode 100644 index 000000000..4f625c35b --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py @@ -0,0 +1,230 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import json +import threading +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Literal + +from smithy_core import URI +from smithy_core.aio.interfaces.identity import IdentityResolver +from smithy_core.exceptions import SmithyIdentityException +from smithy_core.interfaces.identity import IdentityProperties +from smithy_core.interfaces.retries import RetryStrategy +from smithy_core.retries import SimpleRetryStrategy +from smithy_http import Field, Fields +from smithy_http.aio import HTTPRequest +from smithy_http.aio.interfaces import HTTPClient + +from smithy_aws_core.identity import AWSCredentialsIdentity + + +class Token: + """Represents an IMDSv2 session token with a value and method for checking + expiration.""" + + def __init__(self, value: bytes, ttl: int): + self._value = value + self._ttl = ttl + self._created_time = datetime.now() + + def is_expired(self) -> bool: + """Check if the token has expired.""" + return datetime.now() - self._created_time >= timedelta(seconds=self._ttl) + + @property + def value(self) -> bytes: + return self._value + + +class TokenCache: + """Holds the token needed to fetch instance metadata. In addition, it knows how to + refresh itself. + + :param HTTPClient http_client: The client used for making http requests. + :param int token_ttl: The time in seconds before a token expires. + """ + + _MIN_TTL = 5 + _MAX_TTL = 21600 + _TOKEN_PATH = "/latest/api/token" + + def __init__( + self, http_client: HTTPClient, base_uri: URI, token_ttl: int = _MAX_TTL + ): + self._http_client = http_client + self._base_uri = base_uri + self._token_ttl = self._validate_token_ttl(token_ttl) + self._refresh_lock = threading.Lock() + self._token = None + + def _validate_token_ttl(self, ttl: int) -> int: + """Validates the token TTL value.""" + if not self._MIN_TTL <= ttl <= self._MAX_TTL: + raise ValueError( + f"Token TTL must be between {self._MIN_TTL} and {self._MAX_TTL} seconds." + ) + return ttl + + def _should_refresh(self) -> bool: + """Determines if the token should be refreshed.""" + return self._token is None or self._token.is_expired() + + async def _refresh(self) -> None: + """Refreshes the token if needed, with thread safety.""" + with self._refresh_lock: + if not self._should_refresh(): + return + headers = Fields( + [ + # TODO: Add user-agent + Field( + name="x-aws-ec2-metadata-token-ttl-seconds", + values=[str(self._token_ttl)], + ), + ] + ) + request = HTTPRequest( + method="PUT", + destination=URI( + scheme=self._base_uri.scheme, + host=self._base_uri.host, + path=self._TOKEN_PATH, + ), + fields=headers, + ) + response = await self._http_client.send(request) + token_value = await response.consume_body_async() + self._token = Token(token_value, self._token_ttl) + + async def get_token(self) -> Token: + """Get the current token, refreshing it if expired.""" + if self._should_refresh(): + await self._refresh() + assert self._token is not None + return self._token + + +@dataclass(init=False) +class Config: + """Configuration for EC2Metadata.""" + + retry_strategy: RetryStrategy + endpoint_uri: URI + endpoint_mode: Literal["IPv4", "IPv6"] + port: int + token_ttl: int + + def __init__( + self, + *, + retry_strategy: RetryStrategy | None = None, + endpoint_uri: URI | None = None, + endpoint_mode: Literal["IPv4", "IPv6"] = "IPv4", + port: int = 80, + token_ttl: int = 21600, + ec2_instance_profile_name: str | None = None, + ): + self.retry_strategy = retry_strategy or SimpleRetryStrategy(max_attempts=3) + self.endpoint_mode = endpoint_mode + self.endpoint_uri = self._resolve_endpoint(endpoint_uri, endpoint_mode) + self.port = port + self.token_ttl = token_ttl + self.ec2_instance_profile_name = ec2_instance_profile_name + + def _resolve_endpoint( + self, endpoint_uri: URI | None, endpoint_mode: Literal["IPv4", "IPv6"] + ) -> URI: + if endpoint_uri is not None: + return endpoint_uri + + host_mapping = {"IPv4": "169.254.169.254", "IPv6": "[fd00:ec2::254]"} + + return URI( + scheme="http", host=host_mapping.get(endpoint_mode, host_mapping["IPv4"]) + ) + + +class EC2Metadata: + def __init__(self, http_client: HTTPClient, config: Config | None = None): + self._http_client = http_client + self._config = config or Config() + self._token_cache = TokenCache( + http_client=self._http_client, + base_uri=self._config.endpoint_uri, + token_ttl=self._config.token_ttl, + ) + + async def get(self, *, path: str) -> str: + token = await self._token_cache.get_token() + headers = Fields( + [ + # TODO: Add user-agent + Field( + name="x-aws-ec2-metadata-token", + values=[token.value.decode("utf-8")], + ) + ] + ) + request = HTTPRequest( + method="GET", + destination=URI( + scheme=self._config.endpoint_uri.scheme, + host=self._config.endpoint_uri.host, + port=self._config.port, + path=path, + ), + fields=headers, + ) + response = await self._http_client.send(request=request) + body = await response.consume_body_async() + return body.decode("utf-8") + + +class IMDSCredentialsResolver( + IdentityResolver[AWSCredentialsIdentity, IdentityProperties] +): + """Resolves AWS Credentials from an EC2 Instance Metadata Service (IMDS) client.""" + + # TODO: Handle fallback to legacy path when a 404 is received. + _METADATA_PATH_BASE = "/latest/meta-data/iam/security-credentials-extended/" + + def __init__(self, http_client: HTTPClient, config: Config | None = None): + self._http_client = http_client + self._ec2_metadata_client = EC2Metadata(http_client=http_client, config=config) + self._config = config or Config() + self._credentials = None + self._profile_name = self._config.ec2_instance_profile_name + + async def get_identity( + self, *, identity_properties: IdentityProperties + ) -> AWSCredentialsIdentity: + if self._credentials is not None: + return self._credentials + + profile = self._profile_name + if profile is None: + profile = await self._ec2_metadata_client.get(path=self._METADATA_PATH_BASE) + + creds_str = await self._ec2_metadata_client.get( + path=f"{self._METADATA_PATH_BASE}/{profile}" + ) + creds = json.loads(creds_str) + + access_key_id = creds.get("AccessKeyId") + secret_access_key = creds.get("SecretAccessKey") + session_token = creds.get("Token") + account_id = creds.get("AccountId") + + if access_key_id is None or secret_access_key is None: + raise SmithyIdentityException( + "AccessKeyId and SecretAccessKey are required" + ) + + self._credentials = AWSCredentialsIdentity( + access_key_id=access_key_id, + secret_access_key=secret_access_key, + session_token=session_token, + account_id=account_id, + ) + return self._credentials From 3d7489e2a184832924a99c6b1e2f021ffdbb085d Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Wed, 12 Mar 2025 16:30:03 -0400 Subject: [PATCH 02/14] Use old `security-credentials` instead of `security-credentials-extended` --- .../src/smithy_aws_core/credentials_resolvers/imds.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py index 4f625c35b..58aeee38c 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py @@ -186,8 +186,7 @@ class IMDSCredentialsResolver( ): """Resolves AWS Credentials from an EC2 Instance Metadata Service (IMDS) client.""" - # TODO: Handle fallback to legacy path when a 404 is received. - _METADATA_PATH_BASE = "/latest/meta-data/iam/security-credentials-extended/" + _METADATA_PATH_BASE = "/latest/meta-data/iam/security-credentials" def __init__(self, http_client: HTTPClient, config: Config | None = None): self._http_client = http_client From a6cee6006da62eec6ba540a07b1d41b6cc88cc1e Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Wed, 12 Mar 2025 16:30:29 -0400 Subject: [PATCH 03/14] Use async lock --- .../src/smithy_aws_core/credentials_resolvers/imds.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py index 58aeee38c..d318ee74f 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py @@ -1,7 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import json -import threading +import asyncio from dataclasses import dataclass from datetime import datetime, timedelta from typing import Literal @@ -55,7 +55,7 @@ def __init__( self._http_client = http_client self._base_uri = base_uri self._token_ttl = self._validate_token_ttl(token_ttl) - self._refresh_lock = threading.Lock() + self._refresh_lock = asyncio.Lock() self._token = None def _validate_token_ttl(self, ttl: int) -> int: @@ -72,7 +72,7 @@ def _should_refresh(self) -> bool: async def _refresh(self) -> None: """Refreshes the token if needed, with thread safety.""" - with self._refresh_lock: + async with self._refresh_lock: if not self._should_refresh(): return headers = Fields( From df131f304d7ce49b3f266099042b563cf956588f Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Thu, 13 Mar 2025 09:52:20 -0400 Subject: [PATCH 04/14] Address initial feedback --- .../credentials_resolvers/imds.py | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py index d318ee74f..ab81c0f20 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py @@ -3,7 +3,7 @@ import json import asyncio from dataclasses import dataclass -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Literal from smithy_core import URI @@ -38,11 +38,9 @@ def value(self) -> bytes: class TokenCache: - """Holds the token needed to fetch instance metadata. In addition, it knows how to - refresh itself. + """Holds the token needed to fetch instance metadata. - :param HTTPClient http_client: The client used for making http requests. - :param int token_ttl: The time in seconds before a token expires. + In addition, it knows how to refresh itself. """ _MIN_TTL = 5 @@ -109,6 +107,8 @@ async def get_token(self) -> Token: class Config: """Configuration for EC2Metadata.""" + _HOST_MAPPING = {"IPv4": "169.254.169.254", "IPv6": "[fd00:ec2::254]"} + retry_strategy: RetryStrategy endpoint_uri: URI endpoint_mode: Literal["IPv4", "IPv6"] @@ -138,10 +138,9 @@ def _resolve_endpoint( if endpoint_uri is not None: return endpoint_uri - host_mapping = {"IPv4": "169.254.169.254", "IPv6": "[fd00:ec2::254]"} - return URI( - scheme="http", host=host_mapping.get(endpoint_mode, host_mapping["IPv4"]) + scheme="http", + host=self._HOST_MAPPING.get(endpoint_mode, self._HOST_MAPPING["IPv4"]), ) @@ -198,7 +197,11 @@ def __init__(self, http_client: HTTPClient, config: Config | None = None): async def get_identity( self, *, identity_properties: IdentityProperties ) -> AWSCredentialsIdentity: - if self._credentials is not None: + if ( + self._credentials is not None + and self._credentials.expiration + and datetime.now(timezone.utc) < self._credentials.expiration + ): return self._credentials profile = self._profile_name @@ -214,6 +217,9 @@ async def get_identity( secret_access_key = creds.get("SecretAccessKey") session_token = creds.get("Token") account_id = creds.get("AccountId") + expiration = creds.get("Expiration") + if expiration is not None: + expiration = datetime.fromisoformat(expiration).replace(tzinfo=timezone.utc) if access_key_id is None or secret_access_key is None: raise SmithyIdentityException( @@ -224,6 +230,7 @@ async def get_identity( access_key_id=access_key_id, secret_access_key=secret_access_key, session_token=session_token, + expiration=expiration, account_id=account_id, ) return self._credentials From 3f79e6f18ef030b4a3060aa4a4f0daea656e8c60 Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Thu, 13 Mar 2025 12:22:28 -0400 Subject: [PATCH 05/14] Move ttl validation to Config. Make config a parameter for TokenCache. --- .../credentials_resolvers/imds.py | 119 +++++++++--------- 1 file changed, 58 insertions(+), 61 deletions(-) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py index ab81c0f20..0562a09ec 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py @@ -19,6 +19,58 @@ from smithy_aws_core.identity import AWSCredentialsIdentity +@dataclass(init=False) +class Config: + """Configuration for EC2Metadata.""" + + _HOST_MAPPING = {"IPv4": "169.254.169.254", "IPv6": "[fd00:ec2::254]"} + _MIN_TTL = 5 + _MAX_TTL = 21600 + + retry_strategy: RetryStrategy + endpoint_uri: URI + endpoint_mode: Literal["IPv4", "IPv6"] + port: int + token_ttl: int + + def __init__( + self, + *, + retry_strategy: RetryStrategy | None = None, + endpoint_uri: URI | None = None, + endpoint_mode: Literal["IPv4", "IPv6"] = "IPv4", + port: int = 80, + token_ttl: int = _MAX_TTL, + ec2_instance_profile_name: str | None = None, + ): + # TODO: Implement retries. + self.retry_strategy = retry_strategy or SimpleRetryStrategy(max_attempts=3) + self.endpoint_mode = endpoint_mode + self.endpoint_uri = self._resolve_endpoint(endpoint_uri, endpoint_mode) + self.port = port + self.token_ttl = self._validate_token_ttl(token_ttl) + self.ec2_instance_profile_name = ec2_instance_profile_name + + def _validate_token_ttl(self, ttl: int) -> int: + """Validates the token TTL value.""" + if not self._MIN_TTL <= ttl <= self._MAX_TTL: + raise ValueError( + f"Token TTL must be between {self._MIN_TTL} and {self._MAX_TTL} seconds." + ) + return ttl + + def _resolve_endpoint( + self, endpoint_uri: URI | None, endpoint_mode: Literal["IPv4", "IPv6"] + ) -> URI: + if endpoint_uri is not None: + return endpoint_uri + + return URI( + scheme="http", + host=self._HOST_MAPPING.get(endpoint_mode, self._HOST_MAPPING["IPv4"]), + ) + + class Token: """Represents an IMDSv2 session token with a value and method for checking expiration.""" @@ -43,27 +95,15 @@ class TokenCache: In addition, it knows how to refresh itself. """ - _MIN_TTL = 5 - _MAX_TTL = 21600 _TOKEN_PATH = "/latest/api/token" - def __init__( - self, http_client: HTTPClient, base_uri: URI, token_ttl: int = _MAX_TTL - ): + def __init__(self, http_client: HTTPClient, config: Config): self._http_client = http_client - self._base_uri = base_uri - self._token_ttl = self._validate_token_ttl(token_ttl) + self._config = config + self._base_uri = config.endpoint_uri self._refresh_lock = asyncio.Lock() self._token = None - def _validate_token_ttl(self, ttl: int) -> int: - """Validates the token TTL value.""" - if not self._MIN_TTL <= ttl <= self._MAX_TTL: - raise ValueError( - f"Token TTL must be between {self._MIN_TTL} and {self._MAX_TTL} seconds." - ) - return ttl - def _should_refresh(self) -> bool: """Determines if the token should be refreshed.""" return self._token is None or self._token.is_expired() @@ -78,7 +118,7 @@ async def _refresh(self) -> None: # TODO: Add user-agent Field( name="x-aws-ec2-metadata-token-ttl-seconds", - values=[str(self._token_ttl)], + values=[str(self._config.token_ttl)], ), ] ) @@ -93,7 +133,7 @@ async def _refresh(self) -> None: ) response = await self._http_client.send(request) token_value = await response.consume_body_async() - self._token = Token(token_value, self._token_ttl) + self._token = Token(token_value, self._config.token_ttl) async def get_token(self) -> Token: """Get the current token, refreshing it if expired.""" @@ -103,55 +143,12 @@ async def get_token(self) -> Token: return self._token -@dataclass(init=False) -class Config: - """Configuration for EC2Metadata.""" - - _HOST_MAPPING = {"IPv4": "169.254.169.254", "IPv6": "[fd00:ec2::254]"} - - retry_strategy: RetryStrategy - endpoint_uri: URI - endpoint_mode: Literal["IPv4", "IPv6"] - port: int - token_ttl: int - - def __init__( - self, - *, - retry_strategy: RetryStrategy | None = None, - endpoint_uri: URI | None = None, - endpoint_mode: Literal["IPv4", "IPv6"] = "IPv4", - port: int = 80, - token_ttl: int = 21600, - ec2_instance_profile_name: str | None = None, - ): - self.retry_strategy = retry_strategy or SimpleRetryStrategy(max_attempts=3) - self.endpoint_mode = endpoint_mode - self.endpoint_uri = self._resolve_endpoint(endpoint_uri, endpoint_mode) - self.port = port - self.token_ttl = token_ttl - self.ec2_instance_profile_name = ec2_instance_profile_name - - def _resolve_endpoint( - self, endpoint_uri: URI | None, endpoint_mode: Literal["IPv4", "IPv6"] - ) -> URI: - if endpoint_uri is not None: - return endpoint_uri - - return URI( - scheme="http", - host=self._HOST_MAPPING.get(endpoint_mode, self._HOST_MAPPING["IPv4"]), - ) - - class EC2Metadata: def __init__(self, http_client: HTTPClient, config: Config | None = None): self._http_client = http_client self._config = config or Config() self._token_cache = TokenCache( - http_client=self._http_client, - base_uri=self._config.endpoint_uri, - token_ttl=self._config.token_ttl, + http_client=self._http_client, config=self._config ) async def get(self, *, path: str) -> str: From 28cc2d5f2f87f56f7acaa9cdd95b1d3eb6678b65 Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Thu, 13 Mar 2025 18:26:09 -0400 Subject: [PATCH 06/14] Store Token value as a string instead of bytes --- .../src/smithy_aws_core/credentials_resolvers/imds.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py index 0562a09ec..5b47c5e00 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py @@ -75,7 +75,7 @@ class Token: """Represents an IMDSv2 session token with a value and method for checking expiration.""" - def __init__(self, value: bytes, ttl: int): + def __init__(self, value: str, ttl: int): self._value = value self._ttl = ttl self._created_time = datetime.now() @@ -85,7 +85,7 @@ def is_expired(self) -> bool: return datetime.now() - self._created_time >= timedelta(seconds=self._ttl) @property - def value(self) -> bytes: + def value(self) -> str: return self._value @@ -133,7 +133,7 @@ async def _refresh(self) -> None: ) response = await self._http_client.send(request) token_value = await response.consume_body_async() - self._token = Token(token_value, self._config.token_ttl) + self._token = Token(token_value.decode("utf-8"), self._config.token_ttl) async def get_token(self) -> Token: """Get the current token, refreshing it if expired.""" @@ -158,7 +158,7 @@ async def get(self, *, path: str) -> str: # TODO: Add user-agent Field( name="x-aws-ec2-metadata-token", - values=[token.value.decode("utf-8")], + values=[token.value], ) ] ) From 878e7b2d693ae5ae13abbeaba6561e1c18e41ca0 Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Thu, 13 Mar 2025 19:21:28 -0400 Subject: [PATCH 07/14] Add tests for Config, Token, and TokenCache --- .../unit/credentials_resolvers/test_imds.py | 114 ++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py diff --git a/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py b/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py new file mode 100644 index 000000000..4d8fa8af6 --- /dev/null +++ b/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py @@ -0,0 +1,114 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# pyright: reportPrivateUsage=false +import pytest +import time +from smithy_core.retries import SimpleRetryStrategy +from smithy_core import URI +from smithy_aws_core.credentials_resolvers.imds import Config, Token, TokenCache +from unittest.mock import MagicMock, AsyncMock + + +def test_config_defaults(): + config = Config() + assert isinstance(config.retry_strategy, SimpleRetryStrategy) + assert config.endpoint_uri == URI(scheme="http", host=Config._HOST_MAPPING["IPv4"]) + assert config.endpoint_mode == "IPv4" + assert config.port == 80 + assert config.token_ttl == 21600 + + +def test_endpoint_resolution(): + config_ipv4 = Config(endpoint_mode="IPv4") + config_ipv6 = Config(endpoint_mode="IPv6") + assert config_ipv4.endpoint_uri.host == Config._HOST_MAPPING["IPv4"] + assert config_ipv6.endpoint_uri.host == Config._HOST_MAPPING["IPv6"] + + +def test_config_uses_custom_endpoint(): + # The custom endpoint should take precedence over IPv4 endpoint resolution. + config = Config( + endpoint_uri=URI(scheme="http", host="test.host"), endpoint_mode="IPv4" + ) + assert config.endpoint_uri == URI(scheme="http", host="test.host") + + # The custom endpoint takes precedence over IPv6 endpoint resolution. + config = Config( + endpoint_uri=URI(scheme="http", host="test.host"), endpoint_mode="IPv6" + ) + assert config.endpoint_uri == URI(scheme="http", host="test.host") + + +def test_config_ttl_validation(): + # TTL values < _MIN_TTL should throw a ValueError + with pytest.raises(ValueError): + Config(token_ttl=Config._MIN_TTL - 1) + # TTL values > _MAX_TTL should throw a ValueError + with pytest.raises(ValueError): + Config(token_ttl=Config._MAX_TTL + 1) + + +def test_token_creation(): + token = Token(value="test-token", ttl=100) + assert token._value == "test-token" + assert token._ttl == 100 + assert not token.is_expired() + + +def test_token_expiration(): + token = Token(value="test-token", ttl=1) + assert not token.is_expired() + time.sleep(1.1) + assert token.is_expired() + + +async def test_token_cache_should_refresh(): + http_client = AsyncMock() + config = MagicMock() + # A new token cache needs a refresh + token_cache = TokenCache(http_client, config) + assert token_cache._should_refresh() + # A token cache with an unexpired token doesn't need a refresh + token_cache._token = MagicMock() + token_cache._token.is_expired.return_value = False + assert not token_cache._should_refresh() + # A token cache with an expired token needs a refresh + token_cache._token.is_expired.return_value = True + assert token_cache._should_refresh() + + +async def test_token_cache_refresh(): + # Test that TokenCache correctly refreshes the token when needed + http_client = AsyncMock() + config = MagicMock() + config.token_ttl = 100 + config.endpoint_uri.scheme = "http" + config.endpoint_uri.host = "169.254.169.254" + response_mock = AsyncMock() + response_mock.consume_body_async.return_value = b"new-token-value" + http_client.send.return_value = response_mock + token_cache = TokenCache(http_client, config) + assert token_cache._should_refresh() + await token_cache._refresh() + assert token_cache._token is not None + assert token_cache._token.value == "new-token-value" + assert token_cache._token._ttl == 100 + + +async def test_token_cache_get_token(): + # Test that TokenCache correctly returns an existing token or refreshes if expired + http_client = AsyncMock() + config = MagicMock() + token_cache = TokenCache(http_client, config) + token_cache._refresh = AsyncMock() + token_cache._token = MagicMock() + token_cache._token.is_expired.return_value = False + token = await token_cache.get_token() + assert token == token_cache._token + token_cache._refresh.assert_not_awaited() + token_cache._token.is_expired.return_value = True + await token_cache.get_token() + token_cache._refresh.assert_awaited() + +# TODO: Add tests for EC2Metadata and IMDSCredentialsResolver \ No newline at end of file From 5fb2c4348b5329e1d3a7b8251e13bec20e29d34a Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Thu, 13 Mar 2025 19:24:17 -0400 Subject: [PATCH 08/14] whitespace linter fix --- .../tests/unit/credentials_resolvers/test_imds.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py b/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py index 4d8fa8af6..9d7a900d0 100644 --- a/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py +++ b/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py @@ -111,4 +111,5 @@ async def test_token_cache_get_token(): await token_cache.get_token() token_cache._refresh.assert_awaited() -# TODO: Add tests for EC2Metadata and IMDSCredentialsResolver \ No newline at end of file + +# TODO: Add tests for EC2Metadata and IMDSCredentialsResolver From f7daf20c2d782d2ce0bd3ee1f0c154e0369fc69f Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Fri, 14 Mar 2025 10:56:39 -0400 Subject: [PATCH 09/14] Add tests for EC2Metadata and IMDSCredentialsResolver classes --- .../unit/credentials_resolvers/test_imds.py | 66 ++++++++++++++++++- 1 file changed, 64 insertions(+), 2 deletions(-) diff --git a/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py b/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py index 9d7a900d0..ef9999c85 100644 --- a/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py +++ b/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py @@ -2,11 +2,20 @@ # SPDX-License-Identifier: Apache-2.0 # pyright: reportPrivateUsage=false +import json import pytest import time +from datetime import datetime, timezone from smithy_core.retries import SimpleRetryStrategy from smithy_core import URI -from smithy_aws_core.credentials_resolvers.imds import Config, Token, TokenCache +from smithy_http.aio import HTTPRequest +from smithy_aws_core.credentials_resolvers.imds import ( + Config, + Token, + TokenCache, + EC2Metadata, + IMDSCredentialsResolver, +) from unittest.mock import MagicMock, AsyncMock @@ -112,4 +121,57 @@ async def test_token_cache_get_token(): token_cache._refresh.assert_awaited() -# TODO: Add tests for EC2Metadata and IMDSCredentialsResolver +async def test_ec2_metadata_get(): + # Test EC2Metadata.get() method to retrieve metadata from IMDS + http_client = AsyncMock() + config = Config() + response = AsyncMock() + response.consume_body_async.return_value = b"metadata-response" + http_client.send.return_value = response + + ec2_metadata = EC2Metadata(http_client, config) + ec2_metadata._token_cache.get_token = AsyncMock( + return_value=Token("mocked-token", config.token_ttl) + ) + + result = await ec2_metadata.get(path="/test-path") + assert result == "metadata-response" + + request = http_client.send.call_args.kwargs["request"] + assert isinstance(request, HTTPRequest) + assert request.destination.path == "/test-path" + assert request.method == "GET" + assert request.fields["x-aws-ec2-metadata-token"].values == ["mocked-token"] + + +async def test_imds_credentials_resolver(): + # Test IMDSCredentialsResolver retrieving credentials + http_client = AsyncMock() + config = Config() + ec2_metadata = AsyncMock() + resolver = IMDSCredentialsResolver(http_client, config) + resolver._ec2_metadata_client = ec2_metadata + + # Mock EC2Metadata client get responses + ec2_metadata.get.side_effect = [ + "test-profile", + json.dumps( + { + "AccessKeyId": "test-access-key", + "SecretAccessKey": "test-secret-key", + "Token": "test-session-token", + "AccountId": "test-account", + "Expiration": "2025-03-13T07:28:47Z", + } + ), + ] + + credentials = await resolver.get_identity(identity_properties=MagicMock()) + assert credentials.access_key_id == "test-access-key" + assert credentials.secret_access_key == "test-secret-key" + assert credentials.session_token == "test-session-token" + assert credentials.account_id == "test-account" + assert credentials.expiration == datetime( + 2025, 3, 13, 7, 28, 47, tzinfo=timezone.utc + ) + ec2_metadata.get.assert_awaited() From 1b920a5cc016073e3cb43721aa8f522eaa185fcb Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Fri, 14 Mar 2025 11:12:48 -0400 Subject: [PATCH 10/14] Add todo for aws shared config and env var support --- .../src/smithy_aws_core/credentials_resolvers/imds.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py index 5b47c5e00..cd6348028 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py @@ -185,6 +185,7 @@ class IMDSCredentialsResolver( _METADATA_PATH_BASE = "/latest/meta-data/iam/security-credentials" def __init__(self, http_client: HTTPClient, config: Config | None = None): + # TODO: Respect IMDS specific config values from aws shared config file and environment. self._http_client = http_client self._ec2_metadata_client = EC2Metadata(http_client=http_client, config=config) self._config = config or Config() From b0622c38626dd672dcaebb9ce8f5893d18112af5 Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Mon, 17 Mar 2025 16:46:42 -0400 Subject: [PATCH 11/14] Only use reuse connections in the pool if they're open --- packages/smithy-http/src/smithy_http/aio/crt.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/packages/smithy-http/src/smithy_http/aio/crt.py b/packages/smithy-http/src/smithy_http/aio/crt.py index 0e17eba49..6183a232b 100644 --- a/packages/smithy-http/src/smithy_http/aio/crt.py +++ b/packages/smithy-http/src/smithy_http/aio/crt.py @@ -273,13 +273,15 @@ async def _get_connection( ) -> "crt_http.HttpClientConnection": # TODO: Use CRT connection pooling instead of this basic kind connection_key = (url.scheme, url.host, url.port) - if connection_key in self._connections: - return self._connections[connection_key] - else: - connection = await self._create_connection(url) - self._connections[connection_key] = connection + connection = self._connections.get(connection_key) + + if connection and connection.is_open(): return connection + connection = await self._create_connection(url) + self._connections[connection_key] = connection + return connection + def _build_new_connection( self, url: core_interfaces.URI ) -> ConcurrentFuture["crt_http.HttpClientConnection"]: From 8db26d2d2ee735bbff026b3fc53bda731c7a5fd7 Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Tue, 18 Mar 2025 13:59:18 -0400 Subject: [PATCH 12/14] Remove unnecessary doc strings --- .../src/smithy_aws_core/credentials_resolvers/imds.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py index cd6348028..184348aae 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py @@ -52,7 +52,6 @@ def __init__( self.ec2_instance_profile_name = ec2_instance_profile_name def _validate_token_ttl(self, ttl: int) -> int: - """Validates the token TTL value.""" if not self._MIN_TTL <= ttl <= self._MAX_TTL: raise ValueError( f"Token TTL must be between {self._MIN_TTL} and {self._MAX_TTL} seconds." @@ -81,7 +80,6 @@ def __init__(self, value: str, ttl: int): self._created_time = datetime.now() def is_expired(self) -> bool: - """Check if the token has expired.""" return datetime.now() - self._created_time >= timedelta(seconds=self._ttl) @property @@ -105,11 +103,9 @@ def __init__(self, http_client: HTTPClient, config: Config): self._token = None def _should_refresh(self) -> bool: - """Determines if the token should be refreshed.""" return self._token is None or self._token.is_expired() async def _refresh(self) -> None: - """Refreshes the token if needed, with thread safety.""" async with self._refresh_lock: if not self._should_refresh(): return @@ -136,7 +132,6 @@ async def _refresh(self) -> None: self._token = Token(token_value.decode("utf-8"), self._config.token_ttl) async def get_token(self) -> Token: - """Get the current token, refreshing it if expired.""" if self._should_refresh(): await self._refresh() assert self._token is not None From dfac021fd9b5c883ac1882611fec9c82e997de69 Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Tue, 18 Mar 2025 14:01:47 -0400 Subject: [PATCH 13/14] testing --- .../smithy_aws_core/credentials_resolvers/imds.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py index 184348aae..224501036 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import json import asyncio +import smithy_aws_core from dataclasses import dataclass from datetime import datetime, timedelta, timezone from typing import Literal @@ -18,6 +19,11 @@ from smithy_aws_core.identity import AWSCredentialsIdentity +_USER_AGENT_FIELD = Field( + name="User-Agent", + values=[f"aws-sdk-python-imds-client/{smithy_aws_core.__version__}"], +) + @dataclass(init=False) class Config: @@ -111,7 +117,7 @@ async def _refresh(self) -> None: return headers = Fields( [ - # TODO: Add user-agent + _USER_AGENT_FIELD, Field( name="x-aws-ec2-metadata-token-ttl-seconds", values=[str(self._config.token_ttl)], @@ -150,11 +156,11 @@ async def get(self, *, path: str) -> str: token = await self._token_cache.get_token() headers = Fields( [ - # TODO: Add user-agent + _USER_AGENT_FIELD, Field( name="x-aws-ec2-metadata-token", values=[token.value], - ) + ), ] ) request = HTTPRequest( From 690de0ebcb95f144f26db3f36bb5e0d6d4a5ccbc Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Tue, 18 Mar 2025 14:33:06 -0400 Subject: [PATCH 14/14] Remove port config option and use endpoint_uri.port instead --- .../smithy_aws_core/credentials_resolvers/imds.py | 7 +++---- .../tests/unit/credentials_resolvers/test_imds.py | 15 +++++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py index 224501036..1ba05ab4c 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py @@ -36,7 +36,6 @@ class Config: retry_strategy: RetryStrategy endpoint_uri: URI endpoint_mode: Literal["IPv4", "IPv6"] - port: int token_ttl: int def __init__( @@ -45,7 +44,6 @@ def __init__( retry_strategy: RetryStrategy | None = None, endpoint_uri: URI | None = None, endpoint_mode: Literal["IPv4", "IPv6"] = "IPv4", - port: int = 80, token_ttl: int = _MAX_TTL, ec2_instance_profile_name: str | None = None, ): @@ -53,7 +51,6 @@ def __init__( self.retry_strategy = retry_strategy or SimpleRetryStrategy(max_attempts=3) self.endpoint_mode = endpoint_mode self.endpoint_uri = self._resolve_endpoint(endpoint_uri, endpoint_mode) - self.port = port self.token_ttl = self._validate_token_ttl(token_ttl) self.ec2_instance_profile_name = ec2_instance_profile_name @@ -73,6 +70,7 @@ def _resolve_endpoint( return URI( scheme="http", host=self._HOST_MAPPING.get(endpoint_mode, self._HOST_MAPPING["IPv4"]), + port=80, ) @@ -129,6 +127,7 @@ async def _refresh(self) -> None: destination=URI( scheme=self._base_uri.scheme, host=self._base_uri.host, + port=self._base_uri.port, path=self._TOKEN_PATH, ), fields=headers, @@ -168,7 +167,7 @@ async def get(self, *, path: str) -> str: destination=URI( scheme=self._config.endpoint_uri.scheme, host=self._config.endpoint_uri.host, - port=self._config.port, + port=self._config.endpoint_uri.port, path=path, ), fields=headers, diff --git a/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py b/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py index ef9999c85..ebee43f17 100644 --- a/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py +++ b/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py @@ -22,9 +22,10 @@ def test_config_defaults(): config = Config() assert isinstance(config.retry_strategy, SimpleRetryStrategy) - assert config.endpoint_uri == URI(scheme="http", host=Config._HOST_MAPPING["IPv4"]) + assert config.endpoint_uri == URI( + scheme="http", host=Config._HOST_MAPPING["IPv4"], port=80 + ) assert config.endpoint_mode == "IPv4" - assert config.port == 80 assert config.token_ttl == 21600 @@ -38,15 +39,17 @@ def test_endpoint_resolution(): def test_config_uses_custom_endpoint(): # The custom endpoint should take precedence over IPv4 endpoint resolution. config = Config( - endpoint_uri=URI(scheme="http", host="test.host"), endpoint_mode="IPv4" + endpoint_uri=URI(scheme="https", host="test.host", port=123), + endpoint_mode="IPv4", ) - assert config.endpoint_uri == URI(scheme="http", host="test.host") + assert config.endpoint_uri == URI(scheme="https", host="test.host", port=123) # The custom endpoint takes precedence over IPv6 endpoint resolution. config = Config( - endpoint_uri=URI(scheme="http", host="test.host"), endpoint_mode="IPv6" + endpoint_uri=URI(scheme="https", host="test.host", port=123), + endpoint_mode="IPv6", ) - assert config.endpoint_uri == URI(scheme="http", host="test.host") + assert config.endpoint_uri == URI(scheme="https", host="test.host", port=123) def test_config_ttl_validation():