diff --git a/custom_components/localtuya/common.py b/custom_components/localtuya/common.py index 88ea4ade2..56c8535be 100644 --- a/custom_components/localtuya/common.py +++ b/custom_components/localtuya/common.py @@ -57,21 +57,6 @@ _LOGGER = logging.getLogger(__name__) -def prepare_setup_entities(hass, config_entry, platform): - """Prepare ro setup entities for a platform.""" - entities_to_setup = [ - entity - for entity in config_entry.data[CONF_ENTITIES] - if entity[CONF_PLATFORM] == platform - ] - if not entities_to_setup: - return None, None - - tuyainterface = [] - - return tuyainterface, entities_to_setup - - async def async_setup_entry( domain, entity_class, @@ -179,7 +164,7 @@ def __init__( self._status = {} self.dps_to_request = {} self._is_closing = False - self._connect_task: bool | None = None + self._connect_task: asyncio.Task | None = None self._disconnect_task: Callable[[], None] | None = None self._unsub_interval: Callable[[], None] = None self._entities = [] @@ -241,27 +226,29 @@ async def async_connect(self, _now=None) -> None: if not self._is_closing and not self.is_connecting and not self.connected: try: - await asyncio.wait_for(self._make_connection(), 5) - except TimeoutError: + self._connect_task = self._hass.async_create_task( + asyncio.wait_for(self._make_connection(), 5) + ) + await self._connect_task + except (TimeoutError, asyncio.CancelledError): ... - # self._connect_task = asyncio.create_task(self._make_connection()) async def _make_connection(self): """Subscribe localtuya entity events.""" - self._connect_task = True name = self._device_config.get(CONF_FRIENDLY_NAME) host = name if self.is_subdevice else self._device_config.get(CONF_HOST) retry = 0 + self.debug(f"Trying to connect to {host}...", force=True) while retry < self._connect_max_tries: retry += 1 try: - self.debug(f"Trying to connect to {host}...", force=True) if self.is_subdevice: await self.get_gateway() gateway = self._gwateway - if gateway and not gateway.connected or gateway.is_connecting: - self._connect_task = None - return + # if not gateway or not (gateway.connected and gateway.is_connecting): + # return await self.abort_connect() + if gateway and gateway.is_connecting: + await gateway._connect_task self._interface = gateway._interface else: self._interface = await pytuya.connect( @@ -294,7 +281,7 @@ async def _make_connection(self): self.debug("Retrieving initial state") status = await self._interface.status(cid=self._node_id) - if status is None and not self.is_subdevice: + if status is None: # and not self.is_subdevice raise Exception("Failed to retrieve status") if not self._interface.heartbeater: self._interface.start_heartbeat() @@ -338,12 +325,12 @@ def _new_entity_handler(entity_id): ) self._is_closing = False + self._connect_task = None self.debug(f"Success: connected to {host}", force=True) if self._sub_devices: connect_sub_devices = [ device.async_connect() for device in self._sub_devices.values() ] - self._connect_task = None await asyncio.gather(*connect_sub_devices) self._connect_task = None @@ -352,11 +339,35 @@ async def abort_connect(self): """Abort the connect process to the interface[device]""" if self.is_subdevice: self._interface = None + self._connect_task = None if self._interface is not None: await self._interface.close() self._interface = None - self._connect_task = None + + async def check_connection(self): + """Ensure that the device is not still connecting; if it is, wait for it.""" + if self._connect_task: + await self._connect_task + if self._gwateway and self._gwateway._connect_task: + await self._gwateway._connect_task + + async def close(self): + """Close connection and stop re-connect loop.""" + self._is_closing = True + if self._connect_task is not None: + self._connect_task.cancel() + await self._connect_task + self._connect_task = None + if self._interface is not None: + await self._interface.close() + self._interface = None + if self._disconnect_task is not None: + self._disconnect_task() + self.debug( + f"Closed connection with {self._device_config[CONF_FRIENDLY_NAME]}", + force=True, + ) async def update_local_key(self): """Retrieve updated local_key from Cloud API and update the config_entry.""" @@ -380,25 +391,9 @@ async def _async_refresh(self, _now): self.debug("Refreshing dps for device") await self._interface.update_dps(cid=self._node_id) - async def close(self): - """Close connection and stop re-connect loop.""" - self._is_closing = True - if self._connect_task is not None: - # self._connect_task.cancel() - # await self._connect_task - self._connect_task = None - if self._interface is not None: - await self._interface.close() - self._interface = None - if self._disconnect_task is not None: - self._disconnect_task() - self.debug( - f"Closed connection with {self._device_config[CONF_FRIENDLY_NAME]}", - force=True, - ) - async def set_dp(self, state, dp_index): """Change value of a DP of the Tuya device.""" + await self.check_connection() if self._interface is not None: try: await self._interface.set_dp(state, dp_index, cid=self._node_id) @@ -411,6 +406,7 @@ async def set_dp(self, state, dp_index): async def set_dps(self, states): """Change value of a DPs of the Tuya device.""" + await self.check_connection() if self._interface is not None: try: await self._interface.set_dps(states, cid=self._node_id) @@ -421,23 +417,11 @@ async def set_dps(self, states): f"Not connected to device {self._device_config[CONF_FRIENDLY_NAME]}" ) - @callback - def status_updated(self, status: dict): - """Device updated status.""" - if self._fake_gateway: - # Fake gateways are only used to pass commands no need to update status. - return - cid = self._node_id - status = status.get(cid, {}) if cid else status.get("parent", {}) - self._handle_event(self._status, status) - self._status.update(status) - self._dispatch_status() - def _dispatch_status(self): signal = f"localtuya_{self._device_config[CONF_DEVICE_ID]}" async_dispatcher_send(self._hass, signal, self._status) - def _handle_event(self, old_status, new_status, deviceID=None): + def _handle_event(self, old_status: dict, new_status: dict, deviceID=None): """Handle events in HA when devices updated.""" def fire_event(event, data: dict): @@ -473,11 +457,28 @@ def fire_event(event, data: dict): data = {"dp": dpid_trigger, "value": dpid_value} fire_event(event, data) + @callback + def status_updated(self, status: dict): + """Device updated status.""" + if self._fake_gateway: + # Fake gateways are only used to pass commands no need to update status. + return + cid = self._node_id + status = status.get(cid, {}) if cid else status.get("parent", {}) + self._handle_event(self._status, status) + self._status.update(status) + self._dispatch_status() + @callback def disconnected(self): """Device disconnected.""" - signal = f"localtuya_{self._device_config[CONF_DEVICE_ID]}" - async_dispatcher_send(self._hass, signal, None) + + def shutdown_entities(now=None): + """Shutdown device entities""" + if not self.connected: + signal = f"localtuya_{self._device_config[CONF_DEVICE_ID]}" + async_dispatcher_send(self._hass, signal, None) + if self._unsub_interval is not None: self._unsub_interval() self._unsub_interval = None @@ -488,15 +489,17 @@ def disconnected(self): sub_dev.disconnected() if self._connect_task is not None: - # self._connect_task.cancel() + self._connect_task.cancel() self._connect_task = None - # If it's disconnect by unexpected error. + # If it disconnects unexpectedly. if self._is_closing is not True and not self.is_subdevice: self.debug(f"Disconnected - waiting for discovery broadcast", force=True) # Try to quickly reconnect. self._is_closing = False self._config_entry.async_create_task(self._hass, self.async_connect()) + if not self._is_closing: + async_call_later(self._hass, 5, shutdown_entities) class LocalTuyaEntity(RestoreEntity, pytuya.ContextualLogger):