diff --git a/replit_river/message_buffer.py b/replit_river/message_buffer.py index b7b91ea5..e07ff56e 100644 --- a/replit_river/message_buffer.py +++ b/replit_river/message_buffer.py @@ -8,6 +8,10 @@ logger = logging.getLogger(__name__) +class MessageBufferClosedError(BaseException): + """Raised when a message buffer is closed and is not accepting new messages.""" + + class MessageBuffer: """A buffer to store messages and support current updates""" @@ -15,6 +19,8 @@ def __init__(self, max_num_messages: int = MAX_MESSAGE_BUFFER_SIZE): self.max_size = max_num_messages self.buffer: list[TransportMessage] = [] self._lock = asyncio.Lock() + self._space_available_cond = asyncio.Condition(lock=self._lock) + self._closed = False async def empty(self) -> bool: """Check if the buffer is empty""" @@ -22,11 +28,17 @@ async def empty(self) -> bool: return len(self.buffer) == 0 async def put(self, message: TransportMessage) -> None: - """Add a message to the buffer""" - async with self._lock: - if len(self.buffer) >= self.max_size: - logger.error("Buffer is full, dropping message") - raise ValueError("Buffer is full") + """Add a message to the buffer. Blocks until there is space in the buffer. + + Raises: + MessageBufferClosedError: if the buffer is closed. + """ + async with self._space_available_cond: + await self._space_available_cond.wait_for( + lambda: len(self.buffer) < self.max_size or self._closed + ) + if self._closed: + raise MessageBufferClosedError("message buffer is closed") self.buffer.append(message) async def peek(self) -> Optional[TransportMessage]: @@ -40,3 +52,12 @@ async def remove_old_messages(self, min_seq: int) -> None: """Remove messages in the buffer with a seq number less than min_seq.""" async with self._lock: self.buffer = [msg for msg in self.buffer if msg.seq >= min_seq] + self._space_available_cond.notify_all() + + async def close(self) -> None: + """ + Closes the message buffer and rejects any pending put operations. + """ + async with self._lock: + self._closed = True + self._space_available_cond.notify_all() diff --git a/replit_river/session.py b/replit_river/session.py index 44893930..38d09723 100644 --- a/replit_river/session.py +++ b/replit_river/session.py @@ -8,7 +8,7 @@ from aiochannel import Channel, ChannelClosed from websockets.exceptions import ConnectionClosed -from replit_river.message_buffer import MessageBuffer +from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError from replit_river.messages import ( FailedSendingMessageException, WebsocketClosedException, @@ -386,10 +386,8 @@ async def send_message( async with self._msg_lock: try: await self._buffer.put(msg) - except Exception: - # We should close the session when there are too many messages in - # buffer - await self.close() + except MessageBufferClosedError: + # The session is closed and is no longer accepting new messages. return async with self._ws_lock: if not await self._ws_wrapper.is_open(): @@ -542,6 +540,8 @@ async def close(self) -> None: await self.close_websocket(self._ws_wrapper, should_retry=False) + await self._buffer.close() + # Clear the session in transports await self._close_session_callback(self) diff --git a/replit_river/transport_options.py b/replit_river/transport_options.py index 6a4222a2..47032bac 100644 --- a/replit_river/transport_options.py +++ b/replit_river/transport_options.py @@ -3,7 +3,7 @@ from pydantic import BaseModel -MAX_MESSAGE_BUFFER_SIZE = 1024 +MAX_MESSAGE_BUFFER_SIZE = 128 class ConnectionRetryOptions(BaseModel): diff --git a/tests/test_communication.py b/tests/test_communication.py index dfeca489..c18d357b 100644 --- a/tests/test_communication.py +++ b/tests/test_communication.py @@ -5,6 +5,7 @@ from replit_river.client import Client from replit_river.error_schema import RiverError +from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE from tests.conftest import deserialize_error, deserialize_response, serialize_request @@ -41,6 +42,27 @@ async def upload_data() -> AsyncGenerator[str, None]: assert response == "Uploaded: Initial Data, Data 1, Data 2, Data 3" +@pytest.mark.asyncio +async def test_upload_more_than_send_buffer_max(client: Client) -> None: + iterations = MAX_MESSAGE_BUFFER_SIZE * 2 + + async def upload_data() -> AsyncGenerator[str, None]: + for _ in range(0, iterations): + yield "Data" + + response = await client.send_upload( + "test_service", + "upload_method", + "Initial Data", + upload_data(), + serialize_request, + serialize_request, + deserialize_response, + deserialize_response, + ) # type: ignore + assert response == "Uploaded: Initial Data" + (", Data" * iterations) + + @pytest.mark.asyncio async def test_upload_empty(client: Client) -> None: async def upload_data(enabled: bool = False) -> AsyncGenerator[str, None]: diff --git a/tests/test_message_buffer.py b/tests/test_message_buffer.py new file mode 100644 index 00000000..3c7c6a93 --- /dev/null +++ b/tests/test_message_buffer.py @@ -0,0 +1,64 @@ +import asyncio + +import pytest + +from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError +from replit_river.rpc import TransportMessage + + +def mock_transport_message(seq: int) -> TransportMessage: + return TransportMessage( + seq=seq, + id="test", + ack=0, + from_="test", + to="test", + streamId="test", + controlFlags=0, + payload=0, + model_config={}, + ) + + +async def test_message_buffer_backpressure() -> None: + """ + Tests that MessageBuffer.put blocks until there is space in the buffer, + creating back pressure in the client. + """ + buffer = MessageBuffer(max_num_messages=1) + + iterations = 100 + + # We use a queue as a way to sync our test logic with the background + # task with the testing logic. + sync_events: asyncio.Queue[None] = asyncio.Queue() + + async def put_messages() -> None: + for i in range(0, iterations): + await buffer.put(mock_transport_message(seq=i)) + await sync_events.put(None) + + background_puts = asyncio.create_task(put_messages()) + + for i in range(1, iterations): + # Wait for the put call to return. + await sync_events.get() + assert len(buffer.buffer) == 1 + await buffer.remove_old_messages(i) + + await background_puts + + +async def test_message_buffer_close() -> None: + """ + Tests that MessageBuffer.put raises an exception when the buffer + is closed while the put operation is waiting for space in the buffer. + """ + buffer = MessageBuffer(max_num_messages=1) + await buffer.put(mock_transport_message(seq=1)) + background_put = asyncio.create_task(buffer.put(mock_transport_message(seq=1))) + await buffer.close() + with pytest.raises(MessageBufferClosedError): + await background_put + with pytest.raises(MessageBufferClosedError): + await buffer.put(mock_transport_message(seq=1))