Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ await memphis.produce(station_name='test_station_py', producer_name='prod_py',
ack_wait_sec=15, # defaults to 15
headers=headers, # default to {}
nonblocking=False, #defaults to false
msg_id="123"
msg_id="123",
producer_partition_key="key" #default to None
)
```

Expand Down Expand Up @@ -274,6 +275,15 @@ await producer.produce(
headers={}, nonblocking=True)
```

### Produce using partition key
Use any string to produce messages to a specific partition

```python
await producer.produce(
message='bytearray/protobuf class/dict/string/graphql.language.ast.DocumentNode', # bytearray / protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema)
producer_partition_key="key") #default to None
```

### Non-blocking Produce with Task Limits
For better performance, the client won't block requests while waiting for an acknowledgment.
If you are producing a large number of messages and see timeout errors, then you may need to
Expand Down Expand Up @@ -341,6 +351,15 @@ async def msg_handler(msgs, error, context):
consumer.consume(msg_handler)
```

### Consume using a partition key
The key will be used to consume from a specific partition

```python
consumer.consume(msg_handler,
consumer_partition_key = "key" #consume from a specific partition
)
```

### Fetch a single batch of messages
```python
msgs = await memphis.fetch_messages(
Expand All @@ -351,8 +370,9 @@ msgs = await memphis.fetch_messages(
batch_max_time_to_wait_ms=5000, # defaults to 5000
max_ack_time_ms=30000, # defaults to 30000
max_msg_deliveries=10, # defaults to 10
start_consume_from_sequence=1 # start consuming from a specific sequence. defaults to 1
last_messages=-1 # consume the last N messages, defaults to -1 (all messages in the station))
start_consume_from_sequence=1, # start consuming from a specific sequence. defaults to 1
last_messages=-1, # consume the last N messages, defaults to -1 (all messages in the station))
consumer_partition_key="key" # used to consume from a specific partition, default to None
)
```

Expand Down
62 changes: 28 additions & 34 deletions memphis/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import json
import mmh3

from memphis.exceptions import MemphisError
from memphis.utils import default_error_handler, get_internal_name
Expand All @@ -26,6 +27,8 @@ def __init__(
error_callback=None,
start_consume_from_sequence: int = 1,
last_messages: int = -1,
partition_generator: PartitionGenerator = None,
subscriptions: dict = None,
):
self.connection = connection
self.station_name = station_name.lower()
Expand All @@ -49,16 +52,15 @@ def __init__(
self.t_dls = asyncio.create_task(self.__consume_dls())
self.t_consume = None
self.inner_station_name = get_internal_name(self.station_name)
self.subscriptions = {}
if self.inner_station_name in connection.partition_consumers_updates_data:
self.partition_generator = PartitionGenerator(connection.partition_consumers_updates_data[self.inner_station_name]["partitions_list"])
self.subscriptions = subscriptions
self.partition_generator = partition_generator


def set_context(self, context):
"""Set a context (dict) that will be passed to each message handler call."""
self.context = context

def consume(self, callback):
def consume(self, callback, consumer_partition_key: str = None):
"""
This method starts consuming events from the specified station and invokes the provided callback function for each batch of messages received.

Expand All @@ -68,6 +70,7 @@ def consume(self, callback):
- `messages`: A list of `Message` objects representing the batch of messages received.
- `error`: An optional `MemphisError` object if there was an error while consuming the messages.
- `context`: A dictionary representing the context that was set using the `set_context()` method.
consumer_partition_key (str): A string that will be used to determine the partition to consume from. If not provided, the consume will work in a Round Robin fashion.

Example:
import asyncio
Expand All @@ -94,28 +97,20 @@ async def main():
asyncio.run(main())
"""
self.dls_callback_func = callback
self.t_consume = asyncio.create_task(self.__consume(callback))

async def __consume(self, callback):
if self.inner_station_name not in self.connection.partition_consumers_updates_data:
subject = self.inner_station_name + ".final"
consumer_group = get_internal_name(self.consumer_group)
psub = await self.connection.broker_connection.pull_subscribe(subject, durable=consumer_group)
self.subscriptions[1] = psub
else:
for p in self.connection.partition_consumers_updates_data[self.inner_station_name]["partitions_list"]:
subject = f"{self.inner_station_name}${str(p)}.final"
consumer_group = get_internal_name(self.consumer_group)
psub = await self.connection.broker_connection.pull_subscribe(subject, durable=consumer_group)
self.subscriptions[p] = psub
self.t_consume = asyncio.create_task(self.__consume(callback, partition_key=consumer_partition_key))

async def __consume(self, callback, partition_key: str = None):
partition_number = 1

if partition_key is not None:
partition_number = self.get_partition_from_key(partition_key)

while True:
if self.connection.is_connection_active and self.pull_interval_ms:
try:
if len(self.subscriptions) > 1:
partition_number = next(self.partition_generator)
if partition_key is None:
partition_number = next(self.partition_generator)

memphis_messages = []
msgs = await self.subscriptions[partition_number].fetch(self.batch_size)
Expand Down Expand Up @@ -167,7 +162,7 @@ async def __consume_dls(self):
await self.dls_callback_func([], MemphisError(str(e)), self.context)
return

async def fetch(self, batch_size: int = 10):
async def fetch(self, batch_size: int = 10, consumer_partition_key: str = None):
"""
Fetch a batch of messages.

Expand Down Expand Up @@ -225,22 +220,14 @@ async def main(host, username, password, station):
self.dls_current_index -= len(messages)
return messages

subject = get_internal_name(self.station_name)
if len(self.subscriptions) == 0:
if self.inner_station_name not in self.connection.partition_consumers_updates_data:
subject = self.inner_station_name + ".final"
consumer_group = get_internal_name(self.consumer_group)
psub = await self.connection.broker_connection.pull_subscribe(subject, durable=consumer_group)
self.subscriptions[1] = psub
else:
for p in self.connection.partition_consumers_updates_data[self.inner_station_name]["partitions_list"]:
subject = f"{self.inner_station_name}${str(p)}.final"
consumer_group = get_internal_name(self.consumer_group)
psub = await self.connection.broker_connection.pull_subscribe(subject, durable=consumer_group)
self.subscriptions[p] = psub
partition_number = 1

if len(self.subscriptions) > 1:
partition_number = next(self.partition_generator)
if consumer_partition_key is not None:
partition_number = self.get_partition_from_key(consumer_partition_key)
else:
partition_number = next(self.partition_generator)

msgs = await self.subscriptions[partition_number].fetch(batch_size)
for msg in msgs:
messages.append(
Expand Down Expand Up @@ -305,3 +292,10 @@ async def destroy(self):
del self.connection.consumers_map[map_key]
except Exception as e:
raise MemphisError(str(e)) from e

def get_partition_from_key(self, key):
try:
index = mmh3.hash(key, self.connection.SEED, signed=False) % len(self.subscriptions)
return self.connection.partition_consumers_updates_data[self.inner_station_name]["partitions_list"][index]
except Exception as e:
raise e
32 changes: 31 additions & 1 deletion memphis/memphis.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@
from memphis.station import Station
from memphis.types import Retention, Storage
from memphis.utils import get_internal_name, random_bytes
from memphis.partition_generator import PartitionGenerator

app_id = str(uuid.uuid4())

class Memphis:
MAX_BATCH_SIZE = 5000
MEMPHIS_GLOBAL_ACCOUNT_NAME = "$memphis"
SEED = 1234

def __init__(self):
self.is_connection_active = False
Expand Down Expand Up @@ -646,6 +648,27 @@ async def consumer(
except:
raise MemphisError(creation_res)

inner_station_name = get_internal_name(station_name.lower())

partition_generator = None

if inner_station_name in self.partition_consumers_updates_data:
partition_generator = PartitionGenerator(self.partition_consumers_updates_data[inner_station_name]["partitions_list"])

consumer_group = get_internal_name(cg.lower())
subscriptions = {}

if inner_station_name not in self.partition_consumers_updates_data:
subject = inner_station_name + ".final"
psub = await self.broker_connection.pull_subscribe(subject, durable=consumer_group)
subscriptions[1] = psub
else:
for p in self.partition_consumers_updates_data[inner_station_name]["partitions_list"]:
subject = f"{inner_station_name}${str(p)}.final"
psub = await self.broker_connection.pull_subscribe(subject, durable=consumer_group)
subscriptions[p] = psub


internal_station_name = get_internal_name(station_name)
map_key = internal_station_name + "_" + real_name
consumer = Consumer(
Expand All @@ -660,6 +683,8 @@ async def consumer(
max_msg_deliveries,
start_consume_from_sequence=start_consume_from_sequence,
last_messages=last_messages,
partition_generator=partition_generator,
subscriptions=subscriptions,
)
self.consumers_map[map_key] = consumer
return consumer
Expand All @@ -676,6 +701,7 @@ async def produce(
headers: Union[Headers, None] = None,
async_produce: bool = False,
msg_id: Union[str, None] = None,
producer_partition_key: Union[str, None] = None
):
"""Produces a message into a station without the need to create a producer.
Args:
Expand All @@ -687,6 +713,7 @@ async def produce(
headers (dict, optional): Message headers, defaults to {}.
async_produce (boolean, optional): produce operation won't wait for broker acknowledgement
msg_id (string, optional): Attach msg-id header to the message in order to achieve idempotence
producer_partition_key (string, optional): produce to a specific partition using the partition key
Raises:
Exception: _description_
"""
Expand All @@ -708,6 +735,7 @@ async def produce(
headers=headers,
async_produce=async_produce,
msg_id=msg_id,
producer_partition_key=producer_partition_key
)
except Exception as e:
raise MemphisError(str(e)) from e
Expand All @@ -724,6 +752,7 @@ async def fetch_messages(
generate_random_suffix: bool = False,
start_consume_from_sequence: int = 1,
last_messages: int = -1,
consumer_partition_key: str = None,
):
"""Consume a batch of messages.
Args:.
Expand All @@ -737,6 +766,7 @@ async def fetch_messages(
generate_random_suffix (bool): Deprecated: will be stopped to be supported after November 1'st, 2023. false by default, if true concatenate a random suffix to consumer's name
start_consume_from_sequence(int, optional): start consuming from a specific sequence. defaults to 1.
last_messages: consume the last N messages, defaults to -1 (all messages in the station).
consumer_partition_key (str): consume from a specific partition using the partition key
Returns:
list: Message
"""
Expand Down Expand Up @@ -765,7 +795,7 @@ async def fetch_messages(
start_consume_from_sequence=start_consume_from_sequence,
last_messages=last_messages,
)
messages = await consumer.fetch(batch_size)
messages = await consumer.fetch(batch_size, consumer_partition_key=consumer_partition_key)
if messages == None:
messages = []
return messages
Expand Down
15 changes: 14 additions & 1 deletion memphis/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from jsonschema import validate
import google.protobuf.json_format as protobuf_json_format
import fastavro
import mmh3
from memphis.exceptions import MemphisError, MemphisSchemaError
from memphis.headers import Headers
from memphis.utils import get_internal_name
Expand Down Expand Up @@ -191,7 +192,8 @@ async def produce(
async_produce: Union[bool, None] = None,
nonblocking: bool = False,
msg_id: Union[str, None] = None,
concurrent_task_limit: Union[int, None] = None
concurrent_task_limit: Union[int, None] = None,
producer_partition_key: Union[str, None] = None
):
"""Produces a message into a station.
Args:
Expand All @@ -215,6 +217,7 @@ async def produce(
tasks. Calls with nonblocking=True will block
if the limit is hit and will wait until the
buffer drains halfway down.
producer_partition_key (string, optional): Produce messages to a specific partition using the partition key.
Raises:
Exception: _description_
"""
Expand All @@ -239,6 +242,9 @@ async def produce(
partition_name = self.internal_station_name
elif len(self.connection.partition_producers_updates_data[self.internal_station_name]['partitions_list']) == 1:
partition_name = f"{self.internal_station_name}${self.connection.partition_producers_updates_data[self.internal_station_name]['partitions_list'][0]}"
elif producer_partition_key is not None:
partition_number = self.get_partition_from_key(producer_partition_key)
partition_name = f"{self.internal_station_name}${str(partition_number)}"
else:
partition_name = f"{self.internal_station_name}${str(next(self.partition_generator))}"

Expand Down Expand Up @@ -395,3 +401,10 @@ async def destroy(self):

except Exception as e:
raise Exception(e)

def get_partition_from_key(self, key):
try:
index = mmh3.hash(key, self.connection.SEED, signed=False) % len(self.connection.partition_producers_updates_data[self.internal_station_name]["partitions_list"])
return self.connection.partition_producers_updates_data[self.internal_station_name]["partitions_list"][index]
except Exception as e:
raise e
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
url="https://github.com/memphisdev/memphis.py",
download_url="https://github.com/memphisdev/memphis.py/archive/refs/tags/1.1.2.tar.gz",
keywords=["message broker", "devtool", "streaming", "data"],
install_requires=["asyncio", "nats-py", "protobuf", "jsonschema", "graphql-core", "fastavro"],
install_requires=["asyncio", "nats-py", "protobuf", "jsonschema", "graphql-core", "fastavro", "mmh3"],
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
Expand Down