diff --git a/README.md b/README.md index f2e859a..b6bf82b 100644 --- a/README.md +++ b/README.md @@ -243,7 +243,8 @@ await memphis.produce(station_name='test_station_py', producer_name='prod_py', ack_wait_sec=15, # defaults to 15 headers=headers, # default to {} nonblocking=False, #defaults to false - msg_id="123" + msg_id="123", + producer_partition_key="key" #default to None ) ``` @@ -274,6 +275,15 @@ await producer.produce( headers={}, nonblocking=True) ``` +### Produce using partition key +Use any string to produce messages to a specific partition + +```python +await producer.produce( + message='bytearray/protobuf class/dict/string/graphql.language.ast.DocumentNode', # bytearray / protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema) + producer_partition_key="key") #default to None +``` + ### Non-blocking Produce with Task Limits For better performance, the client won't block requests while waiting for an acknowledgment. If you are producing a large number of messages and see timeout errors, then you may need to @@ -341,6 +351,15 @@ async def msg_handler(msgs, error, context): consumer.consume(msg_handler) ``` +### Consume using a partition key +The key will be used to consume from a specific partition + +```python +consumer.consume(msg_handler, + consumer_partition_key = "key" #consume from a specific partition + ) +``` + ### Fetch a single batch of messages ```python msgs = await memphis.fetch_messages( @@ -351,8 +370,9 @@ msgs = await memphis.fetch_messages( batch_max_time_to_wait_ms=5000, # defaults to 5000 max_ack_time_ms=30000, # defaults to 30000 max_msg_deliveries=10, # defaults to 10 - start_consume_from_sequence=1 # start consuming from a specific sequence. defaults to 1 - last_messages=-1 # consume the last N messages, defaults to -1 (all messages in the station)) + start_consume_from_sequence=1, # start consuming from a specific sequence. defaults to 1 + last_messages=-1, # consume the last N messages, defaults to -1 (all messages in the station)) + consumer_partition_key="key" # used to consume from a specific partition, default to None ) ``` diff --git a/memphis/consumer.py b/memphis/consumer.py index 693512b..9952b01 100644 --- a/memphis/consumer.py +++ b/memphis/consumer.py @@ -2,6 +2,7 @@ import asyncio import json +import mmh3 from memphis.exceptions import MemphisError from memphis.utils import default_error_handler, get_internal_name @@ -26,6 +27,8 @@ def __init__( error_callback=None, start_consume_from_sequence: int = 1, last_messages: int = -1, + partition_generator: PartitionGenerator = None, + subscriptions: dict = None, ): self.connection = connection self.station_name = station_name.lower() @@ -49,16 +52,15 @@ def __init__( self.t_dls = asyncio.create_task(self.__consume_dls()) self.t_consume = None self.inner_station_name = get_internal_name(self.station_name) - self.subscriptions = {} - if self.inner_station_name in connection.partition_consumers_updates_data: - self.partition_generator = PartitionGenerator(connection.partition_consumers_updates_data[self.inner_station_name]["partitions_list"]) + self.subscriptions = subscriptions + self.partition_generator = partition_generator def set_context(self, context): """Set a context (dict) that will be passed to each message handler call.""" self.context = context - def consume(self, callback): + def consume(self, callback, consumer_partition_key: str = None): """ This method starts consuming events from the specified station and invokes the provided callback function for each batch of messages received. @@ -68,6 +70,7 @@ def consume(self, callback): - `messages`: A list of `Message` objects representing the batch of messages received. - `error`: An optional `MemphisError` object if there was an error while consuming the messages. - `context`: A dictionary representing the context that was set using the `set_context()` method. + consumer_partition_key (str): A string that will be used to determine the partition to consume from. If not provided, the consume will work in a Round Robin fashion. Example: import asyncio @@ -94,28 +97,20 @@ async def main(): asyncio.run(main()) """ self.dls_callback_func = callback - self.t_consume = asyncio.create_task(self.__consume(callback)) - - async def __consume(self, callback): - if self.inner_station_name not in self.connection.partition_consumers_updates_data: - subject = self.inner_station_name + ".final" - consumer_group = get_internal_name(self.consumer_group) - psub = await self.connection.broker_connection.pull_subscribe(subject, durable=consumer_group) - self.subscriptions[1] = psub - else: - for p in self.connection.partition_consumers_updates_data[self.inner_station_name]["partitions_list"]: - subject = f"{self.inner_station_name}${str(p)}.final" - consumer_group = get_internal_name(self.consumer_group) - psub = await self.connection.broker_connection.pull_subscribe(subject, durable=consumer_group) - self.subscriptions[p] = psub + self.t_consume = asyncio.create_task(self.__consume(callback, partition_key=consumer_partition_key)) + async def __consume(self, callback, partition_key: str = None): partition_number = 1 + if partition_key is not None: + partition_number = self.get_partition_from_key(partition_key) + while True: if self.connection.is_connection_active and self.pull_interval_ms: try: if len(self.subscriptions) > 1: - partition_number = next(self.partition_generator) + if partition_key is None: + partition_number = next(self.partition_generator) memphis_messages = [] msgs = await self.subscriptions[partition_number].fetch(self.batch_size) @@ -167,7 +162,7 @@ async def __consume_dls(self): await self.dls_callback_func([], MemphisError(str(e)), self.context) return - async def fetch(self, batch_size: int = 10): + async def fetch(self, batch_size: int = 10, consumer_partition_key: str = None): """ Fetch a batch of messages. @@ -225,22 +220,14 @@ async def main(host, username, password, station): self.dls_current_index -= len(messages) return messages - subject = get_internal_name(self.station_name) - if len(self.subscriptions) == 0: - if self.inner_station_name not in self.connection.partition_consumers_updates_data: - subject = self.inner_station_name + ".final" - consumer_group = get_internal_name(self.consumer_group) - psub = await self.connection.broker_connection.pull_subscribe(subject, durable=consumer_group) - self.subscriptions[1] = psub - else: - for p in self.connection.partition_consumers_updates_data[self.inner_station_name]["partitions_list"]: - subject = f"{self.inner_station_name}${str(p)}.final" - consumer_group = get_internal_name(self.consumer_group) - psub = await self.connection.broker_connection.pull_subscribe(subject, durable=consumer_group) - self.subscriptions[p] = psub partition_number = 1 + if len(self.subscriptions) > 1: - partition_number = next(self.partition_generator) + if consumer_partition_key is not None: + partition_number = self.get_partition_from_key(consumer_partition_key) + else: + partition_number = next(self.partition_generator) + msgs = await self.subscriptions[partition_number].fetch(batch_size) for msg in msgs: messages.append( @@ -305,3 +292,10 @@ async def destroy(self): del self.connection.consumers_map[map_key] except Exception as e: raise MemphisError(str(e)) from e + + def get_partition_from_key(self, key): + try: + index = mmh3.hash(key, self.connection.SEED, signed=False) % len(self.subscriptions) + return self.connection.partition_consumers_updates_data[self.inner_station_name]["partitions_list"][index] + except Exception as e: + raise e diff --git a/memphis/memphis.py b/memphis/memphis.py index a519ab1..2d5c578 100644 --- a/memphis/memphis.py +++ b/memphis/memphis.py @@ -34,12 +34,14 @@ from memphis.station import Station from memphis.types import Retention, Storage from memphis.utils import get_internal_name, random_bytes +from memphis.partition_generator import PartitionGenerator app_id = str(uuid.uuid4()) class Memphis: MAX_BATCH_SIZE = 5000 MEMPHIS_GLOBAL_ACCOUNT_NAME = "$memphis" + SEED = 1234 def __init__(self): self.is_connection_active = False @@ -646,6 +648,27 @@ async def consumer( except: raise MemphisError(creation_res) + inner_station_name = get_internal_name(station_name.lower()) + + partition_generator = None + + if inner_station_name in self.partition_consumers_updates_data: + partition_generator = PartitionGenerator(self.partition_consumers_updates_data[inner_station_name]["partitions_list"]) + + consumer_group = get_internal_name(cg.lower()) + subscriptions = {} + + if inner_station_name not in self.partition_consumers_updates_data: + subject = inner_station_name + ".final" + psub = await self.broker_connection.pull_subscribe(subject, durable=consumer_group) + subscriptions[1] = psub + else: + for p in self.partition_consumers_updates_data[inner_station_name]["partitions_list"]: + subject = f"{inner_station_name}${str(p)}.final" + psub = await self.broker_connection.pull_subscribe(subject, durable=consumer_group) + subscriptions[p] = psub + + internal_station_name = get_internal_name(station_name) map_key = internal_station_name + "_" + real_name consumer = Consumer( @@ -660,6 +683,8 @@ async def consumer( max_msg_deliveries, start_consume_from_sequence=start_consume_from_sequence, last_messages=last_messages, + partition_generator=partition_generator, + subscriptions=subscriptions, ) self.consumers_map[map_key] = consumer return consumer @@ -676,6 +701,7 @@ async def produce( headers: Union[Headers, None] = None, async_produce: bool = False, msg_id: Union[str, None] = None, + producer_partition_key: Union[str, None] = None ): """Produces a message into a station without the need to create a producer. Args: @@ -687,6 +713,7 @@ async def produce( headers (dict, optional): Message headers, defaults to {}. async_produce (boolean, optional): produce operation won't wait for broker acknowledgement msg_id (string, optional): Attach msg-id header to the message in order to achieve idempotence + producer_partition_key (string, optional): produce to a specific partition using the partition key Raises: Exception: _description_ """ @@ -708,6 +735,7 @@ async def produce( headers=headers, async_produce=async_produce, msg_id=msg_id, + producer_partition_key=producer_partition_key ) except Exception as e: raise MemphisError(str(e)) from e @@ -724,6 +752,7 @@ async def fetch_messages( generate_random_suffix: bool = False, start_consume_from_sequence: int = 1, last_messages: int = -1, + consumer_partition_key: str = None, ): """Consume a batch of messages. Args:. @@ -737,6 +766,7 @@ async def fetch_messages( generate_random_suffix (bool): Deprecated: will be stopped to be supported after November 1'st, 2023. false by default, if true concatenate a random suffix to consumer's name start_consume_from_sequence(int, optional): start consuming from a specific sequence. defaults to 1. last_messages: consume the last N messages, defaults to -1 (all messages in the station). + consumer_partition_key (str): consume from a specific partition using the partition key Returns: list: Message """ @@ -765,7 +795,7 @@ async def fetch_messages( start_consume_from_sequence=start_consume_from_sequence, last_messages=last_messages, ) - messages = await consumer.fetch(batch_size) + messages = await consumer.fetch(batch_size, consumer_partition_key=consumer_partition_key) if messages == None: messages = [] return messages diff --git a/memphis/producer.py b/memphis/producer.py index 6310793..33729f2 100644 --- a/memphis/producer.py +++ b/memphis/producer.py @@ -12,6 +12,7 @@ from jsonschema import validate import google.protobuf.json_format as protobuf_json_format import fastavro +import mmh3 from memphis.exceptions import MemphisError, MemphisSchemaError from memphis.headers import Headers from memphis.utils import get_internal_name @@ -191,7 +192,8 @@ async def produce( async_produce: Union[bool, None] = None, nonblocking: bool = False, msg_id: Union[str, None] = None, - concurrent_task_limit: Union[int, None] = None + concurrent_task_limit: Union[int, None] = None, + producer_partition_key: Union[str, None] = None ): """Produces a message into a station. Args: @@ -215,6 +217,7 @@ async def produce( tasks. Calls with nonblocking=True will block if the limit is hit and will wait until the buffer drains halfway down. + producer_partition_key (string, optional): Produce messages to a specific partition using the partition key. Raises: Exception: _description_ """ @@ -239,6 +242,9 @@ async def produce( partition_name = self.internal_station_name elif len(self.connection.partition_producers_updates_data[self.internal_station_name]['partitions_list']) == 1: partition_name = f"{self.internal_station_name}${self.connection.partition_producers_updates_data[self.internal_station_name]['partitions_list'][0]}" + elif producer_partition_key is not None: + partition_number = self.get_partition_from_key(producer_partition_key) + partition_name = f"{self.internal_station_name}${str(partition_number)}" else: partition_name = f"{self.internal_station_name}${str(next(self.partition_generator))}" @@ -395,3 +401,10 @@ async def destroy(self): except Exception as e: raise Exception(e) + + def get_partition_from_key(self, key): + try: + index = mmh3.hash(key, self.connection.SEED, signed=False) % len(self.connection.partition_producers_updates_data[self.internal_station_name]["partitions_list"]) + return self.connection.partition_producers_updates_data[self.internal_station_name]["partitions_list"][index] + except Exception as e: + raise e diff --git a/setup.py b/setup.py index c29f766..b485690 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ url="https://github.com/memphisdev/memphis.py", download_url="https://github.com/memphisdev/memphis.py/archive/refs/tags/1.1.2.tar.gz", keywords=["message broker", "devtool", "streaming", "data"], - install_requires=["asyncio", "nats-py", "protobuf", "jsonschema", "graphql-core", "fastavro"], + install_requires=["asyncio", "nats-py", "protobuf", "jsonschema", "graphql-core", "fastavro", "mmh3"], classifiers=[ "Development Status :: 4 - Beta", "Intended Audience :: Developers",