Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from datetime import datetime
from typing import Protocol
from typing import Protocol, runtime_checkable


class Identity(Protocol):
Expand All @@ -18,3 +18,24 @@ class Identity(Protocol):
def is_expired(self) -> bool:
"""Whether the identity is expired."""
...


@runtime_checkable
class AWSCredentialsIdentity(Protocol):
"""AWS Credentials Identity."""

# The access key ID.
access_key_id: str

# The secret access key.
secret_access_key: str

# The session token.
session_token: str | None

expiration: datetime | None = None

@property
def is_expired(self) -> bool:
"""Whether the identity is expired."""
...
21 changes: 11 additions & 10 deletions packages/aws-sdk-signers/src/aws_sdk_signers/signers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .interfaces.io import AsyncSeekable, Seekable
from ._http import URI, AWSRequest, Field
from ._identity import AWSCredentialIdentity
from .interfaces.identity import AWSCredentialsIdentity as _AWSCredentialsIdentity
from ._io import AsyncBytesReader
from .exceptions import AWSSDKWarning, MissingExpectedParameterException

Expand Down Expand Up @@ -49,14 +50,14 @@ def sign(
self,
*,
signing_properties: SigV4SigningProperties,
request: AWSRequest,
http_request: AWSRequest,
identity: AWSCredentialIdentity,
) -> AWSRequest:
"""Generate and apply a SigV4 Signature to a copy of the supplied request.

:param signing_properties: SigV4SigningProperties to define signing primitives
such as the target service, region, and date.
:param request: An AWSRequest to sign prior to sending to the service.
:param http_request: An AWSRequest to sign prior to sending to the service.
:param identity: A set of credentials representing an AWS Identity or role
capacity.
"""
Expand All @@ -68,7 +69,7 @@ def sign(
)
assert "date" in new_signing_properties

new_request = self._generate_new_request(request=request)
new_request = self._generate_new_request(request=http_request)
self._apply_required_fields(
request=new_request,
signing_properties=new_signing_properties,
Expand Down Expand Up @@ -159,7 +160,7 @@ def _hash(self, key: bytes, value: str) -> bytes:

def _validate_identity(self, *, identity: AWSCredentialIdentity) -> None:
"""Perform runtime and expiration checks before attempting signing."""
if not isinstance(identity, AWSCredentialIdentity): # pyright: ignore
if not isinstance(identity, _AWSCredentialsIdentity): # pyright: ignore
raise ValueError(
"Received unexpected value for identity parameter. Expected "
f"AWSCredentialIdentity but received {type(identity)}."
Expand Down Expand Up @@ -413,14 +414,14 @@ async def sign(
self,
*,
signing_properties: SigV4SigningProperties,
request: AWSRequest,
http_request: AWSRequest,
identity: AWSCredentialIdentity,
) -> AWSRequest:
"""Generate and apply a SigV4 Signature to a copy of the supplied request.

:param signing_properties: SigV4SigningProperties to define signing primitives
such as the target service, region, and date.
:param request: An AWSRequest to sign prior to sending to the service.
:param http_request: An AWSRequest to sign prior to sending to the service.
:param identity: A set of credentials representing an AWS Identity or role
capacity.
"""
Expand All @@ -431,7 +432,7 @@ async def sign(
new_signing_properties = await self._normalize_signing_properties(
signing_properties=signing_properties
)
new_request = await self._generate_new_request(request=request)
new_request = await self._generate_new_request(request=http_request)
await self._apply_required_fields(
request=new_request,
signing_properties=new_signing_properties,
Expand All @@ -441,7 +442,7 @@ async def sign(
# Construct core signing components
canonical_request = await self.canonical_request(
signing_properties=signing_properties,
request=request,
request=http_request,
)
string_to_sign = await self.string_to_sign(
canonical_request=canonical_request,
Expand All @@ -453,7 +454,7 @@ async def sign(
signing_properties=new_signing_properties,
)

signing_fields = await self._normalize_signing_fields(request=request)
signing_fields = await self._normalize_signing_fields(request=http_request)
credential_scope = await self._scope(signing_properties=new_signing_properties)
credential = f"{identity.access_key_id}/{credential_scope}"
authorization = await self.generate_authorization_field(
Expand Down Expand Up @@ -522,7 +523,7 @@ async def _hash(self, key: bytes, value: str) -> bytes:

async def _validate_identity(self, *, identity: AWSCredentialIdentity) -> None:
"""Perform runtime and expiration checks before attempting signing."""
if not isinstance(identity, AWSCredentialIdentity): # pyright: ignore
if not isinstance(identity, _AWSCredentialsIdentity): # pyright: ignore
raise ValueError(
"Received unexpected value for identity parameter. Expected "
f"AWSCredentialIdentity but received {type(identity)}."
Expand Down
4 changes: 2 additions & 2 deletions packages/aws-sdk-signers/tests/unit/auth/test_sigv4.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _test_signature_version_4_sync(test_case_name: str, signer: SigV4Signer) ->
with pytest.warns(AWSSDKWarning):
signed_request = signer.sign(
signing_properties=signing_props,
request=request,
http_request=request,
identity=test_case.credentials,
)
assert (
Expand Down Expand Up @@ -154,7 +154,7 @@ async def _test_signature_version_4_async(
with pytest.warns(AWSSDKWarning):
signed_request = await signer.sign(
signing_properties=signing_props,
request=request,
http_request=request,
identity=test_case.credentials,
)
assert (
Expand Down
12 changes: 6 additions & 6 deletions packages/aws-sdk-signers/tests/unit/test_signers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_sign(
) -> None:
signed_request = self.SIGV4_SYNC_SIGNER.sign(
signing_properties=signing_properties,
request=aws_request,
http_request=aws_request,
identity=aws_identity,
)
assert isinstance(signed_request, AWSRequest)
Expand All @@ -86,7 +86,7 @@ def test_sign_with_invalid_identity(
with pytest.raises(ValueError):
self.SIGV4_SYNC_SIGNER.sign(
signing_properties=signing_properties,
request=aws_request,
http_request=aws_request,
identity=identity,
)

Expand All @@ -102,7 +102,7 @@ def test_sign_with_expired_identity(
with pytest.raises(ValueError):
self.SIGV4_SYNC_SIGNER.sign(
signing_properties=signing_properties,
request=aws_request,
http_request=aws_request,
identity=identity,
)

Expand All @@ -118,7 +118,7 @@ async def test_sign(
) -> None:
signed_request = await self.SIGV4_ASYNC_SIGNER.sign(
signing_properties=signing_properties,
request=aws_request,
http_request=aws_request,
identity=aws_identity,
)
assert isinstance(signed_request, AWSRequest)
Expand All @@ -137,7 +137,7 @@ async def test_sign_with_invalid_identity(
with pytest.raises(ValueError):
await self.SIGV4_ASYNC_SIGNER.sign(
signing_properties=signing_properties,
request=aws_request,
http_request=aws_request,
identity=identity,
)

Expand All @@ -153,6 +153,6 @@ async def test_sign_with_expired_identity(
with pytest.raises(ValueError):
await self.SIGV4_ASYNC_SIGNER.sign(
signing_properties=signing_properties,
request=aws_request,
http_request=aws_request,
identity=identity,
)