Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions src/realtime/src/realtime/_async/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
else {
"config": {
"broadcast": {"ack": False, "self": False},
"presence": {"key": ""},
"presence": {"key": "", "enabled": False},
"private": False,
}
}
Expand Down Expand Up @@ -191,9 +191,16 @@ async def subscribe(
else:
config: RealtimeChannelConfig = self.params["config"]
broadcast = config.get("broadcast")
presence = config.get("presence")
presence = config.get("presence") or RealtimeChannelPresenceConfig(
key="", enabled=False
)
private = config.get("private", False)

presence_enabled = self.presence._has_callback_attached or presence.get(
"enabled", False
)
presence["enabled"] = presence_enabled

config_payload: Dict[str, Any] = {
"config": {
"broadcast": broadcast,
Expand Down Expand Up @@ -429,6 +436,13 @@ def on_presence_sync(self, callback: Callable[[], None]) -> AsyncRealtimeChannel
:return: The Channel instance for method chaining.
"""
self.presence.on_sync(callback)

if self.is_joined:
logger.info(
f"channel {self.topic} resubscribe due to change in presence callbacks on joined channel"
)
asyncio.create_task(self._resubscribe())

return self

def on_presence_join(
Expand All @@ -441,6 +455,12 @@ def on_presence_join(
:return: The Channel instance for method chaining.
"""
self.presence.on_join(callback)
if self.is_joined:
logger.info(
f"channel {self.topic} resubscribe due to change in presence callbacks on joined channel"
)
asyncio.create_task(self._resubscribe())

return self

def on_presence_leave(
Expand All @@ -453,6 +473,11 @@ def on_presence_leave(
:return: The Channel instance for method chaining.
"""
self.presence.on_leave(callback)
if self.is_joined:
logger.info(
f"channel {self.topic} resubscribe due to change in presence callbacks on joined channel"
)
asyncio.create_task(self._resubscribe())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@o-santi not sure if this create_task is appropriate here, should I hold a reference for the task to make sure it is finished/cancelled?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we can hold a reference but there's no real way to make sure its finished without randomly throwing errors, as we cannot raise something out of thin air.

If we really want to make sure it's finished, we should spawn a local loop and run it in there ... which is finicky. By the nature of unsubscribe and subscribe, we are not sure that we've got any state change, as the only thing those two functions do is send the message.

In order to properly handle this, these functions should've been async in the first place, but that can only be changed in the V3.

return self

# Broadcast methods
Expand All @@ -469,6 +494,11 @@ async def send_broadcast(self, event: str, data: Any) -> None:
)

# Internal methods

async def _resubscribe(self) -> None:
await self.unsubscribe()
await self.subscribe()

def _broadcast_endpoint_url(self):
return f"{http_endpoint_url(self.socket.http_endpoint)}/api/broadcast"

Expand Down
8 changes: 8 additions & 0 deletions src/realtime/src/realtime/_async/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@


class AsyncRealtimePresence:
@property
def _has_callback_attached(self) -> bool:
return (
self.on_join_callback is not None
or self.on_leave_callback is not None
or self.on_sync_callback is not None
)

def __init__(self):
self.state: RealtimePresenceState = {}
self.on_join_callback: Optional[PresenceOnJoinCallback] = None
Expand Down
1 change: 1 addition & 0 deletions src/realtime/src/realtime/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class RealtimeChannelBroadcastConfig(TypedDict):

class RealtimeChannelPresenceConfig(TypedDict):
key: str
enabled: bool


class RealtimeChannelConfig(TypedDict):
Expand Down
99 changes: 99 additions & 0 deletions src/realtime/tests/test_presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,102 @@ def test_transform_state_additional_fields():

result = AsyncRealtimePresence._transform_state(state_with_additional_fields)
assert result == expected_output


def test_presence_has_callback_attached():
"""Test that _has_callback_attached property correctly detects presence callbacks."""
presence = AsyncRealtimePresence()

# Initially no callbacks should be attached
assert not presence._has_callback_attached

# After setting sync callback
presence.on_sync(lambda: None)
assert presence._has_callback_attached

# Reset and test with join callback
presence = AsyncRealtimePresence()
presence.on_join(lambda key, current, new: None)
assert presence._has_callback_attached

# Reset and test with leave callback
presence = AsyncRealtimePresence()
presence.on_leave(lambda key, current, left: None)
assert presence._has_callback_attached


def test_presence_config_includes_enabled_field():
"""Test that presence config correctly includes enabled flag."""
from realtime.types import RealtimeChannelPresenceConfig

# Test creating presence config with enabled field
config: RealtimeChannelPresenceConfig = {"key": "user123", "enabled": True}
assert config["key"] == "user123"
assert config["enabled"] == True

# Test with enabled False
config_disabled: RealtimeChannelPresenceConfig = {"key": "", "enabled": False}
assert config_disabled["key"] == ""
assert config_disabled["enabled"] == False


@pytest.mark.asyncio
async def test_presence_enabled_when_callbacks_attached():
"""Test that presence.enabled is set correctly based on callback attachment."""
from unittest.mock import AsyncMock, Mock

socket = AsyncRealtimeClient(f"{URL}/realtime/v1", ANON_KEY)
channel = socket.channel("test")

# Mock the join_push to capture the payload
mock_join_push = Mock()
mock_join_push.receive = Mock(return_value=mock_join_push)
mock_join_push.update_payload = Mock()
mock_join_push.resend = AsyncMock()
channel.join_push = mock_join_push

# Mock socket connection by setting _ws_connection
mock_ws = Mock()
socket._ws_connection = mock_ws
socket._leave_open_topic = AsyncMock()

# Add presence callback before subscription
channel.on_presence_sync(lambda: None)

await channel.subscribe()

# Verify that update_payload was called
assert mock_join_push.update_payload.called

# Get the payload that was passed to update_payload
call_args = mock_join_push.update_payload.call_args
payload = call_args[0][0]

# Verify presence.enabled is True because callback is attached
assert payload["config"]["presence"]["enabled"] == True


@pytest.mark.asyncio
async def test_resubscribe_on_presence_callback_addition():
"""Test that channel resubscribes when presence callbacks are added after joining."""
import asyncio
from unittest.mock import AsyncMock

socket = AsyncRealtimeClient(f"{URL}/realtime/v1", ANON_KEY)
channel = socket.channel("test")

# Mock the channel as joined
channel.state = "joined"
channel._joined_once = True

# Mock resubscribe method
channel._resubscribe = AsyncMock()

# Add presence callbacks after joining
channel.on_presence_sync(lambda: None)

# Wait a bit for async tasks to complete
await asyncio.sleep(0.1)

# Verify resubscribe was called
assert channel._resubscribe.call_count == 1