Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make DiscoveryProtocol part of the public API #524

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion kasa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from importlib.metadata import version

from kasa.credentials import Credentials
from kasa.discover import Discover
from kasa.discover import Discover, DiscoveryProtocol
from kasa.emeterstatus import EmeterStatus
from kasa.exceptions import (
AuthenticationException,
Expand All @@ -34,6 +34,7 @@

__all__ = [
"Discover",
"DiscoveryProtocol",
"TPLinkSmartHomeProtocol",
"SmartBulb",
"SmartBulbPreset",
Expand Down
30 changes: 14 additions & 16 deletions kasa/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
DeviceDict = Dict[str, SmartDevice]


class _DiscoverProtocol(asyncio.DatagramProtocol):
class DiscoveryProtocol(asyncio.DatagramProtocol):
"""Implementation of the discovery protocol handler.

This is internal class, use :func:`Discover.discover`: instead.
Expand All @@ -40,7 +40,6 @@
self,
*,
on_discovered: Optional[OnDiscoveredCallable] = None,
target: str = "255.255.255.255",
discovery_packets: int = 3,
interface: Optional[str] = None,
on_unsupported: Optional[Callable[[Dict], Awaitable[None]]] = None,
Expand All @@ -53,8 +52,6 @@
self.interface = interface
self.on_discovered = on_discovered
self.discovery_port = port or Discover.DISCOVERY_PORT
self.target = (target, self.discovery_port)
self.target_2 = (target, Discover.DISCOVERY_PORT_2)
self.discovered_devices = {}
self.unsupported_devices: Dict = {}
self.invalid_device_exceptions: Dict = {}
Expand All @@ -78,16 +75,16 @@
socket.SOL_SOCKET, socket.SO_BINDTODEVICE, self.interface.encode()
)

self.do_discover()

def do_discover(self) -> None:
def do_discover(self, host: str) -> None:
"""Send number of discovery datagrams."""
req = json_dumps(Discover.DISCOVERY_QUERY)
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
target = (host, self.discovery_port)
target_2 = (host, Discover.DISCOVERY_PORT_2)
_LOGGER.debug("[DISCOVERY] %s >> %s", host, Discover.DISCOVERY_QUERY)
encrypted_req = TPLinkSmartHomeProtocol.encrypt(req)
for i in range(self.discovery_packets):
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
self.transport.sendto(encrypted_req[4:], target) # type: ignore
self.transport.sendto(Discover.DISCOVERY_QUERY_2, target_2) # type: ignore

def datagram_received(self, data, addr) -> None:
"""Handle discovery responses."""
Expand Down Expand Up @@ -219,8 +216,7 @@
"""
loop = asyncio.get_event_loop()
transport, protocol = await loop.create_datagram_endpoint(
lambda: _DiscoverProtocol(
target=target,
lambda: DiscoveryProtocol(
on_discovered=on_discovered,
discovery_packets=discovery_packets,
interface=interface,
Expand All @@ -229,7 +225,8 @@
),
local_addr=("0.0.0.0", 0),
)
protocol = cast(_DiscoverProtocol, protocol)
protocol = cast(DiscoveryProtocol, protocol)
protocol.do_discover(target)

Check warning on line 229 in kasa/discover.py

View check run for this annotation

Codecov / codecov/patch

kasa/discover.py#L228-L229

Added lines #L228 - L229 were not covered by tests

try:
_LOGGER.debug("Waiting %s seconds for responses...", timeout)
Expand Down Expand Up @@ -258,12 +255,13 @@
loop = asyncio.get_event_loop()
event = asyncio.Event()
transport, protocol = await loop.create_datagram_endpoint(
lambda: _DiscoverProtocol(
target=host, port=port, discovered_event=event, credentials=credentials
lambda: DiscoveryProtocol(
port=port, discovered_event=event, credentials=credentials
),
local_addr=("0.0.0.0", 0),
)
protocol = cast(_DiscoverProtocol, protocol)
protocol = cast(DiscoveryProtocol, protocol)
protocol.do_discover(host)

try:
_LOGGER.debug("Waiting a total of %s seconds for responses...", timeout)
Expand Down
23 changes: 11 additions & 12 deletions kasa/tests/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342

from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol
from kasa.discover import _DiscoverProtocol, json_dumps
from kasa.discover import DiscoveryProtocol, json_dumps
from kasa.exceptions import UnsupportedDeviceException

from .conftest import bulb, dimmer, lightstrip, plug, strip
Expand Down Expand Up @@ -59,13 +59,13 @@ async def test_discover_single(discovery_data: dict, mocker, custom_port):
"""Make sure that discover_single returns an initialized SmartDevice instance."""
host = "127.0.0.1"

def mock_discover(self):
def mock_discover(self, host):
self.datagram_received(
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:],
(host, custom_port or 9999),
)

mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
mocker.patch.object(DiscoveryProtocol, "do_discover", mock_discover)
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)

x = await Discover.discover_single(host, port=custom_port)
Expand Down Expand Up @@ -100,15 +100,15 @@ async def test_discover_single_unsupported(mocker):
"""Make sure that discover_single handles unsupported devices correctly."""
host = "127.0.0.1"

def mock_discover(self):
def mock_discover(self, host):
if discovery_data:
data = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
+ json_dumps(discovery_data).encode()
)
self.datagram_received(data, (host, 20002))

mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
mocker.patch.object(DiscoveryProtocol, "do_discover", mock_discover)

# Test with a valid unsupported response
discovery_data = UNSUPPORTED
Expand Down Expand Up @@ -141,30 +141,29 @@ async def test_discover_invalid_info(msg, data, mocker):
"""Make sure that invalid discovery information raises an exception."""
host = "127.0.0.1"

def mock_discover(self):
def mock_discover(self, host):
self.datagram_received(
protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(data))[4:], (host, 9999)
)

mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover)
mocker.patch.object(DiscoveryProtocol, "do_discover", mock_discover)

with pytest.raises(SmartDeviceException, match=msg):
await Discover.discover_single(host)


async def test_discover_send(mocker):
"""Test discovery parameters."""
proto = _DiscoverProtocol()
proto = DiscoveryProtocol(port=9999)
assert proto.discovery_packets == 3
assert proto.target == ("255.255.255.255", 9999)
transport = mocker.patch.object(proto, "transport")
proto.do_discover()
proto.do_discover("255.255.255.255")
assert transport.sendto.call_count == proto.discovery_packets * 2


async def test_discover_datagram_received(mocker, discovery_data):
"""Verify that datagram received fills discovered_devices."""
proto = _DiscoverProtocol()
proto = DiscoveryProtocol()
mocker.patch("kasa.discover.json_loads", return_value=discovery_data)
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt")
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
Expand All @@ -186,7 +185,7 @@ async def test_discover_datagram_received(mocker, discovery_data):
@pytest.mark.parametrize("msg, data", INVALIDS)
async def test_discover_invalid_responses(msg, data, mocker):
"""Verify that we don't crash whole discovery if some devices in the network are sending unexpected data."""
proto = _DiscoverProtocol()
proto = DiscoveryProtocol()
mocker.patch("kasa.discover.json_loads", return_value=data)
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt")
mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt")
Expand Down
Loading