Skip to content

Commit

Permalink
Fix discovery cli to print devices not printed during discovery timeo…
Browse files Browse the repository at this point in the history
…ut (#670)

* Fix discovery cli to print devices not printed during discovery

* Fix tests

* Fix print exceptions not being propagated

* Fix tests

* Reduce test discover_send time

* Simplify wait logic

* Add tests

* Remove sleep loop and make auth failed a list
  • Loading branch information
sdb9696 committed Feb 5, 2024
1 parent 0d119e6 commit 215b8d4
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 63 deletions.
7 changes: 5 additions & 2 deletions kasa/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,12 +444,12 @@ async def print_discovered(dev: Device):
_echo_discovery_info(dev._discovery_info)
echo()
else:
discovered[dev.host] = dev.internal_state
ctx.parent.obj = dev
await ctx.parent.invoke(state)
discovered[dev.host] = dev.internal_state
echo()

await Discover.discover(
discovered_devices = await Discover.discover(
target=target,
discovery_timeout=discovery_timeout,
on_discovered=print_discovered,
Expand All @@ -459,6 +459,9 @@ async def print_discovered(dev: Device):
credentials=credentials,
)

for device in discovered_devices.values():
await device.protocol.close()

echo(f"Found {len(discovered)} devices")
if unsupported:
echo(f"Found {len(unsupported)} unsupported devices")
Expand Down
61 changes: 39 additions & 22 deletions kasa/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ipaddress
import logging
import socket
from typing import Awaitable, Callable, Dict, Optional, Set, Type, cast
from typing import Awaitable, Callable, Dict, List, Optional, Set, Type, cast

# When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout
Expand Down Expand Up @@ -46,6 +46,8 @@ class _DiscoverProtocol(asyncio.DatagramProtocol):
This is internal class, use :func:`Discover.discover`: instead.
"""

DISCOVERY_START_TIMEOUT = 1

discovered_devices: DeviceDict

def __init__(
Expand All @@ -60,7 +62,6 @@ def __init__(
Callable[[UnsupportedDeviceException], Awaitable[None]]
] = None,
port: Optional[int] = None,
discovered_event: Optional[asyncio.Event] = None,
credentials: Optional[Credentials] = None,
timeout: Optional[int] = None,
) -> None:
Expand All @@ -79,12 +80,32 @@ def __init__(
self.unsupported_device_exceptions: Dict = {}
self.invalid_device_exceptions: Dict = {}
self.on_unsupported = on_unsupported
self.discovered_event = discovered_event
self.credentials = credentials
self.timeout = timeout
self.discovery_timeout = discovery_timeout
self.seen_hosts: Set[str] = set()
self.discover_task: Optional[asyncio.Task] = None
self.callback_tasks: List[asyncio.Task] = []
self.target_discovered: bool = False
self._started_event = asyncio.Event()

def _run_callback_task(self, coro):
task = asyncio.create_task(coro)
self.callback_tasks.append(task)

async def wait_for_discovery_to_complete(self):
"""Wait for the discovery task to complete."""
# Give some time for connection_made event to be received
async with asyncio_timeout(self.DISCOVERY_START_TIMEOUT):
await self._started_event.wait()
try:
await self.discover_task
except asyncio.CancelledError:
# if target_discovered then cancel was called internally
if not self.target_discovered:
raise
# Wait for any pending callbacks to complete
await asyncio.gather(*self.callback_tasks)

def connection_made(self, transport) -> None:
"""Set socket options for broadcasting."""
Expand All @@ -103,20 +124,20 @@ def connection_made(self, transport) -> None:
)

self.discover_task = asyncio.create_task(self.do_discover())
self._started_event.set()

async def do_discover(self) -> None:
"""Send number of discovery datagrams."""
req = json_dumps(Discover.DISCOVERY_QUERY)
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
encrypted_req = XorEncryption.encrypt(req)
sleep_between_packets = self.discovery_timeout / self.discovery_packets
for i in range(self.discovery_packets):
for _ in range(self.discovery_packets):
if self.target in self.seen_hosts: # Stop sending for discover_single
break
self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
if i < self.discovery_packets - 1:
await asyncio.sleep(sleep_between_packets)
await asyncio.sleep(sleep_between_packets)

def datagram_received(self, data, addr) -> None:
"""Handle discovery responses."""
Expand Down Expand Up @@ -145,7 +166,7 @@ def datagram_received(self, data, addr) -> None:
_LOGGER.debug("Unsupported device found at %s << %s", ip, udex)
self.unsupported_device_exceptions[ip] = udex
if self.on_unsupported is not None:
asyncio.ensure_future(self.on_unsupported(udex))
self._run_callback_task(self.on_unsupported(udex))
self._handle_discovered_event()
return
except SmartDeviceException as ex:
Expand All @@ -157,16 +178,16 @@ def datagram_received(self, data, addr) -> None:
self.discovered_devices[ip] = device

if self.on_discovered is not None:
asyncio.ensure_future(self.on_discovered(device))
self._run_callback_task(self.on_discovered(device))

self._handle_discovered_event()

def _handle_discovered_event(self):
"""If discovered_event is available set it and cancel discover_task."""
if self.discovered_event is not None:
"""If target is in seen_hosts cancel discover_task."""
if self.target in self.seen_hosts:
self.target_discovered = True
if self.discover_task:
self.discover_task.cancel()
self.discovered_event.set()

def error_received(self, ex):
"""Handle asyncio.Protocol errors."""
Expand Down Expand Up @@ -289,7 +310,11 @@ async def discover(

try:
_LOGGER.debug("Waiting %s seconds for responses...", discovery_timeout)
await asyncio.sleep(discovery_timeout)
await protocol.wait_for_discovery_to_complete()
except SmartDeviceException as ex:
for device in protocol.discovered_devices.values():
await device.protocol.close()
raise ex
finally:
transport.close()

Expand Down Expand Up @@ -322,7 +347,6 @@ async def discover_single(
:return: Object for querying/controlling found device.
"""
loop = asyncio.get_event_loop()
event = asyncio.Event()

try:
ipaddress.ip_address(host)
Expand Down Expand Up @@ -352,7 +376,6 @@ async def discover_single(
lambda: _DiscoverProtocol(
target=ip,
port=port,
discovered_event=event,
credentials=credentials,
timeout=timeout,
discovery_timeout=discovery_timeout,
Expand All @@ -365,13 +388,7 @@ async def discover_single(
_LOGGER.debug(
"Waiting a total of %s seconds for responses...", discovery_timeout
)

async with asyncio_timeout(discovery_timeout):
await event.wait()
except asyncio.TimeoutError as ex:
raise TimeoutException(
f"Timed out getting discovery response for {host}"
) from ex
await protocol.wait_for_discovery_to_complete()
finally:
transport.close()

Expand All @@ -384,7 +401,7 @@ async def discover_single(
elif ip in protocol.invalid_device_exceptions:
raise protocol.invalid_device_exceptions[ip]
else:
raise SmartDeviceException(f"Unable to get discovery response for {host}")
raise TimeoutException(f"Timed out getting discovery response for {host}")

@staticmethod
def _get_device_class(info: dict) -> Type[Device]:
Expand Down
4 changes: 2 additions & 2 deletions kasa/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ class _DiscoveryMock:
login_version,
)

def mock_discover(self):
async def mock_discover(self):
port = (
dm.port_override
if dm.port_override and dm.discovery_port != 20002
Expand Down Expand Up @@ -561,7 +561,7 @@ def unsupported_device_info(request, mocker):
discovery_data = request.param
host = "127.0.0.1"

def mock_discover(self):
async def mock_discover(self):
if discovery_data:
data = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
Expand Down
22 changes: 21 additions & 1 deletion kasa/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ async def test_brightness(dev):
@device_iot
async def test_json_output(dev: Device, mocker):
"""Test that the json output produces correct output."""
mocker.patch("kasa.Discover.discover", return_value=[dev])
mocker.patch("kasa.Discover.discover", return_value={"127.0.0.1": dev})
runner = CliRunner()
res = await runner.invoke(cli, ["--json", "state"], obj=dev)
assert res.exit_code == 0
Expand Down Expand Up @@ -415,6 +415,26 @@ async def test_discover(discovery_mock, mocker):
assert res.exit_code == 0


async def test_discover_host(discovery_mock, mocker):
"""Test discovery output."""
runner = CliRunner()
res = await runner.invoke(
cli,
[
"--discovery-timeout",
0,
"--host",
"127.0.0.123",
"--username",
"foo",
"--password",
"bar",
"--verbose",
],
)
assert res.exit_code == 0


async def test_discover_unsupported(unsupported_device_info):
"""Test discovery output."""
runner = CliRunner()
Expand Down

0 comments on commit 215b8d4

Please sign in to comment.