Skip to content

Commit

Permalink
Cleanup discovery & add tests (#212)
Browse files Browse the repository at this point in the history
* Cleanup discovery & add tests

* discovered_devices_raw is not anymore available, as that can be accessed directly from the device objects
* test most of the discovery code paths
* some minor cleanups to test handling
* update discovery docs

* Move category check to be after the definitions

* skip a couple of tests requiring asyncmock not available on py37

* Remove return_raw usage from cli.discover
  • Loading branch information
rytilahti committed Sep 24, 2021
1 parent bdb07a7 commit acb221b
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 72 deletions.
6 changes: 2 additions & 4 deletions kasa/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,11 @@ async def discover(ctx, timeout, discover_only, dump_raw):
"""Discover devices in the network."""
target = ctx.parent.params["target"]
click.echo(f"Discovering devices on {target} for {timeout} seconds")
found_devs = await Discover.discover(
target=target, timeout=timeout, return_raw=dump_raw
)
found_devs = await Discover.discover(target=target, timeout=timeout)
if not discover_only:
for ip, dev in found_devs.items():
if dump_raw:
click.echo(dev)
click.echo(dev.sys_info)
continue
ctx.obj = dev
await ctx.invoke(state)
Expand Down
56 changes: 16 additions & 40 deletions kasa/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging
import socket
from typing import Awaitable, Callable, Dict, Mapping, Optional, Type, Union, cast
from typing import Awaitable, Callable, Dict, Optional, Type, cast

from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.smartbulb import SmartBulb
Expand All @@ -17,6 +17,7 @@


OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]]
DeviceDict = Dict[str, SmartDevice]


