diff --git a/.DS_Store b/.DS_Store index df48a7b..6a953a7 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/.gitignore b/.gitignore index 9ab6fe8..867d557 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,5 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +#Mac +.DS_Store diff --git a/README.md b/README.md index 1dec449..3b0e92a 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,7 @@ station = memphis.station( send_poison_msg_to_dls=True, # defaults to true send_schema_failed_msg_to_dls=True, # defaults to true tiered_storage_enabled=False # defaults to false + partitions_number=1 # default to 1 ) ``` @@ -172,6 +173,11 @@ memphis.types.Storage.MEMORY Means that messages persist on the main memory +### Station partitions + +Memphis station is created with 1 patition by default +You can change the patitions number as you wish in order to scale your stations + ### Destroying a Station Destroying a station will remove all its resources (producers/consumers) @@ -219,6 +225,8 @@ Memphis messages are payload agnostic. Payloads are `bytearray`. In order to stop getting messages, you have to call `consumer.destroy()`. Destroy will terminate regardless of whether there are messages in flight for the client. +If a station is created with more than one partition, produce and consume bill be perform in a Round Robin fasion + ### Creating a Producer ```python diff --git a/memphis/consumer.py b/memphis/consumer.py index c1731ab..d0a3875 100644 --- a/memphis/consumer.py +++ b/memphis/consumer.py @@ -6,6 +6,7 @@ from memphis.exceptions import MemphisError from memphis.utils import default_error_handler, get_internal_name from memphis.message import Message +from memphis.partition_generator import PartitionGenerator class Consumer: @@ -47,6 +48,11 @@ def __init__( self.dls_callback_func = None self.t_dls = asyncio.create_task(self.__consume_dls()) self.t_consume = None + self.inner_station_name = get_internal_name(self.station_name) + self.subscriptions = {} + if self.inner_station_name in connection.partition_consumers_updates_data: + self.partition_generator = PartitionGenerator(connection.partition_consumers_updates_data[self.inner_station_name]["partitions_list"]) + def set_context(self, context): """Set a context (dict) that will be passed to each message handler call.""" @@ -91,16 +97,29 @@ async def main(): self.t_consume = asyncio.create_task(self.__consume(callback)) async def __consume(self, callback): - subject = get_internal_name(self.station_name) - consumer_group = get_internal_name(self.consumer_group) - self.psub = await self.connection.broker_connection.pull_subscribe( - subject + ".final", durable=consumer_group - ) + if self.inner_station_name not in self.connection.partition_consumers_updates_data: + subject = self.inner_station_name + ".final" + consumer_group = get_internal_name(self.consumer_group) + psub = await self.connection.broker_connection.pull_subscribe(subject, durable=consumer_group) + self.subscriptions[1] = psub + else: + for p in self.connection.partition_consumers_updates_data[self.inner_station_name]["partitions_list"]: + subject = f"{self.inner_station_name}${str(p)}.final" + consumer_group = get_internal_name(self.consumer_group) + psub = await self.connection.broker_connection.pull_subscribe(subject, durable=consumer_group) + self.subscriptions[p] = psub + + partition_number = 1 + while True: if self.connection.is_connection_active and self.pull_interval_ms: try: + if len(self.subscriptions) > 1: + partition_number = next(self.partition_generator) + memphis_messages = [] - msgs = await self.psub.fetch(self.batch_size) + msgs = await self.subscriptions[partition_number].fetch(self.batch_size) + for msg in msgs: memphis_messages.append( Message(msg, self.connection, self.consumer_group) @@ -230,10 +249,19 @@ async def __ping_consumer(self, callback): while True: try: await asyncio.sleep(self.ping_consumer_interval_ms / 1000) + station_inner = get_internal_name(self.station_name) consumer_group = get_internal_name(self.consumer_group) - await self.connection.broker_connection.consumer_info( - self.station_name, consumer_group, timeout=30 - ) + if self.inner_station_name not in self.connection.partition_consumers_updates_data: + for p in self.connection.partition_consumers_updates_data[station_inner]["partitions_list"]: + stream_name = f"{station_inner}${str(p)}.final" + await self.connection.broker_connection.consumer_info( + stream_name, consumer_group, timeout=30 + ) + else: + stream_name = f"{station_inner}.final" + await self.connection.broker_connection.consumer_info( + stream_name, consumer_group, timeout=30 + ) except Exception as e: callback(MemphisError(str(e))) diff --git a/memphis/memphis.py b/memphis/memphis.py index 0807f22..82c02eb 100644 --- a/memphis/memphis.py +++ b/memphis/memphis.py @@ -42,6 +42,8 @@ class Memphis: def __init__(self): self.is_connection_active = False self.schema_updates_data = {} + self.partition_producers_updates_data = {} + self.partition_consumers_updates_data = {} self.schema_updates_subs = {} self.producers_per_station = {} self.schema_tasks = {} @@ -230,6 +232,7 @@ async def station( send_poison_msg_to_dls: bool = True, send_schema_failed_msg_to_dls: bool = True, tiered_storage_enabled: bool = False, + partitions_number = 1, ): """Creates a station. Args: @@ -246,6 +249,8 @@ async def station( try: if not self.is_connection_active: raise MemphisError("Connection is dead") + if partitions_number == 0: + partitions_number = 1 create_station_req = { "name": name, @@ -260,7 +265,8 @@ async def station( "Schemaverse": send_schema_failed_msg_to_dls, }, "username": self.username, - "tiered_storage_enabled": tiered_storage_enabled + "tiered_storage_enabled": tiered_storage_enabled, + "partitions_number" : partitions_number } create_station_req_bytes = json.dumps(create_station_req, indent=2).encode( "utf-8" @@ -420,6 +426,12 @@ async def producer( raise MemphisError(create_res["error"]) internal_station_name = get_internal_name(station_name) + + if create_res["partitions_update"]["partitions_list"] is not None: + self.partition_producers_updates_data[internal_station_name] = create_res[ + "partitions_update" + ] + self.station_schemaverse_to_dls[internal_station_name] = create_res[ "schemaverse_to_dls" ] @@ -599,15 +611,21 @@ async def consumer( create_consumer_req_bytes = json.dumps(create_consumer_req, indent=2).encode( "utf-8" ) - err_msg = await self.broker_manager.request( + creation_res = await self.broker_manager.request( "$memphis_consumer_creations", create_consumer_req_bytes, timeout=5 ) - err_msg = err_msg.data.decode("utf-8") + creation_res = creation_res.data.decode("utf-8") + creation_res = json.loads(creation_res) - if err_msg != "": - raise MemphisError(err_msg) + if creation_res["error"] != "": + raise MemphisError(creation_res["error"]) internal_station_name = get_internal_name(station_name) + + if creation_res["partitions_update"]["partitions_list"] is not None: + self.partition_consumers_updates_data[internal_station_name] = creation_res["partitions_update"] + + map_key = internal_station_name + "_" + real_name consumer = Consumer( self, diff --git a/memphis/partition_generator.py b/memphis/partition_generator.py new file mode 100644 index 0000000..57b9ad3 --- /dev/null +++ b/memphis/partition_generator.py @@ -0,0 +1,14 @@ +#The PartitionGenerator class is used to create a round robin generator for station's partitions +#the class gets a list of partitions and by using the next() function it returns an item from the list + +class PartitionGenerator: + def __init__(self, partitions_list): + self.partitions_list = partitions_list + self.current = 0 + self.length = len(partitions_list) + + def __next__(self): + partition_to_return = self.partitions_list[self.current] + self.current = (self.current + 1) % self.length + return partition_to_return + \ No newline at end of file diff --git a/memphis/producer.py b/memphis/producer.py index ab4bd34..282684e 100644 --- a/memphis/producer.py +++ b/memphis/producer.py @@ -15,6 +15,7 @@ from memphis.exceptions import MemphisError, MemphisSchemaError from memphis.headers import Headers from memphis.utils import get_internal_name +from memphis.partition_generator import PartitionGenerator schemaverse_fail_alert_type = "schema_validation_fail_alert" @@ -30,6 +31,8 @@ def __init__( self.loop = asyncio.get_running_loop() self.real_name = real_name self.background_tasks = set() + if self.internal_station_name in connection.partition_producers_updates_data: + self.partition_generator = PartitionGenerator(connection.partition_producers_updates_data[self.internal_station_name]["partitions_list"]) async def validate_msg(self, message): if self.connection.schema_updates_data[self.internal_station_name] != {}: @@ -232,6 +235,11 @@ async def produce( else: headers = memphis_headers + if self.internal_station_name not in self.connection.partition_producers_updates_data: + partition_name = self.internal_station_name + else: + partition_name = f"{self.internal_station_name}${str(next(self.partition_generator))}" + if async_produce: nonblocking = True warnings.warn("The argument async_produce is deprecated. " + \ @@ -241,7 +249,7 @@ async def produce( try: task = self.loop.create_task( self.connection.broker_connection.publish( - self.internal_station_name + ".final", + partition_name + ".final", message, timeout=ack_wait_sec, headers=headers, @@ -262,7 +270,7 @@ async def produce( raise MemphisError(e) else: await self.connection.broker_connection.publish( - self.internal_station_name + ".final", + partition_name + ".final", message, timeout=ack_wait_sec, headers=headers,