diff --git a/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py b/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py index 234682b4a..359e39329 100644 --- a/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py +++ b/packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py @@ -1,11 +1,11 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from collections.abc import AsyncIterable -from typing import Protocol, runtime_checkable, TYPE_CHECKING, Any +from typing import Protocol, runtime_checkable, TYPE_CHECKING from ...interfaces import URI, Endpoint, TypedProperties from ...interfaces import StreamingBlob as SyncStreamingBlob - +from ...documents import TypeRegistry if TYPE_CHECKING: from ...schemas import APIOperation @@ -126,7 +126,7 @@ async def deserialize_response[ operation: "APIOperation[OperationInput, OperationOutput]", request: I, response: O, - error_registry: Any, # TODO: add error registry + error_registry: TypeRegistry, context: TypedProperties, ) -> OperationOutput: """Deserializes the output from the tranport response or throws an exception. diff --git a/packages/smithy-core/src/smithy_core/documents.py b/packages/smithy-core/src/smithy_core/documents.py index b9af55cf5..828f7a53e 100644 --- a/packages/smithy-core/src/smithy_core/documents.py +++ b/packages/smithy-core/src/smithy_core/documents.py @@ -143,6 +143,11 @@ def shape_type(self) -> ShapeType: """The Smithy data model type for the underlying contents of the document.""" return self._type + @property + def discriminator(self) -> ShapeID: + """The shape ID that corresponds to the contents of the document.""" + return self._schema.id + def is_none(self) -> bool: """Indicates whether the document contains a null value.""" return self._value is None and self._raw_value is None @@ -633,3 +638,55 @@ def read_document(self, schema: "Schema") -> Document: @override def read_timestamp(self, schema: "Schema") -> datetime.datetime: return self._value.as_timestamp() + + +class TypeRegistry: + """A registry for on-demand deserialization of types by using a mapping of shape IDs + to their deserializers.""" + + def __init__( + self, + types: dict[ShapeID, type[DeserializeableShape]], + sub_registry: "TypeRegistry | None" = None, + ): + """Initialize a TypeRegistry. + + :param types: A mapping of ShapeID to the shapes they deserialize to. + :param sub_registry: A registry to delegate to if an ID is not found in types. + """ + self._types = types + self._sub_registry = sub_registry + + def get(self, shape: ShapeID) -> type[DeserializeableShape]: + """Get the deserializable shape for the given shape ID. + + :param shape: The shape ID to get from the registry. + :returns: The corresponding deserializable shape. + :raises KeyError: If the shape ID is not found in the registry. + """ + if shape in self._types: + return self._types[shape] + if self._sub_registry is not None: + return self._sub_registry.get(shape) + raise KeyError(f"Unknown shape: {shape}") + + def __getitem__(self, shape: ShapeID): + """Get the deserializable shape for the given shape ID. + + :param shape: The shape ID to get from the registry. + :returns: The corresponding deserializable shape. + :raises KeyError: If the shape ID is not found in the registry. + """ + return self.get(shape) + + def __contains__(self, item: object, /): + """Check if the registry contains the given shape. + + :param shape: The shape ID to check for. + """ + return item in self._types or ( + self._sub_registry is not None and item in self._sub_registry + ) + + def deserialize(self, document: Document) -> DeserializeableShape: + return document.as_shape(self.get(document.discriminator)) diff --git a/packages/smithy-core/src/smithy_core/schemas.py b/packages/smithy-core/src/smithy_core/schemas.py index 72f8f66f8..fc65b3c62 100644 --- a/packages/smithy-core/src/smithy_core/schemas.py +++ b/packages/smithy-core/src/smithy_core/schemas.py @@ -8,8 +8,8 @@ from .shapes import ShapeID, ShapeType from .traits import Trait, DynamicTrait, IdempotencyTokenTrait, StreamingTrait - if TYPE_CHECKING: + from .documents import TypeRegistry from .serializers import SerializeableShape from .deserializers import DeserializeableShape @@ -289,8 +289,7 @@ class APIOperation[I: "SerializeableShape", O: "DeserializeableShape"]: output_schema: Schema """The schema of the operation's output shape.""" - # TODO: Add a type registry for errors - error_registry: Any + error_registry: "TypeRegistry" """A TypeRegistry used to create errors.""" effective_auth_schemes: Sequence[ShapeID] diff --git a/packages/smithy-core/tests/unit/test_type_registry.py b/packages/smithy-core/tests/unit/test_type_registry.py new file mode 100644 index 000000000..f2d8bc403 --- /dev/null +++ b/packages/smithy-core/tests/unit/test_type_registry.py @@ -0,0 +1,69 @@ +from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer +from smithy_core.documents import Document, TypeRegistry +from smithy_core.schemas import Schema +from smithy_core.shapes import ShapeID, ShapeType +import pytest + + +def test_get(): + registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) + + result = registry[ShapeID("com.example#Test")] + + assert result == TestShape + + +def test_contains(): + registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) + + assert ShapeID("com.example#Test") in registry + + +def test_get_sub_registry(): + sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) + registry = TypeRegistry({}, sub_registry) + + result = registry[ShapeID("com.example#Test")] + + assert result == TestShape + + +def test_contains_sub_registry(): + sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) + registry = TypeRegistry({}, sub_registry) + + assert ShapeID("com.example#Test") in registry + + +def test_get_no_match(): + registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) + + with pytest.raises(KeyError, match="Unknown shape: com.example#Test2"): + registry[ShapeID("com.example#Test2")] + + +def test_contains_no_match(): + registry = TypeRegistry({ShapeID("com.example#Test"): TestShape}) + + assert ShapeID("com.example#Test2") not in registry + + +def test_deserialize(): + shape_id = ShapeID("com.example#Test") + registry = TypeRegistry({shape_id: TestShape}) + + result = registry.deserialize(Document("abc123", schema=TestShape.schema)) + + assert isinstance(result, TestShape) and result.value == "abc123" + + +class TestShape(DeserializeableShape): + __test__ = False + schema = Schema(id=ShapeID("com.example#Test"), shape_type=ShapeType.STRING) + + def __init__(self, value: str): + self.value = value + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> "TestShape": + return TestShape(deserializer.read_string(schema=TestShape.schema))