-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
* There's is a delay now before it sends the other payload. * Revert back status method instead of wait for detect dps.
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -616,19 +616,19 @@ def add_data(self, data): | |
"""Add new data to the buffer and try to parse messages.""" | ||
self.buffer += data | ||
|
||
header_len_55AA = struct.calcsize(MESSAGE_HEADER_FMT_55AA) | ||
header_len_55AA = struct.calcsize(MESSAGE_RECV_HEADER_FMT) | ||
header_len_6699 = struct.calcsize(MESSAGE_HEADER_FMT_6699) | ||
header_len = header_len_55AA | ||
|
||
header_len = header_len_55AA | ||
prefix_len = len(PREFIX_55AA_BIN) | ||
|
||
while self.buffer: | ||
prefix_offset_55AA = self.buffer.find(PREFIX_55AA_BIN) | ||
prefix_offset_6699 = self.buffer.find(PREFIX_6699_BIN) | ||
|
||
if prefix_offset_55AA < 0 and prefix_offset_6699 < 0: | ||
self.buffer = self.buffer[1 - prefix_len :] | ||
header_len = header_len_55AA | ||
self.buffer = self.buffer[1 - prefix_len :] | ||
else: | ||
header_len = header_len_6699 | ||
prefix_offset = ( | ||
|
@@ -655,64 +655,35 @@ def add_data(self, data): | |
|
||
def _dispatch(self, msg): | ||
"""Dispatch a message to someone that is listening.""" | ||
# ON devices >= 3.4 the seqno get conflict with the waited seqno. | ||
# The devices sends cmds 8 and 9 usually before NEW_CONTROL which increase the seqno. | ||
# ^ This needs to be handle in better way, The fix atm is just workaround. | ||
|
||
self.debug("Dispatching message CMD %r %s", msg.cmd, msg) | ||
if msg.seqno in self.listeners and msg.cmd != STATUS: | ||
# self.debug("Dispatching sequence number %d", msg.seqno) | ||
sem = self.listeners[msg.seqno] | ||
if isinstance(sem, asyncio.Semaphore): | ||
self.listeners[msg.seqno] = msg | ||
sem.release() | ||
else: | ||
self.debug("Got additional message without request - skipping: %s", sem) | ||
elif msg.cmd == HEART_BEAT: | ||
|
||
if msg.seqno in self.listeners: | ||
self.debug("Dispatching sequence number %d", msg.seqno) | ||
self._release_listener(msg.seqno, msg) | ||
|
||
if msg.cmd == HEART_BEAT: | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
xZetsubou
Author
Owner
|
||
self.debug("Got heartbeat response") | ||
if self.HEARTBEAT_SEQNO in self.listeners: | ||
sem = self.listeners[self.HEARTBEAT_SEQNO] | ||
self.listeners[self.HEARTBEAT_SEQNO] = msg | ||
sem.release() | ||
self._release_listener(self.HEARTBEAT_SEQNO, msg) | ||
elif msg.cmd == UPDATEDPS: | ||
self.debug("Got normal updatedps response") | ||
if self.RESET_SEQNO in self.listeners: | ||
sem = self.listeners[self.RESET_SEQNO] | ||
if isinstance(sem, asyncio.Semaphore): | ||
self.listeners[self.RESET_SEQNO] = msg | ||
sem.release() | ||
else: | ||
self.debug( | ||
"Got additional updatedps message without request - skipping: %s", | ||
sem, | ||
) | ||
self._release_listener(self.RESET_SEQNO, msg) | ||
if self.RESET_SEQNO not in self.listeners: | ||
self.debug( | ||
This comment has been minimized.
Sorry, something went wrong.
Lurker00
|
||
"Got additional updatedps message without request - skipping: %s", | ||
sem, | ||
) | ||
elif msg.cmd == SESS_KEY_NEG_RESP: | ||
self.debug("Got key negotiation response") | ||
if self.SESS_KEY_SEQNO in self.listeners: | ||
sem = self.listeners[self.SESS_KEY_SEQNO] | ||
self.listeners[self.SESS_KEY_SEQNO] = msg | ||
sem.release() | ||
self._release_listener(self.SESS_KEY_SEQNO, msg) | ||
elif msg.cmd == STATUS: | ||
if self.RESET_SEQNO in self.listeners: | ||
self.debug("Got reset status update") | ||
self._release_listener(self.RESET_SEQNO, msg) | ||
sem = self.listeners[self.RESET_SEQNO] | ||
if isinstance(sem, asyncio.Semaphore): | ||
self.listeners[self.RESET_SEQNO] = msg | ||
sem.release() | ||
else: | ||
self.debug( | ||
"Got additional reset message without request - skipping: %s", | ||
sem, | ||
) | ||
else: | ||
self.debug("Got status update") | ||
self.callback_status_update(msg) | ||
# workdaround for >= v3.4 devices until find prper way to wait seqno correctly. | ||
if msg.seqno in self.listeners: | ||
sem = self.listeners[msg.seqno] | ||
if isinstance(sem, asyncio.Semaphore): | ||
self.listeners[msg.seqno] = msg | ||
sem.release() | ||
else: | ||
if msg.cmd == CONTROL_NEW: | ||
self.debug("Got ACK message for command %d: will ignore it", msg.cmd) | ||
|
@@ -724,6 +695,17 @@ def _dispatch(self, msg): | |
msg, | ||
) | ||
|
||
def _release_listener(self, seqno, msg): | ||
if seqno not in self.listeners: | ||
return | ||
|
||
sem = self.listeners[seqno] | ||
if isinstance(sem, asyncio.Semaphore): | ||
self.listeners[seqno] = msg | ||
sem.release() | ||
else: | ||
self.debug("Got additional message without request - skipping: %s", sem) | ||
|
||
|
||
class TuyaListener(ABC): | ||
"""Listener interface for Tuya device changes.""" | ||
|
@@ -752,6 +734,8 @@ def disconnected(self): | |
class TuyaProtocol(asyncio.Protocol, ContextualLogger): | ||
"""Implementation of the Tuya protocol.""" | ||
|
||
HEARTBEAT_SKIP = 5 | ||
|
||
def __init__( | ||
self, | ||
dev_id: str, | ||
|
@@ -800,6 +784,7 @@ def __init__( | |
self.remote_nonce = b"" | ||
self.dps_whitelist = UPDATE_DPS_WHITELIST | ||
self.dispatched_dps = {} # Store payload so we can trigger an event in HA. | ||
self._last_command_sent = 1 | ||
|
||
def set_version(self, protocol_version): | ||
"""Set the device version and eventually start available DPs detection.""" | ||
|
@@ -827,35 +812,33 @@ def error_json(self, number=None, payload=None): | |
|
||
return json.loads('{ "Error":"%s", "Err":"%s", "Payload":%s }' % vals) | ||
|
||
def _setup_dispatcher(self, enable_debug): | ||
def _setup_dispatcher(self, enable_debug) -> MessageDispatcher: | ||
def _status_update(msg): | ||
if msg.seqno > 0: | ||
self.seqno = msg.seqno + 1 | ||
decoded_message: dict = self._decode_payload(msg.payload) | ||
new_states = {} | ||
cid = None | ||
|
||
if "dps" in decoded_message: | ||
if "dps" not in decoded_message: | ||
return | ||
|
||
if dps_payload := decoded_message.get("dps"): | ||
if cid := decoded_message.get("cid"): | ||
if cid in self.dps_cache: | ||
self.dps_cache[cid].update(decoded_message["dps"]) | ||
else: | ||
self.dps_cache.update({cid: decoded_message["dps"]}) | ||
self.dps_cache.setdefault(cid, {}) | ||
self.dps_cache[cid].update(dps_payload) | ||
else: | ||
if "parent" in self.dps_cache: | ||
self.dps_cache["parent"].update(decoded_message["dps"]) | ||
else: | ||
self.dps_cache.update({"parent": decoded_message["dps"]}) | ||
self.dps_cache.setdefault("parent", {}) | ||
self.dps_cache["parent"].update(dps_payload) | ||
|
||
listener = self.listener and self.listener() | ||
if listener is not None: | ||
if cid: | ||
listener = listener._sub_devices.get(cid, listener) | ||
new_states = self.dps_cache.get(cid) | ||
device = self.dps_cache.get(cid, {}) | ||
else: | ||
new_states = self.dps_cache.get("parent", {}) | ||
device = self.dps_cache.get("parent", {}) | ||
|
||
listener.status_updated(new_states) | ||
listener.status_updated(device) | ||
|
||
return MessageDispatcher( | ||
self.id, _status_update, self.version, self.local_key, enable_debug | ||
|
@@ -866,6 +849,22 @@ def connection_made(self, transport): | |
self.transport = transport | ||
self.on_connected.set_result(True) | ||
|
||
async def transport_write(self, data, command_delay=True): | ||
"""Write data on transport, The 'command_delay' will ensure that no massive requests happen all at once.""" | ||
wait = 0 | ||
while command_delay and self.last_command_sent < 0.050: | ||
await asyncio.sleep(0.060) | ||
wait += 1 | ||
if wait == 10: | ||
break | ||
|
||
try: | ||
self._last_command_sent = time.time() | ||
self.transport.write(data) | ||
except Exception: | ||
await self.close() | ||
raise | ||
|
||
def start_heartbeat(self): | ||
"""Start the heartbeat transmissions with the device.""" | ||
|
||
|
@@ -874,7 +873,8 @@ async def heartbeat_loop(): | |
self.debug("Started heartbeat loop") | ||
while True: | ||
try: | ||
await self.heartbeat() | ||
if self.last_command_sent > self.HEARTBEAT_SKIP: | ||
await self.heartbeat() | ||
await asyncio.sleep(HEARTBEAT_INTERVAL) | ||
except asyncio.CancelledError: | ||
self.debug("Stopped heartbeat loop") | ||
|
@@ -944,10 +944,9 @@ async def exchange_quick(self, payload, recv_retries): | |
# self.debug("Quick-dispatching message %s, seqno %s", binascii.hexlify(enc_payload), self.seqno) | ||
|
||
try: | ||
self.transport.write(enc_payload) | ||
await self.transport_write(enc_payload) | ||
except Exception: | ||
# self._check_socket_close(True) | ||
self.close() | ||
await self.close() | ||
return None | ||
while recv_retries: | ||
try: | ||
|
@@ -972,13 +971,15 @@ async def exchange_quick(self, payload, recv_retries): | |
) | ||
return None | ||
|
||
async def exchange(self, command, dps=None, nodeID=None): | ||
async def exchange(self, command, dps=None, nodeID=None, delay=True): | ||
"""Send and receive a message, returning response from device.""" | ||
if self.version >= 3.4 and self.real_local_key == self.local_key: | ||
self.debug("3.4 or 3.5 device: negotiating a new session key") | ||
await self._negotiate_session_key() | ||
|
||
self.debug("Sending command %s (device type: %s)", command, self.dev_type) | ||
self.debug( | ||
"Sending command %s (device type: %s) DPS: %s", command, self.dev_type, dps | ||
) | ||
payload = self._generate_payload(command, dps, nodeId=nodeID) | ||
real_cmd = payload.cmd | ||
dev_type = self.dev_type | ||
|
@@ -993,7 +994,8 @@ async def exchange(self, command, dps=None, nodeID=None): | |
seqno = MessageDispatcher.RESET_SEQNO | ||
|
||
enc_payload = self._encode_message(payload) | ||
self.transport.write(enc_payload) | ||
|
||
await self.transport_write(enc_payload, delay) | ||
msg = await self.dispatcher.wait_for(seqno, payload.cmd) | ||
if msg is None: | ||
self.debug("Wait was aborted for seqno %d", seqno) | ||
|
@@ -1020,7 +1022,7 @@ async def exchange(self, command, dps=None, nodeID=None): | |
|
||
async def status(self, cid=None): | ||
"""Return device status.""" | ||
status: dict = await self.exchange(command=DP_QUERY, nodeID=cid) | ||
status: dict = await self.exchange(command=DP_QUERY, nodeID=cid, delay=False) | ||
|
||
if status: | ||
if cid and "dps" in status: | ||
|
@@ -1067,7 +1069,7 @@ async def update_dps(self, dps=None, cid=None): | |
dps = list(set(dps).intersection(set(self.dps_whitelist))) | ||
payload = self._generate_payload(UPDATEDPS, dps, nodeId=cid) | ||
enc_payload = self._encode_message(payload) | ||
self.transport.write(enc_payload) | ||
await self.transport_write(enc_payload) | ||
return True | ||
|
||
async def set_dp(self, value, dp_index, cid=None): | ||
|
@@ -1459,6 +1461,11 @@ def deepcopy_dict(_dict: dict): | |
|
||
return MessagePayload(command_override, payload) | ||
|
||
@property | ||
def last_command_sent(self): | ||
"""Return last command sent by seconds""" | ||
return time.time() - self._last_command_sent | ||
|
||
def __repr__(self): | ||
"""Return internal string representation of object.""" | ||
return self.id | ||
|
Now
self._release_listener
(and thensem.release()
) is called twice for all known commands. Is it good forSemaphore
? You may believe that, after the firstrelease()
call, the waiting thread would delete it faster and always faster, but that's called "race condition".Moreover,
self.listeners.pop(seqno)
in thewait_for
may happen between the check and get here: