Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
f4c87e9
feat: do not supress callback exceptions raised from calling callbacks
o-santi Jul 1, 2025
5b8cde0
fix: add explicit clauses for ConnectionClosedOK in send and heartbeat
o-santi Jul 1, 2025
82e5812
Merge remote-tracking branch 'origin/main' into o-santi/only-catch-co…
o-santi Jul 1, 2025
444b4a6
fix: change import order
o-santi Jul 1, 2025
c45fcc9
fix: apply ruff format
o-santi Jul 1, 2025
7b684d5
Merge branch 'main' into o-santi-main
o-santi Jul 3, 2025
cb9e1df
feat: introduce flake.nix leveraging pyproject.nix
o-santi Jul 3, 2025
5c235ce
chore: make `flake.nix` aware of dev dependencies, including ruff and…
o-santi Jul 4, 2025
db9a455
chore: refactor `dependencies-for`
o-santi Jul 4, 2025
52fe664
feat: add nix develop CI job
o-santi Jul 7, 2025
c2242e2
fix: change job name to be more compatible with the other actions
o-santi Jul 7, 2025
536e60b
Merge branch 'main' into o-santi-main
o-santi Jul 7, 2025
9c1f2d9
feat: setup infra and run `python -m pytest` to run tests in nix setu…
o-santi Jul 7, 2025
51caefc
fix: do not rely on `npx` as its not on nix develop
o-santi Jul 7, 2025
c30f9b2
fix: add `--command` typo
o-santi Jul 7, 2025
9f07df1
chore: change name of nix develop command part
o-santi Jul 7, 2025
9f726c9
feat: add coverage information and upload it to coveralls
o-santi Jul 7, 2025
ea2aba7
chore: run nix setup tests in both ubuntu and macos
o-santi Jul 7, 2025
c09ccea
chore: undo macos latest as it apparently does not work
o-santi Jul 7, 2025
a3c32e9
chore: switch from pylsp + mypy to pyright (basepyright)
o-santi Jul 7, 2025
0e766d2
feat: add `basedpyright` to `pyproject.toml` instead of flake.nix only
o-santi Jul 7, 2025
1a60bbd
fix: change back to python-lsp-server with pylsp-mypy
o-santi Jul 8, 2025
d64b494
Merge branch 'main' into o-santi-main
o-santi Jul 8, 2025
54c7c9c
chore: improve type definitions files to make mypy happy
o-santi Jul 9, 2025
bedb062
chore: add mypy to CI as a step
o-santi Jul 9, 2025
1bf7be0
fix: run mypy through poetry
o-santi Jul 9, 2025
800ad97
fix: check for phx_ref_prev before calling del
o-santi Jul 9, 2025
c1b961c
fix: use `python -m mypy` instead of `mypy` directly
o-santi Jul 9, 2025
5af58b4
fix: run type check after `make run_tests` so that `poetry install` i…
o-santi Jul 9, 2025
f8a1dc1
fix: `StrEnum` does not exist in python 3.9
o-santi Jul 9, 2025
6fce253
fix: remove `phx_ref` from presence dict
o-santi Jul 9, 2025
f548403
fix: add deprecation warning for calling send with dicts
o-santi Jul 9, 2025
dbc03f1
fix: fix config payload, improve more types
o-santi Jul 10, 2025
491d349
fix: make typing definitions compatible with python 3.9
o-santi Jul 10, 2025
8264569
fix: import annotations for 3.9 to not complain about type error
o-santi Jul 10, 2025
940a94c
fix: move mypy to Makefile `run_tests`, run it before actual tests
o-santi Jul 10, 2025
7819eff
format: apply ruff reformating
o-santi Jul 10, 2025
98628ad
fix: try explicitly annotating `Callback` as a `TypeAlias`
o-santi Jul 10, 2025
c684704
fix: change `dict` to `Dict` type due to 3.9
o-santi Jul 10, 2025
7ef3cfb
fix: finally, import annotations to stop runtime from breaking
o-santi Jul 10, 2025
66c44a6
format: apply ruff format one last time
o-santi Jul 10, 2025
e8d7495
format: reorder import order in test_connection
o-santi Jul 10, 2025
f671af6
format: trim whitespace
o-santi Jul 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ jobs:
github_access_token: ${{ secrets.GITHUB_TOKEN }}
- name: Clone Repository
uses: actions/checkout@v4
- name: Type check
run: nix develop --command mypy ./realtime
- name: Start Supabase local development setup
run: nix develop --command supabase start --workdir infra -x studio,mailpit,edge-runtime,logflare,vector,supavisor,imgproxy,storage-api
- name: Run python tests through nix
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ install_poetry:
curl -sSL https://install.python-poetry.org | python -
poetry install

tests: install tests_only tests_pre_commit
tests: install run_mypy tests_only tests_pre_commit

tests_pre_commit:
poetry run pre-commit run --all-files

run_mypy:
poetry run mypy ./realtime

run_infra:
npx supabase start --workdir infra -x studio,mailpit,edge-runtime,logflare,vector,supavisor,imgproxy,storage-api

