Skip to content

Commit

Permalink
Optimize device control
Browse files Browse the repository at this point in the history
* There's is a delay now before it sends the other payload.
* Revert back status method instead of wait for detect dps.
  • Loading branch information
xZetsubou committed Apr 21, 2024
1 parent 6d57740 commit 07fa10a
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 82 deletions.
23 changes: 11 additions & 12 deletions custom_components/localtuya/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async def _make_connection(self):

self.debug("Retrieving initial state")
# Usually we use status instead of detect_available_dps, but some device doesn't reports all dps when ask for status.
status = await self._interface.detect_available_dps(cid=self._node_id)
status = await self._interface.status(cid=self._node_id)
if status is None: # and not self.is_subdevice
raise Exception("Failed to retrieve status")

Expand All @@ -218,6 +218,8 @@ async def _make_connection(self):
e = "Sub device is not connected" if self.is_subdevice else e
self.warning(f"Connect to {host} failed: {e}")
await self.abort_connect()
if self.is_subdevice:
update_localkey = True
except:
if self._fake_gateway:
self.warning(f"Failed to use {name} as gateway.")
Expand Down Expand Up @@ -247,17 +249,14 @@ def _new_entity_handler(entity_id):
self._connect_task = None
self.debug(f"Success: connected to {host}", force=True)
if self._sub_devices:
connect_sub_devices = [
device.async_connect() for device in self._sub_devices.values()
]
await asyncio.gather(*connect_sub_devices)
for subdevice in self._sub_devices.values():
self._hass.async_create_task(subdevice.async_connect())

if not self._status and "0" in self._device_config.manual_dps.split(","):
self.status_updated(RESTORE_STATES)

if self._pending_status:
await self.set_dps(self._pending_status)
self._pending_status = {}
await self.set_status()

# If not connected try to handle the errors.
if not self._interface:
Expand Down Expand Up @@ -349,15 +348,15 @@ async def update_local_key(self):
)
self.info(f"local_key updated for device {name}.")

async def set_values(self):
async def set_status(self):
"""Send self._pending_status payload to device."""
await self.check_connection()
if self._interface and self._pending_status:
payload, self._pending_status = self._pending_status.copy(), {}
try:
await self._interface.set_dps(payload, cid=self._node_id)
except Exception: # pylint: disable=broad-except
self.debug(f"Failed to set values {payload}", force=True)
except Exception as ex: # pylint: disable=broad-except
self.debug(f"Failed to set values {payload} --> {ex}", force=True)
elif not self._interface:
self.error(f"Device is not connected.")

Expand All @@ -366,7 +365,7 @@ async def set_dp(self, state, dp_index):
if self._interface is not None:
self._pending_status.update({dp_index: state})
await asyncio.sleep(0.001)
await self.set_values()
await self.set_status()
else:
if self.is_sleep:
return self._pending_status.update({str(dp_index): state})
Expand All @@ -376,7 +375,7 @@ async def set_dps(self, states):
if self._interface is not None:
self._pending_status.update(states)
await asyncio.sleep(0.001)
await self.set_values()
await self.set_status()
else:
if self.is_sleep:
return self._pending_status.update(states)
Expand Down
147 changes: 77 additions & 70 deletions custom_components/localtuya/core/pytuya/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,19 +616,19 @@ def add_data(self, data):
"""Add new data to the buffer and try to parse messages."""
self.buffer += data

header_len_55AA = struct.calcsize(MESSAGE_HEADER_FMT_55AA)
header_len_55AA = struct.calcsize(MESSAGE_RECV_HEADER_FMT)
header_len_6699 = struct.calcsize(MESSAGE_HEADER_FMT_6699)
header_len = header_len_55AA

header_len = header_len_55AA
prefix_len = len(PREFIX_55AA_BIN)

while self.buffer:
prefix_offset_55AA = self.buffer.find(PREFIX_55AA_BIN)
prefix_offset_6699 = self.buffer.find(PREFIX_6699_BIN)

