diff --git a/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py b/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py index f20ec0187..234682b4a 100644 --- a/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py +++ b/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py @@ -3,7 +3,7 @@ from collections.abc import AsyncIterable from typing import Protocol, runtime_checkable, TYPE_CHECKING, Any -from ...interfaces import URI, Endpoint +from ...interfaces import URI, Endpoint, TypedProperties from ...interfaces import StreamingBlob as SyncStreamingBlob @@ -93,7 +93,7 @@ def serialize_request[ operation: "APIOperation[OperationInput, OperationOutput]", input: OperationInput, endpoint: URI, - context: dict[str, Any], + context: TypedProperties, ) -> I: """Serialize an operation input into a transport request. @@ -127,7 +127,7 @@ async def deserialize_response[ request: I, response: O, error_registry: Any, # TODO: add error registry - context: dict[str, Any], # TODO: replace with a typed context bag + context: TypedProperties, ) -> OperationOutput: """Deserializes the output from the tranport response or throws an exception. diff --git a/packages/smithy-core/src/smithy_core/interceptors.py b/packages/smithy-core/src/smithy_core/interceptors.py index 7b358736d..850d1fa8d 100644 --- a/packages/smithy-core/src/smithy_core/interceptors.py +++ b/packages/smithy-core/src/smithy_core/interceptors.py @@ -1,7 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from copy import copy, deepcopy -from typing import Any, TypeVar +from typing import TypeVar + +from .types import TypedProperties Request = TypeVar("Request") Response = TypeVar("Response") @@ -34,7 +36,7 @@ def __init__( self._response = response self._transport_request = transport_request self._transport_response = transport_response - self._properties: dict[str, Any] = {} + self._properties = TypedProperties() @property def request(self) -> Request: @@ -73,7 +75,7 @@ def transport_response(self) -> TransportResponse: return self._transport_response @property - def properties(self) -> dict[str, Any]: + def properties(self) -> TypedProperties: """Retrieve the generic property bag. These untyped properties will be made available to all other interceptors or diff --git a/packages/smithy-core/src/smithy_core/interfaces/__init__.py b/packages/smithy-core/src/smithy_core/interfaces/__init__.py index 7cb030cae..1311b732b 100644 --- a/packages/smithy-core/src/smithy_core/interfaces/__init__.py +++ b/packages/smithy-core/src/smithy_core/interfaces/__init__.py @@ -1,7 +1,17 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from asyncio import iscoroutinefunction -from typing import Protocol, runtime_checkable, Any, TypeGuard +from typing import ( + Protocol, + runtime_checkable, + Any, + TypeGuard, + overload, + Iterator, + KeysView, + ValuesView, + ItemsView, +) class URI(Protocol): @@ -99,3 +109,91 @@ class Endpoint(Protocol): For example, in some AWS use cases this might contain HTTP headers to add to each request. """ + + +@runtime_checkable +class PropertyKey[T](Protocol): + """A typed properties key. + + Used with :py:class:`Context` to set and get typed values. + + For a concrete implementation, see :py:class:`smithy_core.types.PropertyKey`. + """ + + key: str + """The string key used to access the value.""" + + value_type: type[T] + """The type of the associated value in the properties bag.""" + + def __str__(self) -> str: + return self.key + + +# This is currently strongly tied to being compatible with a dict[str, Any], but we +# could remove that to allow for potentially more efficient maps. That might introduce +# unacceptable usability penalties or footguns though. +@runtime_checkable +class TypedProperties(Protocol): + """A properties map with typed setters and getters. + + Keys can be either a string or a :py:class:`PropertyKey`. Using a PropertyKey instead + of a string enables type checkers to narrow to the associated value type rather + than having to use Any. + + Letting the value be either a string or PropertyKey allows consumers who care about + typing to get it, and those who don't care about typing to not have to think about + it. + + ..code-block:: python + + foo = PropertyKey(key="foo", value_type=str) + properties = TypedProperties() + properties[foo] = "bar" + + assert assert_type(properties[foo], str) == "bar + assert assert_type(properties["foo"], Any) == "bar + + + For a concrete implementation, see :py:class:`smithy_core.types.TypedProperties`. + """ + + @overload + def __getitem__[T](self, key: PropertyKey[T]) -> T: ... + @overload + def __getitem__(self, key: str) -> Any: ... + + @overload + def __setitem__[T](self, key: PropertyKey[T], value: T) -> None: ... + @overload + def __setitem__(self, key: str, value: Any) -> None: ... + + def __delitem__(self, key: str | PropertyKey[Any]) -> None: ... + + @overload + def get[T](self, key: PropertyKey[T], default: None = None) -> T | None: ... + @overload + def get[T](self, key: PropertyKey[T], default: T) -> T: ... + @overload + def get[T, DT](self, key: PropertyKey[T], default: DT) -> T | DT: ... + @overload + def get(self, key: str, default: None = None) -> Any | None: ... + @overload + def get[T](self, key: str, default: T) -> Any | T: ... + + @overload + def pop[T](self, key: PropertyKey[T], default: None = None) -> T | None: ... + @overload + def pop[T](self, key: PropertyKey[T], default: T) -> T: ... + @overload + def pop[T, DT](self, key: PropertyKey[T], default: DT) -> T | DT: ... + @overload + def pop(self, key: str, default: None = None) -> Any | None: ... + @overload + def pop[T](self, key: str, default: T) -> Any | T: ... + + def __iter__(self) -> Iterator[str]: ... + def items(self) -> ItemsView[str, Any]: ... + def keys(self) -> KeysView[str]: ... + def values(self) -> ValuesView[Any]: ... + def __contains__(self, key: object) -> bool: ... diff --git a/packages/smithy-core/src/smithy_core/types.py b/packages/smithy-core/src/smithy_core/types.py index f9ce45b71..3b1613c37 100644 --- a/packages/smithy-core/src/smithy_core/types.py +++ b/packages/smithy-core/src/smithy_core/types.py @@ -2,11 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import json import re +import sys +from collections import UserDict from collections.abc import Mapping, Sequence from datetime import datetime from email.utils import format_datetime, parsedate_to_datetime from enum import Enum -from typing import Any +from typing import Any, overload from dataclasses import dataclass from .exceptions import ExpectationNotMetException @@ -17,6 +19,8 @@ serialize_epoch_seconds, serialize_rfc3339, ) +from .interfaces import PropertyKey as _PropertyKey +from .interfaces import TypedProperties as _TypedProperties _GREEDY_LABEL_RE = re.compile(r"\{(\w+)\+\}") @@ -153,3 +157,99 @@ def format(self, *args: object, **kwargs: str) -> str: f'Path must not contain empty segments, but was "{result}".' ) return result + + +@dataclass(kw_only=True, frozen=True, slots=True, init=False) +class PropertyKey[T](_PropertyKey[T]): + """A typed property key.""" + + key: str + """The string key used to access the value.""" + + value_type: type[T] + """The type of the associated value in the property bag.""" + + def __init__(self, *, key: str, value_type: type[T]) -> None: + # Intern the key to speed up dict access + object.__setattr__(self, "key", sys.intern(key)) + object.__setattr__(self, "value_type", value_type) + + +class TypedProperties(UserDict[str, Any], _TypedProperties): + """A map with typed setters and getters. + + Keys can be either a string or a :py:class:`smithy_core.interfaces.PropertyKey`. + Using a PropertyKey instead of a string enables type checkers to narrow to the + associated value type rather than having to use Any. + + Letting the value be either a string or PropertyKey allows consumers who care about + typing to get it, and those who don't care about typing to not have to think about + it. + + ..code-block:: python + + foo = PropertyKey(key="foo", value_type=str) + properties = TypedProperties() + properties[foo] = "bar" + + assert assert_type(properties[foo], str) == "bar + assert assert_type(properties["foo"], Any) == "bar + """ + + @overload + def __getitem__[T](self, key: _PropertyKey[T]) -> T: ... + @overload + def __getitem__(self, key: str) -> Any: ... + def __getitem__(self, key: str | _PropertyKey[Any]) -> Any: + return self.data[key if isinstance(key, str) else key.key] + + @overload + def __setitem__[T](self, key: _PropertyKey[T], value: T) -> None: ... + @overload + def __setitem__(self, key: str, value: Any) -> None: ... + def __setitem__(self, key: str | _PropertyKey[Any], value: Any) -> None: + if isinstance(key, _PropertyKey): + if not isinstance(value, key.value_type): + raise ValueError( + f"Expected value type of {key.value_type}, but was {type(value)}" + ) + key = key.key + self.data[key] = value + + def __delitem__(self, key: str | _PropertyKey[Any]) -> None: + del self.data[key if isinstance(key, str) else key.key] + + def __contains__(self, key: object) -> bool: + return super().__contains__(key.key if isinstance(key, _PropertyKey) else key) + + @overload + def get[T](self, key: _PropertyKey[T], default: None = None) -> T | None: ... + @overload + def get[T](self, key: _PropertyKey[T], default: T) -> T: ... + @overload + def get[T, DT](self, key: _PropertyKey[T], default: DT) -> T | DT: ... + @overload + def get(self, key: str, default: None = None) -> Any | None: ... + @overload + def get[T](self, key: str, default: T) -> Any | T: ... + + # pyright has trouble detecting compatible overrides when both the superclass + # and subclass have overloads. + def get(self, key: str | _PropertyKey[Any], default: Any = None) -> Any: # type: ignore + return self.data.get(key if isinstance(key, str) else key.key, default) + + @overload + def pop[T](self, key: _PropertyKey[T], default: None = None) -> T | None: ... + @overload + def pop[T](self, key: _PropertyKey[T], default: T) -> T: ... + @overload + def pop[T, DT](self, key: _PropertyKey[T], default: DT) -> T | DT: ... + @overload + def pop(self, key: str, default: None = None) -> Any | None: ... + @overload + def pop[T](self, key: str, default: T) -> Any | T: ... + + # pyright has trouble detecting compatible overrides when both the superclass + # and subclass have overloads. + def pop(self, key: str | _PropertyKey[Any], default: Any = None) -> Any: # type: ignore + return self.data.pop(key if isinstance(key, str) else key.key, default) diff --git a/packages/smithy-core/tests/unit/test_types.py b/packages/smithy-core/tests/unit/test_types.py index f1d8b6205..d36c8b06a 100644 --- a/packages/smithy-core/tests/unit/test_types.py +++ b/packages/smithy-core/tests/unit/test_types.py @@ -3,11 +3,19 @@ # pyright: reportPrivateUsage=false from datetime import UTC, datetime +from typing import Any, assert_type import pytest from smithy_core.exceptions import ExpectationNotMetException -from smithy_core.types import JsonBlob, JsonString, TimestampFormat, PathPattern +from smithy_core.types import ( + JsonBlob, + JsonString, + TimestampFormat, + PathPattern, + PropertyKey, + TypedProperties, +) def test_json_string() -> None: @@ -219,3 +227,99 @@ def test_path_pattern_disallows_empty_segments(greedy: bool, value: str): pattern = PathPattern("/{foo+}/" if greedy else "/{foo}/") with pytest.raises(ValueError): pattern.format(foo=value) + + +def test_properties_typed_get() -> None: + foo_key = PropertyKey(key="foo", value_type=str) + properties = TypedProperties() + properties[foo_key] = "bar" + + assert assert_type(properties[foo_key], str) == "bar" + assert assert_type(properties["foo"], Any) == "bar" + + assert assert_type(properties.get(foo_key), str | None) == "bar" + assert assert_type(properties.get(foo_key, "spam"), str) == "bar" + assert assert_type(properties.get(foo_key, 1), str | int) == "bar" + + assert assert_type(properties.get("foo"), Any | None) == "bar" + assert assert_type(properties.get("foo", "spam"), Any | str) == "bar" + assert assert_type(properties.get("foo", 1), Any | int) == "bar" + + baz_key = PropertyKey(key="baz", value_type=str) + assert assert_type(properties.get(baz_key), str | None) is None + assert assert_type(properties.get(baz_key, "spam"), str) == "spam" + assert assert_type(properties.get(baz_key, 1), str | int) == 1 + + assert assert_type(properties.get("baz"), Any | None) is None + assert assert_type(properties.get("baz", "spam"), Any | str) == "spam" + assert assert_type(properties.get("baz", 1), Any | int) == 1 + + +def test_properties_typed_set() -> None: + foo_key = PropertyKey(key="foo", value_type=str) + properties = TypedProperties() + + properties[foo_key] = "foo" + assert properties.data["foo"] == "foo" + + with pytest.raises(ValueError): + properties[foo_key] = b"foo" # type: ignore + + +def test_properties_del() -> None: + foo_key = PropertyKey(key="foo", value_type=str) + properties = TypedProperties() + properties[foo_key] = "bar" + + assert "foo" in properties.data + del properties[foo_key] + assert "foo" not in properties.data + + properties[foo_key] = "bar" + + assert "foo" in properties.data + del properties["foo"] + assert "foo" not in properties.data + + +def test_properties_contains() -> None: + foo_key = PropertyKey(key="foo", value_type=str) + bar_key = PropertyKey(key="bar", value_type=str) + properties = TypedProperties() + properties[foo_key] = "bar" + + assert "foo" in properties + assert foo_key in properties + assert "bar" not in properties + assert bar_key not in properties + + +def test_properties_typed_pop() -> None: + foo_key = PropertyKey(key="foo", value_type=str) + properties = TypedProperties() + + properties[foo_key] = "bar" + assert assert_type(properties.pop(foo_key), str | None) == "bar" + assert "foo" not in properties.data + + properties[foo_key] = "bar" + assert assert_type(properties.pop(foo_key, "foo"), str) == "bar" + assert "foo" not in properties.data + + properties[foo_key] = "bar" + assert assert_type(properties.pop(foo_key, 1), str | int) == "bar" + assert "foo" not in properties.data + + properties[foo_key] = "bar" + assert assert_type(properties.pop("foo"), Any | None) == "bar" + assert "foo" not in properties.data + + properties[foo_key] = "bar" + assert assert_type(properties.pop("foo", "baz"), Any | str) == "bar" + assert "foo" not in properties.data + + properties[foo_key] = "bar" + assert assert_type(properties.pop("foo", 1), Any | int) == "bar" + assert "foo" not in properties.data + + assert properties.pop(foo_key) is None diff --git a/pyproject.toml b/pyproject.toml index ae3e34ba0..b10ae5f41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ dependencies = [] dev = [ "black>=25.1.0", "docformatter>=1.7.5", - "pyright>=1.1.394", + "pyright>=1.1.396", "pytest>=8.3.4", "pytest-asyncio>=0.25.3", "pytest-cov>=6.0.0", diff --git a/uv.lock b/uv.lock index 729168555..788a60bc5 100644 --- a/uv.lock +++ b/uv.lock @@ -551,15 +551,15 @@ wheels = [ [[package]] name = "pyright" -version = "1.1.394" +version = "1.1.396" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodeenv" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b1/e4/79f4d8a342eed6790fdebdb500e95062f319ee3d7d75ae27304ff995ae8c/pyright-1.1.394.tar.gz", hash = "sha256:56f2a3ab88c5214a451eb71d8f2792b7700434f841ea219119ade7f42ca93608", size = 3809348 } +sdist = { url = "https://files.pythonhosted.org/packages/bd/73/f20cb1dea1bdc1774e7f860fb69dc0718c7d8dea854a345faec845eb086a/pyright-1.1.396.tar.gz", hash = "sha256:142901f5908f5a0895be3d3befcc18bedcdb8cc1798deecaec86ef7233a29b03", size = 3814400 } wheels = [ - { url = "https://files.pythonhosted.org/packages/d6/4c/50c74e3d589517a9712a61a26143b587dba6285434a17aebf2ce6b82d2c3/pyright-1.1.394-py3-none-any.whl", hash = "sha256:5f74cce0a795a295fb768759bbeeec62561215dea657edcaab48a932b031ddbb", size = 5679540 }, + { url = "https://files.pythonhosted.org/packages/80/be/ecb7cfb42d242b7ee764b52e6ff4782beeec00e3b943a3ec832b281f9da6/pyright-1.1.396-py3-none-any.whl", hash = "sha256:c635e473095b9138c471abccca22b9fedbe63858e0b40d4fc4b67da041891844", size = 5689355 }, ] [[package]] @@ -752,7 +752,7 @@ dev = [ dev = [ { name = "black", specifier = ">=25.1.0" }, { name = "docformatter", specifier = ">=1.7.5" }, - { name = "pyright", specifier = ">=1.1.394" }, + { name = "pyright", specifier = ">=1.1.396" }, { name = "pytest", specifier = ">=8.3.4" }, { name = "pytest-asyncio", specifier = ">=0.25.3" }, { name = "pytest-cov", specifier = ">=6.0.0" },