Skip to content

Commit

Permalink
Keep connection open and lock to prevent duplicate requests (#213)
Browse files Browse the repository at this point in the history
* Keep connection open and lock to prevent duplicate requests

* option to not update children

* tweaks

* typing

* tweaks

* run tests in the same event loop

* memorize model

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* Update kasa/protocol.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* dry

* tweaks

* warn when the event loop gets switched out from under us

* raise on unable to connect multiple times

* fix patch target

* tweaks

* isrot

* reconnect test

* prune

* fix mocking

* fix mocking

* fix test under python 3.7

* fix test under python 3.7

* less patching

* isort

* use mocker to patch

* disable on old python since mocking doesnt work

* avoid disconnect/reconnect cycles

* isort

* Fix hue validation

* Fix latitude_i/longitude_i units

Co-authored-by: Teemu R. <tpr@iki.fi>
  • Loading branch information
bdraco and rytilahti committed Sep 24, 2021
1 parent f1b28e7 commit e31cc66
Show file tree
Hide file tree
Showing 11 changed files with 238 additions and 93 deletions.
17 changes: 11 additions & 6 deletions devtools/dump_devinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,17 @@ def cli(host, debug):
),
]

protocol = TPLinkSmartHomeProtocol()

successes = []

for test_call in items:

async def _run_query():
protocol = TPLinkSmartHomeProtocol(host)
return await protocol.query({test_call.module: {test_call.method: None}})

try:
click.echo(f"Testing {test_call}..", nl=False)
info = asyncio.run(
protocol.query(host, {test_call.module: {test_call.method: None}})
)
info = asyncio.run(_run_query())
resp = info[test_call.module]
except Exception as ex:
click.echo(click.style(f"FAIL {ex}", fg="red"))
Expand All @@ -107,8 +108,12 @@ def cli(host, debug):

final = default_to_regular(final)

async def _run_final_query():
protocol = TPLinkSmartHomeProtocol(host)
return await protocol.query(final_query)

