diff --git a/src/realtime/src/realtime/_async/channel.py b/src/realtime/src/realtime/_async/channel.py index f217f7d0..b6171098 100644 --- a/src/realtime/src/realtime/_async/channel.py +++ b/src/realtime/src/realtime/_async/channel.py @@ -85,7 +85,7 @@ def __init__( else { "config": { "broadcast": {"ack": False, "self": False}, - "presence": {"key": ""}, + "presence": {"key": "", "enabled": False}, "private": False, } } @@ -191,9 +191,16 @@ async def subscribe( else: config: RealtimeChannelConfig = self.params["config"] broadcast = config.get("broadcast") - presence = config.get("presence") + presence = config.get("presence") or RealtimeChannelPresenceConfig( + key="", enabled=False + ) private = config.get("private", False) + presence_enabled = self.presence._has_callback_attached or presence.get( + "enabled", False + ) + presence["enabled"] = presence_enabled + config_payload: Dict[str, Any] = { "config": { "broadcast": broadcast, @@ -429,6 +436,13 @@ def on_presence_sync(self, callback: Callable[[], None]) -> AsyncRealtimeChannel :return: The Channel instance for method chaining. """ self.presence.on_sync(callback) + + if self.is_joined: + logger.info( + f"channel {self.topic} resubscribe due to change in presence callbacks on joined channel" + ) + asyncio.create_task(self._resubscribe()) + return self def on_presence_join( @@ -441,6 +455,12 @@ def on_presence_join( :return: The Channel instance for method chaining. """ self.presence.on_join(callback) + if self.is_joined: + logger.info( + f"channel {self.topic} resubscribe due to change in presence callbacks on joined channel" + ) + asyncio.create_task(self._resubscribe()) + return self def on_presence_leave( @@ -453,6 +473,11 @@ def on_presence_leave( :return: The Channel instance for method chaining. """ self.presence.on_leave(callback) + if self.is_joined: + logger.info( + f"channel {self.topic} resubscribe due to change in presence callbacks on joined channel" + ) + asyncio.create_task(self._resubscribe()) return self # Broadcast methods @@ -469,6 +494,11 @@ async def send_broadcast(self, event: str, data: Any) -> None: ) # Internal methods + + async def _resubscribe(self) -> None: + await self.unsubscribe() + await self.subscribe() + def _broadcast_endpoint_url(self): return f"{http_endpoint_url(self.socket.http_endpoint)}/api/broadcast" diff --git a/src/realtime/src/realtime/_async/presence.py b/src/realtime/src/realtime/_async/presence.py index b33ae496..ad376702 100644 --- a/src/realtime/src/realtime/_async/presence.py +++ b/src/realtime/src/realtime/_async/presence.py @@ -21,6 +21,14 @@ class AsyncRealtimePresence: + @property + def _has_callback_attached(self) -> bool: + return ( + self.on_join_callback is not None + or self.on_leave_callback is not None + or self.on_sync_callback is not None + ) + def __init__(self): self.state: RealtimePresenceState = {} self.on_join_callback: Optional[PresenceOnJoinCallback] = None diff --git a/src/realtime/src/realtime/types.py b/src/realtime/src/realtime/types.py index 8274e240..75f8d4a6 100644 --- a/src/realtime/src/realtime/types.py +++ b/src/realtime/src/realtime/types.py @@ -179,6 +179,7 @@ class RealtimeChannelBroadcastConfig(TypedDict): class RealtimeChannelPresenceConfig(TypedDict): key: str + enabled: bool class RealtimeChannelConfig(TypedDict): diff --git a/src/realtime/tests/test_presence.py b/src/realtime/tests/test_presence.py index d7727316..c0e8c62b 100644 --- a/src/realtime/tests/test_presence.py +++ b/src/realtime/tests/test_presence.py @@ -173,3 +173,102 @@ def test_transform_state_additional_fields(): result = AsyncRealtimePresence._transform_state(state_with_additional_fields) assert result == expected_output + + +def test_presence_has_callback_attached(): + """Test that _has_callback_attached property correctly detects presence callbacks.""" + presence = AsyncRealtimePresence() + + # Initially no callbacks should be attached + assert not presence._has_callback_attached + + # After setting sync callback + presence.on_sync(lambda: None) + assert presence._has_callback_attached + + # Reset and test with join callback + presence = AsyncRealtimePresence() + presence.on_join(lambda key, current, new: None) + assert presence._has_callback_attached + + # Reset and test with leave callback + presence = AsyncRealtimePresence() + presence.on_leave(lambda key, current, left: None) + assert presence._has_callback_attached + + +def test_presence_config_includes_enabled_field(): + """Test that presence config correctly includes enabled flag.""" + from realtime.types import RealtimeChannelPresenceConfig + + # Test creating presence config with enabled field + config: RealtimeChannelPresenceConfig = {"key": "user123", "enabled": True} + assert config["key"] == "user123" + assert config["enabled"] == True + + # Test with enabled False + config_disabled: RealtimeChannelPresenceConfig = {"key": "", "enabled": False} + assert config_disabled["key"] == "" + assert config_disabled["enabled"] == False + + +@pytest.mark.asyncio +async def test_presence_enabled_when_callbacks_attached(): + """Test that presence.enabled is set correctly based on callback attachment.""" + from unittest.mock import AsyncMock, Mock + + socket = AsyncRealtimeClient(f"{URL}/realtime/v1", ANON_KEY) + channel = socket.channel("test") + + # Mock the join_push to capture the payload + mock_join_push = Mock() + mock_join_push.receive = Mock(return_value=mock_join_push) + mock_join_push.update_payload = Mock() + mock_join_push.resend = AsyncMock() + channel.join_push = mock_join_push + + # Mock socket connection by setting _ws_connection + mock_ws = Mock() + socket._ws_connection = mock_ws + socket._leave_open_topic = AsyncMock() + + # Add presence callback before subscription + channel.on_presence_sync(lambda: None) + + await channel.subscribe() + + # Verify that update_payload was called + assert mock_join_push.update_payload.called + + # Get the payload that was passed to update_payload + call_args = mock_join_push.update_payload.call_args + payload = call_args[0][0] + + # Verify presence.enabled is True because callback is attached + assert payload["config"]["presence"]["enabled"] == True + + +@pytest.mark.asyncio +async def test_resubscribe_on_presence_callback_addition(): + """Test that channel resubscribes when presence callbacks are added after joining.""" + import asyncio + from unittest.mock import AsyncMock + + socket = AsyncRealtimeClient(f"{URL}/realtime/v1", ANON_KEY) + channel = socket.channel("test") + + # Mock the channel as joined + channel.state = "joined" + channel._joined_once = True + + # Mock resubscribe method + channel._resubscribe = AsyncMock() + + # Add presence callbacks after joining + channel.on_presence_sync(lambda: None) + + # Wait a bit for async tasks to complete + await asyncio.sleep(0.1) + + # Verify resubscribe was called + assert channel._resubscribe.call_count == 1