Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions plugboard/library/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .file_io import FileReader, FileWriter
from .llm import LLMChat
from .sql_io import SQLReader, SQLWriter
from .websocket_io import WebsocketReader, WebsocketWriter


__all__ = [
Expand All @@ -15,4 +16,6 @@
"FileWriter",
"SQLReader",
"SQLWriter",
"WebsocketReader",
"WebsocketWriter",
]
149 changes: 149 additions & 0 deletions plugboard/library/websocket_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""Provides `WebsocketReader` and `WebsocketWriter` realtime data in Plugboard."""

from contextlib import AsyncExitStack
import typing as _t

import msgspec.json as json

from plugboard.component import Component, IOController
from plugboard.utils import depends_on_optional


try:
from websockets.asyncio.client import connect
from websockets.asyncio.connection import Connection
from websockets.exceptions import ConnectionClosed
except ImportError:
pass


class WebsocketReader(Component):
"""Reads data from a websocket connection."""

io = IOController(outputs=["message"])

@depends_on_optional("websockets")
def __init__(
self,
name: str,
uri: str,
connect_args: dict[str, _t.Any] | None = None,
initial_message: _t.Any | None = None,
parse_json: bool = False,
*args: _t.Any,
**kwargs: _t.Any,
) -> None:
"""Instantiates the `WebsocketReader`.

See https://websockets.readthedocs.io/en/stable/reference/asyncio/client.html for possible
connection arguments that can be passed using `connect_args`. This `WebsocketReader` will
run until interrupted, and automatically reconnect if the server connection is lost.

Args:
name: The name of the `WebsocketReader`.
uri: The URI of the WebSocket server.
connect_args: Optional; Additional arguments to pass to the WebSocket connection.
initial_message: Optional; The initial message to send to the WebSocket server on
connection. Can be used to subscribe to a specific topic.
parse_json: Whether to parse the received data as JSON.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(name, *args, **kwargs)
self._uri = uri
self._connect_args = connect_args if connect_args else {}
if initial_message is not None:
self._initial_message = json.encode(initial_message) if parse_json else initial_message
else:
self._initial_message = None
self._parse_json = parse_json
self._ctx = AsyncExitStack()
self._conn: Connection | None = None

async def init(self) -> None:
"""Initializes the websocket connection."""
self._conn_iter = aiter(connect(self._uri, **self._connect_args))
self._conn = await self._get_conn()
self._logger.info(f"Connected to {self._uri}")

async def _get_conn(self) -> Connection:
conn = await self._ctx.enter_async_context(await anext(self._conn_iter))
if self._initial_message is not None:
self._logger.info(f"Sending initial message", message=self._initial_message)
await conn.send(self._initial_message)
return conn

async def step(self) -> None:
"""Reads a message from the websocket connection."""
if not self._conn:
self._conn = await self._get_conn()
try:
message = await self._conn.recv()
self.message = json.decode(message) if self._parse_json else message
except ConnectionClosed:
self._logger.warning(f"Connection to {self._uri} closed, will reconnect...")
self._conn = None

async def destroy(self) -> None:
"""Closes the websocket connection."""
await self._ctx.aclose()


class WebsocketWriter(Component):
"""Writes data to a websocket connection."""

io = IOController(inputs=["message"])

