From 06fb01c069fdd940689fe5ed0b53232b753ea9b7 Mon Sep 17 00:00:00 2001
From: Avitaltrifsik <107035359+Avitaltrifsik@users.noreply.github.com>
Date: Thu, 23 Feb 2023 10:51:00 +0200
Subject: [PATCH 01/12] Update README.md (#115)
---
README.md | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/README.md b/README.md
index 8030f14..d08e61f 100644
--- a/README.md
+++ b/README.md
@@ -38,12 +38,12 @@
-**[Memphis](https://memphis.dev)** is a next-generation message broker.
+**[Memphis](https://memphis.dev)** is a next-generation alternative to traditional message brokers.
A simple, robust, and durable cloud-native message broker wrapped with
-an entire ecosystem that enables fast and reliable development of next-generation event-driven use cases.
-Memphis enables building modern applications that require large volumes of streamed and enriched data,
-modern protocols, zero ops, rapid development, extreme cost reduction,
-and a significantly lower amount of dev time for data-oriented developers and data engineers.
+an entire ecosystem that enables cost-effective, fast, and reliable development of modern queue-based use cases.
+Memphis enables the building of modern queue-based applications that require
+large volumes of streamed and enriched data, modern protocols, zero ops, rapid development,
+extreme cost reduction, and a significantly lower amount of dev time for data-oriented developers and data engineers.
## Installation
@@ -328,4 +328,4 @@ consumer.destroy()
```python
memphis.is_connected()
-```
\ No newline at end of file
+```
From de07a2ffd953b70056884e8e475cc853354dbdd5 Mon Sep 17 00:00:00 2001
From: Shay Bratslavsky
Date: Tue, 28 Feb 2023 09:15:03 +0200
Subject: [PATCH 02/12] Fetch messages (#116)
* add fetch_messages function
* try fetch dls
* fetch messages + fetch dls messages
* bug fix + erase on close/destroy
* enable change batch size on same consumer in fetch
* add readme for consumer.fetch()
* allow change batch size in existing consumer
---
README.md | 22 +++++
memphis/memphis.py | 219 ++++++++++++++++++++++++++++++++++-----------
2 files changed, 187 insertions(+), 54 deletions(-)
diff --git a/README.md b/README.md
index d08e61f..0b14e1e 100644
--- a/README.md
+++ b/README.md
@@ -295,6 +295,28 @@ async def msg_handler(msgs, error, context):
consumer.consume(msg_handler)
```
+### Fetch a single batch of messages
+```python
+msgs = await memphis.fetch_messages(
+ station_name="",
+ consumer_name="",
+ consumer_group="", # defaults to the consumer name
+ batch_size=10, # defaults to 10
+ batch_max_time_to_wait_ms=5000, # defaults to 5000
+ max_ack_time_ms=30000, # defaults to 30000
+ max_msg_deliveries=10, # defaults to 10
+ generate_random_suffix=False
+ start_consume_from_sequence=1 # start consuming from a specific sequence. defaults to 1
+ last_messages=-1 # consume the last N messages, defaults to -1 (all messages in the station))
+)
+```
+
+### Fetch a single batch of messages after creating a consumer
+```python
+msgs = await consumer.fetch(batch_size=10) # defaults to 10
+```
+
+
### Acknowledge a message
Acknowledge a message indicates the Memphis server to not re-send the same message again to the same consumer / consumers group
diff --git a/memphis/memphis.py b/memphis/memphis.py
index c8fee2a..831bc6f 100644
--- a/memphis/memphis.py
+++ b/memphis/memphis.py
@@ -83,6 +83,7 @@ def __init__(self):
self.update_configurations_sub = {}
self.configuration_tasks = {}
self.producers_map = dict()
+ self.consumers_map = dict()
async def get_msgs_update_configurations(self, iterable: Iterable):
try:
@@ -329,6 +330,9 @@ async def close(self):
if self.update_configurations_sub is not None:
await self.update_configurations_sub.unsubscribe()
self.producers_map.clear()
+ for consumer in self.consumers_map:
+ consumer.dls_messages.clear()
+ self.consumers_map.clear()
except:
return
@@ -387,45 +391,45 @@ async def producer(
if create_res["error"] != "":
raise MemphisError(create_res["error"])
- station_name_internal = get_internal_name(station_name)
- self.station_schemaverse_to_dls[station_name_internal] = create_res[
+ internal_station_name = get_internal_name(station_name)
+ self.station_schemaverse_to_dls[internal_station_name] = create_res[
"schemaverse_to_dls"
]
self.cluster_configurations["send_notification"] = create_res[
"send_notification"
]
await self.start_listen_for_schema_updates(
- station_name_internal, create_res["schema_update"]
+ internal_station_name, create_res["schema_update"]
)
- if self.schema_updates_data[station_name_internal] != {}:
+ if self.schema_updates_data[internal_station_name] != {}:
if (
- self.schema_updates_data[station_name_internal]["type"]
+ self.schema_updates_data[internal_station_name]["type"]
== "protobuf"
):
- self.parse_descriptor(station_name_internal)
- if self.schema_updates_data[station_name_internal]["type"] == "json":
- schema = self.schema_updates_data[station_name_internal][
+ self.parse_descriptor(internal_station_name)
+ if self.schema_updates_data[internal_station_name]["type"] == "json":
+ schema = self.schema_updates_data[internal_station_name][
"active_version"
]["schema_content"]
- self.json_schemas[station_name_internal] = json.loads(schema)
+ self.json_schemas[internal_station_name] = json.loads(schema)
elif (
- self.schema_updates_data[station_name_internal]["type"] == "graphql"
+ self.schema_updates_data[internal_station_name]["type"] == "graphql"
):
- self.graphql_schemas[station_name_internal] = build_graphql_schema(
- self.schema_updates_data[station_name_internal][
+ self.graphql_schemas[internal_station_name] = build_graphql_schema(
+ self.schema_updates_data[internal_station_name][
"active_version"
]["schema_content"]
)
producer = Producer(self, producer_name, station_name, real_name)
- map_key = station_name_internal + "_" + real_name
+ map_key = internal_station_name + "_" + real_name
self.producers_map[map_key] = producer
return producer
except Exception as e:
raise MemphisError(str(e)) from e
- async def get_msg_schema_updates(self, station_name_internal, iterable):
+ async def get_msg_schema_updates(self, internal_station_name, iterable):
async for msg in iterable:
message = msg.data.decode("utf-8")
message = json.loads(message)
@@ -433,8 +437,8 @@ async def get_msg_schema_updates(self, station_name_internal, iterable):
data = {}
else:
data = message["init"]
- self.schema_updates_data[station_name_internal] = data
- self.parse_descriptor(station_name_internal)
+ self.schema_updates_data[internal_station_name] = data
+ self.parse_descriptor(internal_station_name)
def parse_descriptor(self, station_name):
try:
@@ -521,7 +525,7 @@ async def consumer(
try:
if not self.is_connection_active:
raise MemphisError("Connection is dead")
-
+ real_name = consumer_name.lower()
if generate_random_suffix:
consumer_name = self.__generateRandomSuffix(consumer_name)
cg = consumer_name if not consumer_group else consumer_group
@@ -563,7 +567,9 @@ async def consumer(
if err_msg != "":
raise MemphisError(err_msg)
- return Consumer(
+ internal_station_name = get_internal_name(station_name)
+ map_key = internal_station_name + "_" + real_name
+ consumer = Consumer(
self,
station_name,
consumer_name,
@@ -576,6 +582,8 @@ async def consumer(
start_consume_from_sequence=start_consume_from_sequence,
last_messages=last_messages,
)
+ self.consumers_map[map_key] = consumer
+ return consumer
except Exception as e:
raise MemphisError(str(e)) from e
@@ -604,8 +612,8 @@ async def produce(
Exception: _description_
"""
try:
- station_name_internal = get_internal_name(station_name)
- map_key = station_name_internal + "_" + producer_name.lower()
+ internal_station_name = get_internal_name(station_name)
+ map_key = internal_station_name + "_" + producer_name.lower()
producer = None
if map_key in self.producers_map:
producer = self.producers_map[map_key]
@@ -625,6 +633,52 @@ async def produce(
except Exception as e:
raise MemphisError(str(e)) from e
+
+ async def fetch_messages(
+ self,
+ station_name: str,
+ consumer_name: str,
+ consumer_group: str = "",
+ batch_size: int = 10,
+ batch_max_time_to_wait_ms: int = 5000,
+ max_ack_time_ms: int = 30000,
+ max_msg_deliveries: int = 10,
+ generate_random_suffix: bool = False,
+ start_consume_from_sequence: int = 1,
+ last_messages: int = -1
+ ):
+ """Consume a batch of messages.
+ Args:.
+ station_name (str): station name to consume messages from.
+ consumer_name (str): name for the consumer.
+ consumer_group (str, optional): consumer group name. Defaults to the consumer name.
+ batch_size (int, optional): pull batch size. Defaults to 10.
+ batch_max_time_to_wait_ms (int, optional): max time in miliseconds to wait between pulls. Defaults to 5000.
+ max_ack_time_ms (int, optional): max time for ack a message in miliseconds, in case a message not acked in this time period the Memphis broker will resend it. Defaults to 30000.
+ max_msg_deliveries (int, optional): max number of message deliveries, by default is 10.
+ generate_random_suffix (bool): false by default, if true concatenate a random suffix to consumer's name
+ start_consume_from_sequence(int, optional): start consuming from a specific sequence. defaults to 1.
+ last_messages: consume the last N messages, defaults to -1 (all messages in the station).
+ Returns:
+ list: Message
+ """
+ try:
+ consumer = None
+ if not self.is_connection_active:
+ raise MemphisError("Cant fetch messages without being connected!")
+ internal_station_name = get_internal_name(station_name)
+ consumer_map_key = internal_station_name + "_" + consumer_name.lower()
+ if consumer_map_key in self.consumers_map:
+ consumer = self.consumers_map[consumer_map_key]
+ else:
+ consumer = await self.consumer(station_name=station_name, consumer_name=consumer_name, consumer_group=consumer_group, batch_size=batch_size, batch_max_time_to_wait_ms=batch_max_time_to_wait_ms, max_ack_time_ms=max_ack_time_ms, max_msg_deliveries=max_msg_deliveries, generate_random_suffix=generate_random_suffix, start_consume_from_sequence=start_consume_from_sequence, last_messages=last_messages)
+ messages = await consumer.fetch(batch_size)
+ if messages == None:
+ messages = []
+ return messages
+ except Exception as e:
+ raise MemphisError(str(e)) from e
+
def is_connected(self):
return self.broker_manager.is_connected
@@ -646,17 +700,17 @@ async def destroy(self):
if error != "" and not "not exist" in error:
raise MemphisError(error)
- station_name_internal = get_internal_name(self.name)
- sub = self.connection.schema_updates_subs.get(station_name_internal)
- task = self.connection.schema_tasks.get(station_name_internal)
- if station_name_internal in self.connection.schema_updates_data:
- del self.connection.schema_updates_data[station_name_internal]
- if station_name_internal in self.connection.schema_updates_subs:
- del self.connection.schema_updates_subs[station_name_internal]
- if station_name_internal in self.connection.producers_per_station:
- del self.connection.producers_per_station[station_name_internal]
- if station_name_internal in self.connection.schema_tasks:
- del self.connection.schema_tasks[station_name_internal]
+ internal_station_name = get_internal_name(self.name)
+ sub = self.connection.schema_updates_subs.get(internal_station_name)
+ task = self.connection.schema_tasks.get(internal_station_name)
+ if internal_station_name in self.connection.schema_updates_data:
+ del self.connection.schema_updates_data[internal_station_name]
+ if internal_station_name in self.connection.schema_updates_subs:
+ del self.connection.schema_updates_subs[internal_station_name]
+ if internal_station_name in self.connection.producers_per_station:
+ del self.connection.producers_per_station[internal_station_name]
+ if internal_station_name in self.connection.schema_tasks:
+ del self.connection.schema_tasks[internal_station_name]
if task is not None:
task.cancel()
if sub is not None:
@@ -668,6 +722,12 @@ async def destroy(self):
if self.name not in k
}
+ self.connection.consumers_map = {
+ k: v
+ for k, v in self.connection.consumers_map.items()
+ if self.name not in k
+ }
+
except Exception as e:
raise MemphisError(str(e)) from e
@@ -932,29 +992,29 @@ async def destroy(self):
if error != "" and not "not exist" in error:
raise Exception(error)
- station_name_internal = get_internal_name(self.station_name)
+ internal_station_name = get_internal_name(self.station_name)
producer_number = (
- self.connection.producers_per_station.get(station_name_internal) - 1
+ self.connection.producers_per_station.get(internal_station_name) - 1
)
self.connection.producers_per_station[
- station_name_internal
+ internal_station_name
] = producer_number
if producer_number == 0:
- sub = self.connection.schema_updates_subs.get(station_name_internal)
- task = self.connection.schema_tasks.get(station_name_internal)
- if station_name_internal in self.connection.schema_updates_data:
- del self.connection.schema_updates_data[station_name_internal]
- if station_name_internal in self.connection.schema_updates_subs:
- del self.connection.schema_updates_subs[station_name_internal]
- if station_name_internal in self.connection.schema_tasks:
- del self.connection.schema_tasks[station_name_internal]
+ sub = self.connection.schema_updates_subs.get(internal_station_name)
+ task = self.connection.schema_tasks.get(internal_station_name)
+ if internal_station_name in self.connection.schema_updates_data:
+ del self.connection.schema_updates_data[internal_station_name]
+ if internal_station_name in self.connection.schema_updates_subs:
+ del self.connection.schema_updates_subs[internal_station_name]
+ if internal_station_name in self.connection.schema_tasks:
+ del self.connection.schema_tasks[internal_station_name]
if task is not None:
task.cancel()
if sub is not None:
await sub.unsubscribe()
- map_key = station_name_internal + "_" + self.real_name
+ map_key = internal_station_name + "_" + self.real_name
del self.connection.producers_map[map_key]
except Exception as e:
@@ -965,6 +1025,7 @@ def default_error_handler(e):
print("ping exception raised", e)
+
class Consumer:
def __init__(
self,
@@ -995,9 +1056,13 @@ def __init__(
error_callback = default_error_handler
self.t_ping = asyncio.create_task(self.__ping_consumer(error_callback))
self.start_consume_from_sequence = start_consume_from_sequence
-
self.last_messages = last_messages
self.context = {}
+ self.dls_messages = []
+ self.dls_current_index = 0
+ self.dls_callback_func = None
+ self.t_dls = asyncio.create_task(self.__consume_dls())
+
def set_context(self, context):
"""Set a context (dict) that will be passed to each message handler call."""
@@ -1005,8 +1070,8 @@ def set_context(self, context):
def consume(self, callback):
"""Consume events."""
+ self.dls_callback_func = callback
self.t_consume = asyncio.create_task(self.__consume(callback))
- self.t_dls = asyncio.create_task(self.__consume_dls(callback))
async def __consume(self, callback):
subject = get_internal_name(self.station_name)
@@ -1039,7 +1104,7 @@ async def __consume(self, callback):
else:
break
- async def __consume_dls(self, callback):
+ async def __consume_dls(self):
subject = get_internal_name(self.station_name)
consumer_group = get_internal_name(self.consumer_group)
try:
@@ -1048,16 +1113,58 @@ async def __consume_dls(self, callback):
subscription_name, subscription_name
)
async for msg in self.consumer_dls.messages:
- await callback(
- [Message(msg, self.connection, self.consumer_group)],
- None,
- self.context,
- )
+ index_to_insert = self.dls_current_index
+ if index_to_insert>=10000:
+ index_to_insert%=10000
+ self.dls_messages.insert(index_to_insert, Message(msg, self.connection, self.consumer_group))
+ self.dls_current_index+=1
+ if self.dls_callback_func != None:
+ await self.dls_callback_func(
+ [Message(msg, self.connection, self.consumer_group)],
+ None,
+ self.context,
+ )
except Exception as e:
- print("dls", e)
- await callback([], MemphisError(str(e)), self.context)
+ await self.dls_callback_func([], MemphisError(str(e)), self.context)
return
+ async def fetch(self, batch_size: int = 10):
+ """Fetch a batch of messages."""
+ messages = []
+ if self.connection.is_connection_active:
+ try:
+ self.batch_size = batch_size
+ if len(self.dls_messages)>0:
+ if len(self.dls_messages) <= batch_size:
+ messages = self.dls_messages
+ self.dls_messages = []
+ self.dls_current_index = 0
+ else:
+ messages = self.dls_messages[0:batch_size]
+ del self.dls_messages[0:batch_size]
+ self.dls_current_index -= len(messages)
+ return messages
+
+ durableName = ""
+ if self.consumer_group != "":
+ durableName = get_internal_name(self.consumer_group)
+ else:
+ durableName = get_internal_name(self.consumer_name)
+ subject = get_internal_name(self.station_name)
+ consumer_group = get_internal_name(self.consumer_group)
+ self.psub = await self.connection.broker_connection.pull_subscribe(
+ subject + ".final", durable=durableName
+ )
+ msgs = await self.psub.fetch(batch_size)
+ for msg in msgs:
+ messages.append(Message(msg, self.connection, self.consumer_group))
+ return messages
+ except Exception as e:
+ if not "timeout" in str(e):
+ raise MemphisError(str(e)) from e
+ else:
+ return messages
+
async def __ping_consumer(self, callback):
while True:
try:
@@ -1092,6 +1199,10 @@ async def destroy(self):
error = res.data.decode("utf-8")
if error != "" and not "not exist" in error:
raise MemphisError(error)
+ self.dls_messages.clear()
+ internal_station_name = get_internal_name(self.station_name)
+ map_key = internal_station_name + "_" + self.consumer_name.lower()
+ del self.connection.consumers_map[map_key]
except Exception as e:
raise MemphisError(str(e)) from e
@@ -1109,7 +1220,7 @@ async def ack(self):
except Exception as e:
if (
"$memphis_pm_id"
- in self.message.headers & "$memphis_pm_sequence"
+ in self.message.headers and "$memphis_pm_sequence"
in self.message.headers
):
try:
From 4d25dc9bfd8a8aa9e9de215c85fb152f2ab7e755 Mon Sep 17 00:00:00 2001
From: Alon David <37884564+alon-david@users.noreply.github.com>
Date: Tue, 28 Feb 2023 12:59:51 +0200
Subject: [PATCH 03/12] refactor storage and retention variables to be enum
based (#118)
---
README.md | 16 ++++++++--------
examples/consumer.py | 1 +
examples/producer.py | 1 +
memphis/__init__.py | 2 --
memphis/memphis.py | 16 +++++++---------
memphis/storage_types.py | 16 ----------------
memphis/{retention_types.py => types.py} | 15 ++++++++++++---
7 files changed, 29 insertions(+), 38 deletions(-)
delete mode 100644 memphis/storage_types.py
rename memphis/{retention_types.py => types.py} (75%)
diff --git a/README.md b/README.md
index 0b14e1e..28f5e69 100644
--- a/README.md
+++ b/README.md
@@ -55,7 +55,7 @@ $ pip3 install memphis-py
```python
from memphis import Memphis, Headers
-from memphis import retention_types, storage_types
+from memphis.types import Retention, Storage
```
### Connecting to Memphis
@@ -108,9 +108,9 @@ _If a station already exists nothing happens, the new configuration will not be
station = memphis.station(
name="",
schema_name="",
- retention_type=retention_types.MAX_MESSAGE_AGE_SECONDS, # MAX_MESSAGE_AGE_SECONDS/MESSAGES/BYTES. Defaults to MAX_MESSAGE_AGE_SECONDS
+ retention_type=Retention.MAX_MESSAGE_AGE_SECONDS, # MAX_MESSAGE_AGE_SECONDS/MESSAGES/BYTES. Defaults to MAX_MESSAGE_AGE_SECONDS
retention_value=604800, # defaults to 604800
- storage_type=storage_types.DISK, # storage_types.DISK/storage_types.MEMORY. Defaults to DISK
+ storage_type=Storage.DISK, # Storage.DISK/Storage.MEMORY. Defaults to DISK
replicas=1, # defaults to 1
idempotency_window_ms=120000, # defaults to 2 minutes
send_poison_msg_to_dls=True, # defaults to true
@@ -124,19 +124,19 @@ station = memphis.station(
Memphis currently supports the following types of retention:
```python
-memphis.retention_types.MAX_MESSAGE_AGE_SECONDS
+memphis.types.Retention.MAX_MESSAGE_AGE_SECONDS
```
Means that every message persists for the value set in retention value field (in seconds)
```python
-memphis.retention_types.MESSAGES
+memphis.types.Retention.MESSAGES
```
Means that after max amount of saved messages (set in retention value), the oldest messages will be deleted
```python
-memphis.retention_types.BYTES
+memphis.types.Retention.BYTES
```
Means that after max amount of saved bytes (set in retention value), the oldest messages will be deleted
@@ -146,13 +146,13 @@ Means that after max amount of saved bytes (set in retention value), the oldest
Memphis currently supports the following types of messages storage:
```python
-memphis.storage_types.DISK
+memphis.types.Storage.DISK
```
Means that messages persist on disk
```python
-memphis.storage_types.MEMORY
+memphis.types.Storage.MEMORY
```
Means that messages persist on the main memory
diff --git a/examples/consumer.py b/examples/consumer.py
index 4febb93..dfebbbe 100644
--- a/examples/consumer.py
+++ b/examples/consumer.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+
import asyncio
from memphis import Memphis, MemphisConnectError, MemphisError, MemphisHeaderError
diff --git a/examples/producer.py b/examples/producer.py
index e5a79e0..01dd381 100644
--- a/examples/producer.py
+++ b/examples/producer.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+
import asyncio
from memphis import (
diff --git a/memphis/__init__.py b/memphis/__init__.py
index ff8bd5f..d6048f1 100644
--- a/memphis/__init__.py
+++ b/memphis/__init__.py
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import memphis.retention_types
-import memphis.storage_types
from memphis.memphis import (
Headers,
Memphis,
diff --git a/memphis/memphis.py b/memphis/memphis.py
index 831bc6f..b5abd7b 100644
--- a/memphis/memphis.py
+++ b/memphis/memphis.py
@@ -22,8 +22,6 @@
from typing import Callable, Iterable, Union
import graphql
-import memphis.retention_types as retention_types
-import memphis.storage_types as storage_types
import nats as broker
from google.protobuf import descriptor_pb2, descriptor_pool
from google.protobuf.message_factory import MessageFactory
@@ -31,7 +29,7 @@
from graphql import parse as parse_graphql
from graphql import validate as validate_graphql
from jsonschema import validate
-
+from memphis.types import Retention, Storage
schemaVFailAlertType = "schema_validation_fail_alert"
@@ -199,22 +197,22 @@ async def send_notification(self, title, msg, failedMsg, type):
async def station(
self,
name: str,
- retention_type: str = retention_types.MAX_MESSAGE_AGE_SECONDS,
+ retention_type: Retention = Retention.MAX_MESSAGE_AGE_SECONDS,
retention_value: int = 604800,
- storage_type: str = storage_types.DISK,
+ storage_type: Storage = Storage.DISK,
replicas: int = 1,
idempotency_window_ms: int = 120000,
schema_name: str = "",
send_poison_msg_to_dls: bool = True,
send_schema_failed_msg_to_dls: bool = True,
- tiered_storage_enabled: bool = False
+ tiered_storage_enabled: bool = False,
):
"""Creates a station.
Args:
name (str): station name.
- retention_type (str, optional): retention type: message_age_sec/messages/bytes . Defaults to "message_age_sec".
+ retention_type (Retention, optional): retention type: message_age_sec/messages/bytes . Defaults to "message_age_sec".
retention_value (int, optional): number which represents the retention based on the retention_type. Defaults to 604800.
- storage_type (str, optional): persistance storage for messages of the station: disk/memory. Defaults to "disk".
+ storage_type (Storage, optional): persistance storage for messages of the station: disk/memory. Defaults to "disk".
replicas (int, optional):number of replicas for the messages of the data. Defaults to 1.
idempotency_window_ms (int, optional): time frame in which idempotent messages will be tracked, happens based on message ID Defaults to 120000.
schema_name (str): schema name.
@@ -238,7 +236,7 @@ async def station(
"Schemaverse": send_schema_failed_msg_to_dls,
},
"username": self.username,
- "tiered_storage_enabled": tiered_storage_enabled
+ "tiered_storage_enabled": tiered_storage_enabled,
}
create_station_req_bytes = json.dumps(createStationReq, indent=2).encode(
"utf-8"
diff --git a/memphis/storage_types.py b/memphis/storage_types.py
deleted file mode 100644
index 8c1d570..0000000
--- a/memphis/storage_types.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# Credit for The NATS.IO Authors
-# Copyright 2021-2022 The Memphis Authors
-# Licensed under the Apache License, Version 2.0 (the “License”);
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http:#www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an “AS IS” BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-DISK = "file"
-MEMORY = "memory"
diff --git a/memphis/retention_types.py b/memphis/types.py
similarity index 75%
rename from memphis/retention_types.py
rename to memphis/types.py
index 6fb09ae..cfd6a70 100644
--- a/memphis/retention_types.py
+++ b/memphis/types.py
@@ -12,6 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-MAX_MESSAGE_AGE_SECONDS = "message_age_sec"
-MESSAGES = "messages"
-BYTES = "bytes"
+from enum import Enum
+
+
+class Retention(Enum):
+ MAX_MESSAGE_AGE_SECONDS = "message_age_sec"
+ MESSAGES = "messages"
+ BYTES = "bytes"
+
+
+class Storage(Enum):
+ DISK = "file"
+ MEMORY = "memory"
From f22cda0fe020f3779ac9e2689a06c100a01d1f9f Mon Sep 17 00:00:00 2001
From: Alon David <37884564+alon-david@users.noreply.github.com>
Date: Wed, 1 Mar 2023 11:00:23 +0200
Subject: [PATCH 04/12] Feature/file modularity (#119)
* refactor storage and retention variables to be enum based
* separated memphis.py big file into multiple logical files for better tracking and maintaincing of code
---
memphis/__init__.py | 7 +-
memphis/consumer.py | 189 ++++++++++++
memphis/exceptions.py | 23 ++
memphis/headers.py | 19 ++
memphis/memphis.py | 675 ++----------------------------------------
memphis/message.py | 57 ++++
memphis/producer.py | 298 +++++++++++++++++++
memphis/station.py | 53 ++++
memphis/utils.py | 32 ++
9 files changed, 699 insertions(+), 654 deletions(-)
create mode 100644 memphis/consumer.py
create mode 100644 memphis/exceptions.py
create mode 100644 memphis/headers.py
create mode 100644 memphis/message.py
create mode 100644 memphis/producer.py
create mode 100644 memphis/station.py
create mode 100644 memphis/utils.py
diff --git a/memphis/__init__.py b/memphis/__init__.py
index d6048f1..aa44e30 100644
--- a/memphis/__init__.py
+++ b/memphis/__init__.py
@@ -11,12 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from memphis.memphis import (
- Headers,
- Memphis,
+from memphis.exceptions import (
MemphisConnectError,
MemphisError,
MemphisHeaderError,
MemphisSchemaError,
)
+from memphis.headers import Headers
+from memphis.memphis import Memphis
diff --git a/memphis/consumer.py b/memphis/consumer.py
new file mode 100644
index 0000000..aad4625
--- /dev/null
+++ b/memphis/consumer.py
@@ -0,0 +1,189 @@
+from __future__ import annotations
+
+import asyncio
+import json
+
+from memphis.exceptions import MemphisError
+from memphis.utils import default_error_handler, get_internal_name
+
+
+class Consumer:
+ def __init__(
+ self,
+ connection,
+ station_name: str,
+ consumer_name,
+ consumer_group,
+ pull_interval_ms: int,
+ batch_size: int,
+ batch_max_time_to_wait_ms: int,
+ max_ack_time_ms: int,
+ max_msg_deliveries: int = 10,
+ error_callback=None,
+ start_consume_from_sequence: int = 1,
+ last_messages: int = -1,
+ ):
+ self.connection = connection
+ self.station_name = station_name.lower()
+ self.consumer_name = consumer_name.lower()
+ self.consumer_group = consumer_group.lower()
+ self.pull_interval_ms = pull_interval_ms
+ self.batch_size = batch_size
+ self.batch_max_time_to_wait_ms = batch_max_time_to_wait_ms
+ self.max_ack_time_ms = max_ack_time_ms
+ self.max_msg_deliveries = max_msg_deliveries
+ self.ping_consumer_invterval_ms = 30000
+ if error_callback is None:
+ error_callback = default_error_handler
+ self.t_ping = asyncio.create_task(self.__ping_consumer(error_callback))
+ self.start_consume_from_sequence = start_consume_from_sequence
+ self.last_messages = last_messages
+ self.context = {}
+ self.dls_messages = []
+ self.dls_current_index = 0
+ self.dls_callback_func = None
+ self.t_dls = asyncio.create_task(self.__consume_dls())
+
+ def set_context(self, context):
+ """Set a context (dict) that will be passed to each message handler call."""
+ self.context = context
+
+ def consume(self, callback):
+ """Consume events."""
+ self.dls_callback_func = callback
+ self.t_consume = asyncio.create_task(self.__consume(callback))
+
+ async def __consume(self, callback):
+ subject = get_internal_name(self.station_name)
+ consumer_group = get_internal_name(self.consumer_group)
+ self.psub = await self.connection.broker_connection.pull_subscribe(
+ subject + ".final", durable=consumer_group
+ )
+ while True:
+ if self.connection.is_connection_active and self.pull_interval_ms:
+ try:
+ memphis_messages = []
+ msgs = await self.psub.fetch(self.batch_size)
+ for msg in msgs:
+ memphis_messages.append(
+ Message(msg, self.connection, self.consumer_group)
+ )
+ await callback(memphis_messages, None, self.context)
+ await asyncio.sleep(self.pull_interval_ms / 1000)
+
+ except asyncio.TimeoutError:
+ await callback(
+ [], MemphisError("Memphis: TimeoutError"), self.context
+ )
+ continue
+ except Exception as e:
+ if self.connection.is_connection_active:
+ raise MemphisError(str(e)) from e
+ else:
+ return
+ else:
+ break
+
+ async def __consume_dls(self):
+ subject = get_internal_name(self.station_name)
+ consumer_group = get_internal_name(self.consumer_group)
+ try:
+ subscription_name = "$memphis_dls_" + subject + "_" + consumer_group
+ self.consumer_dls = await self.connection.broker_manager.subscribe(
+ subscription_name, subscription_name
+ )
+ async for msg in self.consumer_dls.messages:
+ index_to_insert = self.dls_current_index
+ if index_to_insert >= 10000:
+ index_to_insert %= 10000
+ self.dls_messages.insert(
+ index_to_insert, Message(msg, self.connection, self.consumer_group)
+ )
+ self.dls_current_index += 1
+ if self.dls_callback_func != None:
+ await self.dls_callback_func(
+ [Message(msg, self.connection, self.consumer_group)],
+ None,
+ self.context,
+ )
+ except Exception as e:
+ await self.dls_callback_func([], MemphisError(str(e)), self.context)
+ return
+
+ async def fetch(self, batch_size: int = 10):
+ """Fetch a batch of messages."""
+ messages = []
+ if self.connection.is_connection_active:
+ try:
+ self.batch_size = batch_size
+ if len(self.dls_messages) > 0:
+ if len(self.dls_messages) <= batch_size:
+ messages = self.dls_messages
+ self.dls_messages = []
+ self.dls_current_index = 0
+ else:
+ messages = self.dls_messages[0:batch_size]
+ del self.dls_messages[0:batch_size]
+ self.dls_current_index -= len(messages)
+ return messages
+
+ durableName = ""
+ if self.consumer_group != "":
+ durableName = get_internal_name(self.consumer_group)
+ else:
+ durableName = get_internal_name(self.consumer_name)
+ subject = get_internal_name(self.station_name)
+ consumer_group = get_internal_name(self.consumer_group)
+ self.psub = await self.connection.broker_connection.pull_subscribe(
+ subject + ".final", durable=durableName
+ )
+ msgs = await self.psub.fetch(batch_size)
+ for msg in msgs:
+ messages.append(Message(msg, self.connection, self.consumer_group))
+ return messages
+ except Exception as e:
+ if not "timeout" in str(e):
+ raise MemphisError(str(e)) from e
+ else:
+ return messages
+
+ async def __ping_consumer(self, callback):
+ while True:
+ try:
+ await asyncio.sleep(self.ping_consumer_invterval_ms / 1000)
+ consumer_group = get_internal_name(self.consumer_group)
+ await self.connection.broker_connection.consumer_info(
+ self.station_name, consumer_group, timeout=30
+ )
+
+ except Exception as e:
+ callback(e)
+
+ async def destroy(self):
+ """Destroy the consumer."""
+ if self.t_consume is not None:
+ self.t_consume.cancel()
+ if self.t_dls is not None:
+ self.t_dls.cancel()
+ if self.t_ping is not None:
+ self.t_ping.cancel()
+ self.pull_interval_ms = None
+ try:
+ destroyConsumerReq = {
+ "name": self.consumer_name,
+ "station_name": self.station_name,
+ "username": self.connection.username,
+ }
+ consumer_name = json.dumps(destroyConsumerReq, indent=2).encode("utf-8")
+ res = await self.connection.broker_manager.request(
+ "$memphis_consumer_destructions", consumer_name, timeout=5
+ )
+ error = res.data.decode("utf-8")
+ if error != "" and not "not exist" in error:
+ raise MemphisError(error)
+ self.dls_messages.clear()
+ internal_station_name = get_internal_name(self.station_name)
+ map_key = internal_station_name + "_" + self.consumer_name.lower()
+ del self.connection.consumers_map[map_key]
+ except Exception as e:
+ raise MemphisError(str(e)) from e
diff --git a/memphis/exceptions.py b/memphis/exceptions.py
new file mode 100644
index 0000000..65c8c25
--- /dev/null
+++ b/memphis/exceptions.py
@@ -0,0 +1,23 @@
+class MemphisError(Exception):
+ def __init__(self, message):
+ message = message.replace("nats", "memphis")
+ message = message.replace("NATS", "memphis")
+ message = message.replace("Nats", "memphis")
+ message = message.replace("NatsError", "MemphisError")
+ self.message = message
+ if message.startswith("memphis:"):
+ super().__init__(self.message)
+ else:
+ super().__init__("memphis: " + self.message)
+
+
+class MemphisConnectError(MemphisError):
+ pass
+
+
+class MemphisSchemaError(MemphisError):
+ pass
+
+
+class MemphisHeaderError(MemphisError):
+ pass
diff --git a/memphis/headers.py b/memphis/headers.py
new file mode 100644
index 0000000..1d25e14
--- /dev/null
+++ b/memphis/headers.py
@@ -0,0 +1,19 @@
+from memphis.exceptions import MemphisHeaderError
+
+
+class Headers:
+ def __init__(self):
+ self.headers = {}
+
+ def add(self, key, value):
+ """Add a header.
+ Args:
+ key (string): header key.
+ value (string): header value.
+ Raises:
+ Exception: _description_
+ """
+ if not key.startswith("$memphis"):
+ self.headers[key] = value
+ else:
+ raise MemphisHeaderError("Keys in headers should not start with $memphis")
diff --git a/memphis/memphis.py b/memphis/memphis.py
index b5abd7b..0d15ccd 100644
--- a/memphis/memphis.py
+++ b/memphis/memphis.py
@@ -29,43 +29,16 @@
from graphql import parse as parse_graphql
from graphql import validate as validate_graphql
from jsonschema import validate
+from memphis.consumer import Consumer
+from memphis.exceptions import MemphisConnectError, MemphisError, MemphisHeaderError
+from memphis.headers import Headers
+from memphis.producer import Producer
from memphis.types import Retention, Storage
+from memphis.utils import get_internal_name
schemaVFailAlertType = "schema_validation_fail_alert"
-class set_interval:
- def __init__(self, func: Callable, sec: int):
- def func_wrapper():
- self.t = Timer(sec, func_wrapper)
- self.t.start()
- func()
-
- self.t = Timer(sec, func_wrapper)
- self.t.start()
-
- def cancel(self):
- self.t.cancel()
-
-
-class Headers:
- def __init__(self):
- self.headers = {}
-
- def add(self, key, value):
- """Add a header.
- Args:
- key (string): header key.
- value (string): header value.
- Raises:
- Exception: _description_
- """
- if not key.startswith("$memphis"):
- self.headers[key] = value
- else:
- raise MemphisHeaderError("Keys in headers should not start with $memphis")
-
-
class Memphis:
def __init__(self):
self.is_connection_active = False
@@ -135,8 +108,8 @@ async def connect(
port (int, optional): port. Defaults to 6666.
reconnect (bool, optional): whether to do reconnect while connection is lost. Defaults to True.
max_reconnect (int, optional): The reconnect attempt. Defaults to 3.
- reconnect_interval_ms (int, optional): Interval in miliseconds between reconnect attempts. Defaults to 200.
- timeout_ms (int, optional): connection timeout in miliseconds. Defaults to 15000.
+ reconnect_interval_ms (int, optional): Interval in milliseconds between reconnect attempts. Defaults to 200.
+ timeout_ms (int, optional): connection timeout in milliseconds. Defaults to 15000.
key_file (string): path to tls key file.
cert_file (string): path to tls cert file.
ca_file (string): path to tls ca file.
@@ -509,10 +482,10 @@ async def consumer(
station_name (str): station name to consume messages from.
consumer_name (str): name for the consumer.
consumer_group (str, optional): consumer group name. Defaults to the consumer name.
- pull_interval_ms (int, optional): interval in miliseconds between pulls. Defaults to 1000.
+ pull_interval_ms (int, optional): interval in milliseconds between pulls. Defaults to 1000.
batch_size (int, optional): pull batch size. Defaults to 10.
- batch_max_time_to_wait_ms (int, optional): max time in miliseconds to wait between pulls. Defaults to 5000.
- max_ack_time_ms (int, optional): max time for ack a message in miliseconds, in case a message not acked in this time period the Memphis broker will resend it. Defaults to 30000.
+ batch_max_time_to_wait_ms (int, optional): max time in milliseconds to wait between pulls. Defaults to 5000.
+ max_ack_time_ms (int, optional): max time for ack a message in milliseconds, in case a message not acked in this time period the Memphis broker will resend it. Defaults to 30000.
max_msg_deliveries (int, optional): max number of message deliveries, by default is 10.
generate_random_suffix (bool): false by default, if true concatenate a random suffix to consumer's name
start_consume_from_sequence(int, optional): start consuming from a specific sequence. defaults to 1.
@@ -605,7 +578,7 @@ async def produce(
ack_wait_sec (int, optional): max time in seconds to wait for an ack from memphis. Defaults to 15.
headers (dict, optional): Message headers, defaults to {}.
async_produce (boolean, optional): produce operation won't wait for broker acknowledgement
- msg_id (string, optional): Attach msg-id header to the message in order to achieve idempotency
+ msg_id (string, optional): Attach msg-id header to the message in order to achieve idempotence
Raises:
Exception: _description_
"""
@@ -631,7 +604,6 @@ async def produce(
except Exception as e:
raise MemphisError(str(e)) from e
-
async def fetch_messages(
self,
station_name: str,
@@ -643,8 +615,8 @@ async def fetch_messages(
max_msg_deliveries: int = 10,
generate_random_suffix: bool = False,
start_consume_from_sequence: int = 1,
- last_messages: int = -1
- ):
+ last_messages: int = -1,
+ ):
"""Consume a batch of messages.
Args:.
station_name (str): station name to consume messages from.
@@ -669,7 +641,18 @@ async def fetch_messages(
if consumer_map_key in self.consumers_map:
consumer = self.consumers_map[consumer_map_key]
else:
- consumer = await self.consumer(station_name=station_name, consumer_name=consumer_name, consumer_group=consumer_group, batch_size=batch_size, batch_max_time_to_wait_ms=batch_max_time_to_wait_ms, max_ack_time_ms=max_ack_time_ms, max_msg_deliveries=max_msg_deliveries, generate_random_suffix=generate_random_suffix, start_consume_from_sequence=start_consume_from_sequence, last_messages=last_messages)
+ consumer = await self.consumer(
+ station_name=station_name,
+ consumer_name=consumer_name,
+ consumer_group=consumer_group,
+ batch_size=batch_size,
+ batch_max_time_to_wait_ms=batch_max_time_to_wait_ms,
+ max_ack_time_ms=max_ack_time_ms,
+ max_msg_deliveries=max_msg_deliveries,
+ generate_random_suffix=generate_random_suffix,
+ start_consume_from_sequence=start_consume_from_sequence,
+ last_messages=last_messages,
+ )
messages = await consumer.fetch(batch_size)
if messages == None:
messages = []
@@ -679,611 +662,3 @@ async def fetch_messages(
def is_connected(self):
return self.broker_manager.is_connected
-
-
-class Station:
- def __init__(self, connection, name: str):
- self.connection = connection
- self.name = name.lower()
-
- async def destroy(self):
- """Destroy the station."""
- try:
- nameReq = {"station_name": self.name, "username": self.connection.username}
- station_name = json.dumps(nameReq, indent=2).encode("utf-8")
- res = await self.connection.broker_manager.request(
- "$memphis_station_destructions", station_name, timeout=5
- )
- error = res.data.decode("utf-8")
- if error != "" and not "not exist" in error:
- raise MemphisError(error)
-
- internal_station_name = get_internal_name(self.name)
- sub = self.connection.schema_updates_subs.get(internal_station_name)
- task = self.connection.schema_tasks.get(internal_station_name)
- if internal_station_name in self.connection.schema_updates_data:
- del self.connection.schema_updates_data[internal_station_name]
- if internal_station_name in self.connection.schema_updates_subs:
- del self.connection.schema_updates_subs[internal_station_name]
- if internal_station_name in self.connection.producers_per_station:
- del self.connection.producers_per_station[internal_station_name]
- if internal_station_name in self.connection.schema_tasks:
- del self.connection.schema_tasks[internal_station_name]
- if task is not None:
- task.cancel()
- if sub is not None:
- await sub.unsubscribe()
-
- self.connection.producers_map = {
- k: v
- for k, v in self.connection.producers_map.items()
- if self.name not in k
- }
-
- self.connection.consumers_map = {
- k: v
- for k, v in self.connection.consumers_map.items()
- if self.name not in k
- }
-
- except Exception as e:
- raise MemphisError(str(e)) from e
-
-
-def get_internal_name(name: str) -> str:
- name = name.lower()
- return name.replace(".", "#")
-
-
-class Producer:
- def __init__(
- self, connection, producer_name: str, station_name: str, real_name: str
- ):
- self.connection = connection
- self.producer_name = producer_name.lower()
- self.station_name = station_name
- self.internal_station_name = get_internal_name(self.station_name)
- self.loop = asyncio.get_running_loop()
- self.real_name = real_name
-
- async def validate_msg(self, message):
- if self.connection.schema_updates_data[self.internal_station_name] != {}:
- schema_type = self.connection.schema_updates_data[
- self.internal_station_name
- ]["type"]
- if schema_type == "protobuf":
- message = self.validate_protobuf(message)
- return message
- elif schema_type == "json":
- message = self.validate_json_schema(message)
- return message
- elif schema_type == "graphql":
- message = self.validate_graphql(message)
- return message
- elif not isinstance(message, bytearray) and not isinstance(message, dict):
- raise MemphisSchemaError("Unsupported message type")
- else:
- if isinstance(message, dict):
- message = bytearray(json.dumps(message).encode("utf-8"))
- return message
-
- def validate_protobuf(self, message):
- proto_msg = self.connection.proto_msgs[self.internal_station_name]
- msgToSend = ""
- try:
- if isinstance(message, bytearray):
- msgToSend = bytes(message)
- try:
- proto_msg.ParseFromString(msgToSend)
- proto_msg.SerializeToString()
- msgToSend = msgToSend.decode("utf-8")
- except Exception as e:
- if "parsing message" in str(e):
- e = "Invalid message format, expecting protobuf"
- raise MemphisSchemaError(str(e))
- return message
- elif hasattr(message, "SerializeToString"):
- msgToSend = message.SerializeToString()
- proto_msg.ParseFromString(msgToSend)
- proto_msg.SerializeToString()
- return msgToSend
-
- else:
- raise MemphisSchemaError("Unsupported message type")
-
- except Exception as e:
- raise MemphisSchemaError("Schema validation has failed: " + str(e))
-
- def validate_json_schema(self, message):
- try:
- if isinstance(message, bytearray):
- try:
- message_obj = json.loads(message)
- except Exception as e:
- raise Exception("Expecting Json format: " + str(e))
- elif isinstance(message, dict):
- message_obj = message
- message = bytearray(json.dumps(message_obj).encode("utf-8"))
- else:
- raise Exception("Unsupported message type")
-
- validate(
- instance=message_obj,
- schema=self.connection.json_schemas[self.internal_station_name],
- )
- return message
- except Exception as e:
- raise MemphisSchemaError("Schema validation has failed: " + str(e))
-
- def validate_graphql(self, message):
- try:
- if isinstance(message, bytearray):
- msg = message.decode("utf-8")
- msg = parse_graphql(msg)
- elif isinstance(message, str):
- msg = parse_graphql(message)
- message = message.encode("utf-8")
- elif isinstance(message, graphql.language.ast.DocumentNode):
- msg = message
- message = str(msg.loc.source.body)
- message = message.encode("utf-8")
- else:
- raise MemphisError("Unsupported message type")
- validate_res = validate_graphql(
- schema=self.connection.graphql_schemas[self.internal_station_name],
- document_ast=msg,
- )
- if len(validate_res) > 0:
- raise Exception("Schema validation has failed: " + str(validate_res))
- return message
- except Exception as e:
- if "Syntax Error" in str(e):
- e = "Invalid message format, expected GraphQL"
- raise Exception("Schema validation has failed: " + str(e))
-
- def get_dls_msg_id(self, station_name: str, producer_name: str, unix_time: str):
- return station_name + "~" + producer_name + "~0~" + unix_time
-
- async def produce(
- self,
- message,
- ack_wait_sec: int = 15,
- headers: Union[Headers, None] = None,
- async_produce: bool = False,
- msg_id: Union[str, None] = None,
- ):
- """Produces a message into a station.
- Args:
- message (bytearray/dict): message to send into the station - bytearray/protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema)
- ack_wait_sec (int, optional): max time in seconds to wait for an ack from memphis. Defaults to 15.
- headers (dict, optional): Message headers, defaults to {}.
- async_produce (boolean, optional): produce operation won't wait for broker acknowledgement
- msg_id (string, optional): Attach msg-id header to the message in order to achieve idempotency
- Raises:
- Exception: _description_
- """
- try:
- message = await self.validate_msg(message)
-
- memphis_headers = {
- "$memphis_producedBy": self.producer_name,
- "$memphis_connectionId": self.connection.connection_id,
- }
-
- if msg_id is not None:
- memphis_headers["msg-id"] = msg_id
-
- if headers is not None:
- headers = headers.headers
- headers.update(memphis_headers)
- else:
- headers = memphis_headers
-
- if async_produce:
- try:
- self.loop.create_task(
- self.connection.broker_connection.publish(
- self.internal_station_name + ".final",
- message,
- timeout=ack_wait_sec,
- headers=headers,
- )
- )
- await asyncio.sleep(1)
- except Exception as e:
- raise MemphisError(e)
- else:
- await self.connection.broker_connection.publish(
- self.internal_station_name + ".final",
- message,
- timeout=ack_wait_sec,
- headers=headers,
- )
- except Exception as e:
- if hasattr(e, "status_code") and e.status_code == "503":
- raise MemphisError(
- "Produce operation has failed, please check whether Station/Producer are still exist"
- )
- else:
- if "Schema validation has failed" in str(
- e
- ) or "Unsupported message type" in str(e):
- msgToSend = ""
- if isinstance(message, bytearray):
- msgToSend = str(message, "utf-8")
- elif hasattr(message, "SerializeToString"):
- msgToSend = message.SerializeToString().decode("utf-8")
- if self.connection.station_schemaverse_to_dls[
- self.internal_station_name
- ]:
- unix_time = int(time.time())
- id = self.get_dls_msg_id(
- self.internal_station_name,
- self.producer_name,
- str(unix_time),
- )
-
- memphis_headers = {
- "$memphis_producedBy": self.producer_name,
- "$memphis_connectionId": self.connection.connection_id,
- }
-
- if headers != {}:
- headers = headers.headers
- headers.update(memphis_headers)
- else:
- headers = memphis_headers
-
- msgToSendEncoded = msgToSend.encode("utf-8")
- msgHex = msgToSendEncoded.hex()
- buf = {
- "_id": id,
- "station_name": self.internal_station_name,
- "producer": {
- "name": self.producer_name,
- "connection_id": self.connection.connection_id,
- },
- "creation_unix": unix_time,
- "message": {
- "data": msgHex,
- "headers": headers,
- },
- }
- buf = json.dumps(buf).encode("utf-8")
- await self.connection.broker_connection.publish(
- "$memphis-"
- + self.internal_station_name
- + "-dls.schema."
- + id,
- buf,
- )
- if self.connection.cluster_configurations.get(
- "send_notification"
- ):
- await self.connection.send_notification(
- "Schema validation has failed",
- "Station: "
- + self.station_name
- + "\nProducer: "
- + self.producer_name
- + "\nError:"
- + str(e),
- msgToSend,
- schemaVFailAlertType,
- )
- raise MemphisError(str(e)) from e
-
- async def destroy(self):
- """Destroy the producer."""
- try:
- destroyProducerReq = {
- "name": self.producer_name,
- "station_name": self.station_name,
- "username": self.connection.username,
- }
-
- producer_name = json.dumps(destroyProducerReq).encode("utf-8")
- res = await self.connection.broker_manager.request(
- "$memphis_producer_destructions", producer_name, timeout=5
- )
- error = res.data.decode("utf-8")
- if error != "" and not "not exist" in error:
- raise Exception(error)
-
- internal_station_name = get_internal_name(self.station_name)
- producer_number = (
- self.connection.producers_per_station.get(internal_station_name) - 1
- )
- self.connection.producers_per_station[
- internal_station_name
- ] = producer_number
-
- if producer_number == 0:
- sub = self.connection.schema_updates_subs.get(internal_station_name)
- task = self.connection.schema_tasks.get(internal_station_name)
- if internal_station_name in self.connection.schema_updates_data:
- del self.connection.schema_updates_data[internal_station_name]
- if internal_station_name in self.connection.schema_updates_subs:
- del self.connection.schema_updates_subs[internal_station_name]
- if internal_station_name in self.connection.schema_tasks:
- del self.connection.schema_tasks[internal_station_name]
- if task is not None:
- task.cancel()
- if sub is not None:
- await sub.unsubscribe()
-
- map_key = internal_station_name + "_" + self.real_name
- del self.connection.producers_map[map_key]
-
- except Exception as e:
- raise Exception(e)
-
-
-def default_error_handler(e):
- print("ping exception raised", e)
-
-
-
-class Consumer:
- def __init__(
- self,
- connection,
- station_name: str,
- consumer_name,
- consumer_group,
- pull_interval_ms: int,
- batch_size: int,
- batch_max_time_to_wait_ms: int,
- max_ack_time_ms: int,
- max_msg_deliveries: int = 10,
- error_callback=None,
- start_consume_from_sequence: int = 1,
- last_messages: int = -1,
- ):
- self.connection = connection
- self.station_name = station_name.lower()
- self.consumer_name = consumer_name.lower()
- self.consumer_group = consumer_group.lower()
- self.pull_interval_ms = pull_interval_ms
- self.batch_size = batch_size
- self.batch_max_time_to_wait_ms = batch_max_time_to_wait_ms
- self.max_ack_time_ms = max_ack_time_ms
- self.max_msg_deliveries = max_msg_deliveries
- self.ping_consumer_invterval_ms = 30000
- if error_callback is None:
- error_callback = default_error_handler
- self.t_ping = asyncio.create_task(self.__ping_consumer(error_callback))
- self.start_consume_from_sequence = start_consume_from_sequence
- self.last_messages = last_messages
- self.context = {}
- self.dls_messages = []
- self.dls_current_index = 0
- self.dls_callback_func = None
- self.t_dls = asyncio.create_task(self.__consume_dls())
-
-
- def set_context(self, context):
- """Set a context (dict) that will be passed to each message handler call."""
- self.context = context
-
- def consume(self, callback):
- """Consume events."""
- self.dls_callback_func = callback
- self.t_consume = asyncio.create_task(self.__consume(callback))
-
- async def __consume(self, callback):
- subject = get_internal_name(self.station_name)
- consumer_group = get_internal_name(self.consumer_group)
- self.psub = await self.connection.broker_connection.pull_subscribe(
- subject + ".final", durable=consumer_group
- )
- while True:
- if self.connection.is_connection_active and self.pull_interval_ms:
- try:
- memphis_messages = []
- msgs = await self.psub.fetch(self.batch_size)
- for msg in msgs:
- memphis_messages.append(
- Message(msg, self.connection, self.consumer_group)
- )
- await callback(memphis_messages, None, self.context)
- await asyncio.sleep(self.pull_interval_ms / 1000)
-
- except asyncio.TimeoutError:
- await callback(
- [], MemphisError("Memphis: TimeoutError"), self.context
- )
- continue
- except Exception as e:
- if self.connection.is_connection_active:
- raise MemphisError(str(e)) from e
- else:
- return
- else:
- break
-
- async def __consume_dls(self):
- subject = get_internal_name(self.station_name)
- consumer_group = get_internal_name(self.consumer_group)
- try:
- subscription_name = "$memphis_dls_" + subject + "_" + consumer_group
- self.consumer_dls = await self.connection.broker_manager.subscribe(
- subscription_name, subscription_name
- )
- async for msg in self.consumer_dls.messages:
- index_to_insert = self.dls_current_index
- if index_to_insert>=10000:
- index_to_insert%=10000
- self.dls_messages.insert(index_to_insert, Message(msg, self.connection, self.consumer_group))
- self.dls_current_index+=1
- if self.dls_callback_func != None:
- await self.dls_callback_func(
- [Message(msg, self.connection, self.consumer_group)],
- None,
- self.context,
- )
- except Exception as e:
- await self.dls_callback_func([], MemphisError(str(e)), self.context)
- return
-
- async def fetch(self, batch_size: int = 10):
- """Fetch a batch of messages."""
- messages = []
- if self.connection.is_connection_active:
- try:
- self.batch_size = batch_size
- if len(self.dls_messages)>0:
- if len(self.dls_messages) <= batch_size:
- messages = self.dls_messages
- self.dls_messages = []
- self.dls_current_index = 0
- else:
- messages = self.dls_messages[0:batch_size]
- del self.dls_messages[0:batch_size]
- self.dls_current_index -= len(messages)
- return messages
-
- durableName = ""
- if self.consumer_group != "":
- durableName = get_internal_name(self.consumer_group)
- else:
- durableName = get_internal_name(self.consumer_name)
- subject = get_internal_name(self.station_name)
- consumer_group = get_internal_name(self.consumer_group)
- self.psub = await self.connection.broker_connection.pull_subscribe(
- subject + ".final", durable=durableName
- )
- msgs = await self.psub.fetch(batch_size)
- for msg in msgs:
- messages.append(Message(msg, self.connection, self.consumer_group))
- return messages
- except Exception as e:
- if not "timeout" in str(e):
- raise MemphisError(str(e)) from e
- else:
- return messages
-
- async def __ping_consumer(self, callback):
- while True:
- try:
- await asyncio.sleep(self.ping_consumer_invterval_ms / 1000)
- consumer_group = get_internal_name(self.consumer_group)
- await self.connection.broker_connection.consumer_info(
- self.station_name, consumer_group, timeout=30
- )
-
- except Exception as e:
- callback(e)
-
- async def destroy(self):
- """Destroy the consumer."""
- if self.t_consume is not None:
- self.t_consume.cancel()
- if self.t_dls is not None:
- self.t_dls.cancel()
- if self.t_ping is not None:
- self.t_ping.cancel()
- self.pull_interval_ms = None
- try:
- destroyConsumerReq = {
- "name": self.consumer_name,
- "station_name": self.station_name,
- "username": self.connection.username,
- }
- consumer_name = json.dumps(destroyConsumerReq, indent=2).encode("utf-8")
- res = await self.connection.broker_manager.request(
- "$memphis_consumer_destructions", consumer_name, timeout=5
- )
- error = res.data.decode("utf-8")
- if error != "" and not "not exist" in error:
- raise MemphisError(error)
- self.dls_messages.clear()
- internal_station_name = get_internal_name(self.station_name)
- map_key = internal_station_name + "_" + self.consumer_name.lower()
- del self.connection.consumers_map[map_key]
- except Exception as e:
- raise MemphisError(str(e)) from e
-
-
-class Message:
- def __init__(self, message, connection, cg_name):
- self.message = message
- self.connection = connection
- self.cg_name = cg_name
-
- async def ack(self):
- """Ack a message is done processing."""
- try:
- await self.message.ack()
- except Exception as e:
- if (
- "$memphis_pm_id"
- in self.message.headers and "$memphis_pm_sequence"
- in self.message.headers
- ):
- try:
- msg = {
- "id": self.message.headers["$memphis_pm_id"],
- "sequence": self.message.headers["$memphis_pm_sequence"],
- }
- msgToAck = json.dumps(msg).encode("utf-8")
- await self.connection.broker_manager.publish(
- "$memphis_pm_acks", msgToAck
- )
- except Exception as er:
- raise MemphisConnectError(str(er)) from er
- else:
- raise MemphisConnectError(str(e)) from e
- return
-
- def get_data(self):
- """Receive the message."""
- try:
- return bytearray(self.message.data)
- except:
- return
-
- def get_headers(self):
- """Receive the headers."""
- try:
- return self.message.headers
- except:
- return
-
- def get_sequence_number(self):
- """Get message sequence number."""
- try:
- return self.message.metadata.sequence.stream
- except:
- return
-
-
-def random_bytes(amount: int) -> str:
- lst = [random.choice("0123456789abcdef") for n in range(amount)]
- s = "".join(lst)
- return s
-
-
-class MemphisError(Exception):
- def __init__(self, message):
- message = message.replace("nats", "memphis")
- message = message.replace("NATS", "memphis")
- message = message.replace("Nats", "memphis")
- message = message.replace("NatsError", "MemphisError")
- self.message = message
- if message.startswith("memphis:"):
- super().__init__(self.message)
- else:
- super().__init__("memphis: " + self.message)
-
-
-class MemphisConnectError(MemphisError):
- pass
-
-
-class MemphisSchemaError(MemphisError):
- pass
-
-
-class MemphisHeaderError(MemphisError):
- pass
diff --git a/memphis/message.py b/memphis/message.py
new file mode 100644
index 0000000..3c8f30d
--- /dev/null
+++ b/memphis/message.py
@@ -0,0 +1,57 @@
+from __future__ import annotations
+
+import json
+
+from memphis.exceptions import MemphisConnectError
+
+
+class Message:
+ def __init__(self, message, connection, cg_name):
+ self.message = message
+ self.connection = connection
+ self.cg_name = cg_name
+
+ async def ack(self):
+ """Ack a message is done processing."""
+ try:
+ await self.message.ack()
+ except Exception as e:
+ if (
+ "$memphis_pm_id" in self.message.headers
+ and "$memphis_pm_sequence" in self.message.headers
+ ):
+ try:
+ msg = {
+ "id": self.message.headers["$memphis_pm_id"],
+ "sequence": self.message.headers["$memphis_pm_sequence"],
+ }
+ msgToAck = json.dumps(msg).encode("utf-8")
+ await self.connection.broker_manager.publish(
+ "$memphis_pm_acks", msgToAck
+ )
+ except Exception as er:
+ raise MemphisConnectError(str(er)) from er
+ else:
+ raise MemphisConnectError(str(e)) from e
+ return
+
+ def get_data(self):
+ """Receive the message."""
+ try:
+ return bytearray(self.message.data)
+ except:
+ return
+
+ def get_headers(self):
+ """Receive the headers."""
+ try:
+ return self.message.headers
+ except:
+ return
+
+ def get_sequence_number(self):
+ """Get message sequence number."""
+ try:
+ return self.message.metadata.sequence.stream
+ except:
+ return
diff --git a/memphis/producer.py b/memphis/producer.py
new file mode 100644
index 0000000..1c8eded
--- /dev/null
+++ b/memphis/producer.py
@@ -0,0 +1,298 @@
+from __future__ import annotations
+
+import asyncio
+import json
+import time
+from typing import Union
+
+import graphql
+from graphql import parse as parse_graphql
+from graphql import validate as validate_graphql
+from jsonschema import validate
+from memphis.exceptions import MemphisError, MemphisSchemaError
+from memphis.headers import Headers
+from memphis.utils import get_internal_name
+
+
+class Producer:
+ def __init__(
+ self, connection, producer_name: str, station_name: str, real_name: str
+ ):
+ self.connection = connection
+ self.producer_name = producer_name.lower()
+ self.station_name = station_name
+ self.internal_station_name = get_internal_name(self.station_name)
+ self.loop = asyncio.get_running_loop()
+ self.real_name = real_name
+
+ async def validate_msg(self, message):
+ if self.connection.schema_updates_data[self.internal_station_name] != {}:
+ schema_type = self.connection.schema_updates_data[
+ self.internal_station_name
+ ]["type"]
+ if schema_type == "protobuf":
+ message = self.validate_protobuf(message)
+ return message
+ elif schema_type == "json":
+ message = self.validate_json_schema(message)
+ return message
+ elif schema_type == "graphql":
+ message = self.validate_graphql(message)
+ return message
+ elif not isinstance(message, bytearray) and not isinstance(message, dict):
+ raise MemphisSchemaError("Unsupported message type")
+ else:
+ if isinstance(message, dict):
+ message = bytearray(json.dumps(message).encode("utf-8"))
+ return message
+
+ def validate_protobuf(self, message):
+ proto_msg = self.connection.proto_msgs[self.internal_station_name]
+ msgToSend = ""
+ try:
+ if isinstance(message, bytearray):
+ msgToSend = bytes(message)
+ try:
+ proto_msg.ParseFromString(msgToSend)
+ proto_msg.SerializeToString()
+ msgToSend = msgToSend.decode("utf-8")
+ except Exception as e:
+ if "parsing message" in str(e):
+ e = "Invalid message format, expecting protobuf"
+ raise MemphisSchemaError(str(e))
+ return message
+ elif hasattr(message, "SerializeToString"):
+ msgToSend = message.SerializeToString()
+ proto_msg.ParseFromString(msgToSend)
+ proto_msg.SerializeToString()
+ return msgToSend
+
+ else:
+ raise MemphisSchemaError("Unsupported message type")
+
+ except Exception as e:
+ raise MemphisSchemaError("Schema validation has failed: " + str(e))
+
+ def validate_json_schema(self, message):
+ try:
+ if isinstance(message, bytearray):
+ try:
+ message_obj = json.loads(message)
+ except Exception as e:
+ raise Exception("Expecting Json format: " + str(e))
+ elif isinstance(message, dict):
+ message_obj = message
+ message = bytearray(json.dumps(message_obj).encode("utf-8"))
+ else:
+ raise Exception("Unsupported message type")
+
+ validate(
+ instance=message_obj,
+ schema=self.connection.json_schemas[self.internal_station_name],
+ )
+ return message
+ except Exception as e:
+ raise MemphisSchemaError("Schema validation has failed: " + str(e))
+
+ def validate_graphql(self, message):
+ try:
+ if isinstance(message, bytearray):
+ msg = message.decode("utf-8")
+ msg = parse_graphql(msg)
+ elif isinstance(message, str):
+ msg = parse_graphql(message)
+ message = message.encode("utf-8")
+ elif isinstance(message, graphql.language.ast.DocumentNode):
+ msg = message
+ message = str(msg.loc.source.body)
+ message = message.encode("utf-8")
+ else:
+ raise MemphisError("Unsupported message type")
+ validate_res = validate_graphql(
+ schema=self.connection.graphql_schemas[self.internal_station_name],
+ document_ast=msg,
+ )
+ if len(validate_res) > 0:
+ raise Exception("Schema validation has failed: " + str(validate_res))
+ return message
+ except Exception as e:
+ if "Syntax Error" in str(e):
+ e = "Invalid message format, expected GraphQL"
+ raise Exception("Schema validation has failed: " + str(e))
+
+ def get_dls_msg_id(self, station_name: str, producer_name: str, unix_time: str):
+ return station_name + "~" + producer_name + "~0~" + unix_time
+
+ async def produce(
+ self,
+ message,
+ ack_wait_sec: int = 15,
+ headers: Union[Headers, None] = None,
+ async_produce: bool = False,
+ msg_id: Union[str, None] = None,
+ ):
+ """Produces a message into a station.
+ Args:
+ message (bytearray/dict): message to send into the station - bytearray/protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema)
+ ack_wait_sec (int, optional): max time in seconds to wait for an ack from memphis. Defaults to 15.
+ headers (dict, optional): Message headers, defaults to {}.
+ async_produce (boolean, optional): produce operation won't wait for broker acknowledgement
+ msg_id (string, optional): Attach msg-id header to the message in order to achieve idempotency
+ Raises:
+ Exception: _description_
+ """
+ try:
+ message = await self.validate_msg(message)
+
+ memphis_headers = {
+ "$memphis_producedBy": self.producer_name,
+ "$memphis_connectionId": self.connection.connection_id,
+ }
+
+ if msg_id is not None:
+ memphis_headers["msg-id"] = msg_id
+
+ if headers is not None:
+ headers = headers.headers
+ headers.update(memphis_headers)
+ else:
+ headers = memphis_headers
+
+ if async_produce:
+ try:
+ self.loop.create_task(
+ self.connection.broker_connection.publish(
+ self.internal_station_name + ".final",
+ message,
+ timeout=ack_wait_sec,
+ headers=headers,
+ )
+ )
+ await asyncio.sleep(1)
+ except Exception as e:
+ raise MemphisError(e)
+ else:
+ await self.connection.broker_connection.publish(
+ self.internal_station_name + ".final",
+ message,
+ timeout=ack_wait_sec,
+ headers=headers,
+ )
+ except Exception as e:
+ if hasattr(e, "status_code") and e.status_code == "503":
+ raise MemphisError(
+ "Produce operation has failed, please check whether Station/Producer are still exist"
+ )
+ else:
+ if "Schema validation has failed" in str(
+ e
+ ) or "Unsupported message type" in str(e):
+ msgToSend = ""
+ if isinstance(message, bytearray):
+ msgToSend = str(message, "utf-8")
+ elif hasattr(message, "SerializeToString"):
+ msgToSend = message.SerializeToString().decode("utf-8")
+ if self.connection.station_schemaverse_to_dls[
+ self.internal_station_name
+ ]:
+ unix_time = int(time.time())
+ id = self.get_dls_msg_id(
+ self.internal_station_name,
+ self.producer_name,
+ str(unix_time),
+ )
+
+ memphis_headers = {
+ "$memphis_producedBy": self.producer_name,
+ "$memphis_connectionId": self.connection.connection_id,
+ }
+
+ if headers != {}:
+ headers = headers.headers
+ headers.update(memphis_headers)
+ else:
+ headers = memphis_headers
+
+ msgToSendEncoded = msgToSend.encode("utf-8")
+ msgHex = msgToSendEncoded.hex()
+ buf = {
+ "_id": id,
+ "station_name": self.internal_station_name,
+ "producer": {
+ "name": self.producer_name,
+ "connection_id": self.connection.connection_id,
+ },
+ "creation_unix": unix_time,
+ "message": {
+ "data": msgHex,
+ "headers": headers,
+ },
+ }
+ buf = json.dumps(buf).encode("utf-8")
+ await self.connection.broker_connection.publish(
+ "$memphis-"
+ + self.internal_station_name
+ + "-dls.schema."
+ + id,
+ buf,
+ )
+ if self.connection.cluster_configurations.get(
+ "send_notification"
+ ):
+ await self.connection.send_notification(
+ "Schema validation has failed",
+ "Station: "
+ + self.station_name
+ + "\nProducer: "
+ + self.producer_name
+ + "\nError:"
+ + str(e),
+ msgToSend,
+ schemaVFailAlertType,
+ )
+ raise MemphisError(str(e)) from e
+
+ async def destroy(self):
+ """Destroy the producer."""
+ try:
+ destroyProducerReq = {
+ "name": self.producer_name,
+ "station_name": self.station_name,
+ "username": self.connection.username,
+ }
+
+ producer_name = json.dumps(destroyProducerReq).encode("utf-8")
+ res = await self.connection.broker_manager.request(
+ "$memphis_producer_destructions", producer_name, timeout=5
+ )
+ error = res.data.decode("utf-8")
+ if error != "" and not "not exist" in error:
+ raise Exception(error)
+
+ internal_station_name = get_internal_name(self.station_name)
+ producer_number = (
+ self.connection.producers_per_station.get(internal_station_name) - 1
+ )
+ self.connection.producers_per_station[
+ internal_station_name
+ ] = producer_number
+
+ if producer_number == 0:
+ sub = self.connection.schema_updates_subs.get(internal_station_name)
+ task = self.connection.schema_tasks.get(internal_station_name)
+ if internal_station_name in self.connection.schema_updates_data:
+ del self.connection.schema_updates_data[internal_station_name]
+ if internal_station_name in self.connection.schema_updates_subs:
+ del self.connection.schema_updates_subs[internal_station_name]
+ if internal_station_name in self.connection.schema_tasks:
+ del self.connection.schema_tasks[internal_station_name]
+ if task is not None:
+ task.cancel()
+ if sub is not None:
+ await sub.unsubscribe()
+
+ map_key = internal_station_name + "_" + self.real_name
+ del self.connection.producers_map[map_key]
+
+ except Exception as e:
+ raise Exception(e)
diff --git a/memphis/station.py b/memphis/station.py
new file mode 100644
index 0000000..05aac6d
--- /dev/null
+++ b/memphis/station.py
@@ -0,0 +1,53 @@
+import json
+
+from memphis.exceptions import MemphisError
+from memphis.utils import get_internal_name
+
+
+class Station:
+ def __init__(self, connection, name: str):
+ self.connection = connection
+ self.name = name.lower()
+
+ async def destroy(self):
+ """Destroy the station."""
+ try:
+ nameReq = {"station_name": self.name, "username": self.connection.username}
+ station_name = json.dumps(nameReq, indent=2).encode("utf-8")
+ res = await self.connection.broker_manager.request(
+ "$memphis_station_destructions", station_name, timeout=5
+ )
+ error = res.data.decode("utf-8")
+ if error != "" and not "not exist" in error:
+ raise MemphisError(error)
+
+ internal_station_name = get_internal_name(self.name)
+ sub = self.connection.schema_updates_subs.get(internal_station_name)
+ task = self.connection.schema_tasks.get(internal_station_name)
+ if internal_station_name in self.connection.schema_updates_data:
+ del self.connection.schema_updates_data[internal_station_name]
+ if internal_station_name in self.connection.schema_updates_subs:
+ del self.connection.schema_updates_subs[internal_station_name]
+ if internal_station_name in self.connection.producers_per_station:
+ del self.connection.producers_per_station[internal_station_name]
+ if internal_station_name in self.connection.schema_tasks:
+ del self.connection.schema_tasks[internal_station_name]
+ if task is not None:
+ task.cancel()
+ if sub is not None:
+ await sub.unsubscribe()
+
+ self.connection.producers_map = {
+ k: v
+ for k, v in self.connection.producers_map.items()
+ if self.name not in k
+ }
+
+ self.connection.consumers_map = {
+ k: v
+ for k, v in self.connection.consumers_map.items()
+ if self.name not in k
+ }
+
+ except Exception as e:
+ raise MemphisError(str(e)) from e
diff --git a/memphis/utils.py b/memphis/utils.py
new file mode 100644
index 0000000..10512e9
--- /dev/null
+++ b/memphis/utils.py
@@ -0,0 +1,32 @@
+import random
+from threading import Timer
+from typing import Callable
+
+
+class set_interval:
+ def __init__(self, func: Callable, sec: int):
+ def func_wrapper():
+ self.t = Timer(sec, func_wrapper)
+ self.t.start()
+ func()
+
+ self.t = Timer(sec, func_wrapper)
+ self.t.start()
+
+ def cancel(self):
+ self.t.cancel()
+
+
+def default_error_handler(e):
+ print("ping exception raised", e)
+
+
+def get_internal_name(name: str) -> str:
+ name = name.lower()
+ return name.replace(".", "#")
+
+
+def random_bytes(amount: int) -> str:
+ lst = [random.choice("0123456789abcdef") for n in range(amount)]
+ s = "".join(lst)
+ return s
From b107f8fa7d49aeaeb28effc47fd45c875719284d Mon Sep 17 00:00:00 2001
From: Shay Bratslavsky
Date: Thu, 2 Mar 2023 13:00:27 +0200
Subject: [PATCH 05/12] fix imports (#121)
---
memphis/consumer.py | 3 ++-
memphis/memphis.py | 9 ++++-----
memphis/producer.py | 2 ++
3 files changed, 8 insertions(+), 6 deletions(-)
diff --git a/memphis/consumer.py b/memphis/consumer.py
index aad4625..598155c 100644
--- a/memphis/consumer.py
+++ b/memphis/consumer.py
@@ -5,6 +5,7 @@
from memphis.exceptions import MemphisError
from memphis.utils import default_error_handler, get_internal_name
+from memphis.message import Message
class Consumer:
@@ -157,7 +158,7 @@ async def __ping_consumer(self, callback):
)
except Exception as e:
- callback(e)
+ callback(MemphisError(str(e)))
async def destroy(self):
"""Destroy the consumer."""
diff --git a/memphis/memphis.py b/memphis/memphis.py
index 0d15ccd..cc0d91b 100644
--- a/memphis/memphis.py
+++ b/memphis/memphis.py
@@ -33,10 +33,9 @@
from memphis.exceptions import MemphisConnectError, MemphisError, MemphisHeaderError
from memphis.headers import Headers
from memphis.producer import Producer
+from memphis.station import Station
from memphis.types import Retention, Storage
-from memphis.utils import get_internal_name
-
-schemaVFailAlertType = "schema_validation_fail_alert"
+from memphis.utils import get_internal_name, random_bytes
class Memphis:
@@ -198,9 +197,9 @@ async def station(
createStationReq = {
"name": name,
- "retention_type": retention_type,
+ "retention_type": retention_type.value,
"retention_value": retention_value,
- "storage_type": storage_type,
+ "storage_type": storage_type.value,
"replicas": replicas,
"idempotency_window_in_ms": idempotency_window_ms,
"schema_name": schema_name,
diff --git a/memphis/producer.py b/memphis/producer.py
index 1c8eded..204c6d2 100644
--- a/memphis/producer.py
+++ b/memphis/producer.py
@@ -13,6 +13,8 @@
from memphis.headers import Headers
from memphis.utils import get_internal_name
+schemaVFailAlertType = "schema_validation_fail_alert"
+
class Producer:
def __init__(
From 31a4ffa6a46e5e0ea65de3e670c037c9af994cc3 Mon Sep 17 00:00:00 2001
From: Bruno Bandeira
Date: Sun, 5 Mar 2023 14:16:30 +0000
Subject: [PATCH 06/12] adding explanation about retention values (#122)
---
README.md | 11 +++++++++++
1 file changed, 11 insertions(+)
diff --git a/README.md b/README.md
index 28f5e69..00f1a62 100644
--- a/README.md
+++ b/README.md
@@ -141,6 +141,17 @@ memphis.types.Retention.BYTES
Means that after max amount of saved bytes (set in retention value), the oldest messages will be deleted
+
+### Retention Values
+
+The `retention values` are directly related to the `retention types` mentioned above, where the values vary according to the type of retention chosen.
+
+All retention values are of type `int` but with different representations as follows:
+
+`memphis.types.Retention.MAX_MESSAGE_AGE_SECONDS` is represented **in seconds**, `memphis.types.Retention.MESSAGES` in a **number of messages** and finally `memphis.types.Retention.BYTES` in a **number of bytes**.
+
+After these limits are reached oldest messages will be deleted.
+
### Storage types
Memphis currently supports the following types of messages storage:
From 9891a0784b04f7e6decd01211e05214961276f4f Mon Sep 17 00:00:00 2001
From: shohamroditimemphis
<108217318+shohamroditimemphis@users.noreply.github.com>
Date: Mon, 6 Mar 2023 14:59:42 +0200
Subject: [PATCH 07/12] delete station (#124)
* delete station
* fix issues
---
memphis/memphis.py | 36 +++++++++++++++++++++++++++++++-----
1 file changed, 31 insertions(+), 5 deletions(-)
diff --git a/memphis/memphis.py b/memphis/memphis.py
index cc0d91b..c1235af 100644
--- a/memphis/memphis.py
+++ b/memphis/memphis.py
@@ -55,7 +55,7 @@ def __init__(self):
self.producers_map = dict()
self.consumers_map = dict()
- async def get_msgs_update_configurations(self, iterable: Iterable):
+ async def get_msgs_sdk_clients_updates(self, iterable: Iterable):
try:
async for msg in iterable:
message = msg.data.decode("utf-8")
@@ -66,18 +66,21 @@ async def get_msgs_update_configurations(self, iterable: Iterable):
self.station_schemaverse_to_dls[data["station_name"]] = data[
"update"
]
+ elif data["type"] == "remove_station":
+ self.unset_cached_producer_station(data['station_name'])
+ self.unset_cached_consumer_station(data['station_name'])
except Exception as err:
raise MemphisError(err)
- async def configurations_listener(self):
+ async def sdk_client_updates_listener(self):
try:
sub = await self.broker_manager.subscribe(
- "$memphis_sdk_configurations_updates"
+ "$memphis_sdk_clients_updates"
)
self.update_configurations_sub = sub
loop = asyncio.get_event_loop()
task = loop.create_task(
- self.get_msgs_update_configurations(
+ self.get_msgs_sdk_clients_updates(
self.update_configurations_sub.messages
)
)
@@ -155,7 +158,7 @@ async def connect(
name=self.connection_id + "::" + self.username,
)
- await self.configurations_listener()
+ await self.sdk_client_updates_listener()
self.broker_connection = self.broker_manager.jetstream()
self.is_connection_active = True
except Exception as e:
@@ -661,3 +664,26 @@ async def fetch_messages(
def is_connected(self):
return self.broker_manager.is_connected
+
+ def unset_cached_producer_station(self, station_name):
+ try:
+ internal_station_name = get_internal_name(station_name)
+ for key in list(self.producers_map):
+ producer = self.producers_map[key]
+ if producer.internal_station_name == internal_station_name:
+ del self.producers_map[key]
+ except Exception as e:
+ raise e
+
+
+ def unset_cached_consumer_station(self, station_name):
+ try:
+ internal_station_name = get_internal_name(station_name)
+ for key in list(self.consumers_map):
+ consumer = self.consumers_map[key]
+ consumer_station_name_internal = get_internal_name(consumer.station_name)
+ if consumer_station_name_internal == internal_station_name:
+ del self.consumers_map[key]
+ except Exception as e:
+ raise e
+
From 61778ef2054b428e5c107d72c8180d61a3a722e4 Mon Sep 17 00:00:00 2001
From: idanasulinStrech
Date: Tue, 14 Mar 2023 14:08:56 +0200
Subject: [PATCH 08/12] add comment
---
memphis/producer.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/memphis/producer.py b/memphis/producer.py
index 204c6d2..5c2e4e1 100644
--- a/memphis/producer.py
+++ b/memphis/producer.py
@@ -170,7 +170,7 @@ async def produce(
headers=headers,
)
)
- await asyncio.sleep(1)
+ await asyncio.sleep(1) # TODO - check why we need sleep in here
except Exception as e:
raise MemphisError(e)
else:
From d30b6d46e0ceea50311a9049ea83b6d55e9b0461 Mon Sep 17 00:00:00 2001
From: RJ Nowling
Date: Tue, 14 Mar 2023 08:23:47 -0500
Subject: [PATCH 09/12] Fix formatting error (#125)
A backtack was missing in one of the Markdown code blocks, causing the code snippet to fail to render.
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 00f1a62..3dffa9e 100644
--- a/README.md
+++ b/README.md
@@ -339,7 +339,7 @@ await message.ack()
### Get headers
Get headers per message
-``python
+```python
headers = message.get_headers()
```
From 2c95431ab27ba934de638438f71244c170af4d8d Mon Sep 17 00:00:00 2001
From: idonaaman123 <127736311+idonaaman123@users.noreply.github.com>
Date: Wed, 15 Mar 2023 16:30:32 +0200
Subject: [PATCH 10/12] change MAINTAINERS.md file (#126)
Co-authored-by: ido
---
MAINTAINERS.md | 4 ++--
README.md | 4 ++--
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/MAINTAINERS.md b/MAINTAINERS.md
index 254758f..173aa67 100644
--- a/MAINTAINERS.md
+++ b/MAINTAINERS.md
@@ -5,5 +5,5 @@
- Yaniv Ben Hemo [@yanivbh1](https://github.com/yanivbh1)
- Sveta Gimpelson [@SvetaMemphis](https://github.com/SvetaMemphis)
- Shay Bratslavsky [@shay23b](https://github.com/shay23b)
- - Or Grinberg [@orgrMmphs](https://github.com/orgrMmphs)
- - Shoham Roditi [@shohamroditimemphis](https://github.com/shohamroditimemphis)
\ No newline at end of file
+ - Shoham Roditi [@shohamroditimemphis](https://github.com/shohamroditimemphis)
+ - Ido Naaman [idonaaman123](https://github.com/idonaaman123)
\ No newline at end of file
diff --git a/README.md b/README.md
index 3dffa9e..12bc974 100644
--- a/README.md
+++ b/README.md
@@ -224,9 +224,9 @@ await memphis.produce(station_name='test_station_py', producer_name='prod_py',
```
-Creating a producer first
+With creating a producer
```python
-await prod.produce(
+await producer.produce(
message='bytearray/protobuf class/dict/string/graphql.language.ast.DocumentNode', # bytearray / protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema)
ack_wait_sec=15) # defaults to 15
```
From 6eef73228c715965df2d8f1329c2b6c425e7bba9 Mon Sep 17 00:00:00 2001
From: Idan Asulin <74712806+idanasulinmemphis@users.noreply.github.com>
Date: Thu, 16 Mar 2023 11:10:16 +0200
Subject: [PATCH 11/12] Update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 12bc974..67086b7 100644
--- a/README.md
+++ b/README.md
@@ -69,7 +69,7 @@ async def main():
await memphis.connect(
host="",
username="",
- connection_token="",
+ connection_token="", # you will get it on application type user creation
port="", # defaults to 6666
reconnect=True, # defaults to True
max_reconnect=3, # defaults to 3
From 1b46f51eab985eb923394be18a1009c63408d8a1 Mon Sep 17 00:00:00 2001
From: idanasulinStrech
Date: Thu, 16 Mar 2023 11:53:13 +0200
Subject: [PATCH 12/12] update version
---
setup.py | 4 ++--
version.conf | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/setup.py b/setup.py
index d53182d..5b15795 100644
--- a/setup.py
+++ b/setup.py
@@ -8,7 +8,7 @@
setup(
name="memphis-py",
packages=["memphis"],
- version="0.3.2",
+ version="0.3.3",
license="Apache-2.0",
description="A powerful messaging platform for modern developers",
long_description=long_description,
@@ -17,7 +17,7 @@
author="Memphis.dev",
author_email="team@memphis.dev",
url="https://github.com/memphisdev/memphis.py",
- download_url="https://github.com/memphisdev/memphis.py/archive/refs/tags/v0.3.2.tar.gz",
+ download_url="https://github.com/memphisdev/memphis.py/archive/refs/tags/v0.3.3.tar.gz",
keywords=["message broker", "devtool", "streaming", "data"],
install_requires=["asyncio", "nats-py", "protobuf", "jsonschema", "graphql-core"],
classifiers=[
diff --git a/version.conf b/version.conf
index 9fc80f9..87a0871 100644
--- a/version.conf
+++ b/version.conf
@@ -1 +1 @@
-0.3.2
\ No newline at end of file
+0.3.3
\ No newline at end of file