Expand Down
163 changes: 162 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ requires-python = ">=3.9"
dependencies = [
"websockets >=11,<16",
"typing-extensions >=4.14.0",
"pydantic (>=2.11.7,<3.0.0)",
]

[tool.poetry.group.dev.dependencies]
Expand Down Expand Up @@ -77,3 +78,13 @@ keep-runtime-typing = true
[tool.pytest.ini_options]
asyncio_mode = "strict"
asyncio_default_fixture_loop_scope = "function"

[tool.mypy]
python_version = "3.9"
check_untyped_defs = true
allow_redefinition = true

warn_return_any = true
warn_unused_configs = true
warn_redundant_casts = true
warn_unused_ignores = true
87 changes: 49 additions & 38 deletions realtime/_async/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import json
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional

from realtime.types import (
Binding,
Expand All @@ -16,6 +16,7 @@
RealtimeSubscribeStates,
)

from ..message import Message
from ..transformers import http_endpoint_url
from .presence import (
AsyncRealtimePresence,
Expand Down Expand Up @@ -52,23 +53,27 @@ def __init__(
:param params: Optional parameters for connection.
"""
self.socket = socket
self.params = params or {}
if self.params.get("config") is None:
self.params["config"] = {
"broadcast": {"ack": False, "self": False},
"presence": {"key": ""},
"private": False,
self.params: RealtimeChannelOptions = (
params
if params
else {
"config": {
"broadcast": {"ack": False, "self": False},
"presence": {"key": ""},
"private": False,
}
}
)

self.topic = topic
self._joined_once = False
self.bindings: Dict[str, List[Binding]] = {}
self.bindings: dict[str, list[Binding]] = {}
self.presence = AsyncRealtimePresence(self)
self.state = ChannelStates.CLOSED
self._push_buffer: List[AsyncPush] = []
self._push_buffer: list[AsyncPush] = []
self.timeout = self.socket.timeout

self.join_push = AsyncPush(self, ChannelEvents.join, self.params)
self.join_push: AsyncPush = AsyncPush(self, ChannelEvents.join, self.params)
self.rejoin_timer = AsyncTimer(
self._rejoin_until_connected, lambda tries: 2**tries
)
Expand Down Expand Up @@ -111,8 +116,9 @@ def on_error(payload, *args):
self._on("close", on_close)
self._on("error", on_error)

def on_reply(payload, ref):
self._trigger(self._reply_event_name(ref), payload)
def on_reply(payload: Dict[str, Any], ref: Optional[str]):
if ref:
self._trigger(self._reply_event_name(ref), payload)

self._on(ChannelEvents.reply, on_reply)

Expand Down Expand Up @@ -169,22 +175,24 @@ async def subscribe(
presence = config.get("presence", {})
private = config.get("private", False)

access_token_payload = {}
config = {
"broadcast": broadcast,
"presence": presence,
"private": private,
"postgres_changes": list(
map(lambda x: x.filter, self.bindings.get("postgres_changes", []))
),
config_payload: Dict[str, Any] = {
"config": {
"broadcast": broadcast,
"presence": presence,
"private": private,
"postgres_changes": list(
map(
lambda x: x.filter,
self.bindings.get("postgres_changes", []),
)
),
}
}

if self.socket.access_token:
access_token_payload["access_token"] = self.socket.access_token
config_payload["access_token"] = self.socket.access_token

self.join_push.update_payload(
{**{"config": config}, **access_token_payload}
)
self.join_push.update_payload(config_payload)
self._joined_once = True

def on_join_push_ok(payload: Dict[str, Any]):
Expand Down Expand Up @@ -253,7 +261,7 @@ def on_join_push_timeout(*args):

return self

async def unsubscribe(self):
async def unsubscribe(self) -> None:
"""
Unsubscribe from the channel and leave the topic.
Sets channel state to LEAVING and cleans up timers and pushes.
Expand All @@ -263,9 +271,9 @@ async def unsubscribe(self):
self.rejoin_timer.reset()
self.join_push.destroy()

def _close(*args):
def _close(*args) -> None:
logger.info(f"channel {self.topic} leave")
self._trigger(ChannelEvents.close, "leave")
self._trigger(ChannelEvents.close, {})

leave_push = AsyncPush(self, ChannelEvents.leave, {})
leave_push.receive("ok", _close).receive("timeout", _close)
Expand Down Expand Up @@ -310,21 +318,24 @@ async def join(self) -> AsyncRealtimeChannel:
:return: Channel
"""
try:
await self.socket.send(
{
"topic": self.topic,
"event": "phx_join",
"payload": {"config": self.params},
"ref": None,
}
message = Message(
topic=self.topic,
event=ChannelEvents.join,
payload={"config": self.params},
ref=None,
)
await self.socket.send(message)
return self
except Exception as e:
print(e)
return self

# Event handling methods
def _on(
self, type: str, callback: Callback, filter: Optional[Dict[str, Any]] = None
self,
type: str,
callback: Callback[[Dict[str, Any], Optional[str]], None],
filter: Optional[Dict[str, Any]] = None,
) -> AsyncRealtimeChannel:
"""
Set up a listener for a specific event.
Expand Down Expand Up @@ -411,7 +422,7 @@ def on_postgres_changes(
)

def on_system(
self, callback: Callable[[Dict[str, Any], None]]
self, callback: Callable[[Dict[str, Any]], None]
) -> AsyncRealtimeChannel:
"""
Set up a listener for system events.
Expand Down Expand Up @@ -508,7 +519,7 @@ def _can_push(self):
async def send_presence(self, event: str, data: Any) -> None:
await self.push(ChannelEvents.presence, {"event": event, "payload": data})

def _trigger(self, type: str, payload: Optional[Any], ref: Optional[str] = None):
def _trigger(self, type: str, payload: Dict[str, Any], ref: Optional[str] = None):
type_lowercase = type.lower()
events = [
ChannelEvents.close,
Expand Down Expand Up @@ -562,7 +573,7 @@ def _trigger(self, type: str, payload: Optional[Any], ref: Optional[str] = None)
elif binding.type == type_lowercase:
binding.callback(payload, ref)

def _reply_event_name(self, ref: str):
def _reply_event_name(self, ref: str) -> str:
return f"chan_reply_{ref}"

async def _rejoin_until_connected(self):
Expand Down
44 changes: 22 additions & 22 deletions realtime/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import logging
import re
from functools import wraps
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Union
from urllib.parse import urlencode, urlparse, urlunparse

import websockets
from websockets import connect
from websockets.client import ClientProtocol
from websockets.asyncio.client import ClientConnection

from ..exceptions import NotConnectedError
from ..message import Message
Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(
:param timeout: Connection timeout in seconds. Defaults to DEFAULT_TIMEOUT.
"""
if not is_ws_url(url):
ValueError("url must be a valid WebSocket URL or HTTP URL string")
raise ValueError("url must be a valid WebSocket URL or HTTP URL string")
self.url = f"{re.sub(r'https://', 'wss://', re.sub(r'http://', 'ws://', url, flags=re.IGNORECASE), flags=re.IGNORECASE)}/websocket"
if token:
self.url += f"?apikey={token}"
Expand All @@ -72,7 +72,7 @@ def __init__(
self.access_token = token
self.send_buffer: List[Callable] = []
self.hb_interval = hb_interval
self._ws_connection: Optional[ClientProtocol] = None
self._ws_connection: Optional[ClientConnection] = None
self.ref = 0
self.auto_reconnect = auto_reconnect
self.channels: Dict[str, AsyncRealtimeChannel] = {}
Expand All @@ -97,13 +97,15 @@ async def _listen(self) -> None:

try:
async for msg in self._ws_connection:
logger.info(f"receive: {msg}")
logger.info(f"receive: {msg!r}")

msg = Message(**json.loads(msg))
channel = self.channels.get(msg.topic)
message = Message.model_validate_json(msg)
channel = self.channels.get(message.topic)

if channel:
channel._trigger(msg.event, msg.payload, msg.ref)
channel._trigger(
message.event, dict(**message.payload), message.ref
)
except websockets.exceptions.ConnectionClosedError as e:
await self._on_connect_error(e)

Expand Down Expand Up @@ -236,7 +238,7 @@ async def _heartbeat(self) -> None:

while self.is_connected:
try:
data = dict(
data = Message(
topic=PHOENIX_CHANNEL,
event=ChannelEvents.heartbeat,
payload={},
Expand Down Expand Up @@ -294,14 +296,6 @@ async def remove_all_channels(self) -> None:

await self.close()

def summary(self) -> None:
"""
Prints a list of topics and event the socket is listening to
:return: None
"""
for topic, channel in self.channels.items():
print(f"Topic: {topic} | Events: {[e for e, _ in channel.listeners]}]")

async def set_auth(self, token: Optional[str]) -> None:
"""
Set the authentication token for the connection and update all joined channels.
Expand All @@ -325,7 +319,7 @@ def _make_ref(self) -> str:
self.ref += 1
return f"{self.ref}"

async def send(self, message: Dict[str, Any]) -> None:
async def send(self, message: Union[Message, Dict[str, Any]]) -> None:
"""
Send a message through the WebSocket connection.

Expand All @@ -340,16 +334,22 @@ async def send(self, message: Dict[str, Any]) -> None:
Returns:
None
"""

message = json.dumps(message)
logger.info(f"send: {message}")
if isinstance(message, Message):
msg = message
else:
logger.warning(
"Warning: calling AsyncRealtimeClient.send with a dictionary is deprecated. Please call it with a Message object instead. This will be a hard error in the future."
)
msg = Message(**message)
message_str = msg.model_dump_json()
logger.info(f"send: {message_str}")

async def send_message():
if not self._ws_connection:
raise NotConnectedError("_send")

try:
await self._ws_connection.send(message)
await self._ws_connection.send(message_str)
except websockets.exceptions.ConnectionClosedError as e:
await self._on_connect_error(e)
except websockets.exceptions.ConnectionClosedOK:
Expand Down
Loading