Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions memphis/memphis.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def __init__(self):
self.schema_updates_data = {}
self.partition_producers_updates_data = {}
self.partition_consumers_updates_data = {}
self.functions_updates_data = {}
self.functions_updates_subs = {}
self.functions_tasks = {}
self.functions_clients_per_station = {}
self.schema_updates_subs = {}
self.clients_per_station = {}
self.schema_tasks = {}
Expand Down Expand Up @@ -468,6 +472,10 @@ async def producer(

self.update_schema_data(station_name)

if "station_version" in create_res:
if create_res["station_version"] > 0:
await self.start_listen_for_functions_updates(internal_station_name, create_res["station_partitions_first_functions"])

producer = Producer(self, producer_name, station_name, real_name)
map_key = internal_station_name + "_" + real_name
self.producers_map[map_key] = producer
Expand Down Expand Up @@ -545,6 +553,39 @@ def parse_descriptor(self, station_name):
except Exception as e:
raise MemphisError(str(e)) from e

async def start_listen_for_functions_updates(self, station_name, first_functions):
#first_functions should contain the dict of the first function of each partition key: partition number, value: first function id

if station_name in self.functions_updates_subs:
self.functions_clients_per_station[station_name] += 1
return
else:
self.functions_clients_per_station[station_name] = 1

functions_updates_subject = "$memphis_functions_updates_" + station_name

if len(first_functions) == 0:
self.functions_updates_data[station_name] = {}
else:
self.functions_updates_data[station_name] = first_functions

sub = await self.broker_manager.subscribe(functions_updates_subject)
self.functions_updates_subs[station_name] = sub

loop = asyncio.get_event_loop()
task = loop.create_task(
self.get_msg_functions_updates(
station_name, self.functions_updates_subs[station_name].messages
)
)
self.functions_tasks[station_name] = task
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please cancel these tasks and stop all functions operation where needed, when destroying producers/stations and on connection close, like we are doing in Schemaverse

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I handle the destruction in producer.py line: 264


async def get_msg_functions_updates(self, station_name, iterable):
async for msg in iterable:
message = msg.data.decode("utf-8")
message = json.loads(message)
self.functions_updates_data[station_name] = message["functions"]

async def start_listen_for_schema_updates(self, station_name, schema_update_data):
schema_updates_subject = "$memphis_schema_updates_" + station_name

Expand Down
30 changes: 28 additions & 2 deletions memphis/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,15 @@ async def produce(
else:
partition_name = f"{self.internal_station_name}${str(next(self.partition_generator))}"

if self.internal_station_name in self.connection.functions_updates_data:
partition_number = partition_name.split("$")[1]
if partition_number in self.connection.functions_updates_data[self.internal_station_name]:
full_subject_name = f"{partition_name}.functions.{self.connection.functions_updates_data[self.internal_station_name][partition_number]}"
else:
full_subject_name = f"{partition_name}.final"
else:
full_subject_name = f"{partition_name}.final"

if async_produce:
nonblocking = True
warnings.warn("The argument async_produce is deprecated. " + \
Expand All @@ -112,7 +121,7 @@ async def produce(
try:
task = self.loop.create_task(
self.connection.broker_connection.publish(
partition_name + ".final",
full_subject_name,
message,
timeout=ack_wait_sec,
headers=headers,
Expand All @@ -133,7 +142,7 @@ async def produce(
raise MemphisError(e)
else:
await self.connection.broker_connection.publish(
partition_name + ".final",
full_subject_name,
message,
timeout=ack_wait_sec,
headers=headers,
Expand Down Expand Up @@ -251,6 +260,23 @@ async def destroy(self):
if sub is not None:
await sub.unsubscribe()


self.connection.functions_clients_per_station[internal_station_name] -= 1
if self.connection.functions_clients_per_station[internal_station_name] == 0:
if internal_station_name in self.connection.functions_updates_data:
del self.connection.functions_updates_data[internal_station_name]
if internal_station_name in self.connection.functions_updates_subs:
sub = self.connection.functions_updates_subs.get(internal_station_name)
if sub is not None:
await sub.unsubscribe()
del self.connection.functions_updates_subs[internal_station_name]
if internal_station_name in self.connection.functions_tasks:
task = self.connection.functions_tasks.get(internal_station_name)
if task is not None:
task.cancel()
del self.connection.functions_tasks[internal_station_name]


map_key = internal_station_name + "_" + self.real_name
del self.connection.producers_map[map_key]

Expand Down
16 changes: 16 additions & 0 deletions memphis/station.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,22 @@ async def destroy(self):
if sub is not None:
await sub.unsubscribe()


if internal_station_name in self.connection.functions_clients_per_station:
del self.connection.functions_clients_per_station[internal_station_name]
if internal_station_name in self.connection.functions_updates_data:
del self.connection.functions_updates_data[internal_station_name]
if internal_station_name in self.connection.functions_updates_subs:
function_sub = self.connection.functions_updates_subs.get(internal_station_name)
if function_sub is not None:
await function_sub.unsubscribe()
del self.connection.functions_updates_subs[internal_station_name]
if internal_station_name in self.connection.functions_tasks:
task = self.connection.functions_tasks.get(internal_station_name)
if task is not None:
task.cancel()
del self.connection.functions_tasks[internal_station_name]

self.connection.producers_map = {
k: v
for k, v in self.connection.producers_map.items()
Expand Down