diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 254758f..173aa67 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -5,5 +5,5 @@ - Yaniv Ben Hemo [@yanivbh1](https://github.com/yanivbh1) - Sveta Gimpelson [@SvetaMemphis](https://github.com/SvetaMemphis) - Shay Bratslavsky [@shay23b](https://github.com/shay23b) - - Or Grinberg [@orgrMmphs](https://github.com/orgrMmphs) - - Shoham Roditi [@shohamroditimemphis](https://github.com/shohamroditimemphis) \ No newline at end of file + - Shoham Roditi [@shohamroditimemphis](https://github.com/shohamroditimemphis) + - Ido Naaman [idonaaman123](https://github.com/idonaaman123) \ No newline at end of file diff --git a/README.md b/README.md index 8030f14..67086b7 100644 --- a/README.md +++ b/README.md @@ -38,12 +38,12 @@

-**[Memphis](https://memphis.dev)** is a next-generation message broker.
+**[Memphis](https://memphis.dev)** is a next-generation alternative to traditional message brokers.

A simple, robust, and durable cloud-native message broker wrapped with
-an entire ecosystem that enables fast and reliable development of next-generation event-driven use cases.

-Memphis enables building modern applications that require large volumes of streamed and enriched data,
-modern protocols, zero ops, rapid development, extreme cost reduction,
-and a significantly lower amount of dev time for data-oriented developers and data engineers. +an entire ecosystem that enables cost-effective, fast, and reliable development of modern queue-based use cases.

+Memphis enables the building of modern queue-based applications that require
+large volumes of streamed and enriched data, modern protocols, zero ops, rapid development,
+extreme cost reduction, and a significantly lower amount of dev time for data-oriented developers and data engineers. ## Installation @@ -55,7 +55,7 @@ $ pip3 install memphis-py ```python from memphis import Memphis, Headers -from memphis import retention_types, storage_types +from memphis.types import Retention, Storage ``` ### Connecting to Memphis @@ -69,7 +69,7 @@ async def main(): await memphis.connect( host="", username="", - connection_token="", + connection_token="", # you will get it on application type user creation port="", # defaults to 6666 reconnect=True, # defaults to True max_reconnect=3, # defaults to 3 @@ -108,9 +108,9 @@ _If a station already exists nothing happens, the new configuration will not be station = memphis.station( name="", schema_name="", - retention_type=retention_types.MAX_MESSAGE_AGE_SECONDS, # MAX_MESSAGE_AGE_SECONDS/MESSAGES/BYTES. Defaults to MAX_MESSAGE_AGE_SECONDS + retention_type=Retention.MAX_MESSAGE_AGE_SECONDS, # MAX_MESSAGE_AGE_SECONDS/MESSAGES/BYTES. Defaults to MAX_MESSAGE_AGE_SECONDS retention_value=604800, # defaults to 604800 - storage_type=storage_types.DISK, # storage_types.DISK/storage_types.MEMORY. Defaults to DISK + storage_type=Storage.DISK, # Storage.DISK/Storage.MEMORY. Defaults to DISK replicas=1, # defaults to 1 idempotency_window_ms=120000, # defaults to 2 minutes send_poison_msg_to_dls=True, # defaults to true @@ -124,35 +124,46 @@ station = memphis.station( Memphis currently supports the following types of retention: ```python -memphis.retention_types.MAX_MESSAGE_AGE_SECONDS +memphis.types.Retention.MAX_MESSAGE_AGE_SECONDS ``` Means that every message persists for the value set in retention value field (in seconds) ```python -memphis.retention_types.MESSAGES +memphis.types.Retention.MESSAGES ``` Means that after max amount of saved messages (set in retention value), the oldest messages will be deleted ```python -memphis.retention_types.BYTES +memphis.types.Retention.BYTES ``` Means that after max amount of saved bytes (set in retention value), the oldest messages will be deleted + +### Retention Values + +The `retention values` are directly related to the `retention types` mentioned above, where the values vary according to the type of retention chosen. + +All retention values are of type `int` but with different representations as follows: + +`memphis.types.Retention.MAX_MESSAGE_AGE_SECONDS` is represented **in seconds**, `memphis.types.Retention.MESSAGES` in a **number of messages** and finally `memphis.types.Retention.BYTES` in a **number of bytes**. + +After these limits are reached oldest messages will be deleted. + ### Storage types Memphis currently supports the following types of messages storage: ```python -memphis.storage_types.DISK +memphis.types.Storage.DISK ``` Means that messages persist on disk ```python -memphis.storage_types.MEMORY +memphis.types.Storage.MEMORY ``` Means that messages persist on the main memory @@ -213,9 +224,9 @@ await memphis.produce(station_name='test_station_py', producer_name='prod_py', ``` -Creating a producer first +With creating a producer ```python -await prod.produce( +await producer.produce( message='bytearray/protobuf class/dict/string/graphql.language.ast.DocumentNode', # bytearray / protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema) ack_wait_sec=15) # defaults to 15 ``` @@ -295,6 +306,28 @@ async def msg_handler(msgs, error, context): consumer.consume(msg_handler) ``` +### Fetch a single batch of messages +```python +msgs = await memphis.fetch_messages( + station_name="", + consumer_name="", + consumer_group="", # defaults to the consumer name + batch_size=10, # defaults to 10 + batch_max_time_to_wait_ms=5000, # defaults to 5000 + max_ack_time_ms=30000, # defaults to 30000 + max_msg_deliveries=10, # defaults to 10 + generate_random_suffix=False + start_consume_from_sequence=1 # start consuming from a specific sequence. defaults to 1 + last_messages=-1 # consume the last N messages, defaults to -1 (all messages in the station)) +) +``` + +### Fetch a single batch of messages after creating a consumer +```python +msgs = await consumer.fetch(batch_size=10) # defaults to 10 +``` + + ### Acknowledge a message Acknowledge a message indicates the Memphis server to not re-send the same message again to the same consumer / consumers group @@ -306,7 +339,7 @@ await message.ack() ### Get headers Get headers per message -``python +```python headers = message.get_headers() ``` @@ -328,4 +361,4 @@ consumer.destroy() ```python memphis.is_connected() -``` \ No newline at end of file +``` diff --git a/examples/consumer.py b/examples/consumer.py index 4febb93..dfebbbe 100644 --- a/examples/consumer.py +++ b/examples/consumer.py @@ -1,4 +1,5 @@ from __future__ import annotations + import asyncio from memphis import Memphis, MemphisConnectError, MemphisError, MemphisHeaderError diff --git a/examples/producer.py b/examples/producer.py index e5a79e0..01dd381 100644 --- a/examples/producer.py +++ b/examples/producer.py @@ -1,4 +1,5 @@ from __future__ import annotations + import asyncio from memphis import ( diff --git a/memphis/__init__.py b/memphis/__init__.py index ff8bd5f..aa44e30 100644 --- a/memphis/__init__.py +++ b/memphis/__init__.py @@ -11,14 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import memphis.retention_types -import memphis.storage_types -from memphis.memphis import ( - Headers, - Memphis, +from memphis.exceptions import ( MemphisConnectError, MemphisError, MemphisHeaderError, MemphisSchemaError, ) +from memphis.headers import Headers +from memphis.memphis import Memphis diff --git a/memphis/consumer.py b/memphis/consumer.py new file mode 100644 index 0000000..598155c --- /dev/null +++ b/memphis/consumer.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import asyncio +import json + +from memphis.exceptions import MemphisError +from memphis.utils import default_error_handler, get_internal_name +from memphis.message import Message + + +class Consumer: + def __init__( + self, + connection, + station_name: str, + consumer_name, + consumer_group, + pull_interval_ms: int, + batch_size: int, + batch_max_time_to_wait_ms: int, + max_ack_time_ms: int, + max_msg_deliveries: int = 10, + error_callback=None, + start_consume_from_sequence: int = 1, + last_messages: int = -1, + ): + self.connection = connection + self.station_name = station_name.lower() + self.consumer_name = consumer_name.lower() + self.consumer_group = consumer_group.lower() + self.pull_interval_ms = pull_interval_ms + self.batch_size = batch_size + self.batch_max_time_to_wait_ms = batch_max_time_to_wait_ms + self.max_ack_time_ms = max_ack_time_ms + self.max_msg_deliveries = max_msg_deliveries + self.ping_consumer_invterval_ms = 30000 + if error_callback is None: + error_callback = default_error_handler + self.t_ping = asyncio.create_task(self.__ping_consumer(error_callback)) + self.start_consume_from_sequence = start_consume_from_sequence + self.last_messages = last_messages + self.context = {} + self.dls_messages = [] + self.dls_current_index = 0 + self.dls_callback_func = None + self.t_dls = asyncio.create_task(self.__consume_dls()) + + def set_context(self, context): + """Set a context (dict) that will be passed to each message handler call.""" + self.context = context + + def consume(self, callback): + """Consume events.""" + self.dls_callback_func = callback + self.t_consume = asyncio.create_task(self.__consume(callback)) + + async def __consume(self, callback): + subject = get_internal_name(self.station_name) + consumer_group = get_internal_name(self.consumer_group) + self.psub = await self.connection.broker_connection.pull_subscribe( + subject + ".final", durable=consumer_group + ) + while True: + if self.connection.is_connection_active and self.pull_interval_ms: + try: + memphis_messages = [] + msgs = await self.psub.fetch(self.batch_size) + for msg in msgs: + memphis_messages.append( + Message(msg, self.connection, self.consumer_group) + ) + await callback(memphis_messages, None, self.context) + await asyncio.sleep(self.pull_interval_ms / 1000) + + except asyncio.TimeoutError: + await callback( + [], MemphisError("Memphis: TimeoutError"), self.context + ) + continue + except Exception as e: + if self.connection.is_connection_active: + raise MemphisError(str(e)) from e + else: + return + else: + break + + async def __consume_dls(self): + subject = get_internal_name(self.station_name) + consumer_group = get_internal_name(self.consumer_group) + try: + subscription_name = "$memphis_dls_" + subject + "_" + consumer_group + self.consumer_dls = await self.connection.broker_manager.subscribe( + subscription_name, subscription_name + ) + async for msg in self.consumer_dls.messages: + index_to_insert = self.dls_current_index + if index_to_insert >= 10000: + index_to_insert %= 10000 + self.dls_messages.insert( + index_to_insert, Message(msg, self.connection, self.consumer_group) + ) + self.dls_current_index += 1 + if self.dls_callback_func != None: + await self.dls_callback_func( + [Message(msg, self.connection, self.consumer_group)], + None, + self.context, + ) + except Exception as e: + await self.dls_callback_func([], MemphisError(str(e)), self.context) + return + + async def fetch(self, batch_size: int = 10): + """Fetch a batch of messages.""" + messages = [] + if self.connection.is_connection_active: + try: + self.batch_size = batch_size + if len(self.dls_messages) > 0: + if len(self.dls_messages) <= batch_size: + messages = self.dls_messages + self.dls_messages = [] + self.dls_current_index = 0 + else: + messages = self.dls_messages[0:batch_size] + del self.dls_messages[0:batch_size] + self.dls_current_index -= len(messages) + return messages + + durableName = "" + if self.consumer_group != "": + durableName = get_internal_name(self.consumer_group) + else: + durableName = get_internal_name(self.consumer_name) + subject = get_internal_name(self.station_name) + consumer_group = get_internal_name(self.consumer_group) + self.psub = await self.connection.broker_connection.pull_subscribe( + subject + ".final", durable=durableName + ) + msgs = await self.psub.fetch(batch_size) + for msg in msgs: + messages.append(Message(msg, self.connection, self.consumer_group)) + return messages + except Exception as e: + if not "timeout" in str(e): + raise MemphisError(str(e)) from e + else: + return messages + + async def __ping_consumer(self, callback): + while True: + try: + await asyncio.sleep(self.ping_consumer_invterval_ms / 1000) + consumer_group = get_internal_name(self.consumer_group) + await self.connection.broker_connection.consumer_info( + self.station_name, consumer_group, timeout=30 + ) + + except Exception as e: + callback(MemphisError(str(e))) + + async def destroy(self): + """Destroy the consumer.""" + if self.t_consume is not None: + self.t_consume.cancel() + if self.t_dls is not None: + self.t_dls.cancel() + if self.t_ping is not None: + self.t_ping.cancel() + self.pull_interval_ms = None + try: + destroyConsumerReq = { + "name": self.consumer_name, + "station_name": self.station_name, + "username": self.connection.username, + } + consumer_name = json.dumps(destroyConsumerReq, indent=2).encode("utf-8") + res = await self.connection.broker_manager.request( + "$memphis_consumer_destructions", consumer_name, timeout=5 + ) + error = res.data.decode("utf-8") + if error != "" and not "not exist" in error: + raise MemphisError(error) + self.dls_messages.clear() + internal_station_name = get_internal_name(self.station_name) + map_key = internal_station_name + "_" + self.consumer_name.lower() + del self.connection.consumers_map[map_key] + except Exception as e: + raise MemphisError(str(e)) from e diff --git a/memphis/exceptions.py b/memphis/exceptions.py new file mode 100644 index 0000000..65c8c25 --- /dev/null +++ b/memphis/exceptions.py @@ -0,0 +1,23 @@ +class MemphisError(Exception): + def __init__(self, message): + message = message.replace("nats", "memphis") + message = message.replace("NATS", "memphis") + message = message.replace("Nats", "memphis") + message = message.replace("NatsError", "MemphisError") + self.message = message + if message.startswith("memphis:"): + super().__init__(self.message) + else: + super().__init__("memphis: " + self.message) + + +class MemphisConnectError(MemphisError): + pass + + +class MemphisSchemaError(MemphisError): + pass + + +class MemphisHeaderError(MemphisError): + pass diff --git a/memphis/headers.py b/memphis/headers.py new file mode 100644 index 0000000..1d25e14 --- /dev/null +++ b/memphis/headers.py @@ -0,0 +1,19 @@ +from memphis.exceptions import MemphisHeaderError + + +class Headers: + def __init__(self): + self.headers = {} + + def add(self, key, value): + """Add a header. + Args: + key (string): header key. + value (string): header value. + Raises: + Exception: _description_ + """ + if not key.startswith("$memphis"): + self.headers[key] = value + else: + raise MemphisHeaderError("Keys in headers should not start with $memphis") diff --git a/memphis/memphis.py b/memphis/memphis.py index c8fee2a..c1235af 100644 --- a/memphis/memphis.py +++ b/memphis/memphis.py @@ -22,8 +22,6 @@ from typing import Callable, Iterable, Union import graphql -import memphis.retention_types as retention_types -import memphis.storage_types as storage_types import nats as broker from google.protobuf import descriptor_pb2, descriptor_pool from google.protobuf.message_factory import MessageFactory @@ -31,41 +29,13 @@ from graphql import parse as parse_graphql from graphql import validate as validate_graphql from jsonschema import validate - - -schemaVFailAlertType = "schema_validation_fail_alert" - - -class set_interval: - def __init__(self, func: Callable, sec: int): - def func_wrapper(): - self.t = Timer(sec, func_wrapper) - self.t.start() - func() - - self.t = Timer(sec, func_wrapper) - self.t.start() - - def cancel(self): - self.t.cancel() - - -class Headers: - def __init__(self): - self.headers = {} - - def add(self, key, value): - """Add a header. - Args: - key (string): header key. - value (string): header value. - Raises: - Exception: _description_ - """ - if not key.startswith("$memphis"): - self.headers[key] = value - else: - raise MemphisHeaderError("Keys in headers should not start with $memphis") +from memphis.consumer import Consumer +from memphis.exceptions import MemphisConnectError, MemphisError, MemphisHeaderError +from memphis.headers import Headers +from memphis.producer import Producer +from memphis.station import Station +from memphis.types import Retention, Storage +from memphis.utils import get_internal_name, random_bytes class Memphis: @@ -83,8 +53,9 @@ def __init__(self): self.update_configurations_sub = {} self.configuration_tasks = {} self.producers_map = dict() + self.consumers_map = dict() - async def get_msgs_update_configurations(self, iterable: Iterable): + async def get_msgs_sdk_clients_updates(self, iterable: Iterable): try: async for msg in iterable: message = msg.data.decode("utf-8") @@ -95,18 +66,21 @@ async def get_msgs_update_configurations(self, iterable: Iterable): self.station_schemaverse_to_dls[data["station_name"]] = data[ "update" ] + elif data["type"] == "remove_station": + self.unset_cached_producer_station(data['station_name']) + self.unset_cached_consumer_station(data['station_name']) except Exception as err: raise MemphisError(err) - async def configurations_listener(self): + async def sdk_client_updates_listener(self): try: sub = await self.broker_manager.subscribe( - "$memphis_sdk_configurations_updates" + "$memphis_sdk_clients_updates" ) self.update_configurations_sub = sub loop = asyncio.get_event_loop() task = loop.create_task( - self.get_msgs_update_configurations( + self.get_msgs_sdk_clients_updates( self.update_configurations_sub.messages ) ) @@ -136,8 +110,8 @@ async def connect( port (int, optional): port. Defaults to 6666. reconnect (bool, optional): whether to do reconnect while connection is lost. Defaults to True. max_reconnect (int, optional): The reconnect attempt. Defaults to 3. - reconnect_interval_ms (int, optional): Interval in miliseconds between reconnect attempts. Defaults to 200. - timeout_ms (int, optional): connection timeout in miliseconds. Defaults to 15000. + reconnect_interval_ms (int, optional): Interval in milliseconds between reconnect attempts. Defaults to 200. + timeout_ms (int, optional): connection timeout in milliseconds. Defaults to 15000. key_file (string): path to tls key file. cert_file (string): path to tls cert file. ca_file (string): path to tls ca file. @@ -184,7 +158,7 @@ async def connect( name=self.connection_id + "::" + self.username, ) - await self.configurations_listener() + await self.sdk_client_updates_listener() self.broker_connection = self.broker_manager.jetstream() self.is_connection_active = True except Exception as e: @@ -198,22 +172,22 @@ async def send_notification(self, title, msg, failedMsg, type): async def station( self, name: str, - retention_type: str = retention_types.MAX_MESSAGE_AGE_SECONDS, + retention_type: Retention = Retention.MAX_MESSAGE_AGE_SECONDS, retention_value: int = 604800, - storage_type: str = storage_types.DISK, + storage_type: Storage = Storage.DISK, replicas: int = 1, idempotency_window_ms: int = 120000, schema_name: str = "", send_poison_msg_to_dls: bool = True, send_schema_failed_msg_to_dls: bool = True, - tiered_storage_enabled: bool = False + tiered_storage_enabled: bool = False, ): """Creates a station. Args: name (str): station name. - retention_type (str, optional): retention type: message_age_sec/messages/bytes . Defaults to "message_age_sec". + retention_type (Retention, optional): retention type: message_age_sec/messages/bytes . Defaults to "message_age_sec". retention_value (int, optional): number which represents the retention based on the retention_type. Defaults to 604800. - storage_type (str, optional): persistance storage for messages of the station: disk/memory. Defaults to "disk". + storage_type (Storage, optional): persistance storage for messages of the station: disk/memory. Defaults to "disk". replicas (int, optional):number of replicas for the messages of the data. Defaults to 1. idempotency_window_ms (int, optional): time frame in which idempotent messages will be tracked, happens based on message ID Defaults to 120000. schema_name (str): schema name. @@ -226,9 +200,9 @@ async def station( createStationReq = { "name": name, - "retention_type": retention_type, + "retention_type": retention_type.value, "retention_value": retention_value, - "storage_type": storage_type, + "storage_type": storage_type.value, "replicas": replicas, "idempotency_window_in_ms": idempotency_window_ms, "schema_name": schema_name, @@ -237,7 +211,7 @@ async def station( "Schemaverse": send_schema_failed_msg_to_dls, }, "username": self.username, - "tiered_storage_enabled": tiered_storage_enabled + "tiered_storage_enabled": tiered_storage_enabled, } create_station_req_bytes = json.dumps(createStationReq, indent=2).encode( "utf-8" @@ -329,6 +303,9 @@ async def close(self): if self.update_configurations_sub is not None: await self.update_configurations_sub.unsubscribe() self.producers_map.clear() + for consumer in self.consumers_map: + consumer.dls_messages.clear() + self.consumers_map.clear() except: return @@ -387,45 +364,45 @@ async def producer( if create_res["error"] != "": raise MemphisError(create_res["error"]) - station_name_internal = get_internal_name(station_name) - self.station_schemaverse_to_dls[station_name_internal] = create_res[ + internal_station_name = get_internal_name(station_name) + self.station_schemaverse_to_dls[internal_station_name] = create_res[ "schemaverse_to_dls" ] self.cluster_configurations["send_notification"] = create_res[ "send_notification" ] await self.start_listen_for_schema_updates( - station_name_internal, create_res["schema_update"] + internal_station_name, create_res["schema_update"] ) - if self.schema_updates_data[station_name_internal] != {}: + if self.schema_updates_data[internal_station_name] != {}: if ( - self.schema_updates_data[station_name_internal]["type"] + self.schema_updates_data[internal_station_name]["type"] == "protobuf" ): - self.parse_descriptor(station_name_internal) - if self.schema_updates_data[station_name_internal]["type"] == "json": - schema = self.schema_updates_data[station_name_internal][ + self.parse_descriptor(internal_station_name) + if self.schema_updates_data[internal_station_name]["type"] == "json": + schema = self.schema_updates_data[internal_station_name][ "active_version" ]["schema_content"] - self.json_schemas[station_name_internal] = json.loads(schema) + self.json_schemas[internal_station_name] = json.loads(schema) elif ( - self.schema_updates_data[station_name_internal]["type"] == "graphql" + self.schema_updates_data[internal_station_name]["type"] == "graphql" ): - self.graphql_schemas[station_name_internal] = build_graphql_schema( - self.schema_updates_data[station_name_internal][ + self.graphql_schemas[internal_station_name] = build_graphql_schema( + self.schema_updates_data[internal_station_name][ "active_version" ]["schema_content"] ) producer = Producer(self, producer_name, station_name, real_name) - map_key = station_name_internal + "_" + real_name + map_key = internal_station_name + "_" + real_name self.producers_map[map_key] = producer return producer except Exception as e: raise MemphisError(str(e)) from e - async def get_msg_schema_updates(self, station_name_internal, iterable): + async def get_msg_schema_updates(self, internal_station_name, iterable): async for msg in iterable: message = msg.data.decode("utf-8") message = json.loads(message) @@ -433,8 +410,8 @@ async def get_msg_schema_updates(self, station_name_internal, iterable): data = {} else: data = message["init"] - self.schema_updates_data[station_name_internal] = data - self.parse_descriptor(station_name_internal) + self.schema_updates_data[internal_station_name] = data + self.parse_descriptor(internal_station_name) def parse_descriptor(self, station_name): try: @@ -507,10 +484,10 @@ async def consumer( station_name (str): station name to consume messages from. consumer_name (str): name for the consumer. consumer_group (str, optional): consumer group name. Defaults to the consumer name. - pull_interval_ms (int, optional): interval in miliseconds between pulls. Defaults to 1000. + pull_interval_ms (int, optional): interval in milliseconds between pulls. Defaults to 1000. batch_size (int, optional): pull batch size. Defaults to 10. - batch_max_time_to_wait_ms (int, optional): max time in miliseconds to wait between pulls. Defaults to 5000. - max_ack_time_ms (int, optional): max time for ack a message in miliseconds, in case a message not acked in this time period the Memphis broker will resend it. Defaults to 30000. + batch_max_time_to_wait_ms (int, optional): max time in milliseconds to wait between pulls. Defaults to 5000. + max_ack_time_ms (int, optional): max time for ack a message in milliseconds, in case a message not acked in this time period the Memphis broker will resend it. Defaults to 30000. max_msg_deliveries (int, optional): max number of message deliveries, by default is 10. generate_random_suffix (bool): false by default, if true concatenate a random suffix to consumer's name start_consume_from_sequence(int, optional): start consuming from a specific sequence. defaults to 1. @@ -521,7 +498,7 @@ async def consumer( try: if not self.is_connection_active: raise MemphisError("Connection is dead") - + real_name = consumer_name.lower() if generate_random_suffix: consumer_name = self.__generateRandomSuffix(consumer_name) cg = consumer_name if not consumer_group else consumer_group @@ -563,7 +540,9 @@ async def consumer( if err_msg != "": raise MemphisError(err_msg) - return Consumer( + internal_station_name = get_internal_name(station_name) + map_key = internal_station_name + "_" + real_name + consumer = Consumer( self, station_name, consumer_name, @@ -576,6 +555,8 @@ async def consumer( start_consume_from_sequence=start_consume_from_sequence, last_messages=last_messages, ) + self.consumers_map[map_key] = consumer + return consumer except Exception as e: raise MemphisError(str(e)) from e @@ -599,13 +580,13 @@ async def produce( ack_wait_sec (int, optional): max time in seconds to wait for an ack from memphis. Defaults to 15. headers (dict, optional): Message headers, defaults to {}. async_produce (boolean, optional): produce operation won't wait for broker acknowledgement - msg_id (string, optional): Attach msg-id header to the message in order to achieve idempotency + msg_id (string, optional): Attach msg-id header to the message in order to achieve idempotence Raises: Exception: _description_ """ try: - station_name_internal = get_internal_name(station_name) - map_key = station_name_internal + "_" + producer_name.lower() + internal_station_name = get_internal_name(station_name) + map_key = internal_station_name + "_" + producer_name.lower() producer = None if map_key in self.producers_map: producer = self.producers_map[map_key] @@ -625,556 +606,84 @@ async def produce( except Exception as e: raise MemphisError(str(e)) from e - def is_connected(self): - return self.broker_manager.is_connected - - -class Station: - def __init__(self, connection, name: str): - self.connection = connection - self.name = name.lower() - - async def destroy(self): - """Destroy the station.""" - try: - nameReq = {"station_name": self.name, "username": self.connection.username} - station_name = json.dumps(nameReq, indent=2).encode("utf-8") - res = await self.connection.broker_manager.request( - "$memphis_station_destructions", station_name, timeout=5 - ) - error = res.data.decode("utf-8") - if error != "" and not "not exist" in error: - raise MemphisError(error) - - station_name_internal = get_internal_name(self.name) - sub = self.connection.schema_updates_subs.get(station_name_internal) - task = self.connection.schema_tasks.get(station_name_internal) - if station_name_internal in self.connection.schema_updates_data: - del self.connection.schema_updates_data[station_name_internal] - if station_name_internal in self.connection.schema_updates_subs: - del self.connection.schema_updates_subs[station_name_internal] - if station_name_internal in self.connection.producers_per_station: - del self.connection.producers_per_station[station_name_internal] - if station_name_internal in self.connection.schema_tasks: - del self.connection.schema_tasks[station_name_internal] - if task is not None: - task.cancel() - if sub is not None: - await sub.unsubscribe() - - self.connection.producers_map = { - k: v - for k, v in self.connection.producers_map.items() - if self.name not in k - } - - except Exception as e: - raise MemphisError(str(e)) from e - - -def get_internal_name(name: str) -> str: - name = name.lower() - return name.replace(".", "#") - - -class Producer: - def __init__( - self, connection, producer_name: str, station_name: str, real_name: str - ): - self.connection = connection - self.producer_name = producer_name.lower() - self.station_name = station_name - self.internal_station_name = get_internal_name(self.station_name) - self.loop = asyncio.get_running_loop() - self.real_name = real_name - - async def validate_msg(self, message): - if self.connection.schema_updates_data[self.internal_station_name] != {}: - schema_type = self.connection.schema_updates_data[ - self.internal_station_name - ]["type"] - if schema_type == "protobuf": - message = self.validate_protobuf(message) - return message - elif schema_type == "json": - message = self.validate_json_schema(message) - return message - elif schema_type == "graphql": - message = self.validate_graphql(message) - return message - elif not isinstance(message, bytearray) and not isinstance(message, dict): - raise MemphisSchemaError("Unsupported message type") - else: - if isinstance(message, dict): - message = bytearray(json.dumps(message).encode("utf-8")) - return message - - def validate_protobuf(self, message): - proto_msg = self.connection.proto_msgs[self.internal_station_name] - msgToSend = "" - try: - if isinstance(message, bytearray): - msgToSend = bytes(message) - try: - proto_msg.ParseFromString(msgToSend) - proto_msg.SerializeToString() - msgToSend = msgToSend.decode("utf-8") - except Exception as e: - if "parsing message" in str(e): - e = "Invalid message format, expecting protobuf" - raise MemphisSchemaError(str(e)) - return message - elif hasattr(message, "SerializeToString"): - msgToSend = message.SerializeToString() - proto_msg.ParseFromString(msgToSend) - proto_msg.SerializeToString() - return msgToSend - - else: - raise MemphisSchemaError("Unsupported message type") - - except Exception as e: - raise MemphisSchemaError("Schema validation has failed: " + str(e)) - - def validate_json_schema(self, message): - try: - if isinstance(message, bytearray): - try: - message_obj = json.loads(message) - except Exception as e: - raise Exception("Expecting Json format: " + str(e)) - elif isinstance(message, dict): - message_obj = message - message = bytearray(json.dumps(message_obj).encode("utf-8")) - else: - raise Exception("Unsupported message type") - - validate( - instance=message_obj, - schema=self.connection.json_schemas[self.internal_station_name], - ) - return message - except Exception as e: - raise MemphisSchemaError("Schema validation has failed: " + str(e)) - - def validate_graphql(self, message): - try: - if isinstance(message, bytearray): - msg = message.decode("utf-8") - msg = parse_graphql(msg) - elif isinstance(message, str): - msg = parse_graphql(message) - message = message.encode("utf-8") - elif isinstance(message, graphql.language.ast.DocumentNode): - msg = message - message = str(msg.loc.source.body) - message = message.encode("utf-8") - else: - raise MemphisError("Unsupported message type") - validate_res = validate_graphql( - schema=self.connection.graphql_schemas[self.internal_station_name], - document_ast=msg, - ) - if len(validate_res) > 0: - raise Exception("Schema validation has failed: " + str(validate_res)) - return message - except Exception as e: - if "Syntax Error" in str(e): - e = "Invalid message format, expected GraphQL" - raise Exception("Schema validation has failed: " + str(e)) - - def get_dls_msg_id(self, station_name: str, producer_name: str, unix_time: str): - return station_name + "~" + producer_name + "~0~" + unix_time - - async def produce( + async def fetch_messages( self, - message, - ack_wait_sec: int = 15, - headers: Union[Headers, None] = None, - async_produce: bool = False, - msg_id: Union[str, None] = None, - ): - """Produces a message into a station. - Args: - message (bytearray/dict): message to send into the station - bytearray/protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema) - ack_wait_sec (int, optional): max time in seconds to wait for an ack from memphis. Defaults to 15. - headers (dict, optional): Message headers, defaults to {}. - async_produce (boolean, optional): produce operation won't wait for broker acknowledgement - msg_id (string, optional): Attach msg-id header to the message in order to achieve idempotency - Raises: - Exception: _description_ - """ - try: - message = await self.validate_msg(message) - - memphis_headers = { - "$memphis_producedBy": self.producer_name, - "$memphis_connectionId": self.connection.connection_id, - } - - if msg_id is not None: - memphis_headers["msg-id"] = msg_id - - if headers is not None: - headers = headers.headers - headers.update(memphis_headers) - else: - headers = memphis_headers - - if async_produce: - try: - self.loop.create_task( - self.connection.broker_connection.publish( - self.internal_station_name + ".final", - message, - timeout=ack_wait_sec, - headers=headers, - ) - ) - await asyncio.sleep(1) - except Exception as e: - raise MemphisError(e) - else: - await self.connection.broker_connection.publish( - self.internal_station_name + ".final", - message, - timeout=ack_wait_sec, - headers=headers, - ) - except Exception as e: - if hasattr(e, "status_code") and e.status_code == "503": - raise MemphisError( - "Produce operation has failed, please check whether Station/Producer are still exist" - ) - else: - if "Schema validation has failed" in str( - e - ) or "Unsupported message type" in str(e): - msgToSend = "" - if isinstance(message, bytearray): - msgToSend = str(message, "utf-8") - elif hasattr(message, "SerializeToString"): - msgToSend = message.SerializeToString().decode("utf-8") - if self.connection.station_schemaverse_to_dls[ - self.internal_station_name - ]: - unix_time = int(time.time()) - id = self.get_dls_msg_id( - self.internal_station_name, - self.producer_name, - str(unix_time), - ) - - memphis_headers = { - "$memphis_producedBy": self.producer_name, - "$memphis_connectionId": self.connection.connection_id, - } - - if headers != {}: - headers = headers.headers - headers.update(memphis_headers) - else: - headers = memphis_headers - - msgToSendEncoded = msgToSend.encode("utf-8") - msgHex = msgToSendEncoded.hex() - buf = { - "_id": id, - "station_name": self.internal_station_name, - "producer": { - "name": self.producer_name, - "connection_id": self.connection.connection_id, - }, - "creation_unix": unix_time, - "message": { - "data": msgHex, - "headers": headers, - }, - } - buf = json.dumps(buf).encode("utf-8") - await self.connection.broker_connection.publish( - "$memphis-" - + self.internal_station_name - + "-dls.schema." - + id, - buf, - ) - if self.connection.cluster_configurations.get( - "send_notification" - ): - await self.connection.send_notification( - "Schema validation has failed", - "Station: " - + self.station_name - + "\nProducer: " - + self.producer_name - + "\nError:" - + str(e), - msgToSend, - schemaVFailAlertType, - ) - raise MemphisError(str(e)) from e - - async def destroy(self): - """Destroy the producer.""" - try: - destroyProducerReq = { - "name": self.producer_name, - "station_name": self.station_name, - "username": self.connection.username, - } - - producer_name = json.dumps(destroyProducerReq).encode("utf-8") - res = await self.connection.broker_manager.request( - "$memphis_producer_destructions", producer_name, timeout=5 - ) - error = res.data.decode("utf-8") - if error != "" and not "not exist" in error: - raise Exception(error) - - station_name_internal = get_internal_name(self.station_name) - producer_number = ( - self.connection.producers_per_station.get(station_name_internal) - 1 - ) - self.connection.producers_per_station[ - station_name_internal - ] = producer_number - - if producer_number == 0: - sub = self.connection.schema_updates_subs.get(station_name_internal) - task = self.connection.schema_tasks.get(station_name_internal) - if station_name_internal in self.connection.schema_updates_data: - del self.connection.schema_updates_data[station_name_internal] - if station_name_internal in self.connection.schema_updates_subs: - del self.connection.schema_updates_subs[station_name_internal] - if station_name_internal in self.connection.schema_tasks: - del self.connection.schema_tasks[station_name_internal] - if task is not None: - task.cancel() - if sub is not None: - await sub.unsubscribe() - - map_key = station_name_internal + "_" + self.real_name - del self.connection.producers_map[map_key] - - except Exception as e: - raise Exception(e) - - -def default_error_handler(e): - print("ping exception raised", e) - - -class Consumer: - def __init__( - self, - connection, station_name: str, - consumer_name, - consumer_group, - pull_interval_ms: int, - batch_size: int, - batch_max_time_to_wait_ms: int, - max_ack_time_ms: int, + consumer_name: str, + consumer_group: str = "", + batch_size: int = 10, + batch_max_time_to_wait_ms: int = 5000, + max_ack_time_ms: int = 30000, max_msg_deliveries: int = 10, - error_callback=None, + generate_random_suffix: bool = False, start_consume_from_sequence: int = 1, last_messages: int = -1, ): - self.connection = connection - self.station_name = station_name.lower() - self.consumer_name = consumer_name.lower() - self.consumer_group = consumer_group.lower() - self.pull_interval_ms = pull_interval_ms - self.batch_size = batch_size - self.batch_max_time_to_wait_ms = batch_max_time_to_wait_ms - self.max_ack_time_ms = max_ack_time_ms - self.max_msg_deliveries = max_msg_deliveries - self.ping_consumer_invterval_ms = 30000 - if error_callback is None: - error_callback = default_error_handler - self.t_ping = asyncio.create_task(self.__ping_consumer(error_callback)) - self.start_consume_from_sequence = start_consume_from_sequence - - self.last_messages = last_messages - self.context = {} - - def set_context(self, context): - """Set a context (dict) that will be passed to each message handler call.""" - self.context = context - - def consume(self, callback): - """Consume events.""" - self.t_consume = asyncio.create_task(self.__consume(callback)) - self.t_dls = asyncio.create_task(self.__consume_dls(callback)) - - async def __consume(self, callback): - subject = get_internal_name(self.station_name) - consumer_group = get_internal_name(self.consumer_group) - self.psub = await self.connection.broker_connection.pull_subscribe( - subject + ".final", durable=consumer_group - ) - while True: - if self.connection.is_connection_active and self.pull_interval_ms: - try: - memphis_messages = [] - msgs = await self.psub.fetch(self.batch_size) - for msg in msgs: - memphis_messages.append( - Message(msg, self.connection, self.consumer_group) - ) - await callback(memphis_messages, None, self.context) - await asyncio.sleep(self.pull_interval_ms / 1000) - - except asyncio.TimeoutError: - await callback( - [], MemphisError("Memphis: TimeoutError"), self.context - ) - continue - except Exception as e: - if self.connection.is_connection_active: - raise MemphisError(str(e)) from e - else: - return - else: - break - - async def __consume_dls(self, callback): - subject = get_internal_name(self.station_name) - consumer_group = get_internal_name(self.consumer_group) + """Consume a batch of messages. + Args:. + station_name (str): station name to consume messages from. + consumer_name (str): name for the consumer. + consumer_group (str, optional): consumer group name. Defaults to the consumer name. + batch_size (int, optional): pull batch size. Defaults to 10. + batch_max_time_to_wait_ms (int, optional): max time in miliseconds to wait between pulls. Defaults to 5000. + max_ack_time_ms (int, optional): max time for ack a message in miliseconds, in case a message not acked in this time period the Memphis broker will resend it. Defaults to 30000. + max_msg_deliveries (int, optional): max number of message deliveries, by default is 10. + generate_random_suffix (bool): false by default, if true concatenate a random suffix to consumer's name + start_consume_from_sequence(int, optional): start consuming from a specific sequence. defaults to 1. + last_messages: consume the last N messages, defaults to -1 (all messages in the station). + Returns: + list: Message + """ try: - subscription_name = "$memphis_dls_" + subject + "_" + consumer_group - self.consumer_dls = await self.connection.broker_manager.subscribe( - subscription_name, subscription_name - ) - async for msg in self.consumer_dls.messages: - await callback( - [Message(msg, self.connection, self.consumer_group)], - None, - self.context, - ) - except Exception as e: - print("dls", e) - await callback([], MemphisError(str(e)), self.context) - return - - async def __ping_consumer(self, callback): - while True: - try: - await asyncio.sleep(self.ping_consumer_invterval_ms / 1000) - consumer_group = get_internal_name(self.consumer_group) - await self.connection.broker_connection.consumer_info( - self.station_name, consumer_group, timeout=30 + consumer = None + if not self.is_connection_active: + raise MemphisError("Cant fetch messages without being connected!") + internal_station_name = get_internal_name(station_name) + consumer_map_key = internal_station_name + "_" + consumer_name.lower() + if consumer_map_key in self.consumers_map: + consumer = self.consumers_map[consumer_map_key] + else: + consumer = await self.consumer( + station_name=station_name, + consumer_name=consumer_name, + consumer_group=consumer_group, + batch_size=batch_size, + batch_max_time_to_wait_ms=batch_max_time_to_wait_ms, + max_ack_time_ms=max_ack_time_ms, + max_msg_deliveries=max_msg_deliveries, + generate_random_suffix=generate_random_suffix, + start_consume_from_sequence=start_consume_from_sequence, + last_messages=last_messages, ) - - except Exception as e: - callback(e) - - async def destroy(self): - """Destroy the consumer.""" - if self.t_consume is not None: - self.t_consume.cancel() - if self.t_dls is not None: - self.t_dls.cancel() - if self.t_ping is not None: - self.t_ping.cancel() - self.pull_interval_ms = None - try: - destroyConsumerReq = { - "name": self.consumer_name, - "station_name": self.station_name, - "username": self.connection.username, - } - consumer_name = json.dumps(destroyConsumerReq, indent=2).encode("utf-8") - res = await self.connection.broker_manager.request( - "$memphis_consumer_destructions", consumer_name, timeout=5 - ) - error = res.data.decode("utf-8") - if error != "" and not "not exist" in error: - raise MemphisError(error) + messages = await consumer.fetch(batch_size) + if messages == None: + messages = [] + return messages except Exception as e: raise MemphisError(str(e)) from e - -class Message: - def __init__(self, message, connection, cg_name): - self.message = message - self.connection = connection - self.cg_name = cg_name - - async def ack(self): - """Ack a message is done processing.""" + def is_connected(self): + return self.broker_manager.is_connected + + def unset_cached_producer_station(self, station_name): try: - await self.message.ack() + internal_station_name = get_internal_name(station_name) + for key in list(self.producers_map): + producer = self.producers_map[key] + if producer.internal_station_name == internal_station_name: + del self.producers_map[key] except Exception as e: - if ( - "$memphis_pm_id" - in self.message.headers & "$memphis_pm_sequence" - in self.message.headers - ): - try: - msg = { - "id": self.message.headers["$memphis_pm_id"], - "sequence": self.message.headers["$memphis_pm_sequence"], - } - msgToAck = json.dumps(msg).encode("utf-8") - await self.connection.broker_manager.publish( - "$memphis_pm_acks", msgToAck - ) - except Exception as er: - raise MemphisConnectError(str(er)) from er - else: - raise MemphisConnectError(str(e)) from e - return + raise e + - def get_data(self): - """Receive the message.""" + def unset_cached_consumer_station(self, station_name): try: - return bytearray(self.message.data) - except: - return - - def get_headers(self): - """Receive the headers.""" - try: - return self.message.headers - except: - return - - def get_sequence_number(self): - """Get message sequence number.""" - try: - return self.message.metadata.sequence.stream - except: - return - - -def random_bytes(amount: int) -> str: - lst = [random.choice("0123456789abcdef") for n in range(amount)] - s = "".join(lst) - return s - - -class MemphisError(Exception): - def __init__(self, message): - message = message.replace("nats", "memphis") - message = message.replace("NATS", "memphis") - message = message.replace("Nats", "memphis") - message = message.replace("NatsError", "MemphisError") - self.message = message - if message.startswith("memphis:"): - super().__init__(self.message) - else: - super().__init__("memphis: " + self.message) - - -class MemphisConnectError(MemphisError): - pass - - -class MemphisSchemaError(MemphisError): - pass - + internal_station_name = get_internal_name(station_name) + for key in list(self.consumers_map): + consumer = self.consumers_map[key] + consumer_station_name_internal = get_internal_name(consumer.station_name) + if consumer_station_name_internal == internal_station_name: + del self.consumers_map[key] + except Exception as e: + raise e -class MemphisHeaderError(MemphisError): - pass diff --git a/memphis/message.py b/memphis/message.py new file mode 100644 index 0000000..3c8f30d --- /dev/null +++ b/memphis/message.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import json + +from memphis.exceptions import MemphisConnectError + + +class Message: + def __init__(self, message, connection, cg_name): + self.message = message + self.connection = connection + self.cg_name = cg_name + + async def ack(self): + """Ack a message is done processing.""" + try: + await self.message.ack() + except Exception as e: + if ( + "$memphis_pm_id" in self.message.headers + and "$memphis_pm_sequence" in self.message.headers + ): + try: + msg = { + "id": self.message.headers["$memphis_pm_id"], + "sequence": self.message.headers["$memphis_pm_sequence"], + } + msgToAck = json.dumps(msg).encode("utf-8") + await self.connection.broker_manager.publish( + "$memphis_pm_acks", msgToAck + ) + except Exception as er: + raise MemphisConnectError(str(er)) from er + else: + raise MemphisConnectError(str(e)) from e + return + + def get_data(self): + """Receive the message.""" + try: + return bytearray(self.message.data) + except: + return + + def get_headers(self): + """Receive the headers.""" + try: + return self.message.headers + except: + return + + def get_sequence_number(self): + """Get message sequence number.""" + try: + return self.message.metadata.sequence.stream + except: + return diff --git a/memphis/producer.py b/memphis/producer.py new file mode 100644 index 0000000..5c2e4e1 --- /dev/null +++ b/memphis/producer.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import asyncio +import json +import time +from typing import Union + +import graphql +from graphql import parse as parse_graphql +from graphql import validate as validate_graphql +from jsonschema import validate +from memphis.exceptions import MemphisError, MemphisSchemaError +from memphis.headers import Headers +from memphis.utils import get_internal_name + +schemaVFailAlertType = "schema_validation_fail_alert" + + +class Producer: + def __init__( + self, connection, producer_name: str, station_name: str, real_name: str + ): + self.connection = connection + self.producer_name = producer_name.lower() + self.station_name = station_name + self.internal_station_name = get_internal_name(self.station_name) + self.loop = asyncio.get_running_loop() + self.real_name = real_name + + async def validate_msg(self, message): + if self.connection.schema_updates_data[self.internal_station_name] != {}: + schema_type = self.connection.schema_updates_data[ + self.internal_station_name + ]["type"] + if schema_type == "protobuf": + message = self.validate_protobuf(message) + return message + elif schema_type == "json": + message = self.validate_json_schema(message) + return message + elif schema_type == "graphql": + message = self.validate_graphql(message) + return message + elif not isinstance(message, bytearray) and not isinstance(message, dict): + raise MemphisSchemaError("Unsupported message type") + else: + if isinstance(message, dict): + message = bytearray(json.dumps(message).encode("utf-8")) + return message + + def validate_protobuf(self, message): + proto_msg = self.connection.proto_msgs[self.internal_station_name] + msgToSend = "" + try: + if isinstance(message, bytearray): + msgToSend = bytes(message) + try: + proto_msg.ParseFromString(msgToSend) + proto_msg.SerializeToString() + msgToSend = msgToSend.decode("utf-8") + except Exception as e: + if "parsing message" in str(e): + e = "Invalid message format, expecting protobuf" + raise MemphisSchemaError(str(e)) + return message + elif hasattr(message, "SerializeToString"): + msgToSend = message.SerializeToString() + proto_msg.ParseFromString(msgToSend) + proto_msg.SerializeToString() + return msgToSend + + else: + raise MemphisSchemaError("Unsupported message type") + + except Exception as e: + raise MemphisSchemaError("Schema validation has failed: " + str(e)) + + def validate_json_schema(self, message): + try: + if isinstance(message, bytearray): + try: + message_obj = json.loads(message) + except Exception as e: + raise Exception("Expecting Json format: " + str(e)) + elif isinstance(message, dict): + message_obj = message + message = bytearray(json.dumps(message_obj).encode("utf-8")) + else: + raise Exception("Unsupported message type") + + validate( + instance=message_obj, + schema=self.connection.json_schemas[self.internal_station_name], + ) + return message + except Exception as e: + raise MemphisSchemaError("Schema validation has failed: " + str(e)) + + def validate_graphql(self, message): + try: + if isinstance(message, bytearray): + msg = message.decode("utf-8") + msg = parse_graphql(msg) + elif isinstance(message, str): + msg = parse_graphql(message) + message = message.encode("utf-8") + elif isinstance(message, graphql.language.ast.DocumentNode): + msg = message + message = str(msg.loc.source.body) + message = message.encode("utf-8") + else: + raise MemphisError("Unsupported message type") + validate_res = validate_graphql( + schema=self.connection.graphql_schemas[self.internal_station_name], + document_ast=msg, + ) + if len(validate_res) > 0: + raise Exception("Schema validation has failed: " + str(validate_res)) + return message + except Exception as e: + if "Syntax Error" in str(e): + e = "Invalid message format, expected GraphQL" + raise Exception("Schema validation has failed: " + str(e)) + + def get_dls_msg_id(self, station_name: str, producer_name: str, unix_time: str): + return station_name + "~" + producer_name + "~0~" + unix_time + + async def produce( + self, + message, + ack_wait_sec: int = 15, + headers: Union[Headers, None] = None, + async_produce: bool = False, + msg_id: Union[str, None] = None, + ): + """Produces a message into a station. + Args: + message (bytearray/dict): message to send into the station - bytearray/protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema) + ack_wait_sec (int, optional): max time in seconds to wait for an ack from memphis. Defaults to 15. + headers (dict, optional): Message headers, defaults to {}. + async_produce (boolean, optional): produce operation won't wait for broker acknowledgement + msg_id (string, optional): Attach msg-id header to the message in order to achieve idempotency + Raises: + Exception: _description_ + """ + try: + message = await self.validate_msg(message) + + memphis_headers = { + "$memphis_producedBy": self.producer_name, + "$memphis_connectionId": self.connection.connection_id, + } + + if msg_id is not None: + memphis_headers["msg-id"] = msg_id + + if headers is not None: + headers = headers.headers + headers.update(memphis_headers) + else: + headers = memphis_headers + + if async_produce: + try: + self.loop.create_task( + self.connection.broker_connection.publish( + self.internal_station_name + ".final", + message, + timeout=ack_wait_sec, + headers=headers, + ) + ) + await asyncio.sleep(1) # TODO - check why we need sleep in here + except Exception as e: + raise MemphisError(e) + else: + await self.connection.broker_connection.publish( + self.internal_station_name + ".final", + message, + timeout=ack_wait_sec, + headers=headers, + ) + except Exception as e: + if hasattr(e, "status_code") and e.status_code == "503": + raise MemphisError( + "Produce operation has failed, please check whether Station/Producer are still exist" + ) + else: + if "Schema validation has failed" in str( + e + ) or "Unsupported message type" in str(e): + msgToSend = "" + if isinstance(message, bytearray): + msgToSend = str(message, "utf-8") + elif hasattr(message, "SerializeToString"): + msgToSend = message.SerializeToString().decode("utf-8") + if self.connection.station_schemaverse_to_dls[ + self.internal_station_name + ]: + unix_time = int(time.time()) + id = self.get_dls_msg_id( + self.internal_station_name, + self.producer_name, + str(unix_time), + ) + + memphis_headers = { + "$memphis_producedBy": self.producer_name, + "$memphis_connectionId": self.connection.connection_id, + } + + if headers != {}: + headers = headers.headers + headers.update(memphis_headers) + else: + headers = memphis_headers + + msgToSendEncoded = msgToSend.encode("utf-8") + msgHex = msgToSendEncoded.hex() + buf = { + "_id": id, + "station_name": self.internal_station_name, + "producer": { + "name": self.producer_name, + "connection_id": self.connection.connection_id, + }, + "creation_unix": unix_time, + "message": { + "data": msgHex, + "headers": headers, + }, + } + buf = json.dumps(buf).encode("utf-8") + await self.connection.broker_connection.publish( + "$memphis-" + + self.internal_station_name + + "-dls.schema." + + id, + buf, + ) + if self.connection.cluster_configurations.get( + "send_notification" + ): + await self.connection.send_notification( + "Schema validation has failed", + "Station: " + + self.station_name + + "\nProducer: " + + self.producer_name + + "\nError:" + + str(e), + msgToSend, + schemaVFailAlertType, + ) + raise MemphisError(str(e)) from e + + async def destroy(self): + """Destroy the producer.""" + try: + destroyProducerReq = { + "name": self.producer_name, + "station_name": self.station_name, + "username": self.connection.username, + } + + producer_name = json.dumps(destroyProducerReq).encode("utf-8") + res = await self.connection.broker_manager.request( + "$memphis_producer_destructions", producer_name, timeout=5 + ) + error = res.data.decode("utf-8") + if error != "" and not "not exist" in error: + raise Exception(error) + + internal_station_name = get_internal_name(self.station_name) + producer_number = ( + self.connection.producers_per_station.get(internal_station_name) - 1 + ) + self.connection.producers_per_station[ + internal_station_name + ] = producer_number + + if producer_number == 0: + sub = self.connection.schema_updates_subs.get(internal_station_name) + task = self.connection.schema_tasks.get(internal_station_name) + if internal_station_name in self.connection.schema_updates_data: + del self.connection.schema_updates_data[internal_station_name] + if internal_station_name in self.connection.schema_updates_subs: + del self.connection.schema_updates_subs[internal_station_name] + if internal_station_name in self.connection.schema_tasks: + del self.connection.schema_tasks[internal_station_name] + if task is not None: + task.cancel() + if sub is not None: + await sub.unsubscribe() + + map_key = internal_station_name + "_" + self.real_name + del self.connection.producers_map[map_key] + + except Exception as e: + raise Exception(e) diff --git a/memphis/station.py b/memphis/station.py new file mode 100644 index 0000000..05aac6d --- /dev/null +++ b/memphis/station.py @@ -0,0 +1,53 @@ +import json + +from memphis.exceptions import MemphisError +from memphis.utils import get_internal_name + + +class Station: + def __init__(self, connection, name: str): + self.connection = connection + self.name = name.lower() + + async def destroy(self): + """Destroy the station.""" + try: + nameReq = {"station_name": self.name, "username": self.connection.username} + station_name = json.dumps(nameReq, indent=2).encode("utf-8") + res = await self.connection.broker_manager.request( + "$memphis_station_destructions", station_name, timeout=5 + ) + error = res.data.decode("utf-8") + if error != "" and not "not exist" in error: + raise MemphisError(error) + + internal_station_name = get_internal_name(self.name) + sub = self.connection.schema_updates_subs.get(internal_station_name) + task = self.connection.schema_tasks.get(internal_station_name) + if internal_station_name in self.connection.schema_updates_data: + del self.connection.schema_updates_data[internal_station_name] + if internal_station_name in self.connection.schema_updates_subs: + del self.connection.schema_updates_subs[internal_station_name] + if internal_station_name in self.connection.producers_per_station: + del self.connection.producers_per_station[internal_station_name] + if internal_station_name in self.connection.schema_tasks: + del self.connection.schema_tasks[internal_station_name] + if task is not None: + task.cancel() + if sub is not None: + await sub.unsubscribe() + + self.connection.producers_map = { + k: v + for k, v in self.connection.producers_map.items() + if self.name not in k + } + + self.connection.consumers_map = { + k: v + for k, v in self.connection.consumers_map.items() + if self.name not in k + } + + except Exception as e: + raise MemphisError(str(e)) from e diff --git a/memphis/storage_types.py b/memphis/storage_types.py deleted file mode 100644 index 8c1d570..0000000 --- a/memphis/storage_types.py +++ /dev/null @@ -1,16 +0,0 @@ -# Credit for The NATS.IO Authors -# Copyright 2021-2022 The Memphis Authors -# Licensed under the Apache License, Version 2.0 (the “License”); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http:#www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an “AS IS” BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -DISK = "file" -MEMORY = "memory" diff --git a/memphis/retention_types.py b/memphis/types.py similarity index 75% rename from memphis/retention_types.py rename to memphis/types.py index 6fb09ae..cfd6a70 100644 --- a/memphis/retention_types.py +++ b/memphis/types.py @@ -12,6 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -MAX_MESSAGE_AGE_SECONDS = "message_age_sec" -MESSAGES = "messages" -BYTES = "bytes" +from enum import Enum + + +class Retention(Enum): + MAX_MESSAGE_AGE_SECONDS = "message_age_sec" + MESSAGES = "messages" + BYTES = "bytes" + + +class Storage(Enum): + DISK = "file" + MEMORY = "memory" diff --git a/memphis/utils.py b/memphis/utils.py new file mode 100644 index 0000000..10512e9 --- /dev/null +++ b/memphis/utils.py @@ -0,0 +1,32 @@ +import random +from threading import Timer +from typing import Callable + + +class set_interval: + def __init__(self, func: Callable, sec: int): + def func_wrapper(): + self.t = Timer(sec, func_wrapper) + self.t.start() + func() + + self.t = Timer(sec, func_wrapper) + self.t.start() + + def cancel(self): + self.t.cancel() + + +def default_error_handler(e): + print("ping exception raised", e) + + +def get_internal_name(name: str) -> str: + name = name.lower() + return name.replace(".", "#") + + +def random_bytes(amount: int) -> str: + lst = [random.choice("0123456789abcdef") for n in range(amount)] + s = "".join(lst) + return s diff --git a/setup.py b/setup.py index d53182d..5b15795 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name="memphis-py", packages=["memphis"], - version="0.3.2", + version="0.3.3", license="Apache-2.0", description="A powerful messaging platform for modern developers", long_description=long_description, @@ -17,7 +17,7 @@ author="Memphis.dev", author_email="team@memphis.dev", url="https://github.com/memphisdev/memphis.py", - download_url="https://github.com/memphisdev/memphis.py/archive/refs/tags/v0.3.2.tar.gz", + download_url="https://github.com/memphisdev/memphis.py/archive/refs/tags/v0.3.3.tar.gz", keywords=["message broker", "devtool", "streaming", "data"], install_requires=["asyncio", "nats-py", "protobuf", "jsonschema", "graphql-core"], classifiers=[ diff --git a/version.conf b/version.conf index 9fc80f9..87a0871 100644 --- a/version.conf +++ b/version.conf @@ -1 +1 @@ -0.3.2 \ No newline at end of file +0.3.3 \ No newline at end of file