From a2c694bf201f9b031e9a1d0fb8f8f536cb8eb5f0 Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Fri, 28 Feb 2025 16:55:17 +0100 Subject: [PATCH 1/5] Add design for dynamically registered traits --- designs/serialization.md | 118 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 109 insertions(+), 9 deletions(-) diff --git a/designs/serialization.md b/designs/serialization.md index 65f9c5def..497d275ab 100644 --- a/designs/serialization.md +++ b/designs/serialization.md @@ -80,27 +80,28 @@ implementation and/or additional helper methods. class Schema: id: ShapeID shape_type: ShapeType - traits: dict[ShapeID, "Trait"] = field(default_factory=dict) + traits: dict[ShapeID, "Trait | DynamicTrait"] = field(default_factory=dict) members: dict[str, "Schema"] = field(default_factory=dict) member_target: "Schema | None" = None member_index: int | None = None + @overload + def get_trait[T: "Trait"](self, t: type[T]) -> T | None: ... + @overload + def get_trait(self, t: ShapeID) -> "Trait | DynamicTrait | None": ... + def get_trait(self, t: "type[Trait] | ShapeID") -> "Trait | DynamicTrait | None":\ + return self.traits.get(t if isinstance(t, ShapeID) else t.id) + @classmethod def collection( cls, *, id: ShapeID, shape_type: ShapeType = ShapeType.STRUCTURE, - traits: list["Trait"] | None = None, + traits: list["Trait | DynamicTrait"] | None = None, members: Mapping[str, "MemberSchema"] | None = None, ) -> Self: ... - - -@dataclass(kw_only=True, frozen=True) -class Trait: - id: "ShapeID" - value: "DocumentValue" = field(default_factory=dict) ``` Below is an example Smithy `structure` shape, followed by the `Schema` it would @@ -122,13 +123,112 @@ EXAMPLE_STRUCTURE_SCHEMA = Schema.collection( "target": INTEGER, "index": 0, "traits": [ - Trait(id=ShapeID("smithy.api#default"), value=0), + DefaultTrait(0), ], }, }, ) ``` +### Traits + +Traits are model components that can be attached to shapes to describe +additional information about the shape; shapes provide the structure and layout +of an API, while traits provide refinement and style. Smithy provides a number +of built-in traits, plus a number of additional traits that may be found in +first-party dependencies. In addition to those first-party traits, traits may be +defined externally. + +In Python, there are two kinds of traits. The first is the `DynamicTrait`. This +represents traits that have no known associated Python class. Traits not defined +by Smithy itself may be unknown, for example, but still need representation. + +The other kind of trait inherits from the `Trait` class. This represents known +traits, such as those defined by Smithy itself or those defined externally but +made available in Python. Since these are concrete classes, they may be more +comfortable to use, providing better typed accessors to data or even relevant +utility functions. + +Both kinds of traits implement an inherent `Protocol` - they both have the `id` +and `document_value` properties with identical type signatures. This allows them +to be used interchangeably for those that don't care about the concrete types. +It also allows concrete types to be introduced later without a breaking change. + + +```python +@dataclass(kw_only=True, frozen=True, slots=True) +class DynamicTrait: + id: ShapeID + document_value: DocumentValue = None + + +@dataclass(init=False, frozen=True) +class Trait: + + _REGISTRY: ClassVar[dict[ShapeID, type["Trait"]]] = {} + + id: ClassVar[ShapeID] + + document_value: DocumentValue = None + + def __init_subclass__(cls, id: ShapeID) -> None: + cls.id = id + Trait._REGISTRY[id] = cls + + def __init__(self, value: DocumentValue | DynamicTrait = None): + if type(self) is Trait: + raise TypeError( + "Only subclasses of Trait may be directly instantiated. " + "Use DynamicTrait for traits without a concrete class." + ) + + if isinstance(value, DynamicTrait): + if value.id != self.id: + raise ValueError( + f"Attempted to instantiate an instance of {type(self)} from an " + f"invalid ID. Expected {self.id} but found {value.id}." + ) + # Note that setattr is needed because it's a frozen (read-only) dataclass + object.__setattr__(self, "document_value", value.document_value) + else: + object.__setattr__(self, "document_value", value) + + # Dynamically creates a subclass instance based on the trait id + @staticmethod + def new(id: ShapeID, value: "DocumentValue" = None) -> "Trait | DynamicTrait": + if (cls := Trait._REGISTRY.get(id, None)) is not None: + return cls(value) + return DynamicTrait(id=id, document_value=value) +``` + +The `Trait` class implements a dynamic registry that allows it to know about +trait implementations automatically. The base class maintains a mapping of trait +ID to the trait class. Since implementations must all share the same constructor +signature, it can then use that registry to dynamically construct concrete types +it knows about in the `new` factory method with a fallback to `DynamicTrait`. + +The `new` factory method will be used to construct traits when `Schema`s are +generated, so any generated schemas will be able to take advantage of the +registry. + +Below is an example of a `Trait` implementation. + +```python +@dataclass(init=False, frozen=True) +class TimestampFormatTrait(Trait, id=ShapeID("smithy.api#timestampFormat")): + format: TimestampFormat + + def __init__(self, value: "DocumentValue | DynamicTrait" = None): + super().__init__(value) + assert isinstance(self.document_value, str) + object.__setattr__(self, "format", TimestampFormat(self.document_value)) +``` + +Data in traits is intended to be immutable, so both `DynamicTrait` and `Trait` +are dataclasses with `frozen=True`, and all implementations of `Trait` must also +use that argument. This can be worked around during `__init__` using +`object.__setattr__` to set any additional properties the `Trait` defines. + ## Shape Serializers and Serializeable Shapes Serialization will function by the interaction of two interfaces: From 3af374ece6b93dd37bd3395731284fe3ac6d9ce3 Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Fri, 28 Feb 2025 15:55:21 +0100 Subject: [PATCH 2/5] Make traits dynamically registerable This updates traits to use a dynamic registry to allow concrete implementations to be automatically discovered and created when constructing schemas. --- .../src/aws_event_stream/_private/__init__.py | 4 +- .../_private/deserializers.py | 7 +- .../aws_event_stream/_private/serializers.py | 11 +- .../src/aws_event_stream/_private/traits.py | 9 - .../tests/unit/_private/__init__.py | 18 +- .../smithy-core/src/smithy_core/prelude.py | 25 +-- .../smithy-core/src/smithy_core/schemas.py | 47 ++++- .../smithy-core/src/smithy_core/traits.py | 192 +++++++++++++++++- .../smithy-core/tests/unit/test_documents.py | 8 +- .../smithy-core/tests/unit/test_schemas.py | 45 +++- .../smithy-core/tests/unit/test_traits.py | 58 ++++++ .../src/smithy_json/_private/deserializers.py | 10 +- .../src/smithy_json/_private/documents.py | 12 +- .../src/smithy_json/_private/serializers.py | 12 +- .../src/smithy_json/_private/traits.py | 7 - packages/smithy-json/tests/unit/__init__.py | 17 +- 16 files changed, 380 insertions(+), 102 deletions(-) delete mode 100644 packages/aws-event-stream/src/aws_event_stream/_private/traits.py create mode 100644 packages/smithy-core/tests/unit/test_traits.py delete mode 100644 packages/smithy-json/src/smithy_json/_private/traits.py diff --git a/packages/aws-event-stream/src/aws_event_stream/_private/__init__.py b/packages/aws-event-stream/src/aws_event_stream/_private/__init__.py index 3b4609496..fed38cef2 100644 --- a/packages/aws-event-stream/src/aws_event_stream/_private/__init__.py +++ b/packages/aws-event-stream/src/aws_event_stream/_private/__init__.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from smithy_core.schemas import Schema -from .traits import EVENT_PAYLOAD_TRAIT +from smithy_core.traits import EventPayloadTrait INITIAL_REQUEST_EVENT_TYPE = "initial-request" INITIAL_RESPONSE_EVENT_TYPE = "initial-response" @@ -10,6 +10,6 @@ def get_payload_member(schema: Schema) -> Schema | None: for member in schema.members.values(): - if EVENT_PAYLOAD_TRAIT in member.traits: + if EventPayloadTrait.id in member.traits: return member return None diff --git a/packages/aws-event-stream/src/aws_event_stream/_private/deserializers.py b/packages/aws-event-stream/src/aws_event_stream/_private/deserializers.py index a1fa3619d..2e173f42d 100644 --- a/packages/aws-event-stream/src/aws_event_stream/_private/deserializers.py +++ b/packages/aws-event-stream/src/aws_event_stream/_private/deserializers.py @@ -23,7 +23,7 @@ INITIAL_RESPONSE_EVENT_TYPE, get_payload_member, ) -from .traits import EVENT_HEADER_TRAIT +from smithy_core.traits import EventHeaderTrait INITIAL_MESSAGE_TYPES = (INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE) @@ -158,7 +158,10 @@ def read_struct( headers_deserializer = EventHeaderDeserializer(self._headers) for key in self._headers.keys(): member_schema = schema.members.get(key) - if member_schema is not None and EVENT_HEADER_TRAIT in member_schema.traits: + if ( + member_schema is not None + and EventHeaderTrait.id in member_schema.traits + ): consumer(member_schema, headers_deserializer) if self._payload_deserializer: diff --git a/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py b/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py index a84d0a726..1914f0e5f 100644 --- a/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py +++ b/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py @@ -18,7 +18,6 @@ SpecificShapeSerializer, ) from smithy_core.shapes import ShapeType -from smithy_core.utils import expect_type from smithy_event_stream.aio.interfaces import AsyncEventPublisher from ..events import EventHeaderEncoder, EventMessage @@ -28,7 +27,7 @@ INITIAL_RESPONSE_EVENT_TYPE, get_payload_member, ) -from .traits import ERROR_TRAIT, EVENT_HEADER_TRAIT, MEDIA_TYPE_TRAIT +from smithy_core.traits import ErrorTrait, EventHeaderTrait, MediaTypeTrait _DEFAULT_STRING_CONTENT_TYPE = "text/plain" _DEFAULT_BLOB_CONTENT_TYPE = "application/octet-stream" @@ -103,7 +102,7 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]: headers_encoder = EventHeaderEncoder() - if ERROR_TRAIT in schema.traits: + if ErrorTrait.id in schema.traits: headers_encoder.encode_string(":message-type", "exception") headers_encoder.encode_string( ":exception-type", schema.expect_member_name() @@ -146,8 +145,8 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]: ) def _get_payload_media_type(self, schema: Schema, default: str) -> str: - if (media_type := schema.traits.get(MEDIA_TYPE_TRAIT)) is not None: - return expect_type(str, media_type.value) + if (media_type := schema.get_trait(MediaTypeTrait)) is not None: + return media_type.value match schema.shape_type: case ShapeType.STRING: @@ -215,7 +214,7 @@ def __init__( self._payload_struct_serializer = payload_struct_serializer def before(self, schema: "Schema") -> ShapeSerializer: - if EVENT_HEADER_TRAIT in schema.traits: + if EventHeaderTrait.id in schema.traits: return self._header_serializer return self._payload_struct_serializer diff --git a/packages/aws-event-stream/src/aws_event_stream/_private/traits.py b/packages/aws-event-stream/src/aws_event_stream/_private/traits.py deleted file mode 100644 index 738126282..000000000 --- a/packages/aws-event-stream/src/aws_event_stream/_private/traits.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -from smithy_core.shapes import ShapeID - -EVENT_HEADER_TRAIT = ShapeID("smithy.api#eventHeader") -EVENT_PAYLOAD_TRAIT = ShapeID("smithy.api#eventPayload") -ERROR_TRAIT = ShapeID("smithy.api#error") -MEDIA_TYPE_TRAIT = ShapeID("smithy.api#mediaType") diff --git a/packages/aws-event-stream/tests/unit/_private/__init__.py b/packages/aws-event-stream/tests/unit/_private/__init__.py index 6eb258f2a..b84c61c55 100644 --- a/packages/aws-event-stream/tests/unit/_private/__init__.py +++ b/packages/aws-event-stream/tests/unit/_private/__init__.py @@ -19,15 +19,21 @@ from smithy_core.schemas import Schema from smithy_core.serializers import ShapeSerializer from smithy_core.shapes import ShapeID, ShapeType -from smithy_core.traits import Trait +from smithy_core.traits import ( + EventHeaderTrait, + EventPayloadTrait, + ErrorTrait, + RequiredTrait, + StreamingTrait, +) from aws_event_stream.events import Byte, EventMessage, Long, Short -EVENT_HEADER_TRAIT = Trait(id=ShapeID("smithy.api#eventHeader")) -EVENT_PAYLOAD_TRAIT = Trait(id=ShapeID("smithy.api#eventPayload")) -ERROR_TRAIT = Trait(id=ShapeID("smithy.api#error"), value="client") -REQUIRED_TRAIT = Trait(id=ShapeID("smithy.api#required")) -STREAMING_TRAIT = Trait(id=ShapeID("smith.api#streaming")) +EVENT_HEADER_TRAIT = EventHeaderTrait() +EVENT_PAYLOAD_TRAIT = EventPayloadTrait() +ERROR_TRAIT = ErrorTrait("client") +REQUIRED_TRAIT = RequiredTrait() +STREAMING_TRAIT = StreamingTrait() SCHEMA_MESSAGE_EVENT = Schema.collection( diff --git a/packages/smithy-core/src/smithy_core/prelude.py b/packages/smithy-core/src/smithy_core/prelude.py index ec1d428ed..673b03edd 100644 --- a/packages/smithy-core/src/smithy_core/prelude.py +++ b/packages/smithy-core/src/smithy_core/prelude.py @@ -4,7 +4,8 @@ from .schemas import Schema from .shapes import ShapeID, ShapeType -from .traits import Trait +from .traits import DefaultTrait, UnitTypeTrait + BLOB = Schema( id=ShapeID("smithy.api#Blob"), @@ -71,54 +72,50 @@ shape_type=ShapeType.DOCUMENT, ) - -_DEFAULT = ShapeID("smithy.api#default") - - PRIMITIVE_BOOLEAN = Schema( id=ShapeID("smithy.api#PrimitiveBoolean"), shape_type=ShapeType.BOOLEAN, - traits=[Trait(id=_DEFAULT, value=False)], + traits=[DefaultTrait(False)], ) PRIMITIVE_BYTE = Schema( id=ShapeID("smithy.api#PrimitiveByte"), shape_type=ShapeType.BYTE, - traits=[Trait(id=_DEFAULT, value=0)], + traits=[DefaultTrait(0)], ) PRIMITIVE_SHORT = Schema( id=ShapeID("smithy.api#PrimitiveShort"), shape_type=ShapeType.SHORT, - traits=[Trait(id=_DEFAULT, value=0)], + traits=[DefaultTrait(0)], ) PRIMITIVE_INTEGER = Schema( id=ShapeID("smithy.api#PrimitiveInteger"), shape_type=ShapeType.INTEGER, - traits=[Trait(id=_DEFAULT, value=0)], + traits=[DefaultTrait(0)], ) PRIMITIVE_LONG = Schema( id=ShapeID("smithy.api#PrimitiveLong"), shape_type=ShapeType.LONG, - traits=[Trait(id=_DEFAULT, value=0)], + traits=[DefaultTrait(0)], ) PRIMITIVE_FLOAT = Schema( id=ShapeID("smithy.api#PrimitiveFloat"), shape_type=ShapeType.FLOAT, - traits=[Trait(id=_DEFAULT, value=0.0)], + traits=[DefaultTrait(0)], ) PRIMITIVE_DOUBLE = Schema( id=ShapeID("smithy.api#PrimitiveDouble"), shape_type=ShapeType.DOUBLE, - traits=[Trait(id=_DEFAULT, value=0.0)], + traits=[DefaultTrait(0)], ) UNIT = Schema( id=ShapeID("smithy.api#Unit"), - shape_type=ShapeType.DOUBLE, - traits=[Trait(id=ShapeID("smithy.api#UnitTypeTrait"))], + shape_type=ShapeType.STRUCTURE, + traits=[UnitTypeTrait()], ) diff --git a/packages/smithy-core/src/smithy_core/schemas.py b/packages/smithy-core/src/smithy_core/schemas.py index 967eb7a78..a9feb93d3 100644 --- a/packages/smithy-core/src/smithy_core/schemas.py +++ b/packages/smithy-core/src/smithy_core/schemas.py @@ -1,12 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 from collections.abc import Mapping from dataclasses import dataclass, field, replace -from typing import TYPE_CHECKING, NotRequired, Required, Self, TypedDict +from typing import TYPE_CHECKING, NotRequired, Required, Self, TypedDict, overload from .exceptions import ExpectationNotMetException, SmithyException from .shapes import ShapeID, ShapeType if TYPE_CHECKING: - from .traits import Trait + from .traits import Trait, DynamicTrait @dataclass(kw_only=True, frozen=True, init=False) @@ -15,7 +17,7 @@ class Schema: id: ShapeID shape_type: ShapeType - traits: dict[ShapeID, "Trait"] = field(default_factory=dict) + traits: dict[ShapeID, "Trait | DynamicTrait"] = field(default_factory=dict) members: dict[str, "Schema"] = field(default_factory=dict) member_target: "Schema | None" = None member_index: int | None = None @@ -25,7 +27,9 @@ def __init__( *, id: ShapeID, shape_type: ShapeType, - traits: list["Trait"] | dict[ShapeID, "Trait"] | None = None, + traits: list["Trait | DynamicTrait"] + | dict[ShapeID, "Trait | DynamicTrait"] + | None = None, members: list["Schema"] | dict[str, "Schema"] | None = None, member_target: "Schema | None" = None, member_index: int | None = None, @@ -121,13 +125,42 @@ def expect_member_index(self) -> int: ) return self.member_index + @overload + def get_trait[T: "Trait"](self, t: type[T]) -> T | None: ... + + @overload + def get_trait(self, t: ShapeID) -> "Trait | DynamicTrait | None": ... + + def get_trait(self, t: "type[Trait] | ShapeID") -> "Trait | DynamicTrait | None": + """Get a trait based on it's ShapeID or class. + + :returns: A Trait if the trait class is known, a DynamicTrait if it isn't, or + None if the trait is not present on the Schema. + """ + id = t if isinstance(t, ShapeID) else t.id + return self.traits.get(id) + + @overload + def expect_trait[T: "Trait"](self, t: type[T]) -> T: ... + + @overload + def expect_trait(self, t: ShapeID) -> "Trait | DynamicTrait": ... + + def expect_trait(self, t: "type[Trait] | ShapeID") -> "Trait | DynamicTrait": + """Get a trait based on it's ShapeID or class. + + :returns: A Trait if the trait class is known, a DynamicTrait if it isn't. + """ + id = t if isinstance(t, ShapeID) else t.id + return self.traits[id] + @classmethod def collection( cls, *, id: ShapeID, shape_type: ShapeType = ShapeType.STRUCTURE, - traits: list["Trait"] | None = None, + traits: list["Trait | DynamicTrait"] | None = None, members: Mapping[str, "MemberSchema"] | None = None, ) -> Self: """Create a schema for a collection shape. @@ -164,7 +197,7 @@ def member( id: ShapeID, target: "Schema", index: int, - member_traits: list["Trait"] | None = None, + member_traits: list["Trait | DynamicTrait"] | None = None, ) -> "Schema": """Create a schema for a member shape. @@ -203,4 +236,4 @@ class MemberSchema(TypedDict): target: Required[Schema] index: Required[int] - traits: NotRequired[list["Trait"]] + traits: NotRequired[list["Trait | DynamicTrait"]] diff --git a/packages/smithy-core/src/smithy_core/traits.py b/packages/smithy-core/src/smithy_core/traits.py index 21646a841..0e10310bb 100644 --- a/packages/smithy-core/src/smithy_core/traits.py +++ b/packages/smithy-core/src/smithy_core/traits.py @@ -1,19 +1,195 @@ -from dataclasses import dataclass, field -from typing import TYPE_CHECKING +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# This ruff check warns against using the assert statement, which can be stripped out +# when running Python with certain (common) optimization settings. Assert is used here +# for trait values. Since these are always generated, we can be fairly confident that +# they're correct regardless, so it's okay if the checks are stripped out. +# ruff: noqa: S101 + +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, ClassVar + +from .types import TimestampFormat +from .shapes import ShapeID if TYPE_CHECKING: from .documents import DocumentValue - from .shapes import ShapeID -@dataclass(kw_only=True, frozen=True) +@dataclass(kw_only=True, frozen=True, slots=True) +class DynamicTrait: + """A component that can be attached to a schema to describe additional information + about it. + + Typed traits can be used by creating a :py:class:`Trait` subclass. + """ + + id: ShapeID + """The ID of the trait.""" + + document_value: "DocumentValue" = None + """The value of the trait.""" + + +@dataclass(init=False, frozen=True) class Trait: """A component that can be attached to a schema to describe additional information about it. - :param id: The ID of the trait. - :param value: The document value of the trait. + This is a base class that registers subclasses. Any known subclasses will + automatically be used when constructing schemas. Any unknown traits may instead be + created as a :py:class:`DynamicTrait`. + + The `id` property of subclasses is set during subclass creation by + `__init_subclass__`, so it is not necessary for subclasses to set it manually. """ - id: "ShapeID" - value: "DocumentValue" = field(default_factory=dict) + _REGISTRY: ClassVar[dict[ShapeID, type["Trait"]]] = {} + + id: ClassVar[ShapeID] + """The ID of the trait.""" + + document_value: "DocumentValue" = None + """The value of the trait as a DocumentValue.""" + + def __init_subclass__(cls, id: ShapeID) -> None: + cls.id = id + Trait._REGISTRY[id] = cls + + def __init__(self, value: "DocumentValue | DynamicTrait" = None): + if type(self) is Trait: + raise TypeError( + "Only subclasses of Trait may be directly instantiated. " + "Use DynamicTrait for traits without a concrete class." + ) + + if isinstance(value, DynamicTrait): + if value.id != self.id: + raise ValueError( + f"Attempted to instantiate an instance of {type(self)} from an " + f"invalid ID. Expected {self.id} but found {value.id}." + ) + # Note that setattr is needed because it's a frozen (read-only) dataclass + object.__setattr__(self, "document_value", value.document_value) + else: + object.__setattr__(self, "document_value", value) + + # Dynamically creates a subclass instance based on the trait id + @staticmethod + def new(id: ShapeID, value: "DocumentValue" = None) -> "Trait | DynamicTrait": + """Dynamically create a new trait of the given ID. + + If the ID corresponds to a known Trait class, that class will be instantiated + and returned. Otherwise, a :py:class:`DynamicTrait` will be returned. + + :returns: A trait of the given ID with the given value. + """ + if (cls := Trait._REGISTRY.get(id, None)) is not None: + return cls(value) + return DynamicTrait(id=id, document_value=value) + + +@dataclass(init=False, frozen=True) +class DefaultTrait(Trait, id=ShapeID("smithy.appi#default")): + @property + def value(self) -> "DocumentValue": + return self.document_value + + +@dataclass(init=False, frozen=True) +class SparseTrait(Trait, id=ShapeID("smithy.api#sparse")): + def __post_init__(self): + assert self.document_value is None + + +@dataclass(init=False, frozen=True) +class TimestampFormatTrait(Trait, id=ShapeID("smithy.api#timestampFormat")): + format: TimestampFormat + + def __init__(self, value: "DocumentValue | DynamicTrait" = None): + super().__init__(value) + assert isinstance(self.document_value, str) + object.__setattr__(self, "format", TimestampFormat(self.document_value)) + + +class ErrorFault(Enum): + CLIENT = "client" + SERVER = "server" + + +@dataclass(init=False, frozen=True) +class ErrorTrait(Trait, id=ShapeID("smithy.api#error")): + fault: ErrorFault + + def __init__(self, value: "DocumentValue | DynamicTrait" = None): + super().__init__(value) + assert isinstance(self.document_value, str) + object.__setattr__(self, "fault", ErrorFault(self.document_value)) + + +@dataclass(init=False, frozen=True) +class RequiredTrait(Trait, id=ShapeID("smithy.api#required")): + def __post_init__(self): + assert self.document_value is None + + +@dataclass(init=False, frozen=True) +class InternalTrait(Trait, id=ShapeID("smithy.api#internal")): + def __post_init__(self): + assert self.document_value is None + + +@dataclass(init=False, frozen=True) +class SensitiveTrait(Trait, id=ShapeID("smithy.api#sensitive")): + def __post_init__(self): + assert self.document_value is None + + +@dataclass(init=False, frozen=True) +class StreamingTrait(Trait, id=ShapeID("smithy.api#streaming")): + def __post_init__(self): + assert self.document_value is None + + +@dataclass(init=False, frozen=True) +class UnitTypeTrait(Trait, id=ShapeID("smithy.api#UnitTypeTrait")): + def __post_init__(self): + assert self.document_value is None + + +@dataclass(init=False, frozen=True) +class MediaTypeTrait(Trait, id=ShapeID("smithy.api#mediaType")): + document_value: str | None = None + + def __post_init__(self): + assert isinstance(self.document_value, str) + + @property + def value(self) -> str: + return self.document_value # type: ignore + + +@dataclass(init=False, frozen=True) +class EventHeaderTrait(Trait, id=ShapeID("smithy.api#eventheader")): + def __post_init__(self): + assert self.document_value is None + + +@dataclass(init=False, frozen=True) +class EventPayloadTrait(Trait, id=ShapeID("smithy.api#eventPayload")): + def __post_init__(self): + assert self.document_value is None + + +@dataclass(init=False, frozen=True) +class JSONNameTrait(Trait, id=ShapeID("smithy.api#jsonName")): + document_value: str | None = None + + def __post_init__(self): + assert isinstance(self.document_value, str) + + @property + def value(self) -> str: + return self.document_value # type: ignore diff --git a/packages/smithy-core/tests/unit/test_documents.py b/packages/smithy-core/tests/unit/test_documents.py index 4f7fd93b0..04963435b 100644 --- a/packages/smithy-core/tests/unit/test_documents.py +++ b/packages/smithy-core/tests/unit/test_documents.py @@ -27,7 +27,8 @@ from smithy_core.schemas import Schema from smithy_core.serializers import ShapeSerializer from smithy_core.shapes import ShapeID, ShapeType -from smithy_core.traits import Trait + +from smithy_core.traits import SparseTrait @pytest.mark.parametrize( @@ -515,7 +516,6 @@ def test_is_none(): assert not Document("foo").is_none() -SPARSE_TRAIT = Trait(id=ShapeID("smithy.api#sparse")) STRING_LIST_SCHEMA = Schema.collection( id=ShapeID("smithy.example#StringList"), shape_type=ShapeType.LIST, @@ -533,7 +533,7 @@ def test_is_none(): id=ShapeID("smithy.example#StringList"), shape_type=ShapeType.LIST, members={"member": {"target": STRING, "index": 0}}, - traits=[SPARSE_TRAIT], + traits=[SparseTrait()], ) SPARSE_STRING_MAP_SCHEMA = Schema.collection( id=ShapeID("smithy.example#StringMap"), @@ -542,7 +542,7 @@ def test_is_none(): "key": {"target": STRING, "index": 0}, "value": {"target": STRING, "index": 1}, }, - traits=[SPARSE_TRAIT], + traits=[SparseTrait()], ) SCHEMA: Schema = Schema.collection( id=ShapeID("smithy.example#DocumentSerdeShape"), diff --git a/packages/smithy-core/tests/unit/test_schemas.py b/packages/smithy-core/tests/unit/test_schemas.py index d06c44f78..9e822b36b 100644 --- a/packages/smithy-core/tests/unit/test_schemas.py +++ b/packages/smithy-core/tests/unit/test_schemas.py @@ -5,17 +5,40 @@ from smithy_core.exceptions import ExpectationNotMetException from smithy_core.schemas import Schema from smithy_core.shapes import ShapeID, ShapeType -from smithy_core.traits import Trait +from smithy_core.traits import InternalTrait, DynamicTrait, SensitiveTrait ID: ShapeID = ShapeID("ns.foo#bar") STRING = Schema(id=ShapeID("smithy.api#String"), shape_type=ShapeType.STRING) def test_traits_list(): - trait_id = ShapeID("smithy.api#internal") - trait = Trait(id=trait_id, value=True) + trait = InternalTrait() schema = Schema(id=ID, shape_type=ShapeType.STRUCTURE, traits=[trait]) - assert schema.traits == {trait_id: trait} + assert schema.traits == {InternalTrait.id: trait} + + +def test_get_trait_by_class(): + trait = InternalTrait() + schema = Schema(id=ID, shape_type=ShapeType.STRUCTURE, traits=[trait]) + assert schema.get_trait(InternalTrait) is trait + + +def test_get_unknown_trait_by_class(): + trait = InternalTrait() + schema = Schema(id=ID, shape_type=ShapeType.STRUCTURE, traits=[trait]) + assert schema.get_trait(SensitiveTrait) is None + + +def test_get_trait_by_id(): + trait = InternalTrait() + schema = Schema(id=ID, shape_type=ShapeType.STRUCTURE, traits=[trait]) + assert schema.get_trait(InternalTrait.id) is trait + + +def test_get_unknown_trait_by_id(): + trait = InternalTrait() + schema = Schema(id=ID, shape_type=ShapeType.STRUCTURE, traits=[trait]) + assert schema.get_trait(SensitiveTrait.id) is None def test_members_list(): @@ -59,7 +82,7 @@ def test_member_expectations_raise_on_non_members(): def test_collection_constructor(): - trait_value = Trait(id=ShapeID("smithy.example#trait"), value="foo") + trait_value = DynamicTrait(id=ShapeID("smithy.example#trait"), document_value="foo") member_name = "baz" member = Schema( id=ID.with_member(member_name), @@ -79,8 +102,8 @@ def test_member_constructor(): target = Schema.collection( id=ShapeID("smithy.example#target"), traits=[ - Trait(id=ShapeID("smithy.api#sensitive")), - Trait(id=ShapeID("smithy.example#foo"), value="bar"), + SensitiveTrait(), + DynamicTrait(id=ShapeID("smithy.example#foo"), document_value="bar"), ], members={"spam": {"target": STRING, "index": 0}}, ) @@ -92,8 +115,8 @@ def test_member_constructor(): member_target=target, member_index=1, traits=[ - Trait(id=ShapeID("smithy.api#sensitive")), - Trait(id=ShapeID("smithy.example#foo"), value="baz"), + SensitiveTrait(), + DynamicTrait(id=ShapeID("smithy.example#foo"), document_value="baz"), ], ) @@ -101,7 +124,9 @@ def test_member_constructor(): id=member_id, target=target, index=1, - member_traits=[Trait(id=ShapeID("smithy.example#foo"), value="baz")], + member_traits=[ + DynamicTrait(id=ShapeID("smithy.example#foo"), document_value="baz") + ], ) assert actual == expected diff --git a/packages/smithy-core/tests/unit/test_traits.py b/packages/smithy-core/tests/unit/test_traits.py new file mode 100644 index 000000000..f53428859 --- /dev/null +++ b/packages/smithy-core/tests/unit/test_traits.py @@ -0,0 +1,58 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +from smithy_core.traits import ( + DynamicTrait, + Trait, + ErrorTrait, + ErrorFault, + JSONNameTrait, +) +from smithy_core.shapes import ShapeID + +import pytest + + +def test_trait_factory_constructs_dynamic_trait(): + trait_id = ShapeID("com.example#foo") + document_value = "bar" + trait = Trait.new(id=trait_id, value=document_value) + assert isinstance(trait, DynamicTrait) + assert trait.id == trait_id + assert trait.document_value == document_value + + +def test_trait_factory_constructs_prelude_trait(): + trait = Trait.new(ErrorTrait.id, "client") + assert isinstance(trait, ErrorTrait) + assert trait.fault is ErrorFault.CLIENT + + +def test_trait_factory_constructs_new_trait(): + trait_id = ShapeID("com.example#newTrait") + + @dataclass(init=False, frozen=True) + class NewTrait(Trait, id=trait_id): + pass + + trait = Trait.new(trait_id) + assert isinstance(trait, NewTrait) + assert NewTrait.id is trait_id + + +def test_cant_construct_base_trait(): + with pytest.raises(TypeError): + Trait("foo") + + +def test_construct_from_dynamic_trait(): + dynamic = DynamicTrait(id=ErrorTrait.id, document_value="server") + static = ErrorTrait(dynamic) + assert static.fault is ErrorFault.SERVER + + +def test_cant_construct_trait_from_non_matching_dynamic_trait(): + dynamic = DynamicTrait(id=JSONNameTrait.id, document_value="client") + with pytest.raises(ValueError): + ErrorTrait(dynamic) diff --git a/packages/smithy-json/src/smithy_json/_private/deserializers.py b/packages/smithy-json/src/smithy_json/_private/deserializers.py index 05c9b967e..09db3b410 100644 --- a/packages/smithy-json/src/smithy_json/_private/deserializers.py +++ b/packages/smithy-json/src/smithy_json/_private/deserializers.py @@ -18,7 +18,7 @@ from smithy_core.types import TimestampFormat from .documents import JSONDocument -from .traits import JSON_NAME, TIMESTAMP_FORMAT +from smithy_core.traits import TimestampFormatTrait, JSONNameTrait # TODO: put these type hints in a pyi somewhere. There here because ijson isn't # typed. @@ -191,8 +191,8 @@ def read_document(self, schema: Schema) -> Document: def read_timestamp(self, schema: Schema) -> datetime.datetime: format = self._default_timestamp_format if self._use_timestamp_format: - if format_trait := schema.traits.get(TIMESTAMP_FORMAT): - format = TimestampFormat(format_trait.value) + if format_trait := schema.get_trait(TimestampFormatTrait): + format = format_trait.format match format: case TimestampFormat.EPOCH_SECONDS: @@ -234,8 +234,8 @@ def _cache_json_names(self, schema: Schema): result: dict[str, str] = {} for member_name, member_schema in schema.members.items(): name: str = member_name - if json_name := member_schema.traits.get(JSON_NAME): - name = cast(str, json_name.value) + if json_name := member_schema.get_trait(JSONNameTrait): + name = json_name.value result[name] = member_name self._json_names[schema.id] = result diff --git a/packages/smithy-json/src/smithy_json/_private/documents.py b/packages/smithy-json/src/smithy_json/_private/documents.py index 30d92f88c..11750d5d8 100644 --- a/packages/smithy-json/src/smithy_json/_private/documents.py +++ b/packages/smithy-json/src/smithy_json/_private/documents.py @@ -5,7 +5,6 @@ from collections.abc import Mapping from datetime import datetime from decimal import Decimal -from typing import cast from smithy_core.documents import Document, DocumentValue from smithy_core.prelude import DOCUMENT @@ -14,7 +13,7 @@ from smithy_core.types import TimestampFormat from smithy_core.utils import expect_type -from .traits import JSON_NAME, TIMESTAMP_FORMAT +from smithy_core.traits import JSONNameTrait, TimestampFormatTrait class JSONDocument(Document): @@ -41,9 +40,8 @@ def __init__( ShapeType.UNION, ): for member_name, member_schema in schema.members.items(): - if json_name := member_schema.traits.get(JSON_NAME): - name = cast(str, json_name.value) - self._json_names[name] = member_name + if json_name := member_schema.get_trait(JSONNameTrait): + self._json_names[json_name.value] = member_name def as_blob(self) -> bytes: return b64decode(expect_type(str, self._value)) @@ -56,8 +54,8 @@ def as_float(self) -> float: def as_timestamp(self) -> datetime: format = self._default_timestamp_format if self._use_timestamp_format: - if format_trait := self._schema.traits.get(TIMESTAMP_FORMAT): - format = TimestampFormat(format_trait.value) + if format_trait := self._schema.get_trait(TimestampFormatTrait): + format = format_trait.format match self._value: case float() | int() | str(): diff --git a/packages/smithy-json/src/smithy_json/_private/serializers.py b/packages/smithy-json/src/smithy_json/_private/serializers.py index 91f8eb50f..9f48e43f7 100644 --- a/packages/smithy-json/src/smithy_json/_private/serializers.py +++ b/packages/smithy-json/src/smithy_json/_private/serializers.py @@ -8,7 +8,7 @@ from decimal import Decimal from io import BufferedWriter, RawIOBase from types import TracebackType -from typing import Self, cast +from typing import Self from smithy_core.documents import Document, DocumentValue from smithy_core.interfaces import BytesWriter @@ -19,9 +19,9 @@ ShapeSerializer, ) from smithy_core.types import TimestampFormat +from smithy_core.traits import TimestampFormatTrait, JSONNameTrait from . import Flushable -from .traits import JSON_NAME, TIMESTAMP_FORMAT _INF: float = float("inf") _NEG_INF: float = float("-inf") @@ -84,8 +84,8 @@ def write_blob(self, schema: "Schema", value: bytes) -> None: def write_timestamp(self, schema: "Schema", value: datetime) -> None: format = self._default_timestamp_format if self._use_timestamp_format: - if format_trait := schema.traits.get(TIMESTAMP_FORMAT): - format = TimestampFormat(format_trait.value) + if format_trait := schema.get_trait(TimestampFormatTrait): + format = format_trait.format self._stream.write_document_value(format.serialize(value)) @@ -132,8 +132,8 @@ def before(self, schema: "Schema") -> ShapeSerializer: self._stream.write_more() member_name = schema.expect_member_name() - if self._use_json_name and (json_name := schema.traits.get(JSON_NAME)): - member_name = cast(str, json_name.value) + if self._use_json_name and (json_name := schema.get_trait(JSONNameTrait)): + member_name = json_name.value self._stream.write_key(member_name) return self._parent diff --git a/packages/smithy-json/src/smithy_json/_private/traits.py b/packages/smithy-json/src/smithy_json/_private/traits.py deleted file mode 100644 index a5eecaa6e..000000000 --- a/packages/smithy-json/src/smithy_json/_private/traits.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -from smithy_core.shapes import ShapeID - -JSON_NAME = ShapeID("smithy.api#jsonName") -TIMESTAMP_FORMAT = ShapeID("smithy.api#timestampFormat") diff --git a/packages/smithy-json/tests/unit/__init__.py b/packages/smithy-json/tests/unit/__init__.py index 92c71b7b2..9f16ab878 100644 --- a/packages/smithy-json/tests/unit/__init__.py +++ b/packages/smithy-json/tests/unit/__init__.py @@ -18,11 +18,10 @@ from smithy_core.schemas import Schema from smithy_core.serializers import ShapeSerializer from smithy_core.shapes import ShapeID, ShapeType -from smithy_core.traits import Trait -from smithy_json._private.traits import JSON_NAME, TIMESTAMP_FORMAT +from smithy_core.traits import TimestampFormatTrait, JSONNameTrait, SparseTrait + -SPARSE_TRAIT = Trait(id=ShapeID("smithy.api#sparse")) STRING_LIST_SCHEMA = Schema.collection( id=ShapeID("smithy.example#StringList"), shape_type=ShapeType.LIST, @@ -40,7 +39,7 @@ id=ShapeID("smithy.example#StringList"), shape_type=ShapeType.LIST, members={"member": {"target": STRING, "index": 0}}, - traits=[SPARSE_TRAIT], + traits=[SparseTrait()], ) SPARSE_STRING_MAP_SCHEMA = Schema.collection( id=ShapeID("smithy.example#StringMap"), @@ -49,7 +48,7 @@ "key": {"target": STRING, "index": 0}, "value": {"target": STRING, "index": 1}, }, - traits=[SPARSE_TRAIT], + traits=[SparseTrait()], ) SCHEMA: Schema = Schema.collection( id=ShapeID("smithy.example#SerdeShape"), @@ -61,24 +60,24 @@ "stringMember": {"target": STRING, "index": 4}, "jsonNameMember": { "target": STRING, - "traits": [Trait(id=JSON_NAME, value="jsonName")], + "traits": [JSONNameTrait("jsonName")], "index": 5, }, "blobMember": {"target": BLOB, "index": 6}, "timestampMember": {"target": TIMESTAMP, "index": 7}, "dateTimeMember": { "target": TIMESTAMP, - "traits": [Trait(id=TIMESTAMP_FORMAT, value="date-time")], + "traits": [TimestampFormatTrait("date-time")], "index": 8, }, "httpDateMember": { "target": TIMESTAMP, - "traits": [Trait(id=TIMESTAMP_FORMAT, value="http-date")], + "traits": [TimestampFormatTrait("http-date")], "index": 9, }, "epochSecondsMember": { "target": TIMESTAMP, - "traits": [Trait(id=TIMESTAMP_FORMAT, value="epoch-seconds")], + "traits": [TimestampFormatTrait("epoch-seconds")], "index": 10, }, "documentMember": {"target": DOCUMENT, "index": 11}, From a33dcb92c5aef96d465878d4e2bf4903f232294a Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Fri, 28 Feb 2025 16:03:18 +0100 Subject: [PATCH 3/5] Update generator to use dynamic trait factory --- .../smithy/python/codegen/generators/SchemaGenerator.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/SchemaGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/SchemaGenerator.java index 413cf5e0f..56b6e3021 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/SchemaGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/SchemaGenerator.java @@ -122,7 +122,7 @@ private void writeTraits(PythonWriter writer, Map> trait writer.putContext("traits", traits); writer.write(""" ${#traits} - Trait(id=ShapeID(${key:S})${?value}, value=${value:N}${/value}), + Trait.new(id=ShapeID(${key:S})${?value}, value=${value:N}${/value}), ${/traits}"""); writer.popState(); } From 17532101454a0d8b24b5ea69535dface6cc58d56 Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Tue, 4 Mar 2025 13:58:27 +0100 Subject: [PATCH 4/5] Instantiate previously unknown trait classes --- packages/smithy-core/src/smithy_core/schemas.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/packages/smithy-core/src/smithy_core/schemas.py b/packages/smithy-core/src/smithy_core/schemas.py index a9feb93d3..36d18bca4 100644 --- a/packages/smithy-core/src/smithy_core/schemas.py +++ b/packages/smithy-core/src/smithy_core/schemas.py @@ -137,8 +137,18 @@ def get_trait(self, t: "type[Trait] | ShapeID") -> "Trait | DynamicTrait | None" :returns: A Trait if the trait class is known, a DynamicTrait if it isn't, or None if the trait is not present on the Schema. """ - id = t if isinstance(t, ShapeID) else t.id - return self.traits.get(id) + if isinstance(t, ShapeID): + return self.traits.get(t) + + result = self.traits.get(t.id) + + # If the trait wasn't known when the schema was created, but is known now, go + # ahead and convert it. + if isinstance(result, DynamicTrait): + result = t(result) + self.traits[t.id] = result + + return result @overload def expect_trait[T: "Trait"](self, t: type[T]) -> T: ... From 5dfbbc4d78c2fab960a228800709df812b877e5b Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Tue, 4 Mar 2025 15:43:24 +0100 Subject: [PATCH 5/5] Implement __contains__ for Schema --- .../src/aws_event_stream/_private/__init__.py | 2 +- .../_private/deserializers.py | 5 +--- .../aws_event_stream/_private/serializers.py | 4 +-- .../smithy-core/src/smithy_core/schemas.py | 24 ++++++++++++--- .../smithy-core/tests/unit/test_schemas.py | 30 ++++++++++++++++++- 5 files changed, 53 insertions(+), 12 deletions(-) diff --git a/packages/aws-event-stream/src/aws_event_stream/_private/__init__.py b/packages/aws-event-stream/src/aws_event_stream/_private/__init__.py index fed38cef2..5ec7ff9cd 100644 --- a/packages/aws-event-stream/src/aws_event_stream/_private/__init__.py +++ b/packages/aws-event-stream/src/aws_event_stream/_private/__init__.py @@ -10,6 +10,6 @@ def get_payload_member(schema: Schema) -> Schema | None: for member in schema.members.values(): - if EventPayloadTrait.id in member.traits: + if EventPayloadTrait in member: return member return None diff --git a/packages/aws-event-stream/src/aws_event_stream/_private/deserializers.py b/packages/aws-event-stream/src/aws_event_stream/_private/deserializers.py index 2e173f42d..864038d57 100644 --- a/packages/aws-event-stream/src/aws_event_stream/_private/deserializers.py +++ b/packages/aws-event-stream/src/aws_event_stream/_private/deserializers.py @@ -158,10 +158,7 @@ def read_struct( headers_deserializer = EventHeaderDeserializer(self._headers) for key in self._headers.keys(): member_schema = schema.members.get(key) - if ( - member_schema is not None - and EventHeaderTrait.id in member_schema.traits - ): + if member_schema is not None and EventHeaderTrait in member_schema: consumer(member_schema, headers_deserializer) if self._payload_deserializer: diff --git a/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py b/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py index 1914f0e5f..6ea9c1e22 100644 --- a/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py +++ b/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py @@ -102,7 +102,7 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]: headers_encoder = EventHeaderEncoder() - if ErrorTrait.id in schema.traits: + if ErrorTrait in schema: headers_encoder.encode_string(":message-type", "exception") headers_encoder.encode_string( ":exception-type", schema.expect_member_name() @@ -214,7 +214,7 @@ def __init__( self._payload_struct_serializer = payload_struct_serializer def before(self, schema: "Schema") -> ShapeSerializer: - if EventHeaderTrait.id in schema.traits: + if EventHeaderTrait in schema: return self._header_serializer return self._payload_struct_serializer diff --git a/packages/smithy-core/src/smithy_core/schemas.py b/packages/smithy-core/src/smithy_core/schemas.py index 36d18bca4..1b59c733d 100644 --- a/packages/smithy-core/src/smithy_core/schemas.py +++ b/packages/smithy-core/src/smithy_core/schemas.py @@ -2,13 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Mapping from dataclasses import dataclass, field, replace -from typing import TYPE_CHECKING, NotRequired, Required, Self, TypedDict, overload +from typing import NotRequired, Required, Self, TypedDict, overload, Any from .exceptions import ExpectationNotMetException, SmithyException from .shapes import ShapeID, ShapeType - -if TYPE_CHECKING: - from .traits import Trait, DynamicTrait +from .traits import Trait, DynamicTrait @dataclass(kw_only=True, frozen=True, init=False) @@ -164,6 +162,24 @@ def expect_trait(self, t: "type[Trait] | ShapeID") -> "Trait | DynamicTrait": id = t if isinstance(t, ShapeID) else t.id return self.traits[id] + def __contains__(self, item: Any): + """Returns whether the schema has the given member or trait.""" + match item: + case type(): + if issubclass(item, Trait): + return item.id in self.traits + return False + case ShapeID(): + if (member := item.member) is not None: + if self.id.with_member(member) == item: + return member in self.members + return False + return item in self.traits + case str(): + return item in self.members + case _: + return False + @classmethod def collection( cls, diff --git a/packages/smithy-core/tests/unit/test_schemas.py b/packages/smithy-core/tests/unit/test_schemas.py index 9e822b36b..3ccb9ce3f 100644 --- a/packages/smithy-core/tests/unit/test_schemas.py +++ b/packages/smithy-core/tests/unit/test_schemas.py @@ -2,10 +2,16 @@ import pytest +from typing import Any + from smithy_core.exceptions import ExpectationNotMetException from smithy_core.schemas import Schema from smithy_core.shapes import ShapeID, ShapeType -from smithy_core.traits import InternalTrait, DynamicTrait, SensitiveTrait +from smithy_core.traits import ( + InternalTrait, + DynamicTrait, + SensitiveTrait, +) ID: ShapeID = ShapeID("ns.foo#bar") STRING = Schema(id=ShapeID("smithy.api#String"), shape_type=ShapeType.STRING) @@ -143,3 +149,25 @@ def test_member_constructor_asserts_target_is_not_member(): ) with pytest.raises(ExpectationNotMetException): Schema.member(id=ShapeID("smithy.example#Foo$bar"), target=target, index=0) + + +@pytest.mark.parametrize( + "item, contains", + [ + (SensitiveTrait, True), + (SensitiveTrait.id, True), + (InternalTrait, False), + (InternalTrait.id, False), + ("baz", True), + (ID.with_member("baz"), True), + (ID, False), + ], +) +def test_contains(item: Any, contains: bool): + schema = Schema.collection( + id=ID, + members={"baz": {"target": STRING, "index": 0}}, + traits=[SensitiveTrait()], + ) + + assert (item in schema) == contains