Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions replit_river/message_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,37 @@
logger = logging.getLogger(__name__)


class MessageBufferClosedError(BaseException):
"""Raised when a message buffer is closed and is not accepting new messages."""


class MessageBuffer:
"""A buffer to store messages and support current updates"""

def __init__(self, max_num_messages: int = MAX_MESSAGE_BUFFER_SIZE):
Comment thread
cbrewster marked this conversation as resolved.
self.max_size = max_num_messages
self.buffer: list[TransportMessage] = []
self._lock = asyncio.Lock()
self._space_available_cond = asyncio.Condition(lock=self._lock)
self._closed = False

async def empty(self) -> bool:
"""Check if the buffer is empty"""
async with self._lock:
return len(self.buffer) == 0

async def put(self, message: TransportMessage) -> None:
"""Add a message to the buffer"""
async with self._lock:
if len(self.buffer) >= self.max_size:
logger.error("Buffer is full, dropping message")
raise ValueError("Buffer is full")
Comment thread
cbrewster marked this conversation as resolved.
"""Add a message to the buffer. Blocks until there is space in the buffer.

Raises:
MessageBufferClosedError: if the buffer is closed.
"""
async with self._space_available_cond:
await self._space_available_cond.wait_for(
lambda: len(self.buffer) < self.max_size or self._closed
)
if self._closed:
raise MessageBufferClosedError("message buffer is closed")
self.buffer.append(message)

async def peek(self) -> Optional[TransportMessage]:
Expand All @@ -40,3 +52,12 @@ async def remove_old_messages(self, min_seq: int) -> None:
"""Remove messages in the buffer with a seq number less than min_seq."""
async with self._lock:
self.buffer = [msg for msg in self.buffer if msg.seq >= min_seq]
self._space_available_cond.notify_all()

async def close(self) -> None:
"""
Closes the message buffer and rejects any pending put operations.
"""
async with self._lock:
self._closed = True
self._space_available_cond.notify_all()
10 changes: 5 additions & 5 deletions replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from aiochannel import Channel, ChannelClosed
from websockets.exceptions import ConnectionClosed

from replit_river.message_buffer import MessageBuffer
from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError
from replit_river.messages import (
FailedSendingMessageException,
WebsocketClosedException,
Expand Down Expand Up @@ -386,10 +386,8 @@ async def send_message(
async with self._msg_lock:
try:
await self._buffer.put(msg)
except Exception:
# We should close the session when there are too many messages in
# buffer
await self.close()
except MessageBufferClosedError:
# The session is closed and is no longer accepting new messages.
return
async with self._ws_lock:
if not await self._ws_wrapper.is_open():
Expand Down Expand Up @@ -542,6 +540,8 @@ async def close(self) -> None:

await self.close_websocket(self._ws_wrapper, should_retry=False)

await self._buffer.close()

# Clear the session in transports
await self._close_session_callback(self)

Expand Down
2 changes: 1 addition & 1 deletion replit_river/transport_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import BaseModel

MAX_MESSAGE_BUFFER_SIZE = 1024
MAX_MESSAGE_BUFFER_SIZE = 128


class ConnectionRetryOptions(BaseModel):
Expand Down
22 changes: 22 additions & 0 deletions tests/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from replit_river.client import Client
from replit_river.error_schema import RiverError
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE
from tests.conftest import deserialize_error, deserialize_response, serialize_request


Expand Down Expand Up @@ -41,6 +42,27 @@ async def upload_data() -> AsyncGenerator[str, None]:
assert response == "Uploaded: Initial Data, Data 1, Data 2, Data 3"


@pytest.mark.asyncio
async def test_upload_more_than_send_buffer_max(client: Client) -> None:
iterations = MAX_MESSAGE_BUFFER_SIZE * 2

async def upload_data() -> AsyncGenerator[str, None]:
for _ in range(0, iterations):
yield "Data"

response = await client.send_upload(
"test_service",
"upload_method",
"Initial Data",
upload_data(),
serialize_request,
serialize_request,
deserialize_response,
deserialize_response,
) # type: ignore
assert response == "Uploaded: Initial Data" + (", Data" * iterations)


@pytest.mark.asyncio
async def test_upload_empty(client: Client) -> None:
async def upload_data(enabled: bool = False) -> AsyncGenerator[str, None]:
Expand Down
64 changes: 64 additions & 0 deletions tests/test_message_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import asyncio

import pytest

from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError
from replit_river.rpc import TransportMessage


def mock_transport_message(seq: int) -> TransportMessage:
return TransportMessage(
seq=seq,
id="test",
ack=0,
from_="test",
to="test",
streamId="test",
controlFlags=0,
payload=0,
model_config={},
)


async def test_message_buffer_backpressure() -> None:
"""
Tests that MessageBuffer.put blocks until there is space in the buffer,
creating back pressure in the client.
"""
buffer = MessageBuffer(max_num_messages=1)

iterations = 100

# We use a queue as a way to sync our test logic with the background
# task with the testing logic.
sync_events: asyncio.Queue[None] = asyncio.Queue()

async def put_messages() -> None:
for i in range(0, iterations):
await buffer.put(mock_transport_message(seq=i))
await sync_events.put(None)

background_puts = asyncio.create_task(put_messages())

for i in range(1, iterations):
# Wait for the put call to return.
await sync_events.get()
assert len(buffer.buffer) == 1
await buffer.remove_old_messages(i)

await background_puts


async def test_message_buffer_close() -> None:
"""
Tests that MessageBuffer.put raises an exception when the buffer
is closed while the put operation is waiting for space in the buffer.
"""
buffer = MessageBuffer(max_num_messages=1)
await buffer.put(mock_transport_message(seq=1))
background_put = asyncio.create_task(buffer.put(mock_transport_message(seq=1)))
await buffer.close()
with pytest.raises(MessageBufferClosedError):
await background_put
with pytest.raises(MessageBufferClosedError):
await buffer.put(mock_transport_message(seq=1))