Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from collections.abc import AsyncIterable
from typing import Protocol, runtime_checkable, TYPE_CHECKING, Any
from typing import Protocol, runtime_checkable, TYPE_CHECKING

from ...interfaces import URI, Endpoint, TypedProperties
from ...interfaces import StreamingBlob as SyncStreamingBlob

from ...documents import TypeRegistry

if TYPE_CHECKING:
from ...schemas import APIOperation
Expand Down Expand Up @@ -126,7 +126,7 @@ async def deserialize_response[
operation: "APIOperation[OperationInput, OperationOutput]",
request: I,
response: O,
error_registry: Any, # TODO: add error registry
error_registry: TypeRegistry,
context: TypedProperties,
) -> OperationOutput:
"""Deserializes the output from the tranport response or throws an exception.
Expand Down
57 changes: 57 additions & 0 deletions packages/smithy-core/src/smithy_core/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ def shape_type(self) -> ShapeType:
"""The Smithy data model type for the underlying contents of the document."""
return self._type

@property
def discriminator(self) -> ShapeID:
"""The shape ID that corresponds to the contents of the document."""
return self._schema.id

def is_none(self) -> bool:
"""Indicates whether the document contains a null value."""
return self._value is None and self._raw_value is None
Expand Down Expand Up @@ -633,3 +638,55 @@ def read_document(self, schema: "Schema") -> Document:
@override
def read_timestamp(self, schema: "Schema") -> datetime.datetime:
return self._value.as_timestamp()


class TypeRegistry:
"""A registry for on-demand deserialization of types by using a mapping of shape IDs
to their deserializers."""

def __init__(
self,
types: dict[ShapeID, type[DeserializeableShape]],
sub_registry: "TypeRegistry | None" = None,
):
"""Initialize a TypeRegistry.

:param types: A mapping of ShapeID to the shapes they deserialize to.
:param sub_registry: A registry to delegate to if an ID is not found in types.
"""
self._types = types
self._sub_registry = sub_registry

def get(self, shape: ShapeID) -> type[DeserializeableShape]:
"""Get the deserializable shape for the given shape ID.

:param shape: The shape ID to get from the registry.
:returns: The corresponding deserializable shape.
:raises KeyError: If the shape ID is not found in the registry.
"""
if shape in self._types:
return self._types[shape]
if self._sub_registry is not None:
return self._sub_registry.get(shape)
raise KeyError(f"Unknown shape: {shape}")

def __getitem__(self, shape: ShapeID):
"""Get the deserializable shape for the given shape ID.

:param shape: The shape ID to get from the registry.
:returns: The corresponding deserializable shape.
:raises KeyError: If the shape ID is not found in the registry.
"""
return self.get(shape)

def __contains__(self, item: object, /):
"""Check if the registry contains the given shape.

:param shape: The shape ID to check for.
"""
return item in self._types or (
self._sub_registry is not None and item in self._sub_registry
)

def deserialize(self, document: Document) -> DeserializeableShape:
return document.as_shape(self.get(document.discriminator))
5 changes: 2 additions & 3 deletions packages/smithy-core/src/smithy_core/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from .shapes import ShapeID, ShapeType
from .traits import Trait, DynamicTrait, IdempotencyTokenTrait, StreamingTrait


if TYPE_CHECKING:
from .documents import TypeRegistry
from .serializers import SerializeableShape
from .deserializers import DeserializeableShape

Expand Down Expand Up @@ -289,8 +289,7 @@ class APIOperation[I: "SerializeableShape", O: "DeserializeableShape"]:
output_schema: Schema
"""The schema of the operation's output shape."""

# TODO: Add a type registry for errors
error_registry: Any
error_registry: "TypeRegistry"
"""A TypeRegistry used to create errors."""

effective_auth_schemes: Sequence[ShapeID]
Expand Down
69 changes: 69 additions & 0 deletions packages/smithy-core/tests/unit/test_type_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer
from smithy_core.documents import Document, TypeRegistry
from smithy_core.schemas import Schema
from smithy_core.shapes import ShapeID, ShapeType
import pytest


def test_get():
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})

result = registry[ShapeID("com.example#Test")]

assert result == TestShape


def test_contains():
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})

assert ShapeID("com.example#Test") in registry


def test_get_sub_registry():
sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
registry = TypeRegistry({}, sub_registry)

result = registry[ShapeID("com.example#Test")]

assert result == TestShape


def test_contains_sub_registry():
sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
registry = TypeRegistry({}, sub_registry)

assert ShapeID("com.example#Test") in registry


def test_get_no_match():
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})

with pytest.raises(KeyError, match="Unknown shape: com.example#Test2"):
registry[ShapeID("com.example#Test2")]


def test_contains_no_match():
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})

assert ShapeID("com.example#Test2") not in registry


def test_deserialize():
shape_id = ShapeID("com.example#Test")
registry = TypeRegistry({shape_id: TestShape})

result = registry.deserialize(Document("abc123", schema=TestShape.schema))

assert isinstance(result, TestShape) and result.value == "abc123"


class TestShape(DeserializeableShape):
__test__ = False
schema = Schema(id=ShapeID("com.example#Test"), shape_type=ShapeType.STRING)

def __init__(self, value: str):
self.value = value

@classmethod
def deserialize(cls, deserializer: ShapeDeserializer) -> "TestShape":
return TestShape(deserializer.read_string(schema=TestShape.schema))