Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move connect_single to SmartDevice.connect #538

Merged
merged 30 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions kasa/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@

from kasa import (
Credentials,
DeviceType,
Discover,
SmartBulb,
SmartDevice,
SmartDimmer,
SmartLightStrip,
SmartPlug,
SmartStrip,
)
from kasa.discover import DEVICE_TYPE_TO_CLASS

try:
from rich import print as _do_echo
Expand All @@ -43,11 +42,9 @@ def wrapper(message=None, *args, **kwargs):
echo = _do_echo

TYPE_TO_CLASS = {
"plug": SmartPlug,
"bulb": SmartBulb,
"dimmer": SmartDimmer,
"strip": SmartStrip,
"lightstrip": SmartLightStrip,
device_type.value: DEVICE_TYPE_TO_CLASS[device_type]
for device_type in DeviceType
if device_type in DEVICE_TYPE_TO_CLASS
}

click.anyio_backend = "asyncio"
Expand Down
38 changes: 26 additions & 12 deletions kasa/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from kasa.json import loads as json_loads
from kasa.protocol import TPLinkSmartHomeProtocol
from kasa.smartbulb import SmartBulb
from kasa.smartdevice import SmartDevice, SmartDeviceException
from kasa.smartdevice import DeviceType, SmartDevice, SmartDeviceException
from kasa.smartdimmer import SmartDimmer
from kasa.smartlightstrip import SmartLightStrip
from kasa.smartplug import SmartPlug
Expand All @@ -27,6 +27,14 @@
OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]]
DeviceDict = Dict[str, SmartDevice]

DEVICE_TYPE_TO_CLASS = {
DeviceType.Plug: SmartPlug,
DeviceType.Bulb: SmartBulb,
DeviceType.Strip: SmartStrip,
DeviceType.Dimmer: SmartDimmer,
DeviceType.LightStrip: SmartLightStrip,
}


class _DiscoverProtocol(asyncio.DatagramProtocol):
"""Implementation of the discovery protocol handler.
Expand Down Expand Up @@ -317,6 +325,7 @@ async def connect_single(
port: Optional[int] = None,
timeout=5,
credentials: Optional[Credentials] = None,
device_type: Optional[DeviceType] = None,
) -> SmartDevice:
"""Connect to a single device by the given IP address.

