Skip to content

Commit

Permalink
Various test code cleanups (#725)
Browse files Browse the repository at this point in the history
* Separate fake protocols for iot and smart

* Move control_child impl into its own method

* Organize schemas into correct places

* Add test_childdevice

* Add missing return for _handle_control_child
  • Loading branch information
rytilahti committed Jan 29, 2024
1 parent 1e26434 commit 9e6896a
Show file tree
Hide file tree
Showing 8 changed files with 333 additions and 305 deletions.
7 changes: 4 additions & 3 deletions kasa/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from kasa.tapo import TapoBulb, TapoPlug
from kasa.xortransport import XorEncryption

from .newfakes import FakeSmartProtocol, FakeTransportProtocol
from .fakeprotocol_iot import FakeIotProtocol
from .fakeprotocol_smart import FakeSmartProtocol

SUPPORTED_IOT_DEVICES = [
(device, "IOT")
Expand Down Expand Up @@ -410,7 +411,7 @@ def load_file():
if protocol == "SMART":
d.protocol = FakeSmartProtocol(sysinfo)
else:
d.protocol = FakeTransportProtocol(sysinfo)
d.protocol = FakeIotProtocol(sysinfo)
await _update_and_close(d)
return d

Expand Down Expand Up @@ -521,7 +522,7 @@ def mock_discover(self):
if "component_nego" in dm.query_data:
proto = FakeSmartProtocol(dm.query_data)
else:
proto = FakeTransportProtocol(dm.query_data)
proto = FakeIotProtocol(dm.query_data)

async def _query(request, retry_count: int = 3):
return await proto.query(request)
Expand Down
294 changes: 3 additions & 291 deletions kasa/tests/newfakes.py → kasa/tests/fakeprotocol_iot.py
Original file line number Diff line number Diff line change
@@ -1,185 +1,13 @@
import base64
import copy
import logging
import re
import warnings
from json import loads as json_loads

from voluptuous import (
REMOVE_EXTRA,
All,
Any,
Coerce, # type: ignore
Invalid,
Optional,
Range,
Schema,
)

from ..credentials import Credentials

from ..deviceconfig import DeviceConfig
from ..exceptions import SmartDeviceException
from ..iotprotocol import IotProtocol
from ..protocol import BaseTransport
from ..smartprotocol import SmartProtocol
from ..xortransport import XorTransport

_LOGGER = logging.getLogger(__name__)


def check_int_bool(x):
if x != 0 and x != 1:
raise Invalid(x)
return x


def check_mac(x):
if re.match("[0-9a-f]{2}([-:])[0-9a-f]{2}(\\1[0-9a-f]{2}){4}$", x.lower()):
return x
raise Invalid(x)


def check_mode(x):
if x in ["schedule", "none", "count_down"]:
return x

raise Invalid(f"invalid mode {x}")


def lb_dev_state(x):
if x in ["normal"]:
return x

raise Invalid(f"Invalid dev_state {x}")


TZ_SCHEMA = Schema(
{"zone_str": str, "dst_offset": int, "index": All(int, Range(min=0)), "tz_str": str}
)

CURRENT_CONSUMPTION_SCHEMA = Schema(
Any(
{
"voltage": Any(All(float, Range(min=0, max=300)), None),
"power": Any(Coerce(float, Range(min=0)), None),
"total": Any(Coerce(float, Range(min=0)), None),
"current": Any(All(float, Range(min=0)), None),
"voltage_mv": Any(
All(float, Range(min=0, max=300000)), int, None
), # TODO can this be int?
"power_mw": Any(Coerce(float, Range(min=0)), None),
"total_wh": Any(Coerce(float, Range(min=0)), None),
"current_ma": Any(
All(float, Range(min=0)), int, None
), # TODO can this be int?
"slot_id": Any(Coerce(int, Range(min=0)), None),
},
None,
)
)

# these schemas should go to the mainlib as
# they can be useful when adding support for new features/devices
# as well as to check that faked devices are operating properly.
PLUG_SCHEMA = Schema(
{
"active_mode": check_mode,
"alias": str,
"dev_name": str,
"deviceId": str,
"feature": str,
"fwId": str,
"hwId": str,
"hw_ver": str,
"icon_hash": str,
"led_off": check_int_bool,
"latitude": Any(All(float, Range(min=-90, max=90)), 0, None),
"latitude_i": Any(
All(int, Range(min=-900000, max=900000)),
All(float, Range(min=-900000, max=900000)),
0,
None,
),
"longitude": Any(All(float, Range(min=-180, max=180)), 0, None),
"longitude_i": Any(
All(int, Range(min=-18000000, max=18000000)),
All(float, Range(min=-18000000, max=18000000)),
0,
None,
),
"mac": check_mac,
"model": str,
"oemId": str,
"on_time": int,
"relay_state": int,
"rssi": Any(int, None), # rssi can also be positive, see #54
"sw_ver": str,
"type": str,
"mic_type": str,
"updating": check_int_bool,
# these are available on hs220
"brightness": int,
"preferred_state": [
{"brightness": All(int, Range(min=0, max=100)), "index": int}
],
"next_action": {"type": int},
"child_num": Optional(Any(None, int)), # TODO fix hs300 checks
"children": Optional(list), # TODO fix hs300
# TODO some tplink simulator entries contain invalid (mic_mac, _i variants for lat/lon)
# Therefore we add REMOVE_EXTRA..
# "INVALIDmac": Optional,
# "INVALIDlatitude": Optional,
# "INVALIDlongitude": Optional,
},
extra=REMOVE_EXTRA,
)

LIGHT_STATE_SCHEMA = Schema(
{
"brightness": All(int, Range(min=0, max=100)),
"color_temp": int,
"hue": All(int, Range(min=0, max=360)),
"mode": str,
"on_off": check_int_bool,
"saturation": All(int, Range(min=0, max=100)),
"dft_on_state": Optional(
{
"brightness": All(int, Range(min=0, max=100)),
"color_temp": All(int, Range(min=0, max=9000)),
"hue": All(int, Range(min=0, max=360)),
"mode": str,
"saturation": All(int, Range(min=0, max=100)),
}
),
"err_code": int,
}
)

BULB_SCHEMA = PLUG_SCHEMA.extend(
{
"ctrl_protocols": Optional(dict),
"description": Optional(str), # TODO: LBxxx similar to dev_name
"dev_state": lb_dev_state,
"disco_ver": str,
"heapsize": int,
"is_color": check_int_bool,
"is_dimmable": check_int_bool,
"is_factory": bool,
"is_variable_color_temp": check_int_bool,
"light_state": LIGHT_STATE_SCHEMA,
"preferred_state": [
{
"brightness": All(int, Range(min=0, max=100)),
"color_temp": int,
"hue": All(int, Range(min=0, max=360)),
"index": int,
"saturation": All(int, Range(min=0, max=100)),
}
],
}
)


def get_realtime(obj, x, *args):
return {
"current": 0.268587,
Expand Down Expand Up @@ -294,123 +122,7 @@ def success(res):
}


class FakeSmartProtocol(SmartProtocol):
def __init__(self, info):
super().__init__(
transport=FakeSmartTransport(info),
)

async def query(self, request, retry_count: int = 3):
"""Implement query here so can still patch SmartProtocol.query."""
resp_dict = await self._query(request, retry_count)
return resp_dict


class FakeSmartTransport(BaseTransport):
def __init__(self, info):
super().__init__(
config=DeviceConfig(
"127.0.0.123",
credentials=Credentials(
username="dummy_user",
password="dummy_password", # noqa: S106
),
),
)
self.info = info
self.components = {
comp["id"]: comp["ver_code"]
for comp in self.info["component_nego"]["component_list"]
}

@property
def default_port(self):
"""Default port for the transport."""
return 80

@property
def credentials_hash(self):
"""The hashed credentials used by the transport."""
return self._credentials.username + self._credentials.password + "hash"

FIXTURE_MISSING_MAP = {
"get_wireless_scan_info": ("wireless", {"ap_list": [], "wep_supported": False}),
}

async def send(self, request: str):
request_dict = json_loads(request)
method = request_dict["method"]
params = request_dict["params"]
if method == "multipleRequest":
responses = []
for request in params["requests"]:
response = self._send_request(request) # type: ignore[arg-type]
response["method"] = request["method"] # type: ignore[index]
responses.append(response)
return {"result": {"responses": responses}, "error_code": 0}
else:
return self._send_request(request_dict)

def _send_request(self, request_dict: dict):
method = request_dict["method"]
params = request_dict["params"]

info = self.info
if method == "control_child":
device_id = params.get("device_id")
request_data = params.get("requestData")

child_method = request_data.get("method")
child_params = request_data.get("params")

children = info["get_child_device_list"]["child_device_list"]

for child in children:
if child["device_id"] == device_id:
info = child
break

# We only support get & set device info for now.
if child_method == "get_device_info":
return {"result": info, "error_code": 0}
elif child_method == "set_device_info":
info.update(child_params)
return {"error_code": 0}

raise NotImplementedError(
"Method %s not implemented for children" % child_method
)

if method == "component_nego" or method[:4] == "get_":
if method in info:
return {"result": info[method], "error_code": 0}
elif (
missing_result := self.FIXTURE_MISSING_MAP.get(method)
) and missing_result[0] in self.components:
warnings.warn(
UserWarning(
f"Fixture missing expected method {method}, try to regenerate"
),
stacklevel=1,
)
return {"result": missing_result[1], "error_code": 0}
else:
raise SmartDeviceException(f"Fixture doesn't support {method}")
elif method == "set_qs_info":
return {"error_code": 0}
elif method[:4] == "set_":
target_method = f"get_{method[4:]}"
info[target_method].update(params)
return {"error_code": 0}

async def close(self) -> None:
pass

async def reset(self) -> None:
pass


class FakeTransportProtocol(IotProtocol):
class FakeIotProtocol(IotProtocol):
def __init__(self, info):
super().__init__(
transport=XorTransport(
Expand All @@ -420,7 +132,7 @@ def __init__(self, info):
self.discovery_data = info
self.writer = None
self.reader = None
proto = copy.deepcopy(FakeTransportProtocol.baseproto)
proto = copy.deepcopy(FakeIotProtocol.baseproto)

for target in info:
# print("target %s" % target)
Expand Down

0 comments on commit 9e6896a

Please sign in to comment.