Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ await memphis.produce(station_name='test_station_py', producer_name='prod_py',
headers=headers, # default to {}
nonblocking=False, #defaults to false
msg_id="123",
producer_partition_key="key" #default to None
producer_partition_key="key", #default to None
producer_partition_number=-1, #default to -1
)
```

Expand Down Expand Up @@ -287,7 +288,18 @@ Use any string to produce messages to a specific partition
```python
await producer.produce(
message='bytearray/protobuf class/dict/string/graphql.language.ast.DocumentNode', # bytearray / protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema)
producer_partition_key="key") #default to None
producer_partition_key="key", #default to None
)
```

### Produce using partition number
Use number of partition to produce messages to a specific partition

```python
await producer.produce(
message='bytearray/protobuf class/dict/string/graphql.language.ast.DocumentNode', # bytearray / protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema)
producer_partition_number=-1 #default to -1
)
```

### Non-blocking Produce with Task Limits
Expand Down Expand Up @@ -366,6 +378,15 @@ consumer.consume(msg_handler,
)
```

### Consume using a partition number
The number will be used to consume from a specific partition

```python
consumer.consume(msg_handler,
consumer_partition_number = -1 #consume from a specific partition
)
```

### Fetch a single batch of messages
```python
msgs = await memphis.fetch_messages(
Expand All @@ -378,7 +399,8 @@ msgs = await memphis.fetch_messages(
max_msg_deliveries=10, # defaults to 10
start_consume_from_sequence=1, # start consuming from a specific sequence. defaults to 1
last_messages=-1, # consume the last N messages, defaults to -1 (all messages in the station))
consumer_partition_key="key" # used to consume from a specific partition, default to None
consumer_partition_key="key", # used to consume from a specific partition, default to None
consumer_partition_number=-1 # used to consume from a specific partition, default to -1
)
```

Expand Down
61 changes: 39 additions & 22 deletions memphis/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def set_context(self, context):
"""Set a context (dict) that will be passed to each message handler call."""
self.context = context

def consume(self, callback, consumer_partition_key: str = None):
def consume(self, callback, consumer_partition_key: str = None, consumer_partition_number: int = -1):
"""
This method starts consuming events from the specified station and invokes the provided callback function for each batch of messages received.

Expand All @@ -74,6 +74,7 @@ def consume(self, callback, consumer_partition_key: str = None):
- `error`: An optional `MemphisError` object if there was an error while consuming the messages.
- `context`: A dictionary representing the context that was set using the `set_context()` method.
consumer_partition_key (str): A string that will be used to determine the partition to consume from. If not provided, the consume will work in a Round Robin fashion.
consumer_partition_number (int): An integer that will be used to determine the partition to consume from. If not provided, the consume will work in a Round Robin fashion.

Example:
import asyncio
Expand All @@ -100,13 +101,17 @@ async def main():
asyncio.run(main())
"""
self.dls_callback_func = callback
self.t_consume = asyncio.create_task(self.__consume(callback, partition_key=consumer_partition_key))
self.t_consume = asyncio.create_task(self.__consume(callback, partition_key=consumer_partition_key, consumer_partition_number=consumer_partition_number))

async def __consume(self, callback, partition_key: str = None):
async def __consume(self, callback, partition_key: str = None, consumer_partition_number: int = -1):
partition_number = 1

if partition_key is not None:
if consumer_partition_number > 0 and partition_key is not None:
raise MemphisError('Can not use both partition number and partition key')
elif partition_key is not None:
partition_number = self.get_partition_from_key(partition_key)
elif consumer_partition_number > 0:
self.validate_partition_number(consumer_partition_number, self.inner_station_name)
partition_number = consumer_partition_number

while True:
if self.connection.is_connection_active and self.pull_interval_ms:
Expand Down Expand Up @@ -165,7 +170,7 @@ async def __consume_dls(self):
await self.dls_callback_func([], MemphisError(str(e)), self.context)
return

