From 950f00b6ba9f8aa939a64ac9a8b6a2d49f0f7a07 Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Mon, 21 Oct 2024 14:07:38 +0200 Subject: [PATCH] Generate RestJson event stream implementation This updates generic event stream generation with recently introduced changes and also introduces the concrete implementation for RestJson. Testing for all of this will be done via protocol tests, and in the early days manual testing. Since a lot of this is effectively throwaway code, I was more liberal with type ignoring and using Any types than I otherwise would be. The request pipeline is going to be moving to pure python soon^tm, and the typing issues will be resolved at that time. --- .../python/codegen/ClientGenerator.java | 233 ++++++++++++++---- .../codegen/SmithyPythonDependency.java | 13 + .../integration/ProtocolGenerator.java | 22 ++ .../RestJsonProtocolGenerator.java | 42 +++- 4 files changed, 260 insertions(+), 50 deletions(-) diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java index 930687687..a56dea3cd 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java @@ -26,6 +26,7 @@ import software.amazon.smithy.model.knowledge.TopDownIndex; import software.amazon.smithy.model.shapes.OperationShape; import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.shapes.Shape; import software.amazon.smithy.model.traits.DocumentationTrait; import software.amazon.smithy.model.traits.StringTrait; import software.amazon.smithy.python.codegen.integration.PythonIntegration; @@ -123,6 +124,16 @@ def __init__(self, config: $1T | None = None, plugins: list[$2T] | None = None): } private void generateOperationExecutor(PythonWriter writer) { + writer.pushState(); + + var hasStreaming = hasEventStream(); + writer.putContext("hasEventStream", hasStreaming); + if (hasStreaming) { + writer.addImports("smithy_core.deserializers", Set.of( + "ShapeDeserializer", "DeserializeableShape")); + writer.addStdlibImport("typing", "Any"); + } + var transportRequest = context.applicationProtocol().requestType(); var transportResponse = context.applicationProtocol().responseType(); var errorSymbol = CodegenUtils.getServiceError(context.settings()); @@ -191,10 +202,18 @@ async def _execute_operation( deserialize: Callable[[$3T, $5T], Awaitable[Output]], config: $5T, operation_name: str, + ${?hasEventStream} + has_input_stream: bool = False, + event_deserializer: Callable[[ShapeDeserializer], Any] | None = None, + event_response_deserializer: DeserializeableShape | None = None, + ${/hasEventStream} ) -> Output: try: return await self._handle_execution( - input, plugins, serialize, deserialize, config, operation_name + input, plugins, serialize, deserialize, config, operation_name, + ${?hasEventStream} + has_input_stream, event_deserializer, event_response_deserializer, + ${/hasEventStream} ) except Exception as e: # Make sure every exception that we throw is an instance of $4T so @@ -211,6 +230,11 @@ async def _handle_execution( deserialize: Callable[[$3T, $5T], Awaitable[Output]], config: $5T, operation_name: str, + ${?hasEventStream} + has_input_stream: bool = False, + event_deserializer: Callable[[ShapeDeserializer], Any] | None = None, + event_response_deserializer: DeserializeableShape | None = None, + ${/hasEventStream} ) -> Output: logger.debug(f"Making request for operation {operation_name} with parameters: {input}") context: InterceptorContext[Input, None, None, None] = InterceptorContext( @@ -326,7 +350,16 @@ await sleep(retry_token.retry_delay) execution_context = cast( InterceptorContext[Input, Output, $2T | None, $3T | None], context ) + ${^hasEventStream} return await self._finalize_execution(interceptors, execution_context) + ${/hasEventStream} + ${?hasEventStream} + operation_output = await self._finalize_execution(interceptors, execution_context) + if has_input_stream or event_deserializer is not None: + ${6C|} + else: + return operation_output + ${/hasEventStream} async def _handle_attempt( self, @@ -342,7 +375,8 @@ async def _handle_attempt( for interceptor in interceptors: interceptor.read_before_attempt(context) - """, pluginSymbol, transportRequest, transportResponse, errorSymbol, configSymbol); + """, pluginSymbol, transportRequest, transportResponse, errorSymbol, configSymbol, + writer.consumer(w -> context.protocolGenerator().wrapEventStream(context, w))); boolean supportsAuth = !ServiceIndex.of(context.model()).getAuthSchemes(service).isEmpty(); writer.pushState(new ResolveIdentitySection()); @@ -604,6 +638,18 @@ async def _finalize_execution( return context.response """, transportRequest, transportResponse); writer.dedent(); + writer.popState(); + } + + private boolean hasEventStream() { + var streamIndex = EventStreamIndex.of(context.model()); + var topDownIndex = TopDownIndex.of(context.model()); + for (OperationShape operation : topDownIndex.getContainedOperations(context.settings().service())) { + if (streamIndex.getInputInfo(operation).isPresent() || streamIndex.getOutputInfo(operation).isPresent()) { + return true; + } + } + return false; } private void initializeHttpAuthParameters(PythonWriter writer) { @@ -649,40 +695,7 @@ private void generateOperation(PythonWriter writer, OperationShape operation) { writer.openBlock("async def $L(self, input: $T, plugins: list[$T] | None = None) -> $T:", "", operationSymbol.getName(), inputSymbol, pluginSymbol, outputSymbol, () -> { - writer.writeDocs(() -> { - var docs = operation.getTrait(DocumentationTrait.class) - .map(StringTrait::getValue) - .orElse(String.format("Invokes the %s operation.", operation.getId().getName())); - - var inputDocs = input.getTrait(DocumentationTrait.class) - .map(StringTrait::getValue) - .orElse("The operation's input."); - - writer.write(""" - $L - - :param input: $L - - :param plugins: A list of callables that modify the configuration dynamically. - Changes made by these plugins only apply for the duration of the operation - execution and will not affect any other operation invocations.""", docs, inputDocs); - }); - - var defaultPlugins = new LinkedHashSet(); - for (PythonIntegration integration : context.integrations()) { - for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins()) { - if (runtimeClientPlugin.matchesOperation(context.model(), service, operation)) { - runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add); - } - } - } - writer.write(""" - operation_plugins: list[Plugin] = [ - $C - ] - if plugins: - operation_plugins.extend(plugins) - """, writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins))); + writeSharedOperationInit(writer, operation, input); if (context.protocolGenerator() == null) { writer.write("raise NotImplementedError()"); @@ -704,16 +717,55 @@ private void generateOperation(PythonWriter writer, OperationShape operation) { }); } + private void writeSharedOperationInit(PythonWriter writer, OperationShape operation, Shape input) { + writer.writeDocs(() -> { + var docs = operation.getTrait(DocumentationTrait.class) + .map(StringTrait::getValue) + .orElse(String.format("Invokes the %s operation.", operation.getId().getName())); + + var inputDocs = input.getTrait(DocumentationTrait.class) + .map(StringTrait::getValue) + .orElse("The operation's input."); + + writer.write(""" + $L + + :param input: $L + + :param plugins: A list of callables that modify the configuration dynamically. + Changes made by these plugins only apply for the duration of the operation + execution and will not affect any other operation invocations.""", docs, inputDocs); + }); + + var defaultPlugins = new LinkedHashSet(); + for (PythonIntegration integration : context.integrations()) { + for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins()) { + if (runtimeClientPlugin.matchesOperation(context.model(), service, operation)) { + runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add); + } + } + } + writer.write(""" + operation_plugins: list[Plugin] = [ + $C + ] + if plugins: + operation_plugins.extend(plugins) + """, writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins))); + + } + private void generateEventStreamOperation(PythonWriter writer, OperationShape operation) { writer.pushState(); writer.addDependency(SmithyPythonDependency.SMITHY_EVENT_STREAM); - writer.addImports("smithy_event_stream.aio.interfaces", Set.of( - "EventStream", "InputEventStream", "OutputEventStream")); var operationSymbol = context.symbolProvider().toSymbol(operation); + writer.putContext("operationName", operationSymbol.getName()); var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings()); + writer.putContext("plugin", pluginSymbol); var input = context.model().expectShape(operation.getInputShape()); var inputSymbol = context.symbolProvider().toSymbol(input); + writer.putContext("input", inputSymbol); var eventStreamIndex = EventStreamIndex.of(context.model()); var inputStreamSymbol = eventStreamIndex.getInputInfo(operation) @@ -724,22 +776,107 @@ private void generateEventStreamOperation(PythonWriter writer, OperationShape op var output = context.model().expectShape(operation.getOutputShape()); var outputSymbol = context.symbolProvider().toSymbol(output); + writer.putContext("output", outputSymbol); + var outputStreamSymbol = eventStreamIndex.getOutputInfo(operation) .map(EventStreamInfo::getEventStreamTarget) .map(target -> context.symbolProvider().toSymbol(target)) .orElse(null); writer.putContext("outputStream", outputStreamSymbol); - writer.write(""" - async def $L(self, input: $T, plugins: list[$T] | None = None) -> EventStream[ - ${?inputStream}InputEventStream[${inputStream:T}]${/inputStream}\ - ${^inputStream}None${/inputStream}, - ${?outputStream}OutputEventStream[${outputStream:T}]${/outputStream}\ - ${^outputStream}None${/outputStream}, - $T - ]: - raise NotImplementedError() - """, operationSymbol.getName(), inputSymbol, pluginSymbol, outputSymbol); + writer.putContext("hasProtocol", context.protocolGenerator() != null); + if (context.protocolGenerator() != null) { + var serSymbol = context.protocolGenerator().getSerializationFunction(context, operation); + writer.putContext("serSymbol", serSymbol); + var deserSymbol = context.protocolGenerator().getDeserializationFunction(context, operation); + writer.putContext("deserSymbol", deserSymbol); + } else { + writer.putContext("serSymbol", null); + writer.putContext("deserSymbol", null); + } + + if (inputStreamSymbol != null) { + if (outputStreamSymbol != null) { + writer.addImport("smithy_event_stream.aio.interfaces", "DuplexEventStream"); + writer.write(""" + async def ${operationName:L}( + self, + input: ${input:T}, + plugins: list[${plugin:T}] | None = None + ) -> DuplexEventStream[${inputStream:T}, ${outputStream:T}, ${output:T}]: + ${C|} + ${^hasProtocol} + raise NotImplementedError() + ${/hasProtocol} + ${?hasProtocol} + return await self._execute_operation( + input=input, + plugins=operation_plugins, + serialize=${serSymbol:T}, + deserialize=${deserSymbol:T}, + config=self._config, + operation_name=${operationName:S}, + has_input_stream=True, + event_deserializer=$T().deserialize, + event_response_deserializer=${output:T}, + ) # type: ignore + ${/hasProtocol} + """, + writer.consumer(w -> writeSharedOperationInit(w, operation, input)), + outputStreamSymbol.expectProperty(SymbolProperties.DESERIALIZER)); + } else { + writer.addImport("smithy_event_stream.aio.interfaces", "InputEventStream"); + writer.write(""" + async def ${operationName:L}( + self, + input: ${input:T}, + plugins: list[${plugin:T}] | None = None + ) -> InputEventStream[${inputStream:T}, ${output:T}]: + ${C|} + ${^hasProtocol} + raise NotImplementedError() + ${/hasProtocol} + ${?hasProtocol} + return await self._execute_operation( + input=input, + plugins=operation_plugins, + serialize=${serSymbol:T}, + deserialize=${deserSymbol:T}, + config=self._config, + operation_name=${operationName:S}, + has_input_stream=True, + ) # type: ignore + ${/hasProtocol} + """, writer.consumer(w -> writeSharedOperationInit(w, operation, input))); + } + } else { + writer.addImport("smithy_event_stream.aio.interfaces", "OutputEventStream"); + writer.write(""" + async def ${operationName:L}( + self, + input: ${input:T}, + plugins: list[${plugin:T}] | None = None + ) -> OutputEventStream[${outputStream:T}, ${output:T}]: + ${C|} + ${^hasProtocol} + raise NotImplementedError() + ${/hasProtocol} + ${?hasProtocol} + return await self._execute_operation( + input=input, + plugins=operation_plugins, + serialize=${serSymbol:T}, + deserialize=${deserSymbol:T}, + config=self._config, + operation_name=${operationName:S}, + event_deserializer=$T().deserialize, + event_response_deserializer=${output:T}, + ) # type: ignore + ${/hasProtocol} + """, + writer.consumer(w -> writeSharedOperationInit(w, operation, input)), + outputStreamSymbol.expectProperty(SymbolProperties.DESERIALIZER)); + } writer.popState(); } diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SmithyPythonDependency.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SmithyPythonDependency.java index d1af2a4d5..954c207d0 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SmithyPythonDependency.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SmithyPythonDependency.java @@ -64,6 +64,9 @@ public final class SmithyPythonDependency { false ); + /** + * Core interfaces for event streams. + */ public static final PythonDependency SMITHY_EVENT_STREAM = new PythonDependency( "smithy_event_stream", "==0.0.1", @@ -71,6 +74,16 @@ public final class SmithyPythonDependency { false ); + /** + * EventStream implementations for application/vnd.amazon.eventstream. + */ + public static final PythonDependency AWS_EVENT_STREAM = new PythonDependency( + "aws_event_stream", + "==0.0.1", + Type.DEPENDENCY, + false + ); + /** * testing framework used in generated functional tests. */ diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/ProtocolGenerator.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/ProtocolGenerator.java index 5349ee6c6..78257ec2e 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/ProtocolGenerator.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/ProtocolGenerator.java @@ -22,6 +22,7 @@ import software.amazon.smithy.model.shapes.ToShapeId; import software.amazon.smithy.python.codegen.ApplicationProtocol; import software.amazon.smithy.python.codegen.GenerationContext; +import software.amazon.smithy.python.codegen.PythonWriter; import software.amazon.smithy.utils.CaseUtils; import software.amazon.smithy.utils.SmithyUnstableApi; @@ -167,4 +168,25 @@ default void generateSharedDeserializerComponents(GenerationContext context) { */ default void generateProtocolTests(GenerationContext context) { } + + /** + * Generates the code to wrap an operation output into an event stream. + * + *

Important context variables are: + *

    + *
  • execution_context - Has the context, including the transport input and output.
  • + *
  • operation_output - The deserialized operation output.
  • + *
  • has_input_stream - Whether or not there is an input stream.
  • + *
  • event_deserializer - The deserialize method for output events, or None for no output stream.
  • + *
  • event_response_deserializer - A DeserializeableShape representing the operation's output shape, + * or None for no output stream. This is used when the operation sends the initial response over the + * event stream. + *
  • + *
+ * + * @param context Generation context. + * @param writer The writer to write to. + */ + default void wrapEventStream(GenerationContext context, PythonWriter writer) { + } } diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/RestJsonProtocolGenerator.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/RestJsonProtocolGenerator.java index 7e566c6a6..e40e16ce3 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/RestJsonProtocolGenerator.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/RestJsonProtocolGenerator.java @@ -167,8 +167,9 @@ protected void serializePayloadBody( // or a blob, meaning it's some potentially big collection of bytes. // See also: https://smithy.io/2.0/spec/streaming.html#smithy-api-streaming-trait if (payloadBinding.getMember().getMemberTrait(context.model(), StreamingTrait.class).isPresent()) { - // TODO: support event streams if (target.isUnionShape()) { + writer.addImport("smithy_core.aio.types", "AsyncBytesProvider"); + writer.write("body = AsyncBytesProvider()"); return; } @@ -306,7 +307,6 @@ protected void deserializePayloadBody( Shape operationOrError, HttpBinding payloadBinding ) { - writer.addDependency(SmithyPythonDependency.SMITHY_JSON); writer.addDependency(SmithyPythonDependency.SMITHY_CORE); writer.addImport("smithy_json", "JSONCodec"); @@ -379,4 +379,42 @@ protected void resolveErrorCodeAndMessage(GenerationContext context, } writer.write(")"); } + + @Override + public void wrapEventStream(GenerationContext context, PythonWriter writer) { + writer.addDependency(SmithyPythonDependency.SMITHY_JSON); + writer.addDependency(SmithyPythonDependency.AWS_EVENT_STREAM); + writer.addDependency(SmithyPythonDependency.SMITHY_CORE); + writer.addImports("aws_event_stream.aio", Set.of( + "AWSDuplexEventStream", "AWSInputEventStream", "AWSOutputEventStream")); + writer.addImport("smithy_json", "JSONCodec"); + writer.addImport("smithy_core.types", "TimestampFormat"); + writer.addStdlibImport("typing", "Any"); + + writer.write(""" + codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS) + if has_input_stream: + if event_deserializer is not None: + return AWSDuplexEventStream[Any, Any, Any]( + payload_codec=codec, + initial_response=operation_output, + async_writer=execution_context.transport_request.body, # type: ignore + async_reader=execution_context.transport_response.body, # type: ignore + deserializer=event_deserializer, # type: ignore + ) + else: + return AWSInputEventStream[Any, Any]( + payload_codec=codec, + initial_response=operation_output, + async_writer=execution_context.transport_request.body, # type: ignore + ) + else: + return AWSOutputEventStream[Any, Any]( + payload_codec=codec, + initial_response=operation_output, + async_reader=execution_context.transport_response.body, # type: ignore + deserializer=event_deserializer, # type: ignore + ) + """); + } }