if prefix_offset_55AA < 0 and prefix_offset_6699 < 0:
self.buffer = self.buffer[1 - prefix_len :]
header_len = header_len_55AA
self.buffer = self.buffer[1 - prefix_len :]
else:
header_len = header_len_6699
prefix_offset = (
Expand All @@ -655,64 +655,35 @@ def add_data(self, data):

def _dispatch(self, msg):
"""Dispatch a message to someone that is listening."""
# ON devices >= 3.4 the seqno get conflict with the waited seqno.
# The devices sends cmds 8 and 9 usually before NEW_CONTROL which increase the seqno.
# ^ This needs to be handle in better way, The fix atm is just workaround.

self.debug("Dispatching message CMD %r %s", msg.cmd, msg)
if msg.seqno in self.listeners and msg.cmd != STATUS:
# self.debug("Dispatching sequence number %d", msg.seqno)
sem = self.listeners[msg.seqno]
if isinstance(sem, asyncio.Semaphore):
self.listeners[msg.seqno] = msg
sem.release()
else:
self.debug("Got additional message without request - skipping: %s", sem)
elif msg.cmd == HEART_BEAT:

if msg.seqno in self.listeners:
self.debug("Dispatching sequence number %d", msg.seqno)
self._release_listener(msg.seqno, msg)

if msg.cmd == HEART_BEAT:

This comment has been minimized.

Copy link
@Lurker00

Lurker00 Apr 27, 2024

Now self._release_listener (and then sem.release()) is called twice for all known commands. Is it good for Semaphore? You may believe that, after the first release() call, the waiting thread would delete it faster and always faster, but that's called "race condition".

Moreover, self.listeners.pop(seqno) in thewait_for may happen between the check and get here:

        if seqno not in self.listeners:
            return

        sem = self.listeners[seqno]

This comment has been minimized.

Copy link
@xZetsubou

xZetsubou Apr 27, 2024

Author Owner

Why do you think it will be called twice? these commands has special cases for sequence number it won't shows in the messages, we releases them by there own static sequence number depending on the command type, for example heartbeat command will always release -100 which is not the same msg sequence.

This comment has been minimized.

Copy link
@Lurker00

Lurker00 Apr 28, 2024

Sorry, I misunderstood the protocol. To be sure, I've collected a log with "Dispatching sequence number" messages and there are only positive numbers there.

Sorry for false alarms (also below)!

self.debug("Got heartbeat response")
if self.HEARTBEAT_SEQNO in self.listeners:
sem = self.listeners[self.HEARTBEAT_SEQNO]
self.listeners[self.HEARTBEAT_SEQNO] = msg
sem.release()
self._release_listener(self.HEARTBEAT_SEQNO, msg)
elif msg.cmd == UPDATEDPS:
self.debug("Got normal updatedps response")
if self.RESET_SEQNO in self.listeners:
sem = self.listeners[self.RESET_SEQNO]
if isinstance(sem, asyncio.Semaphore):
self.listeners[self.RESET_SEQNO] = msg
sem.release()
else:
self.debug(
"Got additional updatedps message without request - skipping: %s",
sem,
)
self._release_listener(self.RESET_SEQNO, msg)
if self.RESET_SEQNO not in self.listeners:
self.debug(

This comment has been minimized.

Copy link
@Lurker00

Lurker00 Apr 27, 2024

After two subsequent calls to _release_listener, there is a big chance that the condition is true. Moreover, it can be true if this call is the only one: it depends on multithreading performance. So, this check looks obsolete anyway.

"Got additional updatedps message without request - skipping: %s",
sem,
)
elif msg.cmd == SESS_KEY_NEG_RESP:
self.debug("Got key negotiation response")
if self.SESS_KEY_SEQNO in self.listeners:
sem = self.listeners[self.SESS_KEY_SEQNO]
self.listeners[self.SESS_KEY_SEQNO] = msg
sem.release()
self._release_listener(self.SESS_KEY_SEQNO, msg)
elif msg.cmd == STATUS:
if self.RESET_SEQNO in self.listeners:
self.debug("Got reset status update")
self._release_listener(self.RESET_SEQNO, msg)
sem = self.listeners[self.RESET_SEQNO]
if isinstance(sem, asyncio.Semaphore):
self.listeners[self.RESET_SEQNO] = msg
sem.release()
else:
self.debug(
"Got additional reset message without request - skipping: %s",
sem,
)
else:
self.debug("Got status update")
self.callback_status_update(msg)
# workdaround for >= v3.4 devices until find prper way to wait seqno correctly.
if msg.seqno in self.listeners:
sem = self.listeners[msg.seqno]
if isinstance(sem, asyncio.Semaphore):
self.listeners[msg.seqno] = msg
sem.release()
else:
if msg.cmd == CONTROL_NEW:
self.debug("Got ACK message for command %d: will ignore it", msg.cmd)
Expand All @@ -724,6 +695,17 @@ def _dispatch(self, msg):
msg,
)