class _DiscoverProtocol(asyncio.DatagramProtocol):
Expand All @@ -25,8 +26,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
This is internal class, use :func:`Discover.discover`: instead.
"""

discovered_devices: Dict[str, SmartDevice]
discovered_devices_raw: Dict[str, Dict]
discovered_devices: DeviceDict

def __init__(
self,
Expand All @@ -43,7 +43,6 @@ def __init__(
self.protocol = TPLinkSmartHomeProtocol()
self.target = (target, Discover.DISCOVERY_PORT)
self.discovered_devices = {}
self.discovered_devices_raw = {}

def connection_made(self, transport) -> None:
"""Set socket options for broadcasting."""
Expand Down Expand Up @@ -80,13 +79,9 @@ def datagram_received(self, data, addr) -> None:
device.update_from_discover_info(info)

self.discovered_devices[ip] = device
self.discovered_devices_raw[ip] = info

if device_class is not None:
if self.on_discovered is not None:
asyncio.ensure_future(self.on_discovered(device))
else:
_LOGGER.error("Received invalid response: %s", info)
if self.on_discovered is not None:
asyncio.ensure_future(self.on_discovered(device))

def error_received(self, ex):
"""Handle asyncio.Protocol errors."""
Expand Down Expand Up @@ -144,27 +139,26 @@ async def discover(
on_discovered=None,
timeout=5,
discovery_packets=3,
return_raw=False,
interface=None,
) -> Mapping[str, Union[SmartDevice, Dict]]:
) -> DeviceDict:
"""Discover supported devices.
Sends discovery message to 255.255.255.255:9999 in order
to detect available supported devices in the local network,
and waits for given timeout for answers from devices.
If you have multiple interfaces, you can use target parameter to specify the network for discovery.
If given, `on_discovered` coroutine will get passed with the :class:`SmartDevice`-derived object as parameter.
If given, `on_discovered` coroutine will get awaited with a :class:`SmartDevice`-derived object as parameter.
The results of the discovery are returned either as a list of :class:`SmartDevice`-derived objects
or as raw response dictionaries objects (if `return_raw` is True).
The results of the discovery are returned as a dict of :class:`SmartDevice`-derived objects keyed with IP addresses.
The devices are already initialized and all but emeter-related properties can be accessed directly.
:param target: The target address where to send the broadcast discovery queries if multi-homing (e.g. 192.168.xxx.255).
:param on_discovered: coroutine to execute on discovery
:param timeout: How long to wait for responses, defaults to 5
:param discovery_packets: Number of discovery packets are broadcasted.
:param return_raw: True to return JSON objects instead of Devices.
:return:
:param discovery_packets: Number of discovery packets to broadcast
:param interface: Bind to specific interface
:return: dictionary with discovered devices
"""
loop = asyncio.get_event_loop()
transport, protocol = await loop.create_datagram_endpoint(
Expand All @@ -186,9 +180,6 @@ async def discover(

_LOGGER.debug("Discovered %s devices", len(protocol.discovered_devices))

if return_raw:
return protocol.discovered_devices_raw

return protocol.discovered_devices

@staticmethod
Expand All @@ -204,12 +195,10 @@ async def discover_single(host: str) -> SmartDevice:
info = await protocol.query(host, Discover.DISCOVERY_QUERY)

device_class = Discover._get_device_class(info)
if device_class is not None:
dev = device_class(host)
await dev.update()
return dev
dev = device_class(host)
await dev.update()

raise SmartDeviceException("Unable to discover device, received: %s" % info)
return dev

@staticmethod
def _get_device_class(info: dict) -> Type[SmartDevice]:
Expand Down Expand Up @@ -237,17 +226,4 @@ def _get_device_class(info: dict) -> Type[SmartDevice]:

return SmartBulb

raise SmartDeviceException("Unknown device type: %s", type_)


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
loop = asyncio.get_event_loop()

async def _on_device(dev):
await dev.update()
_LOGGER.info("Got device: %s", dev)

devices = loop.run_until_complete(Discover.discover(on_discovered=_on_device))
for ip, dev in devices.items():
print(f"[{ip}] {dev}")
raise SmartDeviceException("Unknown device type: %s" % type_)
67 changes: 39 additions & 28 deletions kasa/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ def filter_model(desc, filter):


def parametrize(desc, devices, ids=None):
# if ids is None:
# ids = ["on", "off"]
return pytest.mark.parametrize(
"dev", filter_model(desc, devices), indirect=True, ids=ids
)
Expand All @@ -63,32 +61,11 @@ def parametrize(desc, devices, ids=None):
has_emeter = parametrize("has emeter", WITH_EMETER)
no_emeter = parametrize("no emeter", ALL_DEVICES - WITH_EMETER)


def name_for_filename(x):
from os.path import basename

return basename(x)


bulb = parametrize("bulbs", BULBS, ids=name_for_filename)
plug = parametrize("plugs", PLUGS, ids=name_for_filename)
strip = parametrize("strips", STRIPS, ids=name_for_filename)
dimmer = parametrize("dimmers", DIMMERS, ids=name_for_filename)
lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=name_for_filename)

# This ensures that every single file inside fixtures/ is being placed in some category
categorized_fixtures = set(
dimmer.args[1] + strip.args[1] + plug.args[1] + bulb.args[1] + lightstrip.args[1]
)
diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures)
if diff:
for file in diff:
print(
"No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)"
% file
)
raise Exception("Missing category for %s" % diff)

bulb = parametrize("bulbs", BULBS, ids=basename)
plug = parametrize("plugs", PLUGS, ids=basename)
strip = parametrize("strips", STRIPS, ids=basename)
dimmer = parametrize("dimmers", DIMMERS, ids=basename)
lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=basename)

# bulb types
dimmable = parametrize("dimmable", DIMMABLE)
Expand All @@ -98,6 +75,28 @@ def name_for_filename(x):
color_bulb = parametrize("color bulbs", COLOR_BULBS)
non_color_bulb = parametrize("non-color bulbs", BULBS - COLOR_BULBS)


def check_categories():
"""Check that every fixture file is categorized."""
categorized_fixtures = set(
dimmer.args[1]
+ strip.args[1]
+ plug.args[1]
+ bulb.args[1]
+ lightstrip.args[1]
)
diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures)
if diff:
for file in diff:
print(
"No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)"
% file
)
raise Exception("Missing category for %s" % diff)


check_categories()

# Parametrize tests to run with device both on and off
turn_on = pytest.mark.parametrize("turn_on", [True, False])

Expand Down Expand Up @@ -174,6 +173,18 @@ def dev(request):
return get_device_for_file(file)


@pytest.fixture(params=SUPPORTED_DEVICES, scope="session")
def discovery_data(request):
"""Return raw discovery file contents as JSON. Used for discovery tests."""
file = request.param
p = Path(file)
if not p.is_absolute():
p = Path(__file__).parent / "fixtures" / file

with open(p) as f:
return json.load(f)


def pytest_addoption(parser):
parser.addoption(
"--ip", action="store", default=None, help="run against device on given ip"
Expand Down
57 changes: 57 additions & 0 deletions kasa/tests/test_discovery.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# type: ignore
import sys

import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342

from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException
from kasa.discover import _DiscoverProtocol

from .conftest import bulb, dimmer, lightstrip, plug, pytestmark, strip

Expand Down Expand Up @@ -47,3 +50,57 @@ async def test_type_unknown():
invalid_info = {"system": {"get_sysinfo": {"type": "nosuchtype"}}}
with pytest.raises(SmartDeviceException):
Discover._get_device_class(invalid_info)


@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock")
async def test_discover_single(discovery_data: dict, mocker):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)
x = await Discover.discover_single("127.0.0.1")
assert issubclass(x.__class__, SmartDevice)
assert x._sys_info is not None


INVALIDS = [
("No 'system' or 'get_sysinfo' in response", {"no": "data"}),
(
"Unable to find the device type field",
{"system": {"get_sysinfo": {"missing_type": 1}}},
),
("Unknown device type: foo", {"system": {"get_sysinfo": {"type": "foo"}}}),
]


@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock")
@pytest.mark.parametrize("msg, data", INVALIDS)
async def test_discover_invalid_info(msg, data, mocker):
"""Make sure that invalid discovery information raises an exception."""
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=data)
with pytest.raises(SmartDeviceException, match=msg):
await Discover.discover_single("127.0.0.1")


async def test_discover_send(mocker):
"""Test discovery parameters."""
proto = _DiscoverProtocol()
assert proto.discovery_packets == 3
assert proto.target == ("255.255.255.255", 9999)
sendto = mocker.patch.object(proto, "transport")
proto.do_discover()
assert sendto.sendto.call_count == proto.discovery_packets


async def test_discover_datagram_received(mocker, discovery_data):
"""Verify that datagram received fills discovered_devices."""
proto = _DiscoverProtocol()
mocker.patch("json.loads", return_value=discovery_data)
mocker.patch.object(proto, "protocol")

addr = "127.0.0.1"
proto.datagram_received("<placeholder data>", (addr, 1234))

# Check that device in discovered_devices is initialized correctly
assert len(proto.discovered_devices) == 1
dev = proto.discovered_devices[addr]
assert issubclass(dev.__class__, SmartDevice)
assert dev.host == addr

0 comments on commit acb221b

Please sign in to comment.