diff --git a/README.md b/README.md index b4f66a5..798f6bc 100644 --- a/README.md +++ b/README.md @@ -250,7 +250,8 @@ await memphis.produce(station_name='test_station_py', producer_name='prod_py', headers=headers, # default to {} nonblocking=False, #defaults to false msg_id="123", - producer_partition_key="key" #default to None + producer_partition_key="key", #default to None + producer_partition_number=-1, #default to -1 ) ``` @@ -287,7 +288,18 @@ Use any string to produce messages to a specific partition ```python await producer.produce( message='bytearray/protobuf class/dict/string/graphql.language.ast.DocumentNode', # bytearray / protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema) - producer_partition_key="key") #default to None + producer_partition_key="key", #default to None +) +``` + +### Produce using partition number +Use number of partition to produce messages to a specific partition + +```python +await producer.produce( + message='bytearray/protobuf class/dict/string/graphql.language.ast.DocumentNode', # bytearray / protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema) + producer_partition_number=-1 #default to -1 +) ``` ### Non-blocking Produce with Task Limits @@ -366,6 +378,15 @@ consumer.consume(msg_handler, ) ``` +### Consume using a partition number +The number will be used to consume from a specific partition + +```python +consumer.consume(msg_handler, + consumer_partition_number = -1 #consume from a specific partition + ) +``` + ### Fetch a single batch of messages ```python msgs = await memphis.fetch_messages( @@ -378,7 +399,8 @@ msgs = await memphis.fetch_messages( max_msg_deliveries=10, # defaults to 10 start_consume_from_sequence=1, # start consuming from a specific sequence. defaults to 1 last_messages=-1, # consume the last N messages, defaults to -1 (all messages in the station)) - consumer_partition_key="key" # used to consume from a specific partition, default to None + consumer_partition_key="key", # used to consume from a specific partition, default to None + consumer_partition_number=-1 # used to consume from a specific partition, default to -1 ) ``` diff --git a/memphis/consumer.py b/memphis/consumer.py index 1423700..1a70697 100644 --- a/memphis/consumer.py +++ b/memphis/consumer.py @@ -63,7 +63,7 @@ def set_context(self, context): """Set a context (dict) that will be passed to each message handler call.""" self.context = context - def consume(self, callback, consumer_partition_key: str = None): + def consume(self, callback, consumer_partition_key: str = None, consumer_partition_number: int = -1): """ This method starts consuming events from the specified station and invokes the provided callback function for each batch of messages received. @@ -74,6 +74,7 @@ def consume(self, callback, consumer_partition_key: str = None): - `error`: An optional `MemphisError` object if there was an error while consuming the messages. - `context`: A dictionary representing the context that was set using the `set_context()` method. consumer_partition_key (str): A string that will be used to determine the partition to consume from. If not provided, the consume will work in a Round Robin fashion. + consumer_partition_number (int): An integer that will be used to determine the partition to consume from. If not provided, the consume will work in a Round Robin fashion. Example: import asyncio @@ -100,13 +101,17 @@ async def main(): asyncio.run(main()) """ self.dls_callback_func = callback - self.t_consume = asyncio.create_task(self.__consume(callback, partition_key=consumer_partition_key)) + self.t_consume = asyncio.create_task(self.__consume(callback, partition_key=consumer_partition_key, consumer_partition_number=consumer_partition_number)) - async def __consume(self, callback, partition_key: str = None): + async def __consume(self, callback, partition_key: str = None, consumer_partition_number: int = -1): partition_number = 1 - - if partition_key is not None: + if consumer_partition_number > 0 and partition_key is not None: + raise MemphisError('Can not use both partition number and partition key') + elif partition_key is not None: partition_number = self.get_partition_from_key(partition_key) + elif consumer_partition_number > 0: + self.validate_partition_number(consumer_partition_number, self.inner_station_name) + partition_number = consumer_partition_number while True: if self.connection.is_connection_active and self.pull_interval_ms: @@ -165,7 +170,7 @@ async def __consume_dls(self): await self.dls_callback_func([], MemphisError(str(e)), self.context) return - async def fetch(self, batch_size: int = 10, consumer_partition_key: str = None, prefetch: bool = False): + async def fetch(self, batch_size: int = 10, consumer_partition_key: str = None, consumer_partition_number: int = -1, prefetch: bool = False): """ Fetch a batch of messages. @@ -206,13 +211,25 @@ async def main(host, username, password, station): """ messages = [] + partition_number = 1 + if len(self.subscriptions) > 1: + if consumer_partition_number > 0 and consumer_partition_key is not None: + raise MemphisError('Can not use both partition number and partition key') + elif consumer_partition_key is not None: + partition_number = self.get_partition_from_key(consumer_partition_key) + elif consumer_partition_number > 0: + self.validate_partition_number(consumer_partition_number, self.inner_station_name) + partition_number = consumer_partition_number + else: + partition_number = next(self.partition_generator) + if prefetch and len(self.cached_messages) > 0: if len(self.cached_messages) >= batch_size: messages = self.cached_messages[:batch_size] self.cached_messages = self.cached_messages[batch_size:] number_of_messages_to_prefetch = batch_size * 2 - batch_size # calculated for clarity - self.load_messages_to_cache(number_of_messages_to_prefetch) + self.load_messages_to_cache(number_of_messages_to_prefetch, partition_number) return messages else: messages = self.cached_messages @@ -235,22 +252,13 @@ async def main(host, username, password, station): del self.dls_messages[0:batch_size] self.dls_current_index -= len(messages) return messages - - partition_number = 1 - - if len(self.subscriptions) > 1: - if consumer_partition_key is not None: - partition_number = self.get_partition_from_key(consumer_partition_key) - else: - partition_number = next(self.partition_generator) - msgs = await self.subscriptions[partition_number].fetch(batch_size) for msg in msgs: messages.append( Message(msg, self.connection, self.consumer_group, self.internal_station_name)) if prefetch: number_of_messages_to_prefetch = batch_size * 2 - self.load_messages_to_cache(number_of_messages_to_prefetch) + self.load_messages_to_cache(number_of_messages_to_prefetch, partition_number) return messages except Exception as e: if "timeout" not in str(e).lower(): @@ -344,12 +352,21 @@ def get_partition_from_key(self, key): except Exception as e: raise e - def load_messages_to_cache(self, batch_size): + def validate_partition_number(self, partition_number, station_name): + partitions_list = self.connection.partition_consumers_updates_data[station_name]["partitions_list"] + if partitions_list is not None: + if partition_number < 0 or partition_number >= len(partitions_list): + raise MemphisError("Partition number is out of range") + elif partition_number not in partitions_list: + raise MemphisError(f"Partition {str(partition_number)} does not exist in station {station_name}") + else: + raise MemphisError(f"Partition {str(partition_number)} does not exist in station {station_name}") + + def load_messages_to_cache(self, batch_size, partition_number): if not self.loading_thread or not self.loading_thread.is_alive(): - asyncio.ensure_future(self.__load_messages(batch_size)) - + asyncio.ensure_future(self.__load_messages(batch_size, partition_number)) - async def __load_messages(self, batch_size): - new_messages = await self.fetch(batch_size) + async def __load_messages(self, batch_size, partition_number): + new_messages = await self.fetch(batch_size, consumer_partition_number=partition_number) if new_messages is not None: self.cached_messages.extend(new_messages) diff --git a/memphis/memphis.py b/memphis/memphis.py index bc7a1b4..a4c27b3 100644 --- a/memphis/memphis.py +++ b/memphis/memphis.py @@ -718,7 +718,8 @@ async def produce( headers: Union[Headers, None] = None, async_produce: bool = False, msg_id: Union[str, None] = None, - producer_partition_key: Union[str, None] = None + producer_partition_key: Union[str, None] = None, + producer_partition_number: Union[int, -1] = -1 ): """Produces a message into a station without the need to create a producer. Args: @@ -731,6 +732,7 @@ async def produce( async_produce (boolean, optional): produce operation won't wait for broker acknowledgement msg_id (string, optional): Attach msg-id header to the message in order to achieve idempotence producer_partition_key (string, optional): produce to a specific partition using the partition key + producer_partition_number (int, optional): produce to a specific partition using the partition number Raises: Exception: _description_ """ @@ -752,7 +754,8 @@ async def produce( headers=headers, async_produce=async_produce, msg_id=msg_id, - producer_partition_key=producer_partition_key + producer_partition_key=producer_partition_key, + producer_partition_number=producer_partition_number ) except Exception as e: raise MemphisError(str(e)) from e @@ -770,6 +773,7 @@ async def fetch_messages( start_consume_from_sequence: int = 1, last_messages: int = -1, consumer_partition_key: str = None, + consumer_partition_number: int = -1, prefetch: bool = False, ): """Consume a batch of messages. @@ -784,7 +788,8 @@ async def fetch_messages( generate_random_suffix (bool): Deprecated: will be stopped to be supported after November 1'st, 2023. false by default, if true concatenate a random suffix to consumer's name start_consume_from_sequence(int, optional): start consuming from a specific sequence. defaults to 1. last_messages: consume the last N messages, defaults to -1 (all messages in the station). - consumer_partition_key (str): consume from a specific partition using the partition key + consumer_partition_key (str): consume from a specific partition using the partition key. + consumer_partition_number (int): consume from a specific partition using the partition number. prefetch: false by default, if true then fetch messages from local cache (if exists) and load more messages into the cache. Returns: list: Message @@ -814,7 +819,7 @@ async def fetch_messages( start_consume_from_sequence=start_consume_from_sequence, last_messages=last_messages, ) - messages = await consumer.fetch(batch_size, consumer_partition_key=consumer_partition_key, prefetch=prefetch) + messages = await consumer.fetch(batch_size, consumer_partition_key=consumer_partition_key, consumer_partition_number=consumer_partition_number, prefetch=prefetch) if messages == None: messages = [] return messages @@ -899,3 +904,4 @@ def unset_cached_consumer_station(self, station_name): del self.consumers_map[key] except Exception as e: raise e + \ No newline at end of file diff --git a/memphis/producer.py b/memphis/producer.py index eba9c52..3e5fac9 100644 --- a/memphis/producer.py +++ b/memphis/producer.py @@ -193,7 +193,8 @@ async def produce( nonblocking: bool = False, msg_id: Union[str, None] = None, concurrent_task_limit: Union[int, None] = None, - producer_partition_key: Union[str, None] = None + producer_partition_key: Union[str, None] = None, + producer_partition_number: Union[int, -1] = -1 ): """Produces a message into a station. Args: @@ -218,6 +219,7 @@ async def produce( if the limit is hit and will wait until the buffer drains halfway down. producer_partition_key (string, optional): Produce messages to a specific partition using the partition key. + producer_partition_number (int, optional): Produce messages to a specific partition using the partition number. Raises: Exception: _description_ """ @@ -242,9 +244,14 @@ async def produce( partition_name = self.internal_station_name elif len(self.connection.partition_producers_updates_data[self.internal_station_name]['partitions_list']) == 1: partition_name = f"{self.internal_station_name}${self.connection.partition_producers_updates_data[self.internal_station_name]['partitions_list'][0]}" + elif producer_partition_number > 0 and producer_partition_key is not None: + raise MemphisError('Can not use both partition number and partition key') elif producer_partition_key is not None: partition_number = self.get_partition_from_key(producer_partition_key) partition_name = f"{self.internal_station_name}${str(partition_number)}" + elif producer_partition_number > 0: + self.validate_partition_number(producer_partition_number, self.internal_station_name) + partition_name = f"{self.internal_station_name}${str(producer_partition_number)}" else: partition_name = f"{self.internal_station_name}${str(next(self.partition_generator))}" @@ -408,3 +415,13 @@ def get_partition_from_key(self, key): return self.connection.partition_producers_updates_data[self.internal_station_name]["partitions_list"][index] except Exception as e: raise e + + def validate_partition_number(self, partition_number, station_name): + partitions_list = self.connection.partition_consumers_updates_data[station_name]["partitions_list"] + if partitions_list is not None: + if partition_number < 0 or partition_number >= len(partitions_list): + raise MemphisError("Partition number is out of range") + elif partition_number not in partitions_list: + raise MemphisError(f"Partition {str(partition_number)} does not exist in station {station_name}") + else: + raise MemphisError(f"Partition {str(partition_number)} does not exist in station {station_name}")