try:
final = asyncio.run(protocol.query(host, final_query))
final = asyncio.run(_run_final_query())
except Exception as ex:
click.echo(
click.style(
Expand Down
9 changes: 4 additions & 5 deletions kasa/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(
self.discovery_packets = discovery_packets
self.interface = interface
self.on_discovered = on_discovered
self.protocol = TPLinkSmartHomeProtocol()
self.target = (target, Discover.DISCOVERY_PORT)
self.discovered_devices = {}

Expand All @@ -61,7 +60,7 @@ def do_discover(self) -> None:
"""Send number of discovery datagrams."""
req = json.dumps(Discover.DISCOVERY_QUERY)
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
encrypted_req = self.protocol.encrypt(req)
encrypted_req = TPLinkSmartHomeProtocol.encrypt(req)
for i in range(self.discovery_packets):
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore

Expand All @@ -71,7 +70,7 @@ def datagram_received(self, data, addr) -> None:
if ip in self.discovered_devices:
return

info = json.loads(self.protocol.decrypt(data))
info = json.loads(TPLinkSmartHomeProtocol.decrypt(data))
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)

device_class = Discover._get_device_class(info)
Expand Down Expand Up @@ -190,9 +189,9 @@ async def discover_single(host: str) -> SmartDevice:
:rtype: SmartDevice
:return: Object for querying/controlling found device.
"""
protocol = TPLinkSmartHomeProtocol()
protocol = TPLinkSmartHomeProtocol(host)

info = await protocol.query(host, Discover.DISCOVERY_QUERY)
info = await protocol.query(Discover.DISCOVERY_QUERY)

device_class = Discover._get_device_class(info)
dev = device_class(host)
Expand Down
140 changes: 104 additions & 36 deletions kasa/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
http://www.apache.org/licenses/LICENSE-2.0
"""
import asyncio
import contextlib
import json
import logging
import struct
from pprint import pformat as pf
from typing import Dict, Union
from typing import Dict, Optional, Union

from .exceptions import SmartDeviceException

Expand All @@ -28,8 +29,26 @@ class TPLinkSmartHomeProtocol:
DEFAULT_PORT = 9999
DEFAULT_TIMEOUT = 5

@staticmethod
async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict:
BLOCK_SIZE = 4

def __init__(self, host: str) -> None:
"""Create a protocol object."""
self.host = host
self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None
self.query_lock: Optional[asyncio.Lock] = None
self.loop: Optional[asyncio.AbstractEventLoop] = None

def _detect_event_loop_change(self) -> None:
"""Check if this object has been reused betwen event loops."""
loop = asyncio.get_running_loop()
if not self.loop:
self.loop = loop
elif self.loop != loop:
_LOGGER.warning("Detected protocol reuse between different event loop")
self._reset()

async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Request information from a TP-Link SmartHome Device.
:param str host: host name or ip address of the device
Expand All @@ -38,57 +57,106 @@ async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> D
:param retry_count: how many retries to do in case of failure
:return: response dict
"""
self._detect_event_loop_change()

if not self.query_lock:
self.query_lock = asyncio.Lock()

if isinstance(request, dict):
request = json.dumps(request)
assert isinstance(request, str)

timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
writer = None

async with self.query_lock:
return await self._query(request, retry_count, timeout)

async def _connect(self, timeout: int) -> bool:
"""Try to connect or reconnect to the device."""
if self.writer:
return True

with contextlib.suppress(Exception):
self.reader = self.writer = None
task = asyncio.open_connection(
self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT
)
self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout)
return True

return False

async def _execute_query(self, request: str) -> Dict:
"""Execute a query on the device and wait for the response."""
assert self.writer is not None
assert self.reader is not None

_LOGGER.debug("> (%i) %s", len(request), request)
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
await self.writer.drain()

packed_block_size = await self.reader.readexactly(self.BLOCK_SIZE)
length = struct.unpack(">I", packed_block_size)[0]

buffer = await self.reader.readexactly(length)
response = TPLinkSmartHomeProtocol.decrypt(buffer)
json_payload = json.loads(response)
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
return json_payload

async def close(self):
"""Close the connection."""
writer = self.writer
self._reset()
if writer:
writer.close()
with contextlib.suppress(Exception):
await writer.wait_closed()

def _reset(self):
"""Clear any varibles that should not survive between loops."""
self.writer = None
self.reader = None
self.query_lock = None
self.loop = None

async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
"""Try to query a device."""
for retry in range(retry_count + 1):
if not await self._connect(timeout):
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up after %s retries", retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}"
)
continue

try:
task = asyncio.open_connection(
host, TPLinkSmartHomeProtocol.DEFAULT_PORT
assert self.reader is not None
assert self.writer is not None
return await asyncio.wait_for(
self._execute_query(request), timeout=timeout
)
reader, writer = await asyncio.wait_for(task, timeout=timeout)
_LOGGER.debug("> (%i) %s", len(request), request)
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
await writer.drain()

buffer = bytes()
# Some devices send responses with a length header of 0 and
# terminate with a zero size chunk. Others send the length and
# will hang if we attempt to read more data.
length = -1
while True:
chunk = await reader.read(4096)
if length == -1:
length = struct.unpack(">I", chunk[0:4])[0]
buffer += chunk
if (length > 0 and len(buffer) >= length + 4) or not chunk:
break

response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
json_payload = json.loads(response)
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))

return json_payload

except Exception as ex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up after %s retries", retry)
raise SmartDeviceException(
"Unable to query the device: %s" % ex
f"Unable to query the device: {ex}"
) from ex

_LOGGER.debug("Unable to query the device, retrying: %s", ex)

finally:
if writer:
writer.close()
await writer.wait_closed()

# make mypy happy, this should never be reached..
await self.close()
raise SmartDeviceException("Query reached somehow to unreachable")

def __del__(self):
if self.writer and self.loop and self.loop.is_running():
self.writer.close()
self._reset()

@staticmethod
def _xor_payload(unencrypted):
key = TPLinkSmartHomeProtocol.INITIALIZATION_VECTOR
Expand Down
14 changes: 7 additions & 7 deletions kasa/smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def __init__(self, host: str) -> None:
"""
self.host = host

self.protocol = TPLinkSmartHomeProtocol()
self.protocol = TPLinkSmartHomeProtocol(host)
self.emeter_type = "emeter"
_LOGGER.debug("Initializing %s of type %s", self.host, type(self))
self._device_type = DeviceType.Unknown
Expand Down Expand Up @@ -234,7 +234,7 @@ async def _query_helper(
request = self._create_request(target, cmd, arg, child_ids)

try:
response = await self.protocol.query(host=self.host, request=request)
response = await self.protocol.query(request=request)
except Exception as ex:
raise SmartDeviceException(f"Communication error on {target}:{cmd}") from ex

Expand Down Expand Up @@ -272,7 +272,7 @@ async def get_sys_info(self) -> Dict[str, Any]:
"""Retrieve system information."""
return await self._query_helper("system", "get_sysinfo")

async def update(self):
async def update(self, update_children: bool = True):
"""Query the device to update the data.
Needed for properties that are decorated with `requires_update`.
Expand All @@ -285,7 +285,7 @@ async def update(self):
# See #105, #120, #161
if self._last_update is None:
_LOGGER.debug("Performing the initial update to obtain sysinfo")
self._last_update = await self.protocol.query(self.host, req)
self._last_update = await self.protocol.query(req)
self._sys_info = self._last_update["system"]["get_sysinfo"]
# If the device has no emeter, we are done for the initial update
# Otherwise we will follow the regular code path to also query
Expand All @@ -299,7 +299,7 @@ async def update(self):
)
req.update(self._create_emeter_request())

self._last_update = await self.protocol.query(self.host, req)
self._last_update = await self.protocol.query(req)
self._sys_info = self._last_update["system"]["get_sysinfo"]

def update_from_discover_info(self, info):
Expand Down Expand Up @@ -383,8 +383,8 @@ def location(self) -> Dict:
loc["latitude"] = sys_info["latitude"]
loc["longitude"] = sys_info["longitude"]
elif "latitude_i" in sys_info and "longitude_i" in sys_info:
loc["latitude"] = sys_info["latitude_i"]
loc["longitude"] = sys_info["longitude_i"]
loc["latitude"] = sys_info["latitude_i"] / 10000
loc["longitude"] = sys_info["longitude_i"] / 10000
else:
_LOGGER.warning("Unsupported device location.")

Expand Down
10 changes: 5 additions & 5 deletions kasa/smartstrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ def is_on(self) -> bool:
"""Return if any of the outlets are on."""
return any(plug.is_on for plug in self.children)

async def update(self):
async def update(self, update_children: bool = True):
"""Update some of the attributes.
Needed for methods that are decorated with `requires_update`.
"""
await super().update()
await super().update(update_children)

# Initialize the child devices during the first update.
if not self.children:
Expand All @@ -103,7 +103,7 @@ async def update(self):
SmartStripPlug(self.host, parent=self, child_id=child["id"])
)

if self.has_emeter:
if update_children and self.has_emeter:
for plug in self.children:
await plug.update()

Expand Down Expand Up @@ -243,13 +243,13 @@ def __init__(self, host: str, parent: "SmartStrip", child_id: str) -> None:
self._sys_info = parent._sys_info
self._device_type = DeviceType.StripSocket

async def update(self):
async def update(self, update_children: bool = True):
"""Query the device to update the data.
Needed for properties that are decorated with `requires_update`.
"""
self._last_update = await self.parent.protocol.query(
self.host, self._create_emeter_request()
self._create_emeter_request()
)

def _create_request(
Expand Down

0 comments on commit e31cc66

Please sign in to comment.