diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java index 707db1930..1a54d2727 100644 --- a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java @@ -4,13 +4,12 @@ */ package software.amazon.smithy.python.aws.codegen; -import java.util.Collections; +import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION; + import java.util.List; import software.amazon.smithy.aws.traits.auth.SigV4Trait; import software.amazon.smithy.codegen.core.Symbol; import software.amazon.smithy.model.shapes.ShapeId; -import software.amazon.smithy.model.traits.HttpApiKeyAuthTrait; -import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION; import software.amazon.smithy.python.codegen.ApplicationProtocol; import software.amazon.smithy.python.codegen.CodegenUtils; import software.amazon.smithy.python.codegen.ConfigProperty; @@ -61,8 +60,7 @@ public List getClientPlugins(GenerationContext context) { .build()) .addConfigProperty(REGION) .authScheme(new Sigv4AuthScheme()) - .build() - ); + .build()); } @Override @@ -129,11 +127,9 @@ public List getAuthProperties() { .source(DerivedProperty.Source.CONFIG) .type(Symbol.builder().name("str").build()) .sourcePropertyName("region") - .build() - ); + .build()); } - @Override public Symbol getAuthOptionGenerator(GenerationContext context) { var resolver = CodegenUtils.getHttpAuthSchemeResolverSymbol(context.settings()); diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsConfiguration.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsConfiguration.java index 7a349772e..5d3ba035f 100644 --- a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsConfiguration.java +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsConfiguration.java @@ -13,8 +13,7 @@ */ @SmithyUnstableApi public final class AwsConfiguration { - private AwsConfiguration() { - } + private AwsConfiguration() {} public static final ConfigProperty REGION = ConfigProperty.builder() .name("region") diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsPythonDependency.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsPythonDependency.java index 738d66072..d1959628b 100644 --- a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsPythonDependency.java +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsPythonDependency.java @@ -13,8 +13,7 @@ @SmithyUnstableApi public class AwsPythonDependency { - private AwsPythonDependency() { - } + private AwsPythonDependency() {} /** * The core aws smithy runtime python package. diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsStandardRegionalEndpointsIntegration.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsStandardRegionalEndpointsIntegration.java index 3b5957cbd..886abcd76 100644 --- a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsStandardRegionalEndpointsIntegration.java +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsStandardRegionalEndpointsIntegration.java @@ -4,12 +4,11 @@ */ package software.amazon.smithy.python.aws.codegen; +import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION; + import java.util.List; import software.amazon.smithy.aws.traits.ServiceTrait; -import software.amazon.smithy.codegen.core.Symbol; -import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION; import software.amazon.smithy.python.codegen.CodegenUtils; -import software.amazon.smithy.python.codegen.ConfigProperty; import software.amazon.smithy.python.codegen.GenerationContext; import software.amazon.smithy.python.codegen.integrations.PythonIntegration; import software.amazon.smithy.python.codegen.integrations.RuntimeClientPlugin; diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java index 6a728f5e1..399b6c479 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java @@ -120,10 +120,7 @@ private void generateOperationExecutor(PythonWriter writer) { var hasStreaming = hasEventStream(); writer.putContext("hasEventStream", hasStreaming); if (hasStreaming) { - writer.addImports("smithy_core.deserializers", - Set.of( - "ShapeDeserializer", - "DeserializeableShape")); + writer.addImport("smithy_core.deserializers", "ShapeDeserializer"); writer.addStdlibImport("typing", "Any"); } @@ -137,7 +134,8 @@ private void generateOperationExecutor(PythonWriter writer) { writer.addStdlibImport("typing", "Awaitable"); writer.addStdlibImport("typing", "cast"); writer.addStdlibImport("copy", "deepcopy"); - writer.addStdlibImport("asyncio", "sleep"); + writer.addStdlibImport("asyncio"); + writer.addStdlibImports("asyncio", Set.of("sleep", "Future")); writer.addDependency(SmithyPythonDependency.SMITHY_CORE); writer.addImport("smithy_core.exceptions", "SmithyRetryException"); @@ -187,6 +185,75 @@ def _classify_error( """); writer.dedent(); + if (hasStreaming) { + writer.addStdlibImports("typing", Set.of("Any", "Awaitable")); + writer.addStdlibImport("asyncio"); + writer.write( + """ + async def _input_stream( + self, + input: Input, + plugins: list[$1T], + serialize: Callable[[Input, $4T], Awaitable[$2T]], + deserialize: Callable[[$3T, $4T], Awaitable[Output]], + config: $4T, + operation_name: str, + ) -> Any: + request_future = Future[InterceptorContext[Any, Any, $2T, Any]]() + awaitable_output = asyncio.create_task(self._execute_operation( + input, plugins, serialize, deserialize, config, operation_name, + request_future=request_future + )) + request_context = await request_future + ${5C|} + + async def _output_stream( + self, + input: Input, + plugins: list[$1T], + serialize: Callable[[Input, $4T], Awaitable[$2T]], + deserialize: Callable[[$3T, $4T], Awaitable[Output]], + config: $4T, + operation_name: str, + event_deserializer: Callable[[ShapeDeserializer], Any], + ) -> Any: + response_future = Future[$3T]() + output = await self._execute_operation( + input, plugins, serialize, deserialize, config, operation_name, + response_future=response_future + ) + transport_response = await response_future + ${6C|} + + async def _duplex_stream( + self, + input: Input, + plugins: list[$1T], + serialize: Callable[[Input, $4T], Awaitable[$2T]], + deserialize: Callable[[$3T, $4T], Awaitable[Output]], + config: $4T, + operation_name: str, + event_deserializer: Callable[[ShapeDeserializer], Any], + ) -> Any: + request_future = Future[InterceptorContext[Any, Any, $2T, Any]]() + response_future = Future[$3T]() + awaitable_output = asyncio.create_task(self._execute_operation( + input, plugins, serialize, deserialize, config, operation_name, + request_future=request_future, + response_future=response_future + )) + request_context = await request_future + ${7C|} + """, + pluginSymbol, + transportRequest, + transportResponse, + configSymbol, + writer.consumer(w -> context.protocolGenerator().wrapInputStream(context, w)), + writer.consumer(w -> context.protocolGenerator().wrapOutputStream(context, w)), + writer.consumer(w -> context.protocolGenerator().wrapDuplexStream(context, w))); + } + writer.write( """ async def _execute_operation( @@ -197,25 +264,25 @@ async def _execute_operation( deserialize: Callable[[$3T, $5T], Awaitable[Output]], config: $5T, operation_name: str, - ${?hasEventStream} - has_input_stream: bool = False, - event_deserializer: Callable[[ShapeDeserializer], Any] | None = None, - event_response_deserializer: DeserializeableShape | None = None, - ${/hasEventStream} + request_future: Future[InterceptorContext[Any, Any, $2T, Any]] | None = None, + response_future: Future[$3T] | None = None, ) -> Output: try: return await self._handle_execution( input, plugins, serialize, deserialize, config, operation_name, - ${?hasEventStream} - has_input_stream, event_deserializer, event_response_deserializer, - ${/hasEventStream} + request_future, response_future, ) except Exception as e: + if request_future is not None and not request_future.done: + request_future.set_exception($4T(e)) + if response_future is not None and not response_future.done: + response_future.set_exception($4T(e)) + # Make sure every exception that we throw is an instance of $4T so # customers can reliably catch everything we throw. if not isinstance(e, $4T): raise $4T(e) from e - raise e + raise async def _handle_execution( self, @@ -225,11 +292,8 @@ async def _handle_execution( deserialize: Callable[[$3T, $5T], Awaitable[Output]], config: $5T, operation_name: str, - ${?hasEventStream} - has_input_stream: bool = False, - event_deserializer: Callable[[ShapeDeserializer], Any] | None = None, - event_response_deserializer: DeserializeableShape | None = None, - ${/hasEventStream} + request_future: Future[InterceptorContext[Any, Any, $2T, Any]] | None, + response_future: Future[$3T] | None, ) -> Output: logger.debug('Making request for operation "%s" with parameters: %s', operation_name, input) context: InterceptorContext[Input, None, None, None] = InterceptorContext( @@ -307,6 +371,7 @@ async def _handle_execution( context_with_transport_request.copy(), config, operation_name, + request_future, ) # We perform this type-ignored re-assignment because `context` needs @@ -342,6 +407,10 @@ await seek(0) else: # Step 8: Invoke record_success retry_strategy.record_success(token=retry_token) + if response_future is not None: + response_future.set_result( + context_with_response.transport_response # type: ignore + ) break except Exception as e: if context.response is not None: @@ -355,16 +424,7 @@ await seek(0) execution_context = cast( InterceptorContext[Input, Output, $2T | None, $3T | None], context ) - ${^hasEventStream} return await self._finalize_execution(interceptors, execution_context) - ${/hasEventStream} - ${?hasEventStream} - operation_output = await self._finalize_execution(interceptors, execution_context) - if has_input_stream or event_deserializer is not None: - ${6C|} - else: - return operation_output - ${/hasEventStream} async def _handle_attempt( self, @@ -373,6 +433,7 @@ async def _handle_attempt( context: InterceptorContext[Input, None, $2T, None], config: $5T, operation_name: str, + request_future: Future[InterceptorContext[Any, Any, $2T, Any]] | None, ) -> InterceptorContext[Input, Output, $2T, $3T | None]: try: # assert config.interceptors is not None @@ -385,8 +446,7 @@ async def _handle_attempt( transportRequest, transportResponse, errorSymbol, - configSymbol, - writer.consumer(w -> context.protocolGenerator().wrapEventStream(context, w))); + configSymbol); boolean supportsAuth = !ServiceIndex.of(context.model()).getAuthSchemes(service).isEmpty(); writer.pushState(new ResolveIdentitySection()); @@ -533,10 +593,19 @@ async def _handle_attempt( ) logger.debug("HTTP request config: %s", request_config) logger.debug("Sending HTTP request: %s", context_with_response.transport_request) - context_with_response._transport_response = await config.http_client.send( - request=context_with_response.transport_request, - request_config=request_config, - ) + + if request_future is not None: + response_task = asyncio.create_task(config.http_client.send( + request=context_with_response.transport_request, + request_config=request_config, + )) + request_future.set_result(context_with_response) + context_with_response._transport_response = await response_task + else: + context_with_response._transport_response = await config.http_client.send( + request=context_with_response.transport_request, + request_config=request_config, + ) logger.debug("Received HTTP response: %s", context_with_response.transport_response) """, transportRequest, transportResponse); @@ -834,16 +903,14 @@ private void generateEventStreamOperation(PythonWriter writer, OperationShape op raise NotImplementedError() ${/hasProtocol} ${?hasProtocol} - return await self._execute_operation( + return await self._duplex_stream( input=input, plugins=operation_plugins, serialize=${serSymbol:T}, deserialize=${deserSymbol:T}, config=self._config, operation_name=${operationName:S}, - has_input_stream=True, event_deserializer=$T().deserialize, - event_response_deserializer=${output:T}, ) # type: ignore ${/hasProtocol} """, @@ -862,14 +929,13 @@ raise NotImplementedError() raise NotImplementedError() ${/hasProtocol} ${?hasProtocol} - return await self._execute_operation( + return await self._input_stream( input=input, plugins=operation_plugins, serialize=${serSymbol:T}, deserialize=${deserSymbol:T}, config=self._config, operation_name=${operationName:S}, - has_input_stream=True, ) # type: ignore ${/hasProtocol} """, writer.consumer(w -> writeSharedOperationInit(w, operation, input))); @@ -887,7 +953,7 @@ raise NotImplementedError() raise NotImplementedError() ${/hasProtocol} ${?hasProtocol} - return await self._execute_operation( + return await self._output_stream( input=input, plugins=operation_plugins, serialize=${serSymbol:T}, @@ -895,7 +961,6 @@ raise NotImplementedError() config=self._config, operation_name=${operationName:S}, event_deserializer=$T().deserialize, - event_response_deserializer=${output:T}, ) # type: ignore ${/hasProtocol} """, diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ProtocolGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ProtocolGenerator.java index 664c08c14..0006e498f 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ProtocolGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ProtocolGenerator.java @@ -154,23 +154,9 @@ default void generateSharedDeserializerComponents(GenerationContext context) {} */ default void generateProtocolTests(GenerationContext context) {} - /** - * Generates the code to wrap an operation output into an event stream. - * - *

Important context variables are: - *

- * - * @param context Generation context. - * @param writer The writer to write to. - */ - default void wrapEventStream(GenerationContext context, PythonWriter writer) {} + default void wrapInputStream(GenerationContext context, PythonWriter writer) {} + + default void wrapOutputStream(GenerationContext context, PythonWriter writer) {} + + default void wrapDuplexStream(GenerationContext context, PythonWriter writer) {} } diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java index ca994e184..ef55e15d7 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java @@ -390,48 +390,64 @@ protected void resolveErrorCodeAndMessage( } @Override - public void wrapEventStream(GenerationContext context, PythonWriter writer) { + public void wrapInputStream(GenerationContext context, PythonWriter writer) { writer.addDependency(SmithyPythonDependency.SMITHY_JSON); writer.addDependency(SmithyPythonDependency.AWS_EVENT_STREAM); - writer.addDependency(SmithyPythonDependency.SMITHY_CORE); - writer.addImports("aws_event_stream.aio", - Set.of( - "AWSDuplexEventStream", - "AWSInputEventStream", - "AWSOutputEventStream")); writer.addImport("smithy_json", "JSONCodec"); writer.addImport("smithy_core.aio.types", "AsyncBytesReader"); writer.addImport("smithy_core.types", "TimestampFormat"); - writer.addStdlibImport("typing", "Any"); + writer.addImport("aws_event_stream.aio", "AWSInputEventStream"); + writer.write( + """ + codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS) + return AWSInputEventStream[Any, Any]( + payload_codec=codec, + awaitable_output=awaitable_output, + async_writer=request_context.transport_request.body, # type: ignore + ) + """); + } - writer.write(""" - codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS) - if has_input_stream: - if event_deserializer is not None: - return AWSDuplexEventStream[Any, Any, Any]( + @Override + public void wrapOutputStream(GenerationContext context, PythonWriter writer) { + writer.addDependency(SmithyPythonDependency.SMITHY_JSON); + writer.addDependency(SmithyPythonDependency.AWS_EVENT_STREAM); + writer.addImport("smithy_json", "JSONCodec"); + writer.addImport("smithy_core.aio.types", "AsyncBytesReader"); + writer.addImport("smithy_core.types", "TimestampFormat"); + writer.addImport("aws_event_stream.aio", "AWSOutputEventStream"); + writer.write( + """ + codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS) + return AWSOutputEventStream[Any, Any]( payload_codec=codec, - initial_response=operation_output, - async_writer=execution_context.transport_request.body, # type: ignore + initial_response=output, async_reader=AsyncBytesReader( - execution_context.transport_response.body # type: ignore + transport_response.body # type: ignore ), deserializer=event_deserializer, # type: ignore ) - else: - return AWSInputEventStream[Any, Any]( + """); + } + + @Override + public void wrapDuplexStream(GenerationContext context, PythonWriter writer) { + writer.addDependency(SmithyPythonDependency.SMITHY_JSON); + writer.addDependency(SmithyPythonDependency.AWS_EVENT_STREAM); + writer.addImport("smithy_json", "JSONCodec"); + writer.addImport("smithy_core.aio.types", "AsyncBytesReader"); + writer.addImport("smithy_core.types", "TimestampFormat"); + writer.addImport("aws_event_stream.aio", "AWSDuplexEventStream"); + writer.write( + """ + codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS) + return AWSDuplexEventStream[Any, Any, Any]( payload_codec=codec, - initial_response=operation_output, - async_writer=execution_context.transport_request.body, # type: ignore + async_writer=request_context.transport_request.body, # type: ignore + awaitable_output=awaitable_output, + awaitable_response=response_future, + deserializer=event_deserializer, # type: ignore ) - else: - return AWSOutputEventStream[Any, Any]( - payload_codec=codec, - initial_response=operation_output, - async_reader=AsyncBytesReader( - execution_context.transport_response.body # type: ignore - ), - deserializer=event_deserializer, # type: ignore - ) - """); + """); } } diff --git a/packages/aws-event-stream/src/aws_event_stream/aio/__init__.py b/packages/aws-event-stream/src/aws_event_stream/aio/__init__.py index 5057ea688..76943f52e 100644 --- a/packages/aws-event-stream/src/aws_event_stream/aio/__init__.py +++ b/packages/aws-event-stream/src/aws_event_stream/aio/__init__.py @@ -1,10 +1,10 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -import asyncio from collections.abc import Callable -from typing import Self +from typing import Self, Awaitable -from smithy_core.aio.interfaces import AsyncByteStream, AsyncWriter +from smithy_core.aio.interfaces import AsyncByteStream, AsyncWriter, Response +from smithy_core.aio.types import AsyncBytesReader from smithy_core.codecs import Codec from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer from smithy_core.serializers import SerializeableShape @@ -33,8 +33,8 @@ def __init__( payload_codec: Codec, async_writer: AsyncWriter, deserializer: Callable[[ShapeDeserializer], O], - async_reader: AsyncByteStream | None = None, - initial_response: R | None = None, + awaitable_response: Awaitable[Response], + awaitable_output: Awaitable[R], deserializeable_response: type[R] | None = None, signer: Signer | None = None, is_client_mode: bool = True, @@ -68,36 +68,14 @@ def __init__( self._deserializer = deserializer self._payload_codec = payload_codec self._is_client_mode = is_client_mode + self._deserializeable_response = deserializeable_response - # Create a future to allow awaiting the reader - loop = asyncio.get_event_loop() - self._reader_future: asyncio.Future[AsyncByteStream] = loop.create_future() - if async_reader is not None: - self._reader_future.set_result(async_reader) - - # Create a future to allow awaiting the initial response - self._response = initial_response - self._deserializerable_response = deserializeable_response - self._response_future: asyncio.Future[R] = loop.create_future() - - @property - def response(self) -> R | None: - return self._response - - @response.setter - def response(self, value: R) -> None: - self._response_future.set_result(value) - self._response = value - - def set_reader(self, value: AsyncByteStream) -> None: - """Sets the object to read events from. - - :param value: An async readable object to read event bytes from. - """ - self._reader_future.set_result(value) + self._awaitable_response = awaitable_response + self._awaitable_output = awaitable_output + self.response: R | None = None async def await_output(self) -> tuple[R, AsyncEventReceiver[O]]: - async_reader = await self._reader_future + async_reader = AsyncBytesReader((await self._awaitable_response).body) if self.output_stream is None: self.output_stream = _AWSEventReceiver[O]( payload_codec=self._payload_codec, @@ -107,13 +85,13 @@ async def await_output(self) -> tuple[R, AsyncEventReceiver[O]]: ) if self.response is None: - if self._deserializerable_response is None: - initial_response = await self._response_future + if self._deserializeable_response is None: + initial_response = await self._awaitable_output else: initial_response_stream = _AWSEventReceiver( payload_codec=self._payload_codec, source=async_reader, - deserializer=self._deserializerable_response.deserialize, + deserializer=self._deserializeable_response.deserialize, is_client_mode=self._is_client_mode, ) initial_response = await initial_response_stream.receive() @@ -133,7 +111,7 @@ def __init__( self, payload_codec: Codec, async_writer: AsyncWriter, - initial_response: R | None = None, + awaitable_output: Awaitable[R], signer: Signer | None = None, is_client_mode: bool = True, ) -> None: @@ -147,13 +125,8 @@ def __init__( :param is_client_mode: Whether the stream is being constructed for a client or server implementation. """ - self._response = initial_response - - # Create a future to allow awaiting the initial response. - loop = asyncio.get_event_loop() - self._response_future: asyncio.Future[R] = loop.create_future() - if initial_response is not None: - self._response_future.set_result(initial_response) + self.response: R | None = None + self._awaitable_response = awaitable_output self.input_stream = _AWSEventPublisher( payload_codec=payload_codec, @@ -162,17 +135,10 @@ def __init__( is_client_mode=is_client_mode, ) - @property - def response(self) -> R | None: - return self._response - - @response.setter - def response(self, value: R) -> None: - self._response_future.set_result(value) - self._response = value - async def await_output(self) -> R: - return await self._response_future + if self.response is None: + self.response = await self._awaitable_response + return self.response class AWSOutputEventStream[O: DeserializeableShape, R: DeserializeableShape](