From 1fa78338faa2227f32092bf59bf1d6a1c6f7cb5a Mon Sep 17 00:00:00 2001 From: David Mulcahey Date: Thu, 17 Oct 2024 13:58:47 -0400 Subject: [PATCH 01/12] Convert dataclasses to pydantic models --- pyproject.toml | 1 + tests/test_gateway.py | 21 ++---- zha/application/gateway.py | 66 +++++++---------- zha/application/helpers.py | 69 ++++++++---------- zha/application/platforms/__init__.py | 45 +++++------- .../platforms/alarm_control_panel/__init__.py | 2 - .../platforms/binary_sensor/__init__.py | 2 - zha/application/platforms/button/__init__.py | 3 - zha/application/platforms/climate/__init__.py | 2 - zha/application/platforms/fan/__init__.py | 2 - zha/application/platforms/light/__init__.py | 6 +- zha/application/platforms/number/__init__.py | 5 +- zha/application/platforms/select.py | 2 - zha/application/platforms/sensor/__init__.py | 26 +++---- zha/application/platforms/siren.py | 2 - zha/application/platforms/switch.py | 2 - zha/application/platforms/update.py | 2 - zha/model.py | 62 ++++++++++++++++ zha/zigbee/cluster_handlers/__init__.py | 56 +++++++------- zha/zigbee/cluster_handlers/general.py | 9 +-- zha/zigbee/cluster_handlers/security.py | 11 ++- zha/zigbee/device.py | 73 ++++++++----------- zha/zigbee/group.py | 16 ++-- 23 files changed, 237 insertions(+), 248 deletions(-) create mode 100644 zha/model.py diff --git a/pyproject.toml b/pyproject.toml index 18f002edd..f3b9ed562 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "zha-quirks==0.0.124", "pyserial==3.5", "pyserial-asyncio-fast", + "pydantic==2.9.2" ] [tool.setuptools.packages.find] diff --git a/tests/test_gateway.py b/tests/test_gateway.py index bd873aa3f..eb7f45abf 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -24,12 +24,7 @@ join_zigpy_device, ) from zha.application import Platform -from zha.application.const import ( - CONF_USE_THREAD, - ZHA_GW_MSG, - ZHA_GW_MSG_CONNECTION_LOST, - RadioType, -) +from zha.application.const import CONF_USE_THREAD, ZHA_GW_MSG_CONNECTION_LOST, RadioType from zha.application.gateway import ( ConnectionLostEvent, DeviceJoinedDeviceInfo, @@ -91,7 +86,7 @@ async def coordinator(zha_gateway: Gateway) -> Device: } }, ieee="00:15:8d:00:02:32:4f:32", - nwk=0x0000, + nwk=zigpy.types.NWK(0x0000), node_descriptor=zdo_t.NodeDescriptor( logical_type=zdo_t.LogicalType.Coordinator, complex_descriptor_available=0, @@ -535,7 +530,7 @@ async def test_startup_concurrency_limit( } }, ieee=f"11:22:33:44:{i:08x}", - nwk=0x1234 + i, + nwk=zigpy.types.NWK(0x1234 + i), ) zigpy_dev.node_desc.mac_capability_flags |= ( zigpy.zdo.types.NodeDescriptor.MACCapabilityFlags.MainsPowered @@ -645,7 +640,7 @@ def test_gateway_raw_device_initialized( RawDeviceInitializedEvent( device_info=RawDeviceInitializedDeviceInfo( ieee=zigpy.types.EUI64.convert("00:0d:6f:00:0a:90:69:e7"), - nwk=0xB79C, + nwk=zigpy.types.NWK(0xB79C), pairing_status=DevicePairingStatus.INTERVIEW_COMPLETE, model="FakeModel", manufacturer="FakeManufacturer", @@ -676,9 +671,7 @@ def test_gateway_raw_device_initialized( } }, }, - ), - event_type="zha_gateway_message", - event="raw_device_initialized", + ) ), ) @@ -698,7 +691,7 @@ def test_gateway_device_joined( DeviceJoinedEvent( device_info=DeviceJoinedDeviceInfo( ieee=zigpy.types.EUI64.convert("00:0d:6f:00:0a:90:69:e7"), - nwk=0xB79C, + nwk=zigpy.types.NWK(0xB79C), pairing_status=DevicePairingStatus.PAIRED, ) ), @@ -717,8 +710,6 @@ def test_gateway_connection_lost(zha_gateway: Gateway) -> None: ZHA_GW_MSG_CONNECTION_LOST, ConnectionLostEvent( exception=exception, - event=ZHA_GW_MSG_CONNECTION_LOST, - event_type=ZHA_GW_MSG, ), ) diff --git a/zha/application/gateway.py b/zha/application/gateway.py index 60ab3ca05..561451f8c 100644 --- a/zha/application/gateway.py +++ b/zha/application/gateway.py @@ -4,12 +4,11 @@ import asyncio from contextlib import suppress -from dataclasses import dataclass from datetime import timedelta from enum import Enum import logging import time -from typing import Any, Final, Self, TypeVar, cast +from typing import Any, Final, Literal, Self, TypeVar, cast from zhaquirks import setup as setup_quirks from zigpy.application import ControllerApplication @@ -25,14 +24,13 @@ import zigpy.group from zigpy.quirks.v2 import UNBUILT_QUIRK_BUILDERS from zigpy.state import State -from zigpy.types.named import EUI64 +from zigpy.types.named import EUI64, NWK from zha.application import discovery from zha.application.const import ( CONF_USE_THREAD, UNKNOWN_MANUFACTURER, UNKNOWN_MODEL, - ZHA_GW_MSG, ZHA_GW_MSG_CONNECTION_LOST, ZHA_GW_MSG_DEVICE_FULL_INIT, ZHA_GW_MSG_DEVICE_JOINED, @@ -52,6 +50,7 @@ gather_with_limited_concurrency, ) from zha.event import EventBase +from zha.model import BaseEvent, BaseModel from zha.zigbee.device import Device, DeviceInfo, DeviceStatus, ExtendedDeviceInfo from zha.zigbee.group import Group, GroupInfo, GroupMemberReference @@ -69,58 +68,51 @@ class DevicePairingStatus(Enum): INITIALIZED = 4 -@dataclass(kw_only=True, frozen=True) class DeviceInfoWithPairingStatus(DeviceInfo): """Information about a device with pairing status.""" pairing_status: DevicePairingStatus -@dataclass(kw_only=True, frozen=True) class ExtendedDeviceInfoWithPairingStatus(ExtendedDeviceInfo): """Information about a device with pairing status.""" pairing_status: DevicePairingStatus -@dataclass(kw_only=True, frozen=True) -class DeviceJoinedDeviceInfo: +class DeviceJoinedDeviceInfo(BaseModel): """Information about a device.""" - ieee: str - nwk: int + ieee: EUI64 + nwk: NWK pairing_status: DevicePairingStatus -@dataclass(kw_only=True, frozen=True) -class ConnectionLostEvent: +class ConnectionLostEvent(BaseEvent): """Event to signal that the connection to the radio has been lost.""" - event_type: Final[str] = ZHA_GW_MSG - event: Final[str] = ZHA_GW_MSG_CONNECTION_LOST + event_type: Literal["zha_gateway_message"] = "zha_gateway_message" + event: Literal["connection_lost"] = "connection_lost" exception: Exception | None = None -@dataclass(kw_only=True, frozen=True) -class DeviceJoinedEvent: +class DeviceJoinedEvent(BaseEvent): """Event to signal that a device has joined the network.""" device_info: DeviceJoinedDeviceInfo - event_type: Final[str] = ZHA_GW_MSG - event: Final[str] = ZHA_GW_MSG_DEVICE_JOINED + event_type: Literal["zha_gateway_message"] = "zha_gateway_message" + event: Literal["device_joined"] = "device_joined" -@dataclass(kw_only=True, frozen=True) -class DeviceLeftEvent: +class DeviceLeftEvent(BaseEvent): """Event to signal that a device has left the network.""" ieee: EUI64 - nwk: int - event_type: Final[str] = ZHA_GW_MSG - event: Final[str] = ZHA_GW_MSG_DEVICE_LEFT + nwk: NWK + event_type: Literal["zha_gateway_message"] = "zha_gateway_message" + event: Literal["device_left"] = "device_left" -@dataclass(kw_only=True, frozen=True) class RawDeviceInitializedDeviceInfo(DeviceJoinedDeviceInfo): """Information about a device that has been initialized without quirks loaded.""" @@ -129,41 +121,37 @@ class RawDeviceInitializedDeviceInfo(DeviceJoinedDeviceInfo): signature: dict[str, Any] -@dataclass(kw_only=True, frozen=True) -class RawDeviceInitializedEvent: +class RawDeviceInitializedEvent(BaseEvent): """Event to signal that a device has been initialized without quirks loaded.""" device_info: RawDeviceInitializedDeviceInfo - event_type: Final[str] = ZHA_GW_MSG - event: Final[str] = ZHA_GW_MSG_RAW_INIT + event_type: Literal["zha_gateway_message"] = "zha_gateway_message" + event: Literal["raw_device_initialized"] = "raw_device_initialized" -@dataclass(kw_only=True, frozen=True) -class DeviceFullInitEvent: +class DeviceFullInitEvent(BaseEvent): """Event to signal that a device has been fully initialized.""" device_info: ExtendedDeviceInfoWithPairingStatus new_join: bool = False - event_type: Final[str] = ZHA_GW_MSG - event: Final[str] = ZHA_GW_MSG_DEVICE_FULL_INIT + event_type: Literal["zha_gateway_message"] = "zha_gateway_message" + event: Literal["device_fully_initialized"] = "device_fully_initialized" -@dataclass(kw_only=True, frozen=True) -class GroupEvent: +class GroupEvent(BaseEvent): """Event to signal a group event.""" event: str group_info: GroupInfo - event_type: Final[str] = ZHA_GW_MSG + event_type: Literal["zha_gateway_message"] = "zha_gateway_message" -@dataclass(kw_only=True, frozen=True) -class DeviceRemovedEvent: +class DeviceRemovedEvent(BaseEvent): """Event to signal that a device has been removed.""" device_info: ExtendedDeviceInfo - event_type: Final[str] = ZHA_GW_MSG - event: Final[str] = ZHA_GW_MSG_DEVICE_REMOVED + event_type: Literal["zha_gateway_message"] = "zha_gateway_message" + event: Literal["device_removed"] = "device_removed" class Gateway(AsyncUtilMixin, EventBase): diff --git a/zha/application/helpers.py b/zha/application/helpers.py index 300de0078..b690c17c0 100644 --- a/zha/application/helpers.py +++ b/zha/application/helpers.py @@ -14,6 +14,7 @@ import re from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar +from pydantic import Field import voluptuous as vol import zigpy.exceptions import zigpy.types @@ -31,6 +32,7 @@ ) from zha.async_ import gather_with_limited_concurrency from zha.decorators import periodic +from zha.model import BaseModel if TYPE_CHECKING: from zha.application.gateway import Gateway @@ -261,81 +263,74 @@ def qr_to_install_code(qr_code: str) -> tuple[zigpy.types.EUI64, zigpy.types.Key raise vol.Invalid(f"couldn't convert qr code: {qr_code}") -@dataclass(kw_only=True, slots=True) -class LightOptions: +class LightOptions(BaseModel): """ZHA light options.""" - default_light_transition: float = dataclasses.field(default=0) - enable_enhanced_light_transition: bool = dataclasses.field(default=False) - enable_light_transitioning_flag: bool = dataclasses.field(default=True) - always_prefer_xy_color_mode: bool = dataclasses.field(default=True) - group_members_assume_state: bool = dataclasses.field(default=True) + default_light_transition: float = Field(default=0) + enable_enhanced_light_transition: bool = Field(default=False) + enable_light_transitioning_flag: bool = Field(default=True) + always_prefer_xy_color_mode: bool = Field(default=True) + group_members_assume_state: bool = Field(default=True) -@dataclass(kw_only=True, slots=True) -class DeviceOptions: +class DeviceOptions(BaseModel): """ZHA device options.""" - enable_identify_on_join: bool = dataclasses.field(default=True) - consider_unavailable_mains: int = dataclasses.field( + enable_identify_on_join: bool = Field(default=True) + consider_unavailable_mains: int = Field( default=CONF_DEFAULT_CONSIDER_UNAVAILABLE_MAINS ) - consider_unavailable_battery: int = dataclasses.field( + consider_unavailable_battery: int = Field( default=CONF_DEFAULT_CONSIDER_UNAVAILABLE_BATTERY ) - enable_mains_startup_polling: bool = dataclasses.field(default=True) + enable_mains_startup_polling: bool = Field(default=True) -@dataclass(kw_only=True, slots=True) -class AlarmControlPanelOptions: +class AlarmControlPanelOptions(BaseModel): """ZHA alarm control panel options.""" - master_code: str = dataclasses.field(default="1234") - failed_tries: int = dataclasses.field(default=3) - arm_requires_code: bool = dataclasses.field(default=False) + master_code: str = Field(default="1234") + failed_tries: int = Field(default=3) + arm_requires_code: bool = Field(default=False) -@dataclass(kw_only=True, slots=True) -class CoordinatorConfiguration: +class CoordinatorConfiguration(BaseModel): """ZHA coordinator configuration.""" path: str - baudrate: int = dataclasses.field(default=115200) - flow_control: str = dataclasses.field(default="hardware") - radio_type: str = dataclasses.field(default="ezsp") + baudrate: int = Field(default=115200) + flow_control: str = Field(default="hardware") + radio_type: str = Field(default="ezsp") -@dataclass(kw_only=True, slots=True) -class QuirksConfiguration: +class QuirksConfiguration(BaseModel): """ZHA quirks configuration.""" - enabled: bool = dataclasses.field(default=True) - custom_quirks_path: str | None = dataclasses.field(default=None) + enabled: bool = Field(default=True) + custom_quirks_path: str | None = Field(default=None) -@dataclass(kw_only=True, slots=True) -class DeviceOverridesConfiguration: +class DeviceOverridesConfiguration(BaseModel): """ZHA device overrides configuration.""" type: Platform -@dataclass(kw_only=True, slots=True) -class ZHAConfiguration: +class ZHAConfiguration(BaseModel): """ZHA configuration.""" - coordinator_configuration: CoordinatorConfiguration = dataclasses.field( + coordinator_configuration: CoordinatorConfiguration = Field( default_factory=CoordinatorConfiguration ) - quirks_configuration: QuirksConfiguration = dataclasses.field( + quirks_configuration: QuirksConfiguration = Field( default_factory=QuirksConfiguration ) - device_overrides: dict[str, DeviceOverridesConfiguration] = dataclasses.field( + device_overrides: dict[str, DeviceOverridesConfiguration] = Field( default_factory=dict ) - light_options: LightOptions = dataclasses.field(default_factory=LightOptions) - device_options: DeviceOptions = dataclasses.field(default_factory=DeviceOptions) - alarm_control_panel_options: AlarmControlPanelOptions = dataclasses.field( + light_options: LightOptions = Field(default_factory=LightOptions) + device_options: DeviceOptions = Field(default_factory=DeviceOptions) + alarm_control_panel_options: AlarmControlPanelOptions = Field( default_factory=AlarmControlPanelOptions ) diff --git a/zha/application/platforms/__init__.py b/zha/application/platforms/__init__.py index b0aedf75b..8aaee54cd 100644 --- a/zha/application/platforms/__init__.py +++ b/zha/application/platforms/__init__.py @@ -5,11 +5,10 @@ from abc import abstractmethod import asyncio from contextlib import suppress -import dataclasses from enum import StrEnum from functools import cached_property import logging -from typing import TYPE_CHECKING, Any, Final, Optional, final +from typing import TYPE_CHECKING, Any, Literal, Optional, final from zigpy.quirks.v2 import EntityMetadata, EntityType from zigpy.types.named import EUI64 @@ -19,6 +18,7 @@ from zha.debounce import Debouncer from zha.event import EventBase from zha.mixins import LogMixin +from zha.model import BaseEvent, BaseModel from zha.zigbee.cluster_handlers import ClusterHandlerInfo if TYPE_CHECKING: @@ -44,13 +44,11 @@ class EntityCategory(StrEnum): DIAGNOSTIC = "diagnostic" -@dataclasses.dataclass(frozen=True, kw_only=True) -class BaseEntityInfo: +class BaseEntityInfo(BaseModel): """Information about a base entity.""" - fallback_name: str + platform: Platform unique_id: str - platform: str class_name: str translation_key: str | None device_class: str | None @@ -58,6 +56,7 @@ class BaseEntityInfo: entity_category: str | None entity_registry_enabled_default: bool enabled: bool = True + fallback_name: str | None # For platform entities cluster_handlers: list[ClusterHandlerInfo] @@ -69,15 +68,13 @@ class BaseEntityInfo: group_id: int | None -@dataclasses.dataclass(frozen=True, kw_only=True) -class BaseIdentifiers: +class BaseIdentifiers(BaseModel): """Identifiers for the base entity.""" unique_id: str - platform: str + platform: Platform -@dataclasses.dataclass(frozen=True, kw_only=True) class PlatformEntityIdentifiers(BaseIdentifiers): """Identifiers for the platform entity.""" @@ -85,20 +82,18 @@ class PlatformEntityIdentifiers(BaseIdentifiers): endpoint_id: int -@dataclasses.dataclass(frozen=True, kw_only=True) class GroupEntityIdentifiers(BaseIdentifiers): """Identifiers for the group entity.""" group_id: int -@dataclasses.dataclass(frozen=True, kw_only=True) -class EntityStateChangedEvent: +class EntityStateChangedEvent(BaseEvent): """Event for when an entity state changes.""" - event_type: Final[str] = "entity" - event: Final[str] = STATE_CHANGED - platform: str + event_type: Literal["entity"] = "entity" + event: Literal["state_changed"] = "state_changed" + platform: Platform unique_id: str device_ieee: Optional[EUI64] = None endpoint_id: Optional[int] = None @@ -375,12 +370,13 @@ def identifiers(self) -> PlatformEntityIdentifiers: @cached_property def info_object(self) -> BaseEntityInfo: """Return a representation of the platform entity.""" - return dataclasses.replace( - super().info_object, - cluster_handlers=[ch.info_object for ch in self._cluster_handlers], - device_ieee=self._device.ieee, - endpoint_id=self._endpoint.id, - available=self.available, + return super().info_object.model_copy( + update={ + "cluster_handlers": [ch.info_object for ch in self._cluster_handlers], + "device_ieee": self._device.ieee, + "endpoint_id": self._endpoint.id, + "available": self.available, + } ) @property @@ -456,10 +452,7 @@ def identifiers(self) -> GroupEntityIdentifiers: @cached_property def info_object(self) -> BaseEntityInfo: """Return a representation of the group.""" - return dataclasses.replace( - super().info_object, - group_id=self.group_id, - ) + return super().info_object.model_copy(update={"group_id": self.group_id}) @property def state(self) -> dict[str, Any]: diff --git a/zha/application/platforms/alarm_control_panel/__init__.py b/zha/application/platforms/alarm_control_panel/__init__.py index f1716a4e6..0dcb004e3 100644 --- a/zha/application/platforms/alarm_control_panel/__init__.py +++ b/zha/application/platforms/alarm_control_panel/__init__.py @@ -2,7 +2,6 @@ from __future__ import annotations -from dataclasses import dataclass import functools import logging from typing import TYPE_CHECKING, Any @@ -42,7 +41,6 @@ _LOGGER = logging.getLogger(__name__) -@dataclass(frozen=True, kw_only=True) class AlarmControlPanelEntityInfo(BaseEntityInfo): """Alarm control panel entity info.""" diff --git a/zha/application/platforms/binary_sensor/__init__.py b/zha/application/platforms/binary_sensor/__init__.py index c35b2b624..f26f14dfe 100644 --- a/zha/application/platforms/binary_sensor/__init__.py +++ b/zha/application/platforms/binary_sensor/__init__.py @@ -2,7 +2,6 @@ from __future__ import annotations -from dataclasses import dataclass import functools import logging from typing import TYPE_CHECKING @@ -46,7 +45,6 @@ _LOGGER = logging.getLogger(__name__) -@dataclass(frozen=True, kw_only=True) class BinarySensorEntityInfo(BaseEntityInfo): """Binary sensor entity info.""" diff --git a/zha/application/platforms/button/__init__.py b/zha/application/platforms/button/__init__.py index fa0d6271d..432d12163 100644 --- a/zha/application/platforms/button/__init__.py +++ b/zha/application/platforms/button/__init__.py @@ -2,7 +2,6 @@ from __future__ import annotations -from dataclasses import dataclass import functools import logging from typing import TYPE_CHECKING, Any, Self @@ -30,7 +29,6 @@ _LOGGER = logging.getLogger(__name__) -@dataclass(frozen=True, kw_only=True) class CommandButtonEntityInfo(BaseEntityInfo): """Command button entity info.""" @@ -39,7 +37,6 @@ class CommandButtonEntityInfo(BaseEntityInfo): kwargs: dict[str, Any] -@dataclass(frozen=True, kw_only=True) class WriteAttributeButtonEntityInfo(BaseEntityInfo): """Write attribute button entity info.""" diff --git a/zha/application/platforms/climate/__init__.py b/zha/application/platforms/climate/__init__.py index c0ba9851b..24b185997 100644 --- a/zha/application/platforms/climate/__init__.py +++ b/zha/application/platforms/climate/__init__.py @@ -3,7 +3,6 @@ from __future__ import annotations from asyncio import Task -from dataclasses import dataclass import datetime as dt import functools from typing import TYPE_CHECKING, Any @@ -56,7 +55,6 @@ MULTI_MATCH = functools.partial(PLATFORM_ENTITIES.multipass_match, Platform.CLIMATE) -@dataclass(frozen=True, kw_only=True) class ThermostatEntityInfo(BaseEntityInfo): """Thermostat entity info.""" diff --git a/zha/application/platforms/fan/__init__.py b/zha/application/platforms/fan/__init__.py index b3270a1a9..7a88d5610 100644 --- a/zha/application/platforms/fan/__init__.py +++ b/zha/application/platforms/fan/__init__.py @@ -3,7 +3,6 @@ from __future__ import annotations from abc import abstractmethod -from dataclasses import dataclass import functools import math from typing import TYPE_CHECKING, Any @@ -59,7 +58,6 @@ MULTI_MATCH = functools.partial(PLATFORM_ENTITIES.multipass_match, Platform.FAN) -@dataclass(frozen=True, kw_only=True) class FanEntityInfo(BaseEntityInfo): """Fan entity info.""" diff --git a/zha/application/platforms/light/__init__.py b/zha/application/platforms/light/__init__.py index 9dbdfc3eb..2057662d8 100644 --- a/zha/application/platforms/light/__init__.py +++ b/zha/application/platforms/light/__init__.py @@ -9,13 +9,12 @@ from collections import Counter from collections.abc import Callable import contextlib -import dataclasses -from dataclasses import dataclass import functools import itertools import logging from typing import TYPE_CHECKING, Any +from pydantic import Field from zigpy.zcl.clusters.general import Identify, LevelControl, OnOff from zigpy.zcl.clusters.lighting import Color from zigpy.zcl.foundation import Status @@ -87,11 +86,10 @@ GROUP_MATCH = functools.partial(PLATFORM_ENTITIES.group_match, Platform.LIGHT) -@dataclass(frozen=True, kw_only=True) class LightEntityInfo(BaseEntityInfo): """Light entity info.""" - effect_list: list[str] | None = dataclasses.field(default=None) + effect_list: list[str] | None = Field(default=None) supported_features: LightEntityFeature min_mireds: int max_mireds: int diff --git a/zha/application/platforms/number/__init__.py b/zha/application/platforms/number/__init__.py index 21817a7b9..f8647a117 100644 --- a/zha/application/platforms/number/__init__.py +++ b/zha/application/platforms/number/__init__.py @@ -2,7 +2,6 @@ from __future__ import annotations -from dataclasses import dataclass import functools import logging from typing import TYPE_CHECKING, Any, Self @@ -48,18 +47,16 @@ ) -@dataclass(frozen=True, kw_only=True) class NumberEntityInfo(BaseEntityInfo): """Number entity info.""" engineering_units: int - application_type: int + application_type: int | None min_value: float | None max_value: float | None step: float | None -@dataclass(frozen=True, kw_only=True) class NumberConfigurationEntityInfo(BaseEntityInfo): """Number configuration entity info.""" diff --git a/zha/application/platforms/select.py b/zha/application/platforms/select.py index 101296652..c7ca4fe01 100644 --- a/zha/application/platforms/select.py +++ b/zha/application/platforms/select.py @@ -2,7 +2,6 @@ from __future__ import annotations -from dataclasses import dataclass from enum import Enum import functools import logging @@ -48,7 +47,6 @@ _LOGGER = logging.getLogger(__name__) -@dataclass(frozen=True, kw_only=True) class EnumSelectInfo(BaseEntityInfo): """Enum select entity info.""" diff --git a/zha/application/platforms/sensor/__init__.py b/zha/application/platforms/sensor/__init__.py index 08140929a..131a3f528 100644 --- a/zha/application/platforms/sensor/__init__.py +++ b/zha/application/platforms/sensor/__init__.py @@ -3,7 +3,6 @@ from __future__ import annotations from asyncio import Task -from dataclasses import dataclass from datetime import UTC, date, datetime import enum import functools @@ -37,6 +36,7 @@ ) from zha.application.registries import PLATFORM_ENTITIES from zha.decorators import periodic +from zha.model import BaseModel from zha.units import ( CONCENTRATION_MICROGRAMS_PER_CUBIC_METER, CONCENTRATION_PARTS_PER_BILLION, @@ -113,24 +113,22 @@ ) -@dataclass(frozen=True, kw_only=True) class SensorEntityInfo(BaseEntityInfo): """Sensor entity info.""" - attribute: str decimals: int divisor: int multiplier: int + attribute: str | None = None # LQI and RSSI have no attribute unit: str | None = None device_class: SensorDeviceClass | None = None state_class: SensorStateClass | None = None -@dataclass(frozen=True, kw_only=True) class DeviceCounterEntityInfo(BaseEntityInfo): """Device counter entity info.""" - device_ieee: str + device_ieee: types.EUI64 available: bool counter: str counter_value: int @@ -138,11 +136,10 @@ class DeviceCounterEntityInfo(BaseEntityInfo): counter_group: str -@dataclass(frozen=True, kw_only=True) class DeviceCounterSensorIdentifiers(BaseIdentifiers): """Device counter sensor identifiers.""" - device_ieee: str + device_ieee: types.EUI64 class Sensor(PlatformEntity): @@ -426,8 +423,13 @@ def identifiers(self) -> DeviceCounterSensorIdentifiers: @functools.cached_property def info_object(self) -> DeviceCounterEntityInfo: """Return a representation of the platform entity.""" + data = super().info_object.__dict__ + data.pop("device_ieee") + data.pop("available") return DeviceCounterEntityInfo( - **super().info_object.__dict__, + **data, + device_ieee=self._device.ieee, + available=self._device.available, counter=self._zigpy_counter.name, counter_value=self._zigpy_counter.value, counter_groups=self._zigpy_counter_groups, @@ -782,9 +784,8 @@ def formatter(self, value: int) -> int | None: return round(pow(10, ((value - 1) / 10000))) -@dataclass(frozen=True, kw_only=True) -class SmartEnergyMeteringEntityDescription: - """Dataclass that describes a Zigbee smart energy metering entity.""" +class SmartEnergyMeteringEntityDescription(BaseModel): + """Model that describes a Zigbee smart energy metering entity.""" key: str = "instantaneous_demand" state_class: SensorStateClass | None = SensorStateClass.MEASUREMENT @@ -907,9 +908,8 @@ def formatter(self, value: int) -> int | float: return self._cluster_handler.demand_formatter(value) -@dataclass(frozen=True, kw_only=True) class SmartEnergySummationEntityDescription(SmartEnergyMeteringEntityDescription): - """Dataclass that describes a Zigbee smart energy summation entity.""" + """Model that describes a Zigbee smart energy summation entity.""" key: str = "summation_delivered" state_class: SensorStateClass | None = SensorStateClass.TOTAL_INCREASING diff --git a/zha/application/platforms/siren.py b/zha/application/platforms/siren.py index b5ab76b17..793f11490 100644 --- a/zha/application/platforms/siren.py +++ b/zha/application/platforms/siren.py @@ -4,7 +4,6 @@ import asyncio import contextlib -from dataclasses import dataclass from enum import IntFlag import functools from typing import TYPE_CHECKING, Any, Final, cast @@ -54,7 +53,6 @@ class SirenEntityFeature(IntFlag): DURATION = 16 -@dataclass(frozen=True, kw_only=True) class SirenEntityInfo(BaseEntityInfo): """Siren entity info.""" diff --git a/zha/application/platforms/switch.py b/zha/application/platforms/switch.py index b5f536109..59b7b0a15 100644 --- a/zha/application/platforms/switch.py +++ b/zha/application/platforms/switch.py @@ -3,7 +3,6 @@ from __future__ import annotations from abc import ABC -from dataclasses import dataclass import functools import logging from typing import TYPE_CHECKING, Any, Self, cast @@ -50,7 +49,6 @@ _LOGGER = logging.getLogger(__name__) -@dataclass(frozen=True, kw_only=True) class ConfigurableAttributeSwitchInfo(BaseEntityInfo): """Switch configuration entity info.""" diff --git a/zha/application/platforms/update.py b/zha/application/platforms/update.py index d9f58916f..0225f72e9 100644 --- a/zha/application/platforms/update.py +++ b/zha/application/platforms/update.py @@ -2,7 +2,6 @@ from __future__ import annotations -from dataclasses import dataclass from enum import IntFlag, StrEnum import functools import itertools @@ -64,7 +63,6 @@ class UpdateEntityFeature(IntFlag): ATTR_VERSION: Final = "version" -@dataclass(frozen=True, kw_only=True) class UpdateEntityInfo(BaseEntityInfo): """Update entity info.""" diff --git a/zha/model.py b/zha/model.py new file mode 100644 index 000000000..0b446eccc --- /dev/null +++ b/zha/model.py @@ -0,0 +1,62 @@ +"""Shared models for ZHA.""" + +import logging +from typing import Any, Literal, Optional, Union + +from pydantic import ( + BaseModel as PydanticBaseModel, + ConfigDict, + field_serializer, + field_validator, +) +from zigpy.types.named import EUI64 + +_LOGGER = logging.getLogger(__name__) + + +def convert_to_ieee(ieee: Optional[Union[str, EUI64, list]]) -> Optional[EUI64]: + """Convert ieee to EUI64.""" + if ieee is None: + return None + if isinstance(ieee, EUI64): + return ieee + if isinstance(ieee, str): + return EUI64.convert(ieee) + if isinstance(ieee, list): + return EUI64.deserialize(ieee)[0] + return ieee + + +class BaseModel(PydanticBaseModel): + """Base model for ZHA models.""" + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") + + @field_validator("ieee", "device_ieee", mode="before", check_fields=False) + @classmethod + def convert_ieee(cls, ieee: Optional[Union[str, EUI64, list]]) -> Optional[EUI64]: + """Convert ieee to EUI64.""" + return convert_to_ieee(ieee) + + @field_serializer("ieee", "device_ieee", check_fields=False) + def serialize_ieee(self, ieee): + """Customize how ieee is serialized.""" + if isinstance(ieee, EUI64): + return str(ieee) + return ieee + + @classmethod + def _get_value(cls, *args, **kwargs) -> Any: + """Convert EUI64 to string.""" + value = args[0] + if isinstance(value, EUI64): + return str(value) + return PydanticBaseModel._get_value(cls, *args, **kwargs) + + +class BaseEvent(BaseModel): + """Base model for ZHA events.""" + + message_type: Literal["event"] = "event" + event_type: str + event: str diff --git a/zha/zigbee/cluster_handlers/__init__.py b/zha/zigbee/cluster_handlers/__init__.py index 321b9e194..3860ed2ae 100644 --- a/zha/zigbee/cluster_handlers/__init__.py +++ b/zha/zigbee/cluster_handlers/__init__.py @@ -4,11 +4,10 @@ from collections.abc import Awaitable, Callable, Coroutine, Iterator import contextlib -from dataclasses import dataclass -from enum import Enum +from enum import StrEnum import functools import logging -from typing import TYPE_CHECKING, Any, Final, ParamSpec, TypedDict +from typing import TYPE_CHECKING, Any, Literal, ParamSpec, TypedDict import zigpy.exceptions import zigpy.util @@ -18,10 +17,10 @@ ConfigureReportingResponseRecord, Status, ZCLAttributeDef, + ZCLCommandDef, ) from zha.application.const import ( - ZHA_CLUSTER_HANDLER_MSG, ZHA_CLUSTER_HANDLER_MSG_BIND, ZHA_CLUSTER_HANDLER_MSG_CFG_RPT, ) @@ -29,13 +28,13 @@ from zha.event import EventBase from zha.exceptions import ZHAException from zha.mixins import LogMixin +from zha.model import BaseEvent, BaseModel from zha.zigbee.cluster_handlers.const import ( ARGS, ATTRIBUTE_ID, ATTRIBUTE_NAME, ATTRIBUTE_VALUE, CLUSTER_HANDLER_ATTRIBUTE_UPDATED, - CLUSTER_HANDLER_EVENT, CLUSTER_HANDLER_ZDO, CLUSTER_ID, CLUSTER_READS_PER_REQ, @@ -114,16 +113,15 @@ def parse_and_log_command(cluster_handler, tsn, command_id, args): return name -class ClusterHandlerStatus(Enum): +class ClusterHandlerStatus(StrEnum): """Status of a cluster handler.""" - CREATED = 1 - CONFIGURED = 2 - INITIALIZED = 3 + CREATED = "created" + CONFIGURED = "configured" + INITIALIZED = "initialized" -@dataclass(kw_only=True, frozen=True) -class ClusterAttributeUpdatedEvent: +class ClusterAttributeUpdatedEvent(BaseEvent): """Event to signal that a cluster attribute has been updated.""" attribute_id: int @@ -131,51 +129,51 @@ class ClusterAttributeUpdatedEvent: attribute_value: Any cluster_handler_unique_id: str cluster_id: int - event_type: Final[str] = CLUSTER_HANDLER_EVENT - event: Final[str] = CLUSTER_HANDLER_ATTRIBUTE_UPDATED + event_type: Literal["cluster_handler_event"] = "cluster_handler_event" + event: Literal["cluster_handler_attribute_updated"] = ( + "cluster_handler_attribute_updated" + ) -@dataclass(kw_only=True, frozen=True) -class ClusterBindEvent: +class ClusterBindEvent(BaseEvent): """Event generated when the cluster is bound.""" cluster_name: str cluster_id: int success: bool cluster_handler_unique_id: str - event_type: Final[str] = ZHA_CLUSTER_HANDLER_MSG - event: Final[str] = ZHA_CLUSTER_HANDLER_MSG_BIND + event_type: Literal["zha_channel_message"] = "zha_channel_message" + event: Literal["zha_channel_bind"] = "zha_channel_bind" -@dataclass(kw_only=True, frozen=True) -class ClusterConfigureReportingEvent: +class ClusterConfigureReportingEvent(BaseEvent): """Event generates when a cluster configures attribute reporting.""" cluster_name: str cluster_id: int attributes: dict[str, dict[str, Any]] cluster_handler_unique_id: str - event_type: Final[str] = ZHA_CLUSTER_HANDLER_MSG - event: Final[str] = ZHA_CLUSTER_HANDLER_MSG_CFG_RPT + event_type: Literal["zha_channel_message"] = "zha_channel_message" + event: Literal["zha_channel_configure_reporting"] = ( + "zha_channel_configure_reporting" + ) -@dataclass(kw_only=True, frozen=True) -class ClusterInfo: +class ClusterInfo(BaseModel): """Cluster information.""" id: int name: str type: str - commands: dict[int, str] + commands: list[ZCLCommandDef] -@dataclass(kw_only=True, frozen=True) -class ClusterHandlerInfo: +class ClusterHandlerInfo(BaseModel): """Cluster handler information.""" class_name: str generic_id: str - endpoint_id: str + endpoint_id: int cluster: ClusterInfo id: str unique_id: str @@ -232,7 +230,7 @@ def info_object(self) -> ClusterHandlerInfo: ), id=self._id, unique_id=self._unique_id, - status=self._status.name, + status=self._status, value_attribute=getattr(self, "value_attribute", None), ) @@ -547,7 +545,7 @@ async def async_update(self) -> None: def _get_attribute_name(self, attrid: int) -> str | int: if attrid not in self.cluster.attributes: - return attrid + return "Unknown" return self.cluster.attributes[attrid].name diff --git a/zha/zigbee/cluster_handlers/general.py b/zha/zigbee/cluster_handlers/general.py index e103f1199..d9ce799f2 100644 --- a/zha/zigbee/cluster_handlers/general.py +++ b/zha/zigbee/cluster_handlers/general.py @@ -4,9 +4,8 @@ import asyncio from collections.abc import Coroutine -from dataclasses import dataclass from datetime import datetime -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any, Literal from zhaquirks.quirk_ids import TUYA_PLUG_ONOFF import zigpy.exceptions @@ -45,6 +44,7 @@ from zigpy.zcl.foundation import Status from zha.exceptions import ZHAException +from zha.model import BaseEvent from zha.zigbee.cluster_handlers import ( AttrReportConfig, ClientClusterHandler, @@ -69,13 +69,12 @@ from zha.zigbee.endpoint import Endpoint -@dataclass(frozen=True, kw_only=True) -class LevelChangeEvent: +class LevelChangeEvent(BaseEvent): """Event to signal that a cluster attribute has been updated.""" level: int event: str - event_type: Final[str] = "cluster_handler_event" + event_type: Literal["cluster_handler_event"] = "cluster_handler_event" @registries.CLUSTER_HANDLER_REGISTRY.register(Alarms.cluster_id) diff --git a/zha/zigbee/cluster_handlers/security.py b/zha/zigbee/cluster_handlers/security.py index ea9d364c4..cef213e02 100644 --- a/zha/zigbee/cluster_handlers/security.py +++ b/zha/zigbee/cluster_handlers/security.py @@ -3,8 +3,7 @@ from __future__ import annotations from collections.abc import Callable -import dataclasses -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any, Literal import zigpy.zcl from zigpy.zcl.clusters.security import ( @@ -18,6 +17,7 @@ ) from zha.exceptions import ZHAException +from zha.model import BaseEvent from zha.zigbee.cluster_handlers import ClusterHandler, ClusterHandlerStatus, registries from zha.zigbee.cluster_handlers.const import CLUSTER_HANDLER_STATE_CHANGED @@ -28,12 +28,11 @@ SIGNAL_ALARM_TRIGGERED = "zha_armed_triggered" -@dataclasses.dataclass(frozen=True, kw_only=True) -class ClusterHandlerStateChangedEvent: +class ClusterHandlerStateChangedEvent(BaseEvent): """Event to signal that a cluster attribute has been updated.""" - event_type: Final[str] = "cluster_handler_event" - event: Final[str] = "cluster_handler_state_changed" + event_type: Literal["cluster_handler_event"] = "cluster_handler_event" + event: Literal["cluster_handler_state_changed"] = "cluster_handler_state_changed" @registries.CLUSTER_HANDLER_REGISTRY.register(AceCluster.cluster_id) diff --git a/zha/zigbee/device.py b/zha/zigbee/device.py index c86a6c3aa..be1f87afc 100644 --- a/zha/zigbee/device.py +++ b/zha/zigbee/device.py @@ -5,12 +5,11 @@ from __future__ import annotations import asyncio -from dataclasses import dataclass -from enum import Enum +from enum import StrEnum from functools import cached_property import logging import time -from typing import TYPE_CHECKING, Any, Final, Self +from typing import TYPE_CHECKING, Any, Literal, Self from zigpy.device import Device as ZigpyDevice import zigpy.exceptions @@ -55,7 +54,6 @@ UNKNOWN_MANUFACTURER, UNKNOWN_MODEL, ZHA_CLUSTER_HANDLER_CFG_DONE, - ZHA_CLUSTER_HANDLER_MSG, ZHA_EVENT, ) from zha.application.helpers import convert_to_zcl_values @@ -63,6 +61,7 @@ from zha.event import EventBase from zha.exceptions import ZHAException from zha.mixins import LogMixin +from zha.model import BaseEvent, BaseModel from zha.zigbee.cluster_handlers import ClusterHandler, ZDOClusterHandler from zha.zigbee.endpoint import Endpoint @@ -83,46 +82,42 @@ def get_device_automation_triggers( } -@dataclass(frozen=True, kw_only=True) -class ClusterBinding: - """Describes a cluster binding.""" - - name: str - type: str - id: int - endpoint_id: int - - -class DeviceStatus(Enum): +class DeviceStatus(StrEnum): """Status of a device.""" - CREATED = 1 - INITIALIZED = 2 + CREATED = "created" + INITIALIZED = "initialized" -@dataclass(kw_only=True, frozen=True) -class ZHAEvent: +class ZHAEvent(BaseEvent): """Event generated when a device wishes to send an arbitrary event.""" device_ieee: EUI64 unique_id: str data: dict[str, Any] - event_type: Final[str] = ZHA_EVENT - event: Final[str] = ZHA_EVENT + event_type: Literal["zha_event"] = "zha_event" + event: Literal["zha_event"] = "zha_event" -@dataclass(kw_only=True, frozen=True) -class ClusterHandlerConfigurationComplete: +class ClusterHandlerConfigurationComplete(BaseEvent): """Event generated when all cluster handlers are configured.""" device_ieee: EUI64 unique_id: str - event_type: Final[str] = ZHA_CLUSTER_HANDLER_MSG - event: Final[str] = ZHA_CLUSTER_HANDLER_CFG_DONE + event_type: Literal["zha_channel_message"] = "zha_channel_message" + event: Literal["zha_channel_cfg_done"] = "zha_channel_cfg_done" + + +class ClusterBinding(BaseModel): + """Describes a cluster binding.""" + + name: str + type: str + id: int + endpoint_id: int -@dataclass(kw_only=True, frozen=True) -class DeviceInfo: +class DeviceInfo(BaseModel): """Describes a device.""" ieee: EUI64 @@ -135,16 +130,15 @@ class DeviceInfo: quirk_id: str | None manufacturer_code: int | None power_source: str - lqi: int - rssi: int + lqi: int | None + rssi: int | None last_seen: str available: bool device_type: str signature: dict[str, Any] -@dataclass(kw_only=True, frozen=True) -class NeighborInfo: +class NeighborInfo(BaseModel): """Describes a neighbor.""" device_type: _NeighborEnums.DeviceType @@ -158,8 +152,7 @@ class NeighborInfo: lqi: uint8_t -@dataclass(kw_only=True, frozen=True) -class RouteInfo: +class RouteInfo(BaseModel): """Describes a route.""" dest_nwk: NWK @@ -170,14 +163,12 @@ class RouteInfo: next_hop: NWK -@dataclass(kw_only=True, frozen=True) -class EndpointNameInfo: +class EndpointNameInfo(BaseModel): """Describes an endpoint name.""" name: str -@dataclass(kw_only=True, frozen=True) class ExtendedDeviceInfo(DeviceInfo): """Describes a ZHA device.""" @@ -567,11 +558,11 @@ async def _check_available(self, *_: Any) -> None: "Attempting to checkin with device - missed checkins: %s", self._checkins_missed_count, ) - if not self.basic_ch: + if not self._basic_ch: self.debug("does not have a mandatory basic cluster") self.update_available(False) return - res = await self.basic_ch.get_attribute_value( + res = await self._basic_ch.get_attribute_value( ATTR_MANUFACTURER, from_cache=False ) if res is not None: @@ -732,7 +723,7 @@ async def async_configure(self) -> None: ZHA_CLUSTER_HANDLER_CFG_DONE, ClusterHandlerConfigurationComplete( device_ieee=self.ieee, - unique_id=self.ieee, + unique_id=self.unique_id, ), ) @@ -740,10 +731,10 @@ async def async_configure(self) -> None: if ( should_identify - and self.identify_ch is not None + and self._identify_ch is not None and not self.skip_configuration ): - await self.identify_ch.trigger_effect( + await self._identify_ch.trigger_effect( effect_id=Identify.EffectIdentifier.Okay, effect_variant=Identify.EffectVariant.Default, ) diff --git a/zha/zigbee/group.py b/zha/zigbee/group.py index 4ec96a7f2..057b4d984 100644 --- a/zha/zigbee/group.py +++ b/zha/zigbee/group.py @@ -4,7 +4,6 @@ import asyncio from collections.abc import Callable -from dataclasses import dataclass from functools import cached_property import logging from typing import TYPE_CHECKING, Any @@ -19,6 +18,7 @@ ) from zha.const import STATE_CHANGED from zha.mixins import LogMixin +from zha.model import BaseModel from zha.zigbee.device import ExtendedDeviceInfo if TYPE_CHECKING: @@ -31,25 +31,22 @@ _LOGGER = logging.getLogger(__name__) -@dataclass(frozen=True, kw_only=True) -class GroupMemberReference: +class GroupMemberReference(BaseModel): """Describes a group member.""" ieee: EUI64 endpoint_id: int -@dataclass(frozen=True, kw_only=True) -class GroupEntityReference: +class GroupEntityReference(BaseModel): """Reference to a group entity.""" - entity_id: int + entity_id: str name: str | None = None original_name: str | None = None -@dataclass(frozen=True, kw_only=True) -class GroupMemberInfo: +class GroupMemberInfo(BaseModel): """Describes a group member.""" ieee: EUI64 @@ -58,8 +55,7 @@ class GroupMemberInfo: entities: dict[str, BaseEntityInfo] -@dataclass(frozen=True, kw_only=True) -class GroupInfo: +class GroupInfo(BaseModel): """Describes a group.""" group_id: int From 2e1e1bb7f3be015bd5a7ee17c183d153f6393120 Mon Sep 17 00:00:00 2001 From: David Mulcahey Date: Thu, 17 Oct 2024 16:12:54 -0400 Subject: [PATCH 02/12] clean up base models and add test --- tests/test_model.py | 85 +++++++++++++++++++++++++++++++++++++++++++++ zha/model.py | 47 ++++++++++--------------- 2 files changed, 103 insertions(+), 29 deletions(-) create mode 100644 tests/test_model.py diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 000000000..604cf9d00 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,85 @@ +"""Tests for the ZHA model module.""" + +from zigpy.types import NWK +from zigpy.types.named import EUI64 + +from zha.zigbee.device import DeviceInfo, ZHAEvent + + +def test_ser_deser_zha_event(): + """Test serializing and deserializing ZHA events.""" + + zha_event = ZHAEvent( + device_ieee="00:00:00:00:00:00:00:00", + unique_id="00:00:00:00:00:00:00:00", + data={"key": "value"}, + ) + + assert isinstance(zha_event.device_ieee, EUI64) + assert zha_event.device_ieee == EUI64.convert("00:00:00:00:00:00:00:00") + assert zha_event.unique_id == "00:00:00:00:00:00:00:00" + assert zha_event.data == {"key": "value"} + + assert zha_event.model_dump() == { + "message_type": "event", + "event_type": "zha_event", + "event": "zha_event", + "device_ieee": "00:00:00:00:00:00:00:00", + "unique_id": "00:00:00:00:00:00:00:00", + "data": {"key": "value"}, + } + + assert ( + zha_event.model_dump_json() + == '{"message_type":"event","event_type":"zha_event","event":"zha_event",' + '"device_ieee":"00:00:00:00:00:00:00:00","unique_id":"00:00:00:00:00:00:00:00","data":{"key":"value"}}' + ) + + device_info = DeviceInfo( + ieee="00:00:00:00:00:00:00:00", + nwk=0x0000, + manufacturer="test", + model="test", + name="test", + quirk_applied=True, + quirk_class="test", + quirk_id="test", + manufacturer_code=0x0000, + power_source="test", + lqi=1, + rssi=2, + last_seen="", + available=True, + device_type="test", + signature={"foo": "bar"}, + ) + + assert isinstance(device_info.ieee, EUI64) + assert device_info.ieee == EUI64.convert("00:00:00:00:00:00:00:00") + assert isinstance(device_info.nwk, NWK) + + assert device_info.model_dump() == { + "ieee": "00:00:00:00:00:00:00:00", + "nwk": 0, + "manufacturer": "test", + "model": "test", + "name": "test", + "quirk_applied": True, + "quirk_class": "test", + "quirk_id": "test", + "manufacturer_code": 0, + "power_source": "test", + "lqi": 1, + "rssi": 2, + "last_seen": "", + "available": True, + "device_type": "test", + "signature": {"foo": "bar"}, + } + + assert device_info.model_dump_json() == ( + '{"ieee":"00:00:00:00:00:00:00:00","nwk":0,' + '"manufacturer":"test","model":"test","name":"test","quirk_applied":true,' + '"quirk_class":"test","quirk_id":"test","manufacturer_code":0,"power_source":"test",' + '"lqi":1,"rssi":2,"last_seen":"","available":true,"device_type":"test","signature":{"foo":"bar"}}' + ) diff --git a/zha/model.py b/zha/model.py index 0b446eccc..eb366603d 100644 --- a/zha/model.py +++ b/zha/model.py @@ -1,7 +1,7 @@ """Shared models for ZHA.""" import logging -from typing import Any, Literal, Optional, Union +from typing import Literal, Optional, Union from pydantic import ( BaseModel as PydanticBaseModel, @@ -9,24 +9,11 @@ field_serializer, field_validator, ) -from zigpy.types.named import EUI64 +from zigpy.types.named import EUI64, NWK _LOGGER = logging.getLogger(__name__) -def convert_to_ieee(ieee: Optional[Union[str, EUI64, list]]) -> Optional[EUI64]: - """Convert ieee to EUI64.""" - if ieee is None: - return None - if isinstance(ieee, EUI64): - return ieee - if isinstance(ieee, str): - return EUI64.convert(ieee) - if isinstance(ieee, list): - return EUI64.deserialize(ieee)[0] - return ieee - - class BaseModel(PydanticBaseModel): """Base model for ZHA models.""" @@ -34,24 +21,26 @@ class BaseModel(PydanticBaseModel): @field_validator("ieee", "device_ieee", mode="before", check_fields=False) @classmethod - def convert_ieee(cls, ieee: Optional[Union[str, EUI64, list]]) -> Optional[EUI64]: + def convert_ieee(cls, ieee: Optional[Union[str, EUI64]]) -> Optional[EUI64]: """Convert ieee to EUI64.""" - return convert_to_ieee(ieee) - - @field_serializer("ieee", "device_ieee", check_fields=False) - def serialize_ieee(self, ieee): - """Customize how ieee is serialized.""" - if isinstance(ieee, EUI64): - return str(ieee) + if ieee is None: + return None + if isinstance(ieee, str): + return EUI64.convert(ieee) return ieee + @field_validator("nwk", mode="before", check_fields=False) @classmethod - def _get_value(cls, *args, **kwargs) -> Any: - """Convert EUI64 to string.""" - value = args[0] - if isinstance(value, EUI64): - return str(value) - return PydanticBaseModel._get_value(cls, *args, **kwargs) + def convert_nwk(cls, nwk: Optional[Union[int, NWK]]) -> Optional[NWK]: + """Convert int to NWK.""" + if isinstance(nwk, int) and not isinstance(nwk, NWK): + return NWK(nwk) + return nwk + + @field_serializer("ieee", "device_ieee", check_fields=False) + def serialize_ieee(self, ieee: EUI64): + """Customize how ieee is serialized.""" + return str(ieee) class BaseEvent(BaseModel): From 406d13d47da1fed8fdd6f01f626df580de1a035c Mon Sep 17 00:00:00 2001 From: David Mulcahey Date: Fri, 18 Oct 2024 09:52:07 -0400 Subject: [PATCH 03/12] make validators shareable --- zha/model.py | 66 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 18 deletions(-) diff --git a/zha/model.py b/zha/model.py index eb366603d..347f01fa6 100644 --- a/zha/model.py +++ b/zha/model.py @@ -1,7 +1,9 @@ """Shared models for ZHA.""" +from collections.abc import Callable +from enum import Enum import logging -from typing import Literal, Optional, Union +from typing import Any, Literal, Optional, Union from pydantic import ( BaseModel as PydanticBaseModel, @@ -14,28 +16,56 @@ _LOGGER = logging.getLogger(__name__) +def convert_ieee(ieee: Optional[Union[str, EUI64]]) -> Optional[EUI64]: + """Convert ieee to EUI64.""" + if ieee is None: + return None + if isinstance(ieee, str): + return EUI64.convert(ieee) + return ieee + + +def convert_nwk(nwk: Optional[Union[int, NWK]]) -> Optional[NWK]: + """Convert int to NWK.""" + if isinstance(nwk, int) and not isinstance(nwk, NWK): + return NWK(nwk) + return nwk + + +def convert_enum(enum_type: Enum) -> Callable[[str | Enum], Enum]: + """Convert enum name to enum instance.""" + + def _convert_enum(enum_name_or_instance: str | Enum) -> Enum: + """Convert extended_pan_id to ExtendedPanId.""" + if isinstance(enum_name_or_instance, str): + return enum_type(enum_name_or_instance) # type: ignore + return enum_name_or_instance + + return _convert_enum + + +def convert_int(zigpy_type: type) -> Any: + """Convert int to zigpy type.""" + + def _convert_int(value: int) -> Any: + """Convert int to zigpy type.""" + return zigpy_type(value) + + return _convert_int + + class BaseModel(PydanticBaseModel): """Base model for ZHA models.""" model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") - @field_validator("ieee", "device_ieee", mode="before", check_fields=False) - @classmethod - def convert_ieee(cls, ieee: Optional[Union[str, EUI64]]) -> Optional[EUI64]: - """Convert ieee to EUI64.""" - if ieee is None: - return None - if isinstance(ieee, str): - return EUI64.convert(ieee) - return ieee - - @field_validator("nwk", mode="before", check_fields=False) - @classmethod - def convert_nwk(cls, nwk: Optional[Union[int, NWK]]) -> Optional[NWK]: - """Convert int to NWK.""" - if isinstance(nwk, int) and not isinstance(nwk, NWK): - return NWK(nwk) - return nwk + _convert_ieee = field_validator( + "ieee", "device_ieee", mode="before", check_fields=False + )(convert_ieee) + + _convert_nwk = field_validator( + "nwk", "dest_nwk", "next_hop", mode="before", check_fields=False + )(convert_nwk) @field_serializer("ieee", "device_ieee", check_fields=False) def serialize_ieee(self, ieee: EUI64): From a83ad3f25984ad7cfe32c85ebc0c22d4609ce759 Mon Sep 17 00:00:00 2001 From: David Mulcahey Date: Fri, 18 Oct 2024 09:52:43 -0400 Subject: [PATCH 04/12] add validators and serializers for device models --- tests/test_device.py | 81 +++++++++++++++++++++++++++++++++++++++++++- zha/zigbee/device.py | 80 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 157 insertions(+), 4 deletions(-) diff --git a/tests/test_device.py b/tests/test_device.py index 7e9c0dfcd..24fdcf0a6 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -36,7 +36,13 @@ from zha.application.platforms.sensor import LQISensor, RSSISensor from zha.application.platforms.switch import Switch from zha.exceptions import ZHAException -from zha.zigbee.device import ClusterBinding, Device, get_device_automation_triggers +from zha.zigbee.device import ( + ClusterBinding, + Device, + NeighborInfo, + RouteInfo, + get_device_automation_triggers, +) from zha.zigbee.group import Group @@ -848,3 +854,76 @@ async def test_device_properties( assert zha_device.is_router is None assert zha_device.is_end_device is None assert zha_device.is_coordinator is None + + +def test_neighbor_info_ser_deser() -> None: + """Test the serialization and deserialization of the neighbor info.""" + + neighbor_info = NeighborInfo( + ieee="00:0d:6f:00:0a:90:69:e7", + nwk=0x1234, + extended_pan_id="00:0d:6f:00:0a:90:69:e7", + lqi=255, + relationship=zdo_t._NeighborEnums.Relationship.Child.name, + depth=0, + device_type=zdo_t._NeighborEnums.DeviceType.Router.name, + rx_on_when_idle=zdo_t._NeighborEnums.RxOnWhenIdle.On.name, + permit_joining=zdo_t._NeighborEnums.PermitJoins.Accepting.name, + ) + + assert isinstance(neighbor_info.ieee, zigpy.types.EUI64) + assert isinstance(neighbor_info.nwk, zigpy.types.NWK) + assert isinstance(neighbor_info.extended_pan_id, zigpy.types.EUI64) + assert isinstance(neighbor_info.relationship, zdo_t._NeighborEnums.Relationship) + assert isinstance(neighbor_info.device_type, zdo_t._NeighborEnums.DeviceType) + assert isinstance(neighbor_info.rx_on_when_idle, zdo_t._NeighborEnums.RxOnWhenIdle) + assert isinstance(neighbor_info.permit_joining, zdo_t._NeighborEnums.PermitJoins) + + assert neighbor_info.model_dump() == { + "ieee": "00:0d:6f:00:0a:90:69:e7", + "nwk": 0x1234, + "extended_pan_id": "00:0d:6f:00:0a:90:69:e7", + "lqi": 255, + "relationship": zdo_t._NeighborEnums.Relationship.Child.name, + "depth": 0, + "device_type": zdo_t._NeighborEnums.DeviceType.Router.name, + "rx_on_when_idle": zdo_t._NeighborEnums.RxOnWhenIdle.On.name, + "permit_joining": zdo_t._NeighborEnums.PermitJoins.Accepting.name, + } + + assert neighbor_info.model_dump_json() == ( + '{"device_type":"Router","rx_on_when_idle":"On","relationship":"Child",' + '"extended_pan_id":"00:0d:6f:00:0a:90:69:e7","ieee":"00:0d:6f:00:0a:90:69:e7","nwk":4660,' + '"permit_joining":"Accepting","depth":0,"lqi":255}' + ) + + +def test_route_info_ser_deser() -> None: + """Test the serialization and deserialization of the route info.""" + + route_info = RouteInfo( + dest_nwk=0x1234, + next_hop=0x5678, + route_status=zdo_t.RouteStatus.Active.name, + memory_constrained=0, + many_to_one=1, + route_record_required=1, + ) + + assert isinstance(route_info.dest_nwk, zigpy.types.NWK) + assert isinstance(route_info.next_hop, zigpy.types.NWK) + assert isinstance(route_info.route_status, zdo_t.RouteStatus) + + assert route_info.model_dump() == { + "dest_nwk": 0x1234, + "next_hop": 0x5678, + "route_status": zdo_t.RouteStatus.Active.name, + "memory_constrained": 0, + "many_to_one": 1, + "route_record_required": 1, + } + + assert route_info.model_dump_json() == ( + '{"dest_nwk":4660,"route_status":"Active","memory_constrained":0,"many_to_one":1,' + '"route_record_required":1,"next_hop":22136}' + ) diff --git a/zha/zigbee/device.py b/zha/zigbee/device.py index be1f87afc..7c2168522 100644 --- a/zha/zigbee/device.py +++ b/zha/zigbee/device.py @@ -5,12 +5,13 @@ from __future__ import annotations import asyncio -from enum import StrEnum +from enum import Enum, StrEnum from functools import cached_property import logging import time -from typing import TYPE_CHECKING, Any, Literal, Self +from typing import TYPE_CHECKING, Any, Literal, Self, Union +from pydantic import field_serializer, field_validator from zigpy.device import Device as ZigpyDevice import zigpy.exceptions from zigpy.profiles import PROFILES @@ -61,7 +62,7 @@ from zha.event import EventBase from zha.exceptions import ZHAException from zha.mixins import LogMixin -from zha.model import BaseEvent, BaseModel +from zha.model import BaseEvent, BaseModel, convert_enum, convert_int from zha.zigbee.cluster_handlers import ClusterHandler, ZDOClusterHandler from zha.zigbee.endpoint import Endpoint @@ -151,6 +152,55 @@ class NeighborInfo(BaseModel): depth: uint8_t lqi: uint8_t + _convert_device_type = field_validator( + "device_type", mode="before", check_fields=False + )(convert_enum(_NeighborEnums.DeviceType)) + + _convert_rx_on_when_idle = field_validator( + "rx_on_when_idle", mode="before", check_fields=False + )(convert_enum(_NeighborEnums.RxOnWhenIdle)) + + _convert_relationship = field_validator( + "relationship", mode="before", check_fields=False + )(convert_enum(_NeighborEnums.Relationship)) + + _convert_permit_joining = field_validator( + "permit_joining", mode="before", check_fields=False + )(convert_enum(_NeighborEnums.PermitJoins)) + + _convert_depth = field_validator("depth", mode="before", check_fields=False)( + convert_int(uint8_t) + ) + _convert_lqi = field_validator("lqi", mode="before", check_fields=False)( + convert_int(uint8_t) + ) + + @field_validator("extended_pan_id", mode="before", check_fields=False) + @classmethod + def convert_extended_pan_id( + cls, extended_pan_id: Union[str, ExtendedPanId] + ) -> ExtendedPanId: + """Convert extended_pan_id to ExtendedPanId.""" + if isinstance(extended_pan_id, str): + return ExtendedPanId.convert(extended_pan_id) + return extended_pan_id + + @field_serializer("extended_pan_id", check_fields=False) + def serialize_extended_pan_id(self, extended_pan_id: ExtendedPanId): + """Customize how extended_pan_id is serialized.""" + return str(extended_pan_id) + + @field_serializer( + "device_type", + "rx_on_when_idle", + "relationship", + "permit_joining", + check_fields=False, + ) + def serialize_enums(self, enum_value: Enum): + """Serialize enums by name.""" + return enum_value.name + class RouteInfo(BaseModel): """Describes a route.""" @@ -162,6 +212,30 @@ class RouteInfo(BaseModel): route_record_required: uint1_t next_hop: NWK + _convert_route_status = field_validator( + "route_status", mode="before", check_fields=False + )(convert_enum(RouteStatus)) + + _convert_memory_constrained = field_validator( + "memory_constrained", mode="before", check_fields=False + )(convert_int(uint1_t)) + + _convert_many_to_one = field_validator( + "many_to_one", mode="before", check_fields=False + )(convert_int(uint1_t)) + + _convert_route_record_required = field_validator( + "route_record_required", mode="before", check_fields=False + )(convert_int(uint1_t)) + + @field_serializer( + "route_status", + check_fields=False, + ) + def serialize_route_status(self, route_status: RouteStatus): + """Serialize route_status as name.""" + return route_status.name + class EndpointNameInfo(BaseModel): """Describes an endpoint name.""" From e7b65fee7add775903055b45dfe1881613d9ed3e Mon Sep 17 00:00:00 2001 From: David Mulcahey Date: Fri, 18 Oct 2024 10:20:09 -0400 Subject: [PATCH 05/12] use hex repr for nwk --- tests/test_device.py | 14 +++++++------- tests/test_model.py | 6 +++--- zha/model.py | 9 ++++++++- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/tests/test_device.py b/tests/test_device.py index 24fdcf0a6..958b2d9fe 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -861,7 +861,7 @@ def test_neighbor_info_ser_deser() -> None: neighbor_info = NeighborInfo( ieee="00:0d:6f:00:0a:90:69:e7", - nwk=0x1234, + nwk="0x1234", extended_pan_id="00:0d:6f:00:0a:90:69:e7", lqi=255, relationship=zdo_t._NeighborEnums.Relationship.Child.name, @@ -881,7 +881,7 @@ def test_neighbor_info_ser_deser() -> None: assert neighbor_info.model_dump() == { "ieee": "00:0d:6f:00:0a:90:69:e7", - "nwk": 0x1234, + "nwk": "0x1234", "extended_pan_id": "00:0d:6f:00:0a:90:69:e7", "lqi": 255, "relationship": zdo_t._NeighborEnums.Relationship.Child.name, @@ -893,7 +893,7 @@ def test_neighbor_info_ser_deser() -> None: assert neighbor_info.model_dump_json() == ( '{"device_type":"Router","rx_on_when_idle":"On","relationship":"Child",' - '"extended_pan_id":"00:0d:6f:00:0a:90:69:e7","ieee":"00:0d:6f:00:0a:90:69:e7","nwk":4660,' + '"extended_pan_id":"00:0d:6f:00:0a:90:69:e7","ieee":"00:0d:6f:00:0a:90:69:e7","nwk":"0x1234",' '"permit_joining":"Accepting","depth":0,"lqi":255}' ) @@ -915,8 +915,8 @@ def test_route_info_ser_deser() -> None: assert isinstance(route_info.route_status, zdo_t.RouteStatus) assert route_info.model_dump() == { - "dest_nwk": 0x1234, - "next_hop": 0x5678, + "dest_nwk": "0x1234", + "next_hop": "0x5678", "route_status": zdo_t.RouteStatus.Active.name, "memory_constrained": 0, "many_to_one": 1, @@ -924,6 +924,6 @@ def test_route_info_ser_deser() -> None: } assert route_info.model_dump_json() == ( - '{"dest_nwk":4660,"route_status":"Active","memory_constrained":0,"many_to_one":1,' - '"route_record_required":1,"next_hop":22136}' + '{"dest_nwk":"0x1234","route_status":"Active","memory_constrained":0,"many_to_one":1,' + '"route_record_required":1,"next_hop":"0x5678"}' ) diff --git a/tests/test_model.py b/tests/test_model.py index 604cf9d00..64a9fb09e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -37,7 +37,7 @@ def test_ser_deser_zha_event(): device_info = DeviceInfo( ieee="00:00:00:00:00:00:00:00", - nwk=0x0000, + nwk="0x0000", manufacturer="test", model="test", name="test", @@ -60,7 +60,7 @@ def test_ser_deser_zha_event(): assert device_info.model_dump() == { "ieee": "00:00:00:00:00:00:00:00", - "nwk": 0, + "nwk": "0x0000", "manufacturer": "test", "model": "test", "name": "test", @@ -78,7 +78,7 @@ def test_ser_deser_zha_event(): } assert device_info.model_dump_json() == ( - '{"ieee":"00:00:00:00:00:00:00:00","nwk":0,' + '{"ieee":"00:00:00:00:00:00:00:00","nwk":"0x0000",' '"manufacturer":"test","model":"test","name":"test","quirk_applied":true,' '"quirk_class":"test","quirk_id":"test","manufacturer_code":0,"power_source":"test",' '"lqi":1,"rssi":2,"last_seen":"","available":true,"device_type":"test","signature":{"foo":"bar"}}' diff --git a/zha/model.py b/zha/model.py index 347f01fa6..b70b5284c 100644 --- a/zha/model.py +++ b/zha/model.py @@ -25,10 +25,12 @@ def convert_ieee(ieee: Optional[Union[str, EUI64]]) -> Optional[EUI64]: return ieee -def convert_nwk(nwk: Optional[Union[int, NWK]]) -> Optional[NWK]: +def convert_nwk(nwk: Optional[Union[int, str, NWK]]) -> Optional[NWK]: """Convert int to NWK.""" if isinstance(nwk, int) and not isinstance(nwk, NWK): return NWK(nwk) + if isinstance(nwk, str): + return NWK(int(nwk, base=16)) return nwk @@ -72,6 +74,11 @@ def serialize_ieee(self, ieee: EUI64): """Customize how ieee is serialized.""" return str(ieee) + @field_serializer("nwk", "dest_nwk", "next_hop", check_fields=False) + def serialize_nwk(self, nwk: NWK): + """Serialize nwk as hex string.""" + return repr(nwk) + class BaseEvent(BaseModel): """Base model for ZHA events.""" From c3c52bd448afa28e60bf59ab2bce332eb3eb2940 Mon Sep 17 00:00:00 2001 From: David Mulcahey Date: Fri, 18 Oct 2024 10:27:46 -0400 Subject: [PATCH 06/12] only use nwk hex repr for json dump --- tests/test_device.py | 6 +++--- tests/test_model.py | 2 +- zha/model.py | 4 +++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_device.py b/tests/test_device.py index 958b2d9fe..93fce3311 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -881,7 +881,7 @@ def test_neighbor_info_ser_deser() -> None: assert neighbor_info.model_dump() == { "ieee": "00:0d:6f:00:0a:90:69:e7", - "nwk": "0x1234", + "nwk": 0x1234, "extended_pan_id": "00:0d:6f:00:0a:90:69:e7", "lqi": 255, "relationship": zdo_t._NeighborEnums.Relationship.Child.name, @@ -915,8 +915,8 @@ def test_route_info_ser_deser() -> None: assert isinstance(route_info.route_status, zdo_t.RouteStatus) assert route_info.model_dump() == { - "dest_nwk": "0x1234", - "next_hop": "0x5678", + "dest_nwk": 0x1234, + "next_hop": 0x5678, "route_status": zdo_t.RouteStatus.Active.name, "memory_constrained": 0, "many_to_one": 1, diff --git a/tests/test_model.py b/tests/test_model.py index 64a9fb09e..bea0b679c 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -60,7 +60,7 @@ def test_ser_deser_zha_event(): assert device_info.model_dump() == { "ieee": "00:00:00:00:00:00:00:00", - "nwk": "0x0000", + "nwk": 0x0000, "manufacturer": "test", "model": "test", "name": "test", diff --git a/zha/model.py b/zha/model.py index b70b5284c..5cd582efa 100644 --- a/zha/model.py +++ b/zha/model.py @@ -74,7 +74,9 @@ def serialize_ieee(self, ieee: EUI64): """Customize how ieee is serialized.""" return str(ieee) - @field_serializer("nwk", "dest_nwk", "next_hop", check_fields=False) + @field_serializer( + "nwk", "dest_nwk", "next_hop", when_used="json", check_fields=False + ) def serialize_nwk(self, nwk: NWK): """Serialize nwk as hex string.""" return repr(nwk) From 93b27b9940547dff47a6db9fe6ebc593f6e8a648 Mon Sep 17 00:00:00 2001 From: David Mulcahey Date: Fri, 18 Oct 2024 11:10:49 -0400 Subject: [PATCH 07/12] coverage --- tests/test_device.py | 14 ++++++++++++++ tests/test_model.py | 19 +++++++++++++++++++ zha/model.py | 2 +- 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/test_device.py b/tests/test_device.py index 93fce3311..0c6f0cc96 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -927,3 +927,17 @@ def test_route_info_ser_deser() -> None: '{"dest_nwk":"0x1234","route_status":"Active","memory_constrained":0,"many_to_one":1,' '"route_record_required":1,"next_hop":"0x5678"}' ) + + +def test_convert_extended_pan_id() -> None: + """Test conversion of extended panid.""" + + extended_pan_id = zigpy.types.ExtendedPanId.convert("00:0d:6f:00:0a:90:69:e7") + + assert NeighborInfo.convert_extended_pan_id(extended_pan_id) == extended_pan_id + + converted_extended_pan_id = NeighborInfo.convert_extended_pan_id( + "00:0d:6f:00:0a:90:69:e7" + ) + assert isinstance(converted_extended_pan_id, zigpy.types.ExtendedPanId) + assert converted_extended_pan_id == extended_pan_id diff --git a/tests/test_model.py b/tests/test_model.py index bea0b679c..9203959f0 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,8 +1,12 @@ """Tests for the ZHA model module.""" +from collections.abc import Callable +from enum import Enum + from zigpy.types import NWK from zigpy.types.named import EUI64 +from zha.model import convert_enum from zha.zigbee.device import DeviceInfo, ZHAEvent @@ -83,3 +87,18 @@ def test_ser_deser_zha_event(): '"quirk_class":"test","quirk_id":"test","manufacturer_code":0,"power_source":"test",' '"lqi":1,"rssi":2,"last_seen":"","available":true,"device_type":"test","signature":{"foo":"bar"}}' ) + + +def test_convert_enum() -> None: + """Test the convert enum method.""" + + class TestEnum(Enum): + """Test enum.""" + + VALUE = 1 + + convert_test_enum: Callable[[str | Enum], Enum] = convert_enum(TestEnum) + + assert convert_test_enum(TestEnum.VALUE.name) == TestEnum.VALUE + assert isinstance(convert_test_enum(TestEnum.VALUE.name), TestEnum) + assert convert_test_enum(TestEnum.VALUE) == TestEnum.VALUE diff --git a/zha/model.py b/zha/model.py index 5cd582efa..0edfd8d66 100644 --- a/zha/model.py +++ b/zha/model.py @@ -40,7 +40,7 @@ def convert_enum(enum_type: Enum) -> Callable[[str | Enum], Enum]: def _convert_enum(enum_name_or_instance: str | Enum) -> Enum: """Convert extended_pan_id to ExtendedPanId.""" if isinstance(enum_name_or_instance, str): - return enum_type(enum_name_or_instance) # type: ignore + return enum_type[enum_name_or_instance] # type: ignore return enum_name_or_instance return _convert_enum From 4de4801ff3f0ba4abb9c42ed6de88bf5c6004484 Mon Sep 17 00:00:00 2001 From: David Mulcahey Date: Sat, 19 Oct 2024 15:40:38 -0400 Subject: [PATCH 08/12] ensure we can serialize ExtendedDeviceInfo --- ...entralite-3320-l-extended-device-info.json | 1 + tests/test_device.py | 26 +++++++++++++++++++ zha/zigbee/cluster_handlers/__init__.py | 26 +++++++++++++++++++ zha/zigbee/device.py | 17 +++++++++--- 4 files changed, 66 insertions(+), 4 deletions(-) create mode 100644 tests/data/serialization_data/centralite-3320-l-extended-device-info.json diff --git a/tests/data/serialization_data/centralite-3320-l-extended-device-info.json b/tests/data/serialization_data/centralite-3320-l-extended-device-info.json new file mode 100644 index 000000000..f52e1d153 --- /dev/null +++ b/tests/data/serialization_data/centralite-3320-l-extended-device-info.json @@ -0,0 +1 @@ +{"ieee":"00:0d:6f:00:0f:3a:e3:69","nwk":"0x970A","manufacturer":"CentraLite","model":"3320-L","name":"CentraLite 3320-L","quirk_applied":true,"quirk_class":"zhaquirks.centralite.ias.CentraLiteIASSensor","quirk_id":null,"manufacturer_code":49887,"power_source":"Battery or Unknown","lqi":null,"rssi":null,"available":true,"device_type":"EndDevice","signature":{"node_descriptor":{"logical_type":2,"complex_descriptor_available":0,"user_descriptor_available":0,"reserved":0,"aps_flags":0,"frequency_band":8,"mac_capability_flags":128,"manufacturer_code":49887,"maximum_buffer_size":82,"maximum_incoming_transfer_size":82,"server_mask":0,"maximum_outgoing_transfer_size":82,"descriptor_capability_field":0},"endpoints":{"1":{"profile_id":"0x0104","device_type":"0x0402","input_clusters":["0x0000","0x0001","0x0003","0x0020","0x0402","0x0500","0x0b05"],"output_clusters":["0x0019"]},"2":{"profile_id":"0xc2df","device_type":"0x000c","input_clusters":["0x0000","0x0003","0x0b05","0xfc0f"],"output_clusters":["0x0003"]}},"manufacturer":"CentraLite","model":"3320-L"},"active_coordinator":false,"entities":{"binary_sensor,00:0d:6f:00:0f:3a:e3:69-1-1280":{"platform":"binary_sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-1280","class_name":"IASZone","translation_key":null,"device_class":"opening","state_class":null,"entity_category":null,"entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"cluster_handlers":[{"class_name":"IASZoneClusterHandler","generic_id":"cluster_handler_0x0500","endpoint_id":1,"cluster":{"id":1280,"name":"IAS Zone","type":"server","commands":[{"id":0,"name":"enroll_response","schema":{"command":"enroll_response","fields":[{"name":"enroll_response_code","type":"EnrollResponse","optional":false},{"name":"zone_id","type":"uint8_t","optional":false}]},"direction":1,"is_manufacturer_specific":null},{"id":1,"name":"init_normal_op_mode","schema":{"command":"init_normal_op_mode","fields":[]},"direction":0,"is_manufacturer_specific":null},{"id":2,"name":"init_test_mode","schema":{"command":"init_test_mode","fields":[{"name":"test_mode_duration","type":"uint8_t","optional":false},{"name":"current_zone_sensitivity_level","type":"uint8_t","optional":false}]},"direction":0,"is_manufacturer_specific":null}]},"id":"1:0x0500","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0500","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null},"button,00:0d:6f:00:0f:3a:e3:69-1-3":{"platform":"button","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-3","class_name":"IdentifyButton","translation_key":null,"device_class":"identify","state_class":null,"entity_category":"diagnostic","entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"cluster_handlers":[{"class_name":"IdentifyClusterHandler","generic_id":"cluster_handler_0x0003","endpoint_id":1,"cluster":{"id":3,"name":"Identify","type":"server","commands":[{"id":0,"name":"identify","schema":{"command":"identify","fields":[{"name":"identify_time","type":"uint16_t","optional":false}]},"direction":0,"is_manufacturer_specific":null},{"id":1,"name":"identify_query","schema":{"command":"identify_query","fields":[]},"direction":0,"is_manufacturer_specific":null},{"id":64,"name":"trigger_effect","schema":{"command":"trigger_effect","fields":[{"name":"effect_id","type":"EffectIdentifier","optional":false},{"name":"effect_variant","type":"EffectVariant","optional":false}]},"direction":0,"is_manufacturer_specific":null}]},"id":"1:0x0003","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0003","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null},"sensor,00:0d:6f:00:0f:3a:e3:69-1-1":{"platform":"sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-1","class_name":"Battery","translation_key":null,"device_class":"battery","state_class":"measurement","entity_category":"diagnostic","entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"cluster_handlers":[{"class_name":"PowerConfigurationClusterHandler","generic_id":"cluster_handler_0x0001","endpoint_id":1,"cluster":{"id":1,"name":"Power Configuration","type":"server","commands":[]},"id":"1:0x0001","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0001","status":"initialized","value_attribute":"battery_voltage"}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null},"sensor,00:0d:6f:00:0f:3a:e3:69-1-1026":{"platform":"sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-1026","class_name":"Temperature","translation_key":null,"device_class":"temperature","state_class":"measurement","entity_category":null,"entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"cluster_handlers":[{"class_name":"TemperatureMeasurementClusterHandler","generic_id":"cluster_handler_0x0402","endpoint_id":1,"cluster":{"id":1026,"name":"Temperature Measurement","type":"server","commands":[]},"id":"1:0x0402","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0402","status":"initialized","value_attribute":"measured_value"}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null},"sensor,00:0d:6f:00:0f:3a:e3:69-1-0-rssi":{"platform":"sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-0-rssi","class_name":"RSSISensor","translation_key":"rssi","device_class":"signal_strength","state_class":"measurement","entity_category":"diagnostic","entity_registry_enabled_default":false,"enabled":true,"fallback_name":null,"cluster_handlers":[{"class_name":"BasicClusterHandler","generic_id":"cluster_handler_0x0000","endpoint_id":1,"cluster":{"id":0,"name":"Basic","type":"server","commands":[{"id":0,"name":"reset_fact_default","schema":{"command":"reset_fact_default","fields":[]},"direction":0,"is_manufacturer_specific":null}]},"id":"1:0x0000","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0000","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null},"sensor,00:0d:6f:00:0f:3a:e3:69-1-0-lqi":{"platform":"sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-0-lqi","class_name":"LQISensor","translation_key":"lqi","device_class":null,"state_class":"measurement","entity_category":"diagnostic","entity_registry_enabled_default":false,"enabled":true,"fallback_name":null,"cluster_handlers":[{"class_name":"BasicClusterHandler","generic_id":"cluster_handler_0x0000","endpoint_id":1,"cluster":{"id":0,"name":"Basic","type":"server","commands":[{"id":0,"name":"reset_fact_default","schema":{"command":"reset_fact_default","fields":[]},"direction":0,"is_manufacturer_specific":null}]},"id":"1:0x0000","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0000","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null},"update,00:0d:6f:00:0f:3a:e3:69-1-25-firmware_update":{"platform":"update","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-25-firmware_update","class_name":"FirmwareUpdateEntity","translation_key":null,"device_class":"firmware","state_class":null,"entity_category":"config","entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"cluster_handlers":[{"class_name":"OtaClientClusterHandler","generic_id":"cluster_handler_0x0019","endpoint_id":1,"cluster":{"id":25,"name":"Ota","type":"client","commands":[{"id":3,"name":"image_block","schema":{"command":"image_block","fields":[{"name":"field_control","type":"FieldControl","optional":false},{"name":"manufacturer_code","type":"uint16_t","optional":false},{"name":"image_type","type":"uint16_t","optional":false},{"name":"file_version","type":"uint32_t","optional":false},{"name":"file_offset","type":"uint32_t","optional":false},{"name":"maximum_data_size","type":"uint8_t","optional":false},{"name":"request_node_addr","type":"EUI64","optional":false},{"name":"minimum_block_period","type":"uint16_t","optional":false}]},"direction":0,"is_manufacturer_specific":null},{"id":4,"name":"image_page","schema":{"command":"image_page","fields":[{"name":"field_control","type":"FieldControl","optional":false},{"name":"manufacturer_code","type":"uint16_t","optional":false},{"name":"image_type","type":"uint16_t","optional":false},{"name":"file_version","type":"uint32_t","optional":false},{"name":"file_offset","type":"uint32_t","optional":false},{"name":"maximum_data_size","type":"uint8_t","optional":false},{"name":"page_size","type":"uint16_t","optional":false},{"name":"response_spacing","type":"uint16_t","optional":false},{"name":"request_node_addr","type":"EUI64","optional":false}]},"direction":0,"is_manufacturer_specific":null},{"id":1,"name":"query_next_image","schema":{"command":"query_next_image","fields":[{"name":"field_control","type":"FieldControl","optional":false},{"name":"manufacturer_code","type":"uint16_t","optional":false},{"name":"image_type","type":"uint16_t","optional":false},{"name":"current_file_version","type":"uint32_t","optional":false},{"name":"hardware_version","type":"uint16_t","optional":false}]},"direction":0,"is_manufacturer_specific":null},{"id":8,"name":"query_specific_file","schema":{"command":"query_specific_file","fields":[{"name":"request_node_addr","type":"EUI64","optional":false},{"name":"manufacturer_code","type":"uint16_t","optional":false},{"name":"image_type","type":"uint16_t","optional":false},{"name":"file_version","type":"uint32_t","optional":false},{"name":"current_zigbee_stack_version","type":"uint16_t","optional":false}]},"direction":0,"is_manufacturer_specific":null},{"id":6,"name":"upgrade_end","schema":{"command":"upgrade_end","fields":[{"name":"status","type":"Status","optional":false},{"name":"manufacturer_code","type":"uint16_t","optional":false},{"name":"image_type","type":"uint16_t","optional":false},{"name":"file_version","type":"uint32_t","optional":false}]},"direction":0,"is_manufacturer_specific":null}]},"id":"1:0x0019","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0019","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null}},"neighbors":[],"routes":[],"endpoint_names":[{"name":"IAS_ZONE"},{"name":"unknown 12 device_type of 0xc2df profile id"}],"device_automation_triggers":{"device_offline,device_offline":{"device_event_type":"device_offline"}}} \ No newline at end of file diff --git a/tests/test_device.py b/tests/test_device.py index 0c6f0cc96..486d65e4b 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -22,6 +22,7 @@ SIG_EP_TYPE, create_mock_zigpy_device, join_zigpy_device, + zigpy_device_from_json, ) from zha.application import Platform from zha.application.const import ( @@ -941,3 +942,28 @@ def test_convert_extended_pan_id() -> None: ) assert isinstance(converted_extended_pan_id, zigpy.types.ExtendedPanId) assert converted_extended_pan_id == extended_pan_id + + +async def test_extended_device_info_ser_deser(zha_gateway: Gateway) -> None: + """Test the serialization and deserialization of the extended device info.""" + + zigpy_dev = await zigpy_device_from_json( + zha_gateway.application_controller, "tests/data/devices/centralite-3320-l.json" + ) + zha_device = await join_zigpy_device(zha_gateway, zigpy_dev) + assert zha_device is not None + + assert isinstance(zha_device.extended_device_info.ieee, zigpy.types.EUI64) + assert isinstance(zha_device.extended_device_info.nwk, zigpy.types.NWK) + + # last_seen changes so we exclude it from the comparison + json = zha_device.extended_device_info.model_dump_json(exclude=["last_seen"]) + + # load the json from a file as string + with open( + "tests/data/serialization_data/centralite-3320-l-extended-device-info.json", + encoding="UTF-8", + ) as file: + expected_json = file.read() + + assert json == expected_json diff --git a/zha/zigbee/cluster_handlers/__init__.py b/zha/zigbee/cluster_handlers/__init__.py index 3860ed2ae..940bf6a41 100644 --- a/zha/zigbee/cluster_handlers/__init__.py +++ b/zha/zigbee/cluster_handlers/__init__.py @@ -9,6 +9,7 @@ import logging from typing import TYPE_CHECKING, Any, Literal, ParamSpec, TypedDict +from pydantic import field_serializer import zigpy.exceptions import zigpy.util import zigpy.zcl @@ -167,6 +168,31 @@ class ClusterInfo(BaseModel): type: str commands: list[ZCLCommandDef] + @field_serializer("commands", when_used="json-unless-none", check_fields=False) + def serialize_commands(self, commands: list[ZCLCommandDef]): + """Serialize commands.""" + converted_commands = [] + for command in commands: + converted_command = { + "id": command.id, + "name": command.name, + "schema": { + "command": command.schema.command.name, + "fields": [ + { + "name": f.name, + "type": f.type.__name__, + "optional": f.optional, + } + for f in command.schema.fields + ], + }, + "direction": command.direction, + "is_manufacturer_specific": command.is_manufacturer_specific, + } + converted_commands.append(converted_command) + return converted_commands + class ClusterHandlerInfo(BaseModel): """Cluster handler information.""" diff --git a/zha/zigbee/device.py b/zha/zigbee/device.py index 7c2168522..abf9262e3 100644 --- a/zha/zigbee/device.py +++ b/zha/zigbee/device.py @@ -138,6 +138,13 @@ class DeviceInfo(BaseModel): device_type: str signature: dict[str, Any] + @field_serializer("signature", when_used="json-unless-none", check_fields=False) + def serialize_signature(self, signature: dict[str, Any]): + """Serialize signature.""" + if "node_descriptor" in signature: + signature["node_descriptor"] = signature["node_descriptor"].as_dict() + return signature + class NeighborInfo(BaseModel): """Describes a neighbor.""" @@ -247,10 +254,11 @@ class ExtendedDeviceInfo(DeviceInfo): """Describes a ZHA device.""" active_coordinator: bool - entities: dict[str, BaseEntityInfo] + entities: dict[tuple[Platform, str], BaseEntityInfo] neighbors: list[NeighborInfo] routes: list[RouteInfo] endpoint_names: list[EndpointNameInfo] + device_automation_triggers: dict[tuple[str, str], dict[str, Any]] class Device(LogMixin, EventBase): @@ -471,7 +479,7 @@ def device_automation_commands(self) -> dict[str, list[tuple[str, str]]]: return commands @cached_property - def device_automation_triggers(self) -> dict[tuple[str, str], dict[str, str]]: + def device_automation_triggers(self) -> dict[tuple[str, str], dict[str, Any]]: """Return the device automation triggers for this device.""" return get_device_automation_triggers(self._zigpy_device) @@ -745,8 +753,8 @@ def extended_device_info(self) -> ExtendedDeviceInfo: **self.device_info.__dict__, active_coordinator=self.is_active_coordinator, entities={ - platform_entity.unique_id: platform_entity.info_object - for platform_entity in self.platform_entities.values() + platform_entity_key: platform_entity.info_object + for platform_entity_key, platform_entity in self.platform_entities.items() }, neighbors=[ NeighborInfo( @@ -774,6 +782,7 @@ def extended_device_info(self) -> ExtendedDeviceInfo: for route in topology.routes[self.ieee] ], endpoint_names=names, + device_automation_triggers=self.device_automation_triggers, ) async def async_configure(self) -> None: From 4a3e1abdd4da86fb2e2bacecb44d6fb419552cc3 Mon Sep 17 00:00:00 2001 From: David Mulcahey Date: Fri, 18 Oct 2024 20:41:04 -0400 Subject: [PATCH 09/12] Add websocket functionality --- pyproject.toml | 4 +- tests/conftest.py | 37 +- tests/test_websocket_server_client.py | 58 ++ zha/application/helpers.py | 9 + zha/websocket/__init__.py | 1 + zha/websocket/client/__init__.py | 1 + zha/websocket/client/__main__.py | 9 + zha/websocket/client/client.py | 271 +++++++++ zha/websocket/client/controller.py | 228 ++++++++ zha/websocket/client/helpers.py | 301 ++++++++++ zha/websocket/client/model/__init__.py | 1 + zha/websocket/client/model/commands.py | 200 +++++++ zha/websocket/client/model/events.py | 263 +++++++++ zha/websocket/client/model/messages.py | 67 +++ zha/websocket/client/model/types.py | 760 +++++++++++++++++++++++++ zha/websocket/client/proxy.py | 114 ++++ zha/websocket/const.py | 170 ++++++ zha/websocket/server/__init__.py | 1 + zha/websocket/server/api/__init__.py | 31 + zha/websocket/server/api/decorators.py | 72 +++ zha/websocket/server/api/model.py | 65 +++ zha/websocket/server/api/types.py | 15 + zha/websocket/server/client.py | 294 ++++++++++ zha/websocket/server/gateway.py | 144 +++++ zha/websocket/server/gateway_api.py | 474 +++++++++++++++ 25 files changed, 3586 insertions(+), 4 deletions(-) create mode 100644 tests/test_websocket_server_client.py create mode 100644 zha/websocket/__init__.py create mode 100644 zha/websocket/client/__init__.py create mode 100644 zha/websocket/client/__main__.py create mode 100644 zha/websocket/client/client.py create mode 100644 zha/websocket/client/controller.py create mode 100644 zha/websocket/client/helpers.py create mode 100644 zha/websocket/client/model/__init__.py create mode 100644 zha/websocket/client/model/commands.py create mode 100644 zha/websocket/client/model/events.py create mode 100644 zha/websocket/client/model/messages.py create mode 100644 zha/websocket/client/model/types.py create mode 100644 zha/websocket/client/proxy.py create mode 100644 zha/websocket/const.py create mode 100644 zha/websocket/server/__init__.py create mode 100644 zha/websocket/server/api/__init__.py create mode 100644 zha/websocket/server/api/decorators.py create mode 100644 zha/websocket/server/api/model.py create mode 100644 zha/websocket/server/api/types.py create mode 100644 zha/websocket/server/client.py create mode 100644 zha/websocket/server/gateway.py create mode 100644 zha/websocket/server/gateway_api.py diff --git a/pyproject.toml b/pyproject.toml index f3b9ed562..f5a1d4a9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,9 @@ dependencies = [ "zha-quirks==0.0.124", "pyserial==3.5", "pyserial-asyncio-fast", - "pydantic==2.9.2" + "pydantic==2.9.2", + "websockets", + "aiohttp" ] [tool.setuptools.packages.find] diff --git a/tests/conftest.py b/tests/conftest.py index 397d9124e..81290f427 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ """Test configuration for the ZHA component.""" import asyncio -from collections.abc import Callable, Generator +from collections.abc import AsyncGenerator, Callable, Generator from contextlib import contextmanager import logging import os @@ -10,6 +10,7 @@ from types import TracebackType from unittest.mock import AsyncMock, MagicMock, patch +import aiohttp.test_utils import pytest import zigpy from zigpy.application import ControllerApplication @@ -28,10 +29,13 @@ AlarmControlPanelOptions, CoordinatorConfiguration, LightOptions, + ServerConfiguration, ZHAConfiguration, ZHAData, ) from zha.async_ import ZHAJob +from zha.websocket.client.controller import Controller +from zha.websocket.server.gateway import WebSocketGateway FIXTURE_GRP_ID = 0x1001 FIXTURE_GRP_NAME = "fixture group" @@ -253,7 +257,7 @@ def caplog_fixture(caplog: pytest.LogCaptureFixture) -> pytest.LogCaptureFixture @pytest.fixture(name="zha_data") def zha_data_fixture() -> ZHAData: """Fixture representing zha configuration data.""" - + port = aiohttp.test_utils.unused_port() return ZHAData( config=ZHAConfiguration( coordinator_configuration=CoordinatorConfiguration( @@ -269,7 +273,12 @@ def zha_data_fixture() -> ZHAData: master_code="4321", failed_tries=2, ), - ) + ), + server_config=ServerConfiguration( + host="localhost", + port=port, + network_auto_start=False, + ), ) @@ -299,6 +308,28 @@ async def __aexit__( await asyncio.sleep(0) +@pytest.fixture +async def connected_client_and_server( + zha_data: ZHAData, + zigpy_app_controller: ControllerApplication, +) -> AsyncGenerator[tuple[Controller, WebSocketGateway], None]: + """Return the connected client and server fixture.""" + + application_controller_patch = patch( + "bellows.zigbee.application.ControllerApplication.new", + return_value=zigpy_app_controller, + ) + + with application_controller_patch: + ws_gateway = await WebSocketGateway.async_from_config(zha_data) + async with ( + ws_gateway as gateway, + Controller(f"ws://localhost:{zha_data.server_config.port}") as controller, + ): + await controller.clients.listen() + yield controller, gateway + + @pytest.fixture async def zha_gateway( zha_data: ZHAData, diff --git a/tests/test_websocket_server_client.py b/tests/test_websocket_server_client.py new file mode 100644 index 000000000..5ca9ad0ce --- /dev/null +++ b/tests/test_websocket_server_client.py @@ -0,0 +1,58 @@ +"""Tests for the server and client.""" + +from __future__ import annotations + +from zha.application.helpers import ZHAData +from zha.websocket.client.client import Client +from zha.websocket.client.controller import Controller +from zha.websocket.server.gateway import StopServerCommand, WebSocketGateway + + +async def test_server_client_connect_disconnect( + zha_data: ZHAData, +) -> None: + """Tests basic connect/disconnect logic.""" + + async with WebSocketGateway(zha_data) as gateway: + assert gateway.is_serving + assert gateway._ws_server is not None + + async with Client(f"ws://localhost:{zha_data.server_config.port}") as client: + assert client.connected + assert "connected" in repr(client) + + # The client does not begin listening immediately + assert client._listen_task is None + await client.listen() + assert client._listen_task is not None + + # The listen task is automatically stopped when we disconnect + assert client._listen_task is None + assert "not connected" in repr(client) + assert not client.connected + + assert not gateway.is_serving + assert gateway._ws_server is None + + +async def test_client_message_id_uniqueness( + connected_client_and_server: tuple[Controller, WebSocketGateway], +) -> None: + """Tests that client message IDs are unique.""" + controller, gateway = connected_client_and_server + + ids = [controller.client.new_message_id() for _ in range(1000)] + assert len(ids) == len(set(ids)) + + +async def test_client_stop_server( + connected_client_and_server: tuple[Controller, WebSocketGateway], +) -> None: + """Tests that the client can stop the server.""" + controller, gateway = connected_client_and_server + + assert gateway.is_serving + await controller.client.async_send_command_no_wait(StopServerCommand()) + await controller.disconnect() + await gateway.wait_closed() + assert not gateway.is_serving diff --git a/zha/application/helpers.py b/zha/application/helpers.py index b690c17c0..037c84f3f 100644 --- a/zha/application/helpers.py +++ b/zha/application/helpers.py @@ -316,6 +316,14 @@ class DeviceOverridesConfiguration(BaseModel): type: Platform +class ServerConfiguration(BaseModel): + """Server configuration for zhaws.""" + + host: str = "0.0.0.0" + port: int = 8001 + network_auto_start: bool = False + + class ZHAConfiguration(BaseModel): """ZHA configuration.""" @@ -340,6 +348,7 @@ class ZHAData: """ZHA data stored in `gateway.data`.""" config: ZHAConfiguration + server_config: ServerConfiguration | None = None zigpy_config: dict[str, Any] = dataclasses.field(default_factory=dict) platforms: collections.defaultdict[Platform, list] = dataclasses.field( default_factory=lambda: collections.defaultdict(list) diff --git a/zha/websocket/__init__.py b/zha/websocket/__init__.py new file mode 100644 index 000000000..88196b389 --- /dev/null +++ b/zha/websocket/__init__.py @@ -0,0 +1 @@ +"""Websocket module for Zigbee Home Automation.""" diff --git a/zha/websocket/client/__init__.py b/zha/websocket/client/__init__.py new file mode 100644 index 000000000..656fa0b69 --- /dev/null +++ b/zha/websocket/client/__init__.py @@ -0,0 +1 @@ +"""Client for the ZHAWSS server.""" diff --git a/zha/websocket/client/__main__.py b/zha/websocket/client/__main__.py new file mode 100644 index 000000000..221ac60db --- /dev/null +++ b/zha/websocket/client/__main__.py @@ -0,0 +1,9 @@ +"""Main module for zhawss.""" + +from websockets.__main__ import main as websockets_cli + +if __name__ == "__main__": + # "Importing this module enables command line editing using GNU readline." + import readline # noqa: F401 + + websockets_cli() diff --git a/zha/websocket/client/client.py b/zha/websocket/client/client.py new file mode 100644 index 000000000..ec8fd3ef4 --- /dev/null +++ b/zha/websocket/client/client.py @@ -0,0 +1,271 @@ +"""Client implementation for the zhaws.client.""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +import pprint +from types import TracebackType +from typing import Any + +from aiohttp import ClientSession, ClientWebSocketResponse, client_exceptions +from aiohttp.http_websocket import WSMsgType +from async_timeout import timeout + +from zha.event import EventBase +from zha.websocket.client.model.commands import CommandResponse, ErrorResponse +from zha.websocket.client.model.messages import Message +from zha.websocket.server.api.model import WebSocketCommand + +SIZE_PARSE_JSON_EXECUTOR = 8192 +_LOGGER = logging.getLogger(__package__) + + +class Client(EventBase): + """Class to manage the IoT connection.""" + + def __init__( + self, + ws_server_url: str, + aiohttp_session: ClientSession | None = None, + *args: Any, + **kwargs: Any, + ) -> None: + """Initialize the Client class.""" + super().__init__(*args, **kwargs) + self.ws_server_url = ws_server_url + + # Create a session if none is provided + if aiohttp_session is None: + self.aiohttp_session = ClientSession() + self._close_aiohttp_session: bool = True + else: + self.aiohttp_session = aiohttp_session + self._close_aiohttp_session = False + + # The WebSocket client + self._client: ClientWebSocketResponse | None = None + self._loop = asyncio.get_running_loop() + self._result_futures: dict[int, asyncio.Future] = {} + self._listen_task: asyncio.Task | None = None + + self._message_id = 0 + + def __repr__(self) -> str: + """Return the representation.""" + prefix = "" if self.connected else "not " + return f"{type(self).__name__}(ws_server_url={self.ws_server_url!r}, {prefix}connected)" + + @property + def connected(self) -> bool: + """Return if we're currently connected.""" + return self._client is not None and not self._client.closed + + def new_message_id(self) -> int: + """Create a new message ID. + + XXX: JSON doesn't define limits for integers but JavaScript itself internally + uses double precision floats for numbers (including in `JSON.parse`), setting + a hard limit of `Number.MAX_SAFE_INTEGER == 2^53 - 1`. We can be more + conservative and just restrict it to the maximum value of a 32-bit signed int. + """ + self._message_id = (self._message_id + 1) % 0x80000000 + return self._message_id + + async def async_send_command( + self, + command: WebSocketCommand, + ) -> CommandResponse: + """Send a command and get a response.""" + future: asyncio.Future[CommandResponse] = self._loop.create_future() + message_id = command.message_id = self.new_message_id() + self._result_futures[message_id] = future + + try: + async with timeout(20): + await self._send_json_message( + command.model_dump_json(exclude_none=True) + ) + return await future + except TimeoutError: + _LOGGER.exception("Timeout waiting for response") + return CommandResponse.model_validate( + {"message_id": message_id, "success": False} + ) + except Exception as err: + _LOGGER.exception("Error sending command", exc_info=err) + return CommandResponse.model_validate( + {"message_id": message_id, "success": False} + ) + finally: + self._result_futures.pop(message_id) + + async def async_send_command_no_wait(self, command: WebSocketCommand) -> None: + """Send a command without waiting for the response.""" + command.message_id = self.new_message_id() + await self._send_json_message(command.model_dump_json(exclude_none=True)) + + async def connect(self) -> None: + """Connect to the websocket server.""" + + _LOGGER.debug("Trying to connect") + try: + self._client = await self.aiohttp_session.ws_connect( + self.ws_server_url, + heartbeat=55, + compress=15, + max_msg_size=0, + ) + except client_exceptions.ClientError as err: + _LOGGER.exception("Error connecting to server", exc_info=err) + raise err + + async def listen_loop(self) -> None: + """Listen to the websocket.""" + assert self._client is not None + while not self._client.closed: + data = await self._receive_json_or_raise() + self._handle_incoming_message(data) + + async def listen(self) -> None: + """Start listening to the websocket.""" + if not self.connected: + raise Exception("Not connected when start listening") # noqa: TRY002 + + assert self._client + + assert self._listen_task is None + self._listen_task = asyncio.create_task(self.listen_loop()) + + async def disconnect(self) -> None: + """Disconnect the client.""" + _LOGGER.debug("Closing client connection") + + if self._listen_task is not None: + self._listen_task.cancel() + + with contextlib.suppress(asyncio.CancelledError): + await self._listen_task + + self._listen_task = None + + assert self._client is not None + await self._client.close() + + if self._close_aiohttp_session: + await self.aiohttp_session.close() + + _LOGGER.debug("Listen completed. Cleaning up") + + for future in self._result_futures.values(): + future.cancel() + + self._result_futures.clear() + + async def _receive_json_or_raise(self) -> dict: + """Receive json or raise.""" + assert self._client + msg = await self._client.receive() + + if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): + raise Exception("Connection was closed.") # noqa: TRY002 + + if msg.type == WSMsgType.ERROR: + raise Exception() # noqa: TRY002 + + if msg.type != WSMsgType.TEXT: + raise Exception(f"Received non-Text message: {msg.type}") # noqa: TRY002 + + try: + if len(msg.data) > SIZE_PARSE_JSON_EXECUTOR: + data: dict = await self._loop.run_in_executor(None, msg.json) + else: + data = msg.json() + except ValueError as err: + raise Exception("Received invalid JSON.") from err # noqa: TRY002 + + if _LOGGER.isEnabledFor(logging.DEBUG): + _LOGGER.debug("Received message:\n%s\n", pprint.pformat(msg)) + + return data + + def _handle_incoming_message(self, msg: dict) -> None: + """Handle incoming message. + + Run all async tasks in a wrapper to log appropriately. + """ + + try: + message = Message.model_validate(msg).root + except Exception as err: + _LOGGER.exception("Error parsing message: %s", msg, exc_info=err) + if msg["message_type"] == "result": + future = self._result_futures.get(msg["message_id"]) + if future is not None: + future.set_exception(err) + return + return + + if message.message_type == "result": + future = self._result_futures.get(message.message_id) + + if future is None: + # no listener for this result + return + + if message.success or isinstance(message, ErrorResponse): + future.set_result(message) + return + + if msg["error_code"] != "zigbee_error": + error = Exception(msg["message_id"], msg["error_code"]) + else: + error = Exception( + msg["message_id"], + msg["zigbee_error_code"], + msg["zigbee_error_message"], + ) + + future.set_exception(error) + return + + if message.message_type != "event": + # Can't handle + _LOGGER.debug( + "Received message with unknown type '%s': %s", + msg["message_type"], + msg, + ) + return + + try: + self.emit(message.event_type, message) + except Exception as err: + _LOGGER.exception("Error handling event", exc_info=err) + + async def _send_json_message(self, message: str) -> None: + """Send a message. + + Raises NotConnected if client not connected. + """ + if not self.connected: + raise Exception() # noqa: TRY002 + + _LOGGER.debug("Publishing message:\n%s\n", pprint.pformat(message)) + + assert self._client + assert "message_id" in message + + await self._client.send_str(message) + + async def __aenter__(self) -> Client: + """Connect to the websocket.""" + await self.connect() + return self + + async def __aexit__( + self, exc_type: Exception, exc_value: str, traceback: TracebackType + ) -> None: + """Disconnect from the websocket.""" + await self.disconnect() diff --git a/zha/websocket/client/controller.py b/zha/websocket/client/controller.py new file mode 100644 index 000000000..717632301 --- /dev/null +++ b/zha/websocket/client/controller.py @@ -0,0 +1,228 @@ +"""Controller implementation for the zhaws.client.""" + +from __future__ import annotations + +import logging +from types import TracebackType + +from aiohttp import ClientSession +from async_timeout import timeout +from zigpy.types.named import EUI64 + +from zha.event import EventBase +from zha.websocket.client.client import Client +from zha.websocket.client.helpers import ( + ClientHelper, + DeviceHelper, + GroupHelper, + NetworkHelper, + ServerHelper, +) +from zha.websocket.client.model.commands import CommandResponse +from zha.websocket.client.model.events import ( + DeviceConfiguredEvent, + DeviceFullyInitializedEvent, + DeviceJoinedEvent, + DeviceLeftEvent, + DeviceRemovedEvent, + GroupAddedEvent, + GroupMemberAddedEvent, + GroupMemberRemovedEvent, + GroupRemovedEvent, + PlatformEntityStateChangedEvent, + RawDeviceInitializedEvent, + ZHAEvent, +) +from zha.websocket.client.proxy import DeviceProxy, GroupProxy +from zha.websocket.const import ControllerEvents, EventTypes +from zha.websocket.server.api.model import WebSocketCommand + +CONNECT_TIMEOUT = 10 + +_LOGGER = logging.getLogger(__name__) + + +class Controller(EventBase): + """Controller implementation.""" + + def __init__( + self, ws_server_url: str, aiohttp_session: ClientSession | None = None + ): + """Initialize the controller.""" + super().__init__() + self._ws_server_url: str = ws_server_url + self._client: Client = Client(ws_server_url, aiohttp_session) + self._devices: dict[EUI64, DeviceProxy] = {} + self._groups: dict[int, GroupProxy] = {} + + self.clients: ClientHelper = ClientHelper(self._client) + self.groups_helper: GroupHelper = GroupHelper(self._client) + self.devices_helper: DeviceHelper = DeviceHelper(self._client) + self.network: NetworkHelper = NetworkHelper(self._client) + self.server_helper: ServerHelper = ServerHelper(self._client) + + # subscribe to event types we care about + self._client.on_event( + EventTypes.PLATFORM_ENTITY_EVENT, self._handle_event_protocol + ) + self._client.on_event(EventTypes.DEVICE_EVENT, self._handle_event_protocol) + self._client.on_event(EventTypes.CONTROLLER_EVENT, self._handle_event_protocol) + + @property + def client(self) -> Client: + """Return the client.""" + return self._client + + @property + def devices(self) -> dict[EUI64, DeviceProxy]: + """Return the devices.""" + return self._devices + + @property + def groups(self) -> dict[int, GroupProxy]: + """Return the groups.""" + return self._groups + + async def connect(self) -> None: + """Connect to the websocket server.""" + _LOGGER.debug("Connecting to websocket server at: %s", self._ws_server_url) + try: + async with timeout(CONNECT_TIMEOUT): + await self._client.connect() + except Exception as err: + _LOGGER.exception("Unable to connect to the ZHA wss", exc_info=err) + raise err + + await self._client.listen() + + async def disconnect(self) -> None: + """Disconnect from the websocket server.""" + await self._client.disconnect() + + async def __aenter__(self) -> Controller: + """Connect to the websocket server.""" + await self.connect() + return self + + async def __aexit__( + self, exc_type: Exception, exc_value: str, traceback: TracebackType + ) -> None: + """Disconnect from the websocket server.""" + await self.disconnect() + + async def send_command(self, command: WebSocketCommand) -> CommandResponse: + """Send a command and get a response.""" + return await self._client.async_send_command(command) + + async def load_devices(self) -> None: + """Load devices from the websocket server.""" + response_devices = await self.devices_helper.get_devices() + for ieee, device in response_devices.items(): + self._devices[ieee] = DeviceProxy(device, self, self._client) + + async def load_groups(self) -> None: + """Load groups from the websocket server.""" + response_groups = await self.groups_helper.get_groups() + for group_id, group in response_groups.items(): + self._groups[group_id] = GroupProxy(group, self, self._client) + + def handle_platform_entity_state_changed( + self, event: PlatformEntityStateChangedEvent + ) -> None: + """Handle a platform_entity_event from the websocket server.""" + _LOGGER.debug("platform_entity_event: %s", event) + if event.device: + device = self.devices.get(event.device.ieee) + if device is None: + _LOGGER.warning("Received event from unknown device: %s", event) + return + device.emit_platform_entity_event(event) + elif event.group: + group = self.groups.get(event.group.id) + if not group: + _LOGGER.warning("Received event from unknown group: %s", event) + return + group.emit_platform_entity_event(event) + + def handle_zha_event(self, event: ZHAEvent) -> None: + """Handle a zha_event from the websocket server.""" + _LOGGER.debug("zha_event: %s", event) + device = self.devices.get(event.device.ieee) + if device is None: + _LOGGER.warning("Received zha_event from unknown device: %s", event) + return + device.emit("zha_event", event) + + def handle_device_joined(self, event: DeviceJoinedEvent) -> None: + """Handle device joined. + + At this point, no information about the device is known other than its + address + """ + _LOGGER.info("Device %s - %s joined", event.ieee, event.nwk) + self.emit(ControllerEvents.DEVICE_JOINED, event) + + def handle_raw_device_initialized(self, event: RawDeviceInitializedEvent) -> None: + """Handle a device initialization without quirks loaded.""" + _LOGGER.info("Device %s - %s raw device initialized", event.ieee, event.nwk) + self.emit(ControllerEvents.RAW_DEVICE_INITIALIZED, event) + + def handle_device_configured(self, event: DeviceConfiguredEvent) -> None: + """Handle device configured event.""" + device = event.device + _LOGGER.info("Device %s - %s configured", device.ieee, device.nwk) + self.emit(ControllerEvents.DEVICE_CONFIGURED, event) + + def handle_device_fully_initialized( + self, event: DeviceFullyInitializedEvent + ) -> None: + """Handle device joined and basic information discovered.""" + device_model = event.device + _LOGGER.info("Device %s - %s initialized", device_model.ieee, device_model.nwk) + if device_model.ieee in self.devices: + self.devices[device_model.ieee].device_model = device_model + else: + self._devices[device_model.ieee] = DeviceProxy( + device_model, self, self._client + ) + self.emit(ControllerEvents.DEVICE_FULLY_INITIALIZED, event) + + def handle_device_left(self, event: DeviceLeftEvent) -> None: + """Handle device leaving the network.""" + _LOGGER.info("Device %s - %s left", event.ieee, event.nwk) + self.emit(ControllerEvents.DEVICE_LEFT, event) + + def handle_device_removed(self, event: DeviceRemovedEvent) -> None: + """Handle device being removed from the network.""" + device = event.device + _LOGGER.info( + "Device %s - %s has been removed from the network", device.ieee, device.nwk + ) + self._devices.pop(device.ieee, None) + self.emit(ControllerEvents.DEVICE_REMOVED, event) + + def handle_group_member_removed(self, event: GroupMemberRemovedEvent) -> None: + """Handle group member removed event.""" + if event.group.id in self.groups: + self.groups[event.group.id].group_model = event.group + self.emit(ControllerEvents.GROUP_MEMBER_REMOVED, event) + + def handle_group_member_added(self, event: GroupMemberAddedEvent) -> None: + """Handle group member added event.""" + if event.group.id in self.groups: + self.groups[event.group.id].group_model = event.group + self.emit(ControllerEvents.GROUP_MEMBER_ADDED, event) + + def handle_group_added(self, event: GroupAddedEvent) -> None: + """Handle group added event.""" + if event.group.id in self.groups: + self.groups[event.group.id].group_model = event.group + else: + self.groups[event.group.id] = GroupProxy(event.group, self, self._client) + self.emit(ControllerEvents.GROUP_ADDED, event) + + def handle_group_removed(self, event: GroupRemovedEvent) -> None: + """Handle group removed event.""" + if event.group.id in self.groups: + self.groups.pop(event.group.id) + self.emit(ControllerEvents.GROUP_REMOVED, event) diff --git a/zha/websocket/client/helpers.py b/zha/websocket/client/helpers.py new file mode 100644 index 000000000..f3d519c7c --- /dev/null +++ b/zha/websocket/client/helpers.py @@ -0,0 +1,301 @@ +"""Helper classes for zhaws.client.""" + +from __future__ import annotations + +from typing import Any, cast + +from zigpy.types.named import EUI64 + +from zha.application.discovery import Platform +from zha.websocket.client.client import Client +from zha.websocket.client.model.commands import ( + CommandResponse, + GetDevicesResponse, + GroupsResponse, + PermitJoiningResponse, + ReadClusterAttributesResponse, + UpdateGroupResponse, + WriteClusterAttributeResponse, +) +from zha.websocket.client.model.types import ( + BaseEntity, + BasePlatformEntity, + Device, + Group, +) +from zha.websocket.server.client import ( + ClientDisconnectCommand, + ClientListenCommand, + ClientListenRawZCLCommand, +) +from zha.websocket.server.gateway import StopServerCommand +from zha.websocket.server.gateway_api import ( + AddGroupMembersCommand, + CreateGroupCommand, + GetDevicesCommand, + GetGroupsCommand, + PermitJoiningCommand, + ReadClusterAttributesCommand, + ReconfigureDeviceCommand, + RemoveDeviceCommand, + RemoveGroupMembersCommand, + RemoveGroupsCommand, + StartNetworkCommand, + StopNetworkCommand, + UpdateTopologyCommand, + WriteClusterAttributeCommand, +) + + +def ensure_platform_entity(entity: BaseEntity, platform: Platform) -> None: + """Ensure an entity exists and is from the specified platform.""" + if entity is None or entity.platform != platform: + raise ValueError( + f"entity must be provided and it must be a {platform} platform entity" + ) + + +class ClientHelper: + """Helper to send client specific commands.""" + + def __init__(self, client: Client): + """Initialize the client helper.""" + self._client: Client = client + + async def listen(self) -> CommandResponse: + """Listen for incoming messages.""" + command = ClientListenCommand() + return await self._client.async_send_command(command) + + async def listen_raw_zcl(self) -> CommandResponse: + """Listen for incoming raw ZCL messages.""" + command = ClientListenRawZCLCommand() + return await self._client.async_send_command(command) + + async def disconnect(self) -> CommandResponse: + """Disconnect this client from the server.""" + command = ClientDisconnectCommand() + return await self._client.async_send_command(command) + + +class GroupHelper: + """Helper to send group commands.""" + + def __init__(self, client: Client): + """Initialize the group helper.""" + self._client: Client = client + + async def get_groups(self) -> dict[int, Group]: + """Get the groups.""" + response = cast( + GroupsResponse, + await self._client.async_send_command(GetGroupsCommand()), + ) + return response.groups + + async def create_group( + self, + name: str, + unique_id: int | None = None, + members: list[BasePlatformEntity] | None = None, + ) -> Group: + """Create a new group.""" + request_data: dict[str, Any] = { + "group_name": name, + "group_id": unique_id, + } + if members is not None: + request_data["members"] = [ + {"ieee": member.device_ieee, "endpoint_id": member.endpoint_id} + for member in members + ] + + command = CreateGroupCommand(**request_data) + response = cast( + UpdateGroupResponse, + await self._client.async_send_command(command), + ) + return response.group + + async def remove_groups(self, groups: list[Group]) -> dict[int, Group]: + """Remove groups.""" + request: dict[str, Any] = { + "group_ids": [group.id for group in groups], + } + command = RemoveGroupsCommand(**request) + response = cast( + GroupsResponse, + await self._client.async_send_command(command), + ) + return response.groups + + async def add_group_members( + self, group: Group, members: list[BasePlatformEntity] + ) -> Group: + """Add members to a group.""" + request_data: dict[str, Any] = { + "group_id": group.id, + "members": [ + {"ieee": member.device_ieee, "endpoint_id": member.endpoint_id} + for member in members + ], + } + + command = AddGroupMembersCommand(**request_data) + response = cast( + UpdateGroupResponse, + await self._client.async_send_command(command), + ) + return response.group + + async def remove_group_members( + self, group: Group, members: list[BasePlatformEntity] + ) -> Group: + """Remove members from a group.""" + request_data: dict[str, Any] = { + "group_id": group.id, + "members": [ + {"ieee": member.device_ieee, "endpoint_id": member.endpoint_id} + for member in members + ], + } + + command = RemoveGroupMembersCommand(**request_data) + response = cast( + UpdateGroupResponse, + await self._client.async_send_command(command), + ) + return response.group + + +class DeviceHelper: + """Helper to send device commands.""" + + def __init__(self, client: Client): + """Initialize the device helper.""" + self._client: Client = client + + async def get_devices(self) -> dict[EUI64, Device]: + """Get the groups.""" + response = cast( + GetDevicesResponse, + await self._client.async_send_command(GetDevicesCommand()), + ) + return response.devices + + async def reconfigure_device(self, device: Device) -> None: + """Reconfigure a device.""" + await self._client.async_send_command( + ReconfigureDeviceCommand(ieee=device.ieee) + ) + + async def remove_device(self, device: Device) -> None: + """Remove a device.""" + await self._client.async_send_command(RemoveDeviceCommand(ieee=device.ieee)) + + async def read_cluster_attributes( + self, + device: Device, + cluster_id: int, + cluster_type: str, + endpoint_id: int, + attributes: list[str], + manufacturer_code: int | None = None, + ) -> ReadClusterAttributesResponse: + """Read cluster attributes.""" + response = cast( + ReadClusterAttributesResponse, + await self._client.async_send_command( + ReadClusterAttributesCommand( + ieee=device.ieee, + endpoint_id=endpoint_id, + cluster_id=cluster_id, + cluster_type=cluster_type, + attributes=attributes, + manufacturer_code=manufacturer_code, + ) + ), + ) + return response + + async def write_cluster_attribute( + self, + device: Device, + cluster_id: int, + cluster_type: str, + endpoint_id: int, + attribute: str, + value: Any, + manufacturer_code: int | None = None, + ) -> WriteClusterAttributeResponse: + """Set the value for a cluster attribute.""" + response = cast( + WriteClusterAttributeResponse, + await self._client.async_send_command( + WriteClusterAttributeCommand( + ieee=device.ieee, + endpoint_id=endpoint_id, + cluster_id=cluster_id, + cluster_type=cluster_type, + attribute=attribute, + value=value, + manufacturer_code=manufacturer_code, + ) + ), + ) + return response + + +class NetworkHelper: + """Helper for network commands.""" + + def __init__(self, client: Client): + """Initialize the device helper.""" + self._client: Client = client + + async def permit_joining( + self, duration: int = 255, device: Device | None = None + ) -> bool: + """Permit joining for a specified duration.""" + # TODO add permit with code support + request_data: dict[str, Any] = { + "duration": duration, + } + if device is not None: + if device.device_type == "EndDevice": + raise ValueError("Device is not a coordinator or router") + request_data["ieee"] = device.ieee + command = PermitJoiningCommand(**request_data) + response = cast( + PermitJoiningResponse, + await self._client.async_send_command(command), + ) + return response.success + + async def update_topology(self) -> None: + """Update the network topology.""" + await self._client.async_send_command(UpdateTopologyCommand()) + + async def start_network(self) -> bool: + """Start the Zigbee network.""" + command = StartNetworkCommand() + response = await self._client.async_send_command(command) + return response.success + + async def stop_network(self) -> bool: + """Stop the Zigbee network.""" + response = await self._client.async_send_command(StopNetworkCommand()) + return response.success + + +class ServerHelper: + """Helper for server commands.""" + + def __init__(self, client: Client): + """Initialize the helper.""" + self._client: Client = client + + async def stop_server(self) -> bool: + """Stop the websocket server.""" + response = await self._client.async_send_command(StopServerCommand()) + return response.success diff --git a/zha/websocket/client/model/__init__.py b/zha/websocket/client/model/__init__.py new file mode 100644 index 000000000..9f32bfa2f --- /dev/null +++ b/zha/websocket/client/model/__init__.py @@ -0,0 +1 @@ +"""Models for the websocket client module for zha.""" diff --git a/zha/websocket/client/model/commands.py b/zha/websocket/client/model/commands.py new file mode 100644 index 000000000..9d0eb878e --- /dev/null +++ b/zha/websocket/client/model/commands.py @@ -0,0 +1,200 @@ +"""Models that represent commands and command responses.""" + +from typing import Annotated, Any, Literal, Optional, Union + +from pydantic import field_validator +from pydantic.fields import Field +from zigpy.types.named import EUI64 + +from zha.model import BaseModel +from zha.websocket.client.model.events import MinimalCluster, MinimalDevice +from zha.websocket.client.model.types import Device, Group + + +class CommandResponse(BaseModel): + """Command response model.""" + + message_type: Literal["result"] = "result" + message_id: int + success: bool + + +class ErrorResponse(CommandResponse): + """Error response model.""" + + success: bool = False + error_code: str + error_message: str + zigbee_error_code: Optional[str] + command: Literal[ + "error.start_network", + "error.stop_network", + "error.remove_device", + "error.stop_server", + "error.light_turn_on", + "error.light_turn_off", + "error.switch_turn_on", + "error.switch_turn_off", + "error.lock_lock", + "error.lock_unlock", + "error.lock_set_user_lock_code", + "error.lock_clear_user_lock_code", + "error.lock_disable_user_lock_code", + "error.lock_enable_user_lock_code", + "error.fan_turn_on", + "error.fan_turn_off", + "error.fan_set_percentage", + "error.fan_set_preset_mode", + "error.cover_open", + "error.cover_close", + "error.cover_set_position", + "error.cover_stop", + "error.climate_set_fan_mode", + "error.climate_set_hvac_mode", + "error.climate_set_preset_mode", + "error.climate_set_temperature", + "error.button_press", + "error.alarm_control_panel_disarm", + "error.alarm_control_panel_arm_home", + "error.alarm_control_panel_arm_away", + "error.alarm_control_panel_arm_night", + "error.alarm_control_panel_trigger", + "error.select_select_option", + "error.siren_turn_on", + "error.siren_turn_off", + "error.number_set_value", + "error.platform_entity_refresh_state", + "error.client_listen", + "error.client_listen_raw_zcl", + "error.client_disconnect", + "error.reconfigure_device", + "error.UpdateNetworkTopologyCommand", + ] + + +class DefaultResponse(CommandResponse): + """Default command response.""" + + command: Literal[ + "start_network", + "stop_network", + "remove_device", + "stop_server", + "light_turn_on", + "light_turn_off", + "switch_turn_on", + "switch_turn_off", + "lock_lock", + "lock_unlock", + "lock_set_user_lock_code", + "lock_clear_user_lock_code", + "lock_disable_user_lock_code", + "lock_enable_user_lock_code", + "fan_turn_on", + "fan_turn_off", + "fan_set_percentage", + "fan_set_preset_mode", + "cover_open", + "cover_close", + "cover_set_position", + "cover_stop", + "climate_set_fan_mode", + "climate_set_hvac_mode", + "climate_set_preset_mode", + "climate_set_temperature", + "button_press", + "alarm_control_panel_disarm", + "alarm_control_panel_arm_home", + "alarm_control_panel_arm_away", + "alarm_control_panel_arm_night", + "alarm_control_panel_trigger", + "select_select_option", + "siren_turn_on", + "siren_turn_off", + "number_set_value", + "platform_entity_refresh_state", + "client_listen", + "client_listen_raw_zcl", + "client_disconnect", + "reconfigure_device", + "UpdateNetworkTopologyCommand", + ] + + +class PermitJoiningResponse(CommandResponse): + """Get devices response.""" + + command: Literal["permit_joining"] = "permit_joining" + duration: int + + +class GetDevicesResponse(CommandResponse): + """Get devices response.""" + + command: Literal["get_devices"] = "get_devices" + devices: dict[EUI64, Device] + + @field_validator("devices", mode="before", check_fields=False) + @classmethod + def convert_devices_device_ieee( + cls, devices: dict[str, dict] + ) -> dict[EUI64, Device]: + """Convert device ieee to EUI64.""" + return {EUI64.convert(k): Device(**v) for k, v in devices.items()} + + +class ReadClusterAttributesResponse(CommandResponse): + """Read cluster attributes response.""" + + command: Literal["read_cluster_attributes"] = "read_cluster_attributes" + device: MinimalDevice + cluster: MinimalCluster + manufacturer_code: Optional[int] + succeeded: dict[str, Any] + failed: dict[str, Any] + + +class AttributeStatus(BaseModel): + """Attribute status.""" + + attribute: str + status: str + + +class WriteClusterAttributeResponse(CommandResponse): + """Write cluster attribute response.""" + + command: Literal["write_cluster_attribute"] = "write_cluster_attribute" + device: MinimalDevice + cluster: MinimalCluster + manufacturer_code: Optional[int] + response: AttributeStatus + + +class GroupsResponse(CommandResponse): + """Get groups response.""" + + command: Literal["get_groups", "remove_groups"] + groups: dict[int, Group] + + +class UpdateGroupResponse(CommandResponse): + """Update group response.""" + + command: Literal["create_group", "add_group_members", "remove_group_members"] + group: Group + + +CommandResponses = Annotated[ + Union[ + DefaultResponse, + ErrorResponse, + GetDevicesResponse, + GroupsResponse, + PermitJoiningResponse, + UpdateGroupResponse, + ReadClusterAttributesResponse, + WriteClusterAttributeResponse, + ], + Field(discriminator="command"), # noqa: F821 +] diff --git a/zha/websocket/client/model/events.py b/zha/websocket/client/model/events.py new file mode 100644 index 000000000..03496addc --- /dev/null +++ b/zha/websocket/client/model/events.py @@ -0,0 +1,263 @@ +"""Event models for zhawss. + +Events are unprompted messages from the server -> client and they contain only the data that is necessary to +handle the event. +""" + +from typing import Annotated, Any, Literal, Optional, Union + +from pydantic.fields import Field +from zigpy.types.named import EUI64 + +from zha.model import BaseEvent, BaseModel +from zha.websocket.client.model.types import ( + BaseDevice, + BatteryState, + BooleanState, + CoverState, + Device, + DeviceSignature, + DeviceTrackerState, + ElectricalMeasurementState, + FanState, + GenericState, + Group, + LightState, + LockState, + ShadeState, + SmareEnergyMeteringState, + SwitchState, + ThermostatState, +) + + +class MinimalPlatformEntity(BaseModel): + """Platform entity model.""" + + unique_id: str + platform: str + + +class MinimalEndpoint(BaseModel): + """Minimal endpoint model.""" + + id: int + unique_id: str + + +class MinimalDevice(BaseModel): + """Minimal device model.""" + + ieee: EUI64 + + +class Attribute(BaseModel): + """Attribute model.""" + + id: int + name: str + value: Any = None + + +class MinimalCluster(BaseModel): + """Minimal cluster model.""" + + id: int + endpoint_attribute: str + name: str + endpoint_id: int + + +class MinimalClusterHandler(BaseModel): + """Minimal cluster handler model.""" + + unique_id: str + cluster: MinimalCluster + + +class MinimalGroup(BaseModel): + """Minimal group model.""" + + id: int + + +class PlatformEntityStateChangedEvent(BaseEvent): + """Platform entity event.""" + + event_type: Literal["platform_entity_event"] = "platform_entity_event" + event: Literal["platform_entity_state_changed"] = "platform_entity_state_changed" + platform_entity: MinimalPlatformEntity + endpoint: Optional[MinimalEndpoint] = None + device: Optional[MinimalDevice] = None + group: Optional[MinimalGroup] = None + state: Annotated[ + Optional[ + Union[ + DeviceTrackerState, + CoverState, + ShadeState, + FanState, + LockState, + BatteryState, + ElectricalMeasurementState, + LightState, + SwitchState, + SmareEnergyMeteringState, + GenericState, + BooleanState, + ThermostatState, + ] + ], + Field(discriminator="class_name"), # noqa: F821 + ] + + +class ZCLAttributeUpdatedEvent(BaseEvent): + """ZCL attribute updated event.""" + + event_type: Literal["raw_zcl_event"] = "raw_zcl_event" + event: Literal["attribute_updated"] = "attribute_updated" + device: MinimalDevice + cluster_handler: MinimalClusterHandler + attribute: Attribute + endpoint: MinimalEndpoint + + +class ControllerEvent(BaseEvent): + """Controller event.""" + + event_type: Literal["controller_event"] = "controller_event" + + +class DevicePairingEvent(ControllerEvent): + """Device pairing event.""" + + pairing_status: str + + +class DeviceJoinedEvent(DevicePairingEvent): + """Device joined event.""" + + event: Literal["device_joined"] = "device_joined" + ieee: EUI64 + nwk: str + + +class RawDeviceInitializedEvent(DevicePairingEvent): + """Raw device initialized event.""" + + event: Literal["raw_device_initialized"] = "raw_device_initialized" + ieee: EUI64 + nwk: str + manufacturer: str + model: str + signature: DeviceSignature + + +class DeviceFullyInitializedEvent(DevicePairingEvent): + """Device fully initialized event.""" + + event: Literal["device_fully_initialized"] = "device_fully_initialized" + device: Device + new_join: bool + + +class DeviceConfiguredEvent(DevicePairingEvent): + """Device configured event.""" + + event: Literal["device_configured"] = "device_configured" + device: BaseDevice + + +class DeviceLeftEvent(ControllerEvent): + """Device left event.""" + + event: Literal["device_left"] = "device_left" + ieee: EUI64 + nwk: str + + +class DeviceRemovedEvent(ControllerEvent): + """Device removed event.""" + + event: Literal["device_removed"] = "device_removed" + device: Device + + +class DeviceOfflineEvent(BaseEvent): + """Device offline event.""" + + event: Literal["device_offline"] = "device_offline" + event_type: Literal["device_event"] = "device_event" + device: MinimalDevice + + +class DeviceOnlineEvent(BaseEvent): + """Device online event.""" + + event: Literal["device_online"] = "device_online" + event_type: Literal["device_event"] = "device_event" + device: MinimalDevice + + +class ZHAEvent(BaseEvent): + """ZHA event.""" + + event: Literal["zha_event"] = "zha_event" + event_type: Literal["device_event"] = "device_event" + device: MinimalDevice + cluster_handler: MinimalClusterHandler + endpoint: MinimalEndpoint + command: str + args: Union[list, dict] + params: dict[str, Any] + + +class GroupRemovedEvent(ControllerEvent): + """Group removed event.""" + + event: Literal["group_removed"] = "group_removed" + group: Group + + +class GroupAddedEvent(ControllerEvent): + """Group added event.""" + + event: Literal["group_added"] = "group_added" + group: Group + + +class GroupMemberAddedEvent(ControllerEvent): + """Group member added event.""" + + event: Literal["group_member_added"] = "group_member_added" + group: Group + + +class GroupMemberRemovedEvent(ControllerEvent): + """Group member removed event.""" + + event: Literal["group_member_removed"] = "group_member_removed" + group: Group + + +Events = Annotated[ + Union[ + PlatformEntityStateChangedEvent, + ZCLAttributeUpdatedEvent, + DeviceJoinedEvent, + RawDeviceInitializedEvent, + DeviceFullyInitializedEvent, + DeviceConfiguredEvent, + DeviceLeftEvent, + DeviceRemovedEvent, + GroupRemovedEvent, + GroupAddedEvent, + GroupMemberAddedEvent, + GroupMemberRemovedEvent, + DeviceOfflineEvent, + DeviceOnlineEvent, + ZHAEvent, + ], + Field(discriminator="event"), # noqa: F821 +] diff --git a/zha/websocket/client/model/messages.py b/zha/websocket/client/model/messages.py new file mode 100644 index 000000000..9e5149bd4 --- /dev/null +++ b/zha/websocket/client/model/messages.py @@ -0,0 +1,67 @@ +"""Models that represent messages in zhawss.""" + +from typing import Annotated, Any, Optional, Union + +from pydantic import RootModel, field_serializer, field_validator +from pydantic.fields import Field +from zigpy.types.named import EUI64 + +from zha.websocket.client.model.commands import CommandResponses +from zha.websocket.client.model.events import Events + + +class Message(RootModel): + """Response model.""" + + root: Annotated[ + Union[CommandResponses, Events], + Field(discriminator="message_type"), # noqa: F821 + ] + + @field_validator("ieee", mode="before", check_fields=False) + @classmethod + def convert_ieee(cls, ieee: Optional[Union[str, EUI64]]) -> Optional[EUI64]: + """Convert ieee to EUI64.""" + if ieee is None: + return None + if isinstance(ieee, str): + return EUI64.convert(ieee) + if isinstance(ieee, list) and not isinstance(ieee, EUI64): + return EUI64.deserialize(ieee)[0] + return ieee + + @field_serializer("ieee", check_fields=False) + def serialize_ieee(self, ieee): + """Customize how ieee is serialized.""" + if isinstance(ieee, EUI64): + return str(ieee) + return ieee + + @field_validator("device_ieee", mode="before", check_fields=False) + @classmethod + def convert_device_ieee( + cls, device_ieee: Optional[Union[str, EUI64]] + ) -> Optional[EUI64]: + """Convert device ieee to EUI64.""" + if device_ieee is None: + return None + if isinstance(device_ieee, str): + return EUI64.convert(device_ieee) + if isinstance(device_ieee, list) and not isinstance(device_ieee, EUI64): + return EUI64.deserialize(device_ieee)[0] + return device_ieee + + @field_serializer("device_ieee", check_fields=False) + def serialize_device_ieee(self, device_ieee): + """Customize how device_ieee is serialized.""" + if isinstance(device_ieee, EUI64): + return str(device_ieee) + return device_ieee + + @classmethod + def _get_value(cls, *args, **kwargs) -> Any: + """Convert EUI64 to string.""" + value = args[0] + if isinstance(value, EUI64): + return str(value) + return RootModel._get_value(cls, *args, **kwargs) diff --git a/zha/websocket/client/model/types.py b/zha/websocket/client/model/types.py new file mode 100644 index 000000000..83d3b8c15 --- /dev/null +++ b/zha/websocket/client/model/types.py @@ -0,0 +1,760 @@ +"""Models that represent types for the zhaws.client. + +Types are representations of the objects that exist in zhawss. +""" + +from typing import Annotated, Any, Literal, Optional, Union + +from pydantic import ValidationInfo, field_serializer, field_validator +from pydantic.fields import Field +from zigpy.types.named import EUI64, NWK +from zigpy.zdo.types import NodeDescriptor as ZigpyNodeDescriptor + +from zha.event import EventBase +from zha.model import BaseModel + + +class BaseEventedModel(EventBase, BaseModel): + """Base evented model.""" + + +class Cluster(BaseModel): + """Cluster model.""" + + id: int + endpoint_attribute: str + name: str + endpoint_id: int + type: str + commands: list[str] + + +class ClusterHandler(BaseModel): + """Cluster handler model.""" + + unique_id: str + cluster: Cluster + class_name: str + generic_id: str + endpoint_id: int + id: str + status: str + + +class Endpoint(BaseModel): + """Endpoint model.""" + + id: int + unique_id: str + + +class GenericState(BaseModel): + """Default state model.""" + + class_name: Literal[ + "ZHAAlarmControlPanel", + "Number", + "DefaultToneSelectEntity", + "DefaultSirenLevelSelectEntity", + "DefaultStrobeLevelSelectEntity", + "DefaultStrobeSelectEntity", + "AnalogInput", + "Humidity", + "SoilMoisture", + "LeafWetness", + "Illuminance", + "Pressure", + "Temperature", + "CarbonDioxideConcentration", + "CarbonMonoxideConcentration", + "VOCLevel", + "PPBVOCLevel", + "FormaldehydeConcentration", + "ThermostatHVACAction", + "SinopeHVACAction", + "RSSISensor", + "LQISensor", + "LastSeenSensor", + ] + state: Union[str, bool, int, float, None] = None + + +class DeviceCounterSensorState(BaseModel): + """Device counter sensor state model.""" + + class_name: Literal["DeviceCounterSensor"] = "DeviceCounterSensor" + state: int + + +class DeviceTrackerState(BaseModel): + """Device tracker state model.""" + + class_name: Literal["DeviceTracker"] = "DeviceTracker" + connected: bool + battery_level: Optional[float] = None + + +class BooleanState(BaseModel): + """Boolean value state model.""" + + class_name: Literal[ + "Accelerometer", + "Occupancy", + "Opening", + "BinaryInput", + "Motion", + "IASZone", + "Siren", + ] + state: bool + + +class CoverState(BaseModel): + """Cover state model.""" + + class_name: Literal["Cover"] = "Cover" + current_position: int + state: Optional[str] = None + is_opening: bool + is_closing: bool + is_closed: bool + + +class ShadeState(BaseModel): + """Cover state model.""" + + class_name: Literal["Shade", "KeenVent"] + current_position: Optional[int] = ( + None # TODO: how should we represent this when it is None? + ) + is_closed: bool + state: Optional[str] = None + + +class FanState(BaseModel): + """Fan state model.""" + + class_name: Literal["Fan", "FanGroup"] + preset_mode: Optional[str] = ( + None # TODO: how should we represent these when they are None? + ) + percentage: Optional[int] = ( + None # TODO: how should we represent these when they are None? + ) + is_on: bool + speed: Optional[str] = None + + +class LockState(BaseModel): + """Lock state model.""" + + class_name: Literal["Lock"] = "Lock" + is_locked: bool + + +class BatteryState(BaseModel): + """Battery state model.""" + + class_name: Literal["Battery"] = "Battery" + state: Optional[Union[str, float, int]] = None + battery_size: Optional[str] = None + battery_quantity: Optional[int] = None + battery_voltage: Optional[float] = None + + +class ElectricalMeasurementState(BaseModel): + """Electrical measurement state model.""" + + class_name: Literal[ + "ElectricalMeasurement", + "ElectricalMeasurementApparentPower", + "ElectricalMeasurementRMSCurrent", + "ElectricalMeasurementRMSVoltage", + ] + state: Optional[Union[str, float, int]] = None + measurement_type: Optional[str] = None + active_power_max: Optional[str] = None + rms_current_max: Optional[str] = None + rms_voltage_max: Optional[str] = None + + +class LightState(BaseModel): + """Light state model.""" + + class_name: Literal["Light", "HueLight", "ForceOnLight", "LightGroup"] + on: bool + brightness: Optional[int] = None + hs_color: Optional[tuple[float, float]] = None + color_temp: Optional[int] = None + effect: Optional[str] = None + off_brightness: Optional[int] = None + + +class ThermostatState(BaseModel): + """Thermostat state model.""" + + class_name: Literal[ + "Thermostat", + "SinopeTechnologiesThermostat", + "ZenWithinThermostat", + "MoesThermostat", + "BecaThermostat", + ] + current_temperature: Optional[float] = None + target_temperature: Optional[float] = None + target_temperature_low: Optional[float] = None + target_temperature_high: Optional[float] = None + hvac_action: Optional[str] = None + hvac_mode: Optional[str] = None + preset_mode: Optional[str] = None + fan_mode: Optional[str] = None + + +class SwitchState(BaseModel): + """Switch state model.""" + + class_name: Literal["Switch", "SwitchGroup"] + state: bool + + +class SmareEnergyMeteringState(BaseModel): + """Smare energy metering state model.""" + + class_name: Literal["SmartEnergyMetering", "SmartEnergySummation"] + state: Optional[Union[str, float, int]] = None + device_type: Optional[str] = None + status: Optional[str] = None + + +class BaseEntity(BaseEventedModel): + """Base platform entity model.""" + + unique_id: str + platform: str + class_name: str + fallback_name: str | None = None + translation_key: str | None = None + device_class: str | None = None + state_class: str | None = None + entity_category: str | None = None + entity_registry_enabled_default: bool + enabled: bool + + +class BasePlatformEntity(BaseEntity): + """Base platform entity model.""" + + device_ieee: EUI64 + endpoint_id: int + + +class LockEntity(BasePlatformEntity): + """Lock entity model.""" + + class_name: Literal["Lock"] + state: LockState + + +class DeviceTrackerEntity(BasePlatformEntity): + """Device tracker entity model.""" + + class_name: Literal["DeviceTracker"] + state: DeviceTrackerState + + +class CoverEntity(BasePlatformEntity): + """Cover entity model.""" + + class_name: Literal["Cover"] + state: CoverState + + +class ShadeEntity(BasePlatformEntity): + """Shade entity model.""" + + class_name: Literal["Shade", "KeenVent"] + state: ShadeState + + +class BinarySensorEntity(BasePlatformEntity): + """Binary sensor model.""" + + class_name: Literal[ + "Accelerometer", "Occupancy", "Opening", "BinaryInput", "Motion", "IASZone" + ] + attribute_name: str + state: BooleanState + + +class BaseSensorEntity(BasePlatformEntity): + """Sensor model.""" + + attribute: Optional[str] + decimals: int + divisor: int + multiplier: Union[int, float] + unit: Optional[int | str] + + +class SensorEntity(BaseSensorEntity): + """Sensor entity model.""" + + class_name: Literal[ + "AnalogInput", + "Humidity", + "SoilMoisture", + "LeafWetness", + "Illuminance", + "Pressure", + "Temperature", + "CarbonDioxideConcentration", + "CarbonMonoxideConcentration", + "VOCLevel", + "PPBVOCLevel", + "FormaldehydeConcentration", + "ThermostatHVACAction", + "SinopeHVACAction", + "RSSISensor", + "LQISensor", + "LastSeenSensor", + ] + state: GenericState + + +class DeviceCounterSensorEntity(BaseEntity): + """Device counter sensor model.""" + + class_name: Literal["DeviceCounterSensor"] + counter: str + counter_value: int + counter_groups: str + counter_group: str + state: DeviceCounterSensorState + + @field_validator("state", mode="before", check_fields=False) + @classmethod + def convert_state( + cls, state: dict | int | None, validation_info: ValidationInfo + ) -> DeviceCounterSensorState: + """Convert counter value to counter_value.""" + if state is not None: + if isinstance(state, int): + return DeviceCounterSensorState(state=state) + if isinstance(state, dict): + if "state" in state: + return DeviceCounterSensorState(state=state["state"]) + else: + return DeviceCounterSensorState( + state=validation_info.data["counter_value"] + ) + return DeviceCounterSensorState(state=validation_info.data["counter_value"]) + + +class BatteryEntity(BaseSensorEntity): + """Battery entity model.""" + + class_name: Literal["Battery"] + state: BatteryState + + +class ElectricalMeasurementEntity(BaseSensorEntity): + """Electrical measurement entity model.""" + + class_name: Literal[ + "ElectricalMeasurement", + "ElectricalMeasurementApparentPower", + "ElectricalMeasurementRMSCurrent", + "ElectricalMeasurementRMSVoltage", + ] + state: ElectricalMeasurementState + + +class SmartEnergyMeteringEntity(BaseSensorEntity): + """Smare energy metering entity model.""" + + class_name: Literal["SmartEnergyMetering", "SmartEnergySummation"] + state: SmareEnergyMeteringState + + +class AlarmControlPanelEntity(BasePlatformEntity): + """Alarm control panel model.""" + + class_name: Literal["ZHAAlarmControlPanel"] + supported_features: int + code_required_arm_actions: bool + max_invalid_tries: int + state: GenericState + + +class ButtonEntity(BasePlatformEntity): + """Button model.""" + + class_name: Literal["IdentifyButton"] + command: str + + +class FanEntity(BasePlatformEntity): + """Fan model.""" + + class_name: Literal["Fan"] + preset_modes: list[str] + supported_features: int + speed_count: int + speed_list: list[str] + percentage_step: float + state: FanState + + +class LightEntity(BasePlatformEntity): + """Light model.""" + + class_name: Literal["Light", "HueLight", "ForceOnLight"] + supported_features: int + min_mireds: int + max_mireds: int + effect_list: Optional[list[str]] + state: LightState + + +class NumberEntity(BasePlatformEntity): + """Number entity model.""" + + class_name: Literal["Number"] + engineering_units: Optional[ + int + ] # TODO: how should we represent this when it is None? + application_type: Optional[ + int + ] # TODO: how should we represent this when it is None? + step: Optional[float] # TODO: how should we represent this when it is None? + min_value: float + max_value: float + state: GenericState + + +class SelectEntity(BasePlatformEntity): + """Select entity model.""" + + class_name: Literal[ + "DefaultToneSelectEntity", + "DefaultSirenLevelSelectEntity", + "DefaultStrobeLevelSelectEntity", + "DefaultStrobeSelectEntity", + ] + enum: str + options: list[str] + state: GenericState + + +class ThermostatEntity(BasePlatformEntity): + """Thermostat entity model.""" + + class_name: Literal[ + "Thermostat", + "SinopeTechnologiesThermostat", + "ZenWithinThermostat", + "MoesThermostat", + "BecaThermostat", + ] + state: ThermostatState + hvac_modes: tuple[str, ...] + fan_modes: Optional[list[str]] + preset_modes: Optional[list[str]] + + +class SirenEntity(BasePlatformEntity): + """Siren entity model.""" + + class_name: Literal["Siren"] + available_tones: Optional[Union[list[Union[int, str]], dict[int, str]]] + supported_features: int + state: BooleanState + + +class SwitchEntity(BasePlatformEntity): + """Switch entity model.""" + + class_name: Literal["Switch"] + state: SwitchState + + +class DeviceSignatureEndpoint(BaseModel): + """Device signature endpoint model.""" + + profile_id: Optional[str] = None + device_type: Optional[str] = None + input_clusters: list[str] + output_clusters: list[str] + + @field_validator("profile_id", mode="before", check_fields=False) + @classmethod + def convert_profile_id(cls, profile_id: int | str) -> str: + """Convert profile_id.""" + if isinstance(profile_id, int): + return f"0x{profile_id:04x}" + return profile_id + + @field_validator("device_type", mode="before", check_fields=False) + @classmethod + def convert_device_type(cls, device_type: int | str) -> str: + """Convert device_type.""" + if isinstance(device_type, int): + return f"0x{device_type:04x}" + return device_type + + @field_validator("input_clusters", mode="before", check_fields=False) + @classmethod + def convert_input_clusters(cls, input_clusters: list[int | str]) -> list[str]: + """Convert input_clusters.""" + clusters = [] + for cluster_id in input_clusters: + if isinstance(cluster_id, int): + clusters.append(f"0x{cluster_id:04x}") + else: + clusters.append(cluster_id) + return clusters + + @field_validator("output_clusters", mode="before", check_fields=False) + @classmethod + def convert_output_clusters(cls, output_clusters: list[int | str]) -> list[str]: + """Convert output_clusters.""" + clusters = [] + for cluster_id in output_clusters: + if isinstance(cluster_id, int): + clusters.append(f"0x{cluster_id:04x}") + else: + clusters.append(cluster_id) + return clusters + + +class NodeDescriptor(BaseModel): + """Node descriptor model.""" + + logical_type: int + complex_descriptor_available: bool + user_descriptor_available: bool + reserved: int + aps_flags: int + frequency_band: int + mac_capability_flags: int + manufacturer_code: int + maximum_buffer_size: int + maximum_incoming_transfer_size: int + server_mask: int + maximum_outgoing_transfer_size: int + descriptor_capability_field: int + + +class DeviceSignature(BaseModel): + """Device signature model.""" + + node_descriptor: Optional[NodeDescriptor] = None + manufacturer: Optional[str] = None + model: Optional[str] = None + endpoints: dict[int, DeviceSignatureEndpoint] + + @field_validator("node_descriptor", mode="before", check_fields=False) + @classmethod + def convert_node_descriptor( + cls, node_descriptor: ZigpyNodeDescriptor + ) -> NodeDescriptor: + """Convert node descriptor.""" + if isinstance(node_descriptor, ZigpyNodeDescriptor): + return node_descriptor.as_dict() + return node_descriptor + + +class BaseDevice(BaseModel): + """Base device model.""" + + ieee: EUI64 + nwk: str + manufacturer: str + model: str + name: str + quirk_applied: bool + quirk_class: Union[str, None] = None + manufacturer_code: int + power_source: str + lqi: Union[int, None] = None + rssi: Union[int, None] = None + last_seen: str + available: bool + device_type: Literal["Coordinator", "Router", "EndDevice"] + signature: DeviceSignature + + @field_validator("nwk", mode="before", check_fields=False) + @classmethod + def convert_nwk(cls, nwk: NWK) -> str: + """Convert nwk to hex.""" + if isinstance(nwk, NWK): + return repr(nwk) + return nwk + + @field_serializer("ieee") + def serialize_ieee(self, ieee): + """Customize how ieee is serialized.""" + if isinstance(ieee, EUI64): + return str(ieee) + return ieee + + +class Device(BaseDevice): + """Device model.""" + + entities: dict[ + str, + Annotated[ + Union[ + SirenEntity, + SelectEntity, + NumberEntity, + LightEntity, + FanEntity, + ButtonEntity, + AlarmControlPanelEntity, + SensorEntity, + BinarySensorEntity, + DeviceTrackerEntity, + ShadeEntity, + CoverEntity, + LockEntity, + SwitchEntity, + BatteryEntity, + ElectricalMeasurementEntity, + SmartEnergyMeteringEntity, + ThermostatEntity, + DeviceCounterSensorEntity, + ], + Field(discriminator="class_name"), # noqa: F821 + ], + ] + neighbors: list[Any] + device_automation_triggers: dict[str, dict[str, Any]] + + @field_validator("entities", mode="before", check_fields=False) + @classmethod + def convert_entities(cls, entities: dict) -> dict: + """Convert entities keys from tuple to string.""" + if all(isinstance(k, tuple) for k in entities): + return {f"{k[0]}.{k[1]}": v for k, v in entities.items()} + assert all(isinstance(k, str) for k in entities) + return entities + + @field_validator("device_automation_triggers", mode="before", check_fields=False) + @classmethod + def convert_device_automation_triggers(cls, triggers: dict) -> dict: + """Convert device automation triggers keys from tuple to string.""" + if all(isinstance(k, tuple) for k in triggers): + return {f"{k[0]}~{k[1]}": v for k, v in triggers.items()} + return triggers + + +class GroupEntity(BaseEntity): + """Group entity model.""" + + group_id: int + state: Any + + +class LightGroupEntity(GroupEntity): + """Group entity model.""" + + class_name: Literal["LightGroup"] + state: LightState + + +class FanGroupEntity(GroupEntity): + """Group entity model.""" + + class_name: Literal["FanGroup"] + state: FanState + + +class SwitchGroupEntity(GroupEntity): + """Group entity model.""" + + class_name: Literal["SwitchGroup"] + state: SwitchState + + +class GroupMember(BaseModel): + """Group member model.""" + + ieee: EUI64 + endpoint_id: int + device: Device = Field(alias="device_info") + entities: dict[ + str, + Annotated[ + Union[ + SirenEntity, + SelectEntity, + NumberEntity, + LightEntity, + FanEntity, + ButtonEntity, + AlarmControlPanelEntity, + SensorEntity, + BinarySensorEntity, + DeviceTrackerEntity, + ShadeEntity, + CoverEntity, + LockEntity, + SwitchEntity, + BatteryEntity, + ElectricalMeasurementEntity, + SmartEnergyMeteringEntity, + ThermostatEntity, + ], + Field(discriminator="class_name"), # noqa: F821 + ], + ] + + +class Group(BaseModel): + """Group model.""" + + name: str + id: int + members: dict[EUI64, GroupMember] + entities: dict[ + str, + Annotated[ + Union[LightGroupEntity, FanGroupEntity, SwitchGroupEntity], + Field(discriminator="class_name"), # noqa: F821 + ], + ] + + @field_validator("members", mode="before", check_fields=False) + @classmethod + def convert_members(cls, members: dict | list[dict]) -> dict: + """Convert members.""" + + converted_members = {} + if isinstance(members, dict): + return {EUI64.convert(k): v for k, v in members.items()} + for member in members: + if "device" in member: + ieee = member["device"]["ieee"] + else: + ieee = member["device_info"]["ieee"] + if isinstance(ieee, str): + ieee = EUI64.convert(ieee) + elif isinstance(ieee, list) and not isinstance(ieee, EUI64): + ieee = EUI64.deserialize(ieee)[0] + converted_members[ieee] = member + return converted_members + + @field_serializer("members") + def serialize_members(self, members): + """Customize how members are serialized.""" + data = {str(k): v.model_dump(by_alias=True) for k, v in members.items()} + return data + + +class GroupMemberReference(BaseModel): + """Group member reference model.""" + + ieee: EUI64 + endpoint_id: int diff --git a/zha/websocket/client/proxy.py b/zha/websocket/client/proxy.py new file mode 100644 index 000000000..92db0e20e --- /dev/null +++ b/zha/websocket/client/proxy.py @@ -0,0 +1,114 @@ +"""Proxy object for the client side objects.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from zha.event import EventBase +from zha.websocket.client.model.events import PlatformEntityStateChangedEvent +from zha.websocket.client.model.types import ( + ButtonEntity, + Device as DeviceModel, + Group as GroupModel, +) + +if TYPE_CHECKING: + from zha.websocket.client.client import Client + from zha.websocket.client.controller import Controller + + +class BaseProxyObject(EventBase): + """BaseProxyObject for the zhaws.client.""" + + def __init__(self, controller: Controller, client: Client): + """Initialize the BaseProxyObject class.""" + super().__init__() + self._controller: Controller = controller + self._client: Client = client + self._proxied_object: GroupModel | DeviceModel + + @property + def controller(self) -> Controller: + """Return the controller.""" + return self._controller + + @property + def client(self) -> Client: + """Return the client.""" + return self._client + + def emit_platform_entity_event( + self, event: PlatformEntityStateChangedEvent + ) -> None: + """Proxy the firing of an entity event.""" + entity = self._proxied_object.entities.get( + f"{event.platform_entity.platform}.{event.platform_entity.unique_id}" + if event.group is None + else event.platform_entity.unique_id + ) + if entity is None: + if isinstance(self._proxied_object, DeviceModel): + raise ValueError( + f"Entity not found: {event.platform_entity.unique_id}", + ) + return # group entities are updated to get state when created so we may not have the entity yet + if not isinstance(entity, ButtonEntity): + entity.state = event.state + self.emit(f"{event.platform_entity.unique_id}_{event.event}", event) + + +class GroupProxy(BaseProxyObject): + """Group proxy for the zhaws.client.""" + + def __init__(self, group_model: GroupModel, controller: Controller, client: Client): + """Initialize the GroupProxy class.""" + super().__init__(controller, client) + self._proxied_object: GroupModel = group_model + + @property + def group_model(self) -> GroupModel: + """Return the group model.""" + return self._proxied_object + + @group_model.setter + def group_model(self, group_model: GroupModel) -> None: + """Set the group model.""" + self._proxied_object = group_model + + def __repr__(self) -> str: + """Return the string representation of the group proxy.""" + return self._proxied_object.__repr__() + + +class DeviceProxy(BaseProxyObject): + """Device proxy for the zhaws.client.""" + + def __init__( + self, device_model: DeviceModel, controller: Controller, client: Client + ): + """Initialize the DeviceProxy class.""" + super().__init__(controller, client) + self._proxied_object: DeviceModel = device_model + + @property + def device_model(self) -> DeviceModel: + """Return the device model.""" + return self._proxied_object + + @device_model.setter + def device_model(self, device_model: DeviceModel) -> None: + """Set the device model.""" + self._proxied_object = device_model + + @property + def device_automation_triggers(self) -> dict[tuple[str, str], dict[str, Any]]: + """Return the device automation triggers.""" + model_triggers = self._proxied_object.device_automation_triggers + return { + (key.split("~")[0], key.split("~")[1]): value + for key, value in model_triggers.items() + } + + def __repr__(self) -> str: + """Return the string representation of the device proxy.""" + return self._proxied_object.__repr__() diff --git a/zha/websocket/const.py b/zha/websocket/const.py new file mode 100644 index 000000000..a5c6eca03 --- /dev/null +++ b/zha/websocket/const.py @@ -0,0 +1,170 @@ +"""Constants.""" + +from enum import StrEnum +from typing import Final + + +class APICommands(StrEnum): + """WS API commands.""" + + # Device commands + GET_DEVICES = "get_devices" + REMOVE_DEVICE = "remove_device" + RECONFIGURE_DEVICE = "reconfigure_device" + READ_CLUSTER_ATTRIBUTES = "read_cluster_attributes" + WRITE_CLUSTER_ATTRIBUTE = "write_cluster_attribute" + + # Zigbee API commands + PERMIT_JOINING = "permit_joining" + START_NETWORK = "start_network" + STOP_NETWORK = "stop_network" + UPDATE_NETWORK_TOPOLOGY = "update_network_topology" + + # Group commands + GET_GROUPS = "get_groups" + CREATE_GROUP = "create_group" + REMOVE_GROUPS = "remove_groups" + ADD_GROUP_MEMBERS = "add_group_members" + REMOVE_GROUP_MEMBERS = "remove_group_members" + + # Server API commands + STOP_SERVER = "stop_server" + + # Light API commands + LIGHT_TURN_ON = "light_turn_on" + LIGHT_TURN_OFF = "light_turn_off" + + # Switch API commands + SWITCH_TURN_ON = "switch_turn_on" + SWITCH_TURN_OFF = "switch_turn_off" + + SIREN_TURN_ON = "siren_turn_on" + SIREN_TURN_OFF = "siren_turn_off" + + LOCK_UNLOCK = "lock_unlock" + LOCK_LOCK = "lock_lock" + LOCK_SET_USER_CODE = "lock_set_user_lock_code" + LOCK_ENAABLE_USER_CODE = "lock_enable_user_lock_code" + LOCK_DISABLE_USER_CODE = "lock_disable_user_lock_code" + LOCK_CLEAR_USER_CODE = "lock_clear_user_lock_code" + + CLIMATE_SET_TEMPERATURE = "climate_set_temperature" + CLIMATE_SET_HVAC_MODE = "climate_set_hvac_mode" + CLIMATE_SET_FAN_MODE = "climate_set_fan_mode" + CLIMATE_SET_PRESET_MODE = "climate_set_preset_mode" + + COVER_OPEN = "cover_open" + COVER_CLOSE = "cover_close" + COVER_STOP = "cover_stop" + COVER_SET_POSITION = "cover_set_position" + + FAN_TURN_ON = "fan_turn_on" + FAN_TURN_OFF = "fan_turn_off" + FAN_SET_PERCENTAGE = "fan_set_percentage" + FAN_SET_PRESET_MODE = "fan_set_preset_mode" + + BUTTON_PRESS = "button_press" + + ALARM_CONTROL_PANEL_DISARM = "alarm_control_panel_disarm" + ALARM_CONTROL_PANEL_ARM_HOME = "alarm_control_panel_arm_home" + ALARM_CONTROL_PANEL_ARM_AWAY = "alarm_control_panel_arm_away" + ALARM_CONTROL_PANEL_ARM_NIGHT = "alarm_control_panel_arm_night" + ALARM_CONTROL_PANEL_TRIGGER = "alarm_control_panel_trigger" + + SELECT_SELECT_OPTION = "select_select_option" + + NUMBER_SET_VALUE = "number_set_value" + + PLATFORM_ENTITY_REFRESH_STATE = "platform_entity_refresh_state" + + CLIENT_LISTEN = "client_listen" + CLIENT_LISTEN_RAW_ZCL = "client_listen_raw_zcl" + CLIENT_DISCONNECT = "client_disconnect" + + +class MessageTypes(StrEnum): + """WS message types.""" + + EVENT = "event" + RESULT = "result" + + +class EventTypes(StrEnum): + """WS event types.""" + + CONTROLLER_EVENT = "controller_event" + PLATFORM_ENTITY_EVENT = "platform_entity_event" + RAW_ZCL_EVENT = "raw_zcl_event" + DEVICE_EVENT = "device_event" + + +class ControllerEvents(StrEnum): + """WS controller events.""" + + DEVICE_JOINED = "device_joined" + RAW_DEVICE_INITIALIZED = "raw_device_initialized" + DEVICE_REMOVED = "device_removed" + DEVICE_LEFT = "device_left" + DEVICE_FULLY_INITIALIZED = "device_fully_initialized" + DEVICE_CONFIGURED = "device_configured" + GROUP_MEMBER_ADDED = "group_member_added" + GROUP_MEMBER_REMOVED = "group_member_removed" + GROUP_ADDED = "group_added" + GROUP_REMOVED = "group_removed" + + +class PlatformEntityEvents(StrEnum): + """WS platform entity events.""" + + PLATFORM_ENTITY_STATE_CHANGED = "platform_entity_state_changed" + + +class RawZCLEvents(StrEnum): + """WS raw ZCL events.""" + + ATTRIBUTE_UPDATED = "attribute_updated" + + +class DeviceEvents(StrEnum): + """Events that devices can broadcast.""" + + DEVICE_OFFLINE = "device_offline" + DEVICE_ONLINE = "device_online" + ZHA_EVENT = "zha_event" + + +ATTR_UNIQUE_ID: Final[str] = "unique_id" +COMMAND: Final[str] = "command" +CONF_BAUDRATE: Final[str] = "baudrate" +CONF_CUSTOM_QUIRKS_PATH: Final[str] = "custom_quirks_path" +CONF_DATABASE: Final[str] = "database_path" +CONF_DEFAULT_LIGHT_TRANSITION: Final[str] = "default_light_transition" +CONF_DEVICE_CONFIG: Final[str] = "device_config" +CONF_ENABLE_IDENTIFY_ON_JOIN: Final[str] = "enable_identify_on_join" +CONF_ENABLE_QUIRKS: Final[str] = "enable_quirks" +CONF_FLOWCONTROL: Final[str] = "flow_control" +CONF_RADIO_TYPE: Final[str] = "radio_type" +CONF_USB_PATH: Final[str] = "usb_path" +CONF_ZIGPY: Final[str] = "zigpy_config" + +DEVICE: Final[str] = "device" + +EVENT: Final[str] = "event" +EVENT_TYPE: Final[str] = "event_type" + +MESSAGE_TYPE: Final[str] = "message_type" + +IEEE: Final[str] = "ieee" +NWK: Final[str] = "nwk" +PAIRING_STATUS: Final[str] = "pairing_status" + + +DEVICES: Final[str] = "devices" +GROUPS: Final[str] = "groups" +DURATION: Final[str] = "duration" +ERROR_CODE: Final[str] = "error_code" +ERROR_MESSAGE: Final[str] = "error_message" +MESSAGE_ID: Final[str] = "message_id" +SUCCESS: Final[str] = "success" +WEBSOCKET_API: Final[str] = "websocket_api" +ZIGBEE_ERROR_CODE: Final[str] = "zigbee_error_code" diff --git a/zha/websocket/server/__init__.py b/zha/websocket/server/__init__.py new file mode 100644 index 000000000..5732f7f2c --- /dev/null +++ b/zha/websocket/server/__init__.py @@ -0,0 +1 @@ +"""Websocket server module for Zigbee Home Automation.""" diff --git a/zha/websocket/server/api/__init__.py b/zha/websocket/server/api/__init__.py new file mode 100644 index 000000000..052e0e7df --- /dev/null +++ b/zha/websocket/server/api/__init__.py @@ -0,0 +1,31 @@ +"""Websocket api for zha.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +from zha.websocket.const import WEBSOCKET_API +from zha.websocket.server.api.model import WebSocketCommand +from zha.websocket.server.api.types import WebSocketCommandHandler + +if TYPE_CHECKING: + from zha.websocket.server.gateway import WebSocketGateway + + +def register_api_command( + server: WebSocketGateway, + command_or_handler: str | WebSocketCommandHandler, + handler: WebSocketCommandHandler | None = None, + model: type[WebSocketCommand] | None = None, +) -> None: + """Register a websocket command.""" + # pylint: disable=protected-access + if handler is None: + handler = cast(WebSocketCommandHandler, command_or_handler) + command = handler._ws_command # type: ignore[attr-defined] + model = handler._ws_command_model # type: ignore[attr-defined] + else: + command = command_or_handler + if (handlers := server.data.get(WEBSOCKET_API)) is None: + handlers = server.data[WEBSOCKET_API] = {} + handlers[command] = (handler, model) diff --git a/zha/websocket/server/api/decorators.py b/zha/websocket/server/api/decorators.py new file mode 100644 index 000000000..42903f379 --- /dev/null +++ b/zha/websocket/server/api/decorators.py @@ -0,0 +1,72 @@ +"""Decorators for the Websocket API.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Callable +from functools import wraps +import logging +from typing import TYPE_CHECKING + +from zha.websocket.server.api.model import WebSocketCommand + +if TYPE_CHECKING: + from zha.websocket.server.api.types import ( + AsyncWebSocketCommandHandler, + T_WebSocketCommand, + WebSocketCommandHandler, + ) + from zha.websocket.server.client import Client + from zha.websocket.server.gateway import WebSocketGateway + +_LOGGER = logging.getLogger(__name__) + + +async def _handle_async_response( + func: AsyncWebSocketCommandHandler, + server: WebSocketGateway, + client: Client, + msg: T_WebSocketCommand, +) -> None: + """Create a response and handle exception.""" + try: + await func(server, client, msg) + except Exception as err: # pylint: disable=broad-except + # TODO fix this to send a real error code and message + _LOGGER.exception("Error handling message", exc_info=err) + client.send_result_error(msg, "API_COMMAND_HANDLER_ERROR", str(err)) + + +def async_response( + func: AsyncWebSocketCommandHandler, +) -> WebSocketCommandHandler: + """Decorate an async function to handle WebSocket API messages.""" + + @wraps(func) + def schedule_handler( + server: WebSocketGateway, client: Client, msg: T_WebSocketCommand + ) -> None: + """Schedule the handler.""" + # As the webserver is now started before the start + # event we do not want to block for websocket responders + server.track_ws_task( + asyncio.create_task(_handle_async_response(func, server, client, msg)) + ) + + return schedule_handler + + +def websocket_command( + ws_command: type[WebSocketCommand], +) -> Callable[[WebSocketCommandHandler], WebSocketCommandHandler]: + """Tag a function as a websocket command.""" + command = ws_command.model_fields["command"].default + + def decorate(func: WebSocketCommandHandler) -> WebSocketCommandHandler: + """Decorate ws command function.""" + # pylint: disable=protected-access + func._ws_command_model = ws_command # type: ignore[attr-defined] + func._ws_command = command # type: ignore[attr-defined] + return func + + return decorate diff --git a/zha/websocket/server/api/model.py b/zha/websocket/server/api/model.py new file mode 100644 index 000000000..370b2e249 --- /dev/null +++ b/zha/websocket/server/api/model.py @@ -0,0 +1,65 @@ +"""Models for the websocket API.""" + +from typing import Literal + +from zha.model import BaseModel +from zha.websocket.const import APICommands + + +class WebSocketCommand(BaseModel): + """Command for the websocket API.""" + + message_id: int = 1 + command: Literal[ + APICommands.STOP_SERVER, + APICommands.CLIENT_LISTEN_RAW_ZCL, + APICommands.CLIENT_DISCONNECT, + APICommands.CLIENT_LISTEN, + APICommands.BUTTON_PRESS, + APICommands.PLATFORM_ENTITY_REFRESH_STATE, + APICommands.ALARM_CONTROL_PANEL_DISARM, + APICommands.ALARM_CONTROL_PANEL_ARM_HOME, + APICommands.ALARM_CONTROL_PANEL_ARM_AWAY, + APICommands.ALARM_CONTROL_PANEL_ARM_NIGHT, + APICommands.ALARM_CONTROL_PANEL_TRIGGER, + APICommands.START_NETWORK, + APICommands.STOP_NETWORK, + APICommands.UPDATE_NETWORK_TOPOLOGY, + APICommands.RECONFIGURE_DEVICE, + APICommands.GET_DEVICES, + APICommands.GET_GROUPS, + APICommands.PERMIT_JOINING, + APICommands.ADD_GROUP_MEMBERS, + APICommands.REMOVE_GROUP_MEMBERS, + APICommands.CREATE_GROUP, + APICommands.REMOVE_GROUPS, + APICommands.REMOVE_DEVICE, + APICommands.READ_CLUSTER_ATTRIBUTES, + APICommands.WRITE_CLUSTER_ATTRIBUTE, + APICommands.SIREN_TURN_ON, + APICommands.SIREN_TURN_OFF, + APICommands.SELECT_SELECT_OPTION, + APICommands.NUMBER_SET_VALUE, + APICommands.LOCK_CLEAR_USER_CODE, + APICommands.LOCK_SET_USER_CODE, + APICommands.LOCK_ENAABLE_USER_CODE, + APICommands.LOCK_DISABLE_USER_CODE, + APICommands.LOCK_LOCK, + APICommands.LOCK_UNLOCK, + APICommands.LIGHT_TURN_OFF, + APICommands.LIGHT_TURN_ON, + APICommands.FAN_SET_PERCENTAGE, + APICommands.FAN_SET_PRESET_MODE, + APICommands.FAN_TURN_ON, + APICommands.FAN_TURN_OFF, + APICommands.COVER_STOP, + APICommands.COVER_SET_POSITION, + APICommands.COVER_OPEN, + APICommands.COVER_CLOSE, + APICommands.CLIMATE_SET_TEMPERATURE, + APICommands.CLIMATE_SET_HVAC_MODE, + APICommands.CLIMATE_SET_FAN_MODE, + APICommands.CLIMATE_SET_PRESET_MODE, + APICommands.SWITCH_TURN_ON, + APICommands.SWITCH_TURN_OFF, + ] diff --git a/zha/websocket/server/api/types.py b/zha/websocket/server/api/types.py new file mode 100644 index 000000000..5819a91ca --- /dev/null +++ b/zha/websocket/server/api/types.py @@ -0,0 +1,15 @@ +"""Type information for the websocket api module.""" + +from __future__ import annotations + +from collections.abc import Callable, Coroutine +from typing import Any, TypeVar + +from zha.websocket.server.api.model import WebSocketCommand + +T_WebSocketCommand = TypeVar("T_WebSocketCommand", bound=WebSocketCommand) + +AsyncWebSocketCommandHandler = Callable[ + [Any, Any, T_WebSocketCommand], Coroutine[Any, Any, None] +] +WebSocketCommandHandler = Callable[[Any, Any, T_WebSocketCommand], None] diff --git a/zha/websocket/server/client.py b/zha/websocket/server/client.py new file mode 100644 index 000000000..f6b4ff879 --- /dev/null +++ b/zha/websocket/server/client.py @@ -0,0 +1,294 @@ +"""Client classes for zhawss.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Callable +import json +import logging +from typing import TYPE_CHECKING, Any, Literal + +from pydantic import BaseModel, ValidationError +from websockets.server import WebSocketServerProtocol + +from zha.websocket.const import ( + COMMAND, + ERROR_CODE, + ERROR_MESSAGE, + EVENT_TYPE, + MESSAGE_ID, + MESSAGE_TYPE, + SUCCESS, + WEBSOCKET_API, + ZIGBEE_ERROR_CODE, + APICommands, + EventTypes, + MessageTypes, +) +from zha.websocket.server.api import decorators, register_api_command +from zha.websocket.server.api.model import WebSocketCommand + +if TYPE_CHECKING: + from zha.websocket.server.gateway import WebSocketGateway + +_LOGGER = logging.getLogger(__name__) + + +class Client: + """ZHAWSS client implementation.""" + + def __init__( + self, + websocket: WebSocketServerProtocol, + client_manager: ClientManager, + ): + """Initialize the client.""" + self._websocket: WebSocketServerProtocol = websocket + self._client_manager: ClientManager = client_manager + self.receive_events: bool = False + self.receive_raw_zcl_events: bool = False + + @property + def is_connected(self) -> bool: + """Return True if the websocket connection is connected.""" + return self._websocket.open + + def disconnect(self) -> None: + """Disconnect this client and close the websocket.""" + self._client_manager.server.track_ws_task( + asyncio.create_task(self._websocket.close()) + ) + + def send_event(self, message: dict[str, Any]) -> None: + """Send event data to this client.""" + message[MESSAGE_TYPE] = MessageTypes.EVENT + self._send_data(message) + + def send_result_success( + self, command: WebSocketCommand, data: dict[str, Any] | None = None + ) -> None: + """Send success result prompted by a client request.""" + message = { + SUCCESS: True, + MESSAGE_ID: command.message_id, + MESSAGE_TYPE: MessageTypes.RESULT, + COMMAND: command.command, + } + if data: + message.update(data) + self._send_data(message) + + def send_result_error( + self, + command: WebSocketCommand, + error_code: str, + error_message: str, + data: dict[str, Any] | None = None, + ) -> None: + """Send error result prompted by a client request.""" + message = { + SUCCESS: False, + MESSAGE_ID: command.message_id, + MESSAGE_TYPE: MessageTypes.RESULT, + COMMAND: f"error.{command.command}", + ERROR_CODE: error_code, + ERROR_MESSAGE: error_message, + } + if data: + message.update(data) + self._send_data(message) + + def send_result_zigbee_error( + self, + command: WebSocketCommand, + error_message: str, + zigbee_error_code: str, + ) -> None: + """Send zigbee error result prompted by a client zigbee request.""" + self.send_result_error( + command, + error_code="zigbee_error", + error_message=error_message, + data={ZIGBEE_ERROR_CODE: zigbee_error_code}, + ) + + def _send_data(self, message: dict[str, Any] | BaseModel) -> None: + """Send data to this client.""" + try: + if isinstance(message, BaseModel): + message_json = message.model_dump_json() + else: + message_json = json.dumps(message) + except ValueError as exc: + _LOGGER.exception("Couldn't serialize data: %s", message, exc_info=exc) + raise exc + else: + self._client_manager.server.track_ws_task( + asyncio.create_task(self._websocket.send(message_json)) + ) + + async def _handle_incoming_message(self, message: str | bytes) -> None: + """Handle an incoming message.""" + _LOGGER.info("Message received: %s", message) + handlers: dict[str, tuple[Callable, WebSocketCommand]] = ( + self._client_manager.server.data[WEBSOCKET_API] + ) + + try: + msg = WebSocketCommand.model_validate_json(message) + except ValidationError as exception: + _LOGGER.exception( + "Received invalid command[unable to parse command]: %s on websocket: %s", + message, + self._websocket.id, + exc_info=exception, + ) + return + + if msg.command not in handlers: + _LOGGER.error( + "Received invalid command[command not registered]: %s", message + ) + return + + handler, model = handlers[msg.command] + + try: + handler( + self._client_manager.server, self, model.model_validate_json(message) + ) + except Exception as err: # pylint: disable=broad-except + # TODO Fix this - make real error codes with error messages + _LOGGER.exception("Error handling message: %s", message, exc_info=err) + self.send_result_error(message, "INTERNAL_ERROR", f"Internal error: {err}") + + async def listen(self) -> None: + """Listen for incoming messages.""" + async for message in self._websocket: + self._client_manager.server.track_ws_task( + asyncio.create_task(self._handle_incoming_message(message)) + ) + + def will_accept_message(self, message: dict[str, Any]) -> bool: + """Determine if client accepts this type of message.""" + if not self.receive_events: + return False + + if ( + message[EVENT_TYPE] == EventTypes.RAW_ZCL_EVENT + and not self.receive_raw_zcl_events + ): + _LOGGER.info( + "Client %s not accepting raw ZCL events: %s", + self._websocket.id, + message, + ) + return False + + return True + + +class ClientListenRawZCLCommand(WebSocketCommand): + """Listen to raw ZCL data.""" + + command: Literal[APICommands.CLIENT_LISTEN_RAW_ZCL] = ( + APICommands.CLIENT_LISTEN_RAW_ZCL + ) + + +class ClientListenCommand(WebSocketCommand): + """Listen for zhawss messages.""" + + command: Literal[APICommands.CLIENT_LISTEN] = APICommands.CLIENT_LISTEN + + +class ClientDisconnectCommand(WebSocketCommand): + """Disconnect this client.""" + + command: Literal[APICommands.CLIENT_DISCONNECT] = APICommands.CLIENT_DISCONNECT + + +@decorators.websocket_command(ClientListenRawZCLCommand) +@decorators.async_response +async def listen_raw_zcl( + server: WebSocketGateway, client: Client, command: WebSocketCommand +) -> None: + """Listen for raw ZCL events.""" + client.receive_raw_zcl_events = True + client.send_result_success(command) + + +@decorators.websocket_command(ClientListenCommand) +@decorators.async_response +async def listen( + server: WebSocketGateway, client: Client, command: WebSocketCommand +) -> None: + """Listen for events.""" + client.receive_events = True + client.send_result_success(command) + + +@decorators.websocket_command(ClientDisconnectCommand) +@decorators.async_response +async def disconnect( + server: WebSocketGateway, client: Client, command: WebSocketCommand +) -> None: + """Disconnect the client.""" + client.disconnect() + server.client_manager.remove_client(client) + + +def load_api(server: WebSocketGateway) -> None: + """Load the api command handlers.""" + register_api_command(server, listen_raw_zcl) + register_api_command(server, listen) + register_api_command(server, disconnect) + + +class ClientManager: + """ZHAWSS client manager implementation.""" + + def __init__(self, server: WebSocketGateway): + """Initialize the client.""" + self._server: WebSocketGateway = server + self._clients: list[Client] = [] + + @property + def server(self) -> WebSocketGateway: + """Return the server this ClientManager belongs to.""" + return self._server + + async def add_client(self, websocket: WebSocketServerProtocol) -> None: + """Add a new client to the client manager.""" + client: Client = Client(websocket, self) + self._clients.append(client) + await client.listen() + + def remove_client(self, client: Client) -> None: + """Remove a client from the client manager.""" + client.disconnect() + self._clients.remove(client) + + def broadcast(self, message: dict[str, Any]) -> None: + """Broadcast a message to all connected clients.""" + clients_to_remove = [] + + for client in self._clients: + if not client.is_connected: + # XXX: We cannot remove elements from `_clients` while iterating over it + clients_to_remove.append(client) + continue + + if not client.will_accept_message(message): + continue + + _LOGGER.info( + "Broadcasting message: %s to client: %s", + message, + client._websocket.id, + ) + # TODO use the receive flags on the client to determine if the client should receive the message + client.send_event(message) + + for client in clients_to_remove: + self.remove_client(client) diff --git a/zha/websocket/server/gateway.py b/zha/websocket/server/gateway.py new file mode 100644 index 000000000..9d9dec7b7 --- /dev/null +++ b/zha/websocket/server/gateway.py @@ -0,0 +1,144 @@ +"""ZHAWSS websocket server.""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +from types import TracebackType +from typing import TYPE_CHECKING, Any, Final, Literal + +import websockets + +from zha.application.discovery import PLATFORMS +from zha.application.gateway import Gateway +from zha.application.helpers import ZHAData +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command +from zha.websocket.server.api.model import WebSocketCommand +from zha.websocket.server.client import ClientManager + +if TYPE_CHECKING: + from zha.websocket.client import Client + +BLOCK_LOG_TIMEOUT: Final[int] = 60 +_LOGGER = logging.getLogger(__name__) + + +class WebSocketGateway(Gateway): + """ZHAWSS server implementation.""" + + def __init__(self, config: ZHAData) -> None: + """Initialize the websocket gateway.""" + super().__init__(config) + self._ws_server: websockets.WebSocketServer | None = None + self._client_manager: ClientManager = ClientManager(self) + self._stopped_event: asyncio.Event = asyncio.Event() + self._tracked_ws_tasks: set[asyncio.Task] = set() + self.data: dict[Any, Any] = {} + for platform in PLATFORMS: + self.data.setdefault(platform, []) + self._register_api_commands() + + @property + def is_serving(self) -> bool: + """Return whether or not the websocket server is serving.""" + return self._ws_server is not None and self._ws_server.is_serving + + @property + def client_manager(self) -> ClientManager: + """Return the zigbee application controller.""" + return self._client_manager + + async def start_server(self) -> None: + """Start the websocket server.""" + assert self._ws_server is None + self._stopped_event.clear() + self._ws_server = await websockets.serve( + self.client_manager.add_client, + self.config.server_config.host, + self.config.server_config.port, + logger=_LOGGER, + ) + if self.config.server_config.network_auto_start: + await self.async_initialize() + self.on_all_events(self.client_manager.broadcast) + await self.async_initialize_devices_and_entities() + + async def stop_server(self) -> None: + """Stop the websocket server.""" + if self._ws_server is None: + self._stopped_event.set() + return + + assert self._ws_server is not None + + await self.shutdown() + + self._ws_server.close() + await self._ws_server.wait_closed() + self._ws_server = None + + self._stopped_event.set() + + async def wait_closed(self) -> None: + """Wait until the server is not running.""" + await self._stopped_event.wait() + _LOGGER.info("Server stopped. Completing remaining tasks...") + tasks = [t for t in self._tracked_ws_tasks if not (t.done() or t.cancelled())] + for task in tasks: + _LOGGER.debug("Cancelling task: %s", task) + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await asyncio.gather(*tasks, return_exceptions=True) + + tasks = [ + t + for t in self._tracked_completable_tasks + if not (t.done() or t.cancelled()) + ] + for task in tasks: + _LOGGER.debug("Cancelling task: %s", task) + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await asyncio.gather(*tasks, return_exceptions=True) + + def track_ws_task(self, task: asyncio.Task) -> None: + """Create a tracked ws task.""" + self._tracked_ws_tasks.add(task) + task.add_done_callback(self._tracked_ws_tasks.remove) + + async def __aenter__(self) -> WebSocketGateway: + """Enter the context manager.""" + await self.start_server() + return self + + async def __aexit__( + self, exc_type: Exception, exc_value: str, traceback: TracebackType + ) -> None: + """Exit the context manager.""" + await self.stop_server() + await self.wait_closed() + + def _register_api_commands(self) -> None: + """Load server API commands.""" + from zha.websocket.server.client import load_api as load_client_api + + register_api_command(self, stop_server) + load_client_api(self) + + +class StopServerCommand(WebSocketCommand): + """Stop the server.""" + + command: Literal[APICommands.STOP_SERVER] = APICommands.STOP_SERVER + + +@decorators.websocket_command(StopServerCommand) +@decorators.async_response +async def stop_server( + server: WebSocketGateway, client: Client, command: WebSocketCommand +) -> None: + """Stop the Zigbee network.""" + client.send_result_success(command) + await server.stop_server() diff --git a/zha/websocket/server/gateway_api.py b/zha/websocket/server/gateway_api.py new file mode 100644 index 000000000..122d42c95 --- /dev/null +++ b/zha/websocket/server/gateway_api.py @@ -0,0 +1,474 @@ +"""Websocket API for zhawss.""" + +from __future__ import annotations + +import asyncio +import dataclasses +import logging +from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeVar, Union, cast + +from pydantic import Field +from zigpy.types.named import EUI64 + +from zha.websocket.client.model.types import ( + Device as DeviceModel, + Group as GroupModel, + GroupMemberReference, +) +from zha.websocket.const import DEVICES, DURATION, GROUPS, APICommands +from zha.websocket.server.api import decorators, register_api_command +from zha.websocket.server.api.model import WebSocketCommand +from zha.zigbee.device import Device +from zha.zigbee.group import Group + +if TYPE_CHECKING: + from zha.websocket.server.client import Client + from zha.websocket.server.gateway import WebSocketGateway + +GROUP = "group" +MFG_CLUSTER_ID_START = 0xFC00 + +_LOGGER = logging.getLogger(__name__) + +T = TypeVar("T") + + +def ensure_list(value: T | None) -> list[T] | list[Any]: + """Wrap value in list if it is not one.""" + if value is None: + return [] + return cast("list[T]", value) if isinstance(value, list) else [value] + + +class StartNetworkCommand(WebSocketCommand): + """Start the Zigbee network.""" + + command: Literal[APICommands.START_NETWORK] = APICommands.START_NETWORK + + +@decorators.websocket_command(StartNetworkCommand) +@decorators.async_response +async def start_network( + gateway: WebSocketGateway, client: Client, command: StartNetworkCommand +) -> None: + """Start the Zigbee network.""" + await gateway.start_network() + client.send_result_success(command) + + +class StopNetworkCommand(WebSocketCommand): + """Stop the Zigbee network.""" + + command: Literal[APICommands.STOP_NETWORK] = APICommands.STOP_NETWORK + + +@decorators.websocket_command(StopNetworkCommand) +@decorators.async_response +async def stop_network( + gateway: WebSocketGateway, client: Client, command: StopNetworkCommand +) -> None: + """Stop the Zigbee network.""" + await gateway.stop_network() + client.send_result_success(command) + + +class UpdateTopologyCommand(WebSocketCommand): + """Stop the Zigbee network.""" + + command: Literal[APICommands.UPDATE_NETWORK_TOPOLOGY] = ( + APICommands.UPDATE_NETWORK_TOPOLOGY + ) + + +@decorators.websocket_command(UpdateTopologyCommand) +@decorators.async_response +async def update_topology( + gateway: WebSocketGateway, client: Client, command: WebSocketCommand +) -> None: + """Update the Zigbee network topology.""" + await gateway.application_controller.topology.scan() + client.send_result_success(command) + + +class GetDevicesCommand(WebSocketCommand): + """Get all Zigbee devices.""" + + command: Literal[APICommands.GET_DEVICES] = APICommands.GET_DEVICES + + +@decorators.websocket_command(GetDevicesCommand) +@decorators.async_response +async def get_devices( + gateway: WebSocketGateway, client: Client, command: GetDevicesCommand +) -> None: + """Get Zigbee devices.""" + try: + response_devices: dict[str, dict] = { + str(ieee): DeviceModel.model_validate( + dataclasses.asdict(device.extended_device_info) + ).model_dump() + for ieee, device in gateway.devices.items() + } + _LOGGER.info("devices: %s", response_devices) + client.send_result_success(command, {DEVICES: response_devices}) + except Exception as e: + _LOGGER.exception("Error getting devices", exc_info=e) + client.send_result_error(command, "Error getting devices", str(e)) + + +class ReconfigureDeviceCommand(WebSocketCommand): + """Reconfigure a zigbee device.""" + + command: Literal[APICommands.RECONFIGURE_DEVICE] = APICommands.RECONFIGURE_DEVICE + ieee: EUI64 + + +@decorators.websocket_command(ReconfigureDeviceCommand) +@decorators.async_response +async def reconfigure_device( + gateway: WebSocketGateway, client: Client, command: ReconfigureDeviceCommand +) -> None: + """Reconfigure a zigbee device.""" + device = gateway.devices.get(command.ieee) + if device: + await device.async_configure() + client.send_result_success(command) + + +class GetGroupsCommand(WebSocketCommand): + """Get all Zigbee devices.""" + + command: Literal[APICommands.GET_GROUPS] = APICommands.GET_GROUPS + + +@decorators.websocket_command(GetGroupsCommand) +@decorators.async_response +async def get_groups( + gateway: WebSocketGateway, client: Client, command: GetGroupsCommand +) -> None: + """Get Zigbee groups.""" + groups: dict[int, Any] = {} + for group_id, group in gateway.groups.items(): + group_data = dataclasses.asdict(group.info_object) + group_data["id"] = group_id + groups[group_id] = GroupModel.model_validate(group_data).model_dump() + _LOGGER.info("groups: %s", groups) + client.send_result_success(command, {GROUPS: groups}) + + +class PermitJoiningCommand(WebSocketCommand): + """Permit joining.""" + + command: Literal[APICommands.PERMIT_JOINING] = APICommands.PERMIT_JOINING + duration: Annotated[int, Field(ge=1, le=254)] = 60 + ieee: Union[EUI64, None] = None + + +@decorators.websocket_command(PermitJoiningCommand) +@decorators.async_response +async def permit_joining( + gateway: WebSocketGateway, client: Client, command: PermitJoiningCommand +) -> None: + """Permit joining devices to the Zigbee network.""" + # TODO add permit with code support + await gateway.application_controller.permit(command.duration, command.ieee) + client.send_result_success( + command, + {DURATION: command.duration}, + ) + + +class RemoveDeviceCommand(WebSocketCommand): + """Remove device command.""" + + command: Literal[APICommands.REMOVE_DEVICE] = APICommands.REMOVE_DEVICE + ieee: EUI64 + + +@decorators.websocket_command(RemoveDeviceCommand) +@decorators.async_response +async def remove_device( + gateway: WebSocketGateway, client: Client, command: RemoveDeviceCommand +) -> None: + """Permit joining devices to the Zigbee network.""" + await gateway.async_remove_device(command.ieee) + client.send_result_success(command) + + +class ReadClusterAttributesCommand(WebSocketCommand): + """Read cluster attributes command.""" + + command: Literal[APICommands.READ_CLUSTER_ATTRIBUTES] = ( + APICommands.READ_CLUSTER_ATTRIBUTES + ) + ieee: EUI64 + endpoint_id: int + cluster_id: int + cluster_type: Literal["in", "out"] + attributes: list[str] + manufacturer_code: Union[int, None] = None + + +@decorators.websocket_command(ReadClusterAttributesCommand) +@decorators.async_response +async def read_cluster_attributes( + gateway: WebSocketGateway, client: Client, command: ReadClusterAttributesCommand +) -> None: + """Read the specified cluster attributes.""" + device: Device = gateway.devices[command.ieee] + if not device: + client.send_result_error( + command, + "Device not found", + f"Device with ieee: {command.ieee} not found", + ) + return + endpoint_id = command.endpoint_id + cluster_id = command.cluster_id + cluster_type = command.cluster_type + attributes = command.attributes + manufacturer = command.manufacturer_code + if cluster_id >= MFG_CLUSTER_ID_START and manufacturer is None: + manufacturer = device.manufacturer_code + cluster = device.async_get_cluster( + endpoint_id, cluster_id, cluster_type=cluster_type + ) + if not cluster: + client.send_result_error( + command, + "Cluster not found", + f"Cluster: {endpoint_id}:{command.cluster_id} not found on device with ieee: {str(command.ieee)} not found", + ) + return + success, failure = await cluster.read_attributes( + attributes, allow_cache=False, only_cache=False, manufacturer=manufacturer + ) + client.send_result_success( + command, + { + "device": { + "ieee": command.ieee, + }, + "cluster": { + "id": cluster.cluster_id, + "endpoint_id": cluster.endpoint.endpoint_id, + "name": cluster.name, + "endpoint_attribute": cluster.ep_attribute, + }, + "manufacturer_code": manufacturer, + "succeeded": success, + "failed": failure, + }, + ) + + +class WriteClusterAttributeCommand(WebSocketCommand): + """Write cluster attribute command.""" + + command: Literal[APICommands.WRITE_CLUSTER_ATTRIBUTE] = ( + APICommands.WRITE_CLUSTER_ATTRIBUTE + ) + ieee: EUI64 + endpoint_id: int + cluster_id: int + cluster_type: Literal["in", "out"] + attribute: str + value: Union[str, int, float, bool] + manufacturer_code: Union[int, None] = None + + +@decorators.websocket_command(WriteClusterAttributeCommand) +@decorators.async_response +async def write_cluster_attribute( + gateway: WebSocketGateway, client: Client, command: WriteClusterAttributeCommand +) -> None: + """Set the value of the specific cluster attribute.""" + device: Device = gateway.devices[command.ieee] + if not device: + client.send_result_error( + command, + "Device not found", + f"Device with ieee: {command.ieee} not found", + ) + return + endpoint_id = command.endpoint_id + cluster_id = command.cluster_id + cluster_type = command.cluster_type + attribute = command.attribute + value = command.value + manufacturer = command.manufacturer_code + if cluster_id >= MFG_CLUSTER_ID_START and manufacturer is None: + manufacturer = device.manufacturer_code + cluster = device.async_get_cluster( + endpoint_id, cluster_id, cluster_type=cluster_type + ) + if not cluster: + client.send_result_error( + command, + "Cluster not found", + f"Cluster: {endpoint_id}:{command.cluster_id} not found on device with ieee: {str(command.ieee)} not found", + ) + return + response = await device.write_zigbee_attribute( + endpoint_id, + cluster_id, + attribute, + value, + cluster_type=cluster_type, + manufacturer=manufacturer, + ) + client.send_result_success( + command, + { + "device": { + "ieee": str(command.ieee), + }, + "cluster": { + "id": cluster.cluster_id, + "endpoint_id": cluster.endpoint.endpoint_id, + "name": cluster.name, + "endpoint_attribute": cluster.ep_attribute, + }, + "manufacturer_code": manufacturer, + "response": { + "attribute": attribute, + "status": response[0][0].status.name, # type: ignore + }, # TODO there has to be a better way to do this + }, + ) + + +class CreateGroupCommand(WebSocketCommand): + """Create group command.""" + + command: Literal[APICommands.CREATE_GROUP] = APICommands.CREATE_GROUP + group_name: str + members: list[GroupMemberReference] + group_id: Union[int, None] = None + + +@decorators.websocket_command(CreateGroupCommand) +@decorators.async_response +async def create_group( + gateway: WebSocketGateway, client: Client, command: CreateGroupCommand +) -> None: + """Create a new group.""" + group_name = command.group_name + members = command.members + group_id = command.group_id + group: Group = await gateway.async_create_zigpy_group(group_name, members, group_id) + ret_group = dataclasses.asdict(group.info_object) + ret_group["id"] = ret_group["group_id"] + ret_group = GroupModel.model_validate(ret_group).model_dump() + client.send_result_success(command, {"group": ret_group}) + + +class RemoveGroupsCommand(WebSocketCommand): + """Remove groups command.""" + + command: Literal[APICommands.REMOVE_GROUPS] = APICommands.REMOVE_GROUPS + group_ids: list[int] + + +@decorators.websocket_command(RemoveGroupsCommand) +@decorators.async_response +async def remove_groups( + gateway: WebSocketGateway, client: Client, command: RemoveGroupsCommand +) -> None: + """Remove the specified groups.""" + group_ids = command.group_ids + + if len(group_ids) > 1: + tasks = [] + for group_id in group_ids: + tasks.append(gateway.async_remove_zigpy_group(group_id)) + await asyncio.gather(*tasks) + else: + await gateway.async_remove_zigpy_group(group_ids[0]) + groups: dict[int, Any] = {} + for id, group in gateway.groups.items(): + group_data = dataclasses.asdict(group.info_object) + group_data["id"] = group_data["group_id"] + groups[id] = GroupModel.model_validate(group_data).model_dump() + _LOGGER.info("groups: %s", groups) + client.send_result_success(command, {GROUPS: groups}) + + +class AddGroupMembersCommand(WebSocketCommand): + """Add group members command.""" + + command: Literal[ + APICommands.ADD_GROUP_MEMBERS, APICommands.REMOVE_GROUP_MEMBERS + ] = APICommands.ADD_GROUP_MEMBERS + group_id: int + members: list[GroupMemberReference] + + +@decorators.websocket_command(AddGroupMembersCommand) +@decorators.async_response +async def add_group_members( + gateway: WebSocketGateway, client: Client, command: AddGroupMembersCommand +) -> None: + """Add members to a ZHA group.""" + group_id = command.group_id + members = command.members + group = None + + if group_id in gateway.groups: + group = gateway.groups[group_id] + await group.async_add_members(members) + if not group: + client.send_result_error(command, "G1", "ZHA Group not found") + return + ret_group = dataclasses.asdict(group.info_object) + ret_group["id"] = ret_group["group_id"] + ret_group = GroupModel.model_validate(ret_group).model_dump() + client.send_result_success(command, {GROUP: ret_group}) + + +class RemoveGroupMembersCommand(AddGroupMembersCommand): + """Remove group members command.""" + + command: Literal[APICommands.REMOVE_GROUP_MEMBERS] = ( + APICommands.REMOVE_GROUP_MEMBERS + ) + + +@decorators.websocket_command(RemoveGroupMembersCommand) +@decorators.async_response +async def remove_group_members( + gateway: WebSocketGateway, client: Client, command: RemoveGroupMembersCommand +) -> None: + """Remove members from a ZHA group.""" + group_id = command.group_id + members = command.members + group = None + + if group_id in gateway.groups: + group = gateway.groups[group_id] + await group.async_remove_members(members) + if not group: + client.send_result_error(command, "G1", "ZHA Group not found") + return + ret_group = dataclasses.asdict(group.info_object) + ret_group["id"] = ret_group["group_id"] + ret_group = GroupModel.model_validate(ret_group).model_dump() + client.send_result_success(command, {GROUP: ret_group}) + + +def load_api(gateway: WebSocketGateway) -> None: + """Load the api command handlers.""" + register_api_command(gateway, start_network) + register_api_command(gateway, stop_network) + register_api_command(gateway, get_devices) + register_api_command(gateway, reconfigure_device) + register_api_command(gateway, get_groups) + register_api_command(gateway, create_group) + register_api_command(gateway, remove_groups) + register_api_command(gateway, add_group_members) + register_api_command(gateway, remove_group_members) + register_api_command(gateway, permit_joining) + register_api_command(gateway, remove_device) + register_api_command(gateway, update_topology) + register_api_command(gateway, read_cluster_attributes) + register_api_command(gateway, write_cluster_attribute) From 3c2299bd6af95deb252afbc4ccc2e04681b2ebf2 Mon Sep 17 00:00:00 2001 From: David Mulcahey Date: Sun, 20 Oct 2024 16:38:58 -0400 Subject: [PATCH 10/12] restructure, add entity APIs back and remove duplicate models --- tests/common.py | 64 ++ tests/conftest.py | 36 +- ...entralite-3320-l-extended-device-info.json | 2 +- tests/test_gateway.py | 8 +- tests/test_model.py | 4 +- tests/websocket/__init__.py | 1 + tests/websocket/test_binary_sensor.py | 124 +++ tests/websocket/test_button.py | 76 ++ tests/websocket/test_client_controller.py | 396 +++++++++ tests/websocket/test_number.py | 119 +++ tests/websocket/test_siren.py | 177 ++++ tests/websocket/test_switch.py | 363 +++++++++ .../test_websocket_server_client.py | 0 zha/application/gateway.py | 191 ++--- zha/application/model.py | 144 ++++ zha/application/platforms/__init__.py | 107 +-- .../platforms/alarm_control_panel/__init__.py | 2 + zha/application/platforms/model.py | 730 +++++++++++++++++ zha/application/platforms/number/__init__.py | 2 +- zha/const.py | 2 +- zha/model.py | 14 +- zha/websocket/client/client.py | 19 +- zha/websocket/client/controller.py | 119 +-- zha/websocket/client/helpers.py | 706 +++++++++++++++- zha/websocket/client/model/commands.py | 200 ----- zha/websocket/client/model/events.py | 263 ------ zha/websocket/client/model/messages.py | 3 +- zha/websocket/client/model/types.py | 760 ------------------ zha/websocket/client/proxy.py | 64 +- zha/websocket/const.py | 2 +- zha/websocket/server/api/model.py | 236 +++++- .../server/api/platforms/__init__.py | 19 + .../platforms/alarm_control_panel/__init__.py | 3 + .../api/platforms/alarm_control_panel/api.py | 117 +++ zha/websocket/server/api/platforms/api.py | 124 +++ .../server/api/platforms/button/__init__.py | 3 + .../server/api/platforms/button/api.py | 34 + .../server/api/platforms/climate/__init__.py | 3 + .../server/api/platforms/climate/api.py | 128 +++ .../server/api/platforms/cover/__init__.py | 3 + .../server/api/platforms/cover/api.py | 86 ++ .../server/api/platforms/fan/__init__.py | 3 + zha/websocket/server/api/platforms/fan/api.py | 94 +++ .../server/api/platforms/light/__init__.py | 3 + .../server/api/platforms/light/api.py | 85 ++ .../server/api/platforms/lock/__init__.py | 3 + .../server/api/platforms/lock/api.py | 136 ++++ .../server/api/platforms/number/__init__.py | 3 + .../server/api/platforms/number/api.py | 40 + .../server/api/platforms/select/__init__.py | 3 + .../server/api/platforms/select/api.py | 41 + .../server/api/platforms/siren/__init__.py | 3 + .../server/api/platforms/siren/api.py | 54 ++ .../server/api/platforms/switch/__init__.py | 3 + .../server/api/platforms/switch/api.py | 51 ++ zha/websocket/server/client.py | 38 +- zha/websocket/server/gateway.py | 41 +- zha/websocket/server/gateway_api.py | 125 ++- zha/zigbee/cluster_handlers/__init__.py | 110 +-- zha/zigbee/cluster_handlers/general.py | 12 +- zha/zigbee/cluster_handlers/model.py | 83 ++ zha/zigbee/device.py | 203 +---- zha/zigbee/group.py | 46 +- zha/zigbee/model.py | 329 ++++++++ 64 files changed, 4990 insertions(+), 1973 deletions(-) create mode 100644 tests/websocket/__init__.py create mode 100644 tests/websocket/test_binary_sensor.py create mode 100644 tests/websocket/test_button.py create mode 100644 tests/websocket/test_client_controller.py create mode 100644 tests/websocket/test_number.py create mode 100644 tests/websocket/test_siren.py create mode 100644 tests/websocket/test_switch.py rename tests/{ => websocket}/test_websocket_server_client.py (100%) create mode 100644 zha/application/model.py create mode 100644 zha/application/platforms/model.py delete mode 100644 zha/websocket/client/model/commands.py delete mode 100644 zha/websocket/client/model/events.py delete mode 100644 zha/websocket/client/model/types.py create mode 100644 zha/websocket/server/api/platforms/__init__.py create mode 100644 zha/websocket/server/api/platforms/alarm_control_panel/__init__.py create mode 100644 zha/websocket/server/api/platforms/alarm_control_panel/api.py create mode 100644 zha/websocket/server/api/platforms/api.py create mode 100644 zha/websocket/server/api/platforms/button/__init__.py create mode 100644 zha/websocket/server/api/platforms/button/api.py create mode 100644 zha/websocket/server/api/platforms/climate/__init__.py create mode 100644 zha/websocket/server/api/platforms/climate/api.py create mode 100644 zha/websocket/server/api/platforms/cover/__init__.py create mode 100644 zha/websocket/server/api/platforms/cover/api.py create mode 100644 zha/websocket/server/api/platforms/fan/__init__.py create mode 100644 zha/websocket/server/api/platforms/fan/api.py create mode 100644 zha/websocket/server/api/platforms/light/__init__.py create mode 100644 zha/websocket/server/api/platforms/light/api.py create mode 100644 zha/websocket/server/api/platforms/lock/__init__.py create mode 100644 zha/websocket/server/api/platforms/lock/api.py create mode 100644 zha/websocket/server/api/platforms/number/__init__.py create mode 100644 zha/websocket/server/api/platforms/number/api.py create mode 100644 zha/websocket/server/api/platforms/select/__init__.py create mode 100644 zha/websocket/server/api/platforms/select/api.py create mode 100644 zha/websocket/server/api/platforms/siren/__init__.py create mode 100644 zha/websocket/server/api/platforms/siren/api.py create mode 100644 zha/websocket/server/api/platforms/switch/__init__.py create mode 100644 zha/websocket/server/api/platforms/switch/api.py create mode 100644 zha/zigbee/cluster_handlers/model.py create mode 100644 zha/zigbee/model.py diff --git a/tests/common.py b/tests/common.py index bff7c862e..6cee2a9fd 100644 --- a/tests/common.py +++ b/tests/common.py @@ -542,3 +542,67 @@ def create_mock_zigpy_device( cluster._attr_cache[attr_id] = value return device + + +def find_entity_id( + domain: str, zha_device: Device, qualifier: Optional[str] = None +) -> Optional[str]: + """Find the entity id under the testing. + + This is used to get the entity id in order to get the state from the state + machine so that we can test state changes. + """ + entities = find_entity_ids(domain, zha_device) + if not entities: + return None + if qualifier: + for entity_id in entities: + if qualifier in entity_id: + return entity_id + return None + else: + return entities[0] + + +def find_entity_ids( + domain: str, zha_device: Device, omit: Optional[list[str]] = None +) -> list[str]: + """Find the entity ids under the testing. + + This is used to get the entity id in order to get the state from the state + machine so that we can test state changes. + """ + head = f"{domain}.{str(zha_device.ieee)}" + + entity_ids = [ + f"{entity.PLATFORM}.{entity.unique_id}" + for entity in zha_device.platform_entities.values() + ] + + matches = [] + res = [] + for entity_id in entity_ids: + if entity_id.startswith(head): + matches.append(entity_id) + + if omit: + for entity_id in matches: + skip = False + for o in omit: + if o in entity_id: + skip = True + break + if not skip: + res.append(entity_id) + else: + res = matches + return res + + +def async_find_group_entity_id(domain: str, group: Group) -> Optional[str]: + """Find the group entity id under test.""" + entity_id = f"{domain}_zha_group_0x{group.group_id:04x}" + + if entity_id in group.group_entities: + return entity_id + return None diff --git a/tests/conftest.py b/tests/conftest.py index 81290f427..c2726ef3b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -234,7 +234,21 @@ async def zigpy_app_controller_fixture(): # Create a fake coordinator device dev = app.add_device(nwk=app.state.node_info.nwk, ieee=app.state.node_info.ieee) - dev.node_desc = zdo_t.NodeDescriptor() + dev.node_desc = zdo_t.NodeDescriptor( + logical_type=zdo_t.LogicalType.Coordinator, + complex_descriptor_available=0, + user_descriptor_available=0, + reserved=0, + aps_flags=0, + frequency_band=zdo_t.NodeDescriptor.FrequencyBand.Freq2400MHz, + mac_capability_flags=zdo_t.NodeDescriptor.MACCapabilityFlags.AllocateAddress, + manufacturer_code=0x1234, + maximum_buffer_size=127, + maximum_incoming_transfer_size=100, + server_mask=10752, + maximum_outgoing_transfer_size=100, + descriptor_capability_field=zdo_t.NodeDescriptor.DescriptorCapability.NONE, + ) dev.node_desc.logical_type = zdo_t.LogicalType.Coordinator dev.manufacturer = "Coordinator Manufacturer" dev.model = "Coordinator Model" @@ -312,16 +326,24 @@ async def __aexit__( async def connected_client_and_server( zha_data: ZHAData, zigpy_app_controller: ControllerApplication, + caplog: pytest.LogCaptureFixture, # pylint: disable=unused-argument ) -> AsyncGenerator[tuple[Controller, WebSocketGateway], None]: """Return the connected client and server fixture.""" - application_controller_patch = patch( - "bellows.zigbee.application.ControllerApplication.new", - return_value=zigpy_app_controller, - ) - - with application_controller_patch: + with ( + patch( + "bellows.zigbee.application.ControllerApplication.new", + return_value=zigpy_app_controller, + ), + patch( + "bellows.zigbee.application.ControllerApplication", + return_value=zigpy_app_controller, + ), + ): ws_gateway = await WebSocketGateway.async_from_config(zha_data) + await ws_gateway.async_initialize() + await ws_gateway.async_block_till_done() + await ws_gateway.async_initialize_devices_and_entities() async with ( ws_gateway as gateway, Controller(f"ws://localhost:{zha_data.server_config.port}") as controller, diff --git a/tests/data/serialization_data/centralite-3320-l-extended-device-info.json b/tests/data/serialization_data/centralite-3320-l-extended-device-info.json index f52e1d153..c50de9b65 100644 --- a/tests/data/serialization_data/centralite-3320-l-extended-device-info.json +++ b/tests/data/serialization_data/centralite-3320-l-extended-device-info.json @@ -1 +1 @@ -{"ieee":"00:0d:6f:00:0f:3a:e3:69","nwk":"0x970A","manufacturer":"CentraLite","model":"3320-L","name":"CentraLite 3320-L","quirk_applied":true,"quirk_class":"zhaquirks.centralite.ias.CentraLiteIASSensor","quirk_id":null,"manufacturer_code":49887,"power_source":"Battery or Unknown","lqi":null,"rssi":null,"available":true,"device_type":"EndDevice","signature":{"node_descriptor":{"logical_type":2,"complex_descriptor_available":0,"user_descriptor_available":0,"reserved":0,"aps_flags":0,"frequency_band":8,"mac_capability_flags":128,"manufacturer_code":49887,"maximum_buffer_size":82,"maximum_incoming_transfer_size":82,"server_mask":0,"maximum_outgoing_transfer_size":82,"descriptor_capability_field":0},"endpoints":{"1":{"profile_id":"0x0104","device_type":"0x0402","input_clusters":["0x0000","0x0001","0x0003","0x0020","0x0402","0x0500","0x0b05"],"output_clusters":["0x0019"]},"2":{"profile_id":"0xc2df","device_type":"0x000c","input_clusters":["0x0000","0x0003","0x0b05","0xfc0f"],"output_clusters":["0x0003"]}},"manufacturer":"CentraLite","model":"3320-L"},"active_coordinator":false,"entities":{"binary_sensor,00:0d:6f:00:0f:3a:e3:69-1-1280":{"platform":"binary_sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-1280","class_name":"IASZone","translation_key":null,"device_class":"opening","state_class":null,"entity_category":null,"entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"cluster_handlers":[{"class_name":"IASZoneClusterHandler","generic_id":"cluster_handler_0x0500","endpoint_id":1,"cluster":{"id":1280,"name":"IAS Zone","type":"server","commands":[{"id":0,"name":"enroll_response","schema":{"command":"enroll_response","fields":[{"name":"enroll_response_code","type":"EnrollResponse","optional":false},{"name":"zone_id","type":"uint8_t","optional":false}]},"direction":1,"is_manufacturer_specific":null},{"id":1,"name":"init_normal_op_mode","schema":{"command":"init_normal_op_mode","fields":[]},"direction":0,"is_manufacturer_specific":null},{"id":2,"name":"init_test_mode","schema":{"command":"init_test_mode","fields":[{"name":"test_mode_duration","type":"uint8_t","optional":false},{"name":"current_zone_sensitivity_level","type":"uint8_t","optional":false}]},"direction":0,"is_manufacturer_specific":null}]},"id":"1:0x0500","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0500","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null},"button,00:0d:6f:00:0f:3a:e3:69-1-3":{"platform":"button","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-3","class_name":"IdentifyButton","translation_key":null,"device_class":"identify","state_class":null,"entity_category":"diagnostic","entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"cluster_handlers":[{"class_name":"IdentifyClusterHandler","generic_id":"cluster_handler_0x0003","endpoint_id":1,"cluster":{"id":3,"name":"Identify","type":"server","commands":[{"id":0,"name":"identify","schema":{"command":"identify","fields":[{"name":"identify_time","type":"uint16_t","optional":false}]},"direction":0,"is_manufacturer_specific":null},{"id":1,"name":"identify_query","schema":{"command":"identify_query","fields":[]},"direction":0,"is_manufacturer_specific":null},{"id":64,"name":"trigger_effect","schema":{"command":"trigger_effect","fields":[{"name":"effect_id","type":"EffectIdentifier","optional":false},{"name":"effect_variant","type":"EffectVariant","optional":false}]},"direction":0,"is_manufacturer_specific":null}]},"id":"1:0x0003","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0003","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null},"sensor,00:0d:6f:00:0f:3a:e3:69-1-1":{"platform":"sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-1","class_name":"Battery","translation_key":null,"device_class":"battery","state_class":"measurement","entity_category":"diagnostic","entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"cluster_handlers":[{"class_name":"PowerConfigurationClusterHandler","generic_id":"cluster_handler_0x0001","endpoint_id":1,"cluster":{"id":1,"name":"Power Configuration","type":"server","commands":[]},"id":"1:0x0001","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0001","status":"initialized","value_attribute":"battery_voltage"}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null},"sensor,00:0d:6f:00:0f:3a:e3:69-1-1026":{"platform":"sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-1026","class_name":"Temperature","translation_key":null,"device_class":"temperature","state_class":"measurement","entity_category":null,"entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"cluster_handlers":[{"class_name":"TemperatureMeasurementClusterHandler","generic_id":"cluster_handler_0x0402","endpoint_id":1,"cluster":{"id":1026,"name":"Temperature Measurement","type":"server","commands":[]},"id":"1:0x0402","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0402","status":"initialized","value_attribute":"measured_value"}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null},"sensor,00:0d:6f:00:0f:3a:e3:69-1-0-rssi":{"platform":"sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-0-rssi","class_name":"RSSISensor","translation_key":"rssi","device_class":"signal_strength","state_class":"measurement","entity_category":"diagnostic","entity_registry_enabled_default":false,"enabled":true,"fallback_name":null,"cluster_handlers":[{"class_name":"BasicClusterHandler","generic_id":"cluster_handler_0x0000","endpoint_id":1,"cluster":{"id":0,"name":"Basic","type":"server","commands":[{"id":0,"name":"reset_fact_default","schema":{"command":"reset_fact_default","fields":[]},"direction":0,"is_manufacturer_specific":null}]},"id":"1:0x0000","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0000","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null},"sensor,00:0d:6f:00:0f:3a:e3:69-1-0-lqi":{"platform":"sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-0-lqi","class_name":"LQISensor","translation_key":"lqi","device_class":null,"state_class":"measurement","entity_category":"diagnostic","entity_registry_enabled_default":false,"enabled":true,"fallback_name":null,"cluster_handlers":[{"class_name":"BasicClusterHandler","generic_id":"cluster_handler_0x0000","endpoint_id":1,"cluster":{"id":0,"name":"Basic","type":"server","commands":[{"id":0,"name":"reset_fact_default","schema":{"command":"reset_fact_default","fields":[]},"direction":0,"is_manufacturer_specific":null}]},"id":"1:0x0000","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0000","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null},"update,00:0d:6f:00:0f:3a:e3:69-1-25-firmware_update":{"platform":"update","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-25-firmware_update","class_name":"FirmwareUpdateEntity","translation_key":null,"device_class":"firmware","state_class":null,"entity_category":"config","entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"cluster_handlers":[{"class_name":"OtaClientClusterHandler","generic_id":"cluster_handler_0x0019","endpoint_id":1,"cluster":{"id":25,"name":"Ota","type":"client","commands":[{"id":3,"name":"image_block","schema":{"command":"image_block","fields":[{"name":"field_control","type":"FieldControl","optional":false},{"name":"manufacturer_code","type":"uint16_t","optional":false},{"name":"image_type","type":"uint16_t","optional":false},{"name":"file_version","type":"uint32_t","optional":false},{"name":"file_offset","type":"uint32_t","optional":false},{"name":"maximum_data_size","type":"uint8_t","optional":false},{"name":"request_node_addr","type":"EUI64","optional":false},{"name":"minimum_block_period","type":"uint16_t","optional":false}]},"direction":0,"is_manufacturer_specific":null},{"id":4,"name":"image_page","schema":{"command":"image_page","fields":[{"name":"field_control","type":"FieldControl","optional":false},{"name":"manufacturer_code","type":"uint16_t","optional":false},{"name":"image_type","type":"uint16_t","optional":false},{"name":"file_version","type":"uint32_t","optional":false},{"name":"file_offset","type":"uint32_t","optional":false},{"name":"maximum_data_size","type":"uint8_t","optional":false},{"name":"page_size","type":"uint16_t","optional":false},{"name":"response_spacing","type":"uint16_t","optional":false},{"name":"request_node_addr","type":"EUI64","optional":false}]},"direction":0,"is_manufacturer_specific":null},{"id":1,"name":"query_next_image","schema":{"command":"query_next_image","fields":[{"name":"field_control","type":"FieldControl","optional":false},{"name":"manufacturer_code","type":"uint16_t","optional":false},{"name":"image_type","type":"uint16_t","optional":false},{"name":"current_file_version","type":"uint32_t","optional":false},{"name":"hardware_version","type":"uint16_t","optional":false}]},"direction":0,"is_manufacturer_specific":null},{"id":8,"name":"query_specific_file","schema":{"command":"query_specific_file","fields":[{"name":"request_node_addr","type":"EUI64","optional":false},{"name":"manufacturer_code","type":"uint16_t","optional":false},{"name":"image_type","type":"uint16_t","optional":false},{"name":"file_version","type":"uint32_t","optional":false},{"name":"current_zigbee_stack_version","type":"uint16_t","optional":false}]},"direction":0,"is_manufacturer_specific":null},{"id":6,"name":"upgrade_end","schema":{"command":"upgrade_end","fields":[{"name":"status","type":"Status","optional":false},{"name":"manufacturer_code","type":"uint16_t","optional":false},{"name":"image_type","type":"uint16_t","optional":false},{"name":"file_version","type":"uint32_t","optional":false}]},"direction":0,"is_manufacturer_specific":null}]},"id":"1:0x0019","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0019","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null}},"neighbors":[],"routes":[],"endpoint_names":[{"name":"IAS_ZONE"},{"name":"unknown 12 device_type of 0xc2df profile id"}],"device_automation_triggers":{"device_offline,device_offline":{"device_event_type":"device_offline"}}} \ No newline at end of file +{"ieee":"00:0d:6f:00:0f:3a:e3:69","nwk":"0x970A","manufacturer":"CentraLite","model":"3320-L","name":"CentraLite 3320-L","quirk_applied":true,"quirk_class":"zhaquirks.centralite.ias.CentraLiteIASSensor","quirk_id":null,"manufacturer_code":49887,"power_source":"Battery or Unknown","lqi":null,"rssi":null,"available":true,"device_type":"EndDevice","signature":{"node_descriptor":{"logical_type":2,"complex_descriptor_available":0,"user_descriptor_available":0,"reserved":0,"aps_flags":0,"frequency_band":8,"mac_capability_flags":128,"manufacturer_code":49887,"maximum_buffer_size":82,"maximum_incoming_transfer_size":82,"server_mask":0,"maximum_outgoing_transfer_size":82,"descriptor_capability_field":0},"endpoints":{"1":{"profile_id":"0x0104","device_type":"0x0402","input_clusters":["0x0000","0x0001","0x0003","0x0020","0x0402","0x0500","0x0b05"],"output_clusters":["0x0019"]},"2":{"profile_id":"0xc2df","device_type":"0x000c","input_clusters":["0x0000","0x0003","0x0b05","0xfc0f"],"output_clusters":["0x0003"]}},"manufacturer":"CentraLite","model":"3320-L"},"active_coordinator":false,"entities":{"binary_sensor,00:0d:6f:00:0f:3a:e3:69-1-1280":{"platform":"binary_sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-1280","class_name":"IASZone","translation_key":null,"device_class":"opening","state_class":null,"entity_category":null,"entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"state":{"class_name":"IASZone","state":false,"available":true},"cluster_handlers":[{"class_name":"IASZoneClusterHandler","generic_id":"cluster_handler_0x0500","endpoint_id":1,"cluster":{"id":1280,"name":"IAS Zone","type":"server","endpoint_id":1,"endpoint_attribute":"ias_zone"},"id":"1:0x0500","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0500","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null,"attribute_name":"zone_status"},"button,00:0d:6f:00:0f:3a:e3:69-1-3":{"platform":"button","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-3","class_name":"IdentifyButton","translation_key":null,"device_class":"identify","state_class":null,"entity_category":"diagnostic","entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"state":{"class_name":"IdentifyButton","available":true,"state":null},"cluster_handlers":[{"class_name":"IdentifyClusterHandler","generic_id":"cluster_handler_0x0003","endpoint_id":1,"cluster":{"id":3,"name":"Identify","type":"server","endpoint_id":1,"endpoint_attribute":"identify"},"id":"1:0x0003","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0003","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null,"command":"identify","attribute_name":null,"attribute_value":null,"args":[5],"kwargs":{}},"sensor,00:0d:6f:00:0f:3a:e3:69-1-1":{"platform":"sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-1","class_name":"Battery","translation_key":null,"device_class":"battery","state_class":"measurement","entity_category":"diagnostic","entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"state":{"class_name":"Battery","state":100,"battery_size":"Other","battery_quantity":1,"battery_voltage":2.8,"available":true},"cluster_handlers":[{"class_name":"PowerConfigurationClusterHandler","generic_id":"cluster_handler_0x0001","endpoint_id":1,"cluster":{"id":1,"name":"Power Configuration","type":"server","endpoint_id":1,"endpoint_attribute":"power"},"id":"1:0x0001","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0001","status":"initialized","value_attribute":"battery_voltage"}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null,"attribute":"battery_percentage_remaining","decimals":1,"divisor":1,"multiplier":1,"unit":"%"},"sensor,00:0d:6f:00:0f:3a:e3:69-1-1026":{"platform":"sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-1026","class_name":"Temperature","translation_key":null,"device_class":"temperature","state_class":"measurement","entity_category":null,"entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"state":{"class_name":"Temperature","available":true,"state":20.2},"cluster_handlers":[{"class_name":"TemperatureMeasurementClusterHandler","generic_id":"cluster_handler_0x0402","endpoint_id":1,"cluster":{"id":1026,"name":"Temperature Measurement","type":"server","endpoint_id":1,"endpoint_attribute":"temperature"},"id":"1:0x0402","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0402","status":"initialized","value_attribute":"measured_value"}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null,"attribute":"measured_value","decimals":1,"divisor":100,"multiplier":1,"unit":"°C"},"sensor,00:0d:6f:00:0f:3a:e3:69-1-0-rssi":{"platform":"sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-0-rssi","class_name":"RSSISensor","translation_key":"rssi","device_class":"signal_strength","state_class":"measurement","entity_category":"diagnostic","entity_registry_enabled_default":false,"enabled":true,"fallback_name":null,"state":{"class_name":"RSSISensor","available":true,"state":null},"cluster_handlers":[{"class_name":"BasicClusterHandler","generic_id":"cluster_handler_0x0000","endpoint_id":1,"cluster":{"id":0,"name":"Basic","type":"server","endpoint_id":1,"endpoint_attribute":"basic"},"id":"1:0x0000","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0000","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null,"attribute":null,"decimals":1,"divisor":1,"multiplier":1,"unit":"dBm"},"sensor,00:0d:6f:00:0f:3a:e3:69-1-0-lqi":{"platform":"sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-0-lqi","class_name":"LQISensor","translation_key":"lqi","device_class":null,"state_class":"measurement","entity_category":"diagnostic","entity_registry_enabled_default":false,"enabled":true,"fallback_name":null,"state":{"class_name":"LQISensor","available":true,"state":null},"cluster_handlers":[{"class_name":"BasicClusterHandler","generic_id":"cluster_handler_0x0000","endpoint_id":1,"cluster":{"id":0,"name":"Basic","type":"server","endpoint_id":1,"endpoint_attribute":"basic"},"id":"1:0x0000","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0000","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null,"attribute":null,"decimals":1,"divisor":1,"multiplier":1,"unit":null},"update,00:0d:6f:00:0f:3a:e3:69-1-25-firmware_update":{"platform":"update","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-25-firmware_update","class_name":"FirmwareUpdateEntity","translation_key":null,"device_class":"firmware","state_class":null,"entity_category":"config","entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"state":{"class_name":"FirmwareUpdateEntity","available":true,"installed_version":null,"in_progress":false,"progress":0,"latest_version":null,"release_summary":null,"release_notes":null,"release_url":null},"cluster_handlers":[{"class_name":"OtaClientClusterHandler","generic_id":"cluster_handler_0x0019","endpoint_id":1,"cluster":{"id":25,"name":"Ota","type":"client","endpoint_id":1,"endpoint_attribute":"ota"},"id":"1:0x0019","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0019","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null,"supported_features":7}},"neighbors":[],"routes":[],"endpoint_names":[{"name":"IAS_ZONE"},{"name":"unknown 12 device_type of 0xc2df profile id"}],"device_automation_triggers":{"device_offline,device_offline":{"device_event_type":"device_offline"}}} \ No newline at end of file diff --git a/tests/test_gateway.py b/tests/test_gateway.py index eb7f45abf..a0c19bec6 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -647,7 +647,7 @@ def test_gateway_raw_device_initialized( signature={ "manufacturer": "FakeManufacturer", "model": "FakeModel", - "node_desc": { + "node_descriptor": { "logical_type": LogicalType.EndDevice, "complex_descriptor_available": 0, "user_descriptor_available": 0, @@ -664,9 +664,9 @@ def test_gateway_raw_device_initialized( }, "endpoints": { 1: { - "profile_id": 260, - "device_type": zha.DeviceType.ON_OFF_SWITCH, - "input_clusters": [0], + "profile_id": "0x0104", + "device_type": "0x0000", + "input_clusters": ["0x0000"], "output_clusters": [], } }, diff --git a/tests/test_model.py b/tests/test_model.py index 9203959f0..7f9f63258 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -26,7 +26,7 @@ def test_ser_deser_zha_event(): assert zha_event.model_dump() == { "message_type": "event", - "event_type": "zha_event", + "event_type": "device_event", "event": "zha_event", "device_ieee": "00:00:00:00:00:00:00:00", "unique_id": "00:00:00:00:00:00:00:00", @@ -35,7 +35,7 @@ def test_ser_deser_zha_event(): assert ( zha_event.model_dump_json() - == '{"message_type":"event","event_type":"zha_event","event":"zha_event",' + == '{"message_type":"event","event_type":"device_event","event":"zha_event",' '"device_ieee":"00:00:00:00:00:00:00:00","unique_id":"00:00:00:00:00:00:00:00","data":{"key":"value"}}' ) diff --git a/tests/websocket/__init__.py b/tests/websocket/__init__.py new file mode 100644 index 000000000..a766f6adb --- /dev/null +++ b/tests/websocket/__init__.py @@ -0,0 +1 @@ +"""Websocket tests modules.""" diff --git a/tests/websocket/test_binary_sensor.py b/tests/websocket/test_binary_sensor.py new file mode 100644 index 000000000..bbc66bd73 --- /dev/null +++ b/tests/websocket/test_binary_sensor.py @@ -0,0 +1,124 @@ +"""Test zhaws binary sensor.""" + +from collections.abc import Awaitable, Callable +from typing import Optional + +import pytest +import zigpy.profiles.zha +from zigpy.zcl.clusters import general, measurement, security + +from zha.application.discovery import Platform +from zha.application.platforms.model import BasePlatformEntity, BinarySensorEntity +from zha.websocket.client.controller import Controller +from zha.websocket.client.proxy import DeviceProxy +from zha.websocket.server.gateway import WebSocketGateway as Server + +from ..common import ( + SIG_EP_INPUT, + SIG_EP_OUTPUT, + SIG_EP_PROFILE, + SIG_EP_TYPE, + create_mock_zigpy_device, + join_zigpy_device, + send_attributes_report, + update_attribute_cache, +) + + +def find_entity( + device_proxy: DeviceProxy, platform: Platform +) -> Optional[BasePlatformEntity]: + """Find an entity for the specified platform on the given device.""" + for entity in device_proxy.device_model.entities.values(): + if entity.platform == platform: + return entity + return None + + +DEVICE_IAS = { + 1: { + SIG_EP_PROFILE: zigpy.profiles.zha.PROFILE_ID, + SIG_EP_TYPE: zigpy.profiles.zha.DeviceType.IAS_ZONE, + SIG_EP_INPUT: [security.IasZone.cluster_id], + SIG_EP_OUTPUT: [], + } +} + + +DEVICE_OCCUPANCY = { + 1: { + SIG_EP_PROFILE: zigpy.profiles.zha.PROFILE_ID, + SIG_EP_TYPE: zigpy.profiles.zha.DeviceType.OCCUPANCY_SENSOR, + SIG_EP_INPUT: [measurement.OccupancySensing.cluster_id], + SIG_EP_OUTPUT: [], + } +} + + +async def async_test_binary_sensor_on_off( + server: Server, cluster: general.OnOff, entity: BinarySensorEntity +) -> None: + """Test getting on and off messages for binary sensors.""" + # binary sensor on + await send_attributes_report(server, cluster, {1: 0, 0: 1, 2: 2}) + assert entity.state.state is True + + # binary sensor off + await send_attributes_report(server, cluster, {1: 1, 0: 0, 2: 2}) + assert entity.state.state is False + + +async def async_test_iaszone_on_off( + server: Server, cluster: security.IasZone, entity: BinarySensorEntity +) -> None: + """Test getting on and off messages for iaszone binary sensors.""" + # binary sensor on + cluster.listener_event("cluster_command", 1, 0, [1]) + await server.async_block_till_done() + assert entity.state.state is True + + # binary sensor off + cluster.listener_event("cluster_command", 1, 0, [0]) + await server.async_block_till_done() + assert entity.state.state is False + + +@pytest.mark.parametrize( + "device, on_off_test, cluster_name, reporting", + [ + (DEVICE_IAS, async_test_iaszone_on_off, "ias_zone", (0,)), + (DEVICE_OCCUPANCY, async_test_binary_sensor_on_off, "occupancy", (1,)), + ], +) +async def test_binary_sensor( + connected_client_and_server: tuple[Controller, Server], + device: dict, + on_off_test: Callable[..., Awaitable[None]], + cluster_name: str, + reporting: tuple, +) -> None: + """Test ZHA binary_sensor platform.""" + controller, server = connected_client_and_server + zigpy_device = create_mock_zigpy_device(server, device) + zhaws_device = await join_zigpy_device(server, zigpy_device) + + await server.async_block_till_done() + + client_device: Optional[DeviceProxy] = controller.devices.get(zhaws_device.ieee) + assert client_device is not None + entity: BinarySensorEntity = find_entity(client_device, Platform.BINARY_SENSOR) # type: ignore + assert entity is not None + assert isinstance(entity, BinarySensorEntity) + assert entity.state.state is False + + # test getting messages that trigger and reset the sensors + cluster = getattr(zigpy_device.endpoints[1], cluster_name) + await on_off_test(server, cluster, entity) + + # test refresh + if cluster_name == "ias_zone": + cluster.PLUGGED_ATTR_READS = {"zone_status": 0} + update_attribute_cache(cluster) + await controller.entities.refresh_state(entity) + await server.async_block_till_done() + assert entity.state.state is False diff --git a/tests/websocket/test_button.py b/tests/websocket/test_button.py new file mode 100644 index 000000000..8c38a7573 --- /dev/null +++ b/tests/websocket/test_button.py @@ -0,0 +1,76 @@ +"""Test ZHA button.""" + +from typing import Optional +from unittest.mock import patch + +from zigpy.const import SIG_EP_PROFILE +from zigpy.profiles import zha +from zigpy.zcl.clusters import general, security +import zigpy.zcl.foundation as zcl_f + +from zha.application.discovery import Platform +from zha.application.platforms.model import BasePlatformEntity, ButtonEntity +from zha.websocket.client.controller import Controller +from zha.websocket.client.proxy import DeviceProxy +from zha.websocket.server.gateway import WebSocketGateway as Server + +from ..common import ( + SIG_EP_INPUT, + SIG_EP_OUTPUT, + SIG_EP_TYPE, + create_mock_zigpy_device, + join_zigpy_device, + mock_coro, +) + + +def find_entity( + device_proxy: DeviceProxy, platform: Platform +) -> Optional[BasePlatformEntity]: + """Find an entity for the specified platform on the given device.""" + for entity in device_proxy.device_model.entities.values(): + if entity.platform == platform: + return entity + return None + + +async def test_button( + connected_client_and_server: tuple[Controller, Server], +) -> None: + """Test zha button platform.""" + controller, server = connected_client_and_server + zigpy_device = create_mock_zigpy_device( + server, + { + 1: { + SIG_EP_INPUT: [ + general.Basic.cluster_id, + general.Identify.cluster_id, + security.IasZone.cluster_id, + ], + SIG_EP_OUTPUT: [], + SIG_EP_TYPE: zha.DeviceType.IAS_ZONE, + SIG_EP_PROFILE: zha.PROFILE_ID, + } + }, + ) + zhaws_device = await join_zigpy_device(server, zigpy_device) + cluster = zigpy_device.endpoints[1].identify + + assert cluster is not None + client_device: Optional[DeviceProxy] = controller.devices.get(zhaws_device.ieee) + assert client_device is not None + entity: ButtonEntity = find_entity(client_device, Platform.BUTTON) # type: ignore + assert entity is not None + assert isinstance(entity, ButtonEntity) + + with patch( + "zigpy.zcl.Cluster.request", + return_value=mock_coro([0x00, zcl_f.Status.SUCCESS]), + ): + await controller.buttons.press(entity) + await server.async_block_till_done() + assert len(cluster.request.mock_calls) == 1 + assert cluster.request.call_args[0][0] is False + assert cluster.request.call_args[0][1] == 0 + assert cluster.request.call_args[0][3] == 5 # duration in seconds diff --git a/tests/websocket/test_client_controller.py b/tests/websocket/test_client_controller.py new file mode 100644 index 000000000..76dc487a6 --- /dev/null +++ b/tests/websocket/test_client_controller.py @@ -0,0 +1,396 @@ +"""Test zha switch.""" + +import logging +from typing import Optional +from unittest.mock import AsyncMock, MagicMock, call + +import pytest +from zigpy.device import Device as ZigpyDevice +from zigpy.profiles import zha +from zigpy.types.named import EUI64 +from zigpy.zcl.clusters import general + +from zha.application.discovery import Platform +from zha.application.gateway import ( + DeviceJoinedDeviceInfo, + DevicePairingStatus, + RawDeviceInitializedDeviceInfo, + RawDeviceInitializedEvent, +) +from zha.application.model import DeviceJoinedEvent, DeviceLeftEvent +from zha.application.platforms.model import ( + BasePlatformEntity, + SwitchEntity, + SwitchGroupEntity, +) +from zha.websocket.client.controller import Controller +from zha.websocket.client.proxy import DeviceProxy, GroupProxy +from zha.websocket.const import ControllerEvents +from zha.websocket.server.api.model import ( + ReadClusterAttributesResponse, + WriteClusterAttributeResponse, +) +from zha.websocket.server.gateway import WebSocketGateway as Server +from zha.zigbee.device import Device +from zha.zigbee.group import Group, GroupMemberReference +from zha.zigbee.model import GroupInfo + +from ..common import ( + SIG_EP_INPUT, + SIG_EP_OUTPUT, + SIG_EP_PROFILE, + SIG_EP_TYPE, + async_find_group_entity_id, + create_mock_zigpy_device, + find_entity_id, + join_zigpy_device, + update_attribute_cache, +) + +ON = 1 +OFF = 0 +IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8" +IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8" +_LOGGER = logging.getLogger(__name__) + + +@pytest.fixture +def zigpy_device(connected_client_and_server: tuple[Controller, Server]) -> ZigpyDevice: + """Device tracker zigpy device.""" + _, server = connected_client_and_server + endpoints = { + 1: { + SIG_EP_INPUT: [general.Basic.cluster_id, general.OnOff.cluster_id], + SIG_EP_OUTPUT: [], + SIG_EP_TYPE: zha.DeviceType.ON_OFF_SWITCH, + SIG_EP_PROFILE: zha.PROFILE_ID, + } + } + return create_mock_zigpy_device(server, endpoints) + + +@pytest.fixture +async def device_switch_1( + connected_client_and_server: tuple[Controller, Server], +) -> Device: + """Test zha switch platform.""" + + _, server = connected_client_and_server + + zigpy_device = create_mock_zigpy_device( + server, + { + 1: { + SIG_EP_INPUT: [general.OnOff.cluster_id, general.Groups.cluster_id], + SIG_EP_OUTPUT: [], + SIG_EP_TYPE: zha.DeviceType.ON_OFF_SWITCH, + SIG_EP_PROFILE: zha.PROFILE_ID, + } + }, + ieee=IEEE_GROUPABLE_DEVICE, + ) + zha_device = await join_zigpy_device(server, zigpy_device) + zha_device.available = True + return zha_device + + +def get_entity(zha_dev: DeviceProxy, entity_id: str) -> BasePlatformEntity: + """Get entity.""" + entities = { + entity.platform + "." + entity.unique_id: entity + for entity in zha_dev.device_model.entities.values() + } + return entities[entity_id] + + +def get_group_entity( + group_proxy: GroupProxy, entity_id: str +) -> Optional[SwitchGroupEntity]: + """Get entity.""" + + return group_proxy.group_model.entities.get(entity_id) + + +@pytest.fixture +async def device_switch_2( + connected_client_and_server: tuple[Controller, Server], +) -> Device: + """Test zha switch platform.""" + + controller, server = connected_client_and_server + zigpy_device = create_mock_zigpy_device( + server, + { + 1: { + SIG_EP_INPUT: [general.OnOff.cluster_id, general.Groups.cluster_id], + SIG_EP_OUTPUT: [], + SIG_EP_TYPE: zha.DeviceType.ON_OFF_SWITCH, + SIG_EP_PROFILE: zha.PROFILE_ID, + } + }, + ieee=IEEE_GROUPABLE_DEVICE2, + ) + zha_device = await join_zigpy_device(server, zigpy_device) + zha_device.available = True + return zha_device + + +async def test_controller_devices( + zigpy_device: ZigpyDevice, + connected_client_and_server: tuple[Controller, Server], +) -> None: + """Test client controller device related functionality.""" + controller, server = connected_client_and_server + zha_device = await join_zigpy_device(server, zigpy_device) + entity_id = find_entity_id(Platform.SWITCH, zha_device) + assert entity_id is not None + + client_device: Optional[DeviceProxy] = controller.devices.get(zha_device.ieee) + assert client_device is not None + entity: SwitchEntity = get_entity(client_device, entity_id) + assert entity is not None + + assert isinstance(entity, SwitchEntity) + + assert entity.state.state is False + + await controller.load_devices() + devices: dict[EUI64, DeviceProxy] = controller.devices + assert len(devices) == 2 + assert zha_device.ieee in devices + + # test client -> server + server.application_controller.remove = AsyncMock( + wraps=server.application_controller.remove + ) + await controller.devices_helper.remove_device(client_device.device_model) + assert server.application_controller.remove.await_count == 1 + assert server.application_controller.remove.await_args == call( + client_device.device_model.ieee + ) + + # test server -> client + server.device_removed(zigpy_device) + await server.async_block_till_done() + assert len(controller.devices) == 1 + + # rejoin the device + zha_device = await join_zigpy_device(server, zigpy_device) + await server.async_block_till_done() + assert len(controller.devices) == 2 + + # test rejoining the same device + zha_device = await join_zigpy_device(server, zigpy_device) + await server.async_block_till_done() + assert len(controller.devices) == 2 + + # we removed and joined the device again so lets get the entity again + client_device = controller.devices.get(zha_device.ieee) + assert client_device is not None + entity: SwitchEntity = get_entity(client_device, entity_id) # type: ignore + assert entity is not None + + # test device reconfigure + zha_device.async_configure = AsyncMock(wraps=zha_device.async_configure) + await controller.devices_helper.reconfigure_device(client_device.device_model) + await server.async_block_till_done() + assert zha_device.async_configure.call_count == 1 + assert zha_device.async_configure.await_count == 1 + assert zha_device.async_configure.call_args == call() + + # test read cluster attribute + cluster = zigpy_device.endpoints.get(1).on_off + assert cluster is not None + cluster.PLUGGED_ATTR_READS = {general.OnOff.AttributeDefs.on_off.name: 1} + update_attribute_cache(cluster) + await controller.entities.refresh_state(entity) + await server.async_block_till_done() + read_response: ReadClusterAttributesResponse = ( + await controller.devices_helper.read_cluster_attributes( + client_device.device_model, + general.OnOff.cluster_id, + "in", + 1, + [general.OnOff.AttributeDefs.on_off.name], + ) + ) + await server.async_block_till_done() + assert read_response is not None + assert read_response.success is True + assert len(read_response.succeeded) == 1 + assert len(read_response.failed) == 0 + assert read_response.succeeded[general.OnOff.AttributeDefs.on_off.name] == 1 + assert read_response.cluster.id == general.OnOff.cluster_id + assert read_response.cluster.endpoint_id == 1 + assert ( + read_response.cluster.endpoint_attribute + == general.OnOff.AttributeDefs.on_off.name + ) + assert read_response.cluster.name == general.OnOff.name + assert entity.state.state is True + + # test write cluster attribute + write_response: WriteClusterAttributeResponse = ( + await controller.devices_helper.write_cluster_attribute( + client_device.device_model, + general.OnOff.cluster_id, + "in", + 1, + general.OnOff.AttributeDefs.on_off.name, + 0, + ) + ) + assert write_response is not None + assert write_response.success is True + assert write_response.cluster.id == general.OnOff.cluster_id + assert write_response.cluster.endpoint_id == 1 + assert ( + write_response.cluster.endpoint_attribute + == general.OnOff.AttributeDefs.on_off.name + ) + assert write_response.cluster.name == general.OnOff.name + + await controller.entities.refresh_state(entity) + await server.async_block_till_done() + assert entity.state.state is False + + # test controller events + listener = MagicMock() + + # test device joined + controller.on_event(ControllerEvents.DEVICE_JOINED, listener) + device_joined_event = DeviceJoinedEvent( + device_info=DeviceJoinedDeviceInfo( + pairing_status=DevicePairingStatus.PAIRED, + ieee=zigpy_device.ieee, + nwk=zigpy_device.nwk, + ) + ) + server.device_joined(zigpy_device) + await server.async_block_till_done() + assert listener.call_count == 1 + assert listener.call_args == call(device_joined_event) + + # test device left + listener.reset_mock() + controller.on_event(ControllerEvents.DEVICE_LEFT, listener) + server.device_left(zigpy_device) + await server.async_block_till_done() + assert listener.call_count == 1 + assert listener.call_args == call( + DeviceLeftEvent( + ieee=zigpy_device.ieee, + nwk=str(zigpy_device.nwk).lower(), + ) + ) + + # test raw device initialized + listener.reset_mock() + controller.on_event(ControllerEvents.RAW_DEVICE_INITIALIZED, listener) + server.raw_device_initialized(zigpy_device) + await server.async_block_till_done() + assert listener.call_count == 1 + assert listener.call_args == call( + RawDeviceInitializedEvent( + device_info=RawDeviceInitializedDeviceInfo( + pairing_status=DevicePairingStatus.INTERVIEW_COMPLETE, + ieee=zigpy_device.ieee, + nwk=zigpy_device.nwk, + manufacturer=client_device.device_model.manufacturer, + model=client_device.device_model.model, + signature=client_device.device_model.signature, + ), + ) + ) + + +async def test_controller_groups( + device_switch_1: Device, + device_switch_2: Device, + connected_client_and_server: tuple[Controller, Server], +) -> None: + """Test client controller group related functionality.""" + controller, server = connected_client_and_server + member_ieee_addresses = [device_switch_1.ieee, device_switch_2.ieee] + members = [ + GroupMemberReference(ieee=device_switch_1.ieee, endpoint_id=1), + GroupMemberReference(ieee=device_switch_2.ieee, endpoint_id=1), + ] + + # test creating a group with 2 members + zha_group: Group = await server.async_create_zigpy_group("Test Group", members) + await server.async_block_till_done() + + assert zha_group is not None + assert len(zha_group.members) == 2 + for member in zha_group.members: + assert member.device.ieee in member_ieee_addresses + assert member.group == zha_group + assert member.endpoint is not None + + entity_id = async_find_group_entity_id(Platform.SWITCH, zha_group) + assert entity_id is not None + + group_proxy: Optional[GroupProxy] = controller.groups.get(zha_group.group_id) + assert group_proxy is not None + + entity: SwitchGroupEntity = get_group_entity(group_proxy, entity_id) # type: ignore + assert entity is not None + + assert isinstance(entity, SwitchGroupEntity) + + assert entity is not None + + await controller.load_groups() + groups: dict[int, GroupProxy] = controller.groups + # the application controller mock starts with a group already created + assert len(groups) == 2 + assert zha_group.group_id in groups + + # test client -> server + await controller.groups_helper.remove_groups([group_proxy.group_model]) + await server.async_block_till_done() + assert len(controller.groups) == 1 + + # test client create group + client_device1: Optional[DeviceProxy] = controller.devices.get(device_switch_1.ieee) + assert client_device1 is not None + entity_id1 = find_entity_id(Platform.SWITCH, device_switch_1) + assert entity_id1 is not None + entity1: SwitchEntity = get_entity(client_device1, entity_id1) + assert entity1 is not None + + client_device2: Optional[DeviceProxy] = controller.devices.get(device_switch_2.ieee) + assert client_device2 is not None + entity_id2 = find_entity_id(Platform.SWITCH, device_switch_2) + assert entity_id2 is not None + entity2: SwitchEntity = get_entity(client_device2, entity_id2) + assert entity2 is not None + + response: GroupInfo = await controller.groups_helper.create_group( + members=[entity1, entity2], name="Test Group Controller" + ) + await server.async_block_till_done() + assert len(controller.groups) == 2 + assert response.group_id in controller.groups + assert response.name == "Test Group Controller" + assert client_device1.device_model.ieee in response.members_by_ieee + assert client_device2.device_model.ieee in response.members_by_ieee + + # test remove member from group from controller + response = await controller.groups_helper.remove_group_members(response, [entity2]) + await server.async_block_till_done() + assert len(controller.groups) == 2 + assert response.group_id in controller.groups + assert response.name == "Test Group Controller" + assert client_device1.device_model.ieee in response.members_by_ieee + assert client_device2.device_model.ieee not in response.members_by_ieee + + # test add member to group from controller + response = await controller.groups_helper.add_group_members(response, [entity2]) + await server.async_block_till_done() + assert len(controller.groups) == 2 + assert response.group_id in controller.groups + assert response.name == "Test Group Controller" + assert client_device1.device_model.ieee in response.members_by_ieee + assert client_device2.device_model.ieee in response.members_by_ieee diff --git a/tests/websocket/test_number.py b/tests/websocket/test_number.py new file mode 100644 index 000000000..eee7e1195 --- /dev/null +++ b/tests/websocket/test_number.py @@ -0,0 +1,119 @@ +"""Test zha analog output.""" + +from typing import Optional +from unittest.mock import call + +from zigpy.profiles import zha +import zigpy.types +from zigpy.zcl.clusters import general + +from zha.application.discovery import Platform +from zha.application.platforms.model import BasePlatformEntity, NumberEntity +from zha.websocket.client.controller import Controller +from zha.websocket.client.proxy import DeviceProxy +from zha.websocket.server.gateway import WebSocketGateway as Server + +from ..common import ( + SIG_EP_INPUT, + SIG_EP_OUTPUT, + SIG_EP_PROFILE, + SIG_EP_TYPE, + create_mock_zigpy_device, + join_zigpy_device, + send_attributes_report, + update_attribute_cache, +) + + +def find_entity( + device_proxy: DeviceProxy, platform: Platform +) -> Optional[BasePlatformEntity]: + """Find an entity for the specified platform on the given device.""" + for entity in device_proxy.device_model.entities.values(): + if entity.platform == platform: + return entity + return None + + +async def test_number( + connected_client_and_server: tuple[Controller, Server], +) -> None: + """Test zha number platform.""" + controller, server = connected_client_and_server + zigpy_device = create_mock_zigpy_device( + server, + { + 1: { + SIG_EP_TYPE: zigpy.profiles.zha.DeviceType.LEVEL_CONTROL_SWITCH, + SIG_EP_INPUT: [ + general.AnalogOutput.cluster_id, + general.Basic.cluster_id, + ], + SIG_EP_OUTPUT: [], + SIG_EP_PROFILE: zha.PROFILE_ID, + } + }, + ) + cluster: general.AnalogOutput = zigpy_device.endpoints.get(1).analog_output + cluster.PLUGGED_ATTR_READS = { + "max_present_value": 100.0, + "min_present_value": 1.0, + "relinquish_default": 50.0, + "resolution": 1.1, + "description": "PWM1", + "engineering_units": 98, + "application_type": 4 * 0x10000, + } + update_attribute_cache(cluster) + cluster.PLUGGED_ATTR_READS["present_value"] = 15.0 + + zha_device = await join_zigpy_device(server, zigpy_device) + # one for present_value and one for the rest configuration attributes + assert cluster.read_attributes.call_count == 3 + attr_reads = set() + for call_args in cluster.read_attributes.call_args_list: + attr_reads |= set(call_args[0][0]) + assert "max_present_value" in attr_reads + assert "min_present_value" in attr_reads + assert "relinquish_default" in attr_reads + assert "resolution" in attr_reads + assert "description" in attr_reads + assert "engineering_units" in attr_reads + assert "application_type" in attr_reads + + client_device: Optional[DeviceProxy] = controller.devices.get(zha_device.ieee) + assert client_device is not None + entity: NumberEntity = find_entity(client_device, Platform.NUMBER) # type: ignore + assert entity is not None + assert isinstance(entity, NumberEntity) + + assert cluster.read_attributes.call_count == 3 + + # test that the state is 15.0 + assert entity.state.state == 15.0 + + # test attributes + assert entity.min_value == 1.0 + assert entity.max_value == 100.0 + assert entity.step == 1.1 + + # change value from device + assert cluster.read_attributes.call_count == 3 + await send_attributes_report(server, cluster, {0x0055: 15}) + await server.async_block_till_done() + assert entity.state.state == 15.0 + + # update value from device + await send_attributes_report(server, cluster, {0x0055: 20}) + await server.async_block_till_done() + assert entity.state.state == 20.0 + + # change value from client + await controller.numbers.set_value(entity, 30.0) + await server.async_block_till_done() + + assert len(cluster.write_attributes.mock_calls) == 1 + assert cluster.write_attributes.call_args == call( + {"present_value": 30.0}, manufacturer=None + ) + assert entity.state.state == 30.0 diff --git a/tests/websocket/test_siren.py b/tests/websocket/test_siren.py new file mode 100644 index 000000000..8115f4d49 --- /dev/null +++ b/tests/websocket/test_siren.py @@ -0,0 +1,177 @@ +"""Test zha siren.""" + +import asyncio +from typing import Optional +from unittest.mock import patch + +import pytest +from zigpy.const import SIG_EP_PROFILE +from zigpy.profiles import zha +from zigpy.zcl.clusters import general, security +import zigpy.zcl.foundation as zcl_f + +from zha.application.discovery import Platform +from zha.application.platforms.model import BasePlatformEntity +from zha.websocket.client.controller import Controller +from zha.websocket.client.proxy import DeviceProxy +from zha.websocket.server.gateway import WebSocketGateway as Server +from zha.zigbee.device import Device + +from ..common import ( + SIG_EP_INPUT, + SIG_EP_OUTPUT, + SIG_EP_TYPE, + create_mock_zigpy_device, + join_zigpy_device, + mock_coro, +) + + +def find_entity( + device_proxy: DeviceProxy, platform: Platform +) -> Optional[BasePlatformEntity]: + """Find an entity for the specified platform on the given device.""" + for entity in device_proxy.device_model.entities.values(): + if entity.platform == platform: + return entity + return None + + +@pytest.fixture +async def siren( + connected_client_and_server: tuple[Controller, Server], +) -> tuple[Device, security.IasWd]: + """Siren fixture.""" + + _, server = connected_client_and_server + zigpy_device = create_mock_zigpy_device( + server, + { + 1: { + SIG_EP_INPUT: [general.Basic.cluster_id, security.IasWd.cluster_id], + SIG_EP_OUTPUT: [], + SIG_EP_TYPE: zha.DeviceType.IAS_WARNING_DEVICE, + SIG_EP_PROFILE: zha.PROFILE_ID, + } + }, + ) + + zha_device = await join_zigpy_device(server, zigpy_device) + return zha_device, zigpy_device.endpoints[1].ias_wd + + +async def test_siren( + siren: tuple[Device, security.IasWd], + connected_client_and_server: tuple[Controller, Server], +) -> None: + """Test zha siren platform.""" + + zha_device, cluster = siren + assert cluster is not None + controller, server = connected_client_and_server + + client_device: Optional[DeviceProxy] = controller.devices.get(zha_device.ieee) + assert client_device is not None + entity = find_entity(client_device, Platform.SIREN) + assert entity is not None + + assert entity.state.state is False + + # turn on from client + with patch( + "zigpy.zcl.Cluster.request", + return_value=mock_coro([0x00, zcl_f.Status.SUCCESS]), + ): + await controller.sirens.turn_on(entity) + await server.async_block_till_done() + assert len(cluster.request.mock_calls) == 1 + assert cluster.request.call_args[0][0] is False + assert cluster.request.call_args[0][1] == 0 + assert cluster.request.call_args[0][3] == 50 # bitmask for default args + assert cluster.request.call_args[0][4] == 5 # duration in seconds + assert cluster.request.call_args[0][5] == 0 + assert cluster.request.call_args[0][6] == 2 + cluster.request.reset_mock() + + # test that the state has changed to on + assert entity.state.state is True + + # turn off from client + with patch( + "zigpy.zcl.Cluster.request", + return_value=mock_coro([0x00, zcl_f.Status.SUCCESS]), + ): + await controller.sirens.turn_off(entity) + await server.async_block_till_done() + assert len(cluster.request.mock_calls) == 1 + assert cluster.request.call_args[0][0] is False + assert cluster.request.call_args[0][1] == 0 + assert cluster.request.call_args[0][3] == 2 # bitmask for default args + assert cluster.request.call_args[0][4] == 5 # duration in seconds + assert cluster.request.call_args[0][5] == 0 + assert cluster.request.call_args[0][6] == 2 + cluster.request.reset_mock() + + # test that the state has changed to off + assert entity.state.state is False + + # turn on from client with options + with patch( + "zigpy.zcl.Cluster.request", + return_value=mock_coro([0x00, zcl_f.Status.SUCCESS]), + ): + await controller.sirens.turn_on(entity, duration=100, volume_level=3, tone=3) + await server.async_block_till_done() + assert len(cluster.request.mock_calls) == 1 + assert cluster.request.call_args[0][0] is False + assert cluster.request.call_args[0][1] == 0 + # assert (cluster.request.call_args[0][3] == 51) # bitmask for specified args TODO fix kwargs on siren methods so args are processed correctly + assert cluster.request.call_args[0][4] == 100 # duration in seconds + assert cluster.request.call_args[0][5] == 0 + assert cluster.request.call_args[0][6] == 2 + cluster.request.reset_mock() + + # test that the state has changed to on + assert entity.state.state is True + + +@pytest.mark.looptime +async def test_siren_timed_off( + siren: tuple[Device, security.IasWd], + connected_client_and_server: tuple[Controller, Server], +) -> None: + """Test zha siren platform.""" + zha_device, cluster = siren + assert cluster is not None + controller, server = connected_client_and_server + + client_device: Optional[DeviceProxy] = controller.devices.get(zha_device.ieee) + assert client_device is not None + entity = find_entity(client_device, Platform.SIREN) + assert entity is not None + + assert entity.state.state is False + + # turn on from client + with patch( + "zigpy.zcl.Cluster.request", + return_value=mock_coro([0x00, zcl_f.Status.SUCCESS]), + ): + await controller.sirens.turn_on(entity) + await server.async_block_till_done() + assert len(cluster.request.mock_calls) == 1 + assert cluster.request.call_args[0][0] is False + assert cluster.request.call_args[0][1] == 0 + assert cluster.request.call_args[0][3] == 50 # bitmask for default args + assert cluster.request.call_args[0][4] == 5 # duration in seconds + assert cluster.request.call_args[0][5] == 0 + assert cluster.request.call_args[0][6] == 2 + cluster.request.reset_mock() + + # test that the state has changed to on + assert entity.state.state is True + + await asyncio.sleep(6) + + # test that the state has changed to off from the timer + assert entity.state.state is False diff --git a/tests/websocket/test_switch.py b/tests/websocket/test_switch.py new file mode 100644 index 000000000..95cc0ef6c --- /dev/null +++ b/tests/websocket/test_switch.py @@ -0,0 +1,363 @@ +"""Test zha switch.""" + +import asyncio +import logging +from typing import Optional +from unittest.mock import call, patch + +import pytest +from zigpy.device import Device as ZigpyDevice +from zigpy.profiles import zha +import zigpy.profiles.zha +from zigpy.zcl.clusters import general +import zigpy.zcl.foundation as zcl_f + +from tests.common import mock_coro +from zha.application.discovery import Platform +from zha.application.platforms.model import ( + BasePlatformEntity, + SwitchEntity, + SwitchGroupEntity, +) +from zha.websocket.client.controller import Controller +from zha.websocket.client.proxy import DeviceProxy, GroupProxy +from zha.websocket.server.gateway import WebSocketGateway as Server +from zha.zigbee.device import Device +from zha.zigbee.group import Group, GroupMemberReference + +from ..common import ( + SIG_EP_INPUT, + SIG_EP_OUTPUT, + SIG_EP_PROFILE, + SIG_EP_TYPE, + async_find_group_entity_id, + create_mock_zigpy_device, + join_zigpy_device, + send_attributes_report, + update_attribute_cache, +) + +ON = 1 +OFF = 0 +IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8" +IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8" +_LOGGER = logging.getLogger(__name__) + + +def find_entity( + device_proxy: DeviceProxy, platform: Platform +) -> Optional[BasePlatformEntity]: + """Find an entity for the specified platform on the given device.""" + for entity in device_proxy.device_model.entities.values(): + if entity.platform == platform: + return entity + return None + + +def get_group_entity( + group_proxy: GroupProxy, entity_id: str +) -> Optional[SwitchGroupEntity]: + """Get entity.""" + + return group_proxy.group_model.entities.get(entity_id) + + +@pytest.fixture +def zigpy_device(connected_client_and_server: tuple[Controller, Server]) -> ZigpyDevice: + """Device tracker zigpy device.""" + controller, server = connected_client_and_server + zigpy_device = create_mock_zigpy_device( + server, + { + 1: { + SIG_EP_INPUT: [general.Basic.cluster_id, general.OnOff.cluster_id], + SIG_EP_OUTPUT: [], + SIG_EP_TYPE: zha.DeviceType.ON_OFF_SWITCH, + SIG_EP_PROFILE: zigpy.profiles.zha.PROFILE_ID, + } + }, + ) + return zigpy_device + + +@pytest.fixture +async def device_switch_1( + connected_client_and_server: tuple[Controller, Server], +) -> Device: + """Test zha switch platform.""" + + _, server = connected_client_and_server + zigpy_device = create_mock_zigpy_device( + server, + { + 1: { + SIG_EP_INPUT: [general.OnOff.cluster_id, general.Groups.cluster_id], + SIG_EP_OUTPUT: [], + SIG_EP_TYPE: zha.DeviceType.ON_OFF_SWITCH, + SIG_EP_PROFILE: zigpy.profiles.zha.PROFILE_ID, + } + }, + ieee=IEEE_GROUPABLE_DEVICE, + ) + zha_device = await join_zigpy_device(server, zigpy_device) + zha_device.available = True + return zha_device + + +@pytest.fixture +async def device_switch_2( + connected_client_and_server: tuple[Controller, Server], +) -> Device: + """Test zha switch platform.""" + + _, server = connected_client_and_server + zigpy_device = create_mock_zigpy_device( + server, + { + 1: { + SIG_EP_INPUT: [general.OnOff.cluster_id, general.Groups.cluster_id], + SIG_EP_OUTPUT: [], + SIG_EP_TYPE: zha.DeviceType.ON_OFF_SWITCH, + SIG_EP_PROFILE: zigpy.profiles.zha.PROFILE_ID, + } + }, + ieee=IEEE_GROUPABLE_DEVICE2, + ) + zha_device = await join_zigpy_device(server, zigpy_device) + zha_device.available = True + return zha_device + + +async def test_switch( + zigpy_device: ZigpyDevice, + connected_client_and_server: tuple[Controller, Server], +) -> None: + """Test zha switch platform.""" + controller, server = connected_client_and_server + zha_device = await join_zigpy_device(server, zigpy_device) + cluster = zigpy_device.endpoints.get(1).on_off + + client_device: Optional[DeviceProxy] = controller.devices.get(zha_device.ieee) + assert client_device is not None + entity: SwitchEntity = find_entity(client_device, Platform.SWITCH) + assert entity is not None + + assert isinstance(entity, SwitchEntity) + + assert entity.state.state is False + + # turn on at switch + await send_attributes_report(server, cluster, {1: 0, 0: 1, 2: 2}) + assert entity.state.state is True + + # turn off at switch + await send_attributes_report(server, cluster, {1: 1, 0: 0, 2: 2}) + assert entity.state.state is False + + # turn on from client + with patch( + "zigpy.zcl.Cluster.request", + return_value=[0x00, zcl_f.Status.SUCCESS], + ): + await controller.switches.turn_on(entity) + await server.async_block_till_done() + assert entity.state.state is True + assert len(cluster.request.mock_calls) == 1 + assert cluster.request.call_args == call( + False, + ON, + cluster.commands_by_name["on"].schema, + expect_reply=True, + manufacturer=None, + tsn=None, + ) + + # Fail turn off from client + with patch( + "zigpy.zcl.Cluster.request", + return_value=mock_coro([0x01, zcl_f.Status.FAILURE]), + ): + await controller.switches.turn_off(entity) + await server.async_block_till_done() + assert entity.state.state is True + assert len(cluster.request.mock_calls) == 1 + assert cluster.request.call_args == call( + False, + OFF, + cluster.commands_by_name["off"].schema, + expect_reply=True, + manufacturer=None, + tsn=None, + ) + + # turn off from client + with patch( + "zigpy.zcl.Cluster.request", + return_value=[0x00, zcl_f.Status.SUCCESS], + ): + await controller.switches.turn_off(entity) + await server.async_block_till_done() + assert entity.state.state is False + assert len(cluster.request.mock_calls) == 1 + assert cluster.request.call_args == call( + False, + OFF, + cluster.commands_by_name["off"].schema, + expect_reply=True, + manufacturer=None, + tsn=None, + ) + + # Fail turn on from client + with patch( + "zigpy.zcl.Cluster.request", + return_value=[0x01, zcl_f.Status.FAILURE], + ): + await controller.switches.turn_on(entity) + await server.async_block_till_done() + assert entity.state.state is False + assert len(cluster.request.mock_calls) == 1 + assert cluster.request.call_args == call( + False, + ON, + cluster.commands_by_name["on"].schema, + expect_reply=True, + manufacturer=None, + tsn=None, + ) + + # test updating entity state from client + assert entity.state.state is False + cluster.PLUGGED_ATTR_READS = {"on_off": True} + update_attribute_cache(cluster) + await controller.entities.refresh_state(entity) + await server.async_block_till_done() + assert entity.state.state is True + + +@pytest.mark.looptime +async def test_zha_group_switch_entity( + device_switch_1: Device, + device_switch_2: Device, + connected_client_and_server: tuple[Controller, Server], +) -> None: + """Test the switch entity for a ZHA group.""" + controller, server = connected_client_and_server + member_ieee_addresses = [device_switch_1.ieee, device_switch_2.ieee] + members = [ + GroupMemberReference(ieee=device_switch_1.ieee, endpoint_id=1), + GroupMemberReference(ieee=device_switch_2.ieee, endpoint_id=1), + ] + + # test creating a group with 2 members + zha_group: Group = await server.async_create_zigpy_group("Test Group", members) + await server.async_block_till_done() + + assert zha_group is not None + assert len(zha_group.members) == 2 + for member in zha_group.members: + assert member.device.ieee in member_ieee_addresses + assert member.group == zha_group + assert member.endpoint is not None + + entity_id = async_find_group_entity_id(Platform.SWITCH, zha_group) + assert entity_id is not None + + group_proxy: Optional[GroupProxy] = controller.groups.get(2) + assert group_proxy is not None + + entity: SwitchGroupEntity = get_group_entity(group_proxy, entity_id) # type: ignore + assert entity is not None + + assert isinstance(entity, SwitchGroupEntity) + + group_cluster_on_off = zha_group.zigpy_group.endpoint[general.OnOff.cluster_id] + dev1_cluster_on_off = device_switch_1.device.endpoints[1].on_off + dev2_cluster_on_off = device_switch_2.device.endpoints[1].on_off + + # test that the lights were created and are off + assert entity.state.state is False + + # turn on from HA + with patch( + "zigpy.zcl.Cluster.request", + return_value=[0x00, zcl_f.Status.SUCCESS], + ): + # turn on via UI + await controller.switches.turn_on(entity) + await server.async_block_till_done() + assert len(group_cluster_on_off.request.mock_calls) == 1 + assert group_cluster_on_off.request.call_args == call( + False, + ON, + group_cluster_on_off.commands_by_name["on"].schema, + expect_reply=True, + manufacturer=None, + tsn=None, + ) + assert entity.state.state is True + + # turn off from HA + with patch( + "zigpy.zcl.Cluster.request", + return_value=[0x00, zcl_f.Status.SUCCESS], + ): + # turn off via UI + await controller.switches.turn_off(entity) + await server.async_block_till_done() + assert len(group_cluster_on_off.request.mock_calls) == 1 + assert group_cluster_on_off.request.call_args == call( + False, + OFF, + group_cluster_on_off.commands_by_name["off"].schema, + expect_reply=True, + manufacturer=None, + tsn=None, + ) + assert entity.state.state is False + + # test some of the group logic to make sure we key off states correctly + await send_attributes_report(server, dev1_cluster_on_off, {0: 1}) + await send_attributes_report(server, dev2_cluster_on_off, {0: 1}) + await server.async_block_till_done() + + # group member updates are debounced + assert entity.state.state is False + await asyncio.sleep(1) + await server.async_block_till_done() + + # test that group light is on + assert entity.state.state is True + + await send_attributes_report(server, dev1_cluster_on_off, {0: 0}) + await server.async_block_till_done() + + # test that group light is still on + assert entity.state.state is True + + await send_attributes_report(server, dev2_cluster_on_off, {0: 0}) + await server.async_block_till_done() + + # group member updates are debounced + assert entity.state.state is True + await asyncio.sleep(1) + await server.async_block_till_done() + + # test that group light is now off + assert entity.state.state is False + + await send_attributes_report(server, dev1_cluster_on_off, {0: 1}) + await server.async_block_till_done() + + # group member updates are debounced + assert entity.state.state is False + await asyncio.sleep(1) + await server.async_block_till_done() + + # test that group light is now back on + assert entity.state.state is True + + # test value error calling client api with wrong entity type + with pytest.raises(ValueError): + await controller.sirens.turn_on(entity) + await server.async_block_till_done() diff --git a/tests/test_websocket_server_client.py b/tests/websocket/test_websocket_server_client.py similarity index 100% rename from tests/test_websocket_server_client.py rename to tests/websocket/test_websocket_server_client.py diff --git a/zha/application/gateway.py b/zha/application/gateway.py index 561451f8c..b0807a26e 100644 --- a/zha/application/gateway.py +++ b/zha/application/gateway.py @@ -5,10 +5,9 @@ import asyncio from contextlib import suppress from datetime import timedelta -from enum import Enum import logging import time -from typing import Any, Final, Literal, Self, TypeVar, cast +from typing import Final, Self, TypeVar, cast from zhaquirks import setup as setup_quirks from zigpy.application import ControllerApplication @@ -24,10 +23,17 @@ import zigpy.group from zigpy.quirks.v2 import UNBUILT_QUIRK_BUILDERS from zigpy.state import State -from zigpy.types.named import EUI64, NWK +from zigpy.types.named import EUI64 +from zigpy.zdo import ZDO from zha.application import discovery from zha.application.const import ( + ATTR_DEVICE_TYPE, + ATTR_ENDPOINTS, + ATTR_MANUFACTURER, + ATTR_MODEL, + ATTR_NODE_DESCRIPTOR, + ATTR_PROFILE_ID, CONF_USE_THREAD, UNKNOWN_MANUFACTURER, UNKNOWN_MODEL, @@ -36,124 +42,42 @@ ZHA_GW_MSG_DEVICE_JOINED, ZHA_GW_MSG_DEVICE_LEFT, ZHA_GW_MSG_DEVICE_REMOVED, - ZHA_GW_MSG_GROUP_ADDED, - ZHA_GW_MSG_GROUP_MEMBER_ADDED, - ZHA_GW_MSG_GROUP_MEMBER_REMOVED, - ZHA_GW_MSG_GROUP_REMOVED, ZHA_GW_MSG_RAW_INIT, RadioType, ) from zha.application.helpers import DeviceAvailabilityChecker, GlobalUpdater, ZHAData +from zha.application.model import ( + ConnectionLostEvent, + DeviceFullyInitializedEvent, + DeviceJoinedDeviceInfo, + DeviceJoinedEvent, + DeviceLeftEvent, + DevicePairingStatus, + DeviceRemovedEvent, + ExtendedDeviceInfoWithPairingStatus, + GroupAddedEvent, + GroupMemberAddedEvent, + GroupMemberRemovedEvent, + GroupRemovedEvent, + RawDeviceInitializedDeviceInfo, + RawDeviceInitializedEvent, +) from zha.async_ import ( AsyncUtilMixin, create_eager_task, gather_with_limited_concurrency, ) from zha.event import EventBase -from zha.model import BaseEvent, BaseModel -from zha.zigbee.device import Device, DeviceInfo, DeviceStatus, ExtendedDeviceInfo -from zha.zigbee.group import Group, GroupInfo, GroupMemberReference +from zha.zigbee.device import Device +from zha.zigbee.endpoint import ATTR_IN_CLUSTERS, ATTR_OUT_CLUSTERS +from zha.zigbee.group import Group, GroupMemberReference +from zha.zigbee.model import DeviceStatus BLOCK_LOG_TIMEOUT: Final[int] = 60 _R = TypeVar("_R") _LOGGER = logging.getLogger(__name__) -class DevicePairingStatus(Enum): - """Status of a device.""" - - PAIRED = 1 - INTERVIEW_COMPLETE = 2 - CONFIGURED = 3 - INITIALIZED = 4 - - -class DeviceInfoWithPairingStatus(DeviceInfo): - """Information about a device with pairing status.""" - - pairing_status: DevicePairingStatus - - -class ExtendedDeviceInfoWithPairingStatus(ExtendedDeviceInfo): - """Information about a device with pairing status.""" - - pairing_status: DevicePairingStatus - - -class DeviceJoinedDeviceInfo(BaseModel): - """Information about a device.""" - - ieee: EUI64 - nwk: NWK - pairing_status: DevicePairingStatus - - -class ConnectionLostEvent(BaseEvent): - """Event to signal that the connection to the radio has been lost.""" - - event_type: Literal["zha_gateway_message"] = "zha_gateway_message" - event: Literal["connection_lost"] = "connection_lost" - exception: Exception | None = None - - -class DeviceJoinedEvent(BaseEvent): - """Event to signal that a device has joined the network.""" - - device_info: DeviceJoinedDeviceInfo - event_type: Literal["zha_gateway_message"] = "zha_gateway_message" - event: Literal["device_joined"] = "device_joined" - - -class DeviceLeftEvent(BaseEvent): - """Event to signal that a device has left the network.""" - - ieee: EUI64 - nwk: NWK - event_type: Literal["zha_gateway_message"] = "zha_gateway_message" - event: Literal["device_left"] = "device_left" - - -class RawDeviceInitializedDeviceInfo(DeviceJoinedDeviceInfo): - """Information about a device that has been initialized without quirks loaded.""" - - model: str - manufacturer: str - signature: dict[str, Any] - - -class RawDeviceInitializedEvent(BaseEvent): - """Event to signal that a device has been initialized without quirks loaded.""" - - device_info: RawDeviceInitializedDeviceInfo - event_type: Literal["zha_gateway_message"] = "zha_gateway_message" - event: Literal["raw_device_initialized"] = "raw_device_initialized" - - -class DeviceFullInitEvent(BaseEvent): - """Event to signal that a device has been fully initialized.""" - - device_info: ExtendedDeviceInfoWithPairingStatus - new_join: bool = False - event_type: Literal["zha_gateway_message"] = "zha_gateway_message" - event: Literal["device_fully_initialized"] = "device_fully_initialized" - - -class GroupEvent(BaseEvent): - """Event to signal a group event.""" - - event: str - group_info: GroupInfo - event_type: Literal["zha_gateway_message"] = "zha_gateway_message" - - -class DeviceRemovedEvent(BaseEvent): - """Event to signal that a device has been removed.""" - - device_info: ExtendedDeviceInfo - event_type: Literal["zha_gateway_message"] = "zha_gateway_message" - event: Literal["device_removed"] = "device_removed" - - class Gateway(AsyncUtilMixin, EventBase): """Gateway that handles events that happen on the ZHA Zigbee network.""" @@ -424,7 +348,33 @@ def raw_device_initialized(self, device: zigpy.device.Device) -> None: # pylint manufacturer=device.manufacturer if device.manufacturer else UNKNOWN_MANUFACTURER, - signature=device.get_signature(), + signature={ + ATTR_NODE_DESCRIPTOR: device.node_desc.as_dict(), + ATTR_ENDPOINTS: { + ep_id: { + ATTR_PROFILE_ID: f"0x{endpoint.profile_id:04x}" + if endpoint.profile_id is not None + else "", + ATTR_DEVICE_TYPE: f"0x{endpoint.device_type:04x}" + if endpoint.device_type is not None + else "", + ATTR_IN_CLUSTERS: [ + f"0x{cluster_id:04x}" + for cluster_id in sorted(endpoint.in_clusters) + ], + ATTR_OUT_CLUSTERS: [ + f"0x{cluster_id:04x}" + for cluster_id in sorted(endpoint.out_clusters) + ], + } + for ep_id, endpoint in device.endpoints.items() + if not isinstance(endpoint, ZDO) + }, + ATTR_MANUFACTURER: device.manufacturer + if device.manufacturer + else UNKNOWN_MANUFACTURER, + ATTR_MODEL: device.model if device.model else UNKNOWN_MODEL, + }, ) ), ) @@ -463,7 +413,7 @@ def group_member_removed( zha_group.clear_caches() discovery.GROUP_PROBE.discover_group_entities(zha_group) zha_group.info("group_member_removed - endpoint: %s", endpoint) - self._emit_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_REMOVED) + self._emit_group_gateway_message(zigpy_group, GroupMemberRemovedEvent) def group_member_added( self, zigpy_group: zigpy.group.Group, endpoint: zigpy.endpoint.Endpoint @@ -474,35 +424,38 @@ def group_member_added( zha_group.clear_caches() discovery.GROUP_PROBE.discover_group_entities(zha_group) zha_group.info("group_member_added - endpoint: %s", endpoint) - self._emit_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_ADDED) + self._emit_group_gateway_message(zigpy_group, GroupMemberAddedEvent) def group_added(self, zigpy_group: zigpy.group.Group) -> None: """Handle zigpy group added event.""" zha_group = self.get_or_create_group(zigpy_group) zha_group.info("group_added") # need to dispatch for entity creation here - self._emit_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_ADDED) + self._emit_group_gateway_message(zigpy_group, GroupAddedEvent) def group_removed(self, zigpy_group: zigpy.group.Group) -> None: """Handle zigpy group removed event.""" - self._emit_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_REMOVED) + self._emit_group_gateway_message(zigpy_group, GroupRemovedEvent) zha_group = self._groups.pop(zigpy_group.group_id) zha_group.info("group_removed") def _emit_group_gateway_message( # pylint: disable=unused-argument self, zigpy_group: zigpy.group.Group, - gateway_message_type: str, + gateway_message_type: GroupRemovedEvent + | GroupAddedEvent + | GroupMemberAddedEvent + | GroupMemberRemovedEvent, ) -> None: """Send the gateway event for a zigpy group event.""" zha_group = self._groups.get(zigpy_group.group_id) if zha_group is not None: + response = gateway_message_type( + group_info=zha_group.info_object, + ) self.emit( - gateway_message_type, - GroupEvent( - event=gateway_message_type, - group_info=zha_group.info_object, - ), + response.event, + response, ) def device_removed(self, device: zigpy.device.Device) -> None: @@ -610,7 +563,7 @@ async def async_device_initialized(self, device: zigpy.device.Device) -> None: ) self.emit( ZHA_GW_MSG_DEVICE_FULL_INIT, - DeviceFullInitEvent(device_info=device_info), + DeviceFullyInitializedEvent(device_info=device_info), ) async def _async_device_joined(self, zha_device: Device) -> None: @@ -625,7 +578,7 @@ async def _async_device_joined(self, zha_device: Device) -> None: self.create_platform_entities() self.emit( ZHA_GW_MSG_DEVICE_FULL_INIT, - DeviceFullInitEvent(device_info=device_info, new_join=True), + DeviceFullyInitializedEvent(device_info=device_info, new_join=True), ) async def _async_device_rejoined(self, zha_device: Device) -> None: @@ -643,7 +596,7 @@ async def _async_device_rejoined(self, zha_device: Device) -> None: ) self.emit( ZHA_GW_MSG_DEVICE_FULL_INIT, - DeviceFullInitEvent(device_info=device_info), + DeviceFullyInitializedEvent(device_info=device_info), ) # force async_initialize() to fire so don't explicitly call it zha_device.available = False diff --git a/zha/application/model.py b/zha/application/model.py new file mode 100644 index 000000000..61320667e --- /dev/null +++ b/zha/application/model.py @@ -0,0 +1,144 @@ +"""Models for the ZHA application module.""" + +from enum import Enum +from typing import Any, Literal + +from zigpy.types.named import EUI64, NWK + +from zha.model import BaseEvent, BaseModel +from zha.zigbee.model import DeviceInfo, ExtendedDeviceInfo, GroupInfo + + +class DevicePairingStatus(Enum): + """Status of a device.""" + + PAIRED = 1 + INTERVIEW_COMPLETE = 2 + CONFIGURED = 3 + INITIALIZED = 4 + + +class DeviceInfoWithPairingStatus(DeviceInfo): + """Information about a device with pairing status.""" + + pairing_status: DevicePairingStatus + + +class ExtendedDeviceInfoWithPairingStatus(ExtendedDeviceInfo): + """Information about a device with pairing status.""" + + pairing_status: DevicePairingStatus + + +class DeviceJoinedDeviceInfo(BaseModel): + """Information about a device.""" + + ieee: EUI64 + nwk: NWK + pairing_status: DevicePairingStatus + + +class ConnectionLostEvent(BaseEvent): + """Event to signal that the connection to the radio has been lost.""" + + event_type: Literal["zha_gateway_message"] = "zha_gateway_message" + event: Literal["connection_lost"] = "connection_lost" + exception: Exception | None = None + + +class DeviceJoinedEvent(BaseEvent): + """Event to signal that a device has joined the network.""" + + device_info: DeviceJoinedDeviceInfo + event_type: Literal["zha_gateway_message"] = "zha_gateway_message" + event: Literal["device_joined"] = "device_joined" + + +class DeviceLeftEvent(BaseEvent): + """Event to signal that a device has left the network.""" + + ieee: EUI64 + nwk: NWK + event_type: Literal["zha_gateway_message"] = "zha_gateway_message" + event: Literal["device_left"] = "device_left" + + +class RawDeviceInitializedDeviceInfo(DeviceJoinedDeviceInfo): + """Information about a device that has been initialized without quirks loaded.""" + + model: str + manufacturer: str + signature: dict[str, Any] + + +class RawDeviceInitializedEvent(BaseEvent): + """Event to signal that a device has been initialized without quirks loaded.""" + + device_info: RawDeviceInitializedDeviceInfo + event_type: Literal["zha_gateway_message"] = "zha_gateway_message" + event: Literal["raw_device_initialized"] = "raw_device_initialized" + + +class DeviceFullyInitializedEvent(BaseEvent): + """Event to signal that a device has been fully initialized.""" + + device_info: ExtendedDeviceInfoWithPairingStatus + new_join: bool = False + event_type: Literal["zha_gateway_message"] = "zha_gateway_message" + event: Literal["device_fully_initialized"] = "device_fully_initialized" + + +class GroupRemovedEvent(BaseEvent): + """Group removed event.""" + + event_type: Literal["zha_gateway_message"] = "zha_gateway_message" + event: Literal["group_removed"] = "group_removed" + group_info: GroupInfo + + +class GroupAddedEvent(BaseEvent): + """Group added event.""" + + event_type: Literal["zha_gateway_message"] = "zha_gateway_message" + event: Literal["group_added"] = "group_added" + group_info: GroupInfo + + +class GroupMemberAddedEvent(BaseEvent): + """Group member added event.""" + + event_type: Literal["zha_gateway_message"] = "zha_gateway_message" + event: Literal["group_member_added"] = "group_member_added" + group_info: GroupInfo + + +class GroupMemberRemovedEvent(BaseEvent): + """Group member removed event.""" + + event_type: Literal["zha_gateway_message"] = "zha_gateway_message" + event: Literal["group_member_removed"] = "group_member_removed" + group_info: GroupInfo + + +class DeviceRemovedEvent(BaseEvent): + """Event to signal that a device has been removed.""" + + device_info: ExtendedDeviceInfo + event_type: Literal["zha_gateway_message"] = "zha_gateway_message" + event: Literal["device_removed"] = "device_removed" + + +class DeviceOfflineEvent(BaseEvent): + """Device offline event.""" + + event: Literal["device_offline"] = "device_offline" + event_type: Literal["device_event"] = "device_event" + device: ExtendedDeviceInfo + + +class DeviceOnlineEvent(BaseEvent): + """Device online event.""" + + event: Literal["device_online"] = "device_online" + event_type: Literal["device_event"] = "device_event" + device: ExtendedDeviceInfo diff --git a/zha/application/platforms/__init__.py b/zha/application/platforms/__init__.py index 8aaee54cd..836aba940 100644 --- a/zha/application/platforms/__init__.py +++ b/zha/application/platforms/__init__.py @@ -5,21 +5,25 @@ from abc import abstractmethod import asyncio from contextlib import suppress -from enum import StrEnum from functools import cached_property import logging -from typing import TYPE_CHECKING, Any, Literal, Optional, final +from typing import TYPE_CHECKING, Any, final from zigpy.quirks.v2 import EntityMetadata, EntityType -from zigpy.types.named import EUI64 from zha.application import Platform +from zha.application.platforms.model import ( + BaseEntityInfo, + BaseIdentifiers, + EntityCategory, + EntityStateChangedEvent, + GroupEntityIdentifiers, + PlatformEntityIdentifiers, +) from zha.const import STATE_CHANGED from zha.debounce import Debouncer from zha.event import EventBase from zha.mixins import LogMixin -from zha.model import BaseEvent, BaseModel -from zha.zigbee.cluster_handlers import ClusterHandlerInfo if TYPE_CHECKING: from zha.zigbee.cluster_handlers import ClusterHandler @@ -33,73 +37,6 @@ DEFAULT_UPDATE_GROUP_FROM_CHILD_DELAY: float = 0.5 -class EntityCategory(StrEnum): - """Category of an entity.""" - - # Config: An entity which allows changing the configuration of a device. - CONFIG = "config" - - # Diagnostic: An entity exposing some configuration parameter, - # or diagnostics of a device. - DIAGNOSTIC = "diagnostic" - - -class BaseEntityInfo(BaseModel): - """Information about a base entity.""" - - platform: Platform - unique_id: str - class_name: str - translation_key: str | None - device_class: str | None - state_class: str | None - entity_category: str | None - entity_registry_enabled_default: bool - enabled: bool = True - fallback_name: str | None - - # For platform entities - cluster_handlers: list[ClusterHandlerInfo] - device_ieee: EUI64 | None - endpoint_id: int | None - available: bool | None - - # For group entities - group_id: int | None - - -class BaseIdentifiers(BaseModel): - """Identifiers for the base entity.""" - - unique_id: str - platform: Platform - - -class PlatformEntityIdentifiers(BaseIdentifiers): - """Identifiers for the platform entity.""" - - device_ieee: EUI64 - endpoint_id: int - - -class GroupEntityIdentifiers(BaseIdentifiers): - """Identifiers for the group entity.""" - - group_id: int - - -class EntityStateChangedEvent(BaseEvent): - """Event for when an entity state changes.""" - - event_type: Literal["entity"] = "entity" - event: Literal["state_changed"] = "state_changed" - platform: Platform - unique_id: str - device_ieee: Optional[EUI64] = None - endpoint_id: Optional[int] = None - group_id: Optional[int] = None - - class BaseEntity(LogMixin, EventBase): """Base class for entities.""" @@ -214,6 +151,7 @@ def info_object(self) -> BaseEntityInfo: available=None, # Set by group entities group_id=None, + state=self.state, ) @property @@ -260,7 +198,8 @@ def maybe_emit_state_changed_event(self) -> None: state = self.state if self.__previous_state != state: self.emit( - STATE_CHANGED, EntityStateChangedEvent(**self.identifiers.__dict__) + STATE_CHANGED, + EntityStateChangedEvent(state=self.state, **self.identifiers.__dict__), ) self.__previous_state = state @@ -406,6 +345,17 @@ def state(self) -> dict[str, Any]: state["available"] = self.available return state + def maybe_emit_state_changed_event(self) -> None: + """Send the state of this platform entity.""" + from zha.websocket.server.gateway import WebSocketGateway + + super().maybe_emit_state_changed_event() + if isinstance(self.device.gateway, WebSocketGateway): + self.device.gateway.emit( + STATE_CHANGED, + EntityStateChangedEvent(state=self.state, **self.identifiers.__dict__), + ) + async def async_update(self) -> None: """Retrieve latest state.""" self.debug("polling current state") @@ -479,6 +429,17 @@ def group(self) -> Group: """Return the group.""" return self._group + def maybe_emit_state_changed_event(self) -> None: + """Send the state of this platform entity.""" + from zha.websocket.server.gateway import WebSocketGateway + + super().maybe_emit_state_changed_event() + if isinstance(self.group.gateway, WebSocketGateway): + self.group.gateway.emit( + STATE_CHANGED, + EntityStateChangedEvent(state=self.state, **self.identifiers.__dict__), + ) + def debounced_update(self, _: Any | None = None) -> None: """Debounce updating group entity from member entity updates.""" # Delay to ensure that we get updates from all members before updating the group entity diff --git a/zha/application/platforms/alarm_control_panel/__init__.py b/zha/application/platforms/alarm_control_panel/__init__.py index 0dcb004e3..0f68b9c5a 100644 --- a/zha/application/platforms/alarm_control_panel/__init__.py +++ b/zha/application/platforms/alarm_control_panel/__init__.py @@ -47,6 +47,7 @@ class AlarmControlPanelEntityInfo(BaseEntityInfo): code_arm_required: bool code_format: CodeFormat supported_features: int + max_invalid_tries: int translation_key: str @@ -86,6 +87,7 @@ def info_object(self) -> AlarmControlPanelEntityInfo: code_arm_required=self.code_arm_required, code_format=self.code_format, supported_features=self.supported_features, + max_invalid_tries=self._cluster_handler.max_invalid_tries, ) @property diff --git a/zha/application/platforms/model.py b/zha/application/platforms/model.py new file mode 100644 index 000000000..05f54d719 --- /dev/null +++ b/zha/application/platforms/model.py @@ -0,0 +1,730 @@ +"""Models for the ZHA platforms module.""" + +from datetime import datetime +from enum import StrEnum +from typing import Annotated, Any, Literal, Optional, Union + +from pydantic import Field, ValidationInfo, field_validator +from zigpy.types.named import EUI64 + +from zha.application.discovery import Platform +from zha.event import EventBase +from zha.model import BaseEvent, BaseEventedModel, BaseModel +from zha.zigbee.cluster_handlers.model import ClusterHandlerInfo + + +class EntityCategory(StrEnum): + """Category of an entity.""" + + # Config: An entity which allows changing the configuration of a device. + CONFIG = "config" + + # Diagnostic: An entity exposing some configuration parameter, + # or diagnostics of a device. + DIAGNOSTIC = "diagnostic" + + +class BaseEntityInfo(BaseModel): + """Information about a base entity.""" + + platform: Platform + unique_id: str + class_name: str + translation_key: str | None + device_class: str | None + state_class: str | None + entity_category: str | None + entity_registry_enabled_default: bool + enabled: bool = True + fallback_name: str | None + state: dict[str, Any] + + # For platform entities + cluster_handlers: list[ClusterHandlerInfo] + device_ieee: EUI64 | None + endpoint_id: int | None + available: bool | None + + # For group entities + group_id: int | None + + +class BaseIdentifiers(BaseModel): + """Identifiers for the base entity.""" + + unique_id: str + platform: Platform + + +class PlatformEntityIdentifiers(BaseIdentifiers): + """Identifiers for the platform entity.""" + + device_ieee: EUI64 + endpoint_id: int + + +class GroupEntityIdentifiers(BaseIdentifiers): + """Identifiers for the group entity.""" + + group_id: int + + +class GenericState(BaseModel): + """Default state model.""" + + class_name: Literal[ + "AlarmControlPanel", + "Number", + "MaxHeatSetpointLimit", + "MinHeatSetpointLimit", + "DefaultToneSelectEntity", + "DefaultSirenLevelSelectEntity", + "DefaultStrobeLevelSelectEntity", + "DefaultStrobeSelectEntity", + "AnalogInput", + "Humidity", + "SoilMoisture", + "LeafWetness", + "Illuminance", + "Pressure", + "Temperature", + "CarbonDioxideConcentration", + "CarbonMonoxideConcentration", + "VOCLevel", + "PPBVOCLevel", + "FormaldehydeConcentration", + "ThermostatHVACAction", + "SinopeHVACAction", + "RSSISensor", + "LQISensor", + "LastSeenSensor", + "ElectricalMeasurementFrequency", + "ElectricalMeasurementPowerFactor", + "PolledElectricalMeasurement", + "PiHeatingDemand", + "SetpointChangeSource", + "SetpointChangeSourceTimestamp", + "TimeLeft", + "DeviceTemperature", + "WindowCoveringTypeSensor", + "StartUpCurrentLevelConfigurationEntity", + "StartUpColorTemperatureConfigurationEntity", + "StartupOnOffSelectEntity", + "PM25", + "Sensor", + "OnOffTransitionTimeConfigurationEntity", + "OnLevelConfigurationEntity", + "NumberConfigurationEntity", + "OnTransitionTimeConfigurationEntity", + "OffTransitionTimeConfigurationEntity", + "DefaultMoveRateConfigurationEntity", + "FilterLifeTime", + "IkeaDeviceRunTime", + "IkeaFilterRunTime", + "AqaraSmokeDensityDbm", + "HueV1MotionSensitivity", + "EnumSensor", + "AqaraMonitoringMode", + "AqaraApproachDistance", + "AqaraMotionSensitivity", + "AqaraCurtainMotorPowerSourceSensor", + "AqaraCurtainHookStateSensor", + "AqaraMagnetAC01DetectionDistance", + "AqaraMotionDetectionInterval", + "HueV2MotionSensitivity", + "TiRouterTransmitPower", + "ZCLEnumSelectEntity", + "SmartEnergySummationReceived", + "IdentifyButton", + "FrostLockResetButton", + "Button", + "WriteAttributeButton", + "AqaraSelfTestButton", + "NoPresenceStatusResetButton", + "TimestampSensor", + "DanfossOpenWindowDetection", + "DanfossLoadEstimate", + "DanfossAdaptationRunStatus", + "DanfossPreheatTime", + "DanfossSoftwareErrorCode", + "DanfossMotorStepCounter", + ] + available: Optional[bool] = None + state: Union[str, bool, int, float, datetime, None] = None + + +class DeviceCounterSensorState(BaseModel): + """Device counter sensor state model.""" + + class_name: Literal["DeviceCounterSensor"] = "DeviceCounterSensor" + state: int + + +class DeviceTrackerState(BaseModel): + """Device tracker state model.""" + + class_name: Literal["DeviceScannerEntity"] = "DeviceScannerEntity" + connected: bool + battery_level: Optional[float] = None + + +class BooleanState(BaseModel): + """Boolean value state model.""" + + class_name: Literal[ + "Accelerometer", + "Occupancy", + "Opening", + "BinaryInput", + "Motion", + "IASZone", + "Siren", + "FrostLock", + "BinarySensor", + "ReplaceFilter", + "AqaraLinkageAlarmState", + "HueOccupancy", + "AqaraE1CurtainMotorOpenedByHandBinarySensor", + "DanfossHeatRequired", + "DanfossMountingModeActive", + "DanfossPreheatStatus", + ] + state: bool + + +class CoverState(BaseModel): + """Cover state model.""" + + class_name: Literal["Cover"] = "Cover" + current_position: int | None = None + state: Optional[str] = None + is_opening: bool | None = None + is_closing: bool | None = None + is_closed: bool | None = None + + +class ShadeState(BaseModel): + """Cover state model.""" + + class_name: Literal["Shade", "KeenVent"] + current_position: Optional[int] = ( + None # TODO: how should we represent this when it is None? + ) + is_closed: bool + state: Optional[str] = None + + +class FanState(BaseModel): + """Fan state model.""" + + class_name: Literal["Fan", "FanGroup", "IkeaFan", "KofFan"] + preset_mode: Optional[str] = ( + None # TODO: how should we represent these when they are None? + ) + percentage: Optional[int] = ( + None # TODO: how should we represent these when they are None? + ) + is_on: bool + speed: Optional[str] = None + + +class LockState(BaseModel): + """Lock state model.""" + + class_name: Literal["Lock", "DoorLock"] = "Lock" + is_locked: bool + + +class BatteryState(BaseModel): + """Battery state model.""" + + class_name: Literal["Battery"] = "Battery" + state: Optional[Union[str, float, int]] = None + battery_size: Optional[str] = None + battery_quantity: Optional[int] = None + battery_voltage: Optional[float] = None + + +class ElectricalMeasurementState(BaseModel): + """Electrical measurement state model.""" + + class_name: Literal[ + "ElectricalMeasurement", + "ElectricalMeasurementApparentPower", + "ElectricalMeasurementRMSCurrent", + "ElectricalMeasurementRMSVoltage", + ] + state: Optional[Union[str, float, int]] = None + measurement_type: Optional[str] = None + active_power_max: Optional[str] = None + rms_current_max: Optional[str] = None + rms_voltage_max: Optional[int] = None + + +class LightState(BaseModel): + """Light state model.""" + + class_name: Literal[ + "Light", "HueLight", "ForceOnLight", "LightGroup", "MinTransitionLight" + ] + on: bool + brightness: Optional[int] = None + hs_color: Optional[tuple[float, float]] = None + color_temp: Optional[int] = None + effect: Optional[str] = None + off_brightness: Optional[int] = None + + +class ThermostatState(BaseModel): + """Thermostat state model.""" + + class_name: Literal[ + "Thermostat", + "SinopeTechnologiesThermostat", + "ZenWithinThermostat", + "MoesThermostat", + "BecaThermostat", + "ZONNSMARTThermostat", + ] + current_temperature: Optional[float] = None + target_temperature: Optional[float] = None + target_temperature_low: Optional[float] = None + target_temperature_high: Optional[float] = None + hvac_action: Optional[str] = None + hvac_mode: Optional[str] = None + preset_mode: Optional[str] = None + fan_mode: Optional[str] = None + + +class SwitchState(BaseModel): + """Switch state model.""" + + class_name: Literal[ + "Switch", + "SwitchGroup", + "WindowCoveringInversionSwitch", + "ChildLock", + "DisableLed", + "AqaraHeartbeatIndicator", + "AqaraLinkageAlarm", + "AqaraBuzzerManualMute", + "AqaraBuzzerManualAlarm", + "HueMotionTriggerIndicatorSwitch", + "AqaraE1CurtainMotorHooksLockedSwitch", + "P1MotionTriggerIndicatorSwitch", + "ConfigurableAttributeSwitch", + "OnOffWindowDetectionFunctionConfigurationEntity", + ] + state: bool + + +class SmareEnergyMeteringState(BaseModel): + """Smare energy metering state model.""" + + class_name: Literal["SmartEnergyMetering", "SmartEnergySummation"] + state: Optional[Union[str, float, int]] = None + device_type: Optional[str] = None + status: Optional[str] = None + + +class FirmwareUpdateState(BaseModel): + """Firmware update state model.""" + + class_name: Literal["FirmwareUpdateEntity"] + available: bool + installed_version: str | None + in_progress: bool | None + progress: int | None + latest_version: str | None + release_summary: str | None + release_notes: str | None + release_url: str | None + + +class EntityStateChangedEvent(BaseEvent): + """Event for when an entity state changes.""" + + event_type: Literal["entity"] = "entity" + event: Literal["state_changed"] = "state_changed" + platform: Platform + unique_id: str + device_ieee: Optional[EUI64] = None + endpoint_id: Optional[int] = None + group_id: Optional[int] = None + state: Annotated[ + Optional[ + Union[ + DeviceTrackerState, + CoverState, + ShadeState, + FanState, + LockState, + BatteryState, + ElectricalMeasurementState, + LightState, + SwitchState, + SmareEnergyMeteringState, + GenericState, + BooleanState, + ThermostatState, + FirmwareUpdateState, + DeviceCounterSensorState, + ] + ], + Field(discriminator="class_name"), # noqa: F821 + ] + + +class BasePlatformEntity(EventBase, BaseEntityInfo): + """Base platform entity model.""" + + +class FirmwareUpdateEntity(BasePlatformEntity): + """Firmware update entity model.""" + + class_name: Literal["FirmwareUpdateEntity"] + state: FirmwareUpdateState + + +class LockEntity(BasePlatformEntity): + """Lock entity model.""" + + class_name: Literal["Lock", "DoorLock"] + state: LockState + + +class DeviceTrackerEntity(BasePlatformEntity): + """Device tracker entity model.""" + + class_name: Literal["DeviceScannerEntity"] + state: DeviceTrackerState + + +class CoverEntity(BasePlatformEntity): + """Cover entity model.""" + + class_name: Literal["Cover"] + state: CoverState + + +class ShadeEntity(BasePlatformEntity): + """Shade entity model.""" + + class_name: Literal["Shade", "KeenVent"] + state: ShadeState + + +class BinarySensorEntity(BasePlatformEntity): + """Binary sensor model.""" + + class_name: Literal[ + "Accelerometer", + "Occupancy", + "Opening", + "BinaryInput", + "Motion", + "IASZone", + "FrostLock", + "BinarySensor", + "ReplaceFilter", + "AqaraLinkageAlarmState", + "HueOccupancy", + "AqaraE1CurtainMotorOpenedByHandBinarySensor", + "DanfossHeatRequired", + "DanfossMountingModeActive", + "DanfossPreheatStatus", + ] + attribute_name: str | None = None + state: BooleanState + + +class BaseSensorEntity(BasePlatformEntity): + """Sensor model.""" + + attribute: Optional[str] + decimals: int + divisor: int + multiplier: Union[int, float] + unit: Optional[int | str] + + +class SensorEntity(BaseSensorEntity): + """Sensor entity model.""" + + class_name: Literal[ + "AnalogInput", + "Humidity", + "SoilMoisture", + "LeafWetness", + "Illuminance", + "Pressure", + "Temperature", + "CarbonDioxideConcentration", + "CarbonMonoxideConcentration", + "VOCLevel", + "PPBVOCLevel", + "FormaldehydeConcentration", + "ThermostatHVACAction", + "SinopeHVACAction", + "RSSISensor", + "LQISensor", + "LastSeenSensor", + "ElectricalMeasurementFrequency", + "ElectricalMeasurementPowerFactor", + "PolledElectricalMeasurement", + "PiHeatingDemand", + "SetpointChangeSource", + "SetpointChangeSourceTimestamp", + "TimeLeft", + "DeviceTemperature", + "WindowCoveringTypeSensor", + "PM25", + "Sensor", + "IkeaDeviceRunTime", + "IkeaFilterRunTime", + "AqaraSmokeDensityDbm", + "EnumSensor", + "AqaraCurtainMotorPowerSourceSensor", + "AqaraCurtainHookStateSensor", + "SmartEnergySummationReceived", + "TimestampSensor", + "DanfossOpenWindowDetection", + "DanfossLoadEstimate", + "DanfossAdaptationRunStatus", + "DanfossPreheatTime", + "DanfossSoftwareErrorCode", + "DanfossMotorStepCounter", + ] + state: GenericState + + +class DeviceCounterSensorEntity(BaseEventedModel, BaseEntityInfo): + """Device counter sensor model.""" + + class_name: Literal["DeviceCounterSensor"] + counter: str + counter_value: int + counter_groups: str + counter_group: str + state: DeviceCounterSensorState + + @field_validator("state", mode="before", check_fields=False) + @classmethod + def convert_state( + cls, state: dict | int | None, validation_info: ValidationInfo + ) -> DeviceCounterSensorState: + """Convert counter value to counter_value.""" + if state is not None: + if isinstance(state, int): + return DeviceCounterSensorState(state=state) + if isinstance(state, dict): + if "state" in state: + return DeviceCounterSensorState(state=state["state"]) + else: + return DeviceCounterSensorState( + state=validation_info.data["counter_value"] + ) + return DeviceCounterSensorState(state=validation_info.data["counter_value"]) + + +class BatteryEntity(BaseSensorEntity): + """Battery entity model.""" + + class_name: Literal["Battery"] + state: BatteryState + + +class ElectricalMeasurementEntity(BaseSensorEntity): + """Electrical measurement entity model.""" + + class_name: Literal[ + "ElectricalMeasurement", + "ElectricalMeasurementApparentPower", + "ElectricalMeasurementRMSCurrent", + "ElectricalMeasurementRMSVoltage", + ] + state: ElectricalMeasurementState + + +class SmartEnergyMeteringEntity(BaseSensorEntity): + """Smare energy metering entity model.""" + + class_name: Literal["SmartEnergyMetering", "SmartEnergySummation"] + state: SmareEnergyMeteringState + + +class AlarmControlPanelEntity(BasePlatformEntity): + """Alarm control panel model.""" + + class_name: Literal["AlarmControlPanel"] + supported_features: int + code_arm_required: bool + max_invalid_tries: int + state: GenericState + + +class ButtonEntity( + BasePlatformEntity +): # TODO split into two models CommandButton and WriteAttributeButton + """Button model.""" + + class_name: Literal[ + "IdentifyButton", + "FrostLockResetButton", + "Button", + "WriteAttributeButton", + "AqaraSelfTestButton", + "NoPresenceStatusResetButton", + ] + command: str | None = None + attribute_name: str | None = None + attribute_value: Any | None = None + state: GenericState + + +class FanEntity(BasePlatformEntity): + """Fan model.""" + + class_name: Literal["Fan", "IkeaFan", "KofFan"] + preset_modes: list[str] + supported_features: int + speed_count: int + speed_list: list[str] + percentage_step: float | None = None + state: FanState + + +class LightEntity(BasePlatformEntity): + """Light model.""" + + class_name: Literal["Light", "HueLight", "ForceOnLight", "MinTransitionLight"] + supported_features: int + min_mireds: int + max_mireds: int + effect_list: Optional[list[str]] + state: LightState + + +class NumberEntity(BasePlatformEntity): + """Number entity model.""" + + class_name: Literal[ + "Number", + "MaxHeatSetpointLimit", + "MinHeatSetpointLimit", + "StartUpCurrentLevelConfigurationEntity", + "StartUpColorTemperatureConfigurationEntity", + "OnOffTransitionTimeConfigurationEntity", + "OnLevelConfigurationEntity", + "NumberConfigurationEntity", + "OnTransitionTimeConfigurationEntity", + "OffTransitionTimeConfigurationEntity", + "DefaultMoveRateConfigurationEntity", + "FilterLifeTime", + "AqaraMotionDetectionInterval", + "TiRouterTransmitPower", + ] + engineering_units: int | None = ( + None # TODO: how should we represent this when it is None? + ) + application_type: int | None = ( + None # TODO: how should we represent this when it is None? + ) + step: Optional[float] = None # TODO: how should we represent this when it is None? + min_value: float + max_value: float + state: GenericState + + +class SelectEntity(BasePlatformEntity): + """Select entity model.""" + + class_name: Literal[ + "DefaultToneSelectEntity", + "DefaultSirenLevelSelectEntity", + "DefaultStrobeLevelSelectEntity", + "DefaultStrobeSelectEntity", + "StartupOnOffSelectEntity", + "HueV1MotionSensitivity", + "AqaraMonitoringMode", + "AqaraApproachDistance", + "AqaraMotionSensitivity", + "AqaraMagnetAC01DetectionDistance", + "HueV2MotionSensitivity", + "ZCLEnumSelectEntity", + ] + enum: str + options: list[str] + state: GenericState + + +class ThermostatEntity(BasePlatformEntity): + """Thermostat entity model.""" + + class_name: Literal[ + "Thermostat", + "SinopeTechnologiesThermostat", + "ZenWithinThermostat", + "MoesThermostat", + "BecaThermostat", + "ZONNSMARTThermostat", + ] + state: ThermostatState + hvac_modes: tuple[str, ...] + fan_modes: Optional[list[str]] + preset_modes: Optional[list[str]] + + +class SirenEntity(BasePlatformEntity): + """Siren entity model.""" + + class_name: Literal["Siren"] + available_tones: Optional[Union[list[Union[int, str]], dict[int, str]]] + supported_features: int + state: BooleanState + + +class SwitchEntity(BasePlatformEntity): + """Switch entity model.""" + + class_name: Literal[ + "Switch", + "WindowCoveringInversionSwitch", + "ChildLock", + "DisableLed", + "AqaraHeartbeatIndicator", + "AqaraLinkageAlarm", + "AqaraBuzzerManualMute", + "AqaraBuzzerManualAlarm", + "HueMotionTriggerIndicatorSwitch", + "AqaraE1CurtainMotorHooksLockedSwitch", + "P1MotionTriggerIndicatorSwitch", + "ConfigurableAttributeSwitch", + "OnOffWindowDetectionFunctionConfigurationEntity", + ] + state: SwitchState + + +class GroupEntity(EventBase, BaseEntityInfo): + """Group entity model.""" + + +class LightGroupEntity(GroupEntity): + """Group entity model.""" + + class_name: Literal["LightGroup"] + state: LightState + + +class FanGroupEntity(GroupEntity): + """Group entity model.""" + + class_name: Literal["FanGroup"] + state: FanState + + +class SwitchGroupEntity(GroupEntity): + """Group entity model.""" + + class_name: Literal["SwitchGroup"] + state: SwitchState diff --git a/zha/application/platforms/number/__init__.py b/zha/application/platforms/number/__init__.py index f8647a117..8e642e256 100644 --- a/zha/application/platforms/number/__init__.py +++ b/zha/application/platforms/number/__init__.py @@ -50,7 +50,7 @@ class NumberEntityInfo(BaseEntityInfo): """Number entity info.""" - engineering_units: int + engineering_units: int | None application_type: int | None min_value: float | None max_value: float | None diff --git a/zha/const.py b/zha/const.py index c96c47daf..cab90794d 100644 --- a/zha/const.py +++ b/zha/const.py @@ -13,7 +13,7 @@ class EventTypes(StrEnum): """WS event types.""" - CONTROLLER_EVENT = "controller_event" + CONTROLLER_EVENT = "zha_gateway_message" PLATFORM_ENTITY_EVENT = "platform_entity_event" RAW_ZCL_EVENT = "raw_zcl_event" DEVICE_EVENT = "device_event" diff --git a/zha/model.py b/zha/model.py index 0edfd8d66..d25cbacbd 100644 --- a/zha/model.py +++ b/zha/model.py @@ -13,6 +13,8 @@ ) from zigpy.types.named import EUI64, NWK +from zha.event import EventBase + _LOGGER = logging.getLogger(__name__) @@ -72,14 +74,18 @@ class BaseModel(PydanticBaseModel): @field_serializer("ieee", "device_ieee", check_fields=False) def serialize_ieee(self, ieee: EUI64): """Customize how ieee is serialized.""" - return str(ieee) + if ieee is not None: + return str(ieee) + return ieee @field_serializer( "nwk", "dest_nwk", "next_hop", when_used="json", check_fields=False ) def serialize_nwk(self, nwk: NWK): """Serialize nwk as hex string.""" - return repr(nwk) + if nwk is not None: + return repr(nwk) + return nwk class BaseEvent(BaseModel): @@ -88,3 +94,7 @@ class BaseEvent(BaseModel): message_type: Literal["event"] = "event" event_type: str event: str + + +class BaseEventedModel(EventBase, BaseModel): + """Base evented model.""" diff --git a/zha/websocket/client/client.py b/zha/websocket/client/client.py index ec8fd3ef4..a58c5ea59 100644 --- a/zha/websocket/client/client.py +++ b/zha/websocket/client/client.py @@ -14,9 +14,12 @@ from async_timeout import timeout from zha.event import EventBase -from zha.websocket.client.model.commands import CommandResponse, ErrorResponse from zha.websocket.client.model.messages import Message -from zha.websocket.server.api.model import WebSocketCommand +from zha.websocket.server.api.model import ( + ErrorResponse, + WebSocketCommand, + WebSocketCommandResponse, +) SIZE_PARSE_JSON_EXECUTOR = 8192 _LOGGER = logging.getLogger(__package__) @@ -76,9 +79,9 @@ def new_message_id(self) -> int: async def async_send_command( self, command: WebSocketCommand, - ) -> CommandResponse: + ) -> WebSocketCommandResponse: """Send a command and get a response.""" - future: asyncio.Future[CommandResponse] = self._loop.create_future() + future: asyncio.Future[WebSocketCommandResponse] = self._loop.create_future() message_id = command.message_id = self.new_message_id() self._result_futures[message_id] = future @@ -90,13 +93,13 @@ async def async_send_command( return await future except TimeoutError: _LOGGER.exception("Timeout waiting for response") - return CommandResponse.model_validate( - {"message_id": message_id, "success": False} + return WebSocketCommandResponse.model_validate( + {"message_id": message_id, "success": False, "command": command.command} ) except Exception as err: _LOGGER.exception("Error sending command", exc_info=err) - return CommandResponse.model_validate( - {"message_id": message_id, "success": False} + return WebSocketCommandResponse.model_validate( + {"message_id": message_id, "success": False, "command": command.command} ) finally: self._result_futures.pop(message_id) diff --git a/zha/websocket/client/controller.py b/zha/websocket/client/controller.py index 717632301..a722278ab 100644 --- a/zha/websocket/client/controller.py +++ b/zha/websocket/client/controller.py @@ -9,18 +9,8 @@ from async_timeout import timeout from zigpy.types.named import EUI64 -from zha.event import EventBase -from zha.websocket.client.client import Client -from zha.websocket.client.helpers import ( - ClientHelper, - DeviceHelper, - GroupHelper, - NetworkHelper, - ServerHelper, -) -from zha.websocket.client.model.commands import CommandResponse -from zha.websocket.client.model.events import ( - DeviceConfiguredEvent, +from zha.application.gateway import RawDeviceInitializedEvent +from zha.application.model import ( DeviceFullyInitializedEvent, DeviceJoinedEvent, DeviceLeftEvent, @@ -29,13 +19,33 @@ GroupMemberAddedEvent, GroupMemberRemovedEvent, GroupRemovedEvent, - PlatformEntityStateChangedEvent, - RawDeviceInitializedEvent, - ZHAEvent, +) +from zha.application.platforms.model import EntityStateChangedEvent +from zha.event import EventBase +from zha.websocket.client.client import Client +from zha.websocket.client.helpers import ( + AlarmControlPanelHelper, + ButtonHelper, + ClientHelper, + ClimateHelper, + CoverHelper, + DeviceHelper, + FanHelper, + GroupHelper, + LightHelper, + LockHelper, + NetworkHelper, + NumberHelper, + PlatformEntityHelper, + SelectHelper, + ServerHelper, + SirenHelper, + SwitchHelper, ) from zha.websocket.client.proxy import DeviceProxy, GroupProxy -from zha.websocket.const import ControllerEvents, EventTypes -from zha.websocket.server.api.model import WebSocketCommand +from zha.websocket.const import ControllerEvents +from zha.websocket.server.api.model import WebSocketCommand, WebSocketCommandResponse +from zha.zigbee.model import ZHAEvent CONNECT_TIMEOUT = 10 @@ -55,6 +65,21 @@ def __init__( self._devices: dict[EUI64, DeviceProxy] = {} self._groups: dict[int, GroupProxy] = {} + # set up all of the helper objects + self.lights: LightHelper = LightHelper(self._client) + self.switches: SwitchHelper = SwitchHelper(self._client) + self.sirens: SirenHelper = SirenHelper(self._client) + self.buttons: ButtonHelper = ButtonHelper(self._client) + self.covers: CoverHelper = CoverHelper(self._client) + self.fans: FanHelper = FanHelper(self._client) + self.locks: LockHelper = LockHelper(self._client) + self.numbers: NumberHelper = NumberHelper(self._client) + self.selects: SelectHelper = SelectHelper(self._client) + self.thermostats: ClimateHelper = ClimateHelper(self._client) + self.alarm_control_panels: AlarmControlPanelHelper = AlarmControlPanelHelper( + self._client + ) + self.entities: PlatformEntityHelper = PlatformEntityHelper(self._client) self.clients: ClientHelper = ClientHelper(self._client) self.groups_helper: GroupHelper = GroupHelper(self._client) self.devices_helper: DeviceHelper = DeviceHelper(self._client) @@ -62,11 +87,7 @@ def __init__( self.server_helper: ServerHelper = ServerHelper(self._client) # subscribe to event types we care about - self._client.on_event( - EventTypes.PLATFORM_ENTITY_EVENT, self._handle_event_protocol - ) - self._client.on_event(EventTypes.DEVICE_EVENT, self._handle_event_protocol) - self._client.on_event(EventTypes.CONTROLLER_EVENT, self._handle_event_protocol) + self._client.on_all_events(self._handle_event_protocol) @property def client(self) -> Client: @@ -110,7 +131,7 @@ async def __aexit__( """Disconnect from the websocket server.""" await self.disconnect() - async def send_command(self, command: WebSocketCommand) -> CommandResponse: + async def send_command(self, command: WebSocketCommand) -> WebSocketCommandResponse: """Send a command and get a response.""" return await self._client.async_send_command(command) @@ -126,19 +147,17 @@ async def load_groups(self) -> None: for group_id, group in response_groups.items(): self._groups[group_id] = GroupProxy(group, self, self._client) - def handle_platform_entity_state_changed( - self, event: PlatformEntityStateChangedEvent - ) -> None: + def handle_state_changed(self, event: EntityStateChangedEvent) -> None: """Handle a platform_entity_event from the websocket server.""" _LOGGER.debug("platform_entity_event: %s", event) - if event.device: - device = self.devices.get(event.device.ieee) + if event.device_ieee: + device = self.devices.get(event.device_ieee) if device is None: _LOGGER.warning("Received event from unknown device: %s", event) return device.emit_platform_entity_event(event) - elif event.group: - group = self.groups.get(event.group.id) + elif event.group_id: + group = self.groups.get(event.group_id) if not group: _LOGGER.warning("Received event from unknown group: %s", event) return @@ -159,25 +178,25 @@ def handle_device_joined(self, event: DeviceJoinedEvent) -> None: At this point, no information about the device is known other than its address """ - _LOGGER.info("Device %s - %s joined", event.ieee, event.nwk) + _LOGGER.info( + "Device %s - %s joined", event.device_info.ieee, event.device_info.nwk + ) self.emit(ControllerEvents.DEVICE_JOINED, event) def handle_raw_device_initialized(self, event: RawDeviceInitializedEvent) -> None: """Handle a device initialization without quirks loaded.""" - _LOGGER.info("Device %s - %s raw device initialized", event.ieee, event.nwk) + _LOGGER.info( + "Device %s - %s raw device initialized", + event.device_info.ieee, + event.device_info.nwk, + ) self.emit(ControllerEvents.RAW_DEVICE_INITIALIZED, event) - def handle_device_configured(self, event: DeviceConfiguredEvent) -> None: - """Handle device configured event.""" - device = event.device - _LOGGER.info("Device %s - %s configured", device.ieee, device.nwk) - self.emit(ControllerEvents.DEVICE_CONFIGURED, event) - def handle_device_fully_initialized( self, event: DeviceFullyInitializedEvent ) -> None: """Handle device joined and basic information discovered.""" - device_model = event.device + device_model = event.device_info _LOGGER.info("Device %s - %s initialized", device_model.ieee, device_model.nwk) if device_model.ieee in self.devices: self.devices[device_model.ieee].device_model = device_model @@ -194,7 +213,7 @@ def handle_device_left(self, event: DeviceLeftEvent) -> None: def handle_device_removed(self, event: DeviceRemovedEvent) -> None: """Handle device being removed from the network.""" - device = event.device + device = event.device_info _LOGGER.info( "Device %s - %s has been removed from the network", device.ieee, device.nwk ) @@ -203,26 +222,28 @@ def handle_device_removed(self, event: DeviceRemovedEvent) -> None: def handle_group_member_removed(self, event: GroupMemberRemovedEvent) -> None: """Handle group member removed event.""" - if event.group.id in self.groups: - self.groups[event.group.id].group_model = event.group + if event.group_info.group_id in self.groups: + self.groups[event.group_info.group_id].group_model = event.group_info self.emit(ControllerEvents.GROUP_MEMBER_REMOVED, event) def handle_group_member_added(self, event: GroupMemberAddedEvent) -> None: """Handle group member added event.""" - if event.group.id in self.groups: - self.groups[event.group.id].group_model = event.group + if event.group_info.group_id in self.groups: + self.groups[event.group_info.group_id].group_model = event.group_info self.emit(ControllerEvents.GROUP_MEMBER_ADDED, event) def handle_group_added(self, event: GroupAddedEvent) -> None: """Handle group added event.""" - if event.group.id in self.groups: - self.groups[event.group.id].group_model = event.group + if event.group_info.group_id in self.groups: + self.groups[event.group_info.group_id].group_model = event.group_info else: - self.groups[event.group.id] = GroupProxy(event.group, self, self._client) + self.groups[event.group_info.group_id] = GroupProxy( + event.group_info, self, self._client + ) self.emit(ControllerEvents.GROUP_ADDED, event) def handle_group_removed(self, event: GroupRemovedEvent) -> None: """Handle group removed event.""" - if event.group.id in self.groups: - self.groups.pop(event.group.id) + if event.group_info.group_id in self.groups: + self.groups.pop(event.group_info.group_id) self.emit(ControllerEvents.GROUP_REMOVED, event) diff --git a/zha/websocket/client/helpers.py b/zha/websocket/client/helpers.py index f3d519c7c..be62057d0 100644 --- a/zha/websocket/client/helpers.py +++ b/zha/websocket/client/helpers.py @@ -2,26 +2,74 @@ from __future__ import annotations -from typing import Any, cast +from typing import Any, Literal, cast from zigpy.types.named import EUI64 from zha.application.discovery import Platform +from zha.application.platforms.model import ( + BaseEntityInfo, + BasePlatformEntity, + GroupEntity, +) from zha.websocket.client.client import Client -from zha.websocket.client.model.commands import ( - CommandResponse, +from zha.websocket.server.api.model import ( GetDevicesResponse, GroupsResponse, PermitJoiningResponse, ReadClusterAttributesResponse, UpdateGroupResponse, + WebSocketCommandResponse, WriteClusterAttributeResponse, ) -from zha.websocket.client.model.types import ( - BaseEntity, - BasePlatformEntity, - Device, - Group, +from zha.websocket.server.api.platforms.alarm_control_panel.api import ( + ArmAwayCommand, + ArmHomeCommand, + ArmNightCommand, + DisarmCommand, + TriggerAlarmCommand, +) +from zha.websocket.server.api.platforms.api import PlatformEntityRefreshStateCommand +from zha.websocket.server.api.platforms.button.api import ButtonPressCommand +from zha.websocket.server.api.platforms.climate.api import ( + ClimateSetFanModeCommand, + ClimateSetHVACModeCommand, + ClimateSetPresetModeCommand, + ClimateSetTemperatureCommand, +) +from zha.websocket.server.api.platforms.cover.api import ( + CoverCloseCommand, + CoverOpenCommand, + CoverSetPositionCommand, + CoverStopCommand, +) +from zha.websocket.server.api.platforms.fan.api import ( + FanSetPercentageCommand, + FanSetPresetModeCommand, + FanTurnOffCommand, + FanTurnOnCommand, +) +from zha.websocket.server.api.platforms.light.api import ( + LightTurnOffCommand, + LightTurnOnCommand, +) +from zha.websocket.server.api.platforms.lock.api import ( + LockClearUserLockCodeCommand, + LockDisableUserLockCodeCommand, + LockEnableUserLockCodeCommand, + LockLockCommand, + LockSetUserLockCodeCommand, + LockUnlockCommand, +) +from zha.websocket.server.api.platforms.number.api import NumberSetValueCommand +from zha.websocket.server.api.platforms.select.api import SelectSelectOptionCommand +from zha.websocket.server.api.platforms.siren.api import ( + SirenTurnOffCommand, + SirenTurnOnCommand, +) +from zha.websocket.server.api.platforms.switch.api import ( + SwitchTurnOffCommand, + SwitchTurnOnCommand, ) from zha.websocket.server.client import ( ClientDisconnectCommand, @@ -45,9 +93,10 @@ UpdateTopologyCommand, WriteClusterAttributeCommand, ) +from zha.zigbee.model import ExtendedDeviceInfo, GroupInfo -def ensure_platform_entity(entity: BaseEntity, platform: Platform) -> None: +def ensure_platform_entity(entity: BaseEntityInfo, platform: Platform) -> None: """Ensure an entity exists and is from the specified platform.""" if entity is None or entity.platform != platform: raise ValueError( @@ -55,6 +104,607 @@ def ensure_platform_entity(entity: BaseEntity, platform: Platform) -> None: ) +class LightHelper: + """Helper to issue light commands.""" + + def __init__(self, client: Client): + """Initialize the light helper.""" + self._client: Client = client + + async def turn_on( + self, + light_platform_entity: BasePlatformEntity | GroupEntity, + brightness: int | None = None, + transition: int | None = None, + flash: str | None = None, + effect: str | None = None, + hs_color: tuple | None = None, + color_temp: int | None = None, + ) -> WebSocketCommandResponse: + """Turn on a light.""" + ensure_platform_entity(light_platform_entity, Platform.LIGHT) + command = LightTurnOnCommand( + ieee=light_platform_entity.device_ieee + if not isinstance(light_platform_entity, GroupEntity) + else None, + group_id=light_platform_entity.group_id + if isinstance(light_platform_entity, GroupEntity) + else None, + unique_id=light_platform_entity.unique_id, + brightness=brightness, + transition=transition, + flash=flash, + effect=effect, + hs_color=hs_color, + color_temp=color_temp, + ) + return await self._client.async_send_command(command) + + async def turn_off( + self, + light_platform_entity: BasePlatformEntity | GroupEntity, + transition: int | None = None, + flash: bool | None = None, + ) -> WebSocketCommandResponse: + """Turn off a light.""" + ensure_platform_entity(light_platform_entity, Platform.LIGHT) + command = LightTurnOffCommand( + ieee=light_platform_entity.device_ieee + if not isinstance(light_platform_entity, GroupEntity) + else None, + group_id=light_platform_entity.group_id + if isinstance(light_platform_entity, GroupEntity) + else None, + unique_id=light_platform_entity.unique_id, + transition=transition, + flash=flash, + ) + return await self._client.async_send_command(command) + + +class SwitchHelper: + """Helper to issue switch commands.""" + + def __init__(self, client: Client): + """Initialize the switch helper.""" + self._client: Client = client + + async def turn_on( + self, + switch_platform_entity: BasePlatformEntity | GroupEntity, + ) -> WebSocketCommandResponse: + """Turn on a switch.""" + ensure_platform_entity(switch_platform_entity, Platform.SWITCH) + command = SwitchTurnOnCommand( + ieee=switch_platform_entity.device_ieee + if not isinstance(switch_platform_entity, GroupEntity) + else None, + group_id=switch_platform_entity.group_id + if isinstance(switch_platform_entity, GroupEntity) + else None, + unique_id=switch_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + async def turn_off( + self, + switch_platform_entity: BasePlatformEntity | GroupEntity, + ) -> WebSocketCommandResponse: + """Turn off a switch.""" + ensure_platform_entity(switch_platform_entity, Platform.SWITCH) + command = SwitchTurnOffCommand( + ieee=switch_platform_entity.device_ieee + if not isinstance(switch_platform_entity, GroupEntity) + else None, + group_id=switch_platform_entity.group_id + if isinstance(switch_platform_entity, GroupEntity) + else None, + unique_id=switch_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + +class SirenHelper: + """Helper to issue siren commands.""" + + def __init__(self, client: Client): + """Initialize the siren helper.""" + self._client: Client = client + + async def turn_on( + self, + siren_platform_entity: BasePlatformEntity, + duration: int | None = None, + volume_level: int | None = None, + tone: int | None = None, + ) -> WebSocketCommandResponse: + """Turn on a siren.""" + ensure_platform_entity(siren_platform_entity, Platform.SIREN) + command = SirenTurnOnCommand( + ieee=siren_platform_entity.device_ieee, + unique_id=siren_platform_entity.unique_id, + duration=duration, + level=volume_level, + tone=tone, + ) + return await self._client.async_send_command(command) + + async def turn_off( + self, siren_platform_entity: BasePlatformEntity + ) -> WebSocketCommandResponse: + """Turn off a siren.""" + ensure_platform_entity(siren_platform_entity, Platform.SIREN) + command = SirenTurnOffCommand( + ieee=siren_platform_entity.device_ieee, + unique_id=siren_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + +class ButtonHelper: + """Helper to issue button commands.""" + + def __init__(self, client: Client): + """Initialize the button helper.""" + self._client: Client = client + + async def press( + self, button_platform_entity: BasePlatformEntity + ) -> WebSocketCommandResponse: + """Press a button.""" + ensure_platform_entity(button_platform_entity, Platform.BUTTON) + command = ButtonPressCommand( + ieee=button_platform_entity.device_ieee, + unique_id=button_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + +class CoverHelper: + """helper to issue cover commands.""" + + def __init__(self, client: Client): + """Initialize the cover helper.""" + self._client: Client = client + + async def open_cover( + self, cover_platform_entity: BasePlatformEntity + ) -> WebSocketCommandResponse: + """Open a cover.""" + ensure_platform_entity(cover_platform_entity, Platform.COVER) + command = CoverOpenCommand( + ieee=cover_platform_entity.device_ieee, + unique_id=cover_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + async def close_cover( + self, cover_platform_entity: BasePlatformEntity + ) -> WebSocketCommandResponse: + """Close a cover.""" + ensure_platform_entity(cover_platform_entity, Platform.COVER) + command = CoverCloseCommand( + ieee=cover_platform_entity.device_ieee, + unique_id=cover_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + async def stop_cover( + self, cover_platform_entity: BasePlatformEntity + ) -> WebSocketCommandResponse: + """Stop a cover.""" + ensure_platform_entity(cover_platform_entity, Platform.COVER) + command = CoverStopCommand( + ieee=cover_platform_entity.device_ieee, + unique_id=cover_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + async def set_cover_position( + self, + cover_platform_entity: BasePlatformEntity, + position: int, + ) -> WebSocketCommandResponse: + """Set a cover position.""" + ensure_platform_entity(cover_platform_entity, Platform.COVER) + command = CoverSetPositionCommand( + ieee=cover_platform_entity.device_ieee, + unique_id=cover_platform_entity.unique_id, + position=position, + ) + return await self._client.async_send_command(command) + + +class FanHelper: + """Helper to issue fan commands.""" + + def __init__(self, client: Client): + """Initialize the fan helper.""" + self._client: Client = client + + async def turn_on( + self, + fan_platform_entity: BasePlatformEntity | GroupEntity, + speed: str | None = None, + percentage: int | None = None, + preset_mode: str | None = None, + ) -> WebSocketCommandResponse: + """Turn on a fan.""" + ensure_platform_entity(fan_platform_entity, Platform.FAN) + command = FanTurnOnCommand( + ieee=fan_platform_entity.device_ieee + if not isinstance(fan_platform_entity, GroupEntity) + else None, + group_id=fan_platform_entity.group_id + if isinstance(fan_platform_entity, GroupEntity) + else None, + unique_id=fan_platform_entity.unique_id, + speed=speed, + percentage=percentage, + preset_mode=preset_mode, + ) + return await self._client.async_send_command(command) + + async def turn_off( + self, + fan_platform_entity: BasePlatformEntity | GroupEntity, + ) -> WebSocketCommandResponse: + """Turn off a fan.""" + ensure_platform_entity(fan_platform_entity, Platform.FAN) + command = FanTurnOffCommand( + ieee=fan_platform_entity.device_ieee + if not isinstance(fan_platform_entity, GroupEntity) + else None, + group_id=fan_platform_entity.group_id + if isinstance(fan_platform_entity, GroupEntity) + else None, + unique_id=fan_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + async def set_fan_percentage( + self, + fan_platform_entity: BasePlatformEntity | GroupEntity, + percentage: int, + ) -> WebSocketCommandResponse: + """Set a fan percentage.""" + ensure_platform_entity(fan_platform_entity, Platform.FAN) + command = FanSetPercentageCommand( + ieee=fan_platform_entity.device_ieee + if not isinstance(fan_platform_entity, GroupEntity) + else None, + group_id=fan_platform_entity.group_id + if isinstance(fan_platform_entity, GroupEntity) + else None, + unique_id=fan_platform_entity.unique_id, + percentage=percentage, + ) + return await self._client.async_send_command(command) + + async def set_fan_preset_mode( + self, + fan_platform_entity: BasePlatformEntity | GroupEntity, + preset_mode: str, + ) -> WebSocketCommandResponse: + """Set a fan preset mode.""" + ensure_platform_entity(fan_platform_entity, Platform.FAN) + command = FanSetPresetModeCommand( + ieee=fan_platform_entity.device_ieee + if not isinstance(fan_platform_entity, GroupEntity) + else None, + group_id=fan_platform_entity.group_id + if isinstance(fan_platform_entity, GroupEntity) + else None, + unique_id=fan_platform_entity.unique_id, + preset_mode=preset_mode, + ) + return await self._client.async_send_command(command) + + +class LockHelper: + """Helper to issue lock commands.""" + + def __init__(self, client: Client): + """Initialize the lock helper.""" + self._client: Client = client + + async def lock( + self, lock_platform_entity: BasePlatformEntity + ) -> WebSocketCommandResponse: + """Lock a lock.""" + ensure_platform_entity(lock_platform_entity, Platform.LOCK) + command = LockLockCommand( + ieee=lock_platform_entity.device_ieee, + unique_id=lock_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + async def unlock( + self, lock_platform_entity: BasePlatformEntity + ) -> WebSocketCommandResponse: + """Unlock a lock.""" + ensure_platform_entity(lock_platform_entity, Platform.LOCK) + command = LockUnlockCommand( + ieee=lock_platform_entity.device_ieee, + unique_id=lock_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + async def set_user_lock_code( + self, + lock_platform_entity: BasePlatformEntity, + code_slot: int, + user_code: str, + ) -> WebSocketCommandResponse: + """Set a user lock code.""" + ensure_platform_entity(lock_platform_entity, Platform.LOCK) + command = LockSetUserLockCodeCommand( + ieee=lock_platform_entity.device_ieee, + unique_id=lock_platform_entity.unique_id, + code_slot=code_slot, + user_code=user_code, + ) + return await self._client.async_send_command(command) + + async def clear_user_lock_code( + self, + lock_platform_entity: BasePlatformEntity, + code_slot: int, + ) -> WebSocketCommandResponse: + """Clear a user lock code.""" + ensure_platform_entity(lock_platform_entity, Platform.LOCK) + command = LockClearUserLockCodeCommand( + ieee=lock_platform_entity.device_ieee, + unique_id=lock_platform_entity.unique_id, + code_slot=code_slot, + ) + return await self._client.async_send_command(command) + + async def enable_user_lock_code( + self, + lock_platform_entity: BasePlatformEntity, + code_slot: int, + ) -> WebSocketCommandResponse: + """Enable a user lock code.""" + ensure_platform_entity(lock_platform_entity, Platform.LOCK) + command = LockEnableUserLockCodeCommand( + ieee=lock_platform_entity.device_ieee, + unique_id=lock_platform_entity.unique_id, + code_slot=code_slot, + ) + return await self._client.async_send_command(command) + + async def disable_user_lock_code( + self, + lock_platform_entity: BasePlatformEntity, + code_slot: int, + ) -> WebSocketCommandResponse: + """Disable a user lock code.""" + ensure_platform_entity(lock_platform_entity, Platform.LOCK) + command = LockDisableUserLockCodeCommand( + ieee=lock_platform_entity.device_ieee, + unique_id=lock_platform_entity.unique_id, + code_slot=code_slot, + ) + return await self._client.async_send_command(command) + + +class NumberHelper: + """Helper to issue number commands.""" + + def __init__(self, client: Client): + """Initialize the number helper.""" + self._client: Client = client + + async def set_value( + self, + number_platform_entity: BasePlatformEntity, + value: int | float, + ) -> WebSocketCommandResponse: + """Set a number.""" + ensure_platform_entity(number_platform_entity, Platform.NUMBER) + command = NumberSetValueCommand( + ieee=number_platform_entity.device_ieee, + unique_id=number_platform_entity.unique_id, + value=value, + ) + return await self._client.async_send_command(command) + + +class SelectHelper: + """Helper to issue select commands.""" + + def __init__(self, client: Client): + """Initialize the select helper.""" + self._client: Client = client + + async def select_option( + self, + select_platform_entity: BasePlatformEntity, + option: str | int, + ) -> WebSocketCommandResponse: + """Set a select.""" + ensure_platform_entity(select_platform_entity, Platform.SELECT) + command = SelectSelectOptionCommand( + ieee=select_platform_entity.device_ieee, + unique_id=select_platform_entity.unique_id, + option=option, + ) + return await self._client.async_send_command(command) + + +class ClimateHelper: + """Helper to issue climate commands.""" + + def __init__(self, client: Client): + """Initialize the climate helper.""" + self._client: Client = client + + async def set_hvac_mode( + self, + climate_platform_entity: BasePlatformEntity, + hvac_mode: Literal[ + "heat_cool", "heat", "cool", "auto", "dry", "fan_only", "off" + ], + ) -> WebSocketCommandResponse: + """Set a climate.""" + ensure_platform_entity(climate_platform_entity, Platform.CLIMATE) + command = ClimateSetHVACModeCommand( + ieee=climate_platform_entity.device_ieee, + unique_id=climate_platform_entity.unique_id, + hvac_mode=hvac_mode, + ) + return await self._client.async_send_command(command) + + async def set_temperature( + self, + climate_platform_entity: BasePlatformEntity, + hvac_mode: None + | ( + Literal["heat_cool", "heat", "cool", "auto", "dry", "fan_only", "off"] + ) = None, + temperature: float | None = None, + target_temp_high: float | None = None, + target_temp_low: float | None = None, + ) -> WebSocketCommandResponse: + """Set a climate.""" + ensure_platform_entity(climate_platform_entity, Platform.CLIMATE) + command = ClimateSetTemperatureCommand( + ieee=climate_platform_entity.device_ieee, + unique_id=climate_platform_entity.unique_id, + temperature=temperature, + target_temp_high=target_temp_high, + target_temp_low=target_temp_low, + hvac_mode=hvac_mode, + ) + return await self._client.async_send_command(command) + + async def set_fan_mode( + self, + climate_platform_entity: BasePlatformEntity, + fan_mode: str, + ) -> WebSocketCommandResponse: + """Set a climate.""" + ensure_platform_entity(climate_platform_entity, Platform.CLIMATE) + command = ClimateSetFanModeCommand( + ieee=climate_platform_entity.device_ieee, + unique_id=climate_platform_entity.unique_id, + fan_mode=fan_mode, + ) + return await self._client.async_send_command(command) + + async def set_preset_mode( + self, + climate_platform_entity: BasePlatformEntity, + preset_mode: str, + ) -> WebSocketCommandResponse: + """Set a climate.""" + ensure_platform_entity(climate_platform_entity, Platform.CLIMATE) + command = ClimateSetPresetModeCommand( + ieee=climate_platform_entity.device_ieee, + unique_id=climate_platform_entity.unique_id, + preset_mode=preset_mode, + ) + return await self._client.async_send_command(command) + + +class AlarmControlPanelHelper: + """Helper to issue alarm control panel commands.""" + + def __init__(self, client: Client): + """Initialize the alarm control panel helper.""" + self._client: Client = client + + async def disarm( + self, alarm_control_panel_platform_entity: BasePlatformEntity, code: str + ) -> WebSocketCommandResponse: + """Disarm an alarm control panel.""" + ensure_platform_entity( + alarm_control_panel_platform_entity, Platform.ALARM_CONTROL_PANEL + ) + command = DisarmCommand( + ieee=alarm_control_panel_platform_entity.device_ieee, + unique_id=alarm_control_panel_platform_entity.unique_id, + code=code, + ) + return await self._client.async_send_command(command) + + async def arm_home( + self, alarm_control_panel_platform_entity: BasePlatformEntity, code: str + ) -> WebSocketCommandResponse: + """Arm an alarm control panel in home mode.""" + ensure_platform_entity( + alarm_control_panel_platform_entity, Platform.ALARM_CONTROL_PANEL + ) + command = ArmHomeCommand( + ieee=alarm_control_panel_platform_entity.device_ieee, + unique_id=alarm_control_panel_platform_entity.unique_id, + code=code, + ) + return await self._client.async_send_command(command) + + async def arm_away( + self, alarm_control_panel_platform_entity: BasePlatformEntity, code: str + ) -> WebSocketCommandResponse: + """Arm an alarm control panel in away mode.""" + ensure_platform_entity( + alarm_control_panel_platform_entity, Platform.ALARM_CONTROL_PANEL + ) + command = ArmAwayCommand( + ieee=alarm_control_panel_platform_entity.device_ieee, + unique_id=alarm_control_panel_platform_entity.unique_id, + code=code, + ) + return await self._client.async_send_command(command) + + async def arm_night( + self, alarm_control_panel_platform_entity: BasePlatformEntity, code: str + ) -> WebSocketCommandResponse: + """Arm an alarm control panel in night mode.""" + ensure_platform_entity( + alarm_control_panel_platform_entity, Platform.ALARM_CONTROL_PANEL + ) + command = ArmNightCommand( + ieee=alarm_control_panel_platform_entity.device_ieee, + unique_id=alarm_control_panel_platform_entity.unique_id, + code=code, + ) + return await self._client.async_send_command(command) + + async def trigger( + self, + alarm_control_panel_platform_entity: BasePlatformEntity, + ) -> WebSocketCommandResponse: + """Trigger an alarm control panel alarm.""" + ensure_platform_entity( + alarm_control_panel_platform_entity, Platform.ALARM_CONTROL_PANEL + ) + command = TriggerAlarmCommand( + ieee=alarm_control_panel_platform_entity.device_ieee, + unique_id=alarm_control_panel_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + +class PlatformEntityHelper: + """Helper to send global platform entity commands.""" + + def __init__(self, client: Client): + """Initialize the platform entity helper.""" + self._client: Client = client + + async def refresh_state( + self, platform_entity: BasePlatformEntity + ) -> WebSocketCommandResponse: + """Refresh the state of a platform entity.""" + command = PlatformEntityRefreshStateCommand( + ieee=platform_entity.device_ieee, + unique_id=platform_entity.unique_id, + platform=platform_entity.platform, + ) + return await self._client.async_send_command(command) + + class ClientHelper: """Helper to send client specific commands.""" @@ -62,17 +712,17 @@ def __init__(self, client: Client): """Initialize the client helper.""" self._client: Client = client - async def listen(self) -> CommandResponse: + async def listen(self) -> WebSocketCommandResponse: """Listen for incoming messages.""" command = ClientListenCommand() return await self._client.async_send_command(command) - async def listen_raw_zcl(self) -> CommandResponse: + async def listen_raw_zcl(self) -> WebSocketCommandResponse: """Listen for incoming raw ZCL messages.""" command = ClientListenRawZCLCommand() return await self._client.async_send_command(command) - async def disconnect(self) -> CommandResponse: + async def disconnect(self) -> WebSocketCommandResponse: """Disconnect this client from the server.""" command = ClientDisconnectCommand() return await self._client.async_send_command(command) @@ -85,7 +735,7 @@ def __init__(self, client: Client): """Initialize the group helper.""" self._client: Client = client - async def get_groups(self) -> dict[int, Group]: + async def get_groups(self) -> dict[int, GroupInfo]: """Get the groups.""" response = cast( GroupsResponse, @@ -98,7 +748,7 @@ async def create_group( name: str, unique_id: int | None = None, members: list[BasePlatformEntity] | None = None, - ) -> Group: + ) -> GroupInfo: """Create a new group.""" request_data: dict[str, Any] = { "group_name": name, @@ -117,10 +767,10 @@ async def create_group( ) return response.group - async def remove_groups(self, groups: list[Group]) -> dict[int, Group]: + async def remove_groups(self, groups: list[GroupInfo]) -> dict[int, GroupInfo]: """Remove groups.""" request: dict[str, Any] = { - "group_ids": [group.id for group in groups], + "group_ids": [group.group_id for group in groups], } command = RemoveGroupsCommand(**request) response = cast( @@ -130,11 +780,11 @@ async def remove_groups(self, groups: list[Group]) -> dict[int, Group]: return response.groups async def add_group_members( - self, group: Group, members: list[BasePlatformEntity] - ) -> Group: + self, group: GroupInfo, members: list[BasePlatformEntity] + ) -> GroupInfo: """Add members to a group.""" request_data: dict[str, Any] = { - "group_id": group.id, + "group_id": group.group_id, "members": [ {"ieee": member.device_ieee, "endpoint_id": member.endpoint_id} for member in members @@ -149,11 +799,11 @@ async def add_group_members( return response.group async def remove_group_members( - self, group: Group, members: list[BasePlatformEntity] - ) -> Group: + self, group: GroupInfo, members: list[BasePlatformEntity] + ) -> GroupInfo: """Remove members from a group.""" request_data: dict[str, Any] = { - "group_id": group.id, + "group_id": group.group_id, "members": [ {"ieee": member.device_ieee, "endpoint_id": member.endpoint_id} for member in members @@ -175,7 +825,7 @@ def __init__(self, client: Client): """Initialize the device helper.""" self._client: Client = client - async def get_devices(self) -> dict[EUI64, Device]: + async def get_devices(self) -> dict[EUI64, ExtendedDeviceInfo]: """Get the groups.""" response = cast( GetDevicesResponse, @@ -183,19 +833,19 @@ async def get_devices(self) -> dict[EUI64, Device]: ) return response.devices - async def reconfigure_device(self, device: Device) -> None: + async def reconfigure_device(self, device: ExtendedDeviceInfo) -> None: """Reconfigure a device.""" await self._client.async_send_command( ReconfigureDeviceCommand(ieee=device.ieee) ) - async def remove_device(self, device: Device) -> None: + async def remove_device(self, device: ExtendedDeviceInfo) -> None: """Remove a device.""" await self._client.async_send_command(RemoveDeviceCommand(ieee=device.ieee)) async def read_cluster_attributes( self, - device: Device, + device: ExtendedDeviceInfo, cluster_id: int, cluster_type: str, endpoint_id: int, @@ -220,7 +870,7 @@ async def read_cluster_attributes( async def write_cluster_attribute( self, - device: Device, + device: ExtendedDeviceInfo, cluster_id: int, cluster_type: str, endpoint_id: int, @@ -254,7 +904,7 @@ def __init__(self, client: Client): self._client: Client = client async def permit_joining( - self, duration: int = 255, device: Device | None = None + self, duration: int = 255, device: ExtendedDeviceInfo | None = None ) -> bool: """Permit joining for a specified duration.""" # TODO add permit with code support diff --git a/zha/websocket/client/model/commands.py b/zha/websocket/client/model/commands.py deleted file mode 100644 index 9d0eb878e..000000000 --- a/zha/websocket/client/model/commands.py +++ /dev/null @@ -1,200 +0,0 @@ -"""Models that represent commands and command responses.""" - -from typing import Annotated, Any, Literal, Optional, Union - -from pydantic import field_validator -from pydantic.fields import Field -from zigpy.types.named import EUI64 - -from zha.model import BaseModel -from zha.websocket.client.model.events import MinimalCluster, MinimalDevice -from zha.websocket.client.model.types import Device, Group - - -class CommandResponse(BaseModel): - """Command response model.""" - - message_type: Literal["result"] = "result" - message_id: int - success: bool - - -class ErrorResponse(CommandResponse): - """Error response model.""" - - success: bool = False - error_code: str - error_message: str - zigbee_error_code: Optional[str] - command: Literal[ - "error.start_network", - "error.stop_network", - "error.remove_device", - "error.stop_server", - "error.light_turn_on", - "error.light_turn_off", - "error.switch_turn_on", - "error.switch_turn_off", - "error.lock_lock", - "error.lock_unlock", - "error.lock_set_user_lock_code", - "error.lock_clear_user_lock_code", - "error.lock_disable_user_lock_code", - "error.lock_enable_user_lock_code", - "error.fan_turn_on", - "error.fan_turn_off", - "error.fan_set_percentage", - "error.fan_set_preset_mode", - "error.cover_open", - "error.cover_close", - "error.cover_set_position", - "error.cover_stop", - "error.climate_set_fan_mode", - "error.climate_set_hvac_mode", - "error.climate_set_preset_mode", - "error.climate_set_temperature", - "error.button_press", - "error.alarm_control_panel_disarm", - "error.alarm_control_panel_arm_home", - "error.alarm_control_panel_arm_away", - "error.alarm_control_panel_arm_night", - "error.alarm_control_panel_trigger", - "error.select_select_option", - "error.siren_turn_on", - "error.siren_turn_off", - "error.number_set_value", - "error.platform_entity_refresh_state", - "error.client_listen", - "error.client_listen_raw_zcl", - "error.client_disconnect", - "error.reconfigure_device", - "error.UpdateNetworkTopologyCommand", - ] - - -class DefaultResponse(CommandResponse): - """Default command response.""" - - command: Literal[ - "start_network", - "stop_network", - "remove_device", - "stop_server", - "light_turn_on", - "light_turn_off", - "switch_turn_on", - "switch_turn_off", - "lock_lock", - "lock_unlock", - "lock_set_user_lock_code", - "lock_clear_user_lock_code", - "lock_disable_user_lock_code", - "lock_enable_user_lock_code", - "fan_turn_on", - "fan_turn_off", - "fan_set_percentage", - "fan_set_preset_mode", - "cover_open", - "cover_close", - "cover_set_position", - "cover_stop", - "climate_set_fan_mode", - "climate_set_hvac_mode", - "climate_set_preset_mode", - "climate_set_temperature", - "button_press", - "alarm_control_panel_disarm", - "alarm_control_panel_arm_home", - "alarm_control_panel_arm_away", - "alarm_control_panel_arm_night", - "alarm_control_panel_trigger", - "select_select_option", - "siren_turn_on", - "siren_turn_off", - "number_set_value", - "platform_entity_refresh_state", - "client_listen", - "client_listen_raw_zcl", - "client_disconnect", - "reconfigure_device", - "UpdateNetworkTopologyCommand", - ] - - -class PermitJoiningResponse(CommandResponse): - """Get devices response.""" - - command: Literal["permit_joining"] = "permit_joining" - duration: int - - -class GetDevicesResponse(CommandResponse): - """Get devices response.""" - - command: Literal["get_devices"] = "get_devices" - devices: dict[EUI64, Device] - - @field_validator("devices", mode="before", check_fields=False) - @classmethod - def convert_devices_device_ieee( - cls, devices: dict[str, dict] - ) -> dict[EUI64, Device]: - """Convert device ieee to EUI64.""" - return {EUI64.convert(k): Device(**v) for k, v in devices.items()} - - -class ReadClusterAttributesResponse(CommandResponse): - """Read cluster attributes response.""" - - command: Literal["read_cluster_attributes"] = "read_cluster_attributes" - device: MinimalDevice - cluster: MinimalCluster - manufacturer_code: Optional[int] - succeeded: dict[str, Any] - failed: dict[str, Any] - - -class AttributeStatus(BaseModel): - """Attribute status.""" - - attribute: str - status: str - - -class WriteClusterAttributeResponse(CommandResponse): - """Write cluster attribute response.""" - - command: Literal["write_cluster_attribute"] = "write_cluster_attribute" - device: MinimalDevice - cluster: MinimalCluster - manufacturer_code: Optional[int] - response: AttributeStatus - - -class GroupsResponse(CommandResponse): - """Get groups response.""" - - command: Literal["get_groups", "remove_groups"] - groups: dict[int, Group] - - -class UpdateGroupResponse(CommandResponse): - """Update group response.""" - - command: Literal["create_group", "add_group_members", "remove_group_members"] - group: Group - - -CommandResponses = Annotated[ - Union[ - DefaultResponse, - ErrorResponse, - GetDevicesResponse, - GroupsResponse, - PermitJoiningResponse, - UpdateGroupResponse, - ReadClusterAttributesResponse, - WriteClusterAttributeResponse, - ], - Field(discriminator="command"), # noqa: F821 -] diff --git a/zha/websocket/client/model/events.py b/zha/websocket/client/model/events.py deleted file mode 100644 index 03496addc..000000000 --- a/zha/websocket/client/model/events.py +++ /dev/null @@ -1,263 +0,0 @@ -"""Event models for zhawss. - -Events are unprompted messages from the server -> client and they contain only the data that is necessary to -handle the event. -""" - -from typing import Annotated, Any, Literal, Optional, Union - -from pydantic.fields import Field -from zigpy.types.named import EUI64 - -from zha.model import BaseEvent, BaseModel -from zha.websocket.client.model.types import ( - BaseDevice, - BatteryState, - BooleanState, - CoverState, - Device, - DeviceSignature, - DeviceTrackerState, - ElectricalMeasurementState, - FanState, - GenericState, - Group, - LightState, - LockState, - ShadeState, - SmareEnergyMeteringState, - SwitchState, - ThermostatState, -) - - -class MinimalPlatformEntity(BaseModel): - """Platform entity model.""" - - unique_id: str - platform: str - - -class MinimalEndpoint(BaseModel): - """Minimal endpoint model.""" - - id: int - unique_id: str - - -class MinimalDevice(BaseModel): - """Minimal device model.""" - - ieee: EUI64 - - -class Attribute(BaseModel): - """Attribute model.""" - - id: int - name: str - value: Any = None - - -class MinimalCluster(BaseModel): - """Minimal cluster model.""" - - id: int - endpoint_attribute: str - name: str - endpoint_id: int - - -class MinimalClusterHandler(BaseModel): - """Minimal cluster handler model.""" - - unique_id: str - cluster: MinimalCluster - - -class MinimalGroup(BaseModel): - """Minimal group model.""" - - id: int - - -class PlatformEntityStateChangedEvent(BaseEvent): - """Platform entity event.""" - - event_type: Literal["platform_entity_event"] = "platform_entity_event" - event: Literal["platform_entity_state_changed"] = "platform_entity_state_changed" - platform_entity: MinimalPlatformEntity - endpoint: Optional[MinimalEndpoint] = None - device: Optional[MinimalDevice] = None - group: Optional[MinimalGroup] = None - state: Annotated[ - Optional[ - Union[ - DeviceTrackerState, - CoverState, - ShadeState, - FanState, - LockState, - BatteryState, - ElectricalMeasurementState, - LightState, - SwitchState, - SmareEnergyMeteringState, - GenericState, - BooleanState, - ThermostatState, - ] - ], - Field(discriminator="class_name"), # noqa: F821 - ] - - -class ZCLAttributeUpdatedEvent(BaseEvent): - """ZCL attribute updated event.""" - - event_type: Literal["raw_zcl_event"] = "raw_zcl_event" - event: Literal["attribute_updated"] = "attribute_updated" - device: MinimalDevice - cluster_handler: MinimalClusterHandler - attribute: Attribute - endpoint: MinimalEndpoint - - -class ControllerEvent(BaseEvent): - """Controller event.""" - - event_type: Literal["controller_event"] = "controller_event" - - -class DevicePairingEvent(ControllerEvent): - """Device pairing event.""" - - pairing_status: str - - -class DeviceJoinedEvent(DevicePairingEvent): - """Device joined event.""" - - event: Literal["device_joined"] = "device_joined" - ieee: EUI64 - nwk: str - - -class RawDeviceInitializedEvent(DevicePairingEvent): - """Raw device initialized event.""" - - event: Literal["raw_device_initialized"] = "raw_device_initialized" - ieee: EUI64 - nwk: str - manufacturer: str - model: str - signature: DeviceSignature - - -class DeviceFullyInitializedEvent(DevicePairingEvent): - """Device fully initialized event.""" - - event: Literal["device_fully_initialized"] = "device_fully_initialized" - device: Device - new_join: bool - - -class DeviceConfiguredEvent(DevicePairingEvent): - """Device configured event.""" - - event: Literal["device_configured"] = "device_configured" - device: BaseDevice - - -class DeviceLeftEvent(ControllerEvent): - """Device left event.""" - - event: Literal["device_left"] = "device_left" - ieee: EUI64 - nwk: str - - -class DeviceRemovedEvent(ControllerEvent): - """Device removed event.""" - - event: Literal["device_removed"] = "device_removed" - device: Device - - -class DeviceOfflineEvent(BaseEvent): - """Device offline event.""" - - event: Literal["device_offline"] = "device_offline" - event_type: Literal["device_event"] = "device_event" - device: MinimalDevice - - -class DeviceOnlineEvent(BaseEvent): - """Device online event.""" - - event: Literal["device_online"] = "device_online" - event_type: Literal["device_event"] = "device_event" - device: MinimalDevice - - -class ZHAEvent(BaseEvent): - """ZHA event.""" - - event: Literal["zha_event"] = "zha_event" - event_type: Literal["device_event"] = "device_event" - device: MinimalDevice - cluster_handler: MinimalClusterHandler - endpoint: MinimalEndpoint - command: str - args: Union[list, dict] - params: dict[str, Any] - - -class GroupRemovedEvent(ControllerEvent): - """Group removed event.""" - - event: Literal["group_removed"] = "group_removed" - group: Group - - -class GroupAddedEvent(ControllerEvent): - """Group added event.""" - - event: Literal["group_added"] = "group_added" - group: Group - - -class GroupMemberAddedEvent(ControllerEvent): - """Group member added event.""" - - event: Literal["group_member_added"] = "group_member_added" - group: Group - - -class GroupMemberRemovedEvent(ControllerEvent): - """Group member removed event.""" - - event: Literal["group_member_removed"] = "group_member_removed" - group: Group - - -Events = Annotated[ - Union[ - PlatformEntityStateChangedEvent, - ZCLAttributeUpdatedEvent, - DeviceJoinedEvent, - RawDeviceInitializedEvent, - DeviceFullyInitializedEvent, - DeviceConfiguredEvent, - DeviceLeftEvent, - DeviceRemovedEvent, - GroupRemovedEvent, - GroupAddedEvent, - GroupMemberAddedEvent, - GroupMemberRemovedEvent, - DeviceOfflineEvent, - DeviceOnlineEvent, - ZHAEvent, - ], - Field(discriminator="event"), # noqa: F821 -] diff --git a/zha/websocket/client/model/messages.py b/zha/websocket/client/model/messages.py index 9e5149bd4..e3801cf5e 100644 --- a/zha/websocket/client/model/messages.py +++ b/zha/websocket/client/model/messages.py @@ -6,8 +6,7 @@ from pydantic.fields import Field from zigpy.types.named import EUI64 -from zha.websocket.client.model.commands import CommandResponses -from zha.websocket.client.model.events import Events +from zha.websocket.server.api.model import CommandResponses, Events class Message(RootModel): diff --git a/zha/websocket/client/model/types.py b/zha/websocket/client/model/types.py deleted file mode 100644 index 83d3b8c15..000000000 --- a/zha/websocket/client/model/types.py +++ /dev/null @@ -1,760 +0,0 @@ -"""Models that represent types for the zhaws.client. - -Types are representations of the objects that exist in zhawss. -""" - -from typing import Annotated, Any, Literal, Optional, Union - -from pydantic import ValidationInfo, field_serializer, field_validator -from pydantic.fields import Field -from zigpy.types.named import EUI64, NWK -from zigpy.zdo.types import NodeDescriptor as ZigpyNodeDescriptor - -from zha.event import EventBase -from zha.model import BaseModel - - -class BaseEventedModel(EventBase, BaseModel): - """Base evented model.""" - - -class Cluster(BaseModel): - """Cluster model.""" - - id: int - endpoint_attribute: str - name: str - endpoint_id: int - type: str - commands: list[str] - - -class ClusterHandler(BaseModel): - """Cluster handler model.""" - - unique_id: str - cluster: Cluster - class_name: str - generic_id: str - endpoint_id: int - id: str - status: str - - -class Endpoint(BaseModel): - """Endpoint model.""" - - id: int - unique_id: str - - -class GenericState(BaseModel): - """Default state model.""" - - class_name: Literal[ - "ZHAAlarmControlPanel", - "Number", - "DefaultToneSelectEntity", - "DefaultSirenLevelSelectEntity", - "DefaultStrobeLevelSelectEntity", - "DefaultStrobeSelectEntity", - "AnalogInput", - "Humidity", - "SoilMoisture", - "LeafWetness", - "Illuminance", - "Pressure", - "Temperature", - "CarbonDioxideConcentration", - "CarbonMonoxideConcentration", - "VOCLevel", - "PPBVOCLevel", - "FormaldehydeConcentration", - "ThermostatHVACAction", - "SinopeHVACAction", - "RSSISensor", - "LQISensor", - "LastSeenSensor", - ] - state: Union[str, bool, int, float, None] = None - - -class DeviceCounterSensorState(BaseModel): - """Device counter sensor state model.""" - - class_name: Literal["DeviceCounterSensor"] = "DeviceCounterSensor" - state: int - - -class DeviceTrackerState(BaseModel): - """Device tracker state model.""" - - class_name: Literal["DeviceTracker"] = "DeviceTracker" - connected: bool - battery_level: Optional[float] = None - - -class BooleanState(BaseModel): - """Boolean value state model.""" - - class_name: Literal[ - "Accelerometer", - "Occupancy", - "Opening", - "BinaryInput", - "Motion", - "IASZone", - "Siren", - ] - state: bool - - -class CoverState(BaseModel): - """Cover state model.""" - - class_name: Literal["Cover"] = "Cover" - current_position: int - state: Optional[str] = None - is_opening: bool - is_closing: bool - is_closed: bool - - -class ShadeState(BaseModel): - """Cover state model.""" - - class_name: Literal["Shade", "KeenVent"] - current_position: Optional[int] = ( - None # TODO: how should we represent this when it is None? - ) - is_closed: bool - state: Optional[str] = None - - -class FanState(BaseModel): - """Fan state model.""" - - class_name: Literal["Fan", "FanGroup"] - preset_mode: Optional[str] = ( - None # TODO: how should we represent these when they are None? - ) - percentage: Optional[int] = ( - None # TODO: how should we represent these when they are None? - ) - is_on: bool - speed: Optional[str] = None - - -class LockState(BaseModel): - """Lock state model.""" - - class_name: Literal["Lock"] = "Lock" - is_locked: bool - - -class BatteryState(BaseModel): - """Battery state model.""" - - class_name: Literal["Battery"] = "Battery" - state: Optional[Union[str, float, int]] = None - battery_size: Optional[str] = None - battery_quantity: Optional[int] = None - battery_voltage: Optional[float] = None - - -class ElectricalMeasurementState(BaseModel): - """Electrical measurement state model.""" - - class_name: Literal[ - "ElectricalMeasurement", - "ElectricalMeasurementApparentPower", - "ElectricalMeasurementRMSCurrent", - "ElectricalMeasurementRMSVoltage", - ] - state: Optional[Union[str, float, int]] = None - measurement_type: Optional[str] = None - active_power_max: Optional[str] = None - rms_current_max: Optional[str] = None - rms_voltage_max: Optional[str] = None - - -class LightState(BaseModel): - """Light state model.""" - - class_name: Literal["Light", "HueLight", "ForceOnLight", "LightGroup"] - on: bool - brightness: Optional[int] = None - hs_color: Optional[tuple[float, float]] = None - color_temp: Optional[int] = None - effect: Optional[str] = None - off_brightness: Optional[int] = None - - -class ThermostatState(BaseModel): - """Thermostat state model.""" - - class_name: Literal[ - "Thermostat", - "SinopeTechnologiesThermostat", - "ZenWithinThermostat", - "MoesThermostat", - "BecaThermostat", - ] - current_temperature: Optional[float] = None - target_temperature: Optional[float] = None - target_temperature_low: Optional[float] = None - target_temperature_high: Optional[float] = None - hvac_action: Optional[str] = None - hvac_mode: Optional[str] = None - preset_mode: Optional[str] = None - fan_mode: Optional[str] = None - - -class SwitchState(BaseModel): - """Switch state model.""" - - class_name: Literal["Switch", "SwitchGroup"] - state: bool - - -class SmareEnergyMeteringState(BaseModel): - """Smare energy metering state model.""" - - class_name: Literal["SmartEnergyMetering", "SmartEnergySummation"] - state: Optional[Union[str, float, int]] = None - device_type: Optional[str] = None - status: Optional[str] = None - - -class BaseEntity(BaseEventedModel): - """Base platform entity model.""" - - unique_id: str - platform: str - class_name: str - fallback_name: str | None = None - translation_key: str | None = None - device_class: str | None = None - state_class: str | None = None - entity_category: str | None = None - entity_registry_enabled_default: bool - enabled: bool - - -class BasePlatformEntity(BaseEntity): - """Base platform entity model.""" - - device_ieee: EUI64 - endpoint_id: int - - -class LockEntity(BasePlatformEntity): - """Lock entity model.""" - - class_name: Literal["Lock"] - state: LockState - - -class DeviceTrackerEntity(BasePlatformEntity): - """Device tracker entity model.""" - - class_name: Literal["DeviceTracker"] - state: DeviceTrackerState - - -class CoverEntity(BasePlatformEntity): - """Cover entity model.""" - - class_name: Literal["Cover"] - state: CoverState - - -class ShadeEntity(BasePlatformEntity): - """Shade entity model.""" - - class_name: Literal["Shade", "KeenVent"] - state: ShadeState - - -class BinarySensorEntity(BasePlatformEntity): - """Binary sensor model.""" - - class_name: Literal[ - "Accelerometer", "Occupancy", "Opening", "BinaryInput", "Motion", "IASZone" - ] - attribute_name: str - state: BooleanState - - -class BaseSensorEntity(BasePlatformEntity): - """Sensor model.""" - - attribute: Optional[str] - decimals: int - divisor: int - multiplier: Union[int, float] - unit: Optional[int | str] - - -class SensorEntity(BaseSensorEntity): - """Sensor entity model.""" - - class_name: Literal[ - "AnalogInput", - "Humidity", - "SoilMoisture", - "LeafWetness", - "Illuminance", - "Pressure", - "Temperature", - "CarbonDioxideConcentration", - "CarbonMonoxideConcentration", - "VOCLevel", - "PPBVOCLevel", - "FormaldehydeConcentration", - "ThermostatHVACAction", - "SinopeHVACAction", - "RSSISensor", - "LQISensor", - "LastSeenSensor", - ] - state: GenericState - - -class DeviceCounterSensorEntity(BaseEntity): - """Device counter sensor model.""" - - class_name: Literal["DeviceCounterSensor"] - counter: str - counter_value: int - counter_groups: str - counter_group: str - state: DeviceCounterSensorState - - @field_validator("state", mode="before", check_fields=False) - @classmethod - def convert_state( - cls, state: dict | int | None, validation_info: ValidationInfo - ) -> DeviceCounterSensorState: - """Convert counter value to counter_value.""" - if state is not None: - if isinstance(state, int): - return DeviceCounterSensorState(state=state) - if isinstance(state, dict): - if "state" in state: - return DeviceCounterSensorState(state=state["state"]) - else: - return DeviceCounterSensorState( - state=validation_info.data["counter_value"] - ) - return DeviceCounterSensorState(state=validation_info.data["counter_value"]) - - -class BatteryEntity(BaseSensorEntity): - """Battery entity model.""" - - class_name: Literal["Battery"] - state: BatteryState - - -class ElectricalMeasurementEntity(BaseSensorEntity): - """Electrical measurement entity model.""" - - class_name: Literal[ - "ElectricalMeasurement", - "ElectricalMeasurementApparentPower", - "ElectricalMeasurementRMSCurrent", - "ElectricalMeasurementRMSVoltage", - ] - state: ElectricalMeasurementState - - -class SmartEnergyMeteringEntity(BaseSensorEntity): - """Smare energy metering entity model.""" - - class_name: Literal["SmartEnergyMetering", "SmartEnergySummation"] - state: SmareEnergyMeteringState - - -class AlarmControlPanelEntity(BasePlatformEntity): - """Alarm control panel model.""" - - class_name: Literal["ZHAAlarmControlPanel"] - supported_features: int - code_required_arm_actions: bool - max_invalid_tries: int - state: GenericState - - -class ButtonEntity(BasePlatformEntity): - """Button model.""" - - class_name: Literal["IdentifyButton"] - command: str - - -class FanEntity(BasePlatformEntity): - """Fan model.""" - - class_name: Literal["Fan"] - preset_modes: list[str] - supported_features: int - speed_count: int - speed_list: list[str] - percentage_step: float - state: FanState - - -class LightEntity(BasePlatformEntity): - """Light model.""" - - class_name: Literal["Light", "HueLight", "ForceOnLight"] - supported_features: int - min_mireds: int - max_mireds: int - effect_list: Optional[list[str]] - state: LightState - - -class NumberEntity(BasePlatformEntity): - """Number entity model.""" - - class_name: Literal["Number"] - engineering_units: Optional[ - int - ] # TODO: how should we represent this when it is None? - application_type: Optional[ - int - ] # TODO: how should we represent this when it is None? - step: Optional[float] # TODO: how should we represent this when it is None? - min_value: float - max_value: float - state: GenericState - - -class SelectEntity(BasePlatformEntity): - """Select entity model.""" - - class_name: Literal[ - "DefaultToneSelectEntity", - "DefaultSirenLevelSelectEntity", - "DefaultStrobeLevelSelectEntity", - "DefaultStrobeSelectEntity", - ] - enum: str - options: list[str] - state: GenericState - - -class ThermostatEntity(BasePlatformEntity): - """Thermostat entity model.""" - - class_name: Literal[ - "Thermostat", - "SinopeTechnologiesThermostat", - "ZenWithinThermostat", - "MoesThermostat", - "BecaThermostat", - ] - state: ThermostatState - hvac_modes: tuple[str, ...] - fan_modes: Optional[list[str]] - preset_modes: Optional[list[str]] - - -class SirenEntity(BasePlatformEntity): - """Siren entity model.""" - - class_name: Literal["Siren"] - available_tones: Optional[Union[list[Union[int, str]], dict[int, str]]] - supported_features: int - state: BooleanState - - -class SwitchEntity(BasePlatformEntity): - """Switch entity model.""" - - class_name: Literal["Switch"] - state: SwitchState - - -class DeviceSignatureEndpoint(BaseModel): - """Device signature endpoint model.""" - - profile_id: Optional[str] = None - device_type: Optional[str] = None - input_clusters: list[str] - output_clusters: list[str] - - @field_validator("profile_id", mode="before", check_fields=False) - @classmethod - def convert_profile_id(cls, profile_id: int | str) -> str: - """Convert profile_id.""" - if isinstance(profile_id, int): - return f"0x{profile_id:04x}" - return profile_id - - @field_validator("device_type", mode="before", check_fields=False) - @classmethod - def convert_device_type(cls, device_type: int | str) -> str: - """Convert device_type.""" - if isinstance(device_type, int): - return f"0x{device_type:04x}" - return device_type - - @field_validator("input_clusters", mode="before", check_fields=False) - @classmethod - def convert_input_clusters(cls, input_clusters: list[int | str]) -> list[str]: - """Convert input_clusters.""" - clusters = [] - for cluster_id in input_clusters: - if isinstance(cluster_id, int): - clusters.append(f"0x{cluster_id:04x}") - else: - clusters.append(cluster_id) - return clusters - - @field_validator("output_clusters", mode="before", check_fields=False) - @classmethod - def convert_output_clusters(cls, output_clusters: list[int | str]) -> list[str]: - """Convert output_clusters.""" - clusters = [] - for cluster_id in output_clusters: - if isinstance(cluster_id, int): - clusters.append(f"0x{cluster_id:04x}") - else: - clusters.append(cluster_id) - return clusters - - -class NodeDescriptor(BaseModel): - """Node descriptor model.""" - - logical_type: int - complex_descriptor_available: bool - user_descriptor_available: bool - reserved: int - aps_flags: int - frequency_band: int - mac_capability_flags: int - manufacturer_code: int - maximum_buffer_size: int - maximum_incoming_transfer_size: int - server_mask: int - maximum_outgoing_transfer_size: int - descriptor_capability_field: int - - -class DeviceSignature(BaseModel): - """Device signature model.""" - - node_descriptor: Optional[NodeDescriptor] = None - manufacturer: Optional[str] = None - model: Optional[str] = None - endpoints: dict[int, DeviceSignatureEndpoint] - - @field_validator("node_descriptor", mode="before", check_fields=False) - @classmethod - def convert_node_descriptor( - cls, node_descriptor: ZigpyNodeDescriptor - ) -> NodeDescriptor: - """Convert node descriptor.""" - if isinstance(node_descriptor, ZigpyNodeDescriptor): - return node_descriptor.as_dict() - return node_descriptor - - -class BaseDevice(BaseModel): - """Base device model.""" - - ieee: EUI64 - nwk: str - manufacturer: str - model: str - name: str - quirk_applied: bool - quirk_class: Union[str, None] = None - manufacturer_code: int - power_source: str - lqi: Union[int, None] = None - rssi: Union[int, None] = None - last_seen: str - available: bool - device_type: Literal["Coordinator", "Router", "EndDevice"] - signature: DeviceSignature - - @field_validator("nwk", mode="before", check_fields=False) - @classmethod - def convert_nwk(cls, nwk: NWK) -> str: - """Convert nwk to hex.""" - if isinstance(nwk, NWK): - return repr(nwk) - return nwk - - @field_serializer("ieee") - def serialize_ieee(self, ieee): - """Customize how ieee is serialized.""" - if isinstance(ieee, EUI64): - return str(ieee) - return ieee - - -class Device(BaseDevice): - """Device model.""" - - entities: dict[ - str, - Annotated[ - Union[ - SirenEntity, - SelectEntity, - NumberEntity, - LightEntity, - FanEntity, - ButtonEntity, - AlarmControlPanelEntity, - SensorEntity, - BinarySensorEntity, - DeviceTrackerEntity, - ShadeEntity, - CoverEntity, - LockEntity, - SwitchEntity, - BatteryEntity, - ElectricalMeasurementEntity, - SmartEnergyMeteringEntity, - ThermostatEntity, - DeviceCounterSensorEntity, - ], - Field(discriminator="class_name"), # noqa: F821 - ], - ] - neighbors: list[Any] - device_automation_triggers: dict[str, dict[str, Any]] - - @field_validator("entities", mode="before", check_fields=False) - @classmethod - def convert_entities(cls, entities: dict) -> dict: - """Convert entities keys from tuple to string.""" - if all(isinstance(k, tuple) for k in entities): - return {f"{k[0]}.{k[1]}": v for k, v in entities.items()} - assert all(isinstance(k, str) for k in entities) - return entities - - @field_validator("device_automation_triggers", mode="before", check_fields=False) - @classmethod - def convert_device_automation_triggers(cls, triggers: dict) -> dict: - """Convert device automation triggers keys from tuple to string.""" - if all(isinstance(k, tuple) for k in triggers): - return {f"{k[0]}~{k[1]}": v for k, v in triggers.items()} - return triggers - - -class GroupEntity(BaseEntity): - """Group entity model.""" - - group_id: int - state: Any - - -class LightGroupEntity(GroupEntity): - """Group entity model.""" - - class_name: Literal["LightGroup"] - state: LightState - - -class FanGroupEntity(GroupEntity): - """Group entity model.""" - - class_name: Literal["FanGroup"] - state: FanState - - -class SwitchGroupEntity(GroupEntity): - """Group entity model.""" - - class_name: Literal["SwitchGroup"] - state: SwitchState - - -class GroupMember(BaseModel): - """Group member model.""" - - ieee: EUI64 - endpoint_id: int - device: Device = Field(alias="device_info") - entities: dict[ - str, - Annotated[ - Union[ - SirenEntity, - SelectEntity, - NumberEntity, - LightEntity, - FanEntity, - ButtonEntity, - AlarmControlPanelEntity, - SensorEntity, - BinarySensorEntity, - DeviceTrackerEntity, - ShadeEntity, - CoverEntity, - LockEntity, - SwitchEntity, - BatteryEntity, - ElectricalMeasurementEntity, - SmartEnergyMeteringEntity, - ThermostatEntity, - ], - Field(discriminator="class_name"), # noqa: F821 - ], - ] - - -class Group(BaseModel): - """Group model.""" - - name: str - id: int - members: dict[EUI64, GroupMember] - entities: dict[ - str, - Annotated[ - Union[LightGroupEntity, FanGroupEntity, SwitchGroupEntity], - Field(discriminator="class_name"), # noqa: F821 - ], - ] - - @field_validator("members", mode="before", check_fields=False) - @classmethod - def convert_members(cls, members: dict | list[dict]) -> dict: - """Convert members.""" - - converted_members = {} - if isinstance(members, dict): - return {EUI64.convert(k): v for k, v in members.items()} - for member in members: - if "device" in member: - ieee = member["device"]["ieee"] - else: - ieee = member["device_info"]["ieee"] - if isinstance(ieee, str): - ieee = EUI64.convert(ieee) - elif isinstance(ieee, list) and not isinstance(ieee, EUI64): - ieee = EUI64.deserialize(ieee)[0] - converted_members[ieee] = member - return converted_members - - @field_serializer("members") - def serialize_members(self, members): - """Customize how members are serialized.""" - data = {str(k): v.model_dump(by_alias=True) for k, v in members.items()} - return data - - -class GroupMemberReference(BaseModel): - """Group member reference model.""" - - ieee: EUI64 - endpoint_id: int diff --git a/zha/websocket/client/proxy.py b/zha/websocket/client/proxy.py index 92db0e20e..fdf00aa42 100644 --- a/zha/websocket/client/proxy.py +++ b/zha/websocket/client/proxy.py @@ -2,22 +2,23 @@ from __future__ import annotations +import abc from typing import TYPE_CHECKING, Any -from zha.event import EventBase -from zha.websocket.client.model.events import PlatformEntityStateChangedEvent -from zha.websocket.client.model.types import ( - ButtonEntity, - Device as DeviceModel, - Group as GroupModel, +from zha.application.platforms.model import ( + BasePlatformEntity, + EntityStateChangedEvent, + GroupEntity, ) +from zha.event import EventBase +from zha.zigbee.model import ExtendedDeviceInfo, GroupInfo if TYPE_CHECKING: from zha.websocket.client.client import Client from zha.websocket.client.controller import Controller -class BaseProxyObject(EventBase): +class BaseProxyObject(EventBase, abc.ABC): """BaseProxyObject for the zhaws.client.""" def __init__(self, controller: Controller, client: Client): @@ -25,7 +26,7 @@ def __init__(self, controller: Controller, client: Client): super().__init__() self._controller: Controller = controller self._client: Client = client - self._proxied_object: GroupModel | DeviceModel + self._proxied_object: GroupInfo | ExtendedDeviceInfo @property def controller(self) -> Controller: @@ -37,44 +38,47 @@ def client(self) -> Client: """Return the client.""" return self._client - def emit_platform_entity_event( - self, event: PlatformEntityStateChangedEvent - ) -> None: + @abc.abstractmethod + def _get_entity( + self, event: EntityStateChangedEvent + ) -> BasePlatformEntity | GroupEntity: + """Get the entity for the event.""" + + def emit_platform_entity_event(self, event: EntityStateChangedEvent) -> None: """Proxy the firing of an entity event.""" - entity = self._proxied_object.entities.get( - f"{event.platform_entity.platform}.{event.platform_entity.unique_id}" - if event.group is None - else event.platform_entity.unique_id - ) + entity = self._get_entity(event) if entity is None: - if isinstance(self._proxied_object, DeviceModel): + if isinstance(self._proxied_object, ExtendedDeviceInfo): # type: ignore raise ValueError( f"Entity not found: {event.platform_entity.unique_id}", ) return # group entities are updated to get state when created so we may not have the entity yet - if not isinstance(entity, ButtonEntity): - entity.state = event.state - self.emit(f"{event.platform_entity.unique_id}_{event.event}", event) + entity.state = event.state + self.emit(f"{event.unique_id}_{event.event}", event) class GroupProxy(BaseProxyObject): """Group proxy for the zhaws.client.""" - def __init__(self, group_model: GroupModel, controller: Controller, client: Client): + def __init__(self, group_model: GroupInfo, controller: Controller, client: Client): """Initialize the GroupProxy class.""" super().__init__(controller, client) - self._proxied_object: GroupModel = group_model + self._proxied_object: GroupInfo = group_model @property - def group_model(self) -> GroupModel: + def group_model(self) -> GroupInfo: """Return the group model.""" return self._proxied_object @group_model.setter - def group_model(self, group_model: GroupModel) -> None: + def group_model(self, group_model: GroupInfo) -> None: """Set the group model.""" self._proxied_object = group_model + def _get_entity(self, event: EntityStateChangedEvent) -> GroupEntity: + """Get the entity for the event.""" + return self._proxied_object.entities.get(event.unique_id) # type: ignore + def __repr__(self) -> str: """Return the string representation of the group proxy.""" return self._proxied_object.__repr__() @@ -84,19 +88,19 @@ class DeviceProxy(BaseProxyObject): """Device proxy for the zhaws.client.""" def __init__( - self, device_model: DeviceModel, controller: Controller, client: Client + self, device_model: ExtendedDeviceInfo, controller: Controller, client: Client ): """Initialize the DeviceProxy class.""" super().__init__(controller, client) - self._proxied_object: DeviceModel = device_model + self._proxied_object: ExtendedDeviceInfo = device_model @property - def device_model(self) -> DeviceModel: + def device_model(self) -> ExtendedDeviceInfo: """Return the device model.""" return self._proxied_object @device_model.setter - def device_model(self, device_model: DeviceModel) -> None: + def device_model(self, device_model: ExtendedDeviceInfo) -> None: """Set the device model.""" self._proxied_object = device_model @@ -109,6 +113,10 @@ def device_automation_triggers(self) -> dict[tuple[str, str], dict[str, Any]]: for key, value in model_triggers.items() } + def _get_entity(self, event: EntityStateChangedEvent) -> BasePlatformEntity: + """Get the entity for the event.""" + return self._proxied_object.entities.get((event.platform, event.unique_id)) # type: ignore + def __repr__(self) -> str: """Return the string representation of the device proxy.""" return self._proxied_object.__repr__() diff --git a/zha/websocket/const.py b/zha/websocket/const.py index a5c6eca03..a0670a19a 100644 --- a/zha/websocket/const.py +++ b/zha/websocket/const.py @@ -92,7 +92,7 @@ class MessageTypes(StrEnum): class EventTypes(StrEnum): """WS event types.""" - CONTROLLER_EVENT = "controller_event" + CONTROLLER_EVENT = "zha_gateway_message" PLATFORM_ENTITY_EVENT = "platform_entity_event" RAW_ZCL_EVENT = "raw_zcl_event" DEVICE_EVENT = "device_event" diff --git a/zha/websocket/server/api/model.py b/zha/websocket/server/api/model.py index 370b2e249..04e6e885c 100644 --- a/zha/websocket/server/api/model.py +++ b/zha/websocket/server/api/model.py @@ -1,9 +1,28 @@ """Models for the websocket API.""" -from typing import Literal +from typing import Annotated, Any, Literal, Optional, Union +from pydantic import Field, field_serializer, field_validator +from zigpy.types.named import EUI64 + +from zha.application.model import ( + DeviceFullyInitializedEvent, + DeviceJoinedEvent, + DeviceLeftEvent, + DeviceOfflineEvent, + DeviceOnlineEvent, + DeviceRemovedEvent, + GroupAddedEvent, + GroupMemberAddedEvent, + GroupMemberRemovedEvent, + GroupRemovedEvent, + RawDeviceInitializedEvent, +) +from zha.application.platforms.model import EntityStateChangedEvent from zha.model import BaseModel from zha.websocket.const import APICommands +from zha.zigbee.cluster_handlers.model import ClusterInfo +from zha.zigbee.model import ExtendedDeviceInfo, GroupInfo, ZHAEvent class WebSocketCommand(BaseModel): @@ -63,3 +82,218 @@ class WebSocketCommand(BaseModel): APICommands.SWITCH_TURN_ON, APICommands.SWITCH_TURN_OFF, ] + + +class WebSocketCommandResponse(WebSocketCommand): + """Websocket command response.""" + + message_type: Literal["result"] = "result" + success: bool + + +class ErrorResponse(WebSocketCommandResponse): + """Error response model.""" + + success: bool = False + error_code: str + error_message: str + zigbee_error_code: Optional[str] + command: Literal[ + "error.start_network", + "error.stop_network", + "error.remove_device", + "error.stop_server", + "error.light_turn_on", + "error.light_turn_off", + "error.switch_turn_on", + "error.switch_turn_off", + "error.lock_lock", + "error.lock_unlock", + "error.lock_set_user_lock_code", + "error.lock_clear_user_lock_code", + "error.lock_disable_user_lock_code", + "error.lock_enable_user_lock_code", + "error.fan_turn_on", + "error.fan_turn_off", + "error.fan_set_percentage", + "error.fan_set_preset_mode", + "error.cover_open", + "error.cover_close", + "error.cover_set_position", + "error.cover_stop", + "error.climate_set_fan_mode", + "error.climate_set_hvac_mode", + "error.climate_set_preset_mode", + "error.climate_set_temperature", + "error.button_press", + "error.alarm_control_panel_disarm", + "error.alarm_control_panel_arm_home", + "error.alarm_control_panel_arm_away", + "error.alarm_control_panel_arm_night", + "error.alarm_control_panel_trigger", + "error.select_select_option", + "error.siren_turn_on", + "error.siren_turn_off", + "error.number_set_value", + "error.platform_entity_refresh_state", + "error.client_listen", + "error.client_listen_raw_zcl", + "error.client_disconnect", + "error.reconfigure_device", + "error.UpdateNetworkTopologyCommand", + ] + + +class DefaultResponse(WebSocketCommandResponse): + """Default command response.""" + + command: Literal[ + "start_network", + "stop_network", + "remove_device", + "stop_server", + "light_turn_on", + "light_turn_off", + "switch_turn_on", + "switch_turn_off", + "lock_lock", + "lock_unlock", + "lock_set_user_lock_code", + "lock_clear_user_lock_code", + "lock_disable_user_lock_code", + "lock_enable_user_lock_code", + "fan_turn_on", + "fan_turn_off", + "fan_set_percentage", + "fan_set_preset_mode", + "cover_open", + "cover_close", + "cover_set_position", + "cover_stop", + "climate_set_fan_mode", + "climate_set_hvac_mode", + "climate_set_preset_mode", + "climate_set_temperature", + "button_press", + "alarm_control_panel_disarm", + "alarm_control_panel_arm_home", + "alarm_control_panel_arm_away", + "alarm_control_panel_arm_night", + "alarm_control_panel_trigger", + "select_select_option", + "siren_turn_on", + "siren_turn_off", + "number_set_value", + "platform_entity_refresh_state", + "client_listen", + "client_listen_raw_zcl", + "client_disconnect", + "reconfigure_device", + "UpdateNetworkTopologyCommand", + ] + + +class PermitJoiningResponse(WebSocketCommandResponse): + """Get devices response.""" + + command: Literal["permit_joining"] = "permit_joining" + duration: int + + +class GetDevicesResponse(WebSocketCommandResponse): + """Get devices response.""" + + command: Literal["get_devices"] = "get_devices" + devices: dict[EUI64, ExtendedDeviceInfo] + + @field_serializer("devices", check_fields=False) + def serialize_devices(self, devices: dict[EUI64, ExtendedDeviceInfo]) -> dict: + """Serialize devices.""" + return {str(ieee): device for ieee, device in devices.items()} + + @field_validator("devices", mode="before", check_fields=False) + @classmethod + def convert_devices( + cls, devices: dict[str, ExtendedDeviceInfo] + ) -> dict[EUI64, ExtendedDeviceInfo]: + """Convert devices.""" + if all(isinstance(ieee, str) for ieee in devices): + return {EUI64.convert(ieee): device for ieee, device in devices.items()} + return devices + + +class ReadClusterAttributesResponse(WebSocketCommandResponse): + """Read cluster attributes response.""" + + command: Literal["read_cluster_attributes"] = "read_cluster_attributes" + device: ExtendedDeviceInfo + cluster: ClusterInfo + manufacturer_code: Optional[int] + succeeded: dict[str, Any] + failed: dict[str, Any] + + +class AttributeStatus(BaseModel): + """Attribute status.""" + + attribute: str + status: str + + +class WriteClusterAttributeResponse(WebSocketCommandResponse): + """Write cluster attribute response.""" + + command: Literal["write_cluster_attribute"] = "write_cluster_attribute" + device: ExtendedDeviceInfo + cluster: ClusterInfo + manufacturer_code: Optional[int] + response: AttributeStatus + + +class GroupsResponse(WebSocketCommandResponse): + """Get groups response.""" + + command: Literal["get_groups", "remove_groups"] + groups: dict[int, GroupInfo] + + +class UpdateGroupResponse(WebSocketCommandResponse): + """Update group response.""" + + command: Literal["create_group", "add_group_members", "remove_group_members"] + group: GroupInfo + + +CommandResponses = Annotated[ + Union[ + DefaultResponse, + ErrorResponse, + GetDevicesResponse, + GroupsResponse, + PermitJoiningResponse, + UpdateGroupResponse, + ReadClusterAttributesResponse, + WriteClusterAttributeResponse, + ], + Field(discriminator="command"), +] + + +Events = Annotated[ + Union[ + EntityStateChangedEvent, + DeviceJoinedEvent, + RawDeviceInitializedEvent, + DeviceFullyInitializedEvent, + DeviceLeftEvent, + DeviceRemovedEvent, + GroupRemovedEvent, + GroupAddedEvent, + GroupMemberAddedEvent, + GroupMemberRemovedEvent, + DeviceOfflineEvent, + DeviceOnlineEvent, + ZHAEvent, + ], + Field(discriminator="event"), +] diff --git a/zha/websocket/server/api/platforms/__init__.py b/zha/websocket/server/api/platforms/__init__.py new file mode 100644 index 000000000..1648efcf0 --- /dev/null +++ b/zha/websocket/server/api/platforms/__init__.py @@ -0,0 +1,19 @@ +"""Websocket api platform module for zha.""" + +from __future__ import annotations + +from typing import Union + +from zigpy.types.named import EUI64 + +from zha.application.platforms import Platform +from zha.websocket.server.api.model import WebSocketCommand + + +class PlatformEntityCommand(WebSocketCommand): + """Base class for platform entity commands.""" + + ieee: Union[EUI64, None] = None + group_id: Union[int, None] = None + unique_id: str + platform: Platform diff --git a/zha/websocket/server/api/platforms/alarm_control_panel/__init__.py b/zha/websocket/server/api/platforms/alarm_control_panel/__init__.py new file mode 100644 index 000000000..272c7366e --- /dev/null +++ b/zha/websocket/server/api/platforms/alarm_control_panel/__init__.py @@ -0,0 +1,3 @@ +"""Alarm control panel websocket api for zha.""" + +from __future__ import annotations diff --git a/zha/websocket/server/api/platforms/alarm_control_panel/api.py b/zha/websocket/server/api/platforms/alarm_control_panel/api.py new file mode 100644 index 000000000..2c06ed5a8 --- /dev/null +++ b/zha/websocket/server/api/platforms/alarm_control_panel/api.py @@ -0,0 +1,117 @@ +"""WS api for the alarm control panel platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, Union + +from zha.application.discovery import Platform +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command +from zha.websocket.server.api.platforms import PlatformEntityCommand +from zha.websocket.server.api.platforms.api import execute_platform_entity_command + +if TYPE_CHECKING: + from zha.websocket.server.client import Client + from zha.websocket.server.gateway import WebSocketGateway as Server + + +class DisarmCommand(PlatformEntityCommand): + """Disarm command.""" + + command: Literal[APICommands.ALARM_CONTROL_PANEL_DISARM] = ( + APICommands.ALARM_CONTROL_PANEL_DISARM + ) + platform: str = Platform.ALARM_CONTROL_PANEL + code: Union[str, None] + + +@decorators.websocket_command(DisarmCommand) +@decorators.async_response +async def disarm(server: Server, client: Client, command: DisarmCommand) -> None: + """Disarm the alarm control panel.""" + await execute_platform_entity_command(server, client, command, "async_alarm_disarm") + + +class ArmHomeCommand(PlatformEntityCommand): + """Arm home command.""" + + command: Literal[APICommands.ALARM_CONTROL_PANEL_ARM_HOME] = ( + APICommands.ALARM_CONTROL_PANEL_ARM_HOME + ) + platform: str = Platform.ALARM_CONTROL_PANEL + code: Union[str, None] + + +@decorators.websocket_command(ArmHomeCommand) +@decorators.async_response +async def arm_home(server: Server, client: Client, command: ArmHomeCommand) -> None: + """Arm the alarm control panel in home mode.""" + await execute_platform_entity_command( + server, client, command, "async_alarm_arm_home" + ) + + +class ArmAwayCommand(PlatformEntityCommand): + """Arm away command.""" + + command: Literal[APICommands.ALARM_CONTROL_PANEL_ARM_AWAY] = ( + APICommands.ALARM_CONTROL_PANEL_ARM_AWAY + ) + platform: str = Platform.ALARM_CONTROL_PANEL + code: Union[str, None] + + +@decorators.websocket_command(ArmAwayCommand) +@decorators.async_response +async def arm_away(server: Server, client: Client, command: ArmAwayCommand) -> None: + """Arm the alarm control panel in away mode.""" + await execute_platform_entity_command( + server, client, command, "async_alarm_arm_away" + ) + + +class ArmNightCommand(PlatformEntityCommand): + """Arm night command.""" + + command: Literal[APICommands.ALARM_CONTROL_PANEL_ARM_NIGHT] = ( + APICommands.ALARM_CONTROL_PANEL_ARM_NIGHT + ) + platform: str = Platform.ALARM_CONTROL_PANEL + code: Union[str, None] + + +@decorators.websocket_command(ArmNightCommand) +@decorators.async_response +async def arm_night(server: Server, client: Client, command: ArmNightCommand) -> None: + """Arm the alarm control panel in night mode.""" + await execute_platform_entity_command( + server, client, command, "async_alarm_arm_night" + ) + + +class TriggerAlarmCommand(PlatformEntityCommand): + """Trigger alarm command.""" + + command: Literal[APICommands.ALARM_CONTROL_PANEL_TRIGGER] = ( + APICommands.ALARM_CONTROL_PANEL_TRIGGER + ) + platform: str = Platform.ALARM_CONTROL_PANEL + code: Union[str, None] = None + + +@decorators.websocket_command(TriggerAlarmCommand) +@decorators.async_response +async def trigger(server: Server, client: Client, command: TriggerAlarmCommand) -> None: + """Trigger the alarm control panel.""" + await execute_platform_entity_command( + server, client, command, "async_alarm_trigger" + ) + + +def load_api(server: Server) -> None: + """Load the api command handlers.""" + register_api_command(server, disarm) + register_api_command(server, arm_home) + register_api_command(server, arm_away) + register_api_command(server, arm_night) + register_api_command(server, trigger) diff --git a/zha/websocket/server/api/platforms/api.py b/zha/websocket/server/api/platforms/api.py new file mode 100644 index 000000000..537b2e9bc --- /dev/null +++ b/zha/websocket/server/api/platforms/api.py @@ -0,0 +1,124 @@ +"""WS API for common platform entity functionality.""" + +from __future__ import annotations + +import inspect +import logging +from typing import TYPE_CHECKING, Any, Literal + +from zha.websocket.const import ATTR_UNIQUE_ID, IEEE, APICommands +from zha.websocket.server.api import decorators, register_api_command +from zha.websocket.server.api.platforms import PlatformEntityCommand + +if TYPE_CHECKING: + from zha.websocket.client import Client + from zha.websocket.server.gateway import WebSocketGateway as Server + +_LOGGER = logging.getLogger(__name__) + + +async def execute_platform_entity_command( + server: Server, + client: Client, + command: PlatformEntityCommand, + method_name: str, +) -> None: + """Get the platform entity and execute a method based on the command.""" + try: + if command.ieee: + _LOGGER.debug("command: %s", command) + device = server.get_device(command.ieee) + platform_entity: Any = device.get_platform_entity( + command.platform, command.unique_id + ) + else: + assert command.group_id + group = server.get_group(command.group_id) + platform_entity = group.group_entities[command.unique_id] + except ValueError as err: + _LOGGER.exception( + "Error executing command: %s method_name: %s", + command, + method_name, + exc_info=err, + ) + client.send_result_error(command, "PLATFORM_ENTITY_COMMAND_ERROR", str(err)) + return None + + try: + action = getattr(platform_entity, method_name) + arg_spec = inspect.getfullargspec(action) + if arg_spec.varkw: # the only argument is self + await action(**command.model_dump(exclude_none=True)) + else: + await action() + + except Exception as err: + _LOGGER.exception("Error executing command: %s", method_name, exc_info=err) + client.send_result_error(command, "PLATFORM_ENTITY_ACTION_ERROR", str(err)) + return + + result: dict[str, Any] = {} + if command.ieee: + result[IEEE] = str(command.ieee) + else: + result["group_id"] = command.group_id + result[ATTR_UNIQUE_ID] = command.unique_id + client.send_result_success(command, result) + + +class PlatformEntityRefreshStateCommand(PlatformEntityCommand): + """Platform entity refresh state command.""" + + command: Literal[APICommands.PLATFORM_ENTITY_REFRESH_STATE] = ( + APICommands.PLATFORM_ENTITY_REFRESH_STATE + ) + + +@decorators.websocket_command(PlatformEntityRefreshStateCommand) +@decorators.async_response +async def refresh_state( + server: Server, client: Client, command: PlatformEntityCommand +) -> None: + """Refresh the state of the platform entity.""" + await execute_platform_entity_command(server, client, command, "async_update") + + +def load_platform_entity_apis(server: Server) -> None: + """Load the ws apis for all platform entities types.""" + from zha.websocket.server.api.platforms.alarm_control_panel.api import ( + load_api as load_alarm_control_panel_api, + ) + from zha.websocket.server.api.platforms.button.api import ( + load_api as load_button_api, + ) + from zha.websocket.server.api.platforms.climate.api import ( + load_api as load_climate_api, + ) + from zha.websocket.server.api.platforms.cover.api import load_api as load_cover_api + from zha.websocket.server.api.platforms.fan.api import load_api as load_fan_api + from zha.websocket.server.api.platforms.light.api import load_api as load_light_api + from zha.websocket.server.api.platforms.lock.api import load_api as load_lock_api + from zha.websocket.server.api.platforms.number.api import ( + load_api as load_number_api, + ) + from zha.websocket.server.api.platforms.select.api import ( + load_api as load_select_api, + ) + from zha.websocket.server.api.platforms.siren.api import load_api as load_siren_api + from zha.websocket.server.api.platforms.switch.api import ( + load_api as load_switch_api, + ) + + register_api_command(server, refresh_state) + load_alarm_control_panel_api(server) + load_button_api(server) + load_climate_api(server) + load_cover_api(server) + load_fan_api(server) + load_light_api(server) + load_lock_api(server) + load_number_api(server) + load_select_api(server) + load_siren_api(server) + load_switch_api(server) diff --git a/zha/websocket/server/api/platforms/button/__init__.py b/zha/websocket/server/api/platforms/button/__init__.py new file mode 100644 index 000000000..1564a7f40 --- /dev/null +++ b/zha/websocket/server/api/platforms/button/__init__.py @@ -0,0 +1,3 @@ +"""Button platform websocket api for zha.""" + +from __future__ import annotations diff --git a/zha/websocket/server/api/platforms/button/api.py b/zha/websocket/server/api/platforms/button/api.py new file mode 100644 index 000000000..3fb6d7f10 --- /dev/null +++ b/zha/websocket/server/api/platforms/button/api.py @@ -0,0 +1,34 @@ +"""WS API for the button platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from zha.application.discovery import Platform +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command +from zha.websocket.server.api.platforms import PlatformEntityCommand +from zha.websocket.server.api.platforms.api import execute_platform_entity_command + +if TYPE_CHECKING: + from zha.websocket.server.client import Client + from zha.websocket.server.gateway import WebSocketGateway as Server + + +class ButtonPressCommand(PlatformEntityCommand): + """Button press command.""" + + command: Literal[APICommands.BUTTON_PRESS] = APICommands.BUTTON_PRESS + platform: str = Platform.BUTTON + + +@decorators.websocket_command(ButtonPressCommand) +@decorators.async_response +async def press(server: Server, client: Client, command: PlatformEntityCommand) -> None: + """Turn on the button.""" + await execute_platform_entity_command(server, client, command, "async_press") + + +def load_api(server: Server) -> None: + """Load the api command handlers.""" + register_api_command(server, press) diff --git a/zha/websocket/server/api/platforms/climate/__init__.py b/zha/websocket/server/api/platforms/climate/__init__.py new file mode 100644 index 000000000..e1a798eae --- /dev/null +++ b/zha/websocket/server/api/platforms/climate/__init__.py @@ -0,0 +1,3 @@ +"""Climate platform websocket api for zha.""" + +from __future__ import annotations diff --git a/zha/websocket/server/api/platforms/climate/api.py b/zha/websocket/server/api/platforms/climate/api.py new file mode 100644 index 000000000..7b3bb9e82 --- /dev/null +++ b/zha/websocket/server/api/platforms/climate/api.py @@ -0,0 +1,128 @@ +"""WS api for the climate platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, Optional, Union + +from zha.application.discovery import Platform +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command +from zha.websocket.server.api.platforms import PlatformEntityCommand +from zha.websocket.server.api.platforms.api import execute_platform_entity_command + +if TYPE_CHECKING: + from zha.websocket.server.client import Client + from zha.websocket.server.gateway import WebSocketGateway as Server + + +class ClimateSetFanModeCommand(PlatformEntityCommand): + """Set fan mode command.""" + + command: Literal[APICommands.CLIMATE_SET_FAN_MODE] = ( + APICommands.CLIMATE_SET_FAN_MODE + ) + platform: str = Platform.CLIMATE + fan_mode: str + + +@decorators.websocket_command(ClimateSetFanModeCommand) +@decorators.async_response +async def set_fan_mode( + server: Server, client: Client, command: ClimateSetFanModeCommand +) -> None: + """Set the fan mode for the climate platform entity.""" + await execute_platform_entity_command(server, client, command, "async_set_fan_mode") + + +class ClimateSetHVACModeCommand(PlatformEntityCommand): + """Set HVAC mode command.""" + + command: Literal[APICommands.CLIMATE_SET_HVAC_MODE] = ( + APICommands.CLIMATE_SET_HVAC_MODE + ) + platform: str = Platform.CLIMATE + hvac_mode: Literal[ + "off", # All activity disabled / Device is off/standby + "heat", # Heating + "cool", # Cooling + "heat_cool", # The device supports heating/cooling to a range + "auto", # The temperature is set based on a schedule, learned behavior, AI or some other related mechanism. User is not able to adjust the temperature + "dry", # Device is in Dry/Humidity mode + "fan_only", # Only the fan is on, not fan and another mode like cool + ] + + +@decorators.websocket_command(ClimateSetHVACModeCommand) +@decorators.async_response +async def set_hvac_mode( + server: Server, client: Client, command: ClimateSetHVACModeCommand +) -> None: + """Set the hvac mode for the climate platform entity.""" + await execute_platform_entity_command( + server, client, command, "async_set_hvac_mode" + ) + + +class ClimateSetPresetModeCommand(PlatformEntityCommand): + """Set preset mode command.""" + + command: Literal[APICommands.CLIMATE_SET_PRESET_MODE] = ( + APICommands.CLIMATE_SET_PRESET_MODE + ) + platform: str = Platform.CLIMATE + preset_mode: str + + +@decorators.websocket_command(ClimateSetPresetModeCommand) +@decorators.async_response +async def set_preset_mode( + server: Server, client: Client, command: ClimateSetPresetModeCommand +) -> None: + """Set the preset mode for the climate platform entity.""" + await execute_platform_entity_command( + server, client, command, "async_set_preset_mode" + ) + + +class ClimateSetTemperatureCommand(PlatformEntityCommand): + """Set temperature command.""" + + command: Literal[APICommands.CLIMATE_SET_TEMPERATURE] = ( + APICommands.CLIMATE_SET_TEMPERATURE + ) + platform: str = Platform.CLIMATE + temperature: Union[float, None] + target_temp_high: Union[float, None] + target_temp_low: Union[float, None] + hvac_mode: Optional[ + ( + Literal[ + "off", # All activity disabled / Device is off/standby + "heat", # Heating + "cool", # Cooling + "heat_cool", # The device supports heating/cooling to a range + "auto", # The temperature is set based on a schedule, learned behavior, AI or some other related mechanism. User is not able to adjust the temperature + "dry", # Device is in Dry/Humidity mode + "fan_only", # Only the fan is on, not fan and another mode like cool + ] + ) + ] + + +@decorators.websocket_command(ClimateSetTemperatureCommand) +@decorators.async_response +async def set_temperature( + server: Server, client: Client, command: ClimateSetTemperatureCommand +) -> None: + """Set the temperature and hvac mode for the climate platform entity.""" + await execute_platform_entity_command( + server, client, command, "async_set_temperature" + ) + + +def load_api(server: Server) -> None: + """Load the api command handlers.""" + register_api_command(server, set_fan_mode) + register_api_command(server, set_hvac_mode) + register_api_command(server, set_preset_mode) + register_api_command(server, set_temperature) diff --git a/zha/websocket/server/api/platforms/cover/__init__.py b/zha/websocket/server/api/platforms/cover/__init__.py new file mode 100644 index 000000000..0b9ac675d --- /dev/null +++ b/zha/websocket/server/api/platforms/cover/__init__.py @@ -0,0 +1,3 @@ +"""Cover platform websocket api for zha.""" + +from __future__ import annotations diff --git a/zha/websocket/server/api/platforms/cover/api.py b/zha/websocket/server/api/platforms/cover/api.py new file mode 100644 index 000000000..1337de241 --- /dev/null +++ b/zha/websocket/server/api/platforms/cover/api.py @@ -0,0 +1,86 @@ +"""WS API for the cover platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from zha.application.discovery import Platform +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command +from zha.websocket.server.api.platforms import PlatformEntityCommand +from zha.websocket.server.api.platforms.api import execute_platform_entity_command + +if TYPE_CHECKING: + from zha.websocket.server.client import Client + from zha.websocket.server.gateway import WebSocketGateway as Server + + +class CoverOpenCommand(PlatformEntityCommand): + """Cover open command.""" + + command: Literal[APICommands.COVER_OPEN] = APICommands.COVER_OPEN + platform: str = Platform.COVER + + +@decorators.websocket_command(CoverOpenCommand) +@decorators.async_response +async def open_cover(server: Server, client: Client, command: CoverOpenCommand) -> None: + """Open the cover.""" + await execute_platform_entity_command(server, client, command, "async_open_cover") + + +class CoverCloseCommand(PlatformEntityCommand): + """Cover close command.""" + + command: Literal[APICommands.COVER_CLOSE] = APICommands.COVER_CLOSE + platform: str = Platform.COVER + + +@decorators.websocket_command(CoverCloseCommand) +@decorators.async_response +async def close_cover( + server: Server, client: Client, command: CoverCloseCommand +) -> None: + """Close the cover.""" + await execute_platform_entity_command(server, client, command, "async_close_cover") + + +class CoverSetPositionCommand(PlatformEntityCommand): + """Cover set position command.""" + + command: Literal[APICommands.COVER_SET_POSITION] = APICommands.COVER_SET_POSITION + platform: str = Platform.COVER + position: int + + +@decorators.websocket_command(CoverSetPositionCommand) +@decorators.async_response +async def set_position( + server: Server, client: Client, command: CoverSetPositionCommand +) -> None: + """Set the cover position.""" + await execute_platform_entity_command( + server, client, command, "async_set_cover_position" + ) + + +class CoverStopCommand(PlatformEntityCommand): + """Cover stop command.""" + + command: Literal[APICommands.COVER_STOP] = APICommands.COVER_STOP + platform: str = Platform.COVER + + +@decorators.websocket_command(CoverStopCommand) +@decorators.async_response +async def stop_cover(server: Server, client: Client, command: CoverStopCommand) -> None: + """Stop the cover.""" + await execute_platform_entity_command(server, client, command, "async_stop_cover") + + +def load_api(server: Server) -> None: + """Load the api command handlers.""" + register_api_command(server, open_cover) + register_api_command(server, close_cover) + register_api_command(server, set_position) + register_api_command(server, stop_cover) diff --git a/zha/websocket/server/api/platforms/fan/__init__.py b/zha/websocket/server/api/platforms/fan/__init__.py new file mode 100644 index 000000000..ade306f84 --- /dev/null +++ b/zha/websocket/server/api/platforms/fan/__init__.py @@ -0,0 +1,3 @@ +"""Fan platform websocket api for zha.""" + +from __future__ import annotations diff --git a/zha/websocket/server/api/platforms/fan/api.py b/zha/websocket/server/api/platforms/fan/api.py new file mode 100644 index 000000000..4577be21b --- /dev/null +++ b/zha/websocket/server/api/platforms/fan/api.py @@ -0,0 +1,94 @@ +"""WS API for the fan platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Annotated, Literal, Union + +from pydantic import Field + +from zha.application.discovery import Platform +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command +from zha.websocket.server.api.platforms import PlatformEntityCommand +from zha.websocket.server.api.platforms.api import execute_platform_entity_command + +if TYPE_CHECKING: + from zha.websocket.server.client import Client + from zha.websocket.server.gateway import WebSocketGateway as Server + + +class FanTurnOnCommand(PlatformEntityCommand): + """Fan turn on command.""" + + command: Literal[APICommands.FAN_TURN_ON] = APICommands.FAN_TURN_ON + platform: str = Platform.FAN + speed: Union[str, None] + percentage: Union[Annotated[int, Field(ge=0, le=100)], None] + preset_mode: Union[str, None] + + +@decorators.websocket_command(FanTurnOnCommand) +@decorators.async_response +async def turn_on(server: Server, client: Client, command: FanTurnOnCommand) -> None: + """Turn fan on.""" + await execute_platform_entity_command(server, client, command, "async_turn_on") + + +class FanTurnOffCommand(PlatformEntityCommand): + """Fan turn off command.""" + + command: Literal[APICommands.FAN_TURN_OFF] = APICommands.FAN_TURN_OFF + platform: str = Platform.FAN + + +@decorators.websocket_command(FanTurnOffCommand) +@decorators.async_response +async def turn_off(server: Server, client: Client, command: FanTurnOffCommand) -> None: + """Turn fan off.""" + await execute_platform_entity_command(server, client, command, "async_turn_off") + + +class FanSetPercentageCommand(PlatformEntityCommand): + """Fan set percentage command.""" + + command: Literal[APICommands.FAN_SET_PERCENTAGE] = APICommands.FAN_SET_PERCENTAGE + platform: str = Platform.FAN + percentage: Annotated[int, Field(ge=0, le=100)] + + +@decorators.websocket_command(FanSetPercentageCommand) +@decorators.async_response +async def set_percentage( + server: Server, client: Client, command: FanSetPercentageCommand +) -> None: + """Set the fan speed percentage.""" + await execute_platform_entity_command( + server, client, command, "async_set_percentage" + ) + + +class FanSetPresetModeCommand(PlatformEntityCommand): + """Fan set preset mode command.""" + + command: Literal[APICommands.FAN_SET_PRESET_MODE] = APICommands.FAN_SET_PRESET_MODE + platform: str = Platform.FAN + preset_mode: str + + +@decorators.websocket_command(FanSetPresetModeCommand) +@decorators.async_response +async def set_preset_mode( + server: Server, client: Client, command: FanSetPresetModeCommand +) -> None: + """Set the fan preset mode.""" + await execute_platform_entity_command( + server, client, command, "async_set_preset_mode" + ) + + +def load_api(server: Server) -> None: + """Load the api command handlers.""" + register_api_command(server, turn_on) + register_api_command(server, turn_off) + register_api_command(server, set_percentage) + register_api_command(server, set_preset_mode) diff --git a/zha/websocket/server/api/platforms/light/__init__.py b/zha/websocket/server/api/platforms/light/__init__.py new file mode 100644 index 000000000..0a30fdf35 --- /dev/null +++ b/zha/websocket/server/api/platforms/light/__init__.py @@ -0,0 +1,3 @@ +"""Light platform websocket api for zha.""" + +from __future__ import annotations diff --git a/zha/websocket/server/api/platforms/light/api.py b/zha/websocket/server/api/platforms/light/api.py new file mode 100644 index 000000000..237b4a08b --- /dev/null +++ b/zha/websocket/server/api/platforms/light/api.py @@ -0,0 +1,85 @@ +"""WS API for the light platform entity.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Annotated, Literal, Union + +from pydantic import Field, ValidationInfo, field_validator + +from zha.application.discovery import Platform +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command +from zha.websocket.server.api.platforms import PlatformEntityCommand +from zha.websocket.server.api.platforms.api import execute_platform_entity_command + +if TYPE_CHECKING: + from zha.websocket.server.client import Client + from zha.websocket.server.gateway import WebSocketGateway as Server + +_LOGGER = logging.getLogger(__name__) + + +class LightTurnOnCommand(PlatformEntityCommand): + """Light turn on command.""" + + command: Literal[APICommands.LIGHT_TURN_ON] = APICommands.LIGHT_TURN_ON + platform: str = Platform.LIGHT + brightness: Union[Annotated[int, Field(ge=0, le=255)], None] + transition: Union[Annotated[float, Field(ge=0, le=6553)], None] + flash: Union[Literal["short", "long"], None] + effect: Union[str, None] + hs_color: Union[ + None, + ( + tuple[ + Annotated[int, Field(ge=0, le=360)], Annotated[int, Field(ge=0, le=100)] + ] + ), + ] + color_temp: Union[int, None] + + @field_validator("color_temp", mode="before", check_fields=False) + @classmethod + def check_color_setting_exclusivity( + cls, color_temp: int | None, validation_info: ValidationInfo + ) -> int | None: + """Ensure only one color mode is set.""" + if ( + "hs_color" in validation_info.data + and validation_info.data["hs_color"] is not None + and color_temp is not None + ): + raise ValueError('Only one of "hs_color" and "color_temp" can be set') + return color_temp + + +@decorators.websocket_command(LightTurnOnCommand) +@decorators.async_response +async def turn_on(server: Server, client: Client, command: LightTurnOnCommand) -> None: + """Turn on the light.""" + await execute_platform_entity_command(server, client, command, "async_turn_on") + + +class LightTurnOffCommand(PlatformEntityCommand): + """Light turn off command.""" + + command: Literal[APICommands.LIGHT_TURN_OFF] = APICommands.LIGHT_TURN_OFF + platform: str = Platform.LIGHT + transition: Union[Annotated[float, Field(ge=0, le=6553)], None] + flash: Union[Literal["short", "long"], None] + + +@decorators.websocket_command(LightTurnOffCommand) +@decorators.async_response +async def turn_off( + server: Server, client: Client, command: LightTurnOffCommand +) -> None: + """Turn on the light.""" + await execute_platform_entity_command(server, client, command, "async_turn_off") + + +def load_api(server: Server) -> None: + """Load the api command handlers.""" + register_api_command(server, turn_on) + register_api_command(server, turn_off) diff --git a/zha/websocket/server/api/platforms/lock/__init__.py b/zha/websocket/server/api/platforms/lock/__init__.py new file mode 100644 index 000000000..69515fd09 --- /dev/null +++ b/zha/websocket/server/api/platforms/lock/__init__.py @@ -0,0 +1,3 @@ +"""Lock platform websocket api for zha.""" + +from __future__ import annotations diff --git a/zha/websocket/server/api/platforms/lock/api.py b/zha/websocket/server/api/platforms/lock/api.py new file mode 100644 index 000000000..a52ca5002 --- /dev/null +++ b/zha/websocket/server/api/platforms/lock/api.py @@ -0,0 +1,136 @@ +"""WS api for the lock platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from zha.application.discovery import Platform +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command +from zha.websocket.server.api.platforms import PlatformEntityCommand +from zha.websocket.server.api.platforms.api import execute_platform_entity_command + +if TYPE_CHECKING: + from zha.websocket.server.client import Client + from zha.websocket.server.gateway import WebSocketGateway as Server + + +class LockLockCommand(PlatformEntityCommand): + """Lock lock command.""" + + command: Literal[APICommands.LOCK_LOCK] = APICommands.LOCK_LOCK + platform: str = Platform.LOCK + + +@decorators.websocket_command(LockLockCommand) +@decorators.async_response +async def lock(server: Server, client: Client, command: LockLockCommand) -> None: + """Lock the lock.""" + await execute_platform_entity_command(server, client, command, "async_lock") + + +class LockUnlockCommand(PlatformEntityCommand): + """Lock unlock command.""" + + command: Literal[APICommands.LOCK_UNLOCK] = APICommands.LOCK_UNLOCK + platform: str = Platform.LOCK + + +@decorators.websocket_command(LockUnlockCommand) +@decorators.async_response +async def unlock(server: Server, client: Client, command: LockUnlockCommand) -> None: + """Unlock the lock.""" + await execute_platform_entity_command(server, client, command, "async_unlock") + + +class LockSetUserLockCodeCommand(PlatformEntityCommand): + """Set user lock code command.""" + + command: Literal[APICommands.LOCK_SET_USER_CODE] = APICommands.LOCK_SET_USER_CODE + platform: str = Platform.LOCK + code_slot: int + user_code: str + + +@decorators.websocket_command(LockSetUserLockCodeCommand) +@decorators.async_response +async def set_user_lock_code( + server: Server, client: Client, command: LockSetUserLockCodeCommand +) -> None: + """Set a user lock code in the specified slot for the lock.""" + await execute_platform_entity_command( + server, client, command, "async_set_lock_user_code" + ) + + +class LockEnableUserLockCodeCommand(PlatformEntityCommand): + """Enable user lock code command.""" + + command: Literal[APICommands.LOCK_ENAABLE_USER_CODE] = ( + APICommands.LOCK_ENAABLE_USER_CODE + ) + platform: str = Platform.LOCK + code_slot: int + + +@decorators.websocket_command(LockEnableUserLockCodeCommand) +@decorators.async_response +async def enable_user_lock_code( + server: Server, client: Client, command: LockEnableUserLockCodeCommand +) -> None: + """Enable a user lock code for the lock.""" + await execute_platform_entity_command( + server, client, command, "async_enable_lock_user_code" + ) + + +class LockDisableUserLockCodeCommand(PlatformEntityCommand): + """Disable user lock code command.""" + + command: Literal[APICommands.LOCK_DISABLE_USER_CODE] = ( + APICommands.LOCK_DISABLE_USER_CODE + ) + platform: str = Platform.LOCK + code_slot: int + + +@decorators.websocket_command(LockDisableUserLockCodeCommand) +@decorators.async_response +async def disable_user_lock_code( + server: Server, client: Client, command: LockDisableUserLockCodeCommand +) -> None: + """Disable a user lock code for the lock.""" + await execute_platform_entity_command( + server, client, command, "async_disable_lock_user_code" + ) + + +class LockClearUserLockCodeCommand(PlatformEntityCommand): + """Clear user lock code command.""" + + command: Literal[APICommands.LOCK_CLEAR_USER_CODE] = ( + APICommands.LOCK_CLEAR_USER_CODE + ) + platform: str = Platform.LOCK + code_slot: int + + +@decorators.websocket_command(LockClearUserLockCodeCommand) +@decorators.async_response +async def clear_user_lock_code( + server: Server, client: Client, command: LockClearUserLockCodeCommand +) -> None: + """Clear a user lock code for the lock.""" + await execute_platform_entity_command( + server, client, command, "async_clear_lock_user_code" + ) + + +def load_api(server: Server) -> None: + """Load the api command handlers.""" + register_api_command(server, lock) + register_api_command(server, unlock) + register_api_command(server, set_user_lock_code) + register_api_command(server, enable_user_lock_code) + register_api_command(server, disable_user_lock_code) + register_api_command(server, clear_user_lock_code) diff --git a/zha/websocket/server/api/platforms/number/__init__.py b/zha/websocket/server/api/platforms/number/__init__.py new file mode 100644 index 000000000..24ebd7482 --- /dev/null +++ b/zha/websocket/server/api/platforms/number/__init__.py @@ -0,0 +1,3 @@ +"""Number platform websocket api for zha.""" + +from __future__ import annotations diff --git a/zha/websocket/server/api/platforms/number/api.py b/zha/websocket/server/api/platforms/number/api.py new file mode 100644 index 000000000..c311a92c2 --- /dev/null +++ b/zha/websocket/server/api/platforms/number/api.py @@ -0,0 +1,40 @@ +"""WS api for the number platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from zha.application.discovery import Platform +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command +from zha.websocket.server.api.platforms import PlatformEntityCommand +from zha.websocket.server.api.platforms.api import execute_platform_entity_command + +if TYPE_CHECKING: + from zha.websocket.server.client import Client + from zha.websocket.server.gateway import WebSocketGateway as Server + +ATTR_VALUE = "value" +COMMAND_SET_VALUE = "number_set_value" + + +class NumberSetValueCommand(PlatformEntityCommand): + """Number set value command.""" + + command: Literal[APICommands.NUMBER_SET_VALUE] = APICommands.NUMBER_SET_VALUE + platform: str = Platform.NUMBER + value: float + + +@decorators.websocket_command(NumberSetValueCommand) +@decorators.async_response +async def set_value( + server: Server, client: Client, command: NumberSetValueCommand +) -> None: + """Select an option.""" + await execute_platform_entity_command(server, client, command, "async_set_value") + + +def load_api(server: Server) -> None: + """Load the api command handlers.""" + register_api_command(server, set_value) diff --git a/zha/websocket/server/api/platforms/select/__init__.py b/zha/websocket/server/api/platforms/select/__init__.py new file mode 100644 index 000000000..17c2e3469 --- /dev/null +++ b/zha/websocket/server/api/platforms/select/__init__.py @@ -0,0 +1,3 @@ +"""Select platform websocket api for zha.""" + +from __future__ import annotations diff --git a/zha/websocket/server/api/platforms/select/api.py b/zha/websocket/server/api/platforms/select/api.py new file mode 100644 index 000000000..c9b2bc8c5 --- /dev/null +++ b/zha/websocket/server/api/platforms/select/api.py @@ -0,0 +1,41 @@ +"""WS api for the select platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from zha.application.discovery import Platform +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command +from zha.websocket.server.api.platforms import PlatformEntityCommand +from zha.websocket.server.api.platforms.api import execute_platform_entity_command + +if TYPE_CHECKING: + from zha.websocket.server.client import Client + from zha.websocket.server.gateway import WebSocketGateway as Server + + +class SelectSelectOptionCommand(PlatformEntityCommand): + """Select select option command.""" + + command: Literal[APICommands.SELECT_SELECT_OPTION] = ( + APICommands.SELECT_SELECT_OPTION + ) + platform: str = Platform.SELECT + option: str + + +@decorators.websocket_command(SelectSelectOptionCommand) +@decorators.async_response +async def select_option( + server: Server, client: Client, command: SelectSelectOptionCommand +) -> None: + """Select an option.""" + await execute_platform_entity_command( + server, client, command, "async_select_option" + ) + + +def load_api(server: Server) -> None: + """Load the api command handlers.""" + register_api_command(server, select_option) diff --git a/zha/websocket/server/api/platforms/siren/__init__.py b/zha/websocket/server/api/platforms/siren/__init__.py new file mode 100644 index 000000000..dc37d7bc6 --- /dev/null +++ b/zha/websocket/server/api/platforms/siren/__init__.py @@ -0,0 +1,3 @@ +"""Siren platform websocket api for zha.""" + +from __future__ import annotations diff --git a/zha/websocket/server/api/platforms/siren/api.py b/zha/websocket/server/api/platforms/siren/api.py new file mode 100644 index 000000000..dccd3a266 --- /dev/null +++ b/zha/websocket/server/api/platforms/siren/api.py @@ -0,0 +1,54 @@ +"""WS api for the siren platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, Union + +from zha.application.discovery import Platform +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command +from zha.websocket.server.api.platforms import PlatformEntityCommand +from zha.websocket.server.api.platforms.api import execute_platform_entity_command + +if TYPE_CHECKING: + from zha.websocket.server.client import Client + from zha.websocket.server.gateway import WebSocketGateway as Server + + +class SirenTurnOnCommand(PlatformEntityCommand): + """Siren turn on command.""" + + command: Literal[APICommands.SIREN_TURN_ON] = APICommands.SIREN_TURN_ON + platform: str = Platform.SIREN + duration: Union[int, None] = None + tone: Union[int, None] = None + level: Union[int, None] = None + + +@decorators.websocket_command(SirenTurnOnCommand) +@decorators.async_response +async def turn_on(server: Server, client: Client, command: SirenTurnOnCommand) -> None: + """Turn on the siren.""" + await execute_platform_entity_command(server, client, command, "async_turn_on") + + +class SirenTurnOffCommand(PlatformEntityCommand): + """Siren turn off command.""" + + command: Literal[APICommands.SIREN_TURN_OFF] = APICommands.SIREN_TURN_OFF + platform: str = Platform.SIREN + + +@decorators.websocket_command(SirenTurnOffCommand) +@decorators.async_response +async def turn_off( + server: Server, client: Client, command: SirenTurnOffCommand +) -> None: + """Turn on the siren.""" + await execute_platform_entity_command(server, client, command, "async_turn_off") + + +def load_api(server: Server) -> None: + """Load the api command handlers.""" + register_api_command(server, turn_on) + register_api_command(server, turn_off) diff --git a/zha/websocket/server/api/platforms/switch/__init__.py b/zha/websocket/server/api/platforms/switch/__init__.py new file mode 100644 index 000000000..1bfc10c74 --- /dev/null +++ b/zha/websocket/server/api/platforms/switch/__init__.py @@ -0,0 +1,3 @@ +"""Switch platform websocket api for zha.""" + +from __future__ import annotations diff --git a/zha/websocket/server/api/platforms/switch/api.py b/zha/websocket/server/api/platforms/switch/api.py new file mode 100644 index 000000000..b14f3cf01 --- /dev/null +++ b/zha/websocket/server/api/platforms/switch/api.py @@ -0,0 +1,51 @@ +"""WS api for the switch platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from zha.application.discovery import Platform +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command +from zha.websocket.server.api.platforms import PlatformEntityCommand +from zha.websocket.server.api.platforms.api import execute_platform_entity_command + +if TYPE_CHECKING: + from zha.websocket.server.client import Client + from zha.websocket.server.gateway import WebSocketGateway as Server + + +class SwitchTurnOnCommand(PlatformEntityCommand): + """Switch turn on command.""" + + command: Literal[APICommands.SWITCH_TURN_ON] = APICommands.SWITCH_TURN_ON + platform: str = Platform.SWITCH + + +@decorators.websocket_command(SwitchTurnOnCommand) +@decorators.async_response +async def turn_on(server: Server, client: Client, command: SwitchTurnOnCommand) -> None: + """Turn on the switch.""" + await execute_platform_entity_command(server, client, command, "async_turn_on") + + +class SwitchTurnOffCommand(PlatformEntityCommand): + """Switch turn off command.""" + + command: Literal[APICommands.SWITCH_TURN_OFF] = APICommands.SWITCH_TURN_OFF + platform: str = Platform.SWITCH + + +@decorators.websocket_command(SwitchTurnOffCommand) +@decorators.async_response +async def turn_off( + server: Server, client: Client, command: SwitchTurnOffCommand +) -> None: + """Turn on the switch.""" + await execute_platform_entity_command(server, client, command, "async_turn_off") + + +def load_api(server: Server) -> None: + """Load the api command handlers.""" + register_api_command(server, turn_on) + register_api_command(server, turn_off) diff --git a/zha/websocket/server/client.py b/zha/websocket/server/client.py index f6b4ff879..ccc1c87f8 100644 --- a/zha/websocket/server/client.py +++ b/zha/websocket/server/client.py @@ -11,11 +11,11 @@ from pydantic import BaseModel, ValidationError from websockets.server import WebSocketServerProtocol +from zha.model import BaseEvent from zha.websocket.const import ( COMMAND, ERROR_CODE, ERROR_MESSAGE, - EVENT_TYPE, MESSAGE_ID, MESSAGE_TYPE, SUCCESS, @@ -26,7 +26,7 @@ MessageTypes, ) from zha.websocket.server.api import decorators, register_api_command -from zha.websocket.server.api.model import WebSocketCommand +from zha.websocket.server.api.model import WebSocketCommand, WebSocketCommandResponse if TYPE_CHECKING: from zha.websocket.server.gateway import WebSocketGateway @@ -59,24 +59,28 @@ def disconnect(self) -> None: asyncio.create_task(self._websocket.close()) ) - def send_event(self, message: dict[str, Any]) -> None: + def send_event(self, message: BaseEvent) -> None: """Send event data to this client.""" - message[MESSAGE_TYPE] = MessageTypes.EVENT + message.message_type = MessageTypes.EVENT self._send_data(message) def send_result_success( - self, command: WebSocketCommand, data: dict[str, Any] | None = None + self, command: WebSocketCommand, data: dict[str, Any] | BaseModel | None = None ) -> None: """Send success result prompted by a client request.""" - message = { - SUCCESS: True, - MESSAGE_ID: command.message_id, - MESSAGE_TYPE: MessageTypes.RESULT, - COMMAND: command.command, - } - if data: - message.update(data) - self._send_data(message) + if data and isinstance(data, BaseModel): + self._send_data(data) + else: + if data is None: + data = {} + self._send_data( + WebSocketCommandResponse( + success=True, + message_id=command.message_id, + command=command.command, + **data, + ) + ) def send_result_error( self, @@ -169,13 +173,13 @@ async def listen(self) -> None: asyncio.create_task(self._handle_incoming_message(message)) ) - def will_accept_message(self, message: dict[str, Any]) -> bool: + def will_accept_message(self, message: BaseEvent) -> bool: """Determine if client accepts this type of message.""" if not self.receive_events: return False if ( - message[EVENT_TYPE] == EventTypes.RAW_ZCL_EVENT + message.event_type == EventTypes.RAW_ZCL_EVENT and not self.receive_raw_zcl_events ): _LOGGER.info( @@ -269,7 +273,7 @@ def remove_client(self, client: Client) -> None: client.disconnect() self._clients.remove(client) - def broadcast(self, message: dict[str, Any]) -> None: + def broadcast(self, message: BaseEvent) -> None: """Broadcast a message to all connected clients.""" clients_to_remove = [] diff --git a/zha/websocket/server/gateway.py b/zha/websocket/server/gateway.py index 9d9dec7b7..115e6b2c7 100644 --- a/zha/websocket/server/gateway.py +++ b/zha/websocket/server/gateway.py @@ -5,6 +5,7 @@ import asyncio import contextlib import logging +from time import monotonic from types import TracebackType from typing import TYPE_CHECKING, Any, Final, Literal @@ -16,7 +17,9 @@ from zha.websocket.const import APICommands from zha.websocket.server.api import decorators, register_api_command from zha.websocket.server.api.model import WebSocketCommand +from zha.websocket.server.api.platforms.api import load_platform_entity_apis from zha.websocket.server.client import ClientManager +from zha.websocket.server.gateway_api import load_api as load_zigbee_controller_api if TYPE_CHECKING: from zha.websocket.client import Client @@ -62,9 +65,13 @@ async def start_server(self) -> None: ) if self.config.server_config.network_auto_start: await self.async_initialize() - self.on_all_events(self.client_manager.broadcast) await self.async_initialize_devices_and_entities() + async def async_initialize(self) -> None: + """Initialize controller and connect radio.""" + await super().async_initialize() + self.on_all_events(self.client_manager.broadcast) + async def stop_server(self) -> None: """Stop the websocket server.""" if self._ws_server is None: @@ -108,6 +115,36 @@ def track_ws_task(self, task: asyncio.Task) -> None: self._tracked_ws_tasks.add(task) task.add_done_callback(self._tracked_ws_tasks.remove) + async def async_block_till_done(self, wait_background_tasks=False): + """Block until all pending work is done.""" + # To flush out any call_soon_threadsafe + await asyncio.sleep(0.001) + start_time: float | None = None + + while self._tracked_ws_tasks: + pending = [task for task in self._tracked_ws_tasks if not task.done()] + self._tracked_ws_tasks.clear() + if pending: + await self._await_and_log_pending(pending) + + if start_time is None: + # Avoid calling monotonic() until we know + # we may need to start logging blocked tasks. + start_time = 0 + elif start_time == 0: + # If we have waited twice then we set the start + # time + start_time = monotonic() + elif monotonic() - start_time > BLOCK_LOG_TIMEOUT: + # We have waited at least three loops and new tasks + # continue to block. At this point we start + # logging all waiting tasks. + for task in pending: + _LOGGER.debug("Waiting for task: %s", task) + else: + await asyncio.sleep(0.001) + await super().async_block_till_done(wait_background_tasks=wait_background_tasks) + async def __aenter__(self) -> WebSocketGateway: """Enter the context manager.""" await self.start_server() @@ -125,6 +162,8 @@ def _register_api_commands(self) -> None: from zha.websocket.server.client import load_api as load_client_api register_api_command(self, stop_server) + load_zigbee_controller_api(self) + load_platform_entity_apis(self) load_client_api(self) diff --git a/zha/websocket/server/gateway_api.py b/zha/websocket/server/gateway_api.py index 122d42c95..4e86c8881 100644 --- a/zha/websocket/server/gateway_api.py +++ b/zha/websocket/server/gateway_api.py @@ -3,23 +3,23 @@ from __future__ import annotations import asyncio -import dataclasses import logging from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeVar, Union, cast from pydantic import Field from zigpy.types.named import EUI64 -from zha.websocket.client.model.types import ( - Device as DeviceModel, - Group as GroupModel, - GroupMemberReference, -) -from zha.websocket.const import DEVICES, DURATION, GROUPS, APICommands +from zha.websocket.const import DURATION, GROUPS, APICommands from zha.websocket.server.api import decorators, register_api_command -from zha.websocket.server.api.model import WebSocketCommand +from zha.websocket.server.api.model import ( + GetDevicesResponse, + ReadClusterAttributesResponse, + WebSocketCommand, + WriteClusterAttributeResponse, +) from zha.zigbee.device import Device from zha.zigbee.group import Group +from zha.zigbee.model import GroupMemberReference if TYPE_CHECKING: from zha.websocket.server.client import Client @@ -103,14 +103,16 @@ async def get_devices( ) -> None: """Get Zigbee devices.""" try: - response_devices: dict[str, dict] = { - str(ieee): DeviceModel.model_validate( - dataclasses.asdict(device.extended_device_info) - ).model_dump() - for ieee, device in gateway.devices.items() - } - _LOGGER.info("devices: %s", response_devices) - client.send_result_success(command, {DEVICES: response_devices}) + response = GetDevicesResponse( + success=True, + devices={ + ieee: device.extended_device_info + for ieee, device in gateway.devices.items() + }, + message_id=command.message_id, + ) + _LOGGER.info("response: %s", response) + client.send_result_success(command, response) except Exception as e: _LOGGER.exception("Error getting devices", exc_info=e) client.send_result_error(command, "Error getting devices", str(e)) @@ -149,9 +151,9 @@ async def get_groups( """Get Zigbee groups.""" groups: dict[int, Any] = {} for group_id, group in gateway.groups.items(): - group_data = dataclasses.asdict(group.info_object) - group_data["id"] = group_id - groups[group_id] = GroupModel.model_validate(group_data).model_dump() + groups[int(group_id)] = ( + group.info_object + ) # maybe we should change the group_id type... _LOGGER.info("groups: %s", groups) client.send_result_success(command, {GROUPS: groups}) @@ -243,23 +245,23 @@ async def read_cluster_attributes( success, failure = await cluster.read_attributes( attributes, allow_cache=False, only_cache=False, manufacturer=manufacturer ) - client.send_result_success( - command, - { - "device": { - "ieee": command.ieee, - }, - "cluster": { - "id": cluster.cluster_id, - "endpoint_id": cluster.endpoint.endpoint_id, - "name": cluster.name, - "endpoint_attribute": cluster.ep_attribute, - }, - "manufacturer_code": manufacturer, - "succeeded": success, - "failed": failure, + + response = ReadClusterAttributesResponse( + message_id=command.message_id, + success=True, + device=device.extended_device_info, + cluster={ + "id": cluster.cluster_id, + "name": cluster.name, + "type": cluster.cluster_type, + "endpoint_id": cluster.endpoint.endpoint_id, + "endpoint_attribute": cluster.ep_attribute, }, + manufacturer_code=manufacturer, + succeeded=success, + failed=failure, ) + client.send_result_success(command, response) class WriteClusterAttributeCommand(WebSocketCommand): @@ -317,25 +319,25 @@ async def write_cluster_attribute( cluster_type=cluster_type, manufacturer=manufacturer, ) - client.send_result_success( - command, - { - "device": { - "ieee": str(command.ieee), - }, - "cluster": { - "id": cluster.cluster_id, - "endpoint_id": cluster.endpoint.endpoint_id, - "name": cluster.name, - "endpoint_attribute": cluster.ep_attribute, - }, - "manufacturer_code": manufacturer, - "response": { - "attribute": attribute, - "status": response[0][0].status.name, # type: ignore - }, # TODO there has to be a better way to do this + + api_response = WriteClusterAttributeResponse( + message_id=command.message_id, + success=True, + device=device.extended_device_info, + cluster={ + "id": cluster.cluster_id, + "name": cluster.name, + "type": cluster.cluster_type, + "endpoint_id": cluster.endpoint.endpoint_id, + "endpoint_attribute": cluster.ep_attribute, }, + manufacturer_code=manufacturer, + response={ + "attribute": attribute, + "status": response[0][0].status.name, # type: ignore + }, # TODO there has to be a better way to do this ) + client.send_result_success(command, api_response) class CreateGroupCommand(WebSocketCommand): @@ -357,10 +359,7 @@ async def create_group( members = command.members group_id = command.group_id group: Group = await gateway.async_create_zigpy_group(group_name, members, group_id) - ret_group = dataclasses.asdict(group.info_object) - ret_group["id"] = ret_group["group_id"] - ret_group = GroupModel.model_validate(ret_group).model_dump() - client.send_result_success(command, {"group": ret_group}) + client.send_result_success(command, {"group": group.info_object}) class RemoveGroupsCommand(WebSocketCommand): @@ -386,10 +385,8 @@ async def remove_groups( else: await gateway.async_remove_zigpy_group(group_ids[0]) groups: dict[int, Any] = {} - for id, group in gateway.groups.items(): - group_data = dataclasses.asdict(group.info_object) - group_data["id"] = group_data["group_id"] - groups[id] = GroupModel.model_validate(group_data).model_dump() + for group_id, group in gateway.groups.items(): + groups[int(group_id)] = group.info_object _LOGGER.info("groups: %s", groups) client.send_result_success(command, {GROUPS: groups}) @@ -420,10 +417,7 @@ async def add_group_members( if not group: client.send_result_error(command, "G1", "ZHA Group not found") return - ret_group = dataclasses.asdict(group.info_object) - ret_group["id"] = ret_group["group_id"] - ret_group = GroupModel.model_validate(ret_group).model_dump() - client.send_result_success(command, {GROUP: ret_group}) + client.send_result_success(command, {GROUP: group.info_object}) class RemoveGroupMembersCommand(AddGroupMembersCommand): @@ -450,10 +444,7 @@ async def remove_group_members( if not group: client.send_result_error(command, "G1", "ZHA Group not found") return - ret_group = dataclasses.asdict(group.info_object) - ret_group["id"] = ret_group["group_id"] - ret_group = GroupModel.model_validate(ret_group).model_dump() - client.send_result_success(command, {GROUP: ret_group}) + client.send_result_success(command, {GROUP: group.info_object}) def load_api(gateway: WebSocketGateway) -> None: diff --git a/zha/zigbee/cluster_handlers/__init__.py b/zha/zigbee/cluster_handlers/__init__.py index 940bf6a41..6450c5c54 100644 --- a/zha/zigbee/cluster_handlers/__init__.py +++ b/zha/zigbee/cluster_handlers/__init__.py @@ -4,12 +4,10 @@ from collections.abc import Awaitable, Callable, Coroutine, Iterator import contextlib -from enum import StrEnum import functools import logging -from typing import TYPE_CHECKING, Any, Literal, ParamSpec, TypedDict +from typing import TYPE_CHECKING, Any, ParamSpec, TypedDict -from pydantic import field_serializer import zigpy.exceptions import zigpy.util import zigpy.zcl @@ -18,7 +16,6 @@ ConfigureReportingResponseRecord, Status, ZCLAttributeDef, - ZCLCommandDef, ) from zha.application.const import ( @@ -29,7 +26,6 @@ from zha.event import EventBase from zha.exceptions import ZHAException from zha.mixins import LogMixin -from zha.model import BaseEvent, BaseModel from zha.zigbee.cluster_handlers.const import ( ARGS, ATTRIBUTE_ID, @@ -46,6 +42,14 @@ UNIQUE_ID, VALUE, ) +from zha.zigbee.cluster_handlers.model import ( + ClusterAttributeUpdatedEvent, + ClusterBindEvent, + ClusterConfigureReportingEvent, + ClusterHandlerInfo, + ClusterHandlerStatus, + ClusterInfo, +) if TYPE_CHECKING: from zha.zigbee.endpoint import Endpoint @@ -114,99 +118,6 @@ def parse_and_log_command(cluster_handler, tsn, command_id, args): return name -class ClusterHandlerStatus(StrEnum): - """Status of a cluster handler.""" - - CREATED = "created" - CONFIGURED = "configured" - INITIALIZED = "initialized" - - -class ClusterAttributeUpdatedEvent(BaseEvent): - """Event to signal that a cluster attribute has been updated.""" - - attribute_id: int - attribute_name: str - attribute_value: Any - cluster_handler_unique_id: str - cluster_id: int - event_type: Literal["cluster_handler_event"] = "cluster_handler_event" - event: Literal["cluster_handler_attribute_updated"] = ( - "cluster_handler_attribute_updated" - ) - - -class ClusterBindEvent(BaseEvent): - """Event generated when the cluster is bound.""" - - cluster_name: str - cluster_id: int - success: bool - cluster_handler_unique_id: str - event_type: Literal["zha_channel_message"] = "zha_channel_message" - event: Literal["zha_channel_bind"] = "zha_channel_bind" - - -class ClusterConfigureReportingEvent(BaseEvent): - """Event generates when a cluster configures attribute reporting.""" - - cluster_name: str - cluster_id: int - attributes: dict[str, dict[str, Any]] - cluster_handler_unique_id: str - event_type: Literal["zha_channel_message"] = "zha_channel_message" - event: Literal["zha_channel_configure_reporting"] = ( - "zha_channel_configure_reporting" - ) - - -class ClusterInfo(BaseModel): - """Cluster information.""" - - id: int - name: str - type: str - commands: list[ZCLCommandDef] - - @field_serializer("commands", when_used="json-unless-none", check_fields=False) - def serialize_commands(self, commands: list[ZCLCommandDef]): - """Serialize commands.""" - converted_commands = [] - for command in commands: - converted_command = { - "id": command.id, - "name": command.name, - "schema": { - "command": command.schema.command.name, - "fields": [ - { - "name": f.name, - "type": f.type.__name__, - "optional": f.optional, - } - for f in command.schema.fields - ], - }, - "direction": command.direction, - "is_manufacturer_specific": command.is_manufacturer_specific, - } - converted_commands.append(converted_command) - return converted_commands - - -class ClusterHandlerInfo(BaseModel): - """Cluster handler information.""" - - class_name: str - generic_id: str - endpoint_id: int - cluster: ClusterInfo - id: str - unique_id: str - status: ClusterHandlerStatus - value_attribute: str | None = None - - class ClusterHandler(LogMixin, EventBase): """Base cluster handler for a Zigbee cluster.""" @@ -252,7 +163,8 @@ def info_object(self) -> ClusterHandlerInfo: id=self._cluster.cluster_id, name=self._cluster.name, type="client" if self._cluster.is_client else "server", - commands=self._cluster.commands, + endpoint_id=self._cluster.endpoint.endpoint_id, + endpoint_attribute=self._cluster.ep_attribute, ), id=self._id, unique_id=self._unique_id, diff --git a/zha/zigbee/cluster_handlers/general.py b/zha/zigbee/cluster_handlers/general.py index d9ce799f2..60b8f7bee 100644 --- a/zha/zigbee/cluster_handlers/general.py +++ b/zha/zigbee/cluster_handlers/general.py @@ -5,7 +5,7 @@ import asyncio from collections.abc import Coroutine from datetime import datetime -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any from zhaquirks.quirk_ids import TUYA_PLUG_ONOFF import zigpy.exceptions @@ -44,7 +44,6 @@ from zigpy.zcl.foundation import Status from zha.exceptions import ZHAException -from zha.model import BaseEvent from zha.zigbee.cluster_handlers import ( AttrReportConfig, ClientClusterHandler, @@ -64,19 +63,12 @@ SIGNAL_SET_LEVEL, ) from zha.zigbee.cluster_handlers.helpers import is_hue_motion_sensor +from zha.zigbee.cluster_handlers.model import LevelChangeEvent if TYPE_CHECKING: from zha.zigbee.endpoint import Endpoint -class LevelChangeEvent(BaseEvent): - """Event to signal that a cluster attribute has been updated.""" - - level: int - event: str - event_type: Literal["cluster_handler_event"] = "cluster_handler_event" - - @registries.CLUSTER_HANDLER_REGISTRY.register(Alarms.cluster_id) class AlarmsClusterHandler(ClusterHandler): """Alarms cluster handler.""" diff --git a/zha/zigbee/cluster_handlers/model.py b/zha/zigbee/cluster_handlers/model.py new file mode 100644 index 000000000..412775c2d --- /dev/null +++ b/zha/zigbee/cluster_handlers/model.py @@ -0,0 +1,83 @@ +"""Models for the ZHA cluster handlers module.""" + +from enum import StrEnum +from typing import Any, Literal + +from zha.model import BaseEvent, BaseModel + + +class ClusterHandlerStatus(StrEnum): + """Status of a cluster handler.""" + + CREATED = "created" + CONFIGURED = "configured" + INITIALIZED = "initialized" + + +class ClusterAttributeUpdatedEvent(BaseEvent): + """Event to signal that a cluster attribute has been updated.""" + + attribute_id: int + attribute_name: str + attribute_value: Any + cluster_handler_unique_id: str + cluster_id: int + event_type: Literal["cluster_handler_event"] = "cluster_handler_event" + event: Literal["cluster_handler_attribute_updated"] = ( + "cluster_handler_attribute_updated" + ) + + +class ClusterBindEvent(BaseEvent): + """Event generated when the cluster is bound.""" + + cluster_name: str + cluster_id: int + success: bool + cluster_handler_unique_id: str + event_type: Literal["zha_channel_message"] = "zha_channel_message" + event: Literal["zha_channel_bind"] = "zha_channel_bind" + + +class ClusterConfigureReportingEvent(BaseEvent): + """Event generates when a cluster configures attribute reporting.""" + + cluster_name: str + cluster_id: int + attributes: dict[str, dict[str, Any]] + cluster_handler_unique_id: str + event_type: Literal["zha_channel_message"] = "zha_channel_message" + event: Literal["zha_channel_configure_reporting"] = ( + "zha_channel_configure_reporting" + ) + + +class ClusterInfo(BaseModel): + """Cluster information.""" + + id: int + name: str + type: str + endpoint_id: int + endpoint_attribute: str | None = None + + +class ClusterHandlerInfo(BaseModel): + """Cluster handler information.""" + + class_name: str + generic_id: str + endpoint_id: int + cluster: ClusterInfo + id: str + unique_id: str + status: ClusterHandlerStatus + value_attribute: str | None = None + + +class LevelChangeEvent(BaseEvent): + """Event to signal that a cluster attribute has been updated.""" + + level: int + event: str + event_type: Literal["cluster_handler_event"] = "cluster_handler_event" diff --git a/zha/zigbee/device.py b/zha/zigbee/device.py index abf9262e3..05c845d1c 100644 --- a/zha/zigbee/device.py +++ b/zha/zigbee/device.py @@ -5,19 +5,17 @@ from __future__ import annotations import asyncio -from enum import Enum, StrEnum from functools import cached_property import logging import time -from typing import TYPE_CHECKING, Any, Literal, Self, Union +from typing import TYPE_CHECKING, Any, Self -from pydantic import field_serializer, field_validator from zigpy.device import Device as ZigpyDevice import zigpy.exceptions from zigpy.profiles import PROFILES import zigpy.quirks -from zigpy.types import uint1_t, uint8_t, uint16_t -from zigpy.types.named import EUI64, NWK, ExtendedPanId +from zigpy.types import uint8_t, uint16_t +from zigpy.types.named import EUI64, NWK from zigpy.zcl.clusters import Cluster from zigpy.zcl.clusters.general import Groups, Identify from zigpy.zcl.foundation import ( @@ -26,7 +24,6 @@ ZCLCommandDef, ) import zigpy.zdo.types as zdo_types -from zigpy.zdo.types import RouteStatus, _NeighborEnums from zha.application import Platform, discovery from zha.application.const import ( @@ -58,13 +55,23 @@ ZHA_EVENT, ) from zha.application.helpers import convert_to_zcl_values -from zha.application.platforms import BaseEntityInfo, PlatformEntity +from zha.application.platforms import PlatformEntity from zha.event import EventBase from zha.exceptions import ZHAException from zha.mixins import LogMixin -from zha.model import BaseEvent, BaseModel, convert_enum, convert_int from zha.zigbee.cluster_handlers import ClusterHandler, ZDOClusterHandler from zha.zigbee.endpoint import Endpoint +from zha.zigbee.model import ( + ClusterBinding, + ClusterHandlerConfigurationComplete, + DeviceInfo, + DeviceStatus, + EndpointNameInfo, + ExtendedDeviceInfo, + NeighborInfo, + RouteInfo, + ZHAEvent, +) if TYPE_CHECKING: from zha.application.gateway import Gateway @@ -83,184 +90,6 @@ def get_device_automation_triggers( } -class DeviceStatus(StrEnum): - """Status of a device.""" - - CREATED = "created" - INITIALIZED = "initialized" - - -class ZHAEvent(BaseEvent): - """Event generated when a device wishes to send an arbitrary event.""" - - device_ieee: EUI64 - unique_id: str - data: dict[str, Any] - event_type: Literal["zha_event"] = "zha_event" - event: Literal["zha_event"] = "zha_event" - - -class ClusterHandlerConfigurationComplete(BaseEvent): - """Event generated when all cluster handlers are configured.""" - - device_ieee: EUI64 - unique_id: str - event_type: Literal["zha_channel_message"] = "zha_channel_message" - event: Literal["zha_channel_cfg_done"] = "zha_channel_cfg_done" - - -class ClusterBinding(BaseModel): - """Describes a cluster binding.""" - - name: str - type: str - id: int - endpoint_id: int - - -class DeviceInfo(BaseModel): - """Describes a device.""" - - ieee: EUI64 - nwk: NWK - manufacturer: str - model: str - name: str - quirk_applied: bool - quirk_class: str - quirk_id: str | None - manufacturer_code: int | None - power_source: str - lqi: int | None - rssi: int | None - last_seen: str - available: bool - device_type: str - signature: dict[str, Any] - - @field_serializer("signature", when_used="json-unless-none", check_fields=False) - def serialize_signature(self, signature: dict[str, Any]): - """Serialize signature.""" - if "node_descriptor" in signature: - signature["node_descriptor"] = signature["node_descriptor"].as_dict() - return signature - - -class NeighborInfo(BaseModel): - """Describes a neighbor.""" - - device_type: _NeighborEnums.DeviceType - rx_on_when_idle: _NeighborEnums.RxOnWhenIdle - relationship: _NeighborEnums.Relationship - extended_pan_id: ExtendedPanId - ieee: EUI64 - nwk: NWK - permit_joining: _NeighborEnums.PermitJoins - depth: uint8_t - lqi: uint8_t - - _convert_device_type = field_validator( - "device_type", mode="before", check_fields=False - )(convert_enum(_NeighborEnums.DeviceType)) - - _convert_rx_on_when_idle = field_validator( - "rx_on_when_idle", mode="before", check_fields=False - )(convert_enum(_NeighborEnums.RxOnWhenIdle)) - - _convert_relationship = field_validator( - "relationship", mode="before", check_fields=False - )(convert_enum(_NeighborEnums.Relationship)) - - _convert_permit_joining = field_validator( - "permit_joining", mode="before", check_fields=False - )(convert_enum(_NeighborEnums.PermitJoins)) - - _convert_depth = field_validator("depth", mode="before", check_fields=False)( - convert_int(uint8_t) - ) - _convert_lqi = field_validator("lqi", mode="before", check_fields=False)( - convert_int(uint8_t) - ) - - @field_validator("extended_pan_id", mode="before", check_fields=False) - @classmethod - def convert_extended_pan_id( - cls, extended_pan_id: Union[str, ExtendedPanId] - ) -> ExtendedPanId: - """Convert extended_pan_id to ExtendedPanId.""" - if isinstance(extended_pan_id, str): - return ExtendedPanId.convert(extended_pan_id) - return extended_pan_id - - @field_serializer("extended_pan_id", check_fields=False) - def serialize_extended_pan_id(self, extended_pan_id: ExtendedPanId): - """Customize how extended_pan_id is serialized.""" - return str(extended_pan_id) - - @field_serializer( - "device_type", - "rx_on_when_idle", - "relationship", - "permit_joining", - check_fields=False, - ) - def serialize_enums(self, enum_value: Enum): - """Serialize enums by name.""" - return enum_value.name - - -class RouteInfo(BaseModel): - """Describes a route.""" - - dest_nwk: NWK - route_status: RouteStatus - memory_constrained: uint1_t - many_to_one: uint1_t - route_record_required: uint1_t - next_hop: NWK - - _convert_route_status = field_validator( - "route_status", mode="before", check_fields=False - )(convert_enum(RouteStatus)) - - _convert_memory_constrained = field_validator( - "memory_constrained", mode="before", check_fields=False - )(convert_int(uint1_t)) - - _convert_many_to_one = field_validator( - "many_to_one", mode="before", check_fields=False - )(convert_int(uint1_t)) - - _convert_route_record_required = field_validator( - "route_record_required", mode="before", check_fields=False - )(convert_int(uint1_t)) - - @field_serializer( - "route_status", - check_fields=False, - ) - def serialize_route_status(self, route_status: RouteStatus): - """Serialize route_status as name.""" - return route_status.name - - -class EndpointNameInfo(BaseModel): - """Describes an endpoint name.""" - - name: str - - -class ExtendedDeviceInfo(DeviceInfo): - """Describes a ZHA device.""" - - active_coordinator: bool - entities: dict[tuple[Platform, str], BaseEntityInfo] - neighbors: list[NeighborInfo] - routes: list[RouteInfo] - endpoint_names: list[EndpointNameInfo] - device_automation_triggers: dict[tuple[str, str], dict[str, Any]] - - class Device(LogMixin, EventBase): """ZHA Zigbee device object.""" @@ -753,7 +582,7 @@ def extended_device_info(self) -> ExtendedDeviceInfo: **self.device_info.__dict__, active_coordinator=self.is_active_coordinator, entities={ - platform_entity_key: platform_entity.info_object + platform_entity_key: platform_entity.info_object.model_dump() for platform_entity_key, platform_entity in self.platform_entities.items() }, neighbors=[ diff --git a/zha/zigbee/group.py b/zha/zigbee/group.py index 057b4d984..7c90d895e 100644 --- a/zha/zigbee/group.py +++ b/zha/zigbee/group.py @@ -11,15 +11,10 @@ import zigpy.exceptions from zigpy.types.named import EUI64 -from zha.application.platforms import ( - BaseEntityInfo, - EntityStateChangedEvent, - PlatformEntity, -) +from zha.application.platforms import EntityStateChangedEvent, PlatformEntity from zha.const import STATE_CHANGED from zha.mixins import LogMixin -from zha.model import BaseModel -from zha.zigbee.device import ExtendedDeviceInfo +from zha.zigbee.model import GroupInfo, GroupMemberInfo, GroupMemberReference if TYPE_CHECKING: from zigpy.group import Group as ZigpyGroup, GroupEndpoint @@ -31,39 +26,6 @@ _LOGGER = logging.getLogger(__name__) -class GroupMemberReference(BaseModel): - """Describes a group member.""" - - ieee: EUI64 - endpoint_id: int - - -class GroupEntityReference(BaseModel): - """Reference to a group entity.""" - - entity_id: str - name: str | None = None - original_name: str | None = None - - -class GroupMemberInfo(BaseModel): - """Describes a group member.""" - - ieee: EUI64 - endpoint_id: int - device_info: ExtendedDeviceInfo - entities: dict[str, BaseEntityInfo] - - -class GroupInfo(BaseModel): - """Describes a group.""" - - group_id: int - name: str - members: list[GroupMemberInfo] - entities: dict[str, BaseEntityInfo] - - class GroupMember(LogMixin): """Composite object that represents a device endpoint in a Zigbee group.""" @@ -101,7 +63,7 @@ def member_info(self) -> GroupMemberInfo: endpoint_id=self.endpoint_id, device_info=self.device.extended_device_info, entities={ - entity.unique_id: entity.info_object + entity.unique_id: entity.info_object.__dict__ for entity in self.associated_entities }, ) @@ -202,7 +164,7 @@ def info_object(self) -> GroupInfo: name=self.name, members=[member.member_info for member in self.members], entities={ - unique_id: entity.info_object + unique_id: entity.info_object.__dict__ for unique_id, entity in self._group_entities.items() }, ) diff --git a/zha/zigbee/model.py b/zha/zigbee/model.py new file mode 100644 index 000000000..c3dfec5a8 --- /dev/null +++ b/zha/zigbee/model.py @@ -0,0 +1,329 @@ +"""Models for the ZHA zigbee module.""" + +from enum import Enum, StrEnum +from typing import Annotated, Any, Literal, Union + +from pydantic import Field, field_serializer, field_validator +from zigpy.types import uint1_t, uint8_t +from zigpy.types.named import EUI64, NWK, ExtendedPanId +from zigpy.zdo.types import RouteStatus, _NeighborEnums + +from zha.application import Platform +from zha.application.platforms.model import ( + AlarmControlPanelEntity, + BatteryEntity, + BinarySensorEntity, + ButtonEntity, + CoverEntity, + DeviceCounterSensorEntity, + DeviceTrackerEntity, + ElectricalMeasurementEntity, + FanEntity, + FanGroupEntity, + FirmwareUpdateEntity, + LightEntity, + LightGroupEntity, + LockEntity, + NumberEntity, + SelectEntity, + SensorEntity, + ShadeEntity, + SirenEntity, + SmartEnergyMeteringEntity, + SwitchEntity, + SwitchGroupEntity, + ThermostatEntity, +) +from zha.model import BaseEvent, BaseModel, convert_enum, convert_int + + +class DeviceStatus(StrEnum): + """Status of a device.""" + + CREATED = "created" + INITIALIZED = "initialized" + + +class ZHAEvent(BaseEvent): + """Event generated when a device wishes to send an arbitrary event.""" + + device_ieee: EUI64 + unique_id: str + data: dict[str, Any] + event_type: Literal["device_event"] = "device_event" + event: Literal["zha_event"] = "zha_event" + + +class ClusterHandlerConfigurationComplete(BaseEvent): + """Event generated when all cluster handlers are configured.""" + + device_ieee: EUI64 + unique_id: str + event_type: Literal["zha_channel_message"] = "zha_channel_message" + event: Literal["zha_channel_cfg_done"] = "zha_channel_cfg_done" + + +class ClusterBinding(BaseModel): + """Describes a cluster binding.""" + + name: str + type: str + id: int + endpoint_id: int + + +class DeviceInfo(BaseModel): + """Describes a device.""" + + ieee: EUI64 + nwk: NWK + manufacturer: str + model: str + name: str + quirk_applied: bool + quirk_class: str + quirk_id: str | None + manufacturer_code: int | None + power_source: str + lqi: int | None + rssi: int | None + last_seen: str + available: bool + device_type: str + signature: dict[str, Any] + + @field_serializer("signature", check_fields=False) + def serialize_signature(self, signature: dict[str, Any]): + """Serialize signature.""" + if "node_descriptor" in signature and not isinstance( + signature["node_descriptor"], dict + ): + signature["node_descriptor"] = signature["node_descriptor"].as_dict() + return signature + + +class NeighborInfo(BaseModel): + """Describes a neighbor.""" + + device_type: _NeighborEnums.DeviceType + rx_on_when_idle: _NeighborEnums.RxOnWhenIdle + relationship: _NeighborEnums.Relationship + extended_pan_id: ExtendedPanId + ieee: EUI64 + nwk: NWK + permit_joining: _NeighborEnums.PermitJoins + depth: uint8_t + lqi: uint8_t + + _convert_device_type = field_validator( + "device_type", mode="before", check_fields=False + )(convert_enum(_NeighborEnums.DeviceType)) + + _convert_rx_on_when_idle = field_validator( + "rx_on_when_idle", mode="before", check_fields=False + )(convert_enum(_NeighborEnums.RxOnWhenIdle)) + + _convert_relationship = field_validator( + "relationship", mode="before", check_fields=False + )(convert_enum(_NeighborEnums.Relationship)) + + _convert_permit_joining = field_validator( + "permit_joining", mode="before", check_fields=False + )(convert_enum(_NeighborEnums.PermitJoins)) + + _convert_depth = field_validator("depth", mode="before", check_fields=False)( + convert_int(uint8_t) + ) + _convert_lqi = field_validator("lqi", mode="before", check_fields=False)( + convert_int(uint8_t) + ) + + @field_validator("extended_pan_id", mode="before", check_fields=False) + @classmethod + def convert_extended_pan_id( + cls, extended_pan_id: Union[str, ExtendedPanId] + ) -> ExtendedPanId: + """Convert extended_pan_id to ExtendedPanId.""" + if isinstance(extended_pan_id, str): + return ExtendedPanId.convert(extended_pan_id) + return extended_pan_id + + @field_serializer("extended_pan_id", check_fields=False) + def serialize_extended_pan_id(self, extended_pan_id: ExtendedPanId): + """Customize how extended_pan_id is serialized.""" + return str(extended_pan_id) + + @field_serializer( + "device_type", + "rx_on_when_idle", + "relationship", + "permit_joining", + check_fields=False, + ) + def serialize_enums(self, enum_value: Enum): + """Serialize enums by name.""" + return enum_value.name + + +class RouteInfo(BaseModel): + """Describes a route.""" + + dest_nwk: NWK + route_status: RouteStatus + memory_constrained: uint1_t + many_to_one: uint1_t + route_record_required: uint1_t + next_hop: NWK + + _convert_route_status = field_validator( + "route_status", mode="before", check_fields=False + )(convert_enum(RouteStatus)) + + _convert_memory_constrained = field_validator( + "memory_constrained", mode="before", check_fields=False + )(convert_int(uint1_t)) + + _convert_many_to_one = field_validator( + "many_to_one", mode="before", check_fields=False + )(convert_int(uint1_t)) + + _convert_route_record_required = field_validator( + "route_record_required", mode="before", check_fields=False + )(convert_int(uint1_t)) + + @field_serializer( + "route_status", + check_fields=False, + ) + def serialize_route_status(self, route_status: RouteStatus): + """Serialize route_status as name.""" + return route_status.name + + +class EndpointNameInfo(BaseModel): + """Describes an endpoint name.""" + + name: str + + +class ExtendedDeviceInfo(DeviceInfo): + """Describes a ZHA device.""" + + active_coordinator: bool + entities: dict[ + tuple[Platform, str], + Annotated[ + Union[ + SirenEntity, + SelectEntity, + NumberEntity, + LightEntity, + FanEntity, + FirmwareUpdateEntity, + ButtonEntity, + AlarmControlPanelEntity, + SensorEntity, + BinarySensorEntity, + DeviceTrackerEntity, + ShadeEntity, + CoverEntity, + LockEntity, + SwitchEntity, + BatteryEntity, + ElectricalMeasurementEntity, + SmartEnergyMeteringEntity, + ThermostatEntity, + DeviceCounterSensorEntity, + ], + Field(discriminator="class_name"), + ], + ] + neighbors: list[NeighborInfo] + routes: list[RouteInfo] + endpoint_names: list[EndpointNameInfo] + device_automation_triggers: dict[tuple[str, str], dict[str, Any]] + + @field_validator( + "device_automation_triggers", "entities", mode="before", check_fields=False + ) + @classmethod + def validate_tuple_keyed_dicts( + cls, + tuple_keyed_dict: dict[tuple[str, str], Any] | dict[str, dict[str, Any]], + ) -> dict[tuple[str, str], Any] | dict[str, dict[str, Any]]: + """Validate device_automation_triggers.""" + if all(isinstance(key, str) for key in tuple_keyed_dict): + return { + tuple(key.split(",")): item for key, item in tuple_keyed_dict.items() + } + return tuple_keyed_dict + + +class GroupMemberReference(BaseModel): + """Describes a group member.""" + + ieee: EUI64 + endpoint_id: int + + +class GroupEntityReference(BaseModel): + """Reference to a group entity.""" + + entity_id: str + name: str | None = None + original_name: str | None = None + + +class GroupMemberInfo(BaseModel): + """Describes a group member.""" + + ieee: EUI64 + endpoint_id: int + device_info: ExtendedDeviceInfo + entities: dict[ + str, + Annotated[ + Union[ + SirenEntity, + SelectEntity, + NumberEntity, + LightEntity, + FanEntity, + ButtonEntity, + AlarmControlPanelEntity, + FirmwareUpdateEntity, + SensorEntity, + BinarySensorEntity, + DeviceTrackerEntity, + ShadeEntity, + CoverEntity, + LockEntity, + SwitchEntity, + BatteryEntity, + ElectricalMeasurementEntity, + SmartEnergyMeteringEntity, + ThermostatEntity, + ], + Field(discriminator="class_name"), + ], + ] + + +class GroupInfo(BaseModel): + """Describes a group.""" + + group_id: int + name: str + members: list[GroupMemberInfo] + entities: dict[ + str, + Annotated[ + Union[LightGroupEntity, FanGroupEntity, SwitchGroupEntity], + Field(discriminator="class_name"), + ], + ] + + @property + def members_by_ieee(self) -> dict[EUI64, GroupMemberInfo]: + """Return members by ieee.""" + return {member.ieee: member for member in self.members} From 54b7e9d740ca7f2ba5c3f7e04703207dc2d29fff Mon Sep 17 00:00:00 2001 From: David Mulcahey Date: Mon, 21 Oct 2024 09:27:56 -0400 Subject: [PATCH 11/12] fix imports for typing --- zha/websocket/server/api/platforms/api.py | 7 ++++--- zha/websocket/server/gateway.py | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/zha/websocket/server/api/platforms/api.py b/zha/websocket/server/api/platforms/api.py index 537b2e9bc..43ffe5df6 100644 --- a/zha/websocket/server/api/platforms/api.py +++ b/zha/websocket/server/api/platforms/api.py @@ -11,7 +11,7 @@ from zha.websocket.server.api.platforms import PlatformEntityCommand if TYPE_CHECKING: - from zha.websocket.client import Client + from zha.websocket.server.client import Client from zha.websocket.server.gateway import WebSocketGateway as Server _LOGGER = logging.getLogger(__name__) @@ -48,10 +48,10 @@ async def execute_platform_entity_command( try: action = getattr(platform_entity, method_name) arg_spec = inspect.getfullargspec(action) - if arg_spec.varkw: # the only argument is self + if arg_spec.varkw: await action(**command.model_dump(exclude_none=True)) else: - await action() + await action() # the only argument is self except Exception as err: _LOGGER.exception("Error executing command: %s", method_name, exc_info=err) @@ -84,6 +84,7 @@ async def refresh_state( await execute_platform_entity_command(server, client, command, "async_update") +# pylint: disable=import-outside-toplevel def load_platform_entity_apis(server: Server) -> None: """Load the ws apis for all platform entities types.""" from zha.websocket.server.api.platforms.alarm_control_panel.api import ( diff --git a/zha/websocket/server/gateway.py b/zha/websocket/server/gateway.py index 115e6b2c7..834129e63 100644 --- a/zha/websocket/server/gateway.py +++ b/zha/websocket/server/gateway.py @@ -22,7 +22,7 @@ from zha.websocket.server.gateway_api import load_api as load_zigbee_controller_api if TYPE_CHECKING: - from zha.websocket.client import Client + from zha.websocket.server.client import Client BLOCK_LOG_TIMEOUT: Final[int] = 60 _LOGGER = logging.getLogger(__name__) @@ -159,6 +159,7 @@ async def __aexit__( def _register_api_commands(self) -> None: """Load server API commands.""" + # pylint: disable=import-outside-toplevel from zha.websocket.server.client import load_api as load_client_api register_api_command(self, stop_server) From f3c1e4872f0cb60848e4a1efa7438c7970ab00ae Mon Sep 17 00:00:00 2001 From: David Mulcahey Date: Mon, 21 Oct 2024 10:13:20 -0400 Subject: [PATCH 12/12] add alarm control panel WS tests --- tests/common.py | 3 + tests/websocket/test_alarm_control_panel.py | 245 ++++++++++++++++++ .../platforms/alarm_control_panel/__init__.py | 10 +- 3 files changed, 253 insertions(+), 5 deletions(-) create mode 100644 tests/websocket/test_alarm_control_panel.py diff --git a/tests/common.py b/tests/common.py index 6cee2a9fd..54e6164c0 100644 --- a/tests/common.py +++ b/tests/common.py @@ -504,6 +504,9 @@ def create_mock_zigpy_device( descriptor_capability_field=zdo_t.NodeDescriptor.DescriptorCapability.NONE, ) + if isinstance(node_descriptor, bytes): + node_descriptor = zdo_t.NodeDescriptor.deserialize(node_descriptor)[0] + device.node_desc = node_descriptor device.last_seen = time.time() diff --git a/tests/websocket/test_alarm_control_panel.py b/tests/websocket/test_alarm_control_panel.py new file mode 100644 index 000000000..98f4eb4d1 --- /dev/null +++ b/tests/websocket/test_alarm_control_panel.py @@ -0,0 +1,245 @@ +"""Test zha alarm control panel.""" + +import logging +from typing import Optional +from unittest.mock import AsyncMock, call, patch, sentinel + +from zigpy.profiles import zha +from zigpy.zcl.clusters import security +import zigpy.zcl.foundation as zcl_f + +from zha.application import Platform +from zha.application.platforms.model import AlarmControlPanelEntity +from zha.websocket.client.controller import Controller +from zha.websocket.client.proxy import DeviceProxy +from zha.websocket.server.gateway import WebSocketGateway as Server + +from ..common import ( + SIG_EP_INPUT, + SIG_EP_OUTPUT, + SIG_EP_PROFILE, + SIG_EP_TYPE, + create_mock_zigpy_device, + join_zigpy_device, +) + +_LOGGER = logging.getLogger(__name__) + + +@patch( + "zigpy.zcl.clusters.security.IasAce.client_command", + new=AsyncMock(return_value=[sentinel.data, zcl_f.Status.SUCCESS]), +) +async def test_alarm_control_panel( + connected_client_and_server: tuple[Controller, Server], +) -> None: + """Test zhaws alarm control panel platform.""" + controller, server = connected_client_and_server + + zigpy_device = create_mock_zigpy_device( + server, + { + 1: { + SIG_EP_INPUT: [security.IasAce.cluster_id], + SIG_EP_OUTPUT: [], + SIG_EP_TYPE: zha.DeviceType.IAS_ANCILLARY_CONTROL, + SIG_EP_PROFILE: zha.PROFILE_ID, + } + }, + node_descriptor=b"\x02@\x8c\x02\x10RR\x00\x00\x00R\x00\x00", + ) + zhaws_device = await join_zigpy_device(server, zigpy_device) + + cluster: security.IasAce = zigpy_device.endpoints.get(1).ias_ace + client_device: Optional[DeviceProxy] = controller.devices.get(zhaws_device.ieee) + assert client_device is not None + alarm_entity: AlarmControlPanelEntity = client_device.device_model.entities.get( + (Platform.ALARM_CONTROL_PANEL, "00:0d:6f:00:0a:90:69:e7-1") + ) + assert alarm_entity is not None + assert isinstance(alarm_entity, AlarmControlPanelEntity) + + # test that the state is STATE_ALARM_DISARMED + assert alarm_entity.state.state == "disarmed" + + # arm_away + cluster.client_command.reset_mock() + await controller.alarm_control_panels.arm_away(alarm_entity, "4321") + assert cluster.client_command.call_count == 2 + assert cluster.client_command.await_count == 2 + assert cluster.client_command.call_args == call( + 4, + security.IasAce.PanelStatus.Armed_Away, + 0, + security.IasAce.AudibleNotification.Default_Sound, + security.IasAce.AlarmStatus.No_Alarm, + ) + assert alarm_entity.state.state == "armed_away" + + # disarm + await reset_alarm_panel(server, controller, cluster, alarm_entity) + + # trip alarm from faulty code entry. First we need to arm away + cluster.client_command.reset_mock() + await controller.alarm_control_panels.arm_away(alarm_entity, "4321") + await server.async_block_till_done() + assert alarm_entity.state.state == "armed_away" + cluster.client_command.reset_mock() + + # now simulate a faulty code entry sequence + await controller.alarm_control_panels.disarm(alarm_entity, "0000") + await controller.alarm_control_panels.disarm(alarm_entity, "0000") + await controller.alarm_control_panels.disarm(alarm_entity, "0000") + await server.async_block_till_done() + + assert alarm_entity.state.state == "triggered" + assert cluster.client_command.call_count == 6 + assert cluster.client_command.await_count == 6 + assert cluster.client_command.call_args == call( + 4, + security.IasAce.PanelStatus.In_Alarm, + 0, + security.IasAce.AudibleNotification.Default_Sound, + security.IasAce.AlarmStatus.Emergency, + ) + + # reset the panel + await reset_alarm_panel(server, controller, cluster, alarm_entity) + + # arm_home + await controller.alarm_control_panels.arm_home(alarm_entity, "4321") + await server.async_block_till_done() + assert alarm_entity.state.state == "armed_home" + assert cluster.client_command.call_count == 2 + assert cluster.client_command.await_count == 2 + assert cluster.client_command.call_args == call( + 4, + security.IasAce.PanelStatus.Armed_Stay, + 0, + security.IasAce.AudibleNotification.Default_Sound, + security.IasAce.AlarmStatus.No_Alarm, + ) + + # reset the panel + await reset_alarm_panel(server, controller, cluster, alarm_entity) + + # arm_night + await controller.alarm_control_panels.arm_night(alarm_entity, "4321") + await server.async_block_till_done() + assert alarm_entity.state.state == "armed_night" + assert cluster.client_command.call_count == 2 + assert cluster.client_command.await_count == 2 + assert cluster.client_command.call_args == call( + 4, + security.IasAce.PanelStatus.Armed_Night, + 0, + security.IasAce.AudibleNotification.Default_Sound, + security.IasAce.AlarmStatus.No_Alarm, + ) + + # reset the panel + await reset_alarm_panel(server, controller, cluster, alarm_entity) + + # arm from panel + cluster.listener_event( + "cluster_command", 1, 0, [security.IasAce.ArmMode.Arm_All_Zones, "", 0] + ) + await server.async_block_till_done() + assert alarm_entity.state.state == "armed_away" + + # reset the panel + await reset_alarm_panel(server, controller, cluster, alarm_entity) + + # arm day home only from panel + cluster.listener_event( + "cluster_command", 1, 0, [security.IasAce.ArmMode.Arm_Day_Home_Only, "", 0] + ) + await server.async_block_till_done() + assert alarm_entity.state.state == "armed_home" + + # reset the panel + await reset_alarm_panel(server, controller, cluster, alarm_entity) + + # arm night sleep only from panel + cluster.listener_event( + "cluster_command", 1, 0, [security.IasAce.ArmMode.Arm_Night_Sleep_Only, "", 0] + ) + await server.async_block_till_done() + assert alarm_entity.state.state == "armed_night" + + # disarm from panel with bad code + cluster.listener_event( + "cluster_command", 1, 0, [security.IasAce.ArmMode.Disarm, "", 0] + ) + await server.async_block_till_done() + assert alarm_entity.state.state == "armed_night" + + # disarm from panel with bad code for 2nd time trips alarm + cluster.listener_event( + "cluster_command", 1, 0, [security.IasAce.ArmMode.Disarm, "", 0] + ) + await server.async_block_till_done() + assert alarm_entity.state.state == "triggered" + + # disarm from panel with good code + cluster.listener_event( + "cluster_command", 1, 0, [security.IasAce.ArmMode.Disarm, "4321", 0] + ) + await server.async_block_till_done() + assert alarm_entity.state.state == "disarmed" + + # panic from panel + cluster.listener_event("cluster_command", 1, 4, []) + await server.async_block_till_done() + assert alarm_entity.state.state == "triggered" + + # reset the panel + await reset_alarm_panel(server, controller, cluster, alarm_entity) + + # fire from panel + cluster.listener_event("cluster_command", 1, 3, []) + await server.async_block_till_done() + assert alarm_entity.state.state == "triggered" + + # reset the panel + await reset_alarm_panel(server, controller, cluster, alarm_entity) + + # emergency from panel + cluster.listener_event("cluster_command", 1, 2, []) + await server.async_block_till_done() + assert alarm_entity.state.state == "triggered" + + # reset the panel + await reset_alarm_panel(server, controller, cluster, alarm_entity) + assert alarm_entity.state.state == "disarmed" + + await controller.alarm_control_panels.trigger(alarm_entity) + await server.async_block_till_done() + assert alarm_entity.state.state == "triggered" + + # reset the panel + await reset_alarm_panel(server, controller, cluster, alarm_entity) + assert alarm_entity.state.state == "disarmed" + + +async def reset_alarm_panel( + server: Server, + controller: Controller, + cluster: security.IasAce, + entity: AlarmControlPanelEntity, +) -> None: + """Reset the state of the alarm panel.""" + cluster.client_command.reset_mock() + await controller.alarm_control_panels.disarm(entity, "4321") + await server.async_block_till_done() + assert entity.state.state == "disarmed" + assert cluster.client_command.call_count == 2 + assert cluster.client_command.await_count == 2 + assert cluster.client_command.call_args == call( + 4, + security.IasAce.PanelStatus.Panel_Disarmed, + 0, + security.IasAce.AudibleNotification.Default_Sound, + security.IasAce.AlarmStatus.No_Alarm, + ) + cluster.client_command.reset_mock() diff --git a/zha/application/platforms/alarm_control_panel/__init__.py b/zha/application/platforms/alarm_control_panel/__init__.py index 0f68b9c5a..40846a0c7 100644 --- a/zha/application/platforms/alarm_control_panel/__init__.py +++ b/zha/application/platforms/alarm_control_panel/__init__.py @@ -126,27 +126,27 @@ def handle_cluster_handler_state_changed( """Handle state changed on cluster.""" self.maybe_emit_state_changed_event() - async def async_alarm_disarm(self, code: str | None = None) -> None: + async def async_alarm_disarm(self, code: str | None = None, **kwargs) -> None: """Send disarm command.""" self._cluster_handler.arm(IasAce.ArmMode.Disarm, code, 0) self.maybe_emit_state_changed_event() - async def async_alarm_arm_home(self, code: str | None = None) -> None: + async def async_alarm_arm_home(self, code: str | None = None, **kwargs) -> None: """Send arm home command.""" self._cluster_handler.arm(IasAce.ArmMode.Arm_Day_Home_Only, code, 0) self.maybe_emit_state_changed_event() - async def async_alarm_arm_away(self, code: str | None = None) -> None: + async def async_alarm_arm_away(self, code: str | None = None, **kwargs) -> None: """Send arm away command.""" self._cluster_handler.arm(IasAce.ArmMode.Arm_All_Zones, code, 0) self.maybe_emit_state_changed_event() - async def async_alarm_arm_night(self, code: str | None = None) -> None: + async def async_alarm_arm_night(self, code: str | None = None, **kwargs) -> None: """Send arm night command.""" self._cluster_handler.arm(IasAce.ArmMode.Arm_Night_Sleep_Only, code, 0) self.maybe_emit_state_changed_event() - async def async_alarm_trigger(self, code: str | None = None) -> None: # pylint: disable=unused-argument + async def async_alarm_trigger(self, code: str | None = None, **kwargs) -> None: # pylint: disable=unused-argument """Send alarm trigger command.""" self._cluster_handler.panic() self.maybe_emit_state_changed_event()