diff --git a/README.md b/README.md index 5d915c0..0aec24f 100644 --- a/README.md +++ b/README.md @@ -639,16 +639,16 @@ To add message headers to the message, use the headers parameter. Headers can he ```python memphis = Memphis() - await memphis.connect(...) + await memphis.connect(...) - await memphis.produce( - station_name = "some_station", - producer_name = "temp_producer", - message = {'some':'message'}, - headers = { - 'trace_header': 'track_me_123' - } - ) + await memphis.produce( + station_name = "some_station", + producer_name = "temp_producer", + message = {'some':'message'}, + headers = { + 'trace_header': 'track_me_123' + } + ) ``` ### Producing to a partition @@ -658,28 +658,28 @@ Lastly, memphis can produce to a specific partition in a station. To do so, use ```python memphis = Memphis() - await memphis.connect(...) + await memphis.connect(...) - await memphis.produce( - station_name = "some_station", - producer_name = "temp_producer", - message = {'some':'message'}, - producer_partition_key = "2nd_partition" - ) + await memphis.produce( + station_name = "some_station", + producer_name = "temp_producer", + message = {'some':'message'}, + producer_partition_key = "2nd_partition" + ) ``` Or, alternatively, use the producer_partition_number parameter: ```python memphis = Memphis() - await memphis.connect(...) + await memphis.connect(...) - await memphis.produce( + await memphis.produce( station_name = "some_station", producer_name = "temp_producer", message = {'some':'message'}, producer_partition_number = 2 - ) + ) ``` ### Non-blocking Produce with Task Limits @@ -695,6 +695,39 @@ await producer.produce( You may read more about this [here](https://memphis.dev/blog/producing-messages-at-warp-speed-best-practices-for-optimizing-your-producers/) on the memphis.dev blog. +### Produce to multiple stations + +Producing to multiple stations can be done by creating a producer with multiple stations and then calling produce on that producer. + +```python +memphis = Memphis() + +await memphis.connect(...) + +producer = await memphis.producer( + station_name = ["station_1", "station_2"], + producer_name = "new_producer" +) + +await producer.produce( + message = "some message" +) +``` + +Alternatively, it also possible to produce to multiple stations using the connection: + +```python +memphis = Memphis() + +await memphis.connect(...) + +await memphis.produce( + station_name = ["station_1", "station_2"], + producer_name = "new_producer", + message = "some message" +) +``` + ### Destroying a Producer ```python diff --git a/memphis/consumer.py b/memphis/consumer.py index 824267b..0d168ca 100644 --- a/memphis/consumer.py +++ b/memphis/consumer.py @@ -308,9 +308,11 @@ async def destroy(self, timeout_retries=5): } consumer_name = json.dumps( destroy_consumer_req, indent=2).encode("utf-8") + # pylint: disable=protected-access res = await self.connection._request( "$memphis_consumer_destructions", consumer_name, 5, timeout_retries ) + # pylint: enable=protected-access error = res.data.decode("utf-8") if error != "" and not "not exist" in error: raise MemphisError(error) diff --git a/memphis/memphis.py b/memphis/memphis.py index 6b82e54..7c35f0d 100644 --- a/memphis/memphis.py +++ b/memphis/memphis.py @@ -17,7 +17,7 @@ import copy import json import ssl -from typing import Iterable, Union +from typing import Iterable, Union, List import uuid import base64 import re @@ -37,7 +37,7 @@ from memphis.partition_generator import PartitionGenerator app_id = str(uuid.uuid4()) - +# pylint: disable=too-many-lines class Memphis: MAX_BATCH_SIZE = 5000 MEMPHIS_GLOBAL_ACCOUNT_NAME = "$memphis" @@ -416,14 +416,14 @@ async def _request(self, subject, payload, timeout, timeout_retries=5): async def producer( self, - station_name: str, + station_name: Union[str, List[str]], producer_name: str, generate_random_suffix: bool = False, timeout_retries=5, ): """Creates a producer. Args: - station_name (str): station name to produce messages into. + station_name (Union[str, List[str]]): station name to produce messages into. producer_name (str): name for the producer. generate_random_suffix (bool): Deprecated: will be stopped to be supported after November 1'st, 2023. false by default, if true concatenate a random suffix to producer's name Raises: @@ -434,16 +434,32 @@ async def producer( try: if not self.is_connection_active: raise MemphisError("Connection is dead") + if not isinstance(station_name, str) and not isinstance(station_name, list): + raise MemphisError("station_name should be either string or list of strings") real_name = producer_name.lower() - internal_station_name = get_internal_name(station_name) if generate_random_suffix: warnings.warn("Deprecation warning: generate_random_suffix will be stopped to be supported after November 1'st, 2023.") producer_name = self.__generate_random_suffix(producer_name) + if isinstance(station_name, str): + return await self._single_station_producer(station_name, producer_name, real_name, timeout_retries) else: - map_key = internal_station_name + "_" + producer_name.lower() - producer = None - if map_key in self.producers_map: - return self.producers_map[map_key] + return await self._multi_station_producer(station_name, producer_name, real_name) + except Exception as e: + raise MemphisError(str(e)) from e + + async def _single_station_producer( + self, + station_name: str, + producer_name: str, + real_name: str, + timeout_retries: int, + ): + try: + internal_station_name = get_internal_name(station_name) + map_key = internal_station_name + "_" + producer_name.lower() + producer = None + if map_key in self.producers_map: + return self.producers_map[map_key] create_producer_req = { "name": producer_name, @@ -495,6 +511,15 @@ async def producer( except Exception as e: raise MemphisError(str(e)) from e + + async def _multi_station_producer( + self, + station_names: List[str], + producer_name: str, + real_name: str + ): + return Producer(self, producer_name, station_names, real_name) + def update_schema_data(self, station_name): internal_station_name = get_internal_name(station_name) if self.schema_updates_data[internal_station_name] != {}: @@ -763,7 +788,7 @@ async def consumer( async def produce( self, - station_name: str, + station_name: Union[str, List[str]], producer_name: str, message, generate_random_suffix: bool = False, @@ -789,6 +814,30 @@ async def produce( Raises: Exception: _description_ """ + try: + if not isinstance(station_name, str) and not isinstance(station_name, list): + raise MemphisError("station_name should be either string or list of strings") + if isinstance(station_name, str): + await self._single_station_produce(station_name, producer_name, message, generate_random_suffix, ack_wait_sec, headers, async_produce, msg_id, producer_partition_key, producer_partition_number) + else: + await self._multi_station_produce(station_name, producer_name, message, generate_random_suffix, ack_wait_sec, headers, async_produce, msg_id, producer_partition_key, producer_partition_number) + except Exception as e: + raise MemphisError(str(e)) from e + + + async def _single_station_produce( + self, + station_name: str, + producer_name: str, + message, + generate_random_suffix: bool = False, + ack_wait_sec: int = 15, + headers: Union[Headers, None] = None, + async_produce: bool = False, + msg_id: Union[str, None] = None, + producer_partition_key: Union[str, None] = None, + producer_partition_number: Union[int, -1] = -1 + ): try: internal_station_name = get_internal_name(station_name) map_key = internal_station_name + "_" + producer_name.lower() @@ -813,6 +862,38 @@ async def produce( except Exception as e: raise MemphisError(str(e)) from e + async def _multi_station_produce( + self, + station_names: List[str], + producer_name: str, + message, + generate_random_suffix: bool = False, + ack_wait_sec: int = 15, + headers: Union[Headers, None] = None, + async_produce: bool = False, + msg_id: Union[str, None] = None, + producer_partition_key: Union[str, None] = None, + producer_partition_number: Union[int, -1] = -1 + ): + try: + producer = await self.producer( + station_name=station_names, + producer_name=producer_name, + generate_random_suffix=generate_random_suffix, + ) + await producer.produce( + message=message, + ack_wait_sec=ack_wait_sec, + headers=headers, + async_produce=async_produce, + msg_id=msg_id, + producer_partition_key=producer_partition_key, + producer_partition_number=producer_partition_number + ) + except Exception as e: + raise MemphisError(str(e)) from e + + async def fetch_messages( self, station_name: str, diff --git a/memphis/producer.py b/memphis/producer.py index b662089..55936d5 100644 --- a/memphis/producer.py +++ b/memphis/producer.py @@ -3,7 +3,7 @@ import asyncio import json import time -from typing import Union +from typing import Union, List import warnings import mmh3 @@ -18,20 +18,26 @@ class Producer: def __init__( - self, connection, producer_name: str, station_name: str, real_name: str + self, connection, producer_name: str, station_name: Union[str, List[str]] , real_name: str ): self.connection = connection self.producer_name = producer_name.lower() self.station_name = station_name + self.real_name = real_name + self.background_tasks = set() + if isinstance(station_name, list): + self.is_multi_station_producer = True + return + else: + self.is_multi_station_producer = False + self.station = Station(connection, station_name) self.internal_station_name = get_internal_name(self.station_name) self.loop = asyncio.get_running_loop() - self.real_name = real_name - self.background_tasks = set() if self.internal_station_name in connection.partition_producers_updates_data: self.partition_generator = PartitionGenerator(connection.partition_producers_updates_data[self.internal_station_name]["partitions_list"]) - # pylint: disable=R0913 + # pylint: disable=too-many-arguments async def produce( self, message, @@ -43,6 +49,68 @@ async def produce( concurrent_task_limit: Union[int, None] = None, producer_partition_key: Union[str, None] = None, producer_partition_number: Union[int, -1] = -1 + ): + """Produces a message into a station. + Args: + message (bytearray/dict): message to send into the station + - bytearray/protobuf class + (schema validated station - protobuf) + - bytearray/dict (schema validated station - json schema) + - string/bytearray/graphql.language.ast.DocumentNode + (schema validated station - graphql schema) + - bytearray/dict (schema validated station - avro schema) + ack_wait_sec (int, optional): max time in seconds to wait for an ack from the broker. + Defaults to 15 sec. + headers (dict, optional): message headers, defaults to {}. + async_produce (boolean, optional): produce operation won't block (wait) on message send. + This argument is deprecated. Please use the + `nonblocking` arguemnt instead. + nonblocking (boolean, optional): produce operation won't block (wait) on message send. + msg_id (string, optional): Attach msg-id header to the message in order to + achieve idempotency. + concurrent_task_limit (int, optional): Limit the number of outstanding async produce + tasks. Calls with nonblocking=True will block + if the limit is hit and will wait until the + buffer drains halfway down. + producer_partition_key (string, optional): Produce messages to a specific partition using the partition key. + producer_partition_number (int, optional): Produce messages to a specific partition using the partition number. + Raises: + Exception: _description_ + """ + if self.is_multi_station_producer: + await self._multi_station_produce( + message, + ack_wait_sec=ack_wait_sec, + headers=headers, + async_produce=async_produce, + msg_id=msg_id, + producer_partition_key=producer_partition_key, + producer_partition_number=producer_partition_number + ) + else: + await self._single_station_produce( + message, + ack_wait_sec=ack_wait_sec, + headers=headers, + async_produce=async_produce, + nonblocking=nonblocking, + msg_id=msg_id, + concurrent_task_limit=concurrent_task_limit, + producer_partition_key=producer_partition_key, + producer_partition_number=producer_partition_number + ) + + async def _single_station_produce( + self, + message, + ack_wait_sec: int = 15, + headers: Union[Headers, None] = None, + async_produce: Union[bool, None] = None, + nonblocking: bool = False, + msg_id: Union[str, None] = None, + concurrent_task_limit: Union[int, None] = None, + producer_partition_key: Union[str, None] = None, + producer_partition_number: Union[int, -1] = -1 ): """Produces a message into a station. Args: @@ -213,7 +281,38 @@ async def produce( ) raise MemphisError(str(e)) from e + async def _multi_station_produce( + self, + message, + ack_wait_sec: int = 15, + headers: Union[Headers, None] = None, + async_produce: Union[bool, None] = None, + msg_id: Union[str, None] = None, + producer_partition_key: Union[str, None] = None, + producer_partition_number: Union[int, -1] = -1 + ): + for sn in self.station_name: + await self.connection.produce( + sn, + self.producer_name, + message, + ack_wait_sec=ack_wait_sec, + headers=headers, + async_produce=async_produce, + msg_id=msg_id, + producer_partition_key=producer_partition_key, + producer_partition_number=producer_partition_number + ) + + # pylint: enable=too-many-arguments async def destroy(self, timeout_retries=5): + """Destroy the producer.""" + if self.is_multi_station_producer: + await self._destroy_multi_station_producer(timeout_retries=timeout_retries) + else: + await self._destroy_single_station_producer(timeout_retries=timeout_retries) + + async def _destroy_single_station_producer(self, timeout_retries=5): """Destroy the producer.""" try: # drain buffered async messages @@ -229,9 +328,11 @@ async def destroy(self, timeout_retries=5): } producer_name = json.dumps(destroy_producer_req).encode("utf-8") + # pylint: disable=protected-access res = await self.connection._request( "$memphis_producer_destructions", producer_name, 5, timeout_retries ) + # pylint: enable=protected-access error = res.data.decode("utf-8") if error != "" and not "not exist" in error: raise Exception(error) @@ -283,6 +384,15 @@ async def destroy(self, timeout_retries=5): except Exception as e: raise Exception(e) + async def _destroy_multi_station_producer(self, timeout_retries=5): + internal_station_name_list = [get_internal_name(station_name) for station_name in self.station_name] + producer_keys = [f"{internal_station_name}_{self.real_name}" for internal_station_name in internal_station_name_list] + producers = [self.connection.producers_map.get(producer_key) for producer_key in producer_keys] + producers = [producer for producer in producers if producer is not None] + for producer in producers: + await producer.destroy(timeout_retries) + + def get_partition_from_key(self, key): try: index = mmh3.hash(key, self.connection.SEED, signed=False) % len(self.connection.partition_producers_updates_data[self.internal_station_name]["partitions_list"]) diff --git a/memphis/station.py b/memphis/station.py index 06d10f2..36f29e5 100644 --- a/memphis/station.py +++ b/memphis/station.py @@ -168,9 +168,11 @@ async def destroy(self, timeout_retries=5): try: name_req = {"station_name": self.name, "username": self.connection.username} station_name = json.dumps(name_req, indent=2).encode("utf-8") + # pylint: disable=protected-access res = await self.connection._request( "$memphis_station_destructions", station_name, 5, timeout_retries ) + # pylint: enable=protected-access error = res.data.decode("utf-8") if error != "" and not "not exist" in error: raise MemphisError(error)