From 7e01290450a00e18fd6b58ae6c7a916657a9e522 Mon Sep 17 00:00:00 2001 From: Hayden Baker Date: Fri, 14 Mar 2025 13:45:37 -0700 Subject: [PATCH 1/3] Add a type registry --- .../smithy_core/aio/interfaces/__init__.py | 6 +-- .../smithy-core/src/smithy_core/documents.py | 5 ++ .../smithy-core/src/smithy_core/schemas.py | 5 +- .../src/smithy_core/type_registry.py | 29 +++++++++++ .../tests/unit/test_type_registry.py | 48 +++++++++++++++++++ 5 files changed, 87 insertions(+), 6 deletions(-) create mode 100644 packages/smithy-core/src/smithy_core/type_registry.py create mode 100644 packages/smithy-core/tests/unit/test_type_registry.py diff --git a/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py b/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py index 234682b4a..0d2318071 100644 --- a/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py +++ b/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py @@ -1,11 +1,11 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from collections.abc import AsyncIterable -from typing import Protocol, runtime_checkable, TYPE_CHECKING, Any +from typing import Protocol, runtime_checkable, TYPE_CHECKING from ...interfaces import URI, Endpoint, TypedProperties from ...interfaces import StreamingBlob as SyncStreamingBlob - +from ...type_registry import TypeRegistry if TYPE_CHECKING: from ...schemas import APIOperation @@ -126,7 +126,7 @@ async def deserialize_response[ operation: "APIOperation[OperationInput, OperationOutput]", request: I, response: O, - error_registry: Any, # TODO: add error registry + error_registry: TypeRegistry, context: TypedProperties, ) -> OperationOutput: """Deserializes the output from the tranport response or throws an exception. diff --git a/packages/smithy-core/src/smithy_core/documents.py b/packages/smithy-core/src/smithy_core/documents.py index b9af55cf5..13eecc707 100644 --- a/packages/smithy-core/src/smithy_core/documents.py +++ b/packages/smithy-core/src/smithy_core/documents.py @@ -143,6 +143,11 @@ def shape_type(self) -> ShapeType: """The Smithy data model type for the underlying contents of the document.""" return self._type + @property + def discriminator(self) -> ShapeID: + """The shape ID that corresponds to the contents of the document.""" + return self._schema.id + def is_none(self) -> bool: """Indicates whether the document contains a null value.""" return self._value is None and self._raw_value is None diff --git a/packages/smithy-core/src/smithy_core/schemas.py b/packages/smithy-core/src/smithy_core/schemas.py index 72f8f66f8..6fbbe4ed8 100644 --- a/packages/smithy-core/src/smithy_core/schemas.py +++ b/packages/smithy-core/src/smithy_core/schemas.py @@ -7,7 +7,7 @@ from .exceptions import ExpectationNotMetException, SmithyException from .shapes import ShapeID, ShapeType from .traits import Trait, DynamicTrait, IdempotencyTokenTrait, StreamingTrait - +from .type_registry import TypeRegistry if TYPE_CHECKING: from .serializers import SerializeableShape @@ -289,8 +289,7 @@ class APIOperation[I: "SerializeableShape", O: "DeserializeableShape"]: output_schema: Schema """The schema of the operation's output shape.""" - # TODO: Add a type registry for errors - error_registry: Any + error_registry: TypeRegistry """A TypeRegistry used to create errors.""" effective_auth_schemes: Sequence[ShapeID] diff --git a/packages/smithy-core/src/smithy_core/type_registry.py b/packages/smithy-core/src/smithy_core/type_registry.py new file mode 100644 index 000000000..b6b4f865d --- /dev/null +++ b/packages/smithy-core/src/smithy_core/type_registry.py @@ -0,0 +1,29 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from smithy_core.deserializers import ( + DeserializeableShape, +) +from smithy_core.documents import Document +from smithy_core.shapes import ShapeID + + +# A registry for on-demand deserialization of types by using a mapping of shape IDs to their deserializers. +class TypeRegistry: + def __init__( + self, + types: dict[ShapeID, type[DeserializeableShape]], + sub_registry: "TypeRegistry | None" = None, + ): + self._types = types + self._sub_registry = sub_registry + + def get(self, shape: ShapeID) -> type[DeserializeableShape]: + if shape in self._types: + return self._types[shape] + if self._sub_registry is not None: + return self._sub_registry.get(shape) + raise KeyError(f"Unknown shape: {shape}") + + def deserialize(self, document: Document) -> DeserializeableShape: + return document.as_shape(self.get(document.discriminator)) diff --git a/packages/smithy-core/tests/unit/test_type_registry.py b/packages/smithy-core/tests/unit/test_type_registry.py new file mode 100644 index 000000000..db17e0d94 --- /dev/null +++ b/packages/smithy-core/tests/unit/test_type_registry.py @@ -0,0 +1,48 @@ +from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer +from smithy_core.documents import Document +from smithy_core.schemas import Schema +from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.type_registry import TypeRegistry +import pytest + + +class TestTypeRegistry: + def test_get(self): + registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) + + result = registry.get(ShapeID("com.example#Test")) + + assert result == TestShape + + def test_get_sub_registry(self): + sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) + registry = TypeRegistry({}, sub_registry) + + result = registry.get(ShapeID("com.example#Test")) + + assert result == TestShape + + def test_get_no_match(self): + registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) + + with pytest.raises(KeyError, match="Unknown shape: com.example#Test2"): + registry.get(ShapeID("com.example#Test2")) + + def test_deserialize(self): + shape_id = ShapeID("com.example#Test") + registry = TypeRegistry({shape_id: TestShape}) + + result = registry.deserialize(Document("abc123", schema=TestShape.schema)) + + assert isinstance(result, TestShape) and result.value == "abc123" + + +class TestShape(DeserializeableShape): + schema = Schema(id=ShapeID("com.example#Test"), shape_type=ShapeType.STRING) + + def __init__(self, value: str): + self.value = value + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> "TestShape": + return TestShape(deserializer.read_string(schema=TestShape.schema)) From 037c68cd1eb58226cee7f2ecfc6dd65000727304 Mon Sep 17 00:00:00 2001 From: Hayden Baker Date: Fri, 14 Mar 2025 15:02:37 -0700 Subject: [PATCH 2/3] Move type registry into documents, fix circular import --- .../smithy_core/aio/interfaces/__init__.py | 2 +- .../smithy-core/src/smithy_core/documents.py | 21 +++++++++ .../smithy-core/src/smithy_core/schemas.py | 4 +- .../src/smithy_core/type_registry.py | 29 ------------- .../tests/unit/test_type_registry.py | 43 ++++++++++--------- 5 files changed, 46 insertions(+), 53 deletions(-) delete mode 100644 packages/smithy-core/src/smithy_core/type_registry.py diff --git a/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py b/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py index 0d2318071..359e39329 100644 --- a/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py +++ b/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py @@ -5,7 +5,7 @@ from ...interfaces import URI, Endpoint, TypedProperties from ...interfaces import StreamingBlob as SyncStreamingBlob -from ...type_registry import TypeRegistry +from ...documents import TypeRegistry if TYPE_CHECKING: from ...schemas import APIOperation diff --git a/packages/smithy-core/src/smithy_core/documents.py b/packages/smithy-core/src/smithy_core/documents.py index 13eecc707..043d75891 100644 --- a/packages/smithy-core/src/smithy_core/documents.py +++ b/packages/smithy-core/src/smithy_core/documents.py @@ -638,3 +638,24 @@ def read_document(self, schema: "Schema") -> Document: @override def read_timestamp(self, schema: "Schema") -> datetime.datetime: return self._value.as_timestamp() + + +# A registry for on-demand deserialization of types by using a mapping of shape IDs to their deserializers. +class TypeRegistry: + def __init__( + self, + types: dict[ShapeID, type[DeserializeableShape]], + sub_registry: "TypeRegistry | None" = None, + ): + self._types = types + self._sub_registry = sub_registry + + def get(self, shape: ShapeID) -> type[DeserializeableShape]: + if shape in self._types: + return self._types[shape] + if self._sub_registry is not None: + return self._sub_registry.get(shape) + raise KeyError(f"Unknown shape: {shape}") + + def deserialize(self, document: Document) -> DeserializeableShape: + return document.as_shape(self.get(document.discriminator)) diff --git a/packages/smithy-core/src/smithy_core/schemas.py b/packages/smithy-core/src/smithy_core/schemas.py index 6fbbe4ed8..fc65b3c62 100644 --- a/packages/smithy-core/src/smithy_core/schemas.py +++ b/packages/smithy-core/src/smithy_core/schemas.py @@ -7,9 +7,9 @@ from .exceptions import ExpectationNotMetException, SmithyException from .shapes import ShapeID, ShapeType from .traits import Trait, DynamicTrait, IdempotencyTokenTrait, StreamingTrait -from .type_registry import TypeRegistry if TYPE_CHECKING: + from .documents import TypeRegistry from .serializers import SerializeableShape from .deserializers import DeserializeableShape @@ -289,7 +289,7 @@ class APIOperation[I: "SerializeableShape", O: "DeserializeableShape"]: output_schema: Schema """The schema of the operation's output shape.""" - error_registry: TypeRegistry + error_registry: "TypeRegistry" """A TypeRegistry used to create errors.""" effective_auth_schemes: Sequence[ShapeID] diff --git a/packages/smithy-core/src/smithy_core/type_registry.py b/packages/smithy-core/src/smithy_core/type_registry.py deleted file mode 100644 index b6b4f865d..000000000 --- a/packages/smithy-core/src/smithy_core/type_registry.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -from smithy_core.deserializers import ( - DeserializeableShape, -) -from smithy_core.documents import Document -from smithy_core.shapes import ShapeID - - -# A registry for on-demand deserialization of types by using a mapping of shape IDs to their deserializers. -class TypeRegistry: - def __init__( - self, - types: dict[ShapeID, type[DeserializeableShape]], - sub_registry: "TypeRegistry | None" = None, - ): - self._types = types - self._sub_registry = sub_registry - - def get(self, shape: ShapeID) -> type[DeserializeableShape]: - if shape in self._types: - return self._types[shape] - if self._sub_registry is not None: - return self._sub_registry.get(shape) - raise KeyError(f"Unknown shape: {shape}") - - def deserialize(self, document: Document) -> DeserializeableShape: - return document.as_shape(self.get(document.discriminator)) diff --git a/packages/smithy-core/tests/unit/test_type_registry.py b/packages/smithy-core/tests/unit/test_type_registry.py index db17e0d94..df0cef6d9 100644 --- a/packages/smithy-core/tests/unit/test_type_registry.py +++ b/packages/smithy-core/tests/unit/test_type_registry.py @@ -1,40 +1,41 @@ from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer -from smithy_core.documents import Document +from smithy_core.documents import Document, TypeRegistry from smithy_core.schemas import Schema from smithy_core.shapes import ShapeID, ShapeType -from smithy_core.type_registry import TypeRegistry import pytest -class TestTypeRegistry: - def test_get(self): - registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) +def test_get(): + registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) - result = registry.get(ShapeID("com.example#Test")) + result = registry.get(ShapeID("com.example#Test")) - assert result == TestShape + assert result == TestShape - def test_get_sub_registry(self): - sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) - registry = TypeRegistry({}, sub_registry) - result = registry.get(ShapeID("com.example#Test")) +def test_get_sub_registry(): + sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) + registry = TypeRegistry({}, sub_registry) - assert result == TestShape + result = registry.get(ShapeID("com.example#Test")) - def test_get_no_match(self): - registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) + assert result == TestShape - with pytest.raises(KeyError, match="Unknown shape: com.example#Test2"): - registry.get(ShapeID("com.example#Test2")) - def test_deserialize(self): - shape_id = ShapeID("com.example#Test") - registry = TypeRegistry({shape_id: TestShape}) +def test_get_no_match(): + registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) - result = registry.deserialize(Document("abc123", schema=TestShape.schema)) + with pytest.raises(KeyError, match="Unknown shape: com.example#Test2"): + registry.get(ShapeID("com.example#Test2")) - assert isinstance(result, TestShape) and result.value == "abc123" + +def test_deserialize(): + shape_id = ShapeID("com.example#Test") + registry = TypeRegistry({shape_id: TestShape}) + + result = registry.deserialize(Document("abc123", schema=TestShape.schema)) + + assert isinstance(result, TestShape) and result.value == "abc123" class TestShape(DeserializeableShape): From 652815c3d4a7bbbaa44cdb4bc81e211a0033f620 Mon Sep 17 00:00:00 2001 From: Hayden Baker Date: Mon, 17 Mar 2025 08:25:23 -0700 Subject: [PATCH 3/3] Address comments --- .../smithy-core/src/smithy_core/documents.py | 33 ++++++++++++++++++- .../tests/unit/test_type_registry.py | 26 +++++++++++++-- 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/packages/smithy-core/src/smithy_core/documents.py b/packages/smithy-core/src/smithy_core/documents.py index 043d75891..828f7a53e 100644 --- a/packages/smithy-core/src/smithy_core/documents.py +++ b/packages/smithy-core/src/smithy_core/documents.py @@ -640,22 +640,53 @@ def read_timestamp(self, schema: "Schema") -> datetime.datetime: return self._value.as_timestamp() -# A registry for on-demand deserialization of types by using a mapping of shape IDs to their deserializers. class TypeRegistry: + """A registry for on-demand deserialization of types by using a mapping of shape IDs + to their deserializers.""" + def __init__( self, types: dict[ShapeID, type[DeserializeableShape]], sub_registry: "TypeRegistry | None" = None, ): + """Initialize a TypeRegistry. + + :param types: A mapping of ShapeID to the shapes they deserialize to. + :param sub_registry: A registry to delegate to if an ID is not found in types. + """ self._types = types self._sub_registry = sub_registry def get(self, shape: ShapeID) -> type[DeserializeableShape]: + """Get the deserializable shape for the given shape ID. + + :param shape: The shape ID to get from the registry. + :returns: The corresponding deserializable shape. + :raises KeyError: If the shape ID is not found in the registry. + """ if shape in self._types: return self._types[shape] if self._sub_registry is not None: return self._sub_registry.get(shape) raise KeyError(f"Unknown shape: {shape}") + def __getitem__(self, shape: ShapeID): + """Get the deserializable shape for the given shape ID. + + :param shape: The shape ID to get from the registry. + :returns: The corresponding deserializable shape. + :raises KeyError: If the shape ID is not found in the registry. + """ + return self.get(shape) + + def __contains__(self, item: object, /): + """Check if the registry contains the given shape. + + :param shape: The shape ID to check for. + """ + return item in self._types or ( + self._sub_registry is not None and item in self._sub_registry + ) + def deserialize(self, document: Document) -> DeserializeableShape: return document.as_shape(self.get(document.discriminator)) diff --git a/packages/smithy-core/tests/unit/test_type_registry.py b/packages/smithy-core/tests/unit/test_type_registry.py index df0cef6d9..f2d8bc403 100644 --- a/packages/smithy-core/tests/unit/test_type_registry.py +++ b/packages/smithy-core/tests/unit/test_type_registry.py @@ -8,25 +8,44 @@ def test_get(): registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) - result = registry.get(ShapeID("com.example#Test")) + result = registry[ShapeID("com.example#Test")] assert result == TestShape +def test_contains(): + registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) + + assert ShapeID("com.example#Test") in registry + + def test_get_sub_registry(): sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) registry = TypeRegistry({}, sub_registry) - result = registry.get(ShapeID("com.example#Test")) + result = registry[ShapeID("com.example#Test")] assert result == TestShape +def test_contains_sub_registry(): + sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) + registry = TypeRegistry({}, sub_registry) + + assert ShapeID("com.example#Test") in registry + + def test_get_no_match(): registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) with pytest.raises(KeyError, match="Unknown shape: com.example#Test2"): - registry.get(ShapeID("com.example#Test2")) + registry[ShapeID("com.example#Test2")] + + +def test_contains_no_match(): + registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) + + assert ShapeID("com.example#Test2") not in registry def test_deserialize(): @@ -39,6 +58,7 @@ def test_deserialize(): class TestShape(DeserializeableShape): + __test__ = False schema = Schema(id=ShapeID("com.example#Test"), shape_type=ShapeType.STRING) def __init__(self, value: str):