diff --git a/designs/serialization.md b/designs/serialization.md index 497d275ab..e970260f7 100644 --- a/designs/serialization.md +++ b/designs/serialization.md @@ -317,6 +317,9 @@ class ShapeSerializer(Protocol): def write_document(self, schema: "Schema", value: "Document") -> None: ... + def write_data_stream(self, schema: "Schema", value: StreamingBlob) -> None: + raise NotImplementedError() + @runtime_checkable class MapSerializer(Protocol): @@ -531,6 +534,9 @@ class ShapeDeserializer(Protocol): def read_timestamp(self, schema: "Schema") -> datetime.datetime: ... + def read_data_stream(self, schema: "Schema") -> StreamingBlob: + raise NotImplementedError() + @runtime_checkable class DeserializeableShape(Protocol): diff --git a/packages/smithy-core/src/smithy_core/deserializers.py b/packages/smithy-core/src/smithy_core/deserializers.py index 148cec41f..b3045c7d1 100644 --- a/packages/smithy-core/src/smithy_core/deserializers.py +++ b/packages/smithy-core/src/smithy_core/deserializers.py @@ -3,11 +3,12 @@ from decimal import Decimal from typing import TYPE_CHECKING, Never, Protocol, Self, runtime_checkable -from .exceptions import SmithyException +from .exceptions import SmithyException, UnsupportedStreamException if TYPE_CHECKING: from .documents import Document from .schemas import Schema + from .aio.interfaces import StreamingBlob as _Stream @runtime_checkable @@ -171,6 +172,22 @@ def read_timestamp(self, schema: "Schema") -> datetime.datetime: """ ... + def read_data_stream(self, schema: "Schema") -> "_Stream": + """Read a data stream from the underlying data. + + The data itself MUST NOT be read by this method. The value returned is intended + to be read later by the consumer. In an HTTP implementation, for example, this + would directly return the HTTP body stream. The stream MAY be wrapped to provide + a more consistent interface or to avoid exposing implementation details. + + Data streams are only supported at the top-level input and output for + operations. + + :param schema: The shape's schema. + :returns: A data stream derived from the underlying data. + """ + raise UnsupportedStreamException() + class SpecificShapeDeserializer(ShapeDeserializer): """Expects to deserialize a specific kind of shape, failing if other shapes are @@ -247,6 +264,9 @@ def read_document(self, schema: "Schema") -> "Document": def read_timestamp(self, schema: "Schema") -> datetime.datetime: self._invalid_state(schema) + def read_data_stream(self, schema: "Schema") -> "_Stream": + self._invalid_state(schema) + @runtime_checkable class DeserializeableShape(Protocol): diff --git a/packages/smithy-core/src/smithy_core/exceptions.py b/packages/smithy-core/src/smithy_core/exceptions.py index 2a86c17c5..104390567 100644 --- a/packages/smithy-core/src/smithy_core/exceptions.py +++ b/packages/smithy-core/src/smithy_core/exceptions.py @@ -28,3 +28,8 @@ class MissingDependencyException(SmithyException): class AsyncBodyException(SmithyException): """Exception indicating that a request with an async body type was created in a sync context.""" + + +class UnsupportedStreamException(SmithyException): + """Indicates that a serializer or deserializer's stream method was called, but data + streams are not supported.""" diff --git a/packages/smithy-core/src/smithy_core/interfaces/__init__.py b/packages/smithy-core/src/smithy_core/interfaces/__init__.py index a9547c175..d39e4b755 100644 --- a/packages/smithy-core/src/smithy_core/interfaces/__init__.py +++ b/packages/smithy-core/src/smithy_core/interfaces/__init__.py @@ -1,6 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Protocol, runtime_checkable +from asyncio import iscoroutinefunction +from typing import Protocol, runtime_checkable, Any, TypeGuard class URI(Protocol): @@ -58,6 +59,27 @@ class BytesReader(Protocol): def read(self, size: int = -1, /) -> bytes: ... +def is_bytes_reader(obj: Any) -> TypeGuard[BytesReader]: + """Determines whether the given object conforms to the BytesReader protocol. + + This is necessary to distinguish this from an async reader, since runtime_checkable + doesn't make that distinction. + + :param obj: The object to inspect. + """ + return isinstance(obj, BytesReader) and not iscoroutinefunction( + getattr(obj, "read") + ) + + # A union of all acceptable streaming blob types. Deserialized payloads will # always return a ByteStream, or AsyncByteStream if async is enabled. type StreamingBlob = BytesReader | bytes | bytearray + + +def is_streaming_blob(obj: Any) -> TypeGuard[StreamingBlob]: + """Determines whether the given object is a StreamingBlob. + + :param obj: The object to inspect. + """ + return isinstance(obj, bytes | bytearray) or is_bytes_reader(obj) diff --git a/packages/smithy-core/src/smithy_core/serializers.py b/packages/smithy-core/src/smithy_core/serializers.py index 2da8eab99..59fad7acf 100644 --- a/packages/smithy-core/src/smithy_core/serializers.py +++ b/packages/smithy-core/src/smithy_core/serializers.py @@ -5,11 +5,12 @@ from decimal import Decimal from typing import TYPE_CHECKING, Never, Protocol, runtime_checkable -from .exceptions import SmithyException +from .exceptions import SmithyException, UnsupportedStreamException if TYPE_CHECKING: from .documents import Document from .schemas import Schema + from .aio.interfaces import StreamingBlob as _Stream @runtime_checkable @@ -198,6 +199,24 @@ def write_document(self, schema: "Schema", value: "Document") -> None: """ ... + def write_data_stream(self, schema: "Schema", value: "_Stream") -> None: + """Write a data stream to the output. + + If the value is a stream (i.e. not bytes or bytearray) it MUST NOT be read + directly by this method. Such values are intended to only be read as needed when + sending a message, and so should be bound directly to the request / response + type and then read by the transport. + + Data streams are only supported at the top-level input and output for + operations. + + :param schema: The shape's schema. + :param value: The streaming value to write. + """ + if isinstance(value, bytes | bytearray): + self.write_blob(schema, bytes(value)) + raise UnsupportedStreamException() + def flush(self) -> None: """Flush the underlying data.""" @@ -324,6 +343,10 @@ def write_document(self, schema: "Schema", value: "Document") -> None: self.before(schema).write_document(schema, value) self.after(schema) + def write_data_stream(self, schema: "Schema", value: "_Stream") -> None: + self.before(schema).write_data_stream(schema, value) + self.after(schema) + class SpecificShapeSerializer(ShapeSerializer): """Expects to serialize a specific kind of shape, failing if other shapes are @@ -393,6 +416,9 @@ def write_timestamp(self, schema: "Schema", value: datetime.datetime) -> None: def write_document(self, schema: "Schema", value: "Document") -> None: self._invalid_state(schema) + def write_data_stream(self, schema: "Schema", value: "_Stream") -> None: + self._invalid_state(schema) + @runtime_checkable class SerializeableShape(Protocol): diff --git a/packages/smithy-core/src/smithy_core/traits.py b/packages/smithy-core/src/smithy_core/traits.py index 0e10310bb..4e1ca740f 100644 --- a/packages/smithy-core/src/smithy_core/traits.py +++ b/packages/smithy-core/src/smithy_core/traits.py @@ -7,11 +7,11 @@ # they're correct regardless, so it's okay if the checks are stripped out. # ruff: noqa: S101 -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING, ClassVar, Mapping -from .types import TimestampFormat +from .types import TimestampFormat, PathPattern from .shapes import ShapeID if TYPE_CHECKING: @@ -193,3 +193,117 @@ def __post_init__(self): @property def value(self) -> str: return self.document_value # type: ignore + + +# TODO: Get all this moved over to the http package +@dataclass(init=False, frozen=True) +class HTTPTrait(Trait, id=ShapeID("smithy.api#http")): + path: PathPattern = field(repr=False, hash=False, compare=False) + code: int = field(repr=False, hash=False, compare=False) + query: str | None = field(default=None, repr=False, hash=False, compare=False) + + def __init__(self, value: "DocumentValue | DynamicTrait" = None): + super().__init__(value) + assert isinstance(self.document_value, Mapping) + assert isinstance(self.document_value["method"], str) + + code = self.document_value.get("code", 200) + assert isinstance(code, int) + object.__setattr__(self, "code", code) + + uri = self.document_value["uri"] + assert isinstance(uri, str) + parts = uri.split("?", 1) + + object.__setattr__(self, "path", PathPattern(parts[0])) + object.__setattr__(self, "query", parts[1] if len(parts) == 2 else None) + + @property + def method(self) -> str: + return self.document_value["method"] # type: ignore + + +@dataclass(init=False, frozen=True) +class HTTPErrorTrait(Trait, id=ShapeID("smithy.api#httpError")): + def __post_init__(self): + assert isinstance(self.document_value, int) + + @property + def code(self) -> int: + return self.document_value # type: ignore + + +@dataclass(init=False, frozen=True) +class HTTPHeaderTrait(Trait, id=ShapeID("smithy.api#httpHeader")): + def __post_init__(self): + assert isinstance(self.document_value, str) + + @property + def key(self) -> str: + return self.document_value # type: ignore + + +@dataclass(init=False, frozen=True) +class HTTPLabelTrait(Trait, id=ShapeID("smithy.api#httpLabel")): + def __post_init__(self): + assert self.document_value is None + + +@dataclass(init=False, frozen=True) +class HTTPPayloadTrait(Trait, id=ShapeID("smithy.api#httpPayload")): + def __post_init__(self): + assert self.document_value is None + + +@dataclass(init=False, frozen=True) +class HTTPPrefixHeadersTrait(Trait, id=ShapeID("smithy.api#httpPrefixHeaders")): + def __post_init__(self): + assert isinstance(self.document_value, str) + + @property + def prefix(self) -> str: + return self.document_value # type: ignore + + +@dataclass(init=False, frozen=True) +class HTTPQueryTrait(Trait, id=ShapeID("smithy.api#httpQuery")): + def __post_init__(self): + assert isinstance(self.document_value, str) + + @property + def key(self) -> str: + return self.document_value # type: ignore + + +@dataclass(init=False, frozen=True) +class HTTPQueryParamsTrait(Trait, id=ShapeID("smithy.api#httpQueryParams")): + def __post_init__(self): + assert self.document_value is None + + +@dataclass(init=False, frozen=True) +class HTTPResponseCodeTrait(Trait, id=ShapeID("smithy.api#httpResponseCode")): + def __post_init__(self): + assert self.document_value is None + + +@dataclass(init=False, frozen=True) +class HTTPChecksumRequiredTrait(Trait, id=ShapeID("smithy.api#httpChecksumRequired")): + def __post_init__(self): + assert self.document_value is None + + +@dataclass(init=False, frozen=True) +class EndpointTrait(Trait, id=ShapeID("smithy.api#endpoint")): + def __post_init__(self): + assert isinstance(self.document_value, str) + + @property + def host_prefix(self) -> str: + return self.document_value["hostPrefix"] # type: ignore + + +@dataclass(init=False, frozen=True) +class HostLabelTrait(Trait, id=ShapeID("smithy.api#hostLabel")): + def __post_init__(self): + assert self.document_value is None diff --git a/packages/smithy-core/src/smithy_core/types.py b/packages/smithy-core/src/smithy_core/types.py index f59327754..f9ce45b71 100644 --- a/packages/smithy-core/src/smithy_core/types.py +++ b/packages/smithy-core/src/smithy_core/types.py @@ -1,11 +1,13 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import json +import re from collections.abc import Mapping, Sequence from datetime import datetime from email.utils import format_datetime, parsedate_to_datetime from enum import Enum from typing import Any +from dataclasses import dataclass from .exceptions import ExpectationNotMetException from .utils import ( @@ -16,6 +18,8 @@ serialize_rfc3339, ) +_GREEDY_LABEL_RE = re.compile(r"\{(\w+)\+\}") + type Document = ( Mapping[str, "Document"] | Sequence["Document"] | str | int | float | bool | None ) @@ -111,3 +115,41 @@ def deserialize(self, value: str | float) -> datetime: return ensure_utc(parsedate_to_datetime(expect_type(str, value))) case TimestampFormat.DATE_TIME: return ensure_utc(datetime.fromisoformat(expect_type(str, value))) + + +@dataclass(init=False, frozen=True) +class PathPattern: + """A formattable URI path pattern. + + The pattern may contain formattable labels, which may be normal labels or greedy + labels. Normal labels forbid path separators, greedy labels allow them. + """ + + pattern: str + """The path component of the URI which is a formattable string.""" + + greedy_labels: set[str] + """The pattern labels whose values may contain path separators.""" + + def __init__(self, pattern: str) -> None: + object.__setattr__(self, "pattern", pattern) + object.__setattr__( + self, "greedy_labels", set(_GREEDY_LABEL_RE.findall(pattern)) + ) + + def format(self, *args: object, **kwargs: str) -> str: + if args: + raise ValueError("PathPattern formatting requires only keyword arguments.") + + for key, value in kwargs.items(): + if "/" in value and key not in self.greedy_labels: + raise ValueError( + 'Non-greedy labels must not contain path separators ("/").' + ) + + result = self.pattern.replace("+}", "}").format(**kwargs) + if "//" in result: + raise ValueError( + f'Path must not contain empty segments, but was "{result}".' + ) + return result diff --git a/packages/smithy-core/tests/unit/test_types.py b/packages/smithy-core/tests/unit/test_types.py index b7c48b697..f1d8b6205 100644 --- a/packages/smithy-core/tests/unit/test_types.py +++ b/packages/smithy-core/tests/unit/test_types.py @@ -7,7 +7,7 @@ import pytest from smithy_core.exceptions import ExpectationNotMetException -from smithy_core.types import JsonBlob, JsonString, TimestampFormat +from smithy_core.types import JsonBlob, JsonString, TimestampFormat, PathPattern def test_json_string() -> None: @@ -180,3 +180,42 @@ def test_invalid_timestamp_format_type_raises( ): with pytest.raises(ExpectationNotMetException): format.deserialize(value) + + +def test_path_pattern_without_labels(): + assert PathPattern("/foo/").format() == "/foo/" + + +def test_path_pattern_with_normal_label(): + assert PathPattern("/{foo}/").format(foo="foo") == "/foo/" + + +def test_path_pattern_with_greedy_label(): + assert PathPattern("/{foo+}/").format(foo="foo") == "/foo/" + + +def test_path_pattern_greedy_label_allows_path_sep(): + assert PathPattern("/{foo+}/").format(foo="foo/bar") == "/foo/bar/" + + +def test_path_pattern_normal_label_disallows_path_sep(): + with pytest.raises(ValueError): + PathPattern("/{foo}").format(foo="foo/bar") + + +@pytest.mark.parametrize( + "greedy, value", + [ + (False, ""), + (True, ""), + (True, "/"), + (True, "/foo"), + (True, "foo/"), + (True, "/foo/"), + (True, "foo//bar"), + ], +) +def test_path_pattern_disallows_empty_segments(greedy: bool, value: str): + pattern = PathPattern("/{foo+}/" if greedy else "/{foo}/") + with pytest.raises(ValueError): + pattern.format(foo=value) diff --git a/packages/smithy-http/src/smithy_http/deserializers.py b/packages/smithy-http/src/smithy_http/deserializers.py new file mode 100644 index 000000000..50d91bbbd --- /dev/null +++ b/packages/smithy-http/src/smithy_http/deserializers.py @@ -0,0 +1,260 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Callable +from typing import TYPE_CHECKING +from decimal import Decimal +import datetime + +from smithy_core.deserializers import ShapeDeserializer, SpecificShapeDeserializer +from smithy_core.codecs import Codec +from smithy_core.schemas import Schema +from smithy_core.traits import ( + HTTPTrait, + HTTPHeaderTrait, + HTTPPrefixHeadersTrait, + HTTPPayloadTrait, + HTTPResponseCodeTrait, + TimestampFormatTrait, +) +from smithy_core.utils import strict_parse_bool, strict_parse_float, ensure_utc +from smithy_core.types import TimestampFormat +from smithy_core.shapes import ShapeType +from smithy_core.exceptions import UnsupportedStreamException +from smithy_core.interfaces import is_bytes_reader, is_streaming_blob + +from .aio.interfaces import HTTPResponse + +from .interfaces import Field, Fields + +if TYPE_CHECKING: + from smithy_core.interfaces import StreamingBlob as SyncStreamingBlob + from smithy_core.aio.interfaces import StreamingBlob as AsyncStreamingBlob + + +__all__ = ["HTTPResponseDeserializer"] + + +class HTTPResponseDeserializer(SpecificShapeDeserializer): + """Binds :py:class:`HTTPResponse` properties to a DeserializableShape.""" + + # Note: caller will have to read the body if it's async and not streaming + def __init__( + self, + payload_codec: Codec, + http_trait: HTTPTrait, + response: HTTPResponse, + body: "SyncStreamingBlob | None" = None, + ) -> None: + """Initialize an HTTPResponseDeserializer. + + :param payload_codec: The Codec to use to deserialize the payload, if present. + :param http_trait: The HTTP trait of the operation being handled. + :param response: The HTTP response to read from. + :param body: The HTTP response body in a synchronously readable form. This is + necessary for async response bodies when there is no streaming member. + """ + self._payload_codec = payload_codec + self._response = response + self._http_trait = http_trait + self._body = body + + def read_struct( + self, schema: Schema, consumer: Callable[[Schema, ShapeDeserializer], None] + ) -> None: + has_body = False + payload_member: Schema | None = None + + for member in schema.members.values(): + if (trait := member.get_trait(HTTPHeaderTrait)) is not None: + header = self._response.fields.entries.get(trait.key.lower()) + if header is not None: + if member.shape_type is ShapeType.LIST: + consumer(member, HTTPHeaderListDeserializer(header)) + else: + consumer(member, HTTPHeaderDeserializer(header.as_string())) + elif (trait := member.get_trait(HTTPPrefixHeadersTrait)) is not None: + consumer( + member, + HTTPHeaderMapDeserializer(self._response.fields, trait.prefix), + ) + elif HTTPPayloadTrait in member: + has_body = True + payload_member = member + elif HTTPResponseCodeTrait in member: + consumer(member, HTTPResponseCodeDeserializer(self._response.status)) + else: + has_body = True + + if has_body: + deserializer = self._create_payload_deserializer(payload_member) + if payload_member is not None: + consumer(payload_member, deserializer) + else: + deserializer.read_struct(schema, consumer) + + def _create_payload_deserializer( + self, payload_member: Schema | None + ) -> ShapeDeserializer: + body = self._body if self._body is not None else self._response.body + if payload_member is not None and payload_member.shape_type in ( + ShapeType.BLOB, + ShapeType.STRING, + ): + return RawPayloadDeserializer(body) + + if not is_streaming_blob(body): + raise UnsupportedStreamException( + "Unable to read async stream. This stream must be buffered prior " + "to creating the deserializer." + ) + + if isinstance(body, bytearray): + body = bytes(body) + + return self._payload_codec.create_deserializer(body) + + +class HTTPHeaderDeserializer(SpecificShapeDeserializer): + """Binds HTTP header values to a deserializable shape. + + For headers with list values, see :py:class:`HTTPHeaderListDeserializer`. + """ + + def __init__(self, value: str) -> None: + """Initialize an HTTPHeaderDeserializer. + + :param value: The string value of the header. + """ + self._value = value + + def read_boolean(self, schema: Schema) -> bool: + return strict_parse_bool(self._value) + + def read_byte(self, schema: Schema) -> int: + return self.read_integer(schema) + + def read_short(self, schema: Schema) -> int: + return self.read_integer(schema) + + def read_integer(self, schema: Schema) -> int: + return int(self._value) + + def read_long(self, schema: Schema) -> int: + return self.read_integer(schema) + + def read_big_integer(self, schema: Schema) -> int: + return self.read_integer(schema) + + def read_float(self, schema: Schema) -> float: + return strict_parse_float(self._value) + + def read_double(self, schema: Schema) -> float: + return self.read_float(schema) + + def read_big_decimal(self, schema: Schema) -> Decimal: + return Decimal(self._value).canonical() + + def read_string(self, schema: Schema) -> str: + return self._value + + def read_timestamp(self, schema: Schema) -> datetime.datetime: + format = TimestampFormat.HTTP_DATE + if (trait := schema.get_trait(TimestampFormatTrait)) is not None: + format = trait.format + return ensure_utc(format.deserialize(self._value)) + + +class HTTPHeaderListDeserializer(SpecificShapeDeserializer): + """Binds HTTP header lists to a deserializable shape.""" + + def __init__(self, field: Field) -> None: + """Initialize an HTTPHeaderListDeserializer. + + :param field: The field to deserialize. + """ + self._field = field + + def read_list( + self, schema: Schema, consumer: Callable[["ShapeDeserializer"], None] + ) -> None: + for value in self._field.values: + consumer(HTTPHeaderDeserializer(value)) + + +class HTTPHeaderMapDeserializer(SpecificShapeDeserializer): + """Binds HTTP header maps to a deserializable shape.""" + + def __init__(self, fields: Fields, prefix: str = "") -> None: + """Initialize an HTTPHeaderMapDeserializer. + + :param fields: The collection of headers to search for map values. + :param prefix: An optional prefix to limit which headers are pulled in to the + map. By default, all headers are pulled in, including headers that are bound + to other properties on the shape. + """ + self._prefix = prefix.lower() + self._fields = fields + + def read_map( + self, + schema: Schema, + consumer: Callable[[str, "ShapeDeserializer"], None], + ) -> None: + trim = len(self._prefix) + for field in self._fields: + if field.name.lower().startswith(self._prefix): + consumer(field.name[trim:], HTTPHeaderDeserializer(field.as_string())) + + +class HTTPResponseCodeDeserializer(SpecificShapeDeserializer): + """Binds HTTP response codes to a deserializeable shape.""" + + def __init__(self, response_code: int) -> None: + """Initialize an HTTPResponseCodeDeserializer. + + :param response_code: The response code to bind. + """ + self._response_code = response_code + + def read_byte(self, schema: Schema) -> int: + return self._response_code + + def read_short(self, schema: Schema) -> int: + return self._response_code + + def read_integer(self, schema: Schema) -> int: + return self._response_code + + +class RawPayloadDeserializer(SpecificShapeDeserializer): + """Binds an HTTP payload to a deserializeable shape.""" + + def __init__(self, payload: "AsyncStreamingBlob") -> None: + """Initialize a RawPayloadDeserializer. + + :param payload: The payload to bind. If the member that is bound to the payload + is a string or blob, it MUST NOT be an async stream. Async streams MUST be + buffered into a synchronous stream ahead of time. + """ + self._payload = payload + + def read_string(self, schema: Schema) -> str: + return self._consume_payload().decode("utf-8") + + def read_blob(self, schema: Schema) -> bytes: + return self._consume_payload() + + def read_data_stream(self, schema: Schema) -> "AsyncStreamingBlob": + return self._payload + + def _consume_payload(self) -> bytes: + if isinstance(self._payload, bytes): + return self._payload + if isinstance(self._payload, bytearray): + return bytes(self._payload) + if is_bytes_reader(self._payload): + return self._payload.read() + raise UnsupportedStreamException( + "Unable to read async stream. This stream must be buffered prior " + "to creating the deserializer." + ) diff --git a/packages/smithy-http/src/smithy_http/serializers.py b/packages/smithy-http/src/smithy_http/serializers.py new file mode 100644 index 000000000..b3016f045 --- /dev/null +++ b/packages/smithy-http/src/smithy_http/serializers.py @@ -0,0 +1,606 @@ +from collections.abc import Callable, Iterator +from contextlib import contextmanager +from datetime import datetime +from decimal import Decimal +from io import BytesIO +from typing import Any, TYPE_CHECKING +from urllib.parse import quote as urlquote +from asyncio import iscoroutinefunction + +from smithy_core.serializers import ( + MapSerializer, + ShapeSerializer, + SpecificShapeSerializer, + InterceptingSerializer, +) +from smithy_core.codecs import Codec +from smithy_core.types import TimestampFormat, PathPattern +from smithy_core.schemas import Schema +from smithy_core.traits import ( + HTTPTrait, + HTTPPayloadTrait, + HTTPHeaderTrait, + HTTPPrefixHeadersTrait, + HTTPQueryParamsTrait, + HTTPQueryTrait, + HTTPLabelTrait, + HTTPResponseCodeTrait, + HostLabelTrait, + TimestampFormatTrait, + EndpointTrait, + HTTPErrorTrait, +) +from smithy_core.shapes import ShapeType +from smithy_core.utils import serialize_float + +from .aio import HTTPRequest as _HTTPRequest +from .aio import HTTPResponse as _HTTPResponse +from .aio.interfaces import HTTPRequest, HTTPResponse +from smithy_core import URI +from . import tuples_to_fields +from .utils import join_query_params + + +if TYPE_CHECKING: + from smithy_core.aio.interfaces import StreamingBlob as AsyncStreamingBlob + + +__all__ = ["HTTPRequestSerializer", "HTTPResponseSerializer"] + + +class HTTPRequestSerializer(SpecificShapeSerializer): + """Binds a serializable shape to an HTTP request. + + The resultant HTTP request is not immediately sendable. In particular, the host of + the destination URI is incomplete and MUST be suffixed before sending. + """ + + def __init__( + self, + payload_codec: Codec, + http_trait: HTTPTrait, + endpoint_trait: EndpointTrait | None = None, + ) -> None: + """Initialize an HTTPRequestSerializer. + + :param payload_codec: The codec to use to serialize the HTTP payload, if one is + present. + :param http_trait: The HTTP trait of the operation being handled. + :param endpoint_trait: The optional endpoint trait of the operation being + handled. + """ + self._http_trait = http_trait + self._endpoint_trait = endpoint_trait + self._payload_codec = payload_codec + self.result: HTTPRequest | None = None + + @contextmanager + def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]: + payload: Any + binding_serializer: HTTPRequestBindingSerializer + + host_prefix = "" + if self._endpoint_trait is not None: + host_prefix = self._endpoint_trait.host_prefix + + if (payload_member := self._get_payload_member(schema)) is not None: + if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING): + payload_serializer = RawPayloadSerializer() + binding_serializer = HTTPRequestBindingSerializer( + payload_serializer, self._http_trait.path, host_prefix + ) + yield binding_serializer + payload = payload_serializer.payload + else: + payload = BytesIO() + payload_serializer = self._payload_codec.create_serializer(payload) + binding_serializer = HTTPRequestBindingSerializer( + payload_serializer, self._http_trait.path, host_prefix + ) + yield binding_serializer + else: + payload = BytesIO() + payload_serializer = self._payload_codec.create_serializer(payload) + with payload_serializer.begin_struct(schema) as body_serializer: + binding_serializer = HTTPRequestBindingSerializer( + body_serializer, self._http_trait.path, host_prefix + ) + yield binding_serializer + + if ( + seek := getattr(payload, "seek", None) + ) is not None and not iscoroutinefunction(seek): + seek(0) + + self.result = _HTTPRequest( + method=self._http_trait.method, + destination=URI( + host=binding_serializer.host_prefix_serializer.host_prefix, + path=binding_serializer.path_serializer.path, + query=join_query_params( + params=binding_serializer.query_serializer.query_params, + prefix=self._http_trait.query or "", + ), + ), + fields=tuples_to_fields(binding_serializer.header_serializer.headers), + body=payload, + ) + + def _get_payload_member(self, schema: Schema) -> Schema | None: + for member in schema.members.values(): + if HTTPPayloadTrait in member: + return member + return None + + +class HTTPRequestBindingSerializer(InterceptingSerializer): + """Delegates HTTP request bindings to binding-location-specific serializers.""" + + def __init__( + self, + payload_serializer: ShapeSerializer, + path_pattern: PathPattern, + host_prefix_pattern: str, + ) -> None: + """Initialize an HTTPRequestBindingSerializer. + + :param payload_serializer: The :py:class:`ShapeSerializer` to use to serialize + the payload, if necessary. + :param path_pattern: The pattern used to construct the path. + :host_prefix_pattern: The pattern used to construct the host prefix. + """ + self._payload_serializer = payload_serializer + self.header_serializer = HTTPHeaderSerializer() + self.query_serializer = HTTPQuerySerializer() + self.path_serializer = HTTPPathSerializer(path_pattern) + self.host_prefix_serializer = HostPrefixSerializer( + payload_serializer, host_prefix_pattern + ) + + def before(self, schema: Schema) -> ShapeSerializer: + if HTTPHeaderTrait in schema or HTTPPrefixHeadersTrait in schema: + return self.header_serializer + if HTTPQueryTrait in schema or HTTPQueryParamsTrait in schema: + return self.query_serializer + if HTTPLabelTrait in schema: + return self.path_serializer + if HostLabelTrait in schema: + return self.host_prefix_serializer + + return self._payload_serializer + + def after(self, schema: Schema) -> None: + pass + + +class HTTPResponseSerializer(SpecificShapeSerializer): + """Binds a serializable shape to an HTTP response.""" + + def __init__( + self, + payload_codec: Codec, + http_trait: HTTPTrait, + ) -> None: + """Initialize an HTTPResponseSerializer. + + :param payload_codec: The codec to use to serialize the HTTP payload, if one is + present. + :param http_trait: The HTTP trait of the operation being handled. + """ + self._http_trait = http_trait + self._payload_codec = payload_codec + self.result: HTTPResponse | None = None + + @contextmanager + def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]: + payload: Any + binding_serializer: HTTPResponseBindingSerializer + + if (payload_member := self._get_payload_member(schema)) is not None: + if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING): + payload_serializer = RawPayloadSerializer() + binding_serializer = HTTPResponseBindingSerializer(payload_serializer) + yield binding_serializer + payload = payload_serializer.payload + else: + payload = BytesIO() + payload_serializer = self._payload_codec.create_serializer(payload) + binding_serializer = HTTPResponseBindingSerializer(payload_serializer) + yield binding_serializer + else: + payload = BytesIO() + payload_serializer = self._payload_codec.create_serializer(payload) + with payload_serializer.begin_struct(schema) as body_serializer: + binding_serializer = HTTPResponseBindingSerializer(body_serializer) + yield binding_serializer + + if ( + seek := getattr(payload, "seek", None) + ) is not None and not iscoroutinefunction(seek): + seek(0) + + default_code = self._http_trait.code + explicit_code = binding_serializer.response_code_serializer.response_code + if (http_error_trait := schema.get_trait(HTTPErrorTrait)) is not None: + default_code = http_error_trait.code + + self.result = _HTTPResponse( + fields=tuples_to_fields(binding_serializer.header_serializer.headers), + body=payload, + status=explicit_code or default_code, + ) + + def _get_payload_member(self, schema: Schema) -> Schema | None: + for member in schema.members.values(): + if HTTPPayloadTrait in member: + return member + return None + + +class HTTPResponseBindingSerializer(InterceptingSerializer): + """Delegates HTTP response bindings to binding-location-specific serializers.""" + + def __init__(self, payload_serializer: ShapeSerializer) -> None: + """Initialize an HTTPResponseBindingSerializer. + + :param payload_serializer: The :py:class:`ShapeSerializer` to use to serialize + the payload, if necessary. + """ + self._payload_serializer = payload_serializer + self.header_serializer = HTTPHeaderSerializer() + self.response_code_serializer = HTTPResponseCodeSerializer() + + def before(self, schema: Schema) -> ShapeSerializer: + if HTTPHeaderTrait in schema or HTTPPrefixHeadersTrait in schema: + return self.header_serializer + if HTTPResponseCodeTrait in schema: + return self.response_code_serializer + + return self._payload_serializer + + def after(self, schema: Schema) -> None: + pass + + +class RawPayloadSerializer(SpecificShapeSerializer): + """Binds properties of serializable shape to an HTTP payload.""" + + payload: "AsyncStreamingBlob | None" + """The serialized payload. + + This will only be non-null after serialization. + """ + + def __init__(self) -> None: + """Initialize a RawPayloadSerializer.""" + self.payload: "AsyncStreamingBlob | None" = None + + def write_string(self, schema: Schema, value: str) -> None: + self.payload = value.encode("utf-8") + + def write_blob(self, schema: Schema, value: bytes) -> None: + self.payload = value + + def write_data_stream(self, schema: Schema, value: "AsyncStreamingBlob") -> None: + self.payload = value + + +class HTTPHeaderSerializer(SpecificShapeSerializer): + """Binds properties of a serializable shape to HTTP headers.""" + + headers: list[tuple[str, str]] + """A list of serialized headers. + + This should only be accessed after serialization. + """ + + def __init__( + self, key: str | None = None, headers: list[tuple[str, str]] | None = None + ) -> None: + """Initialize an HTTPHeaderSerializer. + + :param key: An optional key to specifically write. If not set, the + :py:class:`HTTPHeaderTrait` will be checked instead. Required when + collecting list entries. + :param headers: An optional list of header tuples to append to. If not + set, one will be created. + """ + self.headers: list[tuple[str, str]] = headers if headers is not None else [] + self._key = key + + @contextmanager + def begin_list(self, schema: Schema, size: int) -> Iterator[ShapeSerializer]: + key = self._key or schema.expect_trait(HTTPHeaderTrait).key + delegate = HTTPHeaderSerializer(key=key, headers=self.headers) + yield delegate + + @contextmanager + def begin_map(self, schema: Schema, size: int) -> Iterator[MapSerializer]: + prefix = schema.expect_trait(HTTPPrefixHeadersTrait).prefix + yield HTTPHeaderMapSerializer(prefix, self.headers) + + def write_boolean(self, schema: Schema, value: bool) -> None: + key = self._key or schema.expect_trait(HTTPHeaderTrait).key + self.headers.append((key, "true" if value else "false")) + + def write_byte(self, schema: Schema, value: int) -> None: + key = self._key or schema.expect_trait(HTTPHeaderTrait).key + self.headers.append((key, str(value))) + + def write_short(self, schema: Schema, value: int) -> None: + key = self._key or schema.expect_trait(HTTPHeaderTrait).key + self.headers.append((key, str(value))) + + def write_integer(self, schema: Schema, value: int) -> None: + key = self._key or schema.expect_trait(HTTPHeaderTrait).key + self.headers.append((key, str(value))) + + def write_long(self, schema: Schema, value: int) -> None: + key = self._key or schema.expect_trait(HTTPHeaderTrait).key + self.headers.append((key, str(value))) + + def write_big_integer(self, schema: Schema, value: int) -> None: + key = self._key or schema.expect_trait(HTTPHeaderTrait).key + self.headers.append((key, str(value))) + + def write_float(self, schema: Schema, value: float) -> None: + key = self._key or schema.expect_trait(HTTPHeaderTrait).key + self.headers.append((key, str(value))) + + def write_double(self, schema: Schema, value: float) -> None: + key = self._key or schema.expect_trait(HTTPHeaderTrait).key + self.headers.append((key, str(value))) + + def write_big_decimal(self, schema: Schema, value: Decimal) -> None: + key = self._key or schema.expect_trait(HTTPHeaderTrait).key + self.headers.append((key, str(value.canonical()))) + + def write_string(self, schema: Schema, value: str) -> None: + key = self._key or schema.expect_trait(HTTPHeaderTrait).key + self.headers.append((key, value)) + + def write_timestamp(self, schema: Schema, value: datetime) -> None: + key = self._key or schema.expect_trait(HTTPHeaderTrait).key + format = TimestampFormat.HTTP_DATE + if (trait := schema.get_trait(TimestampFormatTrait)) is not None: + format = trait.format + self.headers.append((key, str(format.serialize(value)))) + + +class HTTPHeaderMapSerializer(MapSerializer): + """Binds a mapping property of a serializeable shape to multiple HTTP headers.""" + + def __init__(self, prefix: str, headers: list[tuple[str, str]]) -> None: + """Initialize an HTTPHeaderMapSerializer. + + :param prefix: The prefix to prepend to each of the map keys. + :param headers: The list of header tuples to append to. + """ + self._prefix = prefix + self._headers = headers + self._delegate = CapturingSerializer() + + def entry(self, key: str, value_writer: Callable[[ShapeSerializer], None]): + value_writer(self._delegate) + assert self._delegate.result is not None + self._headers.append((self._prefix + key, self._delegate.result)) + + +class CapturingSerializer(SpecificShapeSerializer): + """Directly passes along a string through a serializer.""" + + result: str | None + """The captured string. + + This will only be set after the serializer has been used. + """ + + def __init__(self) -> None: + self.result = None + + def write_string(self, schema: Schema, value: str) -> None: + self.result = value + + +class HTTPQuerySerializer(SpecificShapeSerializer): + """Binds properties of a serializable shape to HTTP URI query params.""" + + def __init__( + self, key: str | None = None, params: list[tuple[str, str]] | None = None + ) -> None: + """Initialize an HTTPQuerySerializer. + + :param key: An optional key to specifically write. If not set, the + :py:class:`HTTPQueryTrait` will be checked instead. Required when + collecting list or map entries. + :param headers: An optional list of header tuples to append to. If not + set, one will be created. + """ + self.query_params: list[tuple[str, str]] = params if params is not None else [] + self._key = key + + @contextmanager + def begin_list(self, schema: Schema, size: int) -> Iterator[ShapeSerializer]: + key = self._key or schema.expect_trait(HTTPQueryTrait).key + yield HTTPQuerySerializer(key=key, params=self.query_params) + + @contextmanager + def begin_map(self, schema: Schema, size: int) -> Iterator[MapSerializer]: + yield HTTPQueryMapSerializer(self.query_params) + + def write_boolean(self, schema: Schema, value: bool) -> None: + key = self._key or schema.expect_trait(HTTPQueryTrait).key + self.query_params.append((key, "true" if value else "false")) + + def write_byte(self, schema: Schema, value: int) -> None: + self.write_integer(schema, value) + + def write_short(self, schema: Schema, value: int) -> None: + self.write_integer(schema, value) + + def write_integer(self, schema: Schema, value: int) -> None: + key = self._key or schema.expect_trait(HTTPQueryTrait).key + self.query_params.append((key, str(value))) + + def write_long(self, schema: Schema, value: int) -> None: + self.write_integer(schema, value) + + def write_big_integer(self, schema: Schema, value: int) -> None: + self.write_integer(schema, value) + + def write_float(self, schema: Schema, value: float) -> None: + key = self._key or schema.expect_trait(HTTPQueryTrait).key + self.query_params.append((key, serialize_float(value))) + + def write_double(self, schema: Schema, value: float) -> None: + self.write_float(schema, value) + + def write_big_decimal(self, schema: Schema, value: Decimal) -> None: + key = self._key or schema.expect_trait(HTTPQueryTrait).key + self.query_params.append((key, serialize_float(value))) + + def write_string(self, schema: Schema, value: str) -> None: + key = self._key or schema.expect_trait(HTTPQueryTrait).key + self.query_params.append((key, urlquote(value, safe=""))) + + def write_timestamp(self, schema: Schema, value: datetime) -> None: + key = self._key or schema.expect_trait(HTTPQueryTrait).key + format = TimestampFormat.DATE_TIME + if (trait := schema.get_trait(TimestampFormatTrait)) is not None: + format = trait.format + self.query_params.append((key, str(format.serialize(value)))) + + +class HTTPPathSerializer(SpecificShapeSerializer): + """Binds properties of a serializable shape to the HTTP URI path.""" + + def __init__(self, path_pattern: PathPattern) -> None: + """Initialize an HTTPPathSerializer. + + :param path_pattern: The pattern to bind properties to. This is also used to + detect greedy labels, which have different escaping requirements. + """ + self._path_pattern = path_pattern + self._path_params: dict[str, str] = {} + + @property + def path(self) -> str: + """Get the formatted path. + + This must not be accessed before serialization is complete, otherwise an + exception will be raised. + """ + return self._path_pattern.format(**self._path_params) + + def write_boolean(self, schema: Schema, value: bool) -> None: + self._path_params[schema.expect_member_name()] = "true" if value else "false" + + def write_byte(self, schema: Schema, value: int) -> None: + self.write_integer(schema, value) + + def write_short(self, schema: Schema, value: int) -> None: + self.write_integer(schema, value) + + def write_integer(self, schema: Schema, value: int) -> None: + self._path_params[schema.expect_member_name()] = str(value) + + def write_long(self, schema: Schema, value: int) -> None: + self.write_integer(schema, value) + + def write_big_integer(self, schema: Schema, value: int) -> None: + self.write_integer(schema, value) + + def write_float(self, schema: Schema, value: float) -> None: + self._path_params[schema.expect_member_name()] = serialize_float(value) + + def write_double(self, schema: Schema, value: float) -> None: + self.write_float(schema, value) + + def write_big_decimal(self, schema: Schema, value: Decimal) -> None: + self._path_params[schema.expect_member_name()] = serialize_float(value) + + def write_string(self, schema: Schema, value: str) -> None: + key = schema.expect_member_name() + if key in self._path_pattern.greedy_labels: + value = urlquote(value) + else: + value = urlquote(value, safe="") + self._path_params[schema.expect_member_name()] = value + + def write_timestamp(self, schema: Schema, value: datetime) -> None: + format = TimestampFormat.DATE_TIME + if (trait := schema.get_trait(TimestampFormatTrait)) is not None: + format = trait.format + self._path_params[schema.expect_member_name()] = urlquote( + str(format.serialize(value)) + ) + + +class HTTPQueryMapSerializer(MapSerializer): + """Binds properties of a serializable shape to a map of HTTP query params.""" + + def __init__(self, query_params: list[tuple[str, str]]) -> None: + """Initialize an HTTPQueryMapSerializer. + + :param query_params: The list of query param tuples to append to. + """ + self._query_params = query_params + self._delegate = CapturingSerializer() + + def entry(self, key: str, value_writer: Callable[[ShapeSerializer], None]): + value_writer(self._delegate) + assert self._delegate.result is not None + self._query_params.append((key, urlquote(self._delegate.result, safe=""))) + + +class HostPrefixSerializer(SpecificShapeSerializer): + """Binds properites of a serializable shape to the HTTP URI host. + + These properties are also bound to the payload. + """ + + def __init__( + self, payload_serializer: ShapeSerializer, host_prefix_pattern: str + ) -> None: + """Initialize a HostPrefixSerializer. + + :param host_prefix_pattern: The pattern to bind properties to. + :param payload_serializer: The payload serializer to additionally write + properties to. + """ + self._prefix_params: dict[str, str] = {} + self._host_prefix_pattern = host_prefix_pattern + self._payload_serializer = payload_serializer + + @property + def host_prefix(self) -> str: + """The formatted host prefix. + + This must not be accessed before serialization is complete, otherwise an + exception will be raised. + """ + return self._host_prefix_pattern.format(**self._prefix_params) + + def write_string(self, schema: Schema, value: str) -> None: + self._payload_serializer.write_string(schema, value) + self._prefix_params[schema.expect_member_name()] = urlquote(value, safe=".") + + +class HTTPResponseCodeSerializer(SpecificShapeSerializer): + """Binds properties of a serializable shape to the HTTP response code.""" + + response_code: int | None + """The bound response code, or None if one hasn't been bound.""" + + def __init__(self) -> None: + """Initialize an HTTPResponseCodeSerializer.""" + self.response_code: int | None = None + + def write_byte(self, schema: Schema, value: int) -> None: + self.response_code = value + + def write_short(self, schema: Schema, value: int) -> None: + self.response_code = value + + def write_integer(self, schema: Schema, value: int) -> None: + self.response_code = value diff --git a/packages/smithy-http/src/smithy_http/utils.py b/packages/smithy-http/src/smithy_http/utils.py index d1ef92612..c01ac3541 100644 --- a/packages/smithy-http/src/smithy_http/utils.py +++ b/packages/smithy-http/src/smithy_http/utils.py @@ -1,5 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence from urllib.parse import quote as urlquote from smithy_core.exceptions import SmithyException @@ -105,7 +106,9 @@ def _consume_until( return result, end_index + 1 -def join_query_params(params: list[tuple[str, str | None]], prefix: str = "") -> str: +def join_query_params( + params: Sequence[tuple[str, str | None]], prefix: str = "" +) -> str: """Join a list of query parameter key-value tuples. :param params: The list of key-value query parameter tuples. diff --git a/packages/smithy-http/tests/unit/test_serializers.py b/packages/smithy-http/tests/unit/test_serializers.py new file mode 100644 index 000000000..306b0cc05 --- /dev/null +++ b/packages/smithy-http/tests/unit/test_serializers.py @@ -0,0 +1,1694 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import datetime +from asyncio import iscoroutinefunction +from decimal import Decimal +from dataclasses import dataclass, field +from typing import ClassVar, Self, Any, Protocol +from datetime import UTC +from io import BytesIO + +import pytest + +from smithy_core import URI +from smithy_core.aio.interfaces import StreamingBlob +from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.traits import ( + HTTPLabelTrait, + TimestampFormatTrait, + HTTPPayloadTrait, + StreamingTrait, + HTTPHeaderTrait, + HTTPResponseCodeTrait, + HTTPPrefixHeadersTrait, + Trait, + HTTPQueryTrait, + HTTPQueryParamsTrait, + HTTPTrait, + HostLabelTrait, + EndpointTrait, +) +from smithy_core.schemas import Schema +from smithy_core.serializers import ShapeSerializer, SerializeableShape +from smithy_core.deserializers import ShapeDeserializer, DeserializeableShape +from smithy_core.prelude import ( + BLOB, + STRING, + INTEGER, + FLOAT, + BIG_DECIMAL, + BOOLEAN, + TIMESTAMP, +) +from smithy_core.aio.types import AsyncBytesReader +from smithy_http.deserializers import HTTPResponseDeserializer +from smithy_json import JSONCodec +from smithy_http.aio import HTTPResponse as _HTTPResponse +from smithy_http import tuples_to_fields, Fields +from smithy_http.serializers import HTTPRequestSerializer, HTTPResponseSerializer + +# TODO: empty header prefix, query map + +BOOLEAN_LIST = Schema.collection( + id=ShapeID("com.smithy#BooleanList"), + shape_type=ShapeType.LIST, + members={"member": {"index": 0, "target": BOOLEAN}}, +) +STRING_LIST = Schema.collection( + id=ShapeID("com.smithy#StringList"), + shape_type=ShapeType.LIST, + members={"member": {"index": 0, "target": STRING}}, +) +INTEGER_LIST = Schema.collection( + id=ShapeID("com.smithy#IntegerList"), + shape_type=ShapeType.LIST, + members={"member": {"index": 0, "target": INTEGER}}, +) +FLOAT_LIST = Schema.collection( + id=ShapeID("com.smithy#FloatList"), + shape_type=ShapeType.LIST, + members={"member": {"index": 0, "target": FLOAT}}, +) +BIG_DECIMAL_LIST = Schema.collection( + id=ShapeID("com.smithy#BigDecimalList"), + shape_type=ShapeType.LIST, + members={"member": {"index": 0, "target": BIG_DECIMAL}}, +) +BARE_TIMESTAMP_LIST = Schema.collection( + id=ShapeID("com.smithy#BareTimestampList"), + shape_type=ShapeType.LIST, + members={"member": {"index": 0, "target": TIMESTAMP}}, +) +HTTP_DATE_TIMESTAMP_LIST = Schema.collection( + id=ShapeID("com.smithy#HttpDateTimestampList"), + shape_type=ShapeType.LIST, + members={ + "member": { + "index": 0, + "target": TIMESTAMP, + "traits": [TimestampFormatTrait("http-date")], + } + }, +) +DATE_TIME_TIMESTAMP_LIST = Schema.collection( + id=ShapeID("com.smithy#DateTimeTimestampList"), + shape_type=ShapeType.LIST, + members={ + "member": { + "index": 0, + "target": TIMESTAMP, + "traits": [TimestampFormatTrait("date-time")], + } + }, +) +EPOCH_TIMESTAMP_LIST = Schema.collection( + id=ShapeID("com.smithy#EpochTimestampList"), + shape_type=ShapeType.LIST, + members={ + "member": { + "index": 0, + "target": TIMESTAMP, + "traits": [TimestampFormatTrait("epoch-seconds")], + } + }, +) +STRING_MAP = Schema.collection( + id=ShapeID("com.smithy#StringMap"), + shape_type=ShapeType.MAP, + members={ + "key": {"index": 0, "target": STRING}, + "value": {"index": 1, "target": STRING}, + }, +) + + +@dataclass +class _HTTPMapping(Protocol): + boolean_member: bool | None = None + boolean_list_member: list[bool] = field(default_factory=list) + integer_member: int | None = None + integer_list_member: list[int] = field(default_factory=list) + float_member: float | None = None + float_list_member: list[float] = field(default_factory=list) + big_decimal_member: Decimal | None = None + big_decimal_list_member: list[Decimal] = field(default_factory=list) + string_member: str | None = None + string_list_member: list[str] = field(default_factory=list) + default_timestamp_member: datetime.datetime | None = None + http_date_timestamp_member: datetime.datetime | None = None + http_date_list_timestamp_member: list[datetime.datetime] = field( + default_factory=list + ) + date_time_timestamp_member: datetime.datetime | None = None + date_time_list_timestamp_member: list[datetime.datetime] = field( + default_factory=list + ) + epoch_timestamp_member: datetime.datetime | None = None + epoch_list_timestamp_member: list[datetime.datetime] = field(default_factory=list) + string_map_member: dict[str, str] = field(default_factory=dict) + + ID: ClassVar[ShapeID] + SCHEMA: ClassVar[Schema] + + def __init_subclass__( + cls, id: ShapeID, trait: type[Trait], map_trait: Trait + ) -> None: + cls.ID = id + cls.SCHEMA = Schema.collection( + id=id, + members={ + "boolean_member": { + "index": 0, + "target": BOOLEAN, + "traits": [trait("boolean")], + }, + "boolean_list_member": { + "index": 1, + "target": BOOLEAN_LIST, + "traits": [trait("booleanList")], + }, + "integer_member": { + "index": 2, + "target": INTEGER, + "traits": [trait("integer")], + }, + "integer_list_member": { + "index": 3, + "target": INTEGER_LIST, + "traits": [trait("integerList")], + }, + "float_member": { + "index": 4, + "target": FLOAT, + "traits": [trait("float")], + }, + "float_list_member": { + "index": 5, + "target": FLOAT_LIST, + "traits": [trait("floatList")], + }, + "big_decimal_member": { + "index": 6, + "target": BIG_DECIMAL, + "traits": [trait("bigDecimal")], + }, + "big_decimal_list_member": { + "index": 7, + "target": BIG_DECIMAL_LIST, + "traits": [trait("bigDecimalList")], + }, + "string_member": { + "index": 8, + "target": STRING, + "traits": [trait("string")], + }, + "string_list_member": { + "index": 9, + "target": STRING_LIST, + "traits": [trait("stringList")], + }, + "default_timestamp_member": { + "index": 10, + "target": TIMESTAMP, + "traits": [trait("defaultTimestamp")], + }, + "http_date_timestamp_member": { + "index": 11, + "target": TIMESTAMP, + "traits": [ + trait("httpDateTimestamp"), + TimestampFormatTrait("http-date"), + ], + }, + "http_date_list_timestamp_member": { + "index": 12, + "target": HTTP_DATE_TIMESTAMP_LIST, + "traits": [trait("httpDateListTimestamp")], + }, + "date_time_timestamp_member": { + "index": 13, + "target": TIMESTAMP, + "traits": [ + trait("dateTimeTimestamp"), + TimestampFormatTrait("date-time"), + ], + }, + "date_time_list_timestamp_member": { + "index": 14, + "target": DATE_TIME_TIMESTAMP_LIST, + "traits": [trait("dateTimeListTimestamp")], + }, + "epoch_timestamp_member": { + "index": 15, + "target": TIMESTAMP, + "traits": [ + trait("epochTimestamp"), + TimestampFormatTrait("epoch-seconds"), + ], + }, + "epoch_list_timestamp_member": { + "index": 16, + "target": EPOCH_TIMESTAMP_LIST, + "traits": [trait("epochListTimestamp")], + }, + "string_map_member": { + "index": 17, + "target": STRING_MAP, + "traits": [map_trait], + }, + }, + ) + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(self.SCHEMA) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + if self.boolean_member is not None: + serializer.write_boolean( + self.SCHEMA.members["boolean_member"], self.boolean_member + ) + if self.boolean_list_member: + with serializer.begin_list( + self.SCHEMA.members["boolean_list_member"], + len(self.boolean_list_member), + ) as ls: + s = BOOLEAN_LIST.members["member"] + for e in self.boolean_list_member: + ls.write_boolean(s, e) + if self.integer_member is not None: + serializer.write_integer( + self.SCHEMA.members["integer_member"], self.integer_member + ) + if self.integer_list_member: + with serializer.begin_list( + self.SCHEMA.members["integer_list_member"], + len(self.integer_list_member), + ) as ls: + s = INTEGER_LIST.members["member"] + for e in self.integer_list_member: + ls.write_integer(s, e) + if self.float_member is not None: + serializer.write_float( + self.SCHEMA.members["float_member"], self.float_member + ) + if self.float_list_member: + with serializer.begin_list( + self.SCHEMA.members["float_list_member"], len(self.float_list_member) + ) as ls: + s = FLOAT_LIST.members["member"] + for e in self.float_list_member: + ls.write_float(s, e) + if self.big_decimal_member is not None: + serializer.write_big_decimal( + self.SCHEMA.members["big_decimal_member"], self.big_decimal_member + ) + if self.big_decimal_list_member: + with serializer.begin_list( + self.SCHEMA.members["big_decimal_list_member"], + len(self.big_decimal_list_member), + ) as ls: + s = BIG_DECIMAL_LIST.members["member"] + for e in self.big_decimal_list_member: + ls.write_big_decimal(s, e) + if self.string_member is not None: + serializer.write_string( + self.SCHEMA.members["string_member"], self.string_member + ) + if self.string_list_member: + with serializer.begin_list( + self.SCHEMA.members["string_list_member"], len(self.string_list_member) + ) as ls: + s = STRING_LIST.members["member"] + for e in self.string_list_member: + ls.write_string(s, e) + if self.default_timestamp_member is not None: + serializer.write_timestamp( + self.SCHEMA.members["default_timestamp_member"], + self.default_timestamp_member, + ) + if self.http_date_timestamp_member is not None: + serializer.write_timestamp( + self.SCHEMA.members["http_date_timestamp_member"], + self.http_date_timestamp_member, + ) + if self.http_date_list_timestamp_member: + with serializer.begin_list( + self.SCHEMA.members["http_date_list_timestamp_member"], + len(self.http_date_list_timestamp_member), + ) as ls: + s = HTTP_DATE_TIMESTAMP_LIST.members["member"] + for e in self.http_date_list_timestamp_member: + ls.write_timestamp(s, e) + if self.date_time_timestamp_member is not None: + serializer.write_timestamp( + self.SCHEMA.members["date_time_timestamp_member"], + self.date_time_timestamp_member, + ) + if self.date_time_list_timestamp_member: + with serializer.begin_list( + self.SCHEMA.members["date_time_list_timestamp_member"], + len(self.date_time_list_timestamp_member), + ) as ls: + s = DATE_TIME_TIMESTAMP_LIST.members["member"] + for e in self.date_time_list_timestamp_member: + ls.write_timestamp(s, e) + if self.epoch_timestamp_member is not None: + serializer.write_timestamp( + self.SCHEMA.members["epoch_timestamp_member"], + self.epoch_timestamp_member, + ) + if self.epoch_list_timestamp_member: + with serializer.begin_list( + self.SCHEMA.members["epoch_list_timestamp_member"], + len(self.epoch_list_timestamp_member), + ) as ls: + s = EPOCH_TIMESTAMP_LIST.members["member"] + for e in self.epoch_list_timestamp_member: + ls.write_timestamp(s, e) + if self.string_map_member: + with serializer.begin_map( + self.SCHEMA.members["string_map_member"], len(self.string_map_member) + ) as ms: + s = STRING_MAP.members["value"] + for k, v in self.string_map_member.items(): + ms.entry(k, lambda vs: vs.write_string(s, v)) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["boolean_member"] = de.read_boolean( + cls.SCHEMA.members["boolean_member"] + ) + case 1: + list_value: list[Any] = [] + de.read_list( + cls.SCHEMA.members["boolean_list_member"], + lambda d: list_value.append(d.read_boolean(BOOLEAN)), + ) + kwargs["boolean_list_member"] = list_value + case 2: + kwargs["integer_member"] = de.read_integer( + cls.SCHEMA.members["integer_member"] + ) + case 3: + list_value: list[Any] = [] + de.read_list( + cls.SCHEMA.members["integer_list_member"], + lambda d: list_value.append(d.read_integer(INTEGER)), + ) + kwargs["integer_list_member"] = list_value + case 4: + kwargs["float_member"] = de.read_float( + cls.SCHEMA.members["float_member"] + ) + case 5: + list_value: list[Any] = [] + de.read_list( + cls.SCHEMA.members["float_list_member"], + lambda d: list_value.append(d.read_float(FLOAT)), + ) + kwargs["float_list_member"] = list_value + case 6: + kwargs["big_decimal_member"] = de.read_big_decimal( + cls.SCHEMA.members["big_decimal_member"] + ) + case 7: + list_value: list[Any] = [] + de.read_list( + cls.SCHEMA.members["big_decimal_list_member"], + lambda d: list_value.append(d.read_big_decimal(BIG_DECIMAL)), + ) + kwargs["big_decimal_list_member"] = list_value + case 8: + kwargs["string_member"] = de.read_string( + cls.SCHEMA.members["string_member"] + ) + case 9: + list_value: list[Any] = [] + de.read_list( + cls.SCHEMA.members["string_list_member"], + lambda d: list_value.append(d.read_string(STRING)), + ) + kwargs["string_list_member"] = list_value + case 10: + kwargs["default_timestamp_member"] = de.read_timestamp( + cls.SCHEMA.members["default_timestamp_member"] + ) + case 11: + kwargs["http_date_timestamp_member"] = de.read_timestamp( + cls.SCHEMA.members["http_date_timestamp_member"] + ) + case 12: + list_value: list[Any] = [] + de.read_list( + cls.SCHEMA.members["http_date_list_timestamp_member"], + lambda d: list_value.append( + d.read_timestamp(HTTP_DATE_TIMESTAMP_LIST.members["member"]) + ), + ) + kwargs["http_date_list_timestamp_member"] = list_value + case 13: + kwargs["date_time_timestamp_member"] = de.read_timestamp( + cls.SCHEMA.members["date_time_timestamp_member"] + ) + case 14: + list_value: list[Any] = [] + de.read_list( + cls.SCHEMA.members["date_time_list_timestamp_member"], + lambda d: list_value.append( + d.read_timestamp(DATE_TIME_TIMESTAMP_LIST.members["member"]) + ), + ) + kwargs["date_time_list_timestamp_member"] = list_value + case 15: + kwargs["epoch_timestamp_member"] = de.read_timestamp( + cls.SCHEMA.members["epoch_timestamp_member"] + ) + case 16: + list_value: list[Any] = [] + de.read_list( + cls.SCHEMA.members["epoch_list_timestamp_member"], + lambda d: list_value.append( + d.read_timestamp(EPOCH_TIMESTAMP_LIST.members["member"]) + ), + ) + kwargs["epoch_list_timestamp_member"] = list_value + case 17: + map_value: dict[str, Any] = {} + de.read_map( + cls.SCHEMA.members["string_map_member"], + lambda k, d: map_value.__setitem__(k, d.read_string(STRING)), + ) + kwargs["string_map_member"] = map_value + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=cls.SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class HTTPHeaders( + _HTTPMapping, + id=ShapeID("com.smithy#HttpHeaders"), + trait=HTTPHeaderTrait, + map_trait=HTTPPrefixHeadersTrait("x-"), +): ... + + +@dataclass +class HTTPEmptyPrefixHeaders( + _HTTPMapping, + id=ShapeID("com.smithy#HttpHeaders"), + trait=HTTPHeaderTrait, + map_trait=HTTPPrefixHeadersTrait(""), +): ... + + +@dataclass +class HTTPQuery( + _HTTPMapping, + id=ShapeID("com.smithy#HTTPQuery"), + trait=HTTPQueryTrait, + map_trait=HTTPQueryParamsTrait(), +): ... + + +@dataclass +class HTTPResponseCode: + code: int = 200 + + ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPResponseCode") + SCHEMA: ClassVar[Schema] = Schema.collection( + id=ID, + members={ + "code": {"index": 0, "target": INTEGER, "traits": [HTTPResponseCodeTrait()]} + }, + ) + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(self.SCHEMA) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_integer(self.SCHEMA.members["code"], self.code) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["code"] = de.read_integer(cls.SCHEMA.members["code"]) + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=cls.SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class HTTPImplicitPayload: + header: str | None = None + payload_member: str | None = None + + ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPImplicitPayload") + SCHEMA: ClassVar[Schema] = Schema.collection( + id=ID, + members={ + "header": { + "index": 0, + "target": STRING, + "traits": [HTTPHeaderTrait("header")], + }, + "payload_member": {"index": 1, "target": STRING}, + }, + ) + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(self.SCHEMA) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + if self.header is not None: + serializer.write_string(self.SCHEMA.members["header"], self.header) + if self.payload_member is not None: + serializer.write_string( + self.SCHEMA.members["payload_member"], self.payload_member + ) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["header"] = de.read_string(cls.SCHEMA.members["header"]) + case 1: + kwargs["payload_member"] = de.read_string( + cls.SCHEMA.members["payload_member"] + ) + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=cls.SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class HTTPStringPayload: + payload: str + + ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPStringPayload") + SCHEMA: ClassVar[Schema] = Schema.collection( + id=ID, + members={ + "payload": {"index": 0, "target": STRING, "traits": [HTTPPayloadTrait()]} + }, + ) + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(self.SCHEMA) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_string(self.SCHEMA.members["payload"], self.payload) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["payload"] = de.read_string(cls.SCHEMA.members["payload"]) + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=cls.SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class HTTPBlobPayload: + payload: bytes + + ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPBlobPayload") + SCHEMA: ClassVar[Schema] = Schema.collection( + id=ID, + members={ + "payload": {"index": 0, "target": BLOB, "traits": [HTTPPayloadTrait()]} + }, + ) + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(self.SCHEMA) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_blob(self.SCHEMA.members["payload"], self.payload) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["payload"] = de.read_blob(cls.SCHEMA.members["payload"]) + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=cls.SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class HTTPStreamingPayload: + payload: StreamingBlob + + ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPStreamingPayload") + SCHEMA: ClassVar[Schema] = Schema.collection( + id=ID, + members={ + "payload": { + "index": 0, + "target": BLOB, + "traits": [HTTPPayloadTrait(), StreamingTrait()], + } + }, + ) + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(self.SCHEMA) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_data_stream(self.SCHEMA.members["payload"], self.payload) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["payload"] = de.read_data_stream( + cls.SCHEMA.members["payload"] + ) + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=cls.SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class HTTPStructuredPayload: + payload: HTTPStringPayload + + ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPStructuredPayload") + SCHEMA: ClassVar[Schema] = Schema.collection( + id=ID, + members={ + "payload": { + "index": 0, + "target": HTTPStringPayload.SCHEMA, + "traits": [HTTPPayloadTrait()], + } + }, + ) + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(self.SCHEMA) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_struct(self.SCHEMA.members["payload"], self.payload) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["payload"] = HTTPStringPayload.deserialize(de) + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=cls.SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class HTTPStringLabel: + label: str + + ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPStringLabel") + SCHEMA: ClassVar[Schema] = Schema.collection( + id=ID, + members={"label": {"index": 0, "target": STRING, "traits": [HTTPLabelTrait()]}}, + ) + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(self.SCHEMA) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_string(self.SCHEMA.members["label"], self.label) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["label"] = de.read_string(cls.SCHEMA.members["label"]) + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=cls.SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class HTTPIntegerLabel: + label: int + + ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPIntegerLabel") + SCHEMA: ClassVar[Schema] = Schema.collection( + id=ID, + members={ + "label": {"index": 0, "target": INTEGER, "traits": [HTTPLabelTrait()]} + }, + ) + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(self.SCHEMA) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_integer(self.SCHEMA.members["label"], self.label) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["label"] = de.read_integer(cls.SCHEMA.members["label"]) + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=cls.SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class HTTPFloatLabel: + label: float + + ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPFloatLabel") + SCHEMA: ClassVar[Schema] = Schema.collection( + id=ID, + members={"label": {"index": 0, "target": FLOAT, "traits": [HTTPLabelTrait()]}}, + ) + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(self.SCHEMA) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_float(self.SCHEMA.members["label"], self.label) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["label"] = de.read_float(cls.SCHEMA.members["label"]) + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=cls.SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class HTTPBigDecimalLabel: + label: Decimal + + ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPBigDecimalLabel") + SCHEMA: ClassVar[Schema] = Schema.collection( + id=ID, + members={ + "label": {"index": 0, "target": BIG_DECIMAL, "traits": [HTTPLabelTrait()]} + }, + ) + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(self.SCHEMA) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_big_decimal(self.SCHEMA.members["label"], self.label) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["label"] = de.read_big_decimal(cls.SCHEMA.members["label"]) + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=cls.SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class HTTPBooleanLabel: + label: bool + + ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPBooleanLabel") + SCHEMA: ClassVar[Schema] = Schema.collection( + id=ID, + members={ + "label": {"index": 0, "target": BOOLEAN, "traits": [HTTPLabelTrait()]} + }, + ) + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(self.SCHEMA) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_boolean(self.SCHEMA.members["label"], self.label) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["label"] = de.read_boolean(cls.SCHEMA.members["label"]) + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=cls.SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class HTTPDefaultTimestampLabel: + label: datetime.datetime + + ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPDefaultTimestampLabel") + SCHEMA: ClassVar[Schema] = Schema.collection( + id=ID, + members={ + "label": {"index": 0, "target": TIMESTAMP, "traits": [HTTPLabelTrait()]} + }, + ) + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(self.SCHEMA) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_timestamp(self.SCHEMA.members["label"], self.label) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["label"] = de.read_timestamp(cls.SCHEMA.members["label"]) + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=cls.SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class HTTPEpochTimestampLabel: + label: datetime.datetime + + ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPEpochTimestampLabel") + SCHEMA: ClassVar[Schema] = Schema.collection( + id=ID, + members={ + "label": { + "index": 0, + "target": TIMESTAMP, + "traits": [HTTPLabelTrait(), TimestampFormatTrait("epoch-seconds")], + } + }, + ) + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(self.SCHEMA) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_timestamp(self.SCHEMA.members["label"], self.label) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["label"] = de.read_timestamp(cls.SCHEMA.members["label"]) + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=cls.SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class HTTPDateTimestampLabel: + label: datetime.datetime + + ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPDateTimestampLabel") + SCHEMA: ClassVar[Schema] = Schema.collection( + id=ID, + members={ + "label": { + "index": 0, + "target": TIMESTAMP, + "traits": [HTTPLabelTrait(), TimestampFormatTrait("http-date")], + } + }, + ) + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(self.SCHEMA) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_timestamp(self.SCHEMA.members["label"], self.label) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["label"] = de.read_timestamp(cls.SCHEMA.members["label"]) + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=cls.SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class HTTPDateTimeTimestampLabel: + label: datetime.datetime + + ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPDateTimeTimestampLabel") + SCHEMA: ClassVar[Schema] = Schema.collection( + id=ID, + members={ + "label": { + "index": 0, + "target": TIMESTAMP, + "traits": [HTTPLabelTrait(), TimestampFormatTrait("date-time")], + } + }, + ) + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(self.SCHEMA) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_timestamp(self.SCHEMA.members["label"], self.label) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["label"] = de.read_timestamp(cls.SCHEMA.members["label"]) + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=cls.SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class HostLabel: + label: str + + ID: ClassVar[ShapeID] = ShapeID("com.smithy#HostLabel") + SCHEMA: ClassVar[Schema] = Schema.collection( + id=ID, + members={ + "label": { + "index": 0, + "target": STRING, + "traits": [HostLabelTrait()], + } + }, + ) + + def serialize(self, serializer: ShapeSerializer) -> None: + with serializer.begin_struct(self.SCHEMA) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_string(self.SCHEMA.members["label"], self.label) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["label"] = de.read_string(cls.SCHEMA.members["label"]) + case _: + raise Exception(f"Unexpected schema: {schema}") + + deserializer.read_struct(schema=cls.SCHEMA, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class HTTPMessage: + method: str = "POST" + destination: URI = URI(host="", path="/") + fields: Fields = field(default_factory=Fields) + body: StreamingBlob = field(repr=False, default=b"") + status: int = 200 + + +class Shape(SerializeableShape, DeserializeableShape, Protocol): ... + + +@dataclass +class HTTPMessageTestCase: + shape: Shape + request: HTTPMessage + http_trait: HTTPTrait = HTTPTrait({"method": "POST", "code": 200, "uri": "/"}) + endpoint_trait: EndpointTrait | None = None + + +# All of these test cases need to be created indirectly because they have mutable +# values and the individual cases are re-used. It would be possible to make a +# test generator in conftest.py that achieves a similar effect, but then you lose +# typing. +def header_cases() -> list[HTTPMessageTestCase]: + return [ + HTTPMessageTestCase( + HTTPHeaders(boolean_member=True), + HTTPMessage( + fields=tuples_to_fields([("boolean", "true")]), + ), + ), + HTTPMessageTestCase( + HTTPHeaders(boolean_list_member=[True, False]), + HTTPMessage( + fields=tuples_to_fields( + [("booleanList", "true"), ("booleanList", "false")] + ), + ), + ), + HTTPMessageTestCase( + HTTPHeaders(integer_member=1), + HTTPMessage( + fields=tuples_to_fields([("integer", "1")]), + ), + ), + HTTPMessageTestCase( + HTTPHeaders(integer_list_member=[1, 2]), + HTTPMessage( + fields=tuples_to_fields([("integerList", "1"), ("integerList", "2")]), + ), + ), + HTTPMessageTestCase( + HTTPHeaders(float_member=1.1), + HTTPMessage( + fields=tuples_to_fields([("float", "1.1")]), + ), + ), + HTTPMessageTestCase( + HTTPHeaders(float_list_member=[1.1, 2.2]), + HTTPMessage( + fields=tuples_to_fields([("floatList", "1.1"), ("floatList", "2.2")]), + ), + ), + HTTPMessageTestCase( + HTTPHeaders(big_decimal_member=Decimal("1.1")), + HTTPMessage( + fields=tuples_to_fields([("bigDecimal", "1.1")]), + ), + ), + HTTPMessageTestCase( + HTTPHeaders(big_decimal_list_member=[Decimal("1.1"), Decimal("2.2")]), + HTTPMessage( + fields=tuples_to_fields( + [("bigDecimalList", "1.1"), ("bigDecimalList", "2.2")] + ), + ), + ), + HTTPMessageTestCase( + HTTPHeaders(string_member="foo"), + HTTPMessage( + fields=tuples_to_fields([("string", "foo")]), + ), + ), + HTTPMessageTestCase( + HTTPHeaders(string_list_member=["spam", "eggs"]), + HTTPMessage( + fields=tuples_to_fields( + [("stringList", "spam"), ("stringList", "eggs")] + ), + ), + ), + HTTPMessageTestCase( + HTTPHeaders( + default_timestamp_member=datetime.datetime(2025, 1, 1, tzinfo=UTC) + ), + HTTPMessage( + fields=tuples_to_fields( + [("defaultTimestamp", "Wed, 01 Jan 2025 00:00:00 GMT")] + ), + ), + ), + HTTPMessageTestCase( + HTTPHeaders( + http_date_timestamp_member=datetime.datetime(2025, 1, 1, tzinfo=UTC) + ), + HTTPMessage( + fields=tuples_to_fields( + [("httpDateTimestamp", "Wed, 01 Jan 2025 00:00:00 GMT")] + ), + ), + ), + HTTPMessageTestCase( + HTTPHeaders( + http_date_list_timestamp_member=[ + datetime.datetime(2025, 1, 1, tzinfo=UTC), + datetime.datetime(2024, 1, 1, tzinfo=UTC), + ] + ), + HTTPMessage( + fields=tuples_to_fields( + [ + ("httpDateListTimestamp", "Wed, 01 Jan 2025 00:00:00 GMT"), + ("httpDateListTimestamp", "Mon, 01 Jan 2024 00:00:00 GMT"), + ] + ), + ), + ), + HTTPMessageTestCase( + HTTPHeaders( + date_time_timestamp_member=datetime.datetime(2025, 1, 1, tzinfo=UTC) + ), + HTTPMessage( + fields=tuples_to_fields( + [("dateTimeTimestamp", "2025-01-01T00:00:00Z")] + ), + ), + ), + HTTPMessageTestCase( + HTTPHeaders( + date_time_list_timestamp_member=[ + datetime.datetime(2025, 1, 1, tzinfo=UTC), + datetime.datetime(2024, 1, 1, tzinfo=UTC), + ] + ), + HTTPMessage( + fields=tuples_to_fields( + [ + ("dateTimeListTimestamp", "2025-01-01T00:00:00Z"), + ("dateTimeListTimestamp", "2024-01-01T00:00:00Z"), + ] + ), + ), + ), + HTTPMessageTestCase( + HTTPHeaders( + epoch_timestamp_member=datetime.datetime(2025, 1, 1, tzinfo=UTC) + ), + HTTPMessage( + fields=tuples_to_fields([("epochTimestamp", "1735689600")]), + ), + ), + HTTPMessageTestCase( + HTTPHeaders( + epoch_list_timestamp_member=[ + datetime.datetime(2025, 1, 1, tzinfo=UTC), + datetime.datetime(2024, 1, 1, tzinfo=UTC), + ] + ), + HTTPMessage( + fields=tuples_to_fields( + [ + ("epochListTimestamp", "1735689600"), + ("epochListTimestamp", "1704067200"), + ] + ), + ), + ), + HTTPMessageTestCase( + HTTPHeaders(string_map_member={"foo": "bar", "baz": "bam"}), + HTTPMessage( + fields=tuples_to_fields([("x-foo", "bar"), ("x-baz", "bam")]), + ), + ), + ] + + +def empty_prefix_header_ser_cases() -> list[HTTPMessageTestCase]: + return [ + HTTPMessageTestCase( + HTTPEmptyPrefixHeaders( + string_map_member={"foo": "bar", "baz": "bam", "string": "string"}, + ), + HTTPMessage( + fields=tuples_to_fields( + [("foo", "bar"), ("baz", "bam"), ("string", "string")] + ), + ), + ), + ] + + +def empty_prefix_header_deser_cases() -> list[HTTPMessageTestCase]: + return [ + HTTPMessageTestCase( + HTTPEmptyPrefixHeaders( + string_member="string", + string_map_member={"foo": "bar", "baz": "bam", "string": "string"}, + ), + HTTPMessage( + fields=tuples_to_fields( + [("foo", "bar"), ("baz", "bam"), ("string", "string")] + ), + ), + ), + ] + + +def query_cases() -> list[HTTPMessageTestCase]: + return [ + HTTPMessageTestCase( + HTTPQuery(boolean_member=True), + HTTPMessage( + destination=URI(host="", path="/", query="boolean=true"), + ), + ), + HTTPMessageTestCase( + HTTPQuery(boolean_list_member=[True, False]), + HTTPMessage( + destination=URI( + host="", path="/", query="booleanList=true&booleanList=false" + ), + ), + ), + HTTPMessageTestCase( + HTTPQuery(integer_member=1), + HTTPMessage(destination=URI(host="", path="/", query="integer=1")), + ), + HTTPMessageTestCase( + HTTPQuery(integer_list_member=[1, 2]), + HTTPMessage( + destination=URI(host="", path="/", query="integerList=1&integerList=2") + ), + ), + HTTPMessageTestCase( + HTTPQuery(float_member=1.1), + HTTPMessage(destination=URI(host="", path="/", query="float=1.1")), + ), + HTTPMessageTestCase( + HTTPQuery(float_list_member=[1.1, 2.2]), + HTTPMessage( + destination=URI(host="", path="/", query="floatList=1.1&floatList=2.2") + ), + ), + HTTPMessageTestCase( + HTTPQuery(big_decimal_member=Decimal("1.1")), + HTTPMessage(destination=URI(host="", path="/", query="bigDecimal=1.1")), + ), + HTTPMessageTestCase( + HTTPQuery(big_decimal_list_member=[Decimal("1.1"), Decimal("2.2")]), + HTTPMessage( + destination=URI( + host="", path="/", query="bigDecimalList=1.1&bigDecimalList=2.2" + ) + ), + ), + HTTPMessageTestCase( + HTTPQuery(string_member="foo"), + HTTPMessage(destination=URI(host="", path="/", query="string=foo")), + ), + HTTPMessageTestCase( + HTTPQuery(string_list_member=["spam", "eggs"]), + HTTPMessage( + destination=URI( + host="", path="/", query="stringList=spam&stringList=eggs" + ) + ), + ), + HTTPMessageTestCase( + HTTPQuery( + default_timestamp_member=datetime.datetime(2025, 1, 1, tzinfo=UTC) + ), + HTTPMessage( + destination=URI( + host="", + path="/", + query="defaultTimestamp=2025-01-01T00%3A00%3A00Z", + ) + ), + ), + HTTPMessageTestCase( + HTTPQuery( + http_date_timestamp_member=datetime.datetime(2025, 1, 1, tzinfo=UTC) + ), + HTTPMessage( + destination=URI( + host="", + path="/", + query="httpDateTimestamp=Wed%2C%2001%20Jan%202025%2000%3A00%3A00%20GMT", + ) + ), + ), + HTTPMessageTestCase( + HTTPQuery( + http_date_list_timestamp_member=[ + datetime.datetime(2025, 1, 1, tzinfo=UTC), + datetime.datetime(2024, 1, 1, tzinfo=UTC), + ] + ), + HTTPMessage( + destination=URI( + host="", + path="/", + query=( + "httpDateListTimestamp=Wed%2C%2001%20Jan%202025%2000%3A00%3A00%20GMT" + "&httpDateListTimestamp=Mon%2C%2001%20Jan%202024%2000%3A00%3A00%20GMT" + ), + ) + ), + ), + HTTPMessageTestCase( + HTTPQuery( + date_time_timestamp_member=datetime.datetime(2025, 1, 1, tzinfo=UTC) + ), + HTTPMessage( + destination=URI( + host="", + path="/", + query="dateTimeTimestamp=2025-01-01T00%3A00%3A00Z", + ) + ), + ), + HTTPMessageTestCase( + HTTPQuery( + date_time_list_timestamp_member=[ + datetime.datetime(2025, 1, 1, tzinfo=UTC), + datetime.datetime(2024, 1, 1, tzinfo=UTC), + ] + ), + HTTPMessage( + destination=URI( + host="", + path="/", + query=( + "dateTimeListTimestamp=2025-01-01T00%3A00%3A00Z" + "&dateTimeListTimestamp=2024-01-01T00%3A00%3A00Z" + ), + ) + ), + ), + HTTPMessageTestCase( + HTTPQuery(epoch_timestamp_member=datetime.datetime(2025, 1, 1, tzinfo=UTC)), + HTTPMessage( + destination=URI(host="", path="/", query="epochTimestamp=1735689600") + ), + ), + HTTPMessageTestCase( + HTTPQuery( + epoch_list_timestamp_member=[ + datetime.datetime(2025, 1, 1, tzinfo=UTC), + datetime.datetime(2024, 1, 1, tzinfo=UTC), + ] + ), + HTTPMessage( + destination=URI( + host="", + path="/", + query="epochListTimestamp=1735689600&epochListTimestamp=1704067200", + ) + ), + ), + HTTPMessageTestCase( + HTTPQuery(string_map_member={"foo": "bar", "baz": "bam"}), + HTTPMessage(destination=URI(host="", path="/", query="foo=bar&baz=bam")), + ), + HTTPMessageTestCase( + HTTPQuery(string_member="foo"), + HTTPMessage( + destination=URI(host="", path="/", query="spam=eggs&string=foo") + ), + http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/?spam=eggs"}), + ), + HTTPMessageTestCase( + HTTPQuery(string_member="foo"), + HTTPMessage(destination=URI(host="", path="/", query="spam&string=foo")), + http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/?spam"}), + ), + ] + + +def label_cases() -> list[HTTPMessageTestCase]: + return [ + HTTPMessageTestCase( + HTTPStringLabel(label="foo/bar"), + HTTPMessage(destination=URI(host="", path="/foo%2Fbar")), + http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label}"}), + ), + HTTPMessageTestCase( + HTTPStringLabel(label="foo/bar"), + HTTPMessage(destination=URI(host="", path="/foo/bar")), + http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label+}"}), + ), + HTTPMessageTestCase( + HTTPFloatLabel(label=1.1), + HTTPMessage(destination=URI(host="", path="/1.1")), + http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label}"}), + ), + HTTPMessageTestCase( + HTTPBigDecimalLabel(label=Decimal("1.1")), + HTTPMessage(destination=URI(host="", path="/1.1")), + http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label}"}), + ), + HTTPMessageTestCase( + HTTPBooleanLabel(label=True), + HTTPMessage(destination=URI(host="", path="/true")), + http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label}"}), + ), + HTTPMessageTestCase( + HTTPDefaultTimestampLabel(label=datetime.datetime(2025, 1, 1, tzinfo=UTC)), + HTTPMessage(destination=URI(host="", path="/2025-01-01T00%3A00%3A00Z")), + http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label}"}), + ), + HTTPMessageTestCase( + HTTPEpochTimestampLabel(label=datetime.datetime(2025, 1, 1, tzinfo=UTC)), + HTTPMessage(destination=URI(host="", path="/1735689600")), + http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label}"}), + ), + HTTPMessageTestCase( + HTTPDateTimeTimestampLabel(label=datetime.datetime(2025, 1, 1, tzinfo=UTC)), + HTTPMessage(destination=URI(host="", path="/2025-01-01T00%3A00%3A00Z")), + http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label}"}), + ), + HTTPMessageTestCase( + HTTPDateTimestampLabel(label=datetime.datetime(2025, 1, 1, tzinfo=UTC)), + HTTPMessage( + destination=URI( + host="", path="/Wed%2C%2001%20Jan%202025%2000%3A00%3A00%20GMT" + ) + ), + http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label}"}), + ), + ] + + +def host_cases() -> list[HTTPMessageTestCase]: + return [ + HTTPMessageTestCase( + HostLabel("foo"), + HTTPMessage( + destination=URI(host="foo.", path="/"), body=BytesIO(b'{"label":"foo"}') + ), + endpoint_trait=EndpointTrait({"hostPrefix": "{label}."}), + ), + HTTPMessageTestCase( + HTTPHeaders(), + HTTPMessage(destination=URI(host="foo.", path="/")), + endpoint_trait=EndpointTrait({"hostPrefix": "foo."}), + ), + ] + + +def payload_cases() -> list[HTTPMessageTestCase]: + return [ + HTTPMessageTestCase( + HTTPImplicitPayload(header="foo", payload_member="bar"), + HTTPMessage( + fields=tuples_to_fields([("header", "foo")]), + body=BytesIO(b'{"payload_member":"bar"}'), + ), + ), + HTTPMessageTestCase( + HTTPStringPayload(payload="foo"), + HTTPMessage(body=b"foo"), + ), + HTTPMessageTestCase( + HTTPBlobPayload(payload=b"\xde\xad\xbe\xef"), + HTTPMessage(body=b"\xde\xad\xbe\xef"), + ), + HTTPMessageTestCase( + HTTPStructuredPayload(payload=HTTPStringPayload(payload="foo")), + HTTPMessage(body=BytesIO(b'{"payload":"foo"}')), + ), + ] + + +def async_streaming_payload_cases() -> list[HTTPMessageTestCase]: + return [ + HTTPMessageTestCase( + HTTPStreamingPayload(payload=AsyncBytesReader(b"\xde\xad\xbe\xef")), + HTTPMessage(body=AsyncBytesReader(b"\xde\xad\xbe\xef")), + ), + ] + + +REQUEST_SER_CASES = ( + header_cases() + + empty_prefix_header_ser_cases() + + query_cases() + + label_cases() + + host_cases() + + payload_cases() + + async_streaming_payload_cases() +) + + +@pytest.mark.parametrize("case", REQUEST_SER_CASES) +async def test_serialize_http_request(case: HTTPMessageTestCase) -> None: + serializer = HTTPRequestSerializer( + payload_codec=JSONCodec(), + http_trait=case.http_trait, + endpoint_trait=case.endpoint_trait, + ) + case.shape.serialize(serializer) + actual = serializer.result + expected = case.request + + assert actual is not None + assert actual.method == expected.method + assert actual.destination.host == expected.destination.host + assert actual.destination.path == expected.destination.path + actual_query = actual.destination.query or "" + expected_query = case.request.destination.query or "" + assert actual_query == expected_query + assert actual.fields == expected.fields + + if case.request.body: + actual_body_value = await AsyncBytesReader(actual.body).read() + expected_body_value = await AsyncBytesReader(case.request.body).read() + assert actual_body_value == expected_body_value + assert type(actual.body) is type(case.request.body) + + +RESPONSE_SER_CASES: list[HTTPMessageTestCase] = ( + header_cases() + empty_prefix_header_ser_cases() + payload_cases() +) + + +@pytest.mark.parametrize("case", RESPONSE_SER_CASES) +async def test_serialize_http_response(case: HTTPMessageTestCase) -> None: + serializer = HTTPResponseSerializer( + payload_codec=JSONCodec(), http_trait=case.http_trait + ) + case.shape.serialize(serializer) + actual = serializer.result + expected = case.request + + assert actual is not None + assert actual.fields == expected.fields + assert actual.status == expected.status + + if case.request.body: + actual_body_value = await AsyncBytesReader(actual.body).read() + expected_body_value = await AsyncBytesReader(case.request.body).read() + assert actual_body_value == expected_body_value + assert type(actual.body) is type(case.request.body) + + +RESPONSE_DESER_CASES: list[HTTPMessageTestCase] = ( + header_cases() + empty_prefix_header_deser_cases() + payload_cases() +) + + +# TODO: Move this to a separate file +@pytest.mark.parametrize("case", RESPONSE_DESER_CASES) +async def test_deserialize_http_response(case: HTTPMessageTestCase) -> None: + body = case.request.body + if (read := getattr(body, "read", None)) is not None and iscoroutinefunction(read): + body = BytesIO(await read()) + deserializer = HTTPResponseDeserializer( + payload_codec=JSONCodec(), + http_trait=case.http_trait, + response=_HTTPResponse( + body=case.request.body, + status=case.request.status, + fields=case.request.fields, + ), + body=body, # type: ignore + ) + actual = type(case.shape).deserialize(deserializer) + assert actual == case.shape + + +async def test_deserialize_http_response_with_async_stream() -> None: + stream = AsyncBytesReader(b"\xde\xad\xbe\xef") + + deserializer = HTTPResponseDeserializer( + payload_codec=JSONCodec(), + http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/"}), + response=_HTTPResponse(body=stream, status=200, fields=Fields()), + ) + actual = HTTPStreamingPayload.deserialize(deserializer) + assert actual == HTTPStreamingPayload(stream)