diff --git a/memphis/memphis.py b/memphis/memphis.py index cc0d91b..c1235af 100644 --- a/memphis/memphis.py +++ b/memphis/memphis.py @@ -55,7 +55,7 @@ def __init__(self): self.producers_map = dict() self.consumers_map = dict() - async def get_msgs_update_configurations(self, iterable: Iterable): + async def get_msgs_sdk_clients_updates(self, iterable: Iterable): try: async for msg in iterable: message = msg.data.decode("utf-8") @@ -66,18 +66,21 @@ async def get_msgs_update_configurations(self, iterable: Iterable): self.station_schemaverse_to_dls[data["station_name"]] = data[ "update" ] + elif data["type"] == "remove_station": + self.unset_cached_producer_station(data['station_name']) + self.unset_cached_consumer_station(data['station_name']) except Exception as err: raise MemphisError(err) - async def configurations_listener(self): + async def sdk_client_updates_listener(self): try: sub = await self.broker_manager.subscribe( - "$memphis_sdk_configurations_updates" + "$memphis_sdk_clients_updates" ) self.update_configurations_sub = sub loop = asyncio.get_event_loop() task = loop.create_task( - self.get_msgs_update_configurations( + self.get_msgs_sdk_clients_updates( self.update_configurations_sub.messages ) ) @@ -155,7 +158,7 @@ async def connect( name=self.connection_id + "::" + self.username, ) - await self.configurations_listener() + await self.sdk_client_updates_listener() self.broker_connection = self.broker_manager.jetstream() self.is_connection_active = True except Exception as e: @@ -661,3 +664,26 @@ async def fetch_messages( def is_connected(self): return self.broker_manager.is_connected + + def unset_cached_producer_station(self, station_name): + try: + internal_station_name = get_internal_name(station_name) + for key in list(self.producers_map): + producer = self.producers_map[key] + if producer.internal_station_name == internal_station_name: + del self.producers_map[key] + except Exception as e: + raise e + + + def unset_cached_consumer_station(self, station_name): + try: + internal_station_name = get_internal_name(station_name) + for key in list(self.consumers_map): + consumer = self.consumers_map[key] + consumer_station_name_internal = get_internal_name(consumer.station_name) + if consumer_station_name_internal == internal_station_name: + del self.consumers_map[key] + except Exception as e: + raise e +