From 5b2c1646a373d53ae74a91c4bf2763d0e1e8ea00 Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Thu, 17 Oct 2024 14:14:56 +0200 Subject: [PATCH 1/6] Add EventStream interfaces This adds in the EventStream interfaces that operations will use as their return types. --- .../smithy_event_stream/aio/interfaces.py | 177 ++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py diff --git a/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py b/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py new file mode 100644 index 000000000..abf672ed8 --- /dev/null +++ b/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py @@ -0,0 +1,177 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Protocol, Self + +from smithy_core.deserializers import DeserializeableShape +from smithy_core.serializers import SerializeableShape + + +class InputEventStream[E: SerializeableShape](Protocol): + """Asynchronously sends events to a service. + + This may be used as a context manager to ensure the stream is closed before exiting. + """ + + async def send(self, event: E) -> None: + """Sends an event to the service. + + :param event: The event to send. + """ + ... + + def close(self) -> None: + """Closes the event stream.""" + ... + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): + self.close() + + +class OutputEventStream[E: DeserializeableShape](Protocol): + """Asynchronously receives events from a service. + + Events may be recived via the ``receive`` method or by using this class as + an async iterable. + + This may also be used as a context manager to ensure the stream is closed before + exiting. + """ + + async def receive(self) -> E | None: + """Receive a single event from the service. + + :returns: An event or None. None indicates that no more events will be sent by + the service. + """ + ... + + def close(self) -> None: + """Closes the event stream.""" + ... + + async def __anext__(self) -> E: + result = await self.receive() + if result is None: + self.close() + raise StopAsyncIteration + return result + + def __aiter__(self) -> Self: + return self + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): + self.close() + + +class EventStream[I: InputEventStream[Any] | None, O: OutputEventStream[Any] | None, R]( + Protocol +): + """A unidirectional or bidirectional event stream. + + To ensure that streams are closed upon exiting, this class may be used as an async + context manager. + + .. code-block:: python + + async def main(): + client = ChatClient() + input = StreamMessagesInput(chat_room="aws-python-sdk", username="hunter7") + + async with client.stream_messages(input=input) as stream: + stream.input_stream.send(MessageStreamMessage("Chat logger starting up.")) + response_handler = handle_output(stream) + stream.input_stream.send(MessageStreamMessage("Chat logger active.")) + await response_handler + + async def handle_output(stream: EventStream) -> None: + _, output_stream = await stream.await_output() + async for event in output_stream: + match event: + case MessageStreamMessage(): + print(event.value) + case MessageStreamShutdown(): + return + case _: + stream.input_stream.send( + MessageStreamMessage("Unknown message type received. Shutting down.") + ) + return + """ + + input_stream: I + """An event stream that sends events to the service. + + This value will be None if the operation has no input stream. + """ + + output_stream: O | None = None + """An event stream that receives events from the service. + + This value may be None until ``await_output`` has been called. + + This value will also be None if the operation has no output stream. + """ + + response: R | None = None + """The initial response from the service. + + This value may be None until ``await_output`` has been called. + + This may include context necessary to interpret output events or prepare + input events. It will always be available before any events. + """ + + async def await_output(self) -> tuple[R, O]: + """Await the operation's output. + + The EventStream will be returned as soon as the input stream is ready to + receive events, which may be before the intitial response has been received + and the service is ready to send events. + + Awaiting this method will wait until the initial response was received and the + service is ready to send events. The initial response and output stream will + be returned by this operation and also cached in ``response`` and + ``output_stream``, respectively. + + The default implementation of this method performs the caching behavior, + delegating to the abstract ``_await_output`` method to actually retrieve the + initial response and output stream. + + :returns: A tuple containing the initial response and output stream. If the + operation has no output stream, the second value will be None. + """ + if self.response is not None: + self.response, self.output_stream = await self._await_output() + + return self._response, self._output_stream # type: ignore + + async def _await_output(self) -> tuple[R, O]: + """Await the operation's output without caching. + + This method is meant to be used with the default implementation of await_output. + It should return the output directly without caching. + """ + ... + + async def close(self) -> None: + """Closes the event stream. + + This closes both the input and output streams. + """ + if self.output_stream is None: + _, self.output_stream = await self.await_output() + + if self.output_stream is not None: + self.output_stream.close() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): + await self.close() From e92a423d32702ae091e3deba82aebd0c45f2a0e5 Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Wed, 16 Oct 2024 15:14:52 +0200 Subject: [PATCH 2/6] Add sample stream to test package This adds a bidirectional event stream to the test package so that new code generation for event streams can be examined. --- .../model/main.smithy | 74 ++++++++++++++++++- 1 file changed, 71 insertions(+), 3 deletions(-) diff --git a/codegen/smithy-python-codegen-test/model/main.smithy b/codegen/smithy-python-codegen-test/model/main.smithy index cdc85811f..39ee49502 100644 --- a/codegen/smithy-python-codegen-test/model/main.smithy +++ b/codegen/smithy-python-codegen-test/model/main.smithy @@ -20,11 +20,14 @@ service Weather { operations: [ GetCurrentTime TestUnionListOperation + StreamAtmosphericConditions ] } resource City { - identifiers: { cityId: CityId } + identifiers: { + cityId: CityId + } read: GetCity list: ListCities resources: [ @@ -56,12 +59,16 @@ union UnionListMember { } resource Forecast { - identifiers: { cityId: CityId } + identifiers: { + cityId: CityId + } read: GetForecast } resource CityImage { - identifiers: { cityId: CityId } + identifiers: { + cityId: CityId + } read: GetCityImage } @@ -622,6 +629,67 @@ union Precipitation { baz: example.weather.nested.more#Baz } +@http(method: "POST", uri: "/cities/{cityId}/atmosphere") +operation StreamAtmosphericConditions { + input := { + @required + @httpLabel + cityId: CityId + + @required + @httpPayload + stream: AtmosphericConditions + } + + output := { + @required + @httpHeader("x-initial-sample-rate") + initialSampleRate: Double + + @required + @httpPayload + stream: CollectionDirectives + } +} + +@streaming +union AtmosphericConditions { + humidity: HumiditySample + pressure: PressureSample + temperature: TemperatureSample +} + +@mixin +structure Sample { + @required + collectionTime: Timestamp +} + +structure HumiditySample with [Sample] { + @required + humidity: Double +} + +structure PressureSample with [Sample] { + @required + pressure: Double +} + +structure TemperatureSample with [Sample] { + @required + temperature: Double +} + +@streaming +union CollectionDirectives { + sampleRate: SampleRate +} + +structure SampleRate { + @required + samplesPerMinute: Double +} + structure OtherStructure {} enum StringYesNo { From 87f30ea6ebe76a967de97a6d4317431243a1f386 Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Wed, 16 Oct 2024 16:01:11 +0200 Subject: [PATCH 3/6] Generate event stream operation method signatures This updates operation generation to generate event stream operations with the EventStream type as its return value. It also updates union generation so that unions contain their own deserialize functions. This is needed to make the them pass the type check, but also it is best to have them own as much of that as possible so that the deserializer function can be left to only dispatch duty. --- .../python/codegen/ClientGenerator.java | 106 +++++++++++++----- .../codegen/SmithyPythonDependency.java | 7 ++ .../smithy/python/codegen/UnionGenerator.java | 33 ++++-- 3 files changed, 105 insertions(+), 41 deletions(-) diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java index 9ea91c73a..a3e91d4a8 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java @@ -20,6 +20,8 @@ import java.util.LinkedHashSet; import java.util.Set; import software.amazon.smithy.codegen.core.SymbolReference; +import software.amazon.smithy.model.knowledge.EventStreamIndex; +import software.amazon.smithy.model.knowledge.EventStreamInfo; import software.amazon.smithy.model.knowledge.ServiceIndex; import software.amazon.smithy.model.knowledge.TopDownIndex; import software.amazon.smithy.model.shapes.OperationShape; @@ -104,8 +106,14 @@ def __init__(self, config: $1T | None = None, plugins: list[$2T] | None = None): """, configSymbol, pluginSymbol, writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins))); var topDownIndex = TopDownIndex.of(context.model()); + var eventStreamIndex = EventStreamIndex.of(context.model()); for (OperationShape operation : topDownIndex.getContainedOperations(service)) { - generateOperation(writer, operation); + if (eventStreamIndex.getInputInfo(operation).isPresent() + || eventStreamIndex.getOutputInfo(operation).isPresent()) { + generateEventStreamOperation(writer, operation); + } else { + generateOperation(writer, operation); + } } }); @@ -348,7 +356,7 @@ async def _handle_attempt( ) """, CodegenUtils.getHttpAuthParamsSymbol(context.settings()), - writer.consumer(this::initializeHttpAuthParameters)); + writer.consumer(this::initializeHttpAuthParameters)); writer.popState(); writer.addDependency(SmithyPythonDependency.SMITHY_CORE); @@ -641,16 +649,16 @@ private void generateOperation(PythonWriter writer, OperationShape operation) { writer.openBlock("async def $L(self, input: $T, plugins: list[$T] | None = None) -> $T:", "", operationSymbol.getName(), inputSymbol, pluginSymbol, outputSymbol, () -> { - writer.writeDocs(() -> { - var docs = operation.getTrait(DocumentationTrait.class) - .map(StringTrait::getValue) - .orElse(String.format("Invokes the %s operation.", operation.getId().getName())); + writer.writeDocs(() -> { + var docs = operation.getTrait(DocumentationTrait.class) + .map(StringTrait::getValue) + .orElse(String.format("Invokes the %s operation.", operation.getId().getName())); - var inputDocs = input.getTrait(DocumentationTrait.class) - .map(StringTrait::getValue) - .orElse("The operation's input."); + var inputDocs = input.getTrait(DocumentationTrait.class) + .map(StringTrait::getValue) + .orElse("The operation's input."); - writer.write(""" + writer.write(""" $L :param input: $L @@ -658,17 +666,17 @@ private void generateOperation(PythonWriter writer, OperationShape operation) { :param plugins: A list of callables that modify the configuration dynamically. Changes made by these plugins only apply for the duration of the operation execution and will not affect any other operation invocations.""", docs, inputDocs); - }); - - var defaultPlugins = new LinkedHashSet(); - for (PythonIntegration integration : context.integrations()) { - for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins()) { - if (runtimeClientPlugin.matchesOperation(context.model(), service, operation)) { - runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add); + }); + + var defaultPlugins = new LinkedHashSet(); + for (PythonIntegration integration : context.integrations()) { + for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins()) { + if (runtimeClientPlugin.matchesOperation(context.model(), service, operation)) { + runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add); + } + } } - } - } - writer.write(""" + writer.write(""" operation_plugins: list[Plugin] = [ $C ] @@ -676,13 +684,13 @@ private void generateOperation(PythonWriter writer, OperationShape operation) { operation_plugins.extend(plugins) """, writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins))); - if (context.protocolGenerator() == null) { - writer.write("raise NotImplementedError()"); - } else { - var protocolGenerator = context.protocolGenerator(); - var serSymbol = protocolGenerator.getSerializationFunction(context, operation); - var deserSymbol = protocolGenerator.getDeserializationFunction(context, operation); - writer.write(""" + if (context.protocolGenerator() == null) { + writer.write("raise NotImplementedError()"); + } else { + var protocolGenerator = context.protocolGenerator(); + var serSymbol = protocolGenerator.getSerializationFunction(context, operation); + var deserSymbol = protocolGenerator.getDeserializationFunction(context, operation); + writer.write(""" return await self._execute_operation( input=input, plugins=operation_plugins, @@ -692,7 +700,47 @@ private void generateOperation(PythonWriter writer, OperationShape operation) { operation_name=$S, ) """, serSymbol, deserSymbol, operation.getId().getName()); - } - }); + } + }); + } + + private void generateEventStreamOperation(PythonWriter writer, OperationShape operation) { + writer.pushState(); + writer.addDependency(SmithyPythonDependency.SMITHY_EVENT_STREAM); + writer.addImports("smithy_event_stream.aio.interfaces", Set.of( + "EventStream", "InputEventStream", "OutputEventStream")); + var operationSymbol = context.symbolProvider().toSymbol(operation); + var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings()); + + var input = context.model().expectShape(operation.getInputShape()); + var inputSymbol = context.symbolProvider().toSymbol(input); + + var eventStreamIndex = EventStreamIndex.of(context.model()); + var inputStreamSymbol = eventStreamIndex.getInputInfo(operation) + .map(EventStreamInfo::getEventStreamTarget) + .map(target -> context.symbolProvider().toSymbol(target)) + .orElse(null); + writer.putContext("inputStream", inputStreamSymbol); + + var output = context.model().expectShape(operation.getOutputShape()); + var outputSymbol = context.symbolProvider().toSymbol(output); + var outputStreamSymbol = eventStreamIndex.getOutputInfo(operation) + .map(EventStreamInfo::getEventStreamTarget) + .map(target -> context.symbolProvider().toSymbol(target)) + .orElse(null); + writer.putContext("outputStream", outputStreamSymbol); + + writer.write(""" + async def $L(self, input: $T, plugins: list[$T] | None = None) -> EventStream[ + ${?inputStream}InputEventStream[${inputStream:T}]${/inputStream}\ + ${^inputStream}None${/inputStream}, + ${?outputStream}OutputEventStream[${outputStream:T}]${/outputStream}\ + ${^outputStream}None${/outputStream}, + $T + ]: + raise NotImplementedError() + """, operationSymbol.getName(), inputSymbol, pluginSymbol, outputSymbol); + + writer.popState(); } } diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SmithyPythonDependency.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SmithyPythonDependency.java index e96ef2517..d1af2a4d5 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SmithyPythonDependency.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SmithyPythonDependency.java @@ -64,6 +64,13 @@ public final class SmithyPythonDependency { false ); + public static final PythonDependency SMITHY_EVENT_STREAM = new PythonDependency( + "smithy_event_stream", + "==0.0.1", + Type.DEPENDENCY, + false + ); + /** * testing framework used in generated functional tests. */ diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/UnionGenerator.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/UnionGenerator.java index 19b25ab90..eba7c85f7 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/UnionGenerator.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/UnionGenerator.java @@ -72,16 +72,20 @@ public void run() { writer.write(""" @dataclass - class $L: - ${C|} + class $1L: + ${2C|} - value: $T + value: $3T def serialize(self, serializer: ShapeSerializer): - serializer.write_struct($T, self) + serializer.write_struct($4T, self) def serialize_members(self, serializer: ShapeSerializer): - ${C|} + ${5C|} + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + return cls(value=${6C|}) """, memberSymbol.getName(), @@ -90,7 +94,11 @@ def serialize_members(self, serializer: ShapeSerializer): targetSymbol, schemaSymbol, writer.consumer(w -> target.accept( - new MemberSerializerGenerator(context, w, member, "serializer")))); + new MemberSerializerGenerator(context, w, member, "serializer"))), + writer.consumer(w -> target.accept( + new MemberDeserializerGenerator(context, w, member, "deserializer"))) + + ); } // Note that the unknown variant doesn't implement __eq__. This is because @@ -118,11 +126,15 @@ raise SmithyException("Unknown union variants may not be serialized.") def serialize_members(self, serializer: ShapeSerializer): raise SmithyException("Unknown union variants may not be serialized.") + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + raise NotImplementedError() + """, unknownSymbol.getName()); memberNames.add(unknownSymbol.getName()); - shape.getTrait(DocumentationTrait.class).ifPresent(trait -> writer.writeComment(trait.getValue())); writer.write("type $L = $L\n", parentName, String.join(" | ", memberNames)); + shape.getTrait(DocumentationTrait.class).ifPresent(trait -> writer.writeDocs(trait.getValue())); generateDeserializer(); writer.popState(); @@ -173,13 +185,10 @@ raise SmithyException("Unions must have exactly one value, but found more than o private void deserializeMembers() { int index = 0; for (MemberShape member : shape.members()) { - var target = model.expectShape(member.getTarget()); writer.write(""" case $L: - self._set_result($T(${C|})) - """, index++, symbolProvider.toSymbol(member), writer.consumer(w -> - target.accept(new MemberDeserializerGenerator(context, writer, member, "de")) - )); + self._set_result($T.deserialize(de)) + """, index++, symbolProvider.toSymbol(member)); } } } From ae059daa0009a51c9a8868a3240b4e87a9053b41 Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Thu, 17 Oct 2024 15:22:33 +0200 Subject: [PATCH 4/6] Undo unintential indentation changes --- .../python/codegen/ClientGenerator.java | 56 +++++++++---------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java index a3e91d4a8..930687687 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java @@ -356,7 +356,7 @@ async def _handle_attempt( ) """, CodegenUtils.getHttpAuthParamsSymbol(context.settings()), - writer.consumer(this::initializeHttpAuthParameters)); + writer.consumer(this::initializeHttpAuthParameters)); writer.popState(); writer.addDependency(SmithyPythonDependency.SMITHY_CORE); @@ -649,16 +649,16 @@ private void generateOperation(PythonWriter writer, OperationShape operation) { writer.openBlock("async def $L(self, input: $T, plugins: list[$T] | None = None) -> $T:", "", operationSymbol.getName(), inputSymbol, pluginSymbol, outputSymbol, () -> { - writer.writeDocs(() -> { - var docs = operation.getTrait(DocumentationTrait.class) - .map(StringTrait::getValue) - .orElse(String.format("Invokes the %s operation.", operation.getId().getName())); + writer.writeDocs(() -> { + var docs = operation.getTrait(DocumentationTrait.class) + .map(StringTrait::getValue) + .orElse(String.format("Invokes the %s operation.", operation.getId().getName())); - var inputDocs = input.getTrait(DocumentationTrait.class) - .map(StringTrait::getValue) - .orElse("The operation's input."); + var inputDocs = input.getTrait(DocumentationTrait.class) + .map(StringTrait::getValue) + .orElse("The operation's input."); - writer.write(""" + writer.write(""" $L :param input: $L @@ -666,17 +666,17 @@ private void generateOperation(PythonWriter writer, OperationShape operation) { :param plugins: A list of callables that modify the configuration dynamically. Changes made by these plugins only apply for the duration of the operation execution and will not affect any other operation invocations.""", docs, inputDocs); - }); - - var defaultPlugins = new LinkedHashSet(); - for (PythonIntegration integration : context.integrations()) { - for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins()) { - if (runtimeClientPlugin.matchesOperation(context.model(), service, operation)) { - runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add); - } - } + }); + + var defaultPlugins = new LinkedHashSet(); + for (PythonIntegration integration : context.integrations()) { + for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins()) { + if (runtimeClientPlugin.matchesOperation(context.model(), service, operation)) { + runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add); } - writer.write(""" + } + } + writer.write(""" operation_plugins: list[Plugin] = [ $C ] @@ -684,13 +684,13 @@ private void generateOperation(PythonWriter writer, OperationShape operation) { operation_plugins.extend(plugins) """, writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins))); - if (context.protocolGenerator() == null) { - writer.write("raise NotImplementedError()"); - } else { - var protocolGenerator = context.protocolGenerator(); - var serSymbol = protocolGenerator.getSerializationFunction(context, operation); - var deserSymbol = protocolGenerator.getDeserializationFunction(context, operation); - writer.write(""" + if (context.protocolGenerator() == null) { + writer.write("raise NotImplementedError()"); + } else { + var protocolGenerator = context.protocolGenerator(); + var serSymbol = protocolGenerator.getSerializationFunction(context, operation); + var deserSymbol = protocolGenerator.getDeserializationFunction(context, operation); + writer.write(""" return await self._execute_operation( input=input, plugins=operation_plugins, @@ -700,8 +700,8 @@ private void generateOperation(PythonWriter writer, OperationShape operation) { operation_name=$S, ) """, serSymbol, deserSymbol, operation.getId().getName()); - } - }); + } + }); } private void generateEventStreamOperation(PythonWriter writer, OperationShape operation) { From 4639dc3b62b056b2f5504f626876e9dadf0b615a Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Fri, 18 Oct 2024 17:57:15 +0200 Subject: [PATCH 5/6] Use async context managers for event streams --- .../smithy_event_stream/aio/interfaces.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py b/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py index abf672ed8..acfb2bc3a 100644 --- a/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py +++ b/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py @@ -19,15 +19,15 @@ async def send(self, event: E) -> None: """ ... - def close(self) -> None: + async def close(self) -> None: """Closes the event stream.""" ... - def __enter__(self) -> Self: + async def __aenter__(self) -> Self: return self - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): - self.close() + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): + await self.close() class OutputEventStream[E: DeserializeableShape](Protocol): @@ -48,25 +48,25 @@ async def receive(self) -> E | None: """ ... - def close(self) -> None: + async def close(self) -> None: """Closes the event stream.""" ... async def __anext__(self) -> E: result = await self.receive() if result is None: - self.close() + await self.close() raise StopAsyncIteration return result def __aiter__(self) -> Self: return self - def __enter__(self) -> Self: + async def __enter__(self) -> Self: return self - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): - self.close() + async def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): + await self.close() class EventStream[I: InputEventStream[Any] | None, O: OutputEventStream[Any] | None, R]( @@ -85,7 +85,7 @@ async def main(): async with client.stream_messages(input=input) as stream: stream.input_stream.send(MessageStreamMessage("Chat logger starting up.")) - response_handler = handle_output(stream) + response_task = asyncio.create_task(handle_output(stream)) stream.input_stream.send(MessageStreamMessage("Chat logger active.")) await response_handler @@ -168,7 +168,7 @@ async def close(self) -> None: _, self.output_stream = await self.await_output() if self.output_stream is not None: - self.output_stream.close() + await self.output_stream.close() async def __aenter__(self) -> Self: return self From 996da8a01e1ef2e8393863249105059826eab62d Mon Sep 17 00:00:00 2001 From: Jordon Phillips Date: Fri, 18 Oct 2024 19:01:42 +0200 Subject: [PATCH 6/6] Apply suggestions from code review Fix comment typos Co-authored-by: Nate Prewitt --- .../smithy-event-stream/smithy_event_stream/aio/interfaces.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py b/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py index acfb2bc3a..5708f143f 100644 --- a/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py +++ b/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py @@ -33,7 +33,7 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): class OutputEventStream[E: DeserializeableShape](Protocol): """Asynchronously receives events from a service. - Events may be recived via the ``receive`` method or by using this class as + Events may be received via the ``receive`` method or by using this class as an async iterable. This may also be used as a context manager to ensure the stream is closed before @@ -131,7 +131,7 @@ async def await_output(self) -> tuple[R, O]: """Await the operation's output. The EventStream will be returned as soon as the input stream is ready to - receive events, which may be before the intitial response has been received + receive events, which may be before the initial response has been received and the service is ready to send events. Awaiting this method will wait until the initial response was received and the