Expand All @@ -331,20 +340,25 @@ async def connect_single(
The device type is discovered by querying the device.
bdraco marked this conversation as resolved.
Show resolved Hide resolved

:param host: Hostname of device to query
:param device_type: Device type to use for the device.
:rtype: SmartDevice
bdraco marked this conversation as resolved.
Show resolved Hide resolved
:return: Object for querying/controlling found device.
"""
unknown_dev = SmartDevice(
host=host, port=port, credentials=credentials, timeout=timeout
)
await unknown_dev.update()
device_class = Discover._get_device_class(unknown_dev.internal_state)
dev = device_class(
host=host, port=port, credentials=credentials, timeout=timeout
)
# Reuse the connection from the unknown device
# so we don't have to reconnect
dev.protocol = unknown_dev.protocol
if device_type and (klass := DEVICE_TYPE_TO_CLASS.get(device_type)):
dev = klass(host=host, port=port, credentials=credentials, timeout=timeout)
bdraco marked this conversation as resolved.
Show resolved Hide resolved
else:
unknown_dev = SmartDevice(
host=host, port=port, credentials=credentials, timeout=timeout
)
await unknown_dev.update()
device_class = Discover._get_device_class(unknown_dev.internal_state)
dev = device_class(
host=host, port=port, credentials=credentials, timeout=timeout
)
# Reuse the connection from the unknown device
# so we don't have to reconnect
dev.protocol = unknown_dev.protocol
await dev.update()
return dev

@staticmethod
Expand Down
26 changes: 18 additions & 8 deletions kasa/smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum, auto
from enum import Enum
from typing import Any, Dict, List, Optional, Set

from .credentials import Credentials
Expand All @@ -32,13 +32,23 @@
class DeviceType(Enum):
"""Device type enum."""

Plug = auto()
Bulb = auto()
Strip = auto()
StripSocket = auto()
Dimmer = auto()
LightStrip = auto()
Unknown = -1
# The values match what the cli has historically used

bdraco marked this conversation as resolved.
Show resolved Hide resolved
Plug = "plug"
Bulb = "bulb"
Strip = "strip"
StripSocket = "stripsocket"
Dimmer = "dimmer"
LightStrip = "lightstrip"
Unknown = "unknown"

@staticmethod
def from_value(name: str) -> "DeviceType":
"""Return device type from string value."""
for device_type in DeviceType:
if device_type.value == name:
return device_type
return DeviceType.Unknown


@dataclass
Expand Down
41 changes: 39 additions & 2 deletions kasa/tests/test_discovery.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
# type: ignore
import re
import sys
from typing import Type

import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342

from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol
from kasa import (
DeviceType,
Discover,
SmartBulb,
SmartDevice,
SmartDeviceException,
SmartDimmer,
SmartLightStrip,
SmartPlug,
protocol,
)
from kasa.discover import _DiscoverProtocol, json_dumps
from kasa.exceptions import UnsupportedDeviceException

Expand Down Expand Up @@ -85,6 +95,33 @@ async def test_connect_single(discovery_data: dict, mocker, custom_port):
assert dev.port == custom_port or dev.port == 9999


@pytest.mark.parametrize("custom_port", [123, None])
@pytest.mark.parametrize(
("device_type", "klass"),
(
(DeviceType.Plug, SmartPlug),
(DeviceType.Bulb, SmartBulb),
(DeviceType.Dimmer, SmartDimmer),
(DeviceType.LightStrip, SmartLightStrip),
(DeviceType.Unknown, SmartDevice),
),
)
async def test_connect_single_passed_device_type(
discovery_data: dict,
mocker,
device_type: DeviceType,
klass: Type[SmartDevice],
custom_port,
):
"""Make sure that connect_single with a passed device type."""
host = "127.0.0.1"
mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data)

dev = await Discover.connect_single(host, port=custom_port, device_type=device_type)
assert isinstance(dev, klass)
assert dev.port == custom_port or dev.port == 9999


async def test_connect_single_query_fails(discovery_data: dict, mocker):
"""Make sure that connect_single fails when query fails."""
host = "127.0.0.1"
Expand Down
23 changes: 23 additions & 0 deletions kasa/tests/test_smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import kasa
from kasa import Credentials, SmartDevice, SmartDeviceException
from kasa.smartdevice import DeviceType
from kasa.smartstrip import SmartStripPlug

from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on
Expand Down Expand Up @@ -58,6 +59,28 @@ async def test_initial_update_no_emeter(dev, mocker):
assert spy.call_count == 2


async def test_smart_device_from_value():
"""Make sure that every device type can be created from its value."""
for name in DeviceType:
assert DeviceType.from_value(name.value) is not None

assert DeviceType.from_value("nonexistent") is DeviceType.Unknown
assert DeviceType.from_value("plug") is DeviceType.Plug
assert DeviceType.Plug.value == "plug"

assert DeviceType.from_value("bulb") is DeviceType.Bulb
assert DeviceType.Bulb.value == "bulb"

assert DeviceType.from_value("dimmer") is DeviceType.Dimmer
assert DeviceType.Dimmer.value == "dimmer"

assert DeviceType.from_value("strip") is DeviceType.Strip
assert DeviceType.Strip.value == "strip"

assert DeviceType.from_value("lightstrip") is DeviceType.LightStrip
assert DeviceType.LightStrip.value == "lightstrip"


async def test_query_helper(dev):
with pytest.raises(SmartDeviceException):
await dev._query_helper("test", "testcmd", {})
Expand Down
Loading