diff --git a/memphis/memphis.py b/memphis/memphis.py index a4c27b3..c1292da 100644 --- a/memphis/memphis.py +++ b/memphis/memphis.py @@ -48,6 +48,10 @@ def __init__(self): self.schema_updates_data = {} self.partition_producers_updates_data = {} self.partition_consumers_updates_data = {} + self.functions_updates_data = {} + self.functions_updates_subs = {} + self.functions_tasks = {} + self.functions_clients_per_station = {} self.schema_updates_subs = {} self.clients_per_station = {} self.schema_tasks = {} @@ -468,6 +472,10 @@ async def producer( self.update_schema_data(station_name) + if "station_version" in create_res: + if create_res["station_version"] > 0: + await self.start_listen_for_functions_updates(internal_station_name, create_res["station_partitions_first_functions"]) + producer = Producer(self, producer_name, station_name, real_name) map_key = internal_station_name + "_" + real_name self.producers_map[map_key] = producer @@ -545,6 +553,39 @@ def parse_descriptor(self, station_name): except Exception as e: raise MemphisError(str(e)) from e + async def start_listen_for_functions_updates(self, station_name, first_functions): + #first_functions should contain the dict of the first function of each partition key: partition number, value: first function id + + if station_name in self.functions_updates_subs: + self.functions_clients_per_station[station_name] += 1 + return + else: + self.functions_clients_per_station[station_name] = 1 + + functions_updates_subject = "$memphis_functions_updates_" + station_name + + if len(first_functions) == 0: + self.functions_updates_data[station_name] = {} + else: + self.functions_updates_data[station_name] = first_functions + + sub = await self.broker_manager.subscribe(functions_updates_subject) + self.functions_updates_subs[station_name] = sub + + loop = asyncio.get_event_loop() + task = loop.create_task( + self.get_msg_functions_updates( + station_name, self.functions_updates_subs[station_name].messages + ) + ) + self.functions_tasks[station_name] = task + + async def get_msg_functions_updates(self, station_name, iterable): + async for msg in iterable: + message = msg.data.decode("utf-8") + message = json.loads(message) + self.functions_updates_data[station_name] = message["functions"] + async def start_listen_for_schema_updates(self, station_name, schema_update_data): schema_updates_subject = "$memphis_schema_updates_" + station_name diff --git a/memphis/producer.py b/memphis/producer.py index 7d970fe..a912c33 100644 --- a/memphis/producer.py +++ b/memphis/producer.py @@ -103,6 +103,15 @@ async def produce( else: partition_name = f"{self.internal_station_name}${str(next(self.partition_generator))}" + if self.internal_station_name in self.connection.functions_updates_data: + partition_number = partition_name.split("$")[1] + if partition_number in self.connection.functions_updates_data[self.internal_station_name]: + full_subject_name = f"{partition_name}.functions.{self.connection.functions_updates_data[self.internal_station_name][partition_number]}" + else: + full_subject_name = f"{partition_name}.final" + else: + full_subject_name = f"{partition_name}.final" + if async_produce: nonblocking = True warnings.warn("The argument async_produce is deprecated. " + \ @@ -112,7 +121,7 @@ async def produce( try: task = self.loop.create_task( self.connection.broker_connection.publish( - partition_name + ".final", + full_subject_name, message, timeout=ack_wait_sec, headers=headers, @@ -133,7 +142,7 @@ async def produce( raise MemphisError(e) else: await self.connection.broker_connection.publish( - partition_name + ".final", + full_subject_name, message, timeout=ack_wait_sec, headers=headers, @@ -251,6 +260,23 @@ async def destroy(self): if sub is not None: await sub.unsubscribe() + + self.connection.functions_clients_per_station[internal_station_name] -= 1 + if self.connection.functions_clients_per_station[internal_station_name] == 0: + if internal_station_name in self.connection.functions_updates_data: + del self.connection.functions_updates_data[internal_station_name] + if internal_station_name in self.connection.functions_updates_subs: + sub = self.connection.functions_updates_subs.get(internal_station_name) + if sub is not None: + await sub.unsubscribe() + del self.connection.functions_updates_subs[internal_station_name] + if internal_station_name in self.connection.functions_tasks: + task = self.connection.functions_tasks.get(internal_station_name) + if task is not None: + task.cancel() + del self.connection.functions_tasks[internal_station_name] + + map_key = internal_station_name + "_" + self.real_name del self.connection.producers_map[map_key] diff --git a/memphis/station.py b/memphis/station.py index 26dec82..b2bc572 100644 --- a/memphis/station.py +++ b/memphis/station.py @@ -191,6 +191,22 @@ async def destroy(self): if sub is not None: await sub.unsubscribe() + + if internal_station_name in self.connection.functions_clients_per_station: + del self.connection.functions_clients_per_station[internal_station_name] + if internal_station_name in self.connection.functions_updates_data: + del self.connection.functions_updates_data[internal_station_name] + if internal_station_name in self.connection.functions_updates_subs: + function_sub = self.connection.functions_updates_subs.get(internal_station_name) + if function_sub is not None: + await function_sub.unsubscribe() + del self.connection.functions_updates_subs[internal_station_name] + if internal_station_name in self.connection.functions_tasks: + task = self.connection.functions_tasks.get(internal_station_name) + if task is not None: + task.cancel() + del self.connection.functions_tasks[internal_station_name] + self.connection.producers_map = { k: v for k, v in self.connection.producers_map.items()