diff --git a/.coveragerc b/.coveragerc index 9763e2cf7..9586300fb 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,2 +1,5 @@ [run] source = zigpy + +omit = + zigpy/typing.py diff --git a/README.md b/README.md index e41654083..66ce8377e 100644 --- a/README.md +++ b/README.md @@ -54,13 +54,40 @@ Packages of tagged versions are also released via PyPI - https://pypi.org/project/zigpy-zigate/ - https://pypi.org/project/zigpy-cc/ -## How to contribute +## How to contribute + +You can contribute to this project either as a end-user, a tester (advanced user contributing constructive issue/bug-reports) or as a developer contibuting code. + +### How to contribute as an end-user + +If you think that you are having problems due to a bug then please see the section below on reporting issues as a tester, but be aware that reporting issues put higher responsibility on your active involment on your part as a tester. + +Some developers might be also interested in receiving donations in the form of money or hardware such as Zigbee modules and devices, and even if such donations are most often donated with no strings attached it could in many cases help the developers motivation and indirect improve the development of this project. + +Sometimes it might just be simpler to just donate money earmarked to specifically let an willing developer buy the exact same type Zigbee device that you are having issues with to be able to replicate the issue themselves in order to troubleshoot and hopefully also solve the problem. + +Consider submitting a post on GitHub projects issues tracker about willingness to making a donation (please see section bellow on posing issues). + +### How to report issues or bugs as a tester + +Issues or bugs are normally first be submitted upstream to the software/project that it utilizing zigpy and its radio libraries, (like for example Home Assistant), however if and when the issue is determened to be in the zigpy or underlying radio library then you should continue by submitting a detailed issue/bug report via the GitHub projects issues tracker. + +Always be sure to first check if there is not already an existing issue posted with the same description before posting a new issue. + +- https://help.github.com/en/github/managing-your-work-on-github/creating-an-issue + - https://guides.github.com/features/issues/ + +### How to contribute as a developer + +If you are looking to make a contribution as a developer to this project we suggest that you follow the steps in these guides: -If you are looking to make a contribution to this project we suggest that you follow the steps in these guides: - https://github.com/firstcontributions/first-contributions/blob/master/README.md -- https://github.com/firstcontributions/first-contributions/blob/master/github-desktop-tutorial.md + - https://github.com/firstcontributions/first-contributions/blob/master/github-desktop-tutorial.md + +Code changes or additions can then be submitted to this project on GitHub via pull requests: -Some developers might also be interested in receiving donations in the form of hardware such as Zigbee modules or devices, and even if such donations are most often donated with no strings attached it could in many cases help the developers motivation and indirect improve the development of this project. +- https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-requests + - https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request ## Developer references diff --git a/setup.py b/setup.py index ebc2235b4..32234b059 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,16 @@ """Setup module for zigpy""" +from os import path + from setuptools import find_packages, setup import zigpy +this_directory = path.join(path.abspath(path.dirname(__file__))) +with open(path.join(this_directory, "README.md"), encoding="utf-8") as f: + long_description = f.read() + setup( - name="zigpy-homeassistant", + name="zigpy", version=zigpy.__version__, description="Library implementing a ZigBee stack", url="http://github.com/zigpy/zigpy", diff --git a/tests/test_appdb.py b/tests/test_appdb.py index 4d498a266..b50c5f2fb 100644 --- a/tests/test_appdb.py +++ b/tests/test_appdb.py @@ -1,9 +1,10 @@ import os -from unittest import mock +from asynctest import CoroutineMock, mock import pytest from zigpy import profiles -from zigpy.application import ControllerApplication +import zigpy.application +from zigpy.config import CONF_DATABASE, ZIGPY_SCHEMA from zigpy.device import Device, Status import zigpy.ota from zigpy.quirks import CustomDevice @@ -13,9 +14,33 @@ from zigpy.zdo import types as zdo_t -def make_app(database_file): - with mock.patch("zigpy.ota.OTA", mock.MagicMock(spec_set=zigpy.ota.OTA)): - app = ControllerApplication(database_file) +async def make_app(database_file): + class App(zigpy.application.ControllerApplication): + async def shutdown(self): + pass + + async def startup(self, auto_form=False): + pass + + async def request( + self, + device, + profile, + cluster, + src_ep, + dst_ep, + sequence, + data, + expect_reply=True, + use_ieee=False, + ): + pass + + async def permit_ncp(self, time_s=60): + pass + + with mock.patch("zigpy.ota.OTA.initialize", CoroutineMock()): + app = await App.new(ZIGPY_SCHEMA({CONF_DATABASE: database_file})) return app @@ -45,12 +70,25 @@ def fake_get_device(device): return device +async def test_no_database(tmpdir): + with mock.patch("zigpy.appdb.PersistingListener") as db_mock: + db_mock.return_value.load.side_effect = CoroutineMock() + await make_app(None) + assert db_mock.return_value.load.call_count == 0 + + db = os.path.join(str(tmpdir), "test.db") + with mock.patch("zigpy.appdb.PersistingListener") as db_mock: + db_mock.return_value.load.side_effect = CoroutineMock() + await make_app(db) + assert db_mock.return_value.load.call_count == 1 + + async def test_database(tmpdir, monkeypatch): monkeypatch.setattr( Device, "schedule_initialize", mock_dev_init(Status.ENDPOINTS_INIT) ) db = os.path.join(str(tmpdir), "test.db") - app = make_app(db) + app = await make_app(db) ieee = make_ieee() relays_1 = [t.NWK(0x1234), t.NWK(0x2345)] relays_2 = [t.NWK(0x3456), t.NWK(0x4567)] @@ -94,7 +132,7 @@ async def test_database(tmpdir, monkeypatch): # Everything should've been saved - check that it re-loads with mock.patch("zigpy.quirks.get_device", fake_get_device): - app2 = make_app(db) + app2 = await make_app(db) dev = app2.get_device(ieee) assert dev.endpoints[1].device_type == profiles.zha.DeviceType.PUMP assert dev.endpoints[2].device_type == 0xFFFD @@ -113,7 +151,7 @@ async def test_database(tmpdir, monkeypatch): app.handle_leave(99, ieee) - app2 = make_app(db) + app2 = await make_app(db) assert ieee in app2.devices async def mockleave(*args, **kwargs): @@ -123,7 +161,7 @@ async def mockleave(*args, **kwargs): await app2.remove(ieee) assert ieee not in app2.devices - app3 = make_app(db) + app3 = await make_app(db) assert ieee not in app3.devices dev = app2.get_device(custom_ieee) assert dev.relays is None @@ -132,9 +170,9 @@ async def mockleave(*args, **kwargs): @mock.patch("zigpy.device.Device.schedule_group_membership_scan", mock.MagicMock()) -def _test_null_padded(tmpdir, test_manufacturer=None, test_model=None): +async def _test_null_padded(tmpdir, test_manufacturer=None, test_model=None): db = os.path.join(str(tmpdir), "test.db") - app = make_app(db) + app = await make_app(db) ieee = make_ieee() with mock.patch( "zigpy.device.Device.schedule_initialize", @@ -156,7 +194,7 @@ def _test_null_padded(tmpdir, test_manufacturer=None, test_model=None): clus.listener_event("zdo_command") # Everything should've been saved - check that it re-loads - app2 = make_app(db) + app2 = await make_app(db) dev = app2.get_device(ieee) assert dev.endpoints[3].device_type == profiles.zha.DeviceType.PUMP assert dev.endpoints[3].in_clusters[0]._attr_cache[4] == test_manufacturer @@ -167,10 +205,10 @@ def _test_null_padded(tmpdir, test_manufacturer=None, test_model=None): return dev -def test_appdb_load_null_padded_manuf(tmpdir): +async def test_appdb_load_null_padded_manuf(tmpdir): manufacturer = b"Mock Manufacturer\x00\x04\\\x00\\\x00\x00\x00\x00\x00\x07" model = b"Mock Model" - dev = _test_null_padded(tmpdir, manufacturer, model) + dev = await _test_null_padded(tmpdir, manufacturer, model) assert dev.manufacturer == "Mock Manufacturer" assert dev.model == "Mock Model" @@ -178,10 +216,10 @@ def test_appdb_load_null_padded_manuf(tmpdir): assert dev.endpoints[3].model == "Mock Model" -def test_appdb_load_null_padded_model(tmpdir): +async def test_appdb_load_null_padded_model(tmpdir): manufacturer = b"Mock Manufacturer" model = b"Mock Model\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - dev = _test_null_padded(tmpdir, manufacturer, model) + dev = await _test_null_padded(tmpdir, manufacturer, model) assert dev.manufacturer == "Mock Manufacturer" assert dev.model == "Mock Model" @@ -189,10 +227,10 @@ def test_appdb_load_null_padded_model(tmpdir): assert dev.endpoints[3].model == "Mock Model" -def test_appdb_load_null_padded_manuf_model(tmpdir): +async def test_appdb_load_null_padded_manuf_model(tmpdir): manufacturer = b"Mock Manufacturer\x00\x04\\\x00\\\x00\x00\x00\x00\x00\x07" model = b"Mock Model\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - dev = _test_null_padded(tmpdir, manufacturer, model) + dev = await _test_null_padded(tmpdir, manufacturer, model) assert dev.manufacturer == "Mock Manufacturer" assert dev.model == "Mock Model" @@ -200,10 +238,10 @@ def test_appdb_load_null_padded_manuf_model(tmpdir): assert dev.endpoints[3].model == "Mock Model" -def test_appdb_str_model(tmpdir): +async def test_appdb_str_model(tmpdir): manufacturer = "Mock Manufacturer" model = "Mock Model" - dev = _test_null_padded(tmpdir, manufacturer, model) + dev = await _test_null_padded(tmpdir, manufacturer, model) assert dev.manufacturer == "Mock Manufacturer" assert dev.model == "Mock Model" @@ -217,7 +255,7 @@ def test_appdb_str_model(tmpdir): ) async def test_node_descriptor_updated(tmpdir, status, success): db = os.path.join(str(tmpdir), "test_nd.db") - app = make_app(db) + app = await make_app(db) nd_ieee = make_ieee(2) with mock.patch.object(Device, "schedule_initialize", new=mock_dev_init(status)): app.handle_join(299, nd_ieee, 0) @@ -242,7 +280,7 @@ async def mock_get_node_descriptor(): assert dev.get_node_descriptor.call_count == 1 - app2 = make_app(db) + app2 = await make_app(db) if success: dev = app2.get_device(nd_ieee) assert dev.status == status @@ -267,7 +305,7 @@ async def mock_request(*args, **kwargs): monkeypatch.setattr(zigpy.zcl.Cluster, "request", mock_request) db = os.path.join(str(tmpdir), "test.db") - app = make_app(db) + app = await make_app(db) ieee = make_ieee() app.handle_join(99, ieee, 0) @@ -296,7 +334,7 @@ async def mock_request(*args, **kwargs): assert group_id in ep.member_of # Everything should've been saved - check that it re-loads - app2 = make_app(db) + app2 = await make_app(db) dev2 = app2.get_device(ieee) assert group_id in app2.groups group = app2.groups[group_id] @@ -310,7 +348,7 @@ async def mock_request(*args, **kwargs): # check member removal await dev_b.remove_from_group(group_id) - app3 = make_app(db) + app3 = await make_app(db) dev3 = app3.get_device(ieee) assert group_id in app3.groups group = app3.groups[group_id] @@ -324,14 +362,14 @@ async def mock_request(*args, **kwargs): # check group removal await dev3.remove_from_group(group_id) - app4 = make_app(db) + app4 = await make_app(db) dev4 = app4.get_device(ieee) assert group_id in app4.groups assert not app4.groups[group_id] assert group_id not in dev4.endpoints[1].member_of app4.groups.pop(group_id) - app5 = make_app(db) + app5 = await make_app(db) assert not app5.groups @@ -339,11 +377,11 @@ async def mock_request(*args, **kwargs): "status, success", ((Status.ENDPOINTS_INIT, True), (Status.ZDO_INIT, False), (Status.NEW, False)), ) -def test_attribute_update(tmpdir, status, success): +async def test_attribute_update(tmpdir, status, success): """Test attribute update for initialized and uninitialized devices.""" db = os.path.join(str(tmpdir), "test.db") - app = make_app(db) + app = await make_app(db) ieee = make_ieee() with mock.patch( "zigpy.device.Device.schedule_initialize", new=mock_dev_init(status) @@ -364,7 +402,7 @@ def test_attribute_update(tmpdir, status, success): app.device_initialized(dev) # Everything should've been saved - check that it re-loads - app2 = make_app(db) + app2 = await make_app(db) if success: dev = app2.get_device(ieee) assert dev.status == status diff --git a/tests/test_application.py b/tests/test_application.py index b8362da63..47637851f 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1,10 +1,11 @@ import asyncio -from unittest import mock import asynctest +from asynctest import CoroutineMock, mock import pytest from zigpy import device -from zigpy.application import ControllerApplication +import zigpy.application +from zigpy.config import CONF_DATABASE, ZIGPY_SCHEMA from zigpy.exceptions import DeliveryError import zigpy.ota import zigpy.types as t @@ -14,7 +15,31 @@ @asynctest.patch("zigpy.ota.OTA", asynctest.MagicMock(spec_set=zigpy.ota.OTA)) @asynctest.patch("zigpy.device.Device._initialize", asynctest.CoroutineMock()) def app(): - return ControllerApplication() + class App(zigpy.application.ControllerApplication): + async def shutdown(self): + pass + + async def startup(self, auto_form=False): + pass + + async def request( + self, + device, + profile, + cluster, + src_ep, + dst_ep, + sequence, + data, + expect_reply=True, + use_ieee=False, + ): + pass + + async def permit_ncp(self, time_s=60): + pass + + return App({CONF_DATABASE: None}) @pytest.fixture @@ -22,9 +47,84 @@ def ieee(init=0): return t.EUI64(map(t.uint8_t, range(init, init + 8))) -async def test_startup(app): - with pytest.raises(NotImplementedError): - await app.startup() +async def test_startup(): + class App(zigpy.application.ControllerApplication): + async def shutdown(self): + pass + + async def request( + self, + device, + profile, + cluster, + src_ep, + dst_ep, + sequence, + data, + expect_reply=True, + use_ieee=False, + ): + pass + + async def permit_ncp(self, time_s=60): + pass + + with pytest.raises(TypeError): + await App({}).startup() + + +@mock.patch("zigpy.ota.OTA", spec_set=zigpy.ota.OTA) +async def test_new_exception(ota_mock): + class App(zigpy.application.ControllerApplication): + async def shutdown(self): + pass + + async def startup(self, auto_form=False): + pass + + async def request( + self, + device, + profile, + cluster, + src_ep, + dst_ep, + sequence, + data, + expect_reply=True, + use_ieee=False, + ): + pass + + async def permit_ncp(self, time_s=60): + pass + + p1 = mock.patch.object(App, "_load_db", CoroutineMock()) + p2 = mock.patch.object(App, "startup", CoroutineMock()) + p3 = mock.patch.object(App, "shutdown", CoroutineMock()) + ota_mock.return_value.initialize.side_effect = CoroutineMock() + + with p1 as db_mck, p2 as start_mck, p3 as shut_mck: + await App.new(ZIGPY_SCHEMA({CONF_DATABASE: "/dev/null"})) + assert db_mck.call_count == 1 + assert db_mck.await_count == 1 + assert ota_mock.return_value.initialize.call_count == 1 + assert start_mck.call_count == 1 + assert start_mck.await_count == 1 + assert shut_mck.call_count == 0 + assert shut_mck.await_count == 0 + + start_mck.side_effect = asyncio.TimeoutError + with p1 as db_mck, p2 as start_mck, p3 as shut_mck: + with pytest.raises(asyncio.TimeoutError): + await App.new(ZIGPY_SCHEMA({CONF_DATABASE: "/dev/null"})) + assert db_mck.call_count == 2 + assert db_mck.await_count == 2 + assert ota_mock.return_value.initialize.call_count == 2 + assert start_mck.call_count == 2 + assert start_mck.await_count == 2 + assert shut_mck.call_count == 1 + assert shut_mck.await_count == 1 async def test_form_network(app): @@ -37,14 +137,45 @@ async def test_force_remove(app): await app.force_remove(None) -async def test_request(app): - with pytest.raises(NotImplementedError): - await app.request(None, None, None, None, None, None, None) +async def test_request(): + class App(zigpy.application.ControllerApplication): + async def shutdown(self): + pass + async def startup(self, auto_form=False): + pass -async def test_permit_ncp(app): - with pytest.raises(NotImplementedError): - await app.permit_ncp() + async def permit_ncp(self, time_s=60): + pass + + with pytest.raises(TypeError): + await App({}).request(None, None, None, None, None, None, None) + + +async def test_permit_ncp(): + class App(zigpy.application.ControllerApplication): + async def shutdown(self): + pass + + async def startup(self, auto_form=False): + pass + + async def request( + self, + device, + profile, + cluster, + src_ep, + dst_ep, + sequence, + data, + expect_reply=True, + use_ieee=False, + ): + pass + + with pytest.raises(TypeError): + await App({}).permit_ncp() async def test_permit(app, ieee): @@ -233,8 +364,30 @@ async def test_broadcast(app): ) -async def test_shutdown(app): - await app.shutdown() +async def test_shutdown(): + class App(zigpy.application.ControllerApplication): + async def startup(self, auto_form=False): + pass + + async def request( + self, + device, + profile, + cluster, + src_ep, + dst_ep, + sequence, + data, + expect_reply=True, + use_ieee=False, + ): + pass + + async def permit_ncp(self, time_s=60): + pass + + with pytest.raises(TypeError): + await App({}).shutdown() def test_get_dst_address(app): diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 000000000..1af339b5e --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,42 @@ +"""Test configuration.""" + +import pytest +import voluptuous as vol +import zigpy.config + + +@pytest.mark.parametrize( + "value, result", + [ + (False, False), + (True, True), + ("1", True), + ("yes", True), + ("YeS", True), + ("on", True), + ("oN", True), + ("enable", True), + ("enablE", True), + (0, False), + ("no", False), + ("nO", False), + ("off", False), + ("ofF", False), + ("disable", False), + ("disablE", False), + ], +) +def test_config_validation_bool(value, result): + """Test boolean config validation.""" + assert zigpy.config.cv_boolean(value) is result + + schema = vol.Schema({vol.Required("value"): zigpy.config.cv_boolean}) + validated = schema({"value": value}) + assert validated["value"] is result + + +@pytest.mark.parametrize("value", ["invalid", "not a bool", "something"]) +def test_config_validation_bool_invalid(value): + """Test boolean config validation.""" + with pytest.raises(vol.Invalid): + zigpy.config.cv_boolean(value) diff --git a/tests/test_ota.py b/tests/test_ota.py index 83d38f4f3..2bf6c0d79 100644 --- a/tests/test_ota.py +++ b/tests/test_ota.py @@ -46,20 +46,11 @@ def ota(): async def test_ota_initialize(ota): ota.async_event = CoroutineMock() - await ota._initialize(mock.sentinel.ota_dir) + await ota.initialize() assert ota.async_event.call_count == 1 assert ota.async_event.call_args[0][0] == "initialize_provider" - assert ota.async_event.call_args[0][1] == mock.sentinel.ota_dir - - -async def test_initialize(ota): - ota._initialize = CoroutineMock() - - assert ota.not_initialized - ota.initialize(mock.sentinel.ota_dir) - assert not ota.not_initialized - assert ota._initialize.call_count == 1 + assert ota.not_initialized is False async def test_get_image_empty(ota, image, key): diff --git a/tests/test_ota_provider.py b/tests/test_ota_provider.py index f82255d6b..3c4f0cfb9 100644 --- a/tests/test_ota_provider.py +++ b/tests/test_ota_provider.py @@ -5,6 +5,7 @@ from asynctest import CoroutineMock, patch import pytest +from zigpy.config import CONF_OTA_DIR, CONF_OTA_IKEA, CONF_OTA_LEDVANCE import zigpy.ota import zigpy.ota.image import zigpy.ota.provider as ota_p @@ -79,7 +80,14 @@ def ikea_image(ikea_image_with_version): @pytest.fixture def basic_prov(): - p = ota_p.Basic() + class Prov(ota_p.Basic): + async def initialize_provider(self, ota_config): + return None + + async def refresh_firmware_list(self): + return None + + p = Prov() p.enable() return p @@ -105,11 +113,6 @@ async def test_initialize_provider(basic_prov): await basic_prov.initialize_provider(mock.sentinel.ota_dir) -async def test_basic_refresh_firmware_list(basic_prov): - with pytest.raises(NotImplementedError): - await basic_prov.refresh_firmware_list() - - async def test_basic_get_image(basic_prov, key): image = mock.MagicMock() image.fetch_image = CoroutineMock(return_value=mock.sentinel.image) @@ -146,17 +149,15 @@ async def test_basic_get_image(basic_prov, key): assert image.fetch_image.call_count == 1 -def test_basic_enable_provider(key): - basic_prov = ota_p.Basic() +def test_basic_enable_provider(key, basic_prov): + assert basic_prov.is_enabled is True + basic_prov.disable() assert basic_prov.is_enabled is False basic_prov.enable() assert basic_prov.is_enabled is True - basic_prov.disable() - assert basic_prov.is_enabled is False - async def test_basic_get_image_filtered(basic_prov, key): image = mock.MagicMock() @@ -175,29 +176,11 @@ async def test_basic_get_image_filtered(basic_prov, key): assert image.fetch_image.call_count == 0 -async def test_ikea_init_no_ota_dir(ikea_prov): - ikea_prov.enable = mock.MagicMock() - ikea_prov.refresh_firmware_list = CoroutineMock() - - r = await ikea_prov.initialize_provider(None) - assert r is None - assert ikea_prov.enable.call_count == 0 - assert ikea_prov.refresh_firmware_list.call_count == 0 - - async def test_ikea_init_ota_dir(ikea_prov, tmpdir): ikea_prov.enable = mock.MagicMock() ikea_prov.refresh_firmware_list = CoroutineMock() - r = await ikea_prov.initialize_provider(str(tmpdir)) - assert r is None - assert ikea_prov.enable.call_count == 0 - assert ikea_prov.refresh_firmware_list.call_count == 0 - - # create flag - with open(os.path.join(str(tmpdir), ota_p.ENABLE_IKEA_OTA), mode="w+"): - pass - r = await ikea_prov.initialize_provider(str(tmpdir)) + r = await ikea_prov.initialize_provider({CONF_OTA_IKEA: True}) assert r is None assert ikea_prov.enable.call_count == 1 assert ikea_prov.refresh_firmware_list.call_count == 1 @@ -443,7 +426,7 @@ async def test_filestore_init_provider_success(file_prov): file_prov.refresh_firmware_list = CoroutineMock() file_prov.validate_ota_dir = mock.MagicMock(return_value=mock.sentinel.ota_dir) - r = await file_prov.initialize_provider(mock.sentinel.ota_dir) + r = await file_prov.initialize_provider({CONF_OTA_DIR: mock.sentinel.ota_dir}) assert r is None assert file_prov.validate_ota_dir.call_count == 1 assert file_prov.validate_ota_dir.call_args[0][0] == mock.sentinel.ota_dir @@ -456,7 +439,7 @@ async def test_filestore_init_provider_failure(file_prov): file_prov.refresh_firmware_list = CoroutineMock() file_prov.validate_ota_dir = mock.MagicMock(return_value=None) - r = await file_prov.initialize_provider(mock.sentinel.ota_dir) + r = await file_prov.initialize_provider({CONF_OTA_DIR: mock.sentinel.ota_dir}) assert r is None assert file_prov.validate_ota_dir.call_count == 1 assert file_prov.validate_ota_dir.call_args[0][0] == mock.sentinel.ota_dir @@ -600,29 +583,11 @@ def ledvance_key(): return zigpy.ota.image.ImageKey(LEDVANCE_ID, LEDVANCE_IMAGE_TYPE) -async def test_ledvance_init_no_ota_dir(ledvance_prov): - ledvance_prov.enable = mock.MagicMock() - ledvance_prov.refresh_firmware_list = CoroutineMock() - - r = await ledvance_prov.initialize_provider(None) - assert r is None - assert ledvance_prov.enable.call_count == 0 - assert ledvance_prov.refresh_firmware_list.call_count == 0 - - -async def test_ledvance_init_ota_dir(ledvance_prov, tmpdir): +async def test_ledvance_init_ota_dir(ledvance_prov): ledvance_prov.enable = mock.MagicMock() ledvance_prov.refresh_firmware_list = CoroutineMock() - r = await ledvance_prov.initialize_provider(str(tmpdir)) - assert r is None - assert ledvance_prov.enable.call_count == 0 - assert ledvance_prov.refresh_firmware_list.call_count == 0 - - # create flag - with open(os.path.join(str(tmpdir), ota_p.ENABLE_LEDVANCE_OTA), mode="w+"): - pass - r = await ledvance_prov.initialize_provider(str(tmpdir)) + r = await ledvance_prov.initialize_provider({CONF_OTA_LEDVANCE: True}) assert r is None assert ledvance_prov.enable.call_count == 1 assert ledvance_prov.refresh_firmware_list.call_count == 1 diff --git a/tests/test_zdo_types.py b/tests/test_zdo_types.py index de6e6e958..ff27f4137 100644 --- a/tests/test_zdo_types.py +++ b/tests/test_zdo_types.py @@ -154,9 +154,11 @@ def test_size_prefixed_simple_descriptor(): ser = sd.serialize() assert ser[0] == len(ser) - 1 - sd2, data = types.SizePrefixedSimpleDescriptor.deserialize(ser) + sd2, data = types.SizePrefixedSimpleDescriptor.deserialize(ser + b"extra") assert sd.input_clusters == sd2.input_clusters assert sd.output_clusters == sd2.output_clusters + assert isinstance(sd2, types.SizePrefixedSimpleDescriptor) + assert data == b"extra" def test_empty_size_prefixed_simple_descriptor(): @@ -164,6 +166,11 @@ def test_empty_size_prefixed_simple_descriptor(): assert r == (None, b"") +def test_invalid_size_prefixed_simple_descriptor(): + with pytest.raises(ValueError): + types.SizePrefixedSimpleDescriptor.deserialize(b"\x01") + + def test_status_undef(): data = b"\xff" extra = b"extra" diff --git a/tests/test_zigbee_util.py b/tests/test_zigbee_util.py index 46051f332..9367855b7 100644 --- a/tests/test_zigbee_util.py +++ b/tests/test_zigbee_util.py @@ -1,5 +1,6 @@ import asyncio import logging +import sys from asynctest import CoroutineMock, mock import pytest @@ -65,6 +66,45 @@ def test_log(): log.error("Test error") +@pytest.mark.skipif( + sys.version_info < (3, 8), reason="logging stacklevel kwarg was introduced in 3.8" +) +def test_log_stacklevel(): + class MockHandler(logging.Handler): + emit = mock.Mock() + + handler = MockHandler() + + LOGGER = logging.getLogger("test_log_stacklevel") + LOGGER.setLevel(logging.DEBUG) + LOGGER.addHandler(handler) + + class TestClass(util.LocalLogMixin): + def log(self, lvl, msg, *args, **kwargs): + LOGGER.log(lvl, msg, *args, **kwargs) + + def test_method(self): + self.info("Test1") + LOGGER.info("Test2") + + TestClass().test_method() + + assert handler.emit.call_count == 2 + + indirect_call, direct_call = handler.emit.mock_calls + (indirect,) = indirect_call[1] + (direct,) = direct_call[1] + + assert indirect.message == "Test1" + assert direct.message == "Test2" + assert direct.levelname == indirect.levelname + + assert direct.module == indirect.module + assert direct.filename == indirect.filename + assert direct.funcName == indirect.funcName + assert direct.lineno == indirect.lineno + 1 + + async def _test_retry(exception, retry_exceptions, n): counter = 0 diff --git a/zigpy/__init__.py b/zigpy/__init__.py index bbffbdf61..2f5f4894c 100644 --- a/zigpy/__init__.py +++ b/zigpy/__init__.py @@ -1,6 +1,6 @@ # coding: utf-8 MAJOR_VERSION = 0 -MINOR_VERSION = 19 +MINOR_VERSION = 20 PATCH_VERSION = "0" __short_version__ = "{}.{}".format(MAJOR_VERSION, MINOR_VERSION) __version__ = "{}.{}".format(__short_version__, PATCH_VERSION) diff --git a/zigpy/appdb.py b/zigpy/appdb.py index 4ac134399..3c91787ae 100644 --- a/zigpy/appdb.py +++ b/zigpy/appdb.py @@ -314,14 +314,14 @@ def _scan(self, table, filter=None): return self.execute("SELECT * FROM %s" % (table,)) return self.execute("SELECT * FROM %s WHERE %s" % (table, filter)) - def load(self): + async def load(self) -> None: LOGGER.debug("Loading application state from %s", self._database_file) - self._load_devices() - self._load_node_descriptors() - self._load_endpoints() - self._load_clusters() + await self._load_devices() + await self._load_node_descriptors() + await self._load_endpoints() + await self._load_clusters() - def _load_attributes(filter=None): + async def _load_attributes(filter: str = None) -> None: for (ieee, endpoint_id, cluster, attrid, value) in self._scan( "attributes", filter ): @@ -345,28 +345,28 @@ def _load_attributes(filter=None): else: dev.model = value - _load_attributes("attrid=4 OR attrid=5") + await _load_attributes("attrid=4 OR attrid=5") for device in self._application.devices.values(): device = zigpy.quirks.get_device(device) self._application.devices[device.ieee] = device - _load_attributes() - self._load_groups() - self._load_group_members() - self._load_relays() + await _load_attributes() + await self._load_groups() + await self._load_group_members() + await self._load_relays() - def _load_devices(self): + async def _load_devices(self): for (ieee, nwk, status) in self._scan("devices"): dev = self._application.add_device(ieee, nwk) dev.status = zigpy.device.Status(status) - def _load_node_descriptors(self): + async def _load_node_descriptors(self): for (ieee, value) in self._scan("node_descriptors"): dev = self._application.get_device(ieee) dev.node_desc = zdo_t.NodeDescriptor.deserialize(value)[0] - def _load_endpoints(self): + async def _load_endpoints(self): for (ieee, epid, profile_id, device_type, status) in self._scan("endpoints"): dev = self._application.get_device(ieee) ep = dev.add_endpoint(epid) @@ -378,7 +378,7 @@ def _load_endpoints(self): ep.device_type = zigpy.profiles.zll.DeviceType(device_type) ep.status = zigpy.endpoint.Status(status) - def _load_clusters(self): + async def _load_clusters(self): for (ieee, endpoint_id, cluster) in self._scan("clusters"): dev = self._application.get_device(ieee) ep = dev.endpoints[endpoint_id] @@ -389,18 +389,18 @@ def _load_clusters(self): ep = dev.endpoints[endpoint_id] ep.add_output_cluster(cluster) - def _load_groups(self): + async def _load_groups(self): for (group_id, name) in self._scan("groups"): self._application.groups.add_group(group_id, name, suppress_event=True) - def _load_group_members(self): + async def _load_group_members(self): for (group_id, ieee, ep_id) in self._scan("group_members"): group = self._application.groups[group_id] group.add_member( self._application.get_device(ieee).endpoints[ep_id], suppress_event=True ) - def _load_relays(self): + async def _load_relays(self): for (ieee, value) in self._scan("relays"): dev = self._application.get_device(ieee) dev.relays = t.Relays.deserialize(value)[0] diff --git a/zigpy/application.py b/zigpy/application.py index af7d6c1bb..9f0258a80 100644 --- a/zigpy/application.py +++ b/zigpy/application.py @@ -1,10 +1,10 @@ +import abc import asyncio import logging -import os.path from typing import Dict, Optional -import voluptuous as vol import zigpy.appdb +import zigpy.config import zigpy.device import zigpy.group import zigpy.ota @@ -15,49 +15,65 @@ import zigpy.zdo import zigpy.zdo.types as zdo_types -CONFIG_SCHEMA = vol.Schema({}, extra=vol.ALLOW_EXTRA) DEFAULT_ENDPOINT_ID = 1 LOGGER = logging.getLogger(__name__) -OTA_DIR = "zigpy_ota/" -class ControllerApplication(zigpy.util.ListenableMixin): - def __init__(self, database_file=None, config={}): +class ControllerApplication(zigpy.util.ListenableMixin, abc.ABC): + def __init__(self, config: Dict): self._send_sequence = 0 self.devices: Dict[t.EUI64, zigpy.device.Device] = {} - self._groups = zigpy.group.Groups(self) self._listeners = {} - self._config = CONFIG_SCHEMA(config) self._channel = None self._channels = None + self._config = config + self._dblistener = None self._ext_pan_id = None + self._groups = zigpy.group.Groups(self) self._ieee = None + self._listeners = {} self._nwk = None self._nwk_update_id = None - self._pan_id = None - self._ota = zigpy.ota.OTA(self) - if database_file is None: - ota_dir = None - else: - ota_dir = os.path.dirname(database_file) - ota_dir = os.path.join(ota_dir, OTA_DIR) - self.ota.initialize(ota_dir) + self._pan_id = None + self._send_sequence = 0 - self._dblistener = None - if database_file is not None: - self._dblistener = zigpy.appdb.PersistingListener(database_file, self) - self.add_listener(self._dblistener) - self.groups.add_listener(self._dblistener) - self._dblistener.load() + async def _load_db(self) -> None: + """Restore save state.""" + database_file = self.config[zigpy.config.CONF_DATABASE] + if not database_file: + return + self._dblistener = zigpy.appdb.PersistingListener(database_file, self) + self.add_listener(self._dblistener) + self.groups.add_listener(self._dblistener) + await self._dblistener.load() + + @classmethod + async def new( + cls, config: Dict, auto_form: bool = False, start_radio: bool = True + ) -> "ControllerApplication": + """Create new instance of application controller.""" + app = cls(config) + await app._load_db() + await app.ota.initialize() + if start_radio: + try: + await app.startup(auto_form) + except Exception: + LOGGER.error("Couldn't start application") + await app.shutdown() + raise + + return app + + @abc.abstractmethod async def shutdown(self): """Perform a complete application shutdown.""" - pass + @abc.abstractmethod async def startup(self, auto_form=False): """Perform a complete application startup""" - raise NotImplementedError async def form_network(self, channel=15, pan_id=None, extended_pan_id=None): """Form a new network""" @@ -206,6 +222,7 @@ async def mrequest( """ raise NotImplementedError + @abc.abstractmethod @zigpy.util.retryable_request async def request( self, @@ -233,7 +250,6 @@ async def request( :returns: return a tuple of a status and an error_message. Original requestor has more context to provide a more meaningful error message """ - raise NotImplementedError async def broadcast( self, @@ -264,9 +280,9 @@ async def broadcast( """ raise NotImplementedError + @abc.abstractmethod async def permit_ncp(self, time_s=60): """Permit joining on NCP.""" - raise NotImplementedError async def permit(self, time_s=60, node=None): """Permit joining on a specific node or all router nodes.""" diff --git a/zigpy/config.py b/zigpy/config.py new file mode 100644 index 000000000..1352dac68 --- /dev/null +++ b/zigpy/config.py @@ -0,0 +1,47 @@ +"""Config schemas and validation.""" +from typing import Union + +import voluptuous as vol + +CONF_DATABASE = "database_path" +CONF_DEVICE = "device" +CONF_DEVICE_PATH = "path" +CONF_OTA = "ota" +CONF_OTA_DIR = "otau_directory" +CONF_OTA_IKEA = "ikea_provider" +CONF_OTA_LEDVANCE = "ledvance_provider" + + +def cv_boolean(value: Union[bool, int, str]) -> bool: + """Validate and coerce a boolean value.""" + if isinstance(value, bool): + return value + if isinstance(value, str): + value = value.lower().strip() + if value in ("1", "true", "yes", "on", "enable"): + return True + if value in ("0", "false", "no", "off", "disable"): + return False + elif isinstance(value, int): + return bool(value) + raise vol.Invalid("invalid boolean value {}".format(value)) + + +SCHEMA_DEVICE = vol.Schema({vol.Required(CONF_DEVICE_PATH): str}) +SCHEMA_OTA = { + vol.Optional(CONF_OTA_IKEA, default=False): cv_boolean, + vol.Optional(CONF_OTA_LEDVANCE, default=False): cv_boolean, + vol.Optional(CONF_OTA_DIR, default=None): vol.Any(None, str), +} + +ZIGPY_SCHEMA = vol.Schema( + { + vol.Optional(CONF_DATABASE, default=None): vol.Any(None, str), + vol.Optional(CONF_OTA, default={}): SCHEMA_OTA, + }, + extra=vol.ALLOW_EXTRA, +) + +CONFIG_SCHEMA = ZIGPY_SCHEMA.extend( + {vol.Required(CONF_DEVICE): SCHEMA_DEVICE}, extra=vol.ALLOW_EXTRA +) diff --git a/zigpy/ota/__init__.py b/zigpy/ota/__init__.py index 78237bc7f..04bbeced2 100644 --- a/zigpy/ota/__init__.py +++ b/zigpy/ota/__init__.py @@ -1,12 +1,13 @@ """OTA support for Zigbee devices.""" -import asyncio import datetime import logging from typing import Optional import attr +from zigpy.config import CONF_OTA, CONF_OTA_DIR, CONF_OTA_IKEA, CONF_OTA_LEDVANCE from zigpy.ota.image import ImageKey, OTAImage import zigpy.ota.provider +from zigpy.typing import ControllerApplicationType import zigpy.util LOGGER = logging.getLogger(__name__) @@ -44,22 +45,22 @@ def get_image_block(self, *args, **kwargs) -> bytes: class OTA(zigpy.util.ListenableMixin): """OTA Manager.""" - def __init__(self, app, *args, **kwargs): - self._app = app + def __init__(self, app: ControllerApplicationType, *args, **kwargs): + self._app: ControllerApplicationType = app self._image_cache = {} self._not_initialized = True self._listeners = {} - self.add_listener(zigpy.ota.provider.Trådfri()) - self.add_listener(zigpy.ota.provider.FileStore()) - self.add_listener(zigpy.ota.provider.Ledvance()) - - async def _initialize(self, ota_dir: str) -> None: - LOGGER.debug("Initialize OTA providers") - await self.async_event("initialize_provider", ota_dir) - - def initialize(self, ota_dir: str) -> None: + ota_config = app.config[CONF_OTA] + if ota_config[CONF_OTA_IKEA]: + self.add_listener(zigpy.ota.provider.Trådfri()) + if ota_config[CONF_OTA_DIR]: + self.add_listener(zigpy.ota.provider.FileStore()) + if ota_config[CONF_OTA_LEDVANCE]: + self.add_listener(zigpy.ota.provider.Ledvance()) + + async def initialize(self) -> None: + await self.async_event("initialize_provider", self._app.config[CONF_OTA]) self._not_initialized = False - asyncio.ensure_future(self._initialize(ota_dir)) async def get_ota_image(self, manufacturer_id, image_type) -> Optional[OTAImage]: key = ImageKey(manufacturer_id, image_type) diff --git a/zigpy/ota/provider.py b/zigpy/ota/provider.py index 6274fc029..e3fab3026 100644 --- a/zigpy/ota/provider.py +++ b/zigpy/ota/provider.py @@ -1,14 +1,16 @@ """OTA Firmware providers.""" +from abc import ABC, abstractmethod import asyncio from collections import defaultdict import datetime import logging import os import os.path -from typing import Optional +from typing import Dict, Optional import aiohttp import attr +from zigpy.config import CONF_OTA_DIR from zigpy.ota.image import ImageKey, OTAImage, OTAImageHeader import zigpy.util @@ -20,7 +22,7 @@ SKIP_OTA_FILES = (ENABLE_IKEA_OTA, ENABLE_LEDVANCE_OTA) -class Basic(zigpy.util.LocalLogMixin): +class Basic(zigpy.util.LocalLogMixin, ABC): """Skeleton OTA Firmware provider.""" REFRESH = datetime.timedelta(hours=12) @@ -31,12 +33,13 @@ def __init__(self): self._locks = defaultdict(asyncio.Semaphore) self._last_refresh = None - async def initialize_provider(self, ota_dir: str) -> None: - pass + @abstractmethod + async def initialize_provider(self, ota_config: Dict) -> None: + """Initialize OTA provider.""" + @abstractmethod async def refresh_firmware_list(self) -> None: """Loads list of firmware into memory.""" - raise NotImplementedError async def filter_get_image(self, key: ImageKey) -> bool: """Filter unwanted get_image lookups.""" @@ -139,14 +142,10 @@ class Trådfri(Basic): MANUFACTURER_ID = 4476 HEADERS = {"accept": "application/json;q=0.9,*/*;q=0.8"} - async def initialize_provider(self, ota_dir: str) -> None: - if ota_dir is None: - return - - if os.path.isfile(os.path.join(ota_dir, ENABLE_IKEA_OTA)): - self.info("OTA provider enabled") - await self.refresh_firmware_list() - self.enable() + async def initialize_provider(self, ota_config: Dict) -> None: + self.info("OTA provider enabled") + await self.refresh_firmware_list() + self.enable() async def refresh_firmware_list(self) -> None: if self._locks[LOCK_REFRESH].locked(): @@ -238,14 +237,10 @@ class Ledvance(Basic): UPDATE_URL = "https://api.update.ledvance.com/v1/zigbee/firmwares" HEADERS = {"accept": "application/json"} - async def initialize_provider(self, ota_dir: str) -> None: - if ota_dir is None: - return - - if os.path.isfile(os.path.join(ota_dir, ENABLE_LEDVANCE_OTA)): - self.info("OTA provider enabled") - await self.refresh_firmware_list() - self.enable() + async def initialize_provider(self, ota_config: Dict) -> None: + self.info("OTA provider enabled") + await self.refresh_firmware_list() + self.enable() async def refresh_firmware_list(self) -> None: if self._locks[LOCK_REFRESH].locked(): @@ -323,9 +318,9 @@ def _fetch_image(self) -> Optional[OTAImage]: class FileStore(Basic): - def __init__(self, ota_dir=None): + def __init__(self): super().__init__() - self._ota_dir = self.validate_ota_dir(ota_dir) + self._ota_dir = None @staticmethod def validate_ota_dir(ota_dir: str) -> str: @@ -341,9 +336,9 @@ def validate_ota_dir(ota_dir: str) -> str: LOGGER.debug("OTA image directory '%s' does not exist", ota_dir) return None - async def initialize_provider(self, ota_dir: str) -> None: - if self._ota_dir is None: - self._ota_dir = self.validate_ota_dir(ota_dir) + async def initialize_provider(self, ota_config: Dict) -> None: + ota_dir = ota_config[CONF_OTA_DIR] + self._ota_dir = self.validate_ota_dir(ota_dir) if self._ota_dir is not None: self.enable() diff --git a/zigpy/typing.py b/zigpy/typing.py new file mode 100644 index 000000000..98cf5e703 --- /dev/null +++ b/zigpy/typing.py @@ -0,0 +1,23 @@ +"""Typing helpers for Zigpy.""" + +from typing import TYPE_CHECKING + +import zigpy.application +import zigpy.device +import zigpy.endpoint +import zigpy.zcl +import zigpy.zdo + +# pylint: disable=invalid-name +ClusterType = "Cluster" +ControllerApplicationType = "ControllerApplication" +DeviceType = "Device" +EndpointType = "Endpoint" +ZDOType = "ZDO" + +if TYPE_CHECKING: + ClusterType = zigpy.zcl.Cluster + ControllerApplicationType = zigpy.application.ControllerApplication + DeviceType = zigpy.device.Device + EndpointType = zigpy.endpoint.Endpoint + ZDOType = zigpy.zdo.ZDO diff --git a/zigpy/util.py b/zigpy/util.py index aee087de1..8bf11bdcb 100644 --- a/zigpy/util.py +++ b/zigpy/util.py @@ -3,6 +3,7 @@ import functools import inspect import logging +import sys import traceback from typing import Any, Coroutine, Optional, Tuple, Type, Union @@ -72,20 +73,27 @@ class LocalLogMixin: def log(self, lvl: int, msg: str, *args, **kwargs): # pragma: no cover pass + def _log(self, lvl: int, msg: str, *args, **kwargs): + if sys.version_info >= (3, 8): + # We have to exclude log, _log, and info + return self.log(lvl, msg, *args, stacklevel=4, **kwargs) + + return self.log(lvl, msg, *args, **kwargs) + def exception(self, msg, *args, **kwargs): - return self.log(logging.ERROR, msg, *args, **kwargs) + return self._log(logging.ERROR, msg, *args, **kwargs) def debug(self, msg, *args, **kwargs): - return self.log(logging.DEBUG, msg, *args, **kwargs) + return self._log(logging.DEBUG, msg, *args, **kwargs) def info(self, msg, *args, **kwargs): - return self.log(logging.INFO, msg, *args, **kwargs) + return self._log(logging.INFO, msg, *args, **kwargs) def warning(self, msg, *args, **kwargs): - return self.log(logging.WARNING, msg, *args, **kwargs) + return self._log(logging.WARNING, msg, *args, **kwargs) def error(self, msg, *args, **kwargs): - return self.log(logging.ERROR, msg, *args, **kwargs) + return self._log(logging.ERROR, msg, *args, **kwargs) async def retry(func, retry_exceptions, tries=3, delay=0.1): diff --git a/zigpy/zdo/types.py b/zigpy/zdo/types.py index fec5c0c4c..0b657d68a 100644 --- a/zigpy/zdo/types.py +++ b/zigpy/zdo/types.py @@ -30,7 +30,7 @@ def serialize(self): def deserialize(cls, data): if not data or data[0] == 0: return None, data[1:] - return SimpleDescriptor.deserialize(data[1:]) + return super().deserialize(data[1:]) class LogicalType(t.enum8):