def _release_listener(self, seqno, msg):
if seqno not in self.listeners:
return

sem = self.listeners[seqno]
if isinstance(sem, asyncio.Semaphore):
self.listeners[seqno] = msg
sem.release()
else:
self.debug("Got additional message without request - skipping: %s", sem)


class TuyaListener(ABC):
"""Listener interface for Tuya device changes."""
Expand Down Expand Up @@ -752,6 +734,8 @@ def disconnected(self):
class TuyaProtocol(asyncio.Protocol, ContextualLogger):
"""Implementation of the Tuya protocol."""

HEARTBEAT_SKIP = 5

def __init__(
self,
dev_id: str,
Expand Down Expand Up @@ -800,6 +784,7 @@ def __init__(
self.remote_nonce = b""
self.dps_whitelist = UPDATE_DPS_WHITELIST
self.dispatched_dps = {} # Store payload so we can trigger an event in HA.
self._last_command_sent = 1

def set_version(self, protocol_version):
"""Set the device version and eventually start available DPs detection."""
Expand Down Expand Up @@ -827,35 +812,33 @@ def error_json(self, number=None, payload=None):

return json.loads('{ "Error":"%s", "Err":"%s", "Payload":%s }' % vals)

def _setup_dispatcher(self, enable_debug):
def _setup_dispatcher(self, enable_debug) -> MessageDispatcher:
def _status_update(msg):
if msg.seqno > 0:
self.seqno = msg.seqno + 1
decoded_message: dict = self._decode_payload(msg.payload)
new_states = {}
cid = None

if "dps" in decoded_message:
if "dps" not in decoded_message:
return

if dps_payload := decoded_message.get("dps"):
if cid := decoded_message.get("cid"):
if cid in self.dps_cache:
self.dps_cache[cid].update(decoded_message["dps"])
else:
self.dps_cache.update({cid: decoded_message["dps"]})
self.dps_cache.setdefault(cid, {})
self.dps_cache[cid].update(dps_payload)
else:
if "parent" in self.dps_cache:
self.dps_cache["parent"].update(decoded_message["dps"])
else:
self.dps_cache.update({"parent": decoded_message["dps"]})
self.dps_cache.setdefault("parent", {})
self.dps_cache["parent"].update(dps_payload)

listener = self.listener and self.listener()
if listener is not None:
if cid:
listener = listener._sub_devices.get(cid, listener)
new_states = self.dps_cache.get(cid)
device = self.dps_cache.get(cid, {})
else:
new_states = self.dps_cache.get("parent", {})
device = self.dps_cache.get("parent", {})

listener.status_updated(new_states)
listener.status_updated(device)

return MessageDispatcher(
self.id, _status_update, self.version, self.local_key, enable_debug
Expand All @@ -866,6 +849,22 @@ def connection_made(self, transport):
self.transport = transport
self.on_connected.set_result(True)

async def transport_write(self, data, command_delay=True):
"""Write data on transport, The 'command_delay' will ensure that no massive requests happen all at once."""
wait = 0
while command_delay and self.last_command_sent < 0.050:
await asyncio.sleep(0.060)
wait += 1
if wait == 10:
break

try:
self._last_command_sent = time.time()
self.transport.write(data)
except Exception:
await self.close()
raise

def start_heartbeat(self):
"""Start the heartbeat transmissions with the device."""

Expand All @@ -874,7 +873,8 @@ async def heartbeat_loop():
self.debug("Started heartbeat loop")
while True:
try:
await self.heartbeat()
if self.last_command_sent > self.HEARTBEAT_SKIP:
await self.heartbeat()
await asyncio.sleep(HEARTBEAT_INTERVAL)
except asyncio.CancelledError:
self.debug("Stopped heartbeat loop")
Expand Down Expand Up @@ -944,10 +944,9 @@ async def exchange_quick(self, payload, recv_retries):
# self.debug("Quick-dispatching message %s, seqno %s", binascii.hexlify(enc_payload), self.seqno)

try:
self.transport.write(enc_payload)
await self.transport_write(enc_payload)
except Exception:
# self._check_socket_close(True)
self.close()
await self.close()
return None
while recv_retries:
try:
Expand All @@ -972,13 +971,15 @@ async def exchange_quick(self, payload, recv_retries):
)
return None

async def exchange(self, command, dps=None, nodeID=None):
async def exchange(self, command, dps=None, nodeID=None, delay=True):
"""Send and receive a message, returning response from device."""
if self.version >= 3.4 and self.real_local_key == self.local_key:
self.debug("3.4 or 3.5 device: negotiating a new session key")
await self._negotiate_session_key()

self.debug("Sending command %s (device type: %s)", command, self.dev_type)
self.debug(
"Sending command %s (device type: %s) DPS: %s", command, self.dev_type, dps
)
payload = self._generate_payload(command, dps, nodeId=nodeID)
real_cmd = payload.cmd
dev_type = self.dev_type
Expand All @@ -993,7 +994,8 @@ async def exchange(self, command, dps=None, nodeID=None):
seqno = MessageDispatcher.RESET_SEQNO

enc_payload = self._encode_message(payload)
self.transport.write(enc_payload)

await self.transport_write(enc_payload, delay)
msg = await self.dispatcher.wait_for(seqno, payload.cmd)
if msg is None:
self.debug("Wait was aborted for seqno %d", seqno)
Expand All @@ -1020,7 +1022,7 @@ async def exchange(self, command, dps=None, nodeID=None):

async def status(self, cid=None):
"""Return device status."""
status: dict = await self.exchange(command=DP_QUERY, nodeID=cid)
status: dict = await self.exchange(command=DP_QUERY, nodeID=cid, delay=False)

if status:
if cid and "dps" in status:
Expand Down Expand Up @@ -1067,7 +1069,7 @@ async def update_dps(self, dps=None, cid=None):
dps = list(set(dps).intersection(set(self.dps_whitelist)))
payload = self._generate_payload(UPDATEDPS, dps, nodeId=cid)
enc_payload = self._encode_message(payload)
self.transport.write(enc_payload)
await self.transport_write(enc_payload)
return True

async def set_dp(self, value, dp_index, cid=None):
Expand Down Expand Up @@ -1459,6 +1461,11 @@ def deepcopy_dict(_dict: dict):

return MessagePayload(command_override, payload)

@property
def last_command_sent(self):
"""Return last command sent by seconds"""
return time.time() - self._last_command_sent

def __repr__(self):
"""Return internal string representation of object."""
return self.id
Expand Down

0 comments on commit 07fa10a

Please sign in to comment.