async def fetch(self, batch_size: int = 10, consumer_partition_key: str = None, prefetch: bool = False):
async def fetch(self, batch_size: int = 10, consumer_partition_key: str = None, consumer_partition_number: int = -1, prefetch: bool = False):
"""
Fetch a batch of messages.

Expand Down Expand Up @@ -206,13 +211,25 @@ async def main(host, username, password, station):

"""
messages = []
partition_number = 1
if len(self.subscriptions) > 1:
if consumer_partition_number > 0 and consumer_partition_key is not None:
raise MemphisError('Can not use both partition number and partition key')
elif consumer_partition_key is not None:
partition_number = self.get_partition_from_key(consumer_partition_key)
elif consumer_partition_number > 0:
self.validate_partition_number(consumer_partition_number, self.inner_station_name)
partition_number = consumer_partition_number
else:
partition_number = next(self.partition_generator)


if prefetch and len(self.cached_messages) > 0:
if len(self.cached_messages) >= batch_size:
messages = self.cached_messages[:batch_size]
self.cached_messages = self.cached_messages[batch_size:]
number_of_messages_to_prefetch = batch_size * 2 - batch_size # calculated for clarity
self.load_messages_to_cache(number_of_messages_to_prefetch)
self.load_messages_to_cache(number_of_messages_to_prefetch, partition_number)
return messages
else:
messages = self.cached_messages
Expand All @@ -235,22 +252,13 @@ async def main(host, username, password, station):
del self.dls_messages[0:batch_size]
self.dls_current_index -= len(messages)
return messages

partition_number = 1

if len(self.subscriptions) > 1:
if consumer_partition_key is not None:
partition_number = self.get_partition_from_key(consumer_partition_key)
else:
partition_number = next(self.partition_generator)

msgs = await self.subscriptions[partition_number].fetch(batch_size)
for msg in msgs:
messages.append(
Message(msg, self.connection, self.consumer_group, self.internal_station_name))
if prefetch:
number_of_messages_to_prefetch = batch_size * 2
self.load_messages_to_cache(number_of_messages_to_prefetch)
self.load_messages_to_cache(number_of_messages_to_prefetch, partition_number)
return messages
except Exception as e:
if "timeout" not in str(e).lower():
Expand Down Expand Up @@ -344,12 +352,21 @@ def get_partition_from_key(self, key):
except Exception as e:
raise e

def load_messages_to_cache(self, batch_size):
def validate_partition_number(self, partition_number, station_name):
partitions_list = self.connection.partition_consumers_updates_data[station_name]["partitions_list"]
if partitions_list is not None:
if partition_number < 0 or partition_number >= len(partitions_list):
raise MemphisError("Partition number is out of range")
elif partition_number not in partitions_list:
raise MemphisError(f"Partition {str(partition_number)} does not exist in station {station_name}")
else:
raise MemphisError(f"Partition {str(partition_number)} does not exist in station {station_name}")

def load_messages_to_cache(self, batch_size, partition_number):
if not self.loading_thread or not self.loading_thread.is_alive():
asyncio.ensure_future(self.__load_messages(batch_size))

asyncio.ensure_future(self.__load_messages(batch_size, partition_number))

