Skip to content

Commit

Permalink
Make DiscoveryProtocol part of the public API
Browse files Browse the repository at this point in the history
This makes it easier to perform different type of discovery strategies by the library users, e.g., to send out discoveries to multiple targets using a single discovery protocol instance.
  • Loading branch information
rytilahti committed Oct 5, 2023
1 parent 84a501b commit 46199e3
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 29 deletions.
3 changes: 2 additions & 1 deletion kasa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from importlib.metadata import version

from kasa.credentials import Credentials
from kasa.discover import Discover
from kasa.discover import Discover, DiscoveryProtocol
from kasa.emeterstatus import EmeterStatus
from kasa.exceptions import (
AuthenticationException,
Expand All @@ -34,6 +34,7 @@

__all__ = [
"Discover",
"DiscoveryProtocol",
"TPLinkSmartHomeProtocol",
"SmartBulb",
"SmartBulbPreset",
Expand Down
29 changes: 13 additions & 16 deletions kasa/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
DeviceDict = Dict[str, SmartDevice]


class _DiscoverProtocol(asyncio.DatagramProtocol):
class DiscoveryProtocol(asyncio.DatagramProtocol):
"""Implementation of the discovery protocol handler.
This is internal class, use :func:`Discover.discover`: instead.
Expand All @@ -40,7 +40,6 @@ def __init__(
self,
*,
on_discovered: Optional[OnDiscoveredCallable] = None,
target: str = "255.255.255.255",
discovery_packets: int = 3,
interface: Optional[str] = None,
on_unsupported: Optional[Callable[[Dict], Awaitable[None]]] = None,
Expand All @@ -53,8 +52,6 @@ def __init__(
self.interface = interface
self.on_discovered = on_discovered
self.discovery_port = port or Discover.DISCOVERY_PORT
self.target = (target, self.discovery_port)
self.target_2 = (target, Discover.DISCOVERY_PORT_2)
self.discovered_devices = {}
self.unsupported_devices: Dict = {}
self.invalid_device_exceptions: Dict = {}
Expand All @@ -78,16 +75,16 @@ def connection_made(self, transport) -> None:
socket.SOL_SOCKET, socket.SO_BINDTODEVICE, self.interface.encode()
)

self.do_discover()

def do_discover(self) -> None:
def do_discover(self, host: str) -> None:
"""Send number of discovery datagrams."""
req = json_dumps(Discover.DISCOVERY_QUERY)
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
target = (host, self.discovery_port)
target_2 = (host, Discover.DISCOVERY_PORT_2)
_LOGGER.debug("[DISCOVERY] %s >> %s", host, Discover.DISCOVERY_QUERY)
encrypted_req = TPLinkSmartHomeProtocol.encrypt(req)
for i in range(self.discovery_packets):
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
self.transport.sendto(encrypted_req[4:], target) # type: ignore
self.transport.sendto(Discover.DISCOVERY_QUERY_2, target_2) # type: ignore

def datagram_received(self, data, addr) -> None:
"""Handle discovery responses."""
Expand Down Expand Up @@ -219,8 +216,7 @@ async def discover(
"""
loop = asyncio.get_event_loop()
transport, protocol = await loop.create_datagram_endpoint(
lambda: _DiscoverProtocol(
target=target,
lambda: DiscoveryProtocol(
on_discovered=on_discovered,
discovery_packets=discovery_packets,
interface=interface,
Expand All @@ -229,7 +225,8 @@ async def discover(
),
local_addr=("0.0.0.0", 0),
)
protocol = cast(_DiscoverProtocol, protocol)
protocol = cast(DiscoveryProtocol, protocol)
protocol.do_discover(target)

try:
_LOGGER.debug("Waiting %s seconds for responses...", timeout)
Expand Down Expand Up @@ -258,12 +255,12 @@ async def discover_single(
loop = asyncio.get_event_loop()
event = asyncio.Event()
transport, protocol = await loop.create_datagram_endpoint(
lambda: _DiscoverProtocol(
target=host, port=port, discovered_event=event, credentials=credentials
lambda: DiscoveryProtocol(port=port, discovered_event=event, credentials=credentials
),
local_addr=("0.0.0.0", 0),
)
protocol = cast(_DiscoverProtocol, protocol)
protocol = cast(DiscoveryProtocol, protocol)
protocol.do_discover(host)

try:
_LOGGER.debug("Waiting a total of %s seconds for responses...", timeout)
Expand Down
23 changes: 11 additions & 12 deletions kasa/tests/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342

from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol
from kasa.discover import _DiscoverProtocol, json_dumps
from kasa.discover import DiscoveryProtocol, json_dumps
from kasa.exceptions import UnsupportedDeviceException

from .conftest import bulb, dimmer, lightstrip, plug, strip
Expand Down Expand Up @@ -59,13 +59,13 @@ async def test_discover_single(discovery_data: dict, mocker, custom_port):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "127.0.0.1"

def mock_discover(self):
def mock_discover(self, host):
self.datagram_received(
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:],
(host, custom_port or 9999),
)

mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
mocker.patch.object(DiscoveryProtocol, "do_discover", mock_discover)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)

x = await Discover.discover_single(host, port=custom_port)
Expand Down Expand Up @@ -100,15 +100,15 @@ async def test_discover_single_unsupported(mocker):
"""Make sure that discover_single handles unsupported devices correctly."""
host = "127.0.0.1"

def mock_discover(self):
def mock_discover(self, host):
if discovery_data:
data = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
+ json_dumps(discovery_data).encode()
)
self.datagram_received(data, (host, 20002))

mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
mocker.patch.object(DiscoveryProtocol, "do_discover", mock_discover)

# Test with a valid unsupported response
discovery_data = UNSUPPORTED
Expand Down Expand Up @@ -141,30 +141,29 @@ async def test_discover_invalid_info(msg, data, mocker):
"""Make sure that invalid discovery information raises an exception."""
host = "127.0.0.1"

def mock_discover(self):
def mock_discover(self, host):
self.datagram_received(
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(data))[4:], (host, 9999)
)

mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
mocker.patch.object(DiscoveryProtocol, "do_discover", mock_discover)

with pytest.raises(SmartDeviceException, match=msg):
await Discover.discover_single(host)


async def test_discover_send(mocker):
"""Test discovery parameters."""
proto = _DiscoverProtocol()
proto = DiscoveryProtocol(port=9999)
assert proto.discovery_packets == 3
assert proto.target == ("255.255.255.255", 9999)
transport = mocker.patch.object(proto, "transport")
proto.do_discover()
proto.do_discover("255.255.255.255")
assert transport.sendto.call_count == proto.discovery_packets * 2


async def test_discover_datagram_received(mocker, discovery_data):
"""Verify that datagram received fills discovered_devices."""
proto = _DiscoverProtocol()
proto = DiscoveryProtocol()
mocker.patch("kasa.discover.json_loads", return_value=discovery_data)
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt")
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
Expand All @@ -186,7 +185,7 @@ async def test_discover_datagram_received(mocker, discovery_data):
@pytest.mark.parametrize("msg, data", INVALIDS)
async def test_discover_invalid_responses(msg, data, mocker):
"""Verify that we don't crash whole discovery if some devices in the network are sending unexpected data."""
proto = _DiscoverProtocol()
proto = DiscoveryProtocol()
mocker.patch("kasa.discover.json_loads", return_value=data)
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt")
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
Expand Down

0 comments on commit 46199e3

Please sign in to comment.