diff --git a/python-packages/smithy-http/smithy_http/aio/crt.py b/python-packages/smithy-http/smithy_http/aio/crt.py index 500d932d9..b064bab5c 100644 --- a/python-packages/smithy-http/smithy_http/aio/crt.py +++ b/python-packages/smithy-http/smithy_http/aio/crt.py @@ -25,6 +25,7 @@ HAS_CRT = False # type: ignore from smithy_core import interfaces as core_interfaces +from smithy_core.aio.types import AsyncBytesReader from smithy_core.exceptions import MissingDependencyException from .. import Field, Fields @@ -187,6 +188,7 @@ def __init__( self._tls_ctx = crt_io.ClientTlsContext(crt_io.TlsContextOptions()) self._socket_options = crt_io.SocketOptions() self._connections: ConnectionPoolDict = {} + self._async_reads: set[asyncio.Task[Any]] = set() async def send( self, @@ -293,12 +295,42 @@ async def _marshal_request( path = self._render_path(request.destination) headers = crt_http.HttpHeaders(headers_list) - body = BytesIO(await request.consume_body_async()) + + body = request.body + if isinstance(body, bytes | bytearray): + # If the body is already directly in memory, wrap in a BytesIO to hand + # off to CRT. + crt_body = BytesIO(body) + else: + # If the body is async, or potentially very large, start up a task to read + # it into the BytesIO object that CRT needs. By using asyncio.create_task + # we'll start the coroutine without having to explicitly await it. + crt_body = BytesIO() + if not isinstance(body, AsyncIterable): + # If the body isn't already an async iterable, wrap it in one. Objects + # with read methods will be read in chunks so as not to exhaust memory. + body = AsyncBytesReader(body) + + # Start the read task in the background. + read_task = asyncio.create_task(self._consume_body_async(body, crt_body)) + + # Keep track of the read task so that it doesn't get garbage colllected, + # and stop tracking it once it's done. + self._async_reads.add(read_task) + read_task.add_done_callback(self._async_reads.discard) crt_request = crt_http.HttpRequest( method=request.method, path=path, headers=headers, - body_stream=body, + body_stream=crt_body, ) return crt_request + + async def _consume_body_async( + self, source: AsyncIterable[bytes], dest: BytesIO + ) -> None: + async for chunk in source: + dest.write(chunk) + # Should we call close here? Or will that make the crt unable to read the last + # chunk?