async def __load_messages(self, batch_size):
new_messages = await self.fetch(batch_size)
async def __load_messages(self, batch_size, partition_number):
new_messages = await self.fetch(batch_size, consumer_partition_number=partition_number)
if new_messages is not None:
self.cached_messages.extend(new_messages)
14 changes: 10 additions & 4 deletions memphis/memphis.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,8 @@ async def produce(
headers: Union[Headers, None] = None,
async_produce: bool = False,
msg_id: Union[str, None] = None,
producer_partition_key: Union[str, None] = None
producer_partition_key: Union[str, None] = None,
producer_partition_number: Union[int, -1] = -1
):
"""Produces a message into a station without the need to create a producer.
Args:
Expand All @@ -731,6 +732,7 @@ async def produce(
async_produce (boolean, optional): produce operation won't wait for broker acknowledgement
msg_id (string, optional): Attach msg-id header to the message in order to achieve idempotence
producer_partition_key (string, optional): produce to a specific partition using the partition key
producer_partition_number (int, optional): produce to a specific partition using the partition number
Raises:
Exception: _description_
"""
Expand All @@ -752,7 +754,8 @@ async def produce(
headers=headers,
async_produce=async_produce,
msg_id=msg_id,
producer_partition_key=producer_partition_key
producer_partition_key=producer_partition_key,
producer_partition_number=producer_partition_number
)
except Exception as e:
raise MemphisError(str(e)) from e
Expand All @@ -770,6 +773,7 @@ async def fetch_messages(
start_consume_from_sequence: int = 1,
last_messages: int = -1,
consumer_partition_key: str = None,
consumer_partition_number: int = -1,
prefetch: bool = False,
):
"""Consume a batch of messages.
Expand All @@ -784,7 +788,8 @@ async def fetch_messages(
generate_random_suffix (bool): Deprecated: will be stopped to be supported after November 1'st, 2023. false by default, if true concatenate a random suffix to consumer's name
start_consume_from_sequence(int, optional): start consuming from a specific sequence. defaults to 1.
last_messages: consume the last N messages, defaults to -1 (all messages in the station).
consumer_partition_key (str): consume from a specific partition using the partition key
consumer_partition_key (str): consume from a specific partition using the partition key.
consumer_partition_number (int): consume from a specific partition using the partition number.
prefetch: false by default, if true then fetch messages from local cache (if exists) and load more messages into the cache.
Returns:
list: Message
Expand Down Expand Up @@ -814,7 +819,7 @@ async def fetch_messages(
start_consume_from_sequence=start_consume_from_sequence,
last_messages=last_messages,
)
messages = await consumer.fetch(batch_size, consumer_partition_key=consumer_partition_key, prefetch=prefetch)
messages = await consumer.fetch(batch_size, consumer_partition_key=consumer_partition_key, consumer_partition_number=consumer_partition_number, prefetch=prefetch)
if messages == None:
messages = []
return messages
Expand Down Expand Up @@ -899,3 +904,4 @@ def unset_cached_consumer_station(self, station_name):
del self.consumers_map[key]
except Exception as e:
raise e

19 changes: 18 additions & 1 deletion memphis/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ async def produce(
nonblocking: bool = False,
msg_id: Union[str, None] = None,
concurrent_task_limit: Union[int, None] = None,
producer_partition_key: Union[str, None] = None
producer_partition_key: Union[str, None] = None,
producer_partition_number: Union[int, -1] = -1
):
"""Produces a message into a station.
Args:
Expand All @@ -218,6 +219,7 @@ async def produce(
if the limit is hit and will wait until the
buffer drains halfway down.
producer_partition_key (string, optional): Produce messages to a specific partition using the partition key.
producer_partition_number (int, optional): Produce messages to a specific partition using the partition number.
Raises:
Exception: _description_
"""
Expand All @@ -242,9 +244,14 @@ async def produce(
partition_name = self.internal_station_name
elif len(self.connection.partition_producers_updates_data[self.internal_station_name]['partitions_list']) == 1:
partition_name = f"{self.internal_station_name}${self.connection.partition_producers_updates_data[self.internal_station_name]['partitions_list'][0]}"
elif producer_partition_number > 0 and producer_partition_key is not None:
raise MemphisError('Can not use both partition number and partition key')
elif producer_partition_key is not None:
partition_number = self.get_partition_from_key(producer_partition_key)
partition_name = f"{self.internal_station_name}${str(partition_number)}"
elif producer_partition_number > 0:
self.validate_partition_number(producer_partition_number, self.internal_station_name)
partition_name = f"{self.internal_station_name}${str(producer_partition_number)}"
else:
partition_name = f"{self.internal_station_name}${str(next(self.partition_generator))}"

Expand Down Expand Up @@ -408,3 +415,13 @@ def get_partition_from_key(self, key):
return self.connection.partition_producers_updates_data[self.internal_station_name]["partitions_list"][index]
except Exception as e:
raise e

def validate_partition_number(self, partition_number, station_name):
partitions_list = self.connection.partition_consumers_updates_data[station_name]["partitions_list"]
if partitions_list is not None:
if partition_number < 0 or partition_number >= len(partitions_list):
raise MemphisError("Partition number is out of range")
elif partition_number not in partitions_list:
raise MemphisError(f"Partition {str(partition_number)} does not exist in station {station_name}")
else:
raise MemphisError(f"Partition {str(partition_number)} does not exist in station {station_name}")