diff --git a/codegen/smithy-python-codegen-test/model/main.smithy b/codegen/smithy-python-codegen-test/model/main.smithy index cdc85811f..39ee49502 100644 --- a/codegen/smithy-python-codegen-test/model/main.smithy +++ b/codegen/smithy-python-codegen-test/model/main.smithy @@ -20,11 +20,14 @@ service Weather { operations: [ GetCurrentTime TestUnionListOperation + StreamAtmosphericConditions ] } resource City { - identifiers: { cityId: CityId } + identifiers: { + cityId: CityId + } read: GetCity list: ListCities resources: [ @@ -56,12 +59,16 @@ union UnionListMember { } resource Forecast { - identifiers: { cityId: CityId } + identifiers: { + cityId: CityId + } read: GetForecast } resource CityImage { - identifiers: { cityId: CityId } + identifiers: { + cityId: CityId + } read: GetCityImage } @@ -622,6 +629,67 @@ union Precipitation { baz: example.weather.nested.more#Baz } +@http(method: "POST", uri: "/cities/{cityId}/atmosphere") +operation StreamAtmosphericConditions { + input := { + @required + @httpLabel + cityId: CityId + + @required + @httpPayload + stream: AtmosphericConditions + } + + output := { + @required + @httpHeader("x-initial-sample-rate") + initialSampleRate: Double + + @required + @httpPayload + stream: CollectionDirectives + } +} + +@streaming +union AtmosphericConditions { + humidity: HumiditySample + pressure: PressureSample + temperature: TemperatureSample +} + +@mixin +structure Sample { + @required + collectionTime: Timestamp +} + +structure HumiditySample with [Sample] { + @required + humidity: Double +} + +structure PressureSample with [Sample] { + @required + pressure: Double +} + +structure TemperatureSample with [Sample] { + @required + temperature: Double +} + +@streaming +union CollectionDirectives { + sampleRate: SampleRate +} + +structure SampleRate { + @required + samplesPerMinute: Double +} + structure OtherStructure {} enum StringYesNo { diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java index 9ea91c73a..930687687 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java @@ -20,6 +20,8 @@ import java.util.LinkedHashSet; import java.util.Set; import software.amazon.smithy.codegen.core.SymbolReference; +import software.amazon.smithy.model.knowledge.EventStreamIndex; +import software.amazon.smithy.model.knowledge.EventStreamInfo; import software.amazon.smithy.model.knowledge.ServiceIndex; import software.amazon.smithy.model.knowledge.TopDownIndex; import software.amazon.smithy.model.shapes.OperationShape; @@ -104,8 +106,14 @@ def __init__(self, config: $1T | None = None, plugins: list[$2T] | None = None): """, configSymbol, pluginSymbol, writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins))); var topDownIndex = TopDownIndex.of(context.model()); + var eventStreamIndex = EventStreamIndex.of(context.model()); for (OperationShape operation : topDownIndex.getContainedOperations(service)) { - generateOperation(writer, operation); + if (eventStreamIndex.getInputInfo(operation).isPresent() + || eventStreamIndex.getOutputInfo(operation).isPresent()) { + generateEventStreamOperation(writer, operation); + } else { + generateOperation(writer, operation); + } } }); @@ -695,4 +703,44 @@ private void generateOperation(PythonWriter writer, OperationShape operation) { } }); } + + private void generateEventStreamOperation(PythonWriter writer, OperationShape operation) { + writer.pushState(); + writer.addDependency(SmithyPythonDependency.SMITHY_EVENT_STREAM); + writer.addImports("smithy_event_stream.aio.interfaces", Set.of( + "EventStream", "InputEventStream", "OutputEventStream")); + var operationSymbol = context.symbolProvider().toSymbol(operation); + var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings()); + + var input = context.model().expectShape(operation.getInputShape()); + var inputSymbol = context.symbolProvider().toSymbol(input); + + var eventStreamIndex = EventStreamIndex.of(context.model()); + var inputStreamSymbol = eventStreamIndex.getInputInfo(operation) + .map(EventStreamInfo::getEventStreamTarget) + .map(target -> context.symbolProvider().toSymbol(target)) + .orElse(null); + writer.putContext("inputStream", inputStreamSymbol); + + var output = context.model().expectShape(operation.getOutputShape()); + var outputSymbol = context.symbolProvider().toSymbol(output); + var outputStreamSymbol = eventStreamIndex.getOutputInfo(operation) + .map(EventStreamInfo::getEventStreamTarget) + .map(target -> context.symbolProvider().toSymbol(target)) + .orElse(null); + writer.putContext("outputStream", outputStreamSymbol); + + writer.write(""" + async def $L(self, input: $T, plugins: list[$T] | None = None) -> EventStream[ + ${?inputStream}InputEventStream[${inputStream:T}]${/inputStream}\ + ${^inputStream}None${/inputStream}, + ${?outputStream}OutputEventStream[${outputStream:T}]${/outputStream}\ + ${^outputStream}None${/outputStream}, + $T + ]: + raise NotImplementedError() + """, operationSymbol.getName(), inputSymbol, pluginSymbol, outputSymbol); + + writer.popState(); + } } diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SmithyPythonDependency.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SmithyPythonDependency.java index e96ef2517..d1af2a4d5 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SmithyPythonDependency.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SmithyPythonDependency.java @@ -64,6 +64,13 @@ public final class SmithyPythonDependency { false ); + public static final PythonDependency SMITHY_EVENT_STREAM = new PythonDependency( + "smithy_event_stream", + "==0.0.1", + Type.DEPENDENCY, + false + ); + /** * testing framework used in generated functional tests. */ diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/UnionGenerator.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/UnionGenerator.java index 19b25ab90..eba7c85f7 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/UnionGenerator.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/UnionGenerator.java @@ -72,16 +72,20 @@ public void run() { writer.write(""" @dataclass - class $L: - ${C|} + class $1L: + ${2C|} - value: $T + value: $3T def serialize(self, serializer: ShapeSerializer): - serializer.write_struct($T, self) + serializer.write_struct($4T, self) def serialize_members(self, serializer: ShapeSerializer): - ${C|} + ${5C|} + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + return cls(value=${6C|}) """, memberSymbol.getName(), @@ -90,7 +94,11 @@ def serialize_members(self, serializer: ShapeSerializer): targetSymbol, schemaSymbol, writer.consumer(w -> target.accept( - new MemberSerializerGenerator(context, w, member, "serializer")))); + new MemberSerializerGenerator(context, w, member, "serializer"))), + writer.consumer(w -> target.accept( + new MemberDeserializerGenerator(context, w, member, "deserializer"))) + + ); } // Note that the unknown variant doesn't implement __eq__. This is because @@ -118,11 +126,15 @@ raise SmithyException("Unknown union variants may not be serialized.") def serialize_members(self, serializer: ShapeSerializer): raise SmithyException("Unknown union variants may not be serialized.") + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + raise NotImplementedError() + """, unknownSymbol.getName()); memberNames.add(unknownSymbol.getName()); - shape.getTrait(DocumentationTrait.class).ifPresent(trait -> writer.writeComment(trait.getValue())); writer.write("type $L = $L\n", parentName, String.join(" | ", memberNames)); + shape.getTrait(DocumentationTrait.class).ifPresent(trait -> writer.writeDocs(trait.getValue())); generateDeserializer(); writer.popState(); @@ -173,13 +185,10 @@ raise SmithyException("Unions must have exactly one value, but found more than o private void deserializeMembers() { int index = 0; for (MemberShape member : shape.members()) { - var target = model.expectShape(member.getTarget()); writer.write(""" case $L: - self._set_result($T(${C|})) - """, index++, symbolProvider.toSymbol(member), writer.consumer(w -> - target.accept(new MemberDeserializerGenerator(context, writer, member, "de")) - )); + self._set_result($T.deserialize(de)) + """, index++, symbolProvider.toSymbol(member)); } } } diff --git a/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py b/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py new file mode 100644 index 000000000..5708f143f --- /dev/null +++ b/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py @@ -0,0 +1,177 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Protocol, Self + +from smithy_core.deserializers import DeserializeableShape +from smithy_core.serializers import SerializeableShape + + +class InputEventStream[E: SerializeableShape](Protocol): + """Asynchronously sends events to a service. + + This may be used as a context manager to ensure the stream is closed before exiting. + """ + + async def send(self, event: E) -> None: + """Sends an event to the service. + + :param event: The event to send. + """ + ... + + async def close(self) -> None: + """Closes the event stream.""" + ... + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): + await self.close() + + +class OutputEventStream[E: DeserializeableShape](Protocol): + """Asynchronously receives events from a service. + + Events may be received via the ``receive`` method or by using this class as + an async iterable. + + This may also be used as a context manager to ensure the stream is closed before + exiting. + """ + + async def receive(self) -> E | None: + """Receive a single event from the service. + + :returns: An event or None. None indicates that no more events will be sent by + the service. + """ + ... + + async def close(self) -> None: + """Closes the event stream.""" + ... + + async def __anext__(self) -> E: + result = await self.receive() + if result is None: + await self.close() + raise StopAsyncIteration + return result + + def __aiter__(self) -> Self: + return self + + async def __enter__(self) -> Self: + return self + + async def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): + await self.close() + + +class EventStream[I: InputEventStream[Any] | None, O: OutputEventStream[Any] | None, R]( + Protocol +): + """A unidirectional or bidirectional event stream. + + To ensure that streams are closed upon exiting, this class may be used as an async + context manager. + + .. code-block:: python + + async def main(): + client = ChatClient() + input = StreamMessagesInput(chat_room="aws-python-sdk", username="hunter7") + + async with client.stream_messages(input=input) as stream: + stream.input_stream.send(MessageStreamMessage("Chat logger starting up.")) + response_task = asyncio.create_task(handle_output(stream)) + stream.input_stream.send(MessageStreamMessage("Chat logger active.")) + await response_handler + + async def handle_output(stream: EventStream) -> None: + _, output_stream = await stream.await_output() + async for event in output_stream: + match event: + case MessageStreamMessage(): + print(event.value) + case MessageStreamShutdown(): + return + case _: + stream.input_stream.send( + MessageStreamMessage("Unknown message type received. Shutting down.") + ) + return + """ + + input_stream: I + """An event stream that sends events to the service. + + This value will be None if the operation has no input stream. + """ + + output_stream: O | None = None + """An event stream that receives events from the service. + + This value may be None until ``await_output`` has been called. + + This value will also be None if the operation has no output stream. + """ + + response: R | None = None + """The initial response from the service. + + This value may be None until ``await_output`` has been called. + + This may include context necessary to interpret output events or prepare + input events. It will always be available before any events. + """ + + async def await_output(self) -> tuple[R, O]: + """Await the operation's output. + + The EventStream will be returned as soon as the input stream is ready to + receive events, which may be before the initial response has been received + and the service is ready to send events. + + Awaiting this method will wait until the initial response was received and the + service is ready to send events. The initial response and output stream will + be returned by this operation and also cached in ``response`` and + ``output_stream``, respectively. + + The default implementation of this method performs the caching behavior, + delegating to the abstract ``_await_output`` method to actually retrieve the + initial response and output stream. + + :returns: A tuple containing the initial response and output stream. If the + operation has no output stream, the second value will be None. + """ + if self.response is not None: + self.response, self.output_stream = await self._await_output() + + return self._response, self._output_stream # type: ignore + + async def _await_output(self) -> tuple[R, O]: + """Await the operation's output without caching. + + This method is meant to be used with the default implementation of await_output. + It should return the output directly without caching. + """ + ... + + async def close(self) -> None: + """Closes the event stream. + + This closes both the input and output streams. + """ + if self.output_stream is None: + _, self.output_stream = await self.await_output() + + if self.output_stream is not None: + await self.output_stream.close() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): + await self.close()