@depends_on_optional("websockets")
def __init__(
self,
name: str,
uri: str,
connect_args: dict[str, _t.Any] | None = None,
parse_json: bool = False,
*args: _t.Any,
**kwargs: _t.Any,
) -> None:
"""Instantiates the `WebsocketWriter`.

See https://websockets.readthedocs.io/en/stable/reference/asyncio/client.html for possible
connection arguments that can be passed using `connect_args`.

Args:
name: The name of the `WebsocketWriter`.
uri: The URI of the WebSocket server.
connect_args: Optional; Additional arguments to pass to the websocket connection.
parse_json: Whether to convert the data to JSON before sending.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(name, *args, **kwargs)
self._uri = uri
self._connect_args = connect_args if connect_args else {}
self._parse_json = parse_json
self._ctx = AsyncExitStack()

async def init(self) -> None:
"""Initializes the websocket connection."""
self._conn_iter = aiter(connect(self._uri, **self._connect_args))
self._conn = await self._get_conn()
self._logger.info(f"Connected to {self._uri}")

async def _get_conn(self) -> Connection:
return await self._ctx.enter_async_context(await anext(self._conn_iter))

async def step(self) -> None:
"""Writes a message to the websocket connection."""
message = json.encode(self.message) if self._parse_json else self.message
while True:
try:
await self._conn.send(message)
break
except ConnectionClosed:
self._logger.warning(f"Connection to {self._uri} closed, will reconnect...")
await self._ctx.aclose()
self._conn = await self._get_conn()

async def destroy(self) -> None:
"""Closes the websocket connection."""
await self._ctx.aclose()
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ ray = ["pyzmq~=26.2", "ray>=2.42.1"]
llm = [
"llama-index>=0.12.11",
]
websockets = [
"websockets>=14.2",
]

[dependency-groups]
dev = [
Expand All @@ -46,7 +49,7 @@ dev = [
test = [
"aiofile~=3.9",
"aiosqlite~=0.20",
"anyio>=4.3.0,<4.4.0", # FIXME: Pinned due to hanging tests when running with anyio==4.4.0 on 2024-07-01
"anyio>=4.3.0,<4.4.0", # FIXME: Pinned due to hanging tests when running with anyio==4.4.0 on 2024-07-01
"llama-index>=0.12.11",
"moto[server]~=5.0",
"openai-responses>=0.11.4",
Expand All @@ -57,7 +60,8 @@ test = [
"pytest-rerunfailures~=14.0",
"ray>=2.40.0",
"s3fs>=2024.9.0",
"time-machine~=2.15"
"time-machine~=2.15",
"websockets>=14.2",
]
docs = [
"mike~=2.1",
Expand Down
88 changes: 88 additions & 0 deletions tests/unit/test_websocket_reader_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Unit tests for the websocket components."""

import json
import typing as _t

import pytest
from websockets.asyncio.client import ClientConnection, connect
from websockets.asyncio.server import ServerConnection, serve

from plugboard.library.websocket_io import WebsocketReader, WebsocketWriter


HOST = "localhost"
PORT = 8767
CLIENTS = set()


async def _handler(websocket: ServerConnection) -> None:
"""Broadcasts incoming messages to all connected clients."""
CLIENTS.add(websocket)
try:
async for message in websocket:
for client in CLIENTS:
await client.send(message)
finally:
CLIENTS.remove(websocket)


@pytest.fixture
async def connected_client() -> _t.AsyncIterable[ClientConnection]:
"""Returns a client to a websocket broadcast server."""
async with serve(_handler, HOST, PORT):
async with connect(f"ws://{HOST}:{PORT}") as client:
yield client


@pytest.mark.asyncio
@pytest.mark.parametrize(
"parse_json,initial_message",
[(True, None), (False, None), (True, {"msg": "hello!"}), (False, "G'day!")],
)
async def test_websocket_reader(
connected_client: ClientConnection, parse_json: bool, initial_message: _t.Any
) -> None:
"""Tests the `WebsocketReader`."""
reader = WebsocketReader(
name="test-websocket",
uri=f"ws://{HOST}:{PORT}",
parse_json=parse_json,
initial_message=initial_message,
)
await reader.init()
# Send some messages to the server for broadcast to the reader
messages = [{"test-msg": x} for x in range(5)]
for message in messages:
await connected_client.send(json.dumps(message))

# If initial message set, it should be received first
if initial_message is not None:
await reader.step()
assert initial_message == reader.message
# Check that the reader receives the messages, correctly parsed
for message in messages:
await reader.step()
assert message == reader.message if parse_json else json.loads(reader.message)

await reader.destroy()


@pytest.mark.asyncio
@pytest.mark.parametrize("parse_json", [True, False])
async def test_websocket_writer(connected_client: ClientConnection, parse_json: bool) -> None:
"""Tests the `WebsocketWriter`."""
writer = WebsocketWriter(
name="test-websocket",
uri=f"ws://{HOST}:{PORT}",
parse_json=parse_json,
)
await writer.init()
messages = [{"test-msg": x} for x in range(5)]
for message in messages:
writer.message = message if parse_json else json.dumps(message)
await writer.step()
# Now retrieve the message from the broadcast
response = await connected_client.recv()
assert message == json.loads(response) if parse_json else response

await writer.destroy()
37 changes: 37 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.