From e1d782447d468a4e4fcd737a05c3c9e566ec7a2a Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Thu, 13 Mar 2025 11:49:52 +0100 Subject: [PATCH 1/3] Return immediately on streaming inputs --- .../aws/codegen/AwsAuthIntegration.java | 12 +- .../python/aws/codegen/AwsConfiguration.java | 3 +- .../aws/codegen/AwsPythonDependency.java | 3 +- ...sStandardRegionalEndpointsIntegration.java | 5 +- .../python/codegen/ClientGenerator.java | 147 +++++++++++++----- .../codegen/generators/ProtocolGenerator.java | 24 +-- .../RestJsonProtocolGenerator.java | 76 +++++---- .../src/aws_event_stream/aio/__init__.py | 72 +++------ 8 files changed, 184 insertions(+), 158 deletions(-) diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java index 707db1930..1a54d2727 100644 --- a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java @@ -4,13 +4,12 @@ */ package software.amazon.smithy.python.aws.codegen; -import java.util.Collections; +import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION; + import java.util.List; import software.amazon.smithy.aws.traits.auth.SigV4Trait; import software.amazon.smithy.codegen.core.Symbol; import software.amazon.smithy.model.shapes.ShapeId; -import software.amazon.smithy.model.traits.HttpApiKeyAuthTrait; -import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION; import software.amazon.smithy.python.codegen.ApplicationProtocol; import software.amazon.smithy.python.codegen.CodegenUtils; import software.amazon.smithy.python.codegen.ConfigProperty; @@ -61,8 +60,7 @@ public List getClientPlugins(GenerationContext context) { .build()) .addConfigProperty(REGION) .authScheme(new Sigv4AuthScheme()) - .build() - ); + .build()); } @Override @@ -129,11 +127,9 @@ public List getAuthProperties() { .source(DerivedProperty.Source.CONFIG) .type(Symbol.builder().name("str").build()) .sourcePropertyName("region") - .build() - ); + .build()); } - @Override public Symbol getAuthOptionGenerator(GenerationContext context) { var resolver = CodegenUtils.getHttpAuthSchemeResolverSymbol(context.settings()); diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsConfiguration.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsConfiguration.java index 7a349772e..5d3ba035f 100644 --- a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsConfiguration.java +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsConfiguration.java @@ -13,8 +13,7 @@ */ @SmithyUnstableApi public final class AwsConfiguration { - private AwsConfiguration() { - } + private AwsConfiguration() {} public static final ConfigProperty REGION = ConfigProperty.builder() .name("region") diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsPythonDependency.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsPythonDependency.java index 738d66072..d1959628b 100644 --- a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsPythonDependency.java +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsPythonDependency.java @@ -13,8 +13,7 @@ @SmithyUnstableApi public class AwsPythonDependency { - private AwsPythonDependency() { - } + private AwsPythonDependency() {} /** * The core aws smithy runtime python package. diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsStandardRegionalEndpointsIntegration.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsStandardRegionalEndpointsIntegration.java index 3b5957cbd..886abcd76 100644 --- a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsStandardRegionalEndpointsIntegration.java +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsStandardRegionalEndpointsIntegration.java @@ -4,12 +4,11 @@ */ package software.amazon.smithy.python.aws.codegen; +import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION; + import java.util.List; import software.amazon.smithy.aws.traits.ServiceTrait; -import software.amazon.smithy.codegen.core.Symbol; -import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION; import software.amazon.smithy.python.codegen.CodegenUtils; -import software.amazon.smithy.python.codegen.ConfigProperty; import software.amazon.smithy.python.codegen.GenerationContext; import software.amazon.smithy.python.codegen.integrations.PythonIntegration; import software.amazon.smithy.python.codegen.integrations.RuntimeClientPlugin; diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java index 6a728f5e1..e4d3ea49f 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java @@ -120,10 +120,7 @@ private void generateOperationExecutor(PythonWriter writer) { var hasStreaming = hasEventStream(); writer.putContext("hasEventStream", hasStreaming); if (hasStreaming) { - writer.addImports("smithy_core.deserializers", - Set.of( - "ShapeDeserializer", - "DeserializeableShape")); + writer.addImport("smithy_core.deserializers", "ShapeDeserializer"); writer.addStdlibImport("typing", "Any"); } @@ -137,7 +134,8 @@ private void generateOperationExecutor(PythonWriter writer) { writer.addStdlibImport("typing", "Awaitable"); writer.addStdlibImport("typing", "cast"); writer.addStdlibImport("copy", "deepcopy"); - writer.addStdlibImport("asyncio", "sleep"); + writer.addStdlibImport("asyncio"); + writer.addStdlibImports("asyncio", Set.of("sleep", "Future")); writer.addDependency(SmithyPythonDependency.SMITHY_CORE); writer.addImport("smithy_core.exceptions", "SmithyRetryException"); @@ -187,6 +185,75 @@ def _classify_error( """); writer.dedent(); + if (hasStreaming) { + writer.addStdlibImports("typing", Set.of("Any", "Awaitable")); + writer.addStdlibImport("asyncio"); + writer.write( + """ + async def _input_stream( + self, + input: Input, + plugins: list[$1T], + serialize: Callable[[Input, $4T], Awaitable[$2T]], + deserialize: Callable[[$3T, $4T], Awaitable[Output]], + config: $4T, + operation_name: str, + ) -> Any: + request_future = Future[$2T]() + awaitable_output = asyncio.create_task(self._execute_operation( + input, plugins, serialize, deserialize, config, operation_name, + request_future=request_future + )) + transport_request = await request_future + ${5C|} + + async def _output_stream( + self, + input: Input, + plugins: list[$1T], + serialize: Callable[[Input, $4T], Awaitable[$2T]], + deserialize: Callable[[$3T, $4T], Awaitable[Output]], + config: $4T, + operation_name: str, + event_deserializer: Callable[[ShapeDeserializer], Any], + ) -> Any: + response_future = Future[$3T]() + output = await self._execute_operation( + input, plugins, serialize, deserialize, config, operation_name, + response_future=response_future + ) + transport_response = await response_future + ${6C|} + + async def _duplex_stream( + self, + input: Input, + plugins: list[$1T], + serialize: Callable[[Input, $4T], Awaitable[$2T]], + deserialize: Callable[[$3T, $4T], Awaitable[Output]], + config: $4T, + operation_name: str, + event_deserializer: Callable[[ShapeDeserializer], Any], + ) -> Any: + request_future = Future[$2T]() + response_future = Future[$3T]() + awaitable_output = asyncio.create_task(self._execute_operation( + input, plugins, serialize, deserialize, config, operation_name, + request_future=request_future, + response_future=response_future + )) + transport_request = await request_future + ${7C|} + """, + pluginSymbol, + transportRequest, + transportResponse, + configSymbol, + writer.consumer(w -> context.protocolGenerator().wrapInputStream(context, w)), + writer.consumer(w -> context.protocolGenerator().wrapOutputStream(context, w)), + writer.consumer(w -> context.protocolGenerator().wrapDuplexStream(context, w))); + } + writer.write( """ async def _execute_operation( @@ -197,25 +264,25 @@ async def _execute_operation( deserialize: Callable[[$3T, $5T], Awaitable[Output]], config: $5T, operation_name: str, - ${?hasEventStream} - has_input_stream: bool = False, - event_deserializer: Callable[[ShapeDeserializer], Any] | None = None, - event_response_deserializer: DeserializeableShape | None = None, - ${/hasEventStream} + request_future: Future[$2T] | None = None, + response_future: Future[$3T] | None = None, ) -> Output: try: return await self._handle_execution( input, plugins, serialize, deserialize, config, operation_name, - ${?hasEventStream} - has_input_stream, event_deserializer, event_response_deserializer, - ${/hasEventStream} + request_future, response_future, ) except Exception as e: + if request_future is not None and not request_future.done: + request_future.set_exception($4T(e)) + if response_future is not None and not response_future.done: + response_future.set_exception($4T(e)) + # Make sure every exception that we throw is an instance of $4T so # customers can reliably catch everything we throw. if not isinstance(e, $4T): raise $4T(e) from e - raise e + raise async def _handle_execution( self, @@ -225,11 +292,8 @@ async def _handle_execution( deserialize: Callable[[$3T, $5T], Awaitable[Output]], config: $5T, operation_name: str, - ${?hasEventStream} - has_input_stream: bool = False, - event_deserializer: Callable[[ShapeDeserializer], Any] | None = None, - event_response_deserializer: DeserializeableShape | None = None, - ${/hasEventStream} + request_future: Future[$2T] | None, + response_future: Future[$3T] | None, ) -> Output: logger.debug('Making request for operation "%s" with parameters: %s', operation_name, input) context: InterceptorContext[Input, None, None, None] = InterceptorContext( @@ -307,6 +371,7 @@ async def _handle_execution( context_with_transport_request.copy(), config, operation_name, + request_future, ) # We perform this type-ignored re-assignment because `context` needs @@ -342,6 +407,10 @@ await seek(0) else: # Step 8: Invoke record_success retry_strategy.record_success(token=retry_token) + if response_future is not None: + response_future.set_result( + context_with_response.response, # type: ignore + ) break except Exception as e: if context.response is not None: @@ -355,16 +424,7 @@ await seek(0) execution_context = cast( InterceptorContext[Input, Output, $2T | None, $3T | None], context ) - ${^hasEventStream} return await self._finalize_execution(interceptors, execution_context) - ${/hasEventStream} - ${?hasEventStream} - operation_output = await self._finalize_execution(interceptors, execution_context) - if has_input_stream or event_deserializer is not None: - ${6C|} - else: - return operation_output - ${/hasEventStream} async def _handle_attempt( self, @@ -373,6 +433,7 @@ async def _handle_attempt( context: InterceptorContext[Input, None, $2T, None], config: $5T, operation_name: str, + request_future: Future[$2T] | None, ) -> InterceptorContext[Input, Output, $2T, $3T | None]: try: # assert config.interceptors is not None @@ -385,8 +446,7 @@ async def _handle_attempt( transportRequest, transportResponse, errorSymbol, - configSymbol, - writer.consumer(w -> context.protocolGenerator().wrapEventStream(context, w))); + configSymbol); boolean supportsAuth = !ServiceIndex.of(context.model()).getAuthSchemes(service).isEmpty(); writer.pushState(new ResolveIdentitySection()); @@ -533,10 +593,19 @@ async def _handle_attempt( ) logger.debug("HTTP request config: %s", request_config) logger.debug("Sending HTTP request: %s", context_with_response.transport_request) - context_with_response._transport_response = await config.http_client.send( - request=context_with_response.transport_request, - request_config=request_config, - ) + + if request_future is not None: + response_task = asyncio.create_task(config.http_client.send( + request=context_with_response.transport_request, + request_config=request_config, + )) + request_future.set_result(context_with_response.transport_request) + context_with_response._transport_response = await response_task + else: + context_with_response._transport_response = await config.http_client.send( + request=context_with_response.transport_request, + request_config=request_config, + ) logger.debug("Received HTTP response: %s", context_with_response.transport_response) """, transportRequest, transportResponse); @@ -834,16 +903,14 @@ private void generateEventStreamOperation(PythonWriter writer, OperationShape op raise NotImplementedError() ${/hasProtocol} ${?hasProtocol} - return await self._execute_operation( + return await self._duplex_stream( input=input, plugins=operation_plugins, serialize=${serSymbol:T}, deserialize=${deserSymbol:T}, config=self._config, operation_name=${operationName:S}, - has_input_stream=True, event_deserializer=$T().deserialize, - event_response_deserializer=${output:T}, ) # type: ignore ${/hasProtocol} """, @@ -862,14 +929,13 @@ raise NotImplementedError() raise NotImplementedError() ${/hasProtocol} ${?hasProtocol} - return await self._execute_operation( + return await self._input_stream( input=input, plugins=operation_plugins, serialize=${serSymbol:T}, deserialize=${deserSymbol:T}, config=self._config, operation_name=${operationName:S}, - has_input_stream=True, ) # type: ignore ${/hasProtocol} """, writer.consumer(w -> writeSharedOperationInit(w, operation, input))); @@ -887,7 +953,7 @@ raise NotImplementedError() raise NotImplementedError() ${/hasProtocol} ${?hasProtocol} - return await self._execute_operation( + return await self._output_stream( input=input, plugins=operation_plugins, serialize=${serSymbol:T}, @@ -895,7 +961,6 @@ raise NotImplementedError() config=self._config, operation_name=${operationName:S}, event_deserializer=$T().deserialize, - event_response_deserializer=${output:T}, ) # type: ignore ${/hasProtocol} """, diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ProtocolGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ProtocolGenerator.java index 664c08c14..0006e498f 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ProtocolGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ProtocolGenerator.java @@ -154,23 +154,9 @@ default void generateSharedDeserializerComponents(GenerationContext context) {} */ default void generateProtocolTests(GenerationContext context) {} - /** - * Generates the code to wrap an operation output into an event stream. - * - *

Important context variables are: - *

    - *
  • execution_context - Has the context, including the transport input and output.
  • - *
  • operation_output - The deserialized operation output.
  • - *
  • has_input_stream - Whether or not there is an input stream.
  • - *
  • event_deserializer - The deserialize method for output events, or None for no output stream.
  • - *
  • event_response_deserializer - A DeserializeableShape representing the operation's output shape, - * or None for no output stream. This is used when the operation sends the initial response over the - * event stream. - *
  • - *
- * - * @param context Generation context. - * @param writer The writer to write to. - */ - default void wrapEventStream(GenerationContext context, PythonWriter writer) {} + default void wrapInputStream(GenerationContext context, PythonWriter writer) {} + + default void wrapOutputStream(GenerationContext context, PythonWriter writer) {} + + default void wrapDuplexStream(GenerationContext context, PythonWriter writer) {} } diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java index ca994e184..0424959b3 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java @@ -390,48 +390,64 @@ protected void resolveErrorCodeAndMessage( } @Override - public void wrapEventStream(GenerationContext context, PythonWriter writer) { + public void wrapInputStream(GenerationContext context, PythonWriter writer) { writer.addDependency(SmithyPythonDependency.SMITHY_JSON); writer.addDependency(SmithyPythonDependency.AWS_EVENT_STREAM); - writer.addDependency(SmithyPythonDependency.SMITHY_CORE); - writer.addImports("aws_event_stream.aio", - Set.of( - "AWSDuplexEventStream", - "AWSInputEventStream", - "AWSOutputEventStream")); writer.addImport("smithy_json", "JSONCodec"); writer.addImport("smithy_core.aio.types", "AsyncBytesReader"); writer.addImport("smithy_core.types", "TimestampFormat"); - writer.addStdlibImport("typing", "Any"); + writer.addImport("aws_event_stream.aio", "AWSInputEventStream"); + writer.write( + """ + codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS) + return AWSInputEventStream[Any, Any]( + payload_codec=codec, + awaitable_output=awaitable_output, + async_writer=transport_request.body, # type: ignore + ) + """); + } - writer.write(""" - codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS) - if has_input_stream: - if event_deserializer is not None: - return AWSDuplexEventStream[Any, Any, Any]( + @Override + public void wrapOutputStream(GenerationContext context, PythonWriter writer) { + writer.addDependency(SmithyPythonDependency.SMITHY_JSON); + writer.addDependency(SmithyPythonDependency.AWS_EVENT_STREAM); + writer.addImport("smithy_json", "JSONCodec"); + writer.addImport("smithy_core.aio.types", "AsyncBytesReader"); + writer.addImport("smithy_core.types", "TimestampFormat"); + writer.addImport("aws_event_stream.aio", "AWSOutputEventStream"); + writer.write( + """ + codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS) + return AWSOutputEventStream[Any, Any]( payload_codec=codec, - initial_response=operation_output, - async_writer=execution_context.transport_request.body, # type: ignore + initial_response=output, async_reader=AsyncBytesReader( - execution_context.transport_response.body # type: ignore + transport_response.body # type: ignore ), deserializer=event_deserializer, # type: ignore ) - else: - return AWSInputEventStream[Any, Any]( + """); + } + + @Override + public void wrapDuplexStream(GenerationContext context, PythonWriter writer) { + writer.addDependency(SmithyPythonDependency.SMITHY_JSON); + writer.addDependency(SmithyPythonDependency.AWS_EVENT_STREAM); + writer.addImport("smithy_json", "JSONCodec"); + writer.addImport("smithy_core.aio.types", "AsyncBytesReader"); + writer.addImport("smithy_core.types", "TimestampFormat"); + writer.addImport("aws_event_stream.aio", "AWSDuplexEventStream"); + writer.write( + """ + codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS) + return AWSDuplexEventStream[Any, Any, Any]( payload_codec=codec, - initial_response=operation_output, - async_writer=execution_context.transport_request.body, # type: ignore + async_writer=transport_request.body, # type: ignore + awaitable_output=awaitable_output, + awaitable_response=response_future, + deserializer=event_deserializer, # type: ignore ) - else: - return AWSOutputEventStream[Any, Any]( - payload_codec=codec, - initial_response=operation_output, - async_reader=AsyncBytesReader( - execution_context.transport_response.body # type: ignore - ), - deserializer=event_deserializer, # type: ignore - ) - """); + """); } } diff --git a/packages/aws-event-stream/src/aws_event_stream/aio/__init__.py b/packages/aws-event-stream/src/aws_event_stream/aio/__init__.py index 5057ea688..76943f52e 100644 --- a/packages/aws-event-stream/src/aws_event_stream/aio/__init__.py +++ b/packages/aws-event-stream/src/aws_event_stream/aio/__init__.py @@ -1,10 +1,10 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -import asyncio from collections.abc import Callable -from typing import Self +from typing import Self, Awaitable -from smithy_core.aio.interfaces import AsyncByteStream, AsyncWriter +from smithy_core.aio.interfaces import AsyncByteStream, AsyncWriter, Response +from smithy_core.aio.types import AsyncBytesReader from smithy_core.codecs import Codec from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer from smithy_core.serializers import SerializeableShape @@ -33,8 +33,8 @@ def __init__( payload_codec: Codec, async_writer: AsyncWriter, deserializer: Callable[[ShapeDeserializer], O], - async_reader: AsyncByteStream | None = None, - initial_response: R | None = None, + awaitable_response: Awaitable[Response], + awaitable_output: Awaitable[R], deserializeable_response: type[R] | None = None, signer: Signer | None = None, is_client_mode: bool = True, @@ -68,36 +68,14 @@ def __init__( self._deserializer = deserializer self._payload_codec = payload_codec self._is_client_mode = is_client_mode + self._deserializeable_response = deserializeable_response - # Create a future to allow awaiting the reader - loop = asyncio.get_event_loop() - self._reader_future: asyncio.Future[AsyncByteStream] = loop.create_future() - if async_reader is not None: - self._reader_future.set_result(async_reader) - - # Create a future to allow awaiting the initial response - self._response = initial_response - self._deserializerable_response = deserializeable_response - self._response_future: asyncio.Future[R] = loop.create_future() - - @property - def response(self) -> R | None: - return self._response - - @response.setter - def response(self, value: R) -> None: - self._response_future.set_result(value) - self._response = value - - def set_reader(self, value: AsyncByteStream) -> None: - """Sets the object to read events from. - - :param value: An async readable object to read event bytes from. - """ - self._reader_future.set_result(value) + self._awaitable_response = awaitable_response + self._awaitable_output = awaitable_output + self.response: R | None = None async def await_output(self) -> tuple[R, AsyncEventReceiver[O]]: - async_reader = await self._reader_future + async_reader = AsyncBytesReader((await self._awaitable_response).body) if self.output_stream is None: self.output_stream = _AWSEventReceiver[O]( payload_codec=self._payload_codec, @@ -107,13 +85,13 @@ async def await_output(self) -> tuple[R, AsyncEventReceiver[O]]: ) if self.response is None: - if self._deserializerable_response is None: - initial_response = await self._response_future + if self._deserializeable_response is None: + initial_response = await self._awaitable_output else: initial_response_stream = _AWSEventReceiver( payload_codec=self._payload_codec, source=async_reader, - deserializer=self._deserializerable_response.deserialize, + deserializer=self._deserializeable_response.deserialize, is_client_mode=self._is_client_mode, ) initial_response = await initial_response_stream.receive() @@ -133,7 +111,7 @@ def __init__( self, payload_codec: Codec, async_writer: AsyncWriter, - initial_response: R | None = None, + awaitable_output: Awaitable[R], signer: Signer | None = None, is_client_mode: bool = True, ) -> None: @@ -147,13 +125,8 @@ def __init__( :param is_client_mode: Whether the stream is being constructed for a client or server implementation. """ - self._response = initial_response - - # Create a future to allow awaiting the initial response. - loop = asyncio.get_event_loop() - self._response_future: asyncio.Future[R] = loop.create_future() - if initial_response is not None: - self._response_future.set_result(initial_response) + self.response: R | None = None + self._awaitable_response = awaitable_output self.input_stream = _AWSEventPublisher( payload_codec=payload_codec, @@ -162,17 +135,10 @@ def __init__( is_client_mode=is_client_mode, ) - @property - def response(self) -> R | None: - return self._response - - @response.setter - def response(self, value: R) -> None: - self._response_future.set_result(value) - self._response = value - async def await_output(self) -> R: - return await self._response_future + if self.response is None: + self.response = await self._awaitable_response + return self.response class AWSOutputEventStream[O: DeserializeableShape, R: DeserializeableShape]( From ce663f381e6ce59a61a1117ac2c60b854689546d Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Thu, 13 Mar 2025 18:53:35 +0100 Subject: [PATCH 2/3] send correct future --- .../software/amazon/smithy/python/codegen/ClientGenerator.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java index e4d3ea49f..4513a8fea 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java @@ -409,7 +409,7 @@ await seek(0) retry_strategy.record_success(token=retry_token) if response_future is not None: response_future.set_result( - context_with_response.response, # type: ignore + context_with_response.transport_response, # type: ignore ) break except Exception as e: From d0c5fbef983c43129055cc413b2dee427c18f646 Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Thu, 13 Mar 2025 19:10:38 +0100 Subject: [PATCH 3/3] give context to input streams --- .../smithy/python/codegen/ClientGenerator.java | 18 +++++++++--------- .../RestJsonProtocolGenerator.java | 4 ++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java index 4513a8fea..399b6c479 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java @@ -199,12 +199,12 @@ async def _input_stream( config: $4T, operation_name: str, ) -> Any: - request_future = Future[$2T]() + request_future = Future[InterceptorContext[Any, Any, $2T, Any]]() awaitable_output = asyncio.create_task(self._execute_operation( input, plugins, serialize, deserialize, config, operation_name, request_future=request_future )) - transport_request = await request_future + request_context = await request_future ${5C|} async def _output_stream( @@ -235,14 +235,14 @@ async def _duplex_stream( operation_name: str, event_deserializer: Callable[[ShapeDeserializer], Any], ) -> Any: - request_future = Future[$2T]() + request_future = Future[InterceptorContext[Any, Any, $2T, Any]]() response_future = Future[$3T]() awaitable_output = asyncio.create_task(self._execute_operation( input, plugins, serialize, deserialize, config, operation_name, request_future=request_future, response_future=response_future )) - transport_request = await request_future + request_context = await request_future ${7C|} """, pluginSymbol, @@ -264,7 +264,7 @@ async def _execute_operation( deserialize: Callable[[$3T, $5T], Awaitable[Output]], config: $5T, operation_name: str, - request_future: Future[$2T] | None = None, + request_future: Future[InterceptorContext[Any, Any, $2T, Any]] | None = None, response_future: Future[$3T] | None = None, ) -> Output: try: @@ -292,7 +292,7 @@ async def _handle_execution( deserialize: Callable[[$3T, $5T], Awaitable[Output]], config: $5T, operation_name: str, - request_future: Future[$2T] | None, + request_future: Future[InterceptorContext[Any, Any, $2T, Any]] | None, response_future: Future[$3T] | None, ) -> Output: logger.debug('Making request for operation "%s" with parameters: %s', operation_name, input) @@ -409,7 +409,7 @@ await seek(0) retry_strategy.record_success(token=retry_token) if response_future is not None: response_future.set_result( - context_with_response.transport_response, # type: ignore + context_with_response.transport_response # type: ignore ) break except Exception as e: @@ -433,7 +433,7 @@ async def _handle_attempt( context: InterceptorContext[Input, None, $2T, None], config: $5T, operation_name: str, - request_future: Future[$2T] | None, + request_future: Future[InterceptorContext[Any, Any, $2T, Any]] | None, ) -> InterceptorContext[Input, Output, $2T, $3T | None]: try: # assert config.interceptors is not None @@ -599,7 +599,7 @@ async def _handle_attempt( request=context_with_response.transport_request, request_config=request_config, )) - request_future.set_result(context_with_response.transport_request) + request_future.set_result(context_with_response) context_with_response._transport_response = await response_task else: context_with_response._transport_response = await config.http_client.send( diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java index 0424959b3..ef55e15d7 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java @@ -403,7 +403,7 @@ public void wrapInputStream(GenerationContext context, PythonWriter writer) { return AWSInputEventStream[Any, Any]( payload_codec=codec, awaitable_output=awaitable_output, - async_writer=transport_request.body, # type: ignore + async_writer=request_context.transport_request.body, # type: ignore ) """); } @@ -443,7 +443,7 @@ public void wrapDuplexStream(GenerationContext context, PythonWriter writer) { codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS) return AWSDuplexEventStream[Any, Any, Any]( payload_codec=codec, - async_writer=transport_request.body, # type: ignore + async_writer=request_context.transport_request.body, # type: ignore awaitable_output=awaitable_output, awaitable_response=response_future, deserializer=event_deserializer, # type: ignore