Skip to content

Commit

Permalink
Add direct coverage for event.py (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmulcahey committed Apr 1, 2024
1 parent 325a772 commit 4e98384
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 13 deletions.
185 changes: 185 additions & 0 deletions tests/test_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""Event tests for ZHA."""

from __future__ import annotations

from unittest.mock import AsyncMock, MagicMock

import pytest

from zha.application.gateway import Gateway
from zha.event import EventBase


class EventGenerator(EventBase):
"""Event generator for testing."""


class Event:
"""Event class for testing."""

event = "test"
event_type = "testing"


def test_event_base_unsubs():
"""Test event base class."""
event = EventGenerator()
assert not event._listeners
assert not event._golbal_listeners

callback = MagicMock()

unsub = event.on_event("test", callback)
assert event._listeners == {"test": [callback]}
unsub()
assert event._listeners == {"test": []}

unsub = event.on_all_events(callback)
assert event._golbal_listeners == [callback]
unsub()
assert not event._golbal_listeners

unsub = event.once("test", callback)
assert "test" in event._listeners
assert len(event._listeners["test"]) == 1
unsub()
assert event._listeners == {"test": []}


def test_event_base_emit():
"""Test event base class."""
event = EventGenerator()
assert not event._listeners
assert not event._golbal_listeners

callback = MagicMock()

event.once("test", callback)
event.emit("test")
assert callback.called

callback.reset_mock()
event.emit("test")
assert not callback.called

unsub = event.on_event("test", callback)
event.emit("test")
assert callback.called
unsub()

callback.reset_mock()
unsub = event.on_all_events(callback)
event.emit("test")
assert callback.called
unsub()

assert "test" in event._listeners
assert event._listeners == {"test": []}
assert not event._golbal_listeners


def test_event_base_emit_data():
"""Test event base class."""
event = EventGenerator()
assert not event._listeners
assert not event._golbal_listeners

callback = MagicMock()

event.once("test", callback)
event.emit("test", "data")
assert callback.called
assert callback.call_args[0] == ("data",)

callback.reset_mock()
event.emit("test", "data")
assert not callback.called

unsub = event.on_event("test", callback)
event.emit("test", "data")
assert callback.called
assert callback.call_args[0] == ("data",)
unsub()

callback.reset_mock()
unsub = event.on_all_events(callback)
event.emit("test", "data")
assert callback.called
assert callback.call_args[0] == ("data",)
unsub()

assert "test" in event._listeners
assert event._listeners == {"test": []}
assert not event._golbal_listeners


async def test_event_base_emit_coro(zha_gateway: Gateway):
"""Test event base class."""
event = EventGenerator()
assert not event._listeners
assert not event._golbal_listeners

callback = AsyncMock()

event.once("test", callback)
event.emit("test", "data")
await zha_gateway.async_block_till_done()
await zha_gateway.async_block_till_done()
assert callback.await_count == 1
assert callback.await_args[0] == ("data",)

callback.reset_mock()

unsub = event.on_event("test", callback)
event.emit("test", "data")
await zha_gateway.async_block_till_done()
await zha_gateway.async_block_till_done()
assert callback.await_count == 1
assert callback.await_args[0] == ("data",)
unsub()

callback.reset_mock()

unsub = event.on_all_events(callback)
event.emit("test", "data")
await zha_gateway.async_block_till_done()
await zha_gateway.async_block_till_done()
assert callback.await_count == 1
assert callback.await_args[0] == ("data",)
unsub()

test_event = Event()
event.on_event(test_event.event, event._handle_event_protocol)
event.handle_test = AsyncMock()

event.emit(test_event.event, test_event)
await zha_gateway.async_block_till_done()
await zha_gateway.async_block_till_done()

assert event.handle_test.await_count == 1
assert event.handle_test.await_args[0] == (test_event,)


def test_handle_event_protocol():
"""Test event base class."""

event_handler = EventGenerator()
event_handler.handle_test = MagicMock()
event_handler.on_event("test", event_handler._handle_event_protocol)

event = Event()
event_handler.emit(event.event, event)

assert event_handler.handle_test.called
assert event_handler.handle_test.call_args[0] == (event,)


def test_handle_event_protocol_no_event(caplog: pytest.LogCaptureFixture):
"""Test event base class."""

event_handler = EventGenerator()
event_handler.on_event("not_test", event_handler._handle_event_protocol)
event = Event()
event_handler.emit("not_test", event)

assert "Received unknown event:" in caplog.text
27 changes: 14 additions & 13 deletions zha/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,31 +50,32 @@ def unsubscribe() -> None:

def once(self, event_name: str, callback: Callable) -> Callable:
"""Listen for an event exactly once."""
if inspect.iscoroutinefunction(callback):

async def async_event_listener(data: dict) -> None:
unsub()
task = asyncio.create_task(callback(data))
self._event_tasks.append(task)
task.add_done_callback(self._event_tasks.remove)

unsub = self.on_event(event_name, async_event_listener)
return unsub

def event_listener(data: dict) -> None:
unsub()
callback(data)

unsub = self.on_event(event_name, event_listener)

return unsub

def emit(self, event_name: str, data=None) -> None:
"""Run all callbacks for an event."""
for listener in [*self._listeners.get(event_name, []), *self._golbal_listeners]:
if inspect.iscoroutinefunction(listener):
if data is None:
task = asyncio.create_task(listener())
self._event_tasks.append(task)
task.add_done_callback(self._event_tasks.remove)
else:
task = asyncio.create_task(listener(data))
self._event_tasks.append(task)
task.add_done_callback(self._event_tasks.remove)
elif data is None:
listener()
else:
listener(data)
task = asyncio.create_task(listener(data))
self._event_tasks.append(task)
task.add_done_callback(self._event_tasks.remove)
listener(data)

def _handle_event_protocol(self, event) -> None:
"""Process an event based on event protocol."""
Expand Down

0 comments on commit 4e98384

Please sign in to comment.