Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions memphis/memphis.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self):
self.producers_map = dict()
self.consumers_map = dict()

async def get_msgs_update_configurations(self, iterable: Iterable):
async def get_msgs_sdk_clients_updates(self, iterable: Iterable):
try:
async for msg in iterable:
message = msg.data.decode("utf-8")
Expand All @@ -66,18 +66,21 @@ async def get_msgs_update_configurations(self, iterable: Iterable):
self.station_schemaverse_to_dls[data["station_name"]] = data[
"update"
]
elif data["type"] == "remove_station":
self.unset_cached_producer_station(data['station_name'])
self.unset_cached_consumer_station(data['station_name'])
except Exception as err:
raise MemphisError(err)

async def configurations_listener(self):
async def sdk_client_updates_listener(self):
try:
sub = await self.broker_manager.subscribe(
"$memphis_sdk_configurations_updates"
"$memphis_sdk_clients_updates"
)
self.update_configurations_sub = sub
loop = asyncio.get_event_loop()
task = loop.create_task(
self.get_msgs_update_configurations(
self.get_msgs_sdk_clients_updates(
self.update_configurations_sub.messages
)
)
Expand Down Expand Up @@ -155,7 +158,7 @@ async def connect(
name=self.connection_id + "::" + self.username,
)

await self.configurations_listener()
await self.sdk_client_updates_listener()
self.broker_connection = self.broker_manager.jetstream()
self.is_connection_active = True
except Exception as e:
Expand Down Expand Up @@ -661,3 +664,26 @@ async def fetch_messages(

def is_connected(self):
return self.broker_manager.is_connected

def unset_cached_producer_station(self, station_name):
try:
internal_station_name = get_internal_name(station_name)
for key in list(self.producers_map):
producer = self.producers_map[key]
if producer.internal_station_name == internal_station_name:
del self.producers_map[key]
except Exception as e:
raise e


def unset_cached_consumer_station(self, station_name):
try:
internal_station_name = get_internal_name(station_name)
for key in list(self.consumers_map):
consumer = self.consumers_map[key]
consumer_station_name_internal = get_internal_name(consumer.station_name)
if consumer_station_name_internal == internal_station_name:
del self.consumers_map[key]
except Exception as e:
raise e