diff --git a/MAINTAINERS.md b/MAINTAINERS.md
index 254758f..173aa67 100644
--- a/MAINTAINERS.md
+++ b/MAINTAINERS.md
@@ -5,5 +5,5 @@
- Yaniv Ben Hemo [@yanivbh1](https://github.com/yanivbh1)
- Sveta Gimpelson [@SvetaMemphis](https://github.com/SvetaMemphis)
- Shay Bratslavsky [@shay23b](https://github.com/shay23b)
- - Or Grinberg [@orgrMmphs](https://github.com/orgrMmphs)
- - Shoham Roditi [@shohamroditimemphis](https://github.com/shohamroditimemphis)
\ No newline at end of file
+ - Shoham Roditi [@shohamroditimemphis](https://github.com/shohamroditimemphis)
+ - Ido Naaman [idonaaman123](https://github.com/idonaaman123)
\ No newline at end of file
diff --git a/README.md b/README.md
index 8030f14..67086b7 100644
--- a/README.md
+++ b/README.md
@@ -38,12 +38,12 @@
-**[Memphis](https://memphis.dev)** is a next-generation message broker.
+**[Memphis](https://memphis.dev)** is a next-generation alternative to traditional message brokers.
A simple, robust, and durable cloud-native message broker wrapped with
-an entire ecosystem that enables fast and reliable development of next-generation event-driven use cases.
-Memphis enables building modern applications that require large volumes of streamed and enriched data,
-modern protocols, zero ops, rapid development, extreme cost reduction,
-and a significantly lower amount of dev time for data-oriented developers and data engineers.
+an entire ecosystem that enables cost-effective, fast, and reliable development of modern queue-based use cases.
+Memphis enables the building of modern queue-based applications that require
+large volumes of streamed and enriched data, modern protocols, zero ops, rapid development,
+extreme cost reduction, and a significantly lower amount of dev time for data-oriented developers and data engineers.
## Installation
@@ -55,7 +55,7 @@ $ pip3 install memphis-py
```python
from memphis import Memphis, Headers
-from memphis import retention_types, storage_types
+from memphis.types import Retention, Storage
```
### Connecting to Memphis
@@ -69,7 +69,7 @@ async def main():
await memphis.connect(
host="",
username="",
- connection_token="",
+ connection_token="", # you will get it on application type user creation
port="", # defaults to 6666
reconnect=True, # defaults to True
max_reconnect=3, # defaults to 3
@@ -108,9 +108,9 @@ _If a station already exists nothing happens, the new configuration will not be
station = memphis.station(
name="",
schema_name="",
- retention_type=retention_types.MAX_MESSAGE_AGE_SECONDS, # MAX_MESSAGE_AGE_SECONDS/MESSAGES/BYTES. Defaults to MAX_MESSAGE_AGE_SECONDS
+ retention_type=Retention.MAX_MESSAGE_AGE_SECONDS, # MAX_MESSAGE_AGE_SECONDS/MESSAGES/BYTES. Defaults to MAX_MESSAGE_AGE_SECONDS
retention_value=604800, # defaults to 604800
- storage_type=storage_types.DISK, # storage_types.DISK/storage_types.MEMORY. Defaults to DISK
+ storage_type=Storage.DISK, # Storage.DISK/Storage.MEMORY. Defaults to DISK
replicas=1, # defaults to 1
idempotency_window_ms=120000, # defaults to 2 minutes
send_poison_msg_to_dls=True, # defaults to true
@@ -124,35 +124,46 @@ station = memphis.station(
Memphis currently supports the following types of retention:
```python
-memphis.retention_types.MAX_MESSAGE_AGE_SECONDS
+memphis.types.Retention.MAX_MESSAGE_AGE_SECONDS
```
Means that every message persists for the value set in retention value field (in seconds)
```python
-memphis.retention_types.MESSAGES
+memphis.types.Retention.MESSAGES
```
Means that after max amount of saved messages (set in retention value), the oldest messages will be deleted
```python
-memphis.retention_types.BYTES
+memphis.types.Retention.BYTES
```
Means that after max amount of saved bytes (set in retention value), the oldest messages will be deleted
+
+### Retention Values
+
+The `retention values` are directly related to the `retention types` mentioned above, where the values vary according to the type of retention chosen.
+
+All retention values are of type `int` but with different representations as follows:
+
+`memphis.types.Retention.MAX_MESSAGE_AGE_SECONDS` is represented **in seconds**, `memphis.types.Retention.MESSAGES` in a **number of messages** and finally `memphis.types.Retention.BYTES` in a **number of bytes**.
+
+After these limits are reached oldest messages will be deleted.
+
### Storage types
Memphis currently supports the following types of messages storage:
```python
-memphis.storage_types.DISK
+memphis.types.Storage.DISK
```
Means that messages persist on disk
```python
-memphis.storage_types.MEMORY
+memphis.types.Storage.MEMORY
```
Means that messages persist on the main memory
@@ -213,9 +224,9 @@ await memphis.produce(station_name='test_station_py', producer_name='prod_py',
```
-Creating a producer first
+With creating a producer
```python
-await prod.produce(
+await producer.produce(
message='bytearray/protobuf class/dict/string/graphql.language.ast.DocumentNode', # bytearray / protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema)
ack_wait_sec=15) # defaults to 15
```
@@ -295,6 +306,28 @@ async def msg_handler(msgs, error, context):
consumer.consume(msg_handler)
```
+### Fetch a single batch of messages
+```python
+msgs = await memphis.fetch_messages(
+ station_name="",
+ consumer_name="",
+ consumer_group="", # defaults to the consumer name
+ batch_size=10, # defaults to 10
+ batch_max_time_to_wait_ms=5000, # defaults to 5000
+ max_ack_time_ms=30000, # defaults to 30000
+ max_msg_deliveries=10, # defaults to 10
+ generate_random_suffix=False
+ start_consume_from_sequence=1 # start consuming from a specific sequence. defaults to 1
+ last_messages=-1 # consume the last N messages, defaults to -1 (all messages in the station))
+)
+```
+
+### Fetch a single batch of messages after creating a consumer
+```python
+msgs = await consumer.fetch(batch_size=10) # defaults to 10
+```
+
+
### Acknowledge a message
Acknowledge a message indicates the Memphis server to not re-send the same message again to the same consumer / consumers group
@@ -306,7 +339,7 @@ await message.ack()
### Get headers
Get headers per message
-``python
+```python
headers = message.get_headers()
```
@@ -328,4 +361,4 @@ consumer.destroy()
```python
memphis.is_connected()
-```
\ No newline at end of file
+```
diff --git a/examples/consumer.py b/examples/consumer.py
index 4febb93..dfebbbe 100644
--- a/examples/consumer.py
+++ b/examples/consumer.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+
import asyncio
from memphis import Memphis, MemphisConnectError, MemphisError, MemphisHeaderError
diff --git a/examples/producer.py b/examples/producer.py
index e5a79e0..01dd381 100644
--- a/examples/producer.py
+++ b/examples/producer.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+
import asyncio
from memphis import (
diff --git a/memphis/__init__.py b/memphis/__init__.py
index ff8bd5f..aa44e30 100644
--- a/memphis/__init__.py
+++ b/memphis/__init__.py
@@ -11,14 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import memphis.retention_types
-import memphis.storage_types
-from memphis.memphis import (
- Headers,
- Memphis,
+from memphis.exceptions import (
MemphisConnectError,
MemphisError,
MemphisHeaderError,
MemphisSchemaError,
)
+from memphis.headers import Headers
+from memphis.memphis import Memphis
diff --git a/memphis/consumer.py b/memphis/consumer.py
new file mode 100644
index 0000000..598155c
--- /dev/null
+++ b/memphis/consumer.py
@@ -0,0 +1,190 @@
+from __future__ import annotations
+
+import asyncio
+import json
+
+from memphis.exceptions import MemphisError
+from memphis.utils import default_error_handler, get_internal_name
+from memphis.message import Message
+
+
+class Consumer:
+ def __init__(
+ self,
+ connection,
+ station_name: str,
+ consumer_name,
+ consumer_group,
+ pull_interval_ms: int,
+ batch_size: int,
+ batch_max_time_to_wait_ms: int,
+ max_ack_time_ms: int,
+ max_msg_deliveries: int = 10,
+ error_callback=None,
+ start_consume_from_sequence: int = 1,
+ last_messages: int = -1,
+ ):
+ self.connection = connection
+ self.station_name = station_name.lower()
+ self.consumer_name = consumer_name.lower()
+ self.consumer_group = consumer_group.lower()
+ self.pull_interval_ms = pull_interval_ms
+ self.batch_size = batch_size
+ self.batch_max_time_to_wait_ms = batch_max_time_to_wait_ms
+ self.max_ack_time_ms = max_ack_time_ms
+ self.max_msg_deliveries = max_msg_deliveries
+ self.ping_consumer_invterval_ms = 30000
+ if error_callback is None:
+ error_callback = default_error_handler
+ self.t_ping = asyncio.create_task(self.__ping_consumer(error_callback))
+ self.start_consume_from_sequence = start_consume_from_sequence
+ self.last_messages = last_messages
+ self.context = {}
+ self.dls_messages = []
+ self.dls_current_index = 0
+ self.dls_callback_func = None
+ self.t_dls = asyncio.create_task(self.__consume_dls())
+
+ def set_context(self, context):
+ """Set a context (dict) that will be passed to each message handler call."""
+ self.context = context
+
+ def consume(self, callback):
+ """Consume events."""
+ self.dls_callback_func = callback
+ self.t_consume = asyncio.create_task(self.__consume(callback))
+
+ async def __consume(self, callback):
+ subject = get_internal_name(self.station_name)
+ consumer_group = get_internal_name(self.consumer_group)
+ self.psub = await self.connection.broker_connection.pull_subscribe(
+ subject + ".final", durable=consumer_group
+ )
+ while True:
+ if self.connection.is_connection_active and self.pull_interval_ms:
+ try:
+ memphis_messages = []
+ msgs = await self.psub.fetch(self.batch_size)
+ for msg in msgs:
+ memphis_messages.append(
+ Message(msg, self.connection, self.consumer_group)
+ )
+ await callback(memphis_messages, None, self.context)
+ await asyncio.sleep(self.pull_interval_ms / 1000)
+
+ except asyncio.TimeoutError:
+ await callback(
+ [], MemphisError("Memphis: TimeoutError"), self.context
+ )
+ continue
+ except Exception as e:
+ if self.connection.is_connection_active:
+ raise MemphisError(str(e)) from e
+ else:
+ return
+ else:
+ break
+
+ async def __consume_dls(self):
+ subject = get_internal_name(self.station_name)
+ consumer_group = get_internal_name(self.consumer_group)
+ try:
+ subscription_name = "$memphis_dls_" + subject + "_" + consumer_group
+ self.consumer_dls = await self.connection.broker_manager.subscribe(
+ subscription_name, subscription_name
+ )
+ async for msg in self.consumer_dls.messages:
+ index_to_insert = self.dls_current_index
+ if index_to_insert >= 10000:
+ index_to_insert %= 10000
+ self.dls_messages.insert(
+ index_to_insert, Message(msg, self.connection, self.consumer_group)
+ )
+ self.dls_current_index += 1
+ if self.dls_callback_func != None:
+ await self.dls_callback_func(
+ [Message(msg, self.connection, self.consumer_group)],
+ None,
+ self.context,
+ )
+ except Exception as e:
+ await self.dls_callback_func([], MemphisError(str(e)), self.context)
+ return
+
+ async def fetch(self, batch_size: int = 10):
+ """Fetch a batch of messages."""
+ messages = []
+ if self.connection.is_connection_active:
+ try:
+ self.batch_size = batch_size
+ if len(self.dls_messages) > 0:
+ if len(self.dls_messages) <= batch_size:
+ messages = self.dls_messages
+ self.dls_messages = []
+ self.dls_current_index = 0
+ else:
+ messages = self.dls_messages[0:batch_size]
+ del self.dls_messages[0:batch_size]
+ self.dls_current_index -= len(messages)
+ return messages
+
+ durableName = ""
+ if self.consumer_group != "":
+ durableName = get_internal_name(self.consumer_group)
+ else:
+ durableName = get_internal_name(self.consumer_name)
+ subject = get_internal_name(self.station_name)
+ consumer_group = get_internal_name(self.consumer_group)
+ self.psub = await self.connection.broker_connection.pull_subscribe(
+ subject + ".final", durable=durableName
+ )
+ msgs = await self.psub.fetch(batch_size)
+ for msg in msgs:
+ messages.append(Message(msg, self.connection, self.consumer_group))
+ return messages
+ except Exception as e:
+ if not "timeout" in str(e):
+ raise MemphisError(str(e)) from e
+ else:
+ return messages
+
+ async def __ping_consumer(self, callback):
+ while True:
+ try:
+ await asyncio.sleep(self.ping_consumer_invterval_ms / 1000)
+ consumer_group = get_internal_name(self.consumer_group)
+ await self.connection.broker_connection.consumer_info(
+ self.station_name, consumer_group, timeout=30
+ )
+
+ except Exception as e:
+ callback(MemphisError(str(e)))
+
+ async def destroy(self):
+ """Destroy the consumer."""
+ if self.t_consume is not None:
+ self.t_consume.cancel()
+ if self.t_dls is not None:
+ self.t_dls.cancel()
+ if self.t_ping is not None:
+ self.t_ping.cancel()
+ self.pull_interval_ms = None
+ try:
+ destroyConsumerReq = {
+ "name": self.consumer_name,
+ "station_name": self.station_name,
+ "username": self.connection.username,
+ }
+ consumer_name = json.dumps(destroyConsumerReq, indent=2).encode("utf-8")
+ res = await self.connection.broker_manager.request(
+ "$memphis_consumer_destructions", consumer_name, timeout=5
+ )
+ error = res.data.decode("utf-8")
+ if error != "" and not "not exist" in error:
+ raise MemphisError(error)
+ self.dls_messages.clear()
+ internal_station_name = get_internal_name(self.station_name)
+ map_key = internal_station_name + "_" + self.consumer_name.lower()
+ del self.connection.consumers_map[map_key]
+ except Exception as e:
+ raise MemphisError(str(e)) from e
diff --git a/memphis/exceptions.py b/memphis/exceptions.py
new file mode 100644
index 0000000..65c8c25
--- /dev/null
+++ b/memphis/exceptions.py
@@ -0,0 +1,23 @@
+class MemphisError(Exception):
+ def __init__(self, message):
+ message = message.replace("nats", "memphis")
+ message = message.replace("NATS", "memphis")
+ message = message.replace("Nats", "memphis")
+ message = message.replace("NatsError", "MemphisError")
+ self.message = message
+ if message.startswith("memphis:"):
+ super().__init__(self.message)
+ else:
+ super().__init__("memphis: " + self.message)
+
+
+class MemphisConnectError(MemphisError):
+ pass
+
+
+class MemphisSchemaError(MemphisError):
+ pass
+
+
+class MemphisHeaderError(MemphisError):
+ pass
diff --git a/memphis/headers.py b/memphis/headers.py
new file mode 100644
index 0000000..1d25e14
--- /dev/null
+++ b/memphis/headers.py
@@ -0,0 +1,19 @@
+from memphis.exceptions import MemphisHeaderError
+
+
+class Headers:
+ def __init__(self):
+ self.headers = {}
+
+ def add(self, key, value):
+ """Add a header.
+ Args:
+ key (string): header key.
+ value (string): header value.
+ Raises:
+ Exception: _description_
+ """
+ if not key.startswith("$memphis"):
+ self.headers[key] = value
+ else:
+ raise MemphisHeaderError("Keys in headers should not start with $memphis")
diff --git a/memphis/memphis.py b/memphis/memphis.py
index c8fee2a..c1235af 100644
--- a/memphis/memphis.py
+++ b/memphis/memphis.py
@@ -22,8 +22,6 @@
from typing import Callable, Iterable, Union
import graphql
-import memphis.retention_types as retention_types
-import memphis.storage_types as storage_types
import nats as broker
from google.protobuf import descriptor_pb2, descriptor_pool
from google.protobuf.message_factory import MessageFactory
@@ -31,41 +29,13 @@
from graphql import parse as parse_graphql
from graphql import validate as validate_graphql
from jsonschema import validate
-
-
-schemaVFailAlertType = "schema_validation_fail_alert"
-
-
-class set_interval:
- def __init__(self, func: Callable, sec: int):
- def func_wrapper():
- self.t = Timer(sec, func_wrapper)
- self.t.start()
- func()
-
- self.t = Timer(sec, func_wrapper)
- self.t.start()
-
- def cancel(self):
- self.t.cancel()
-
-
-class Headers:
- def __init__(self):
- self.headers = {}
-
- def add(self, key, value):
- """Add a header.
- Args:
- key (string): header key.
- value (string): header value.
- Raises:
- Exception: _description_
- """
- if not key.startswith("$memphis"):
- self.headers[key] = value
- else:
- raise MemphisHeaderError("Keys in headers should not start with $memphis")
+from memphis.consumer import Consumer
+from memphis.exceptions import MemphisConnectError, MemphisError, MemphisHeaderError
+from memphis.headers import Headers
+from memphis.producer import Producer
+from memphis.station import Station
+from memphis.types import Retention, Storage
+from memphis.utils import get_internal_name, random_bytes
class Memphis:
@@ -83,8 +53,9 @@ def __init__(self):
self.update_configurations_sub = {}
self.configuration_tasks = {}
self.producers_map = dict()
+ self.consumers_map = dict()
- async def get_msgs_update_configurations(self, iterable: Iterable):
+ async def get_msgs_sdk_clients_updates(self, iterable: Iterable):
try:
async for msg in iterable:
message = msg.data.decode("utf-8")
@@ -95,18 +66,21 @@ async def get_msgs_update_configurations(self, iterable: Iterable):
self.station_schemaverse_to_dls[data["station_name"]] = data[
"update"
]
+ elif data["type"] == "remove_station":
+ self.unset_cached_producer_station(data['station_name'])
+ self.unset_cached_consumer_station(data['station_name'])
except Exception as err:
raise MemphisError(err)
- async def configurations_listener(self):
+ async def sdk_client_updates_listener(self):
try:
sub = await self.broker_manager.subscribe(
- "$memphis_sdk_configurations_updates"
+ "$memphis_sdk_clients_updates"
)
self.update_configurations_sub = sub
loop = asyncio.get_event_loop()
task = loop.create_task(
- self.get_msgs_update_configurations(
+ self.get_msgs_sdk_clients_updates(
self.update_configurations_sub.messages
)
)
@@ -136,8 +110,8 @@ async def connect(
port (int, optional): port. Defaults to 6666.
reconnect (bool, optional): whether to do reconnect while connection is lost. Defaults to True.
max_reconnect (int, optional): The reconnect attempt. Defaults to 3.
- reconnect_interval_ms (int, optional): Interval in miliseconds between reconnect attempts. Defaults to 200.
- timeout_ms (int, optional): connection timeout in miliseconds. Defaults to 15000.
+ reconnect_interval_ms (int, optional): Interval in milliseconds between reconnect attempts. Defaults to 200.
+ timeout_ms (int, optional): connection timeout in milliseconds. Defaults to 15000.
key_file (string): path to tls key file.
cert_file (string): path to tls cert file.
ca_file (string): path to tls ca file.
@@ -184,7 +158,7 @@ async def connect(
name=self.connection_id + "::" + self.username,
)
- await self.configurations_listener()
+ await self.sdk_client_updates_listener()
self.broker_connection = self.broker_manager.jetstream()
self.is_connection_active = True
except Exception as e:
@@ -198,22 +172,22 @@ async def send_notification(self, title, msg, failedMsg, type):
async def station(
self,
name: str,
- retention_type: str = retention_types.MAX_MESSAGE_AGE_SECONDS,
+ retention_type: Retention = Retention.MAX_MESSAGE_AGE_SECONDS,
retention_value: int = 604800,
- storage_type: str = storage_types.DISK,
+ storage_type: Storage = Storage.DISK,
replicas: int = 1,
idempotency_window_ms: int = 120000,
schema_name: str = "",
send_poison_msg_to_dls: bool = True,
send_schema_failed_msg_to_dls: bool = True,
- tiered_storage_enabled: bool = False
+ tiered_storage_enabled: bool = False,
):
"""Creates a station.
Args:
name (str): station name.
- retention_type (str, optional): retention type: message_age_sec/messages/bytes . Defaults to "message_age_sec".
+ retention_type (Retention, optional): retention type: message_age_sec/messages/bytes . Defaults to "message_age_sec".
retention_value (int, optional): number which represents the retention based on the retention_type. Defaults to 604800.
- storage_type (str, optional): persistance storage for messages of the station: disk/memory. Defaults to "disk".
+ storage_type (Storage, optional): persistance storage for messages of the station: disk/memory. Defaults to "disk".
replicas (int, optional):number of replicas for the messages of the data. Defaults to 1.
idempotency_window_ms (int, optional): time frame in which idempotent messages will be tracked, happens based on message ID Defaults to 120000.
schema_name (str): schema name.
@@ -226,9 +200,9 @@ async def station(
createStationReq = {
"name": name,
- "retention_type": retention_type,
+ "retention_type": retention_type.value,
"retention_value": retention_value,
- "storage_type": storage_type,
+ "storage_type": storage_type.value,
"replicas": replicas,
"idempotency_window_in_ms": idempotency_window_ms,
"schema_name": schema_name,
@@ -237,7 +211,7 @@ async def station(
"Schemaverse": send_schema_failed_msg_to_dls,
},
"username": self.username,
- "tiered_storage_enabled": tiered_storage_enabled
+ "tiered_storage_enabled": tiered_storage_enabled,
}
create_station_req_bytes = json.dumps(createStationReq, indent=2).encode(
"utf-8"
@@ -329,6 +303,9 @@ async def close(self):
if self.update_configurations_sub is not None:
await self.update_configurations_sub.unsubscribe()
self.producers_map.clear()
+ for consumer in self.consumers_map:
+ consumer.dls_messages.clear()
+ self.consumers_map.clear()
except:
return
@@ -387,45 +364,45 @@ async def producer(
if create_res["error"] != "":
raise MemphisError(create_res["error"])
- station_name_internal = get_internal_name(station_name)
- self.station_schemaverse_to_dls[station_name_internal] = create_res[
+ internal_station_name = get_internal_name(station_name)
+ self.station_schemaverse_to_dls[internal_station_name] = create_res[
"schemaverse_to_dls"
]
self.cluster_configurations["send_notification"] = create_res[
"send_notification"
]
await self.start_listen_for_schema_updates(
- station_name_internal, create_res["schema_update"]
+ internal_station_name, create_res["schema_update"]
)
- if self.schema_updates_data[station_name_internal] != {}:
+ if self.schema_updates_data[internal_station_name] != {}:
if (
- self.schema_updates_data[station_name_internal]["type"]
+ self.schema_updates_data[internal_station_name]["type"]
== "protobuf"
):
- self.parse_descriptor(station_name_internal)
- if self.schema_updates_data[station_name_internal]["type"] == "json":
- schema = self.schema_updates_data[station_name_internal][
+ self.parse_descriptor(internal_station_name)
+ if self.schema_updates_data[internal_station_name]["type"] == "json":
+ schema = self.schema_updates_data[internal_station_name][
"active_version"
]["schema_content"]
- self.json_schemas[station_name_internal] = json.loads(schema)
+ self.json_schemas[internal_station_name] = json.loads(schema)
elif (
- self.schema_updates_data[station_name_internal]["type"] == "graphql"
+ self.schema_updates_data[internal_station_name]["type"] == "graphql"
):
- self.graphql_schemas[station_name_internal] = build_graphql_schema(
- self.schema_updates_data[station_name_internal][
+ self.graphql_schemas[internal_station_name] = build_graphql_schema(
+ self.schema_updates_data[internal_station_name][
"active_version"
]["schema_content"]
)
producer = Producer(self, producer_name, station_name, real_name)
- map_key = station_name_internal + "_" + real_name
+ map_key = internal_station_name + "_" + real_name
self.producers_map[map_key] = producer
return producer
except Exception as e:
raise MemphisError(str(e)) from e
- async def get_msg_schema_updates(self, station_name_internal, iterable):
+ async def get_msg_schema_updates(self, internal_station_name, iterable):
async for msg in iterable:
message = msg.data.decode("utf-8")
message = json.loads(message)
@@ -433,8 +410,8 @@ async def get_msg_schema_updates(self, station_name_internal, iterable):
data = {}
else:
data = message["init"]
- self.schema_updates_data[station_name_internal] = data
- self.parse_descriptor(station_name_internal)
+ self.schema_updates_data[internal_station_name] = data
+ self.parse_descriptor(internal_station_name)
def parse_descriptor(self, station_name):
try:
@@ -507,10 +484,10 @@ async def consumer(
station_name (str): station name to consume messages from.
consumer_name (str): name for the consumer.
consumer_group (str, optional): consumer group name. Defaults to the consumer name.
- pull_interval_ms (int, optional): interval in miliseconds between pulls. Defaults to 1000.
+ pull_interval_ms (int, optional): interval in milliseconds between pulls. Defaults to 1000.
batch_size (int, optional): pull batch size. Defaults to 10.
- batch_max_time_to_wait_ms (int, optional): max time in miliseconds to wait between pulls. Defaults to 5000.
- max_ack_time_ms (int, optional): max time for ack a message in miliseconds, in case a message not acked in this time period the Memphis broker will resend it. Defaults to 30000.
+ batch_max_time_to_wait_ms (int, optional): max time in milliseconds to wait between pulls. Defaults to 5000.
+ max_ack_time_ms (int, optional): max time for ack a message in milliseconds, in case a message not acked in this time period the Memphis broker will resend it. Defaults to 30000.
max_msg_deliveries (int, optional): max number of message deliveries, by default is 10.
generate_random_suffix (bool): false by default, if true concatenate a random suffix to consumer's name
start_consume_from_sequence(int, optional): start consuming from a specific sequence. defaults to 1.
@@ -521,7 +498,7 @@ async def consumer(
try:
if not self.is_connection_active:
raise MemphisError("Connection is dead")
-
+ real_name = consumer_name.lower()
if generate_random_suffix:
consumer_name = self.__generateRandomSuffix(consumer_name)
cg = consumer_name if not consumer_group else consumer_group
@@ -563,7 +540,9 @@ async def consumer(
if err_msg != "":
raise MemphisError(err_msg)
- return Consumer(
+ internal_station_name = get_internal_name(station_name)
+ map_key = internal_station_name + "_" + real_name
+ consumer = Consumer(
self,
station_name,
consumer_name,
@@ -576,6 +555,8 @@ async def consumer(
start_consume_from_sequence=start_consume_from_sequence,
last_messages=last_messages,
)
+ self.consumers_map[map_key] = consumer
+ return consumer
except Exception as e:
raise MemphisError(str(e)) from e
@@ -599,13 +580,13 @@ async def produce(
ack_wait_sec (int, optional): max time in seconds to wait for an ack from memphis. Defaults to 15.
headers (dict, optional): Message headers, defaults to {}.
async_produce (boolean, optional): produce operation won't wait for broker acknowledgement
- msg_id (string, optional): Attach msg-id header to the message in order to achieve idempotency
+ msg_id (string, optional): Attach msg-id header to the message in order to achieve idempotence
Raises:
Exception: _description_
"""
try:
- station_name_internal = get_internal_name(station_name)
- map_key = station_name_internal + "_" + producer_name.lower()
+ internal_station_name = get_internal_name(station_name)
+ map_key = internal_station_name + "_" + producer_name.lower()
producer = None
if map_key in self.producers_map:
producer = self.producers_map[map_key]
@@ -625,556 +606,84 @@ async def produce(
except Exception as e:
raise MemphisError(str(e)) from e
- def is_connected(self):
- return self.broker_manager.is_connected
-
-
-class Station:
- def __init__(self, connection, name: str):
- self.connection = connection
- self.name = name.lower()
-
- async def destroy(self):
- """Destroy the station."""
- try:
- nameReq = {"station_name": self.name, "username": self.connection.username}
- station_name = json.dumps(nameReq, indent=2).encode("utf-8")
- res = await self.connection.broker_manager.request(
- "$memphis_station_destructions", station_name, timeout=5
- )
- error = res.data.decode("utf-8")
- if error != "" and not "not exist" in error:
- raise MemphisError(error)
-
- station_name_internal = get_internal_name(self.name)
- sub = self.connection.schema_updates_subs.get(station_name_internal)
- task = self.connection.schema_tasks.get(station_name_internal)
- if station_name_internal in self.connection.schema_updates_data:
- del self.connection.schema_updates_data[station_name_internal]
- if station_name_internal in self.connection.schema_updates_subs:
- del self.connection.schema_updates_subs[station_name_internal]
- if station_name_internal in self.connection.producers_per_station:
- del self.connection.producers_per_station[station_name_internal]
- if station_name_internal in self.connection.schema_tasks:
- del self.connection.schema_tasks[station_name_internal]
- if task is not None:
- task.cancel()
- if sub is not None:
- await sub.unsubscribe()
-
- self.connection.producers_map = {
- k: v
- for k, v in self.connection.producers_map.items()
- if self.name not in k
- }
-
- except Exception as e:
- raise MemphisError(str(e)) from e
-
-
-def get_internal_name(name: str) -> str:
- name = name.lower()
- return name.replace(".", "#")
-
-
-class Producer:
- def __init__(
- self, connection, producer_name: str, station_name: str, real_name: str
- ):
- self.connection = connection
- self.producer_name = producer_name.lower()
- self.station_name = station_name
- self.internal_station_name = get_internal_name(self.station_name)
- self.loop = asyncio.get_running_loop()
- self.real_name = real_name
-
- async def validate_msg(self, message):
- if self.connection.schema_updates_data[self.internal_station_name] != {}:
- schema_type = self.connection.schema_updates_data[
- self.internal_station_name
- ]["type"]
- if schema_type == "protobuf":
- message = self.validate_protobuf(message)
- return message
- elif schema_type == "json":
- message = self.validate_json_schema(message)
- return message
- elif schema_type == "graphql":
- message = self.validate_graphql(message)
- return message
- elif not isinstance(message, bytearray) and not isinstance(message, dict):
- raise MemphisSchemaError("Unsupported message type")
- else:
- if isinstance(message, dict):
- message = bytearray(json.dumps(message).encode("utf-8"))
- return message
-
- def validate_protobuf(self, message):
- proto_msg = self.connection.proto_msgs[self.internal_station_name]
- msgToSend = ""
- try:
- if isinstance(message, bytearray):
- msgToSend = bytes(message)
- try:
- proto_msg.ParseFromString(msgToSend)
- proto_msg.SerializeToString()
- msgToSend = msgToSend.decode("utf-8")
- except Exception as e:
- if "parsing message" in str(e):
- e = "Invalid message format, expecting protobuf"
- raise MemphisSchemaError(str(e))
- return message
- elif hasattr(message, "SerializeToString"):
- msgToSend = message.SerializeToString()
- proto_msg.ParseFromString(msgToSend)
- proto_msg.SerializeToString()
- return msgToSend
-
- else:
- raise MemphisSchemaError("Unsupported message type")
-
- except Exception as e:
- raise MemphisSchemaError("Schema validation has failed: " + str(e))
-
- def validate_json_schema(self, message):
- try:
- if isinstance(message, bytearray):
- try:
- message_obj = json.loads(message)
- except Exception as e:
- raise Exception("Expecting Json format: " + str(e))
- elif isinstance(message, dict):
- message_obj = message
- message = bytearray(json.dumps(message_obj).encode("utf-8"))
- else:
- raise Exception("Unsupported message type")
-
- validate(
- instance=message_obj,
- schema=self.connection.json_schemas[self.internal_station_name],
- )
- return message
- except Exception as e:
- raise MemphisSchemaError("Schema validation has failed: " + str(e))
-
- def validate_graphql(self, message):
- try:
- if isinstance(message, bytearray):
- msg = message.decode("utf-8")
- msg = parse_graphql(msg)
- elif isinstance(message, str):
- msg = parse_graphql(message)
- message = message.encode("utf-8")
- elif isinstance(message, graphql.language.ast.DocumentNode):
- msg = message
- message = str(msg.loc.source.body)
- message = message.encode("utf-8")
- else:
- raise MemphisError("Unsupported message type")
- validate_res = validate_graphql(
- schema=self.connection.graphql_schemas[self.internal_station_name],
- document_ast=msg,
- )
- if len(validate_res) > 0:
- raise Exception("Schema validation has failed: " + str(validate_res))
- return message
- except Exception as e:
- if "Syntax Error" in str(e):
- e = "Invalid message format, expected GraphQL"
- raise Exception("Schema validation has failed: " + str(e))
-
- def get_dls_msg_id(self, station_name: str, producer_name: str, unix_time: str):
- return station_name + "~" + producer_name + "~0~" + unix_time
-
- async def produce(
+ async def fetch_messages(
self,
- message,
- ack_wait_sec: int = 15,
- headers: Union[Headers, None] = None,
- async_produce: bool = False,
- msg_id: Union[str, None] = None,
- ):
- """Produces a message into a station.
- Args:
- message (bytearray/dict): message to send into the station - bytearray/protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema)
- ack_wait_sec (int, optional): max time in seconds to wait for an ack from memphis. Defaults to 15.
- headers (dict, optional): Message headers, defaults to {}.
- async_produce (boolean, optional): produce operation won't wait for broker acknowledgement
- msg_id (string, optional): Attach msg-id header to the message in order to achieve idempotency
- Raises:
- Exception: _description_
- """
- try:
- message = await self.validate_msg(message)
-
- memphis_headers = {
- "$memphis_producedBy": self.producer_name,
- "$memphis_connectionId": self.connection.connection_id,
- }
-
- if msg_id is not None:
- memphis_headers["msg-id"] = msg_id
-
- if headers is not None:
- headers = headers.headers
- headers.update(memphis_headers)
- else:
- headers = memphis_headers
-
- if async_produce:
- try:
- self.loop.create_task(
- self.connection.broker_connection.publish(
- self.internal_station_name + ".final",
- message,
- timeout=ack_wait_sec,
- headers=headers,
- )
- )
- await asyncio.sleep(1)
- except Exception as e:
- raise MemphisError(e)
- else:
- await self.connection.broker_connection.publish(
- self.internal_station_name + ".final",
- message,
- timeout=ack_wait_sec,
- headers=headers,
- )
- except Exception as e:
- if hasattr(e, "status_code") and e.status_code == "503":
- raise MemphisError(
- "Produce operation has failed, please check whether Station/Producer are still exist"
- )
- else:
- if "Schema validation has failed" in str(
- e
- ) or "Unsupported message type" in str(e):
- msgToSend = ""
- if isinstance(message, bytearray):
- msgToSend = str(message, "utf-8")
- elif hasattr(message, "SerializeToString"):
- msgToSend = message.SerializeToString().decode("utf-8")
- if self.connection.station_schemaverse_to_dls[
- self.internal_station_name
- ]:
- unix_time = int(time.time())
- id = self.get_dls_msg_id(
- self.internal_station_name,
- self.producer_name,
- str(unix_time),
- )
-
- memphis_headers = {
- "$memphis_producedBy": self.producer_name,
- "$memphis_connectionId": self.connection.connection_id,
- }
-
- if headers != {}:
- headers = headers.headers
- headers.update(memphis_headers)
- else:
- headers = memphis_headers
-
- msgToSendEncoded = msgToSend.encode("utf-8")
- msgHex = msgToSendEncoded.hex()
- buf = {
- "_id": id,
- "station_name": self.internal_station_name,
- "producer": {
- "name": self.producer_name,
- "connection_id": self.connection.connection_id,
- },
- "creation_unix": unix_time,
- "message": {
- "data": msgHex,
- "headers": headers,
- },
- }
- buf = json.dumps(buf).encode("utf-8")
- await self.connection.broker_connection.publish(
- "$memphis-"
- + self.internal_station_name
- + "-dls.schema."
- + id,
- buf,
- )
- if self.connection.cluster_configurations.get(
- "send_notification"
- ):
- await self.connection.send_notification(
- "Schema validation has failed",
- "Station: "
- + self.station_name
- + "\nProducer: "
- + self.producer_name
- + "\nError:"
- + str(e),
- msgToSend,
- schemaVFailAlertType,
- )
- raise MemphisError(str(e)) from e
-
- async def destroy(self):
- """Destroy the producer."""
- try:
- destroyProducerReq = {
- "name": self.producer_name,
- "station_name": self.station_name,
- "username": self.connection.username,
- }
-
- producer_name = json.dumps(destroyProducerReq).encode("utf-8")
- res = await self.connection.broker_manager.request(
- "$memphis_producer_destructions", producer_name, timeout=5
- )
- error = res.data.decode("utf-8")
- if error != "" and not "not exist" in error:
- raise Exception(error)
-
- station_name_internal = get_internal_name(self.station_name)
- producer_number = (
- self.connection.producers_per_station.get(station_name_internal) - 1
- )
- self.connection.producers_per_station[
- station_name_internal
- ] = producer_number
-
- if producer_number == 0:
- sub = self.connection.schema_updates_subs.get(station_name_internal)
- task = self.connection.schema_tasks.get(station_name_internal)
- if station_name_internal in self.connection.schema_updates_data:
- del self.connection.schema_updates_data[station_name_internal]
- if station_name_internal in self.connection.schema_updates_subs:
- del self.connection.schema_updates_subs[station_name_internal]
- if station_name_internal in self.connection.schema_tasks:
- del self.connection.schema_tasks[station_name_internal]
- if task is not None:
- task.cancel()
- if sub is not None:
- await sub.unsubscribe()
-
- map_key = station_name_internal + "_" + self.real_name
- del self.connection.producers_map[map_key]
-
- except Exception as e:
- raise Exception(e)
-
-
-def default_error_handler(e):
- print("ping exception raised", e)
-
-
-class Consumer:
- def __init__(
- self,
- connection,
station_name: str,
- consumer_name,
- consumer_group,
- pull_interval_ms: int,
- batch_size: int,
- batch_max_time_to_wait_ms: int,
- max_ack_time_ms: int,
+ consumer_name: str,
+ consumer_group: str = "",
+ batch_size: int = 10,
+ batch_max_time_to_wait_ms: int = 5000,
+ max_ack_time_ms: int = 30000,
max_msg_deliveries: int = 10,
- error_callback=None,
+ generate_random_suffix: bool = False,
start_consume_from_sequence: int = 1,
last_messages: int = -1,
):
- self.connection = connection
- self.station_name = station_name.lower()
- self.consumer_name = consumer_name.lower()
- self.consumer_group = consumer_group.lower()
- self.pull_interval_ms = pull_interval_ms
- self.batch_size = batch_size
- self.batch_max_time_to_wait_ms = batch_max_time_to_wait_ms
- self.max_ack_time_ms = max_ack_time_ms
- self.max_msg_deliveries = max_msg_deliveries
- self.ping_consumer_invterval_ms = 30000
- if error_callback is None:
- error_callback = default_error_handler
- self.t_ping = asyncio.create_task(self.__ping_consumer(error_callback))
- self.start_consume_from_sequence = start_consume_from_sequence
-
- self.last_messages = last_messages
- self.context = {}
-
- def set_context(self, context):
- """Set a context (dict) that will be passed to each message handler call."""
- self.context = context
-
- def consume(self, callback):
- """Consume events."""
- self.t_consume = asyncio.create_task(self.__consume(callback))
- self.t_dls = asyncio.create_task(self.__consume_dls(callback))
-
- async def __consume(self, callback):
- subject = get_internal_name(self.station_name)
- consumer_group = get_internal_name(self.consumer_group)
- self.psub = await self.connection.broker_connection.pull_subscribe(
- subject + ".final", durable=consumer_group
- )
- while True:
- if self.connection.is_connection_active and self.pull_interval_ms:
- try:
- memphis_messages = []
- msgs = await self.psub.fetch(self.batch_size)
- for msg in msgs:
- memphis_messages.append(
- Message(msg, self.connection, self.consumer_group)
- )
- await callback(memphis_messages, None, self.context)
- await asyncio.sleep(self.pull_interval_ms / 1000)
-
- except asyncio.TimeoutError:
- await callback(
- [], MemphisError("Memphis: TimeoutError"), self.context
- )
- continue
- except Exception as e:
- if self.connection.is_connection_active:
- raise MemphisError(str(e)) from e
- else:
- return
- else:
- break
-
- async def __consume_dls(self, callback):
- subject = get_internal_name(self.station_name)
- consumer_group = get_internal_name(self.consumer_group)
+ """Consume a batch of messages.
+ Args:.
+ station_name (str): station name to consume messages from.
+ consumer_name (str): name for the consumer.
+ consumer_group (str, optional): consumer group name. Defaults to the consumer name.
+ batch_size (int, optional): pull batch size. Defaults to 10.
+ batch_max_time_to_wait_ms (int, optional): max time in miliseconds to wait between pulls. Defaults to 5000.
+ max_ack_time_ms (int, optional): max time for ack a message in miliseconds, in case a message not acked in this time period the Memphis broker will resend it. Defaults to 30000.
+ max_msg_deliveries (int, optional): max number of message deliveries, by default is 10.
+ generate_random_suffix (bool): false by default, if true concatenate a random suffix to consumer's name
+ start_consume_from_sequence(int, optional): start consuming from a specific sequence. defaults to 1.
+ last_messages: consume the last N messages, defaults to -1 (all messages in the station).
+ Returns:
+ list: Message
+ """
try:
- subscription_name = "$memphis_dls_" + subject + "_" + consumer_group
- self.consumer_dls = await self.connection.broker_manager.subscribe(
- subscription_name, subscription_name
- )
- async for msg in self.consumer_dls.messages:
- await callback(
- [Message(msg, self.connection, self.consumer_group)],
- None,
- self.context,
- )
- except Exception as e:
- print("dls", e)
- await callback([], MemphisError(str(e)), self.context)
- return
-
- async def __ping_consumer(self, callback):
- while True:
- try:
- await asyncio.sleep(self.ping_consumer_invterval_ms / 1000)
- consumer_group = get_internal_name(self.consumer_group)
- await self.connection.broker_connection.consumer_info(
- self.station_name, consumer_group, timeout=30
+ consumer = None
+ if not self.is_connection_active:
+ raise MemphisError("Cant fetch messages without being connected!")
+ internal_station_name = get_internal_name(station_name)
+ consumer_map_key = internal_station_name + "_" + consumer_name.lower()
+ if consumer_map_key in self.consumers_map:
+ consumer = self.consumers_map[consumer_map_key]
+ else:
+ consumer = await self.consumer(
+ station_name=station_name,
+ consumer_name=consumer_name,
+ consumer_group=consumer_group,
+ batch_size=batch_size,
+ batch_max_time_to_wait_ms=batch_max_time_to_wait_ms,
+ max_ack_time_ms=max_ack_time_ms,
+ max_msg_deliveries=max_msg_deliveries,
+ generate_random_suffix=generate_random_suffix,
+ start_consume_from_sequence=start_consume_from_sequence,
+ last_messages=last_messages,
)
-
- except Exception as e:
- callback(e)
-
- async def destroy(self):
- """Destroy the consumer."""
- if self.t_consume is not None:
- self.t_consume.cancel()
- if self.t_dls is not None:
- self.t_dls.cancel()
- if self.t_ping is not None:
- self.t_ping.cancel()
- self.pull_interval_ms = None
- try:
- destroyConsumerReq = {
- "name": self.consumer_name,
- "station_name": self.station_name,
- "username": self.connection.username,
- }
- consumer_name = json.dumps(destroyConsumerReq, indent=2).encode("utf-8")
- res = await self.connection.broker_manager.request(
- "$memphis_consumer_destructions", consumer_name, timeout=5
- )
- error = res.data.decode("utf-8")
- if error != "" and not "not exist" in error:
- raise MemphisError(error)
+ messages = await consumer.fetch(batch_size)
+ if messages == None:
+ messages = []
+ return messages
except Exception as e:
raise MemphisError(str(e)) from e
-
-class Message:
- def __init__(self, message, connection, cg_name):
- self.message = message
- self.connection = connection
- self.cg_name = cg_name
-
- async def ack(self):
- """Ack a message is done processing."""
+ def is_connected(self):
+ return self.broker_manager.is_connected
+
+ def unset_cached_producer_station(self, station_name):
try:
- await self.message.ack()
+ internal_station_name = get_internal_name(station_name)
+ for key in list(self.producers_map):
+ producer = self.producers_map[key]
+ if producer.internal_station_name == internal_station_name:
+ del self.producers_map[key]
except Exception as e:
- if (
- "$memphis_pm_id"
- in self.message.headers & "$memphis_pm_sequence"
- in self.message.headers
- ):
- try:
- msg = {
- "id": self.message.headers["$memphis_pm_id"],
- "sequence": self.message.headers["$memphis_pm_sequence"],
- }
- msgToAck = json.dumps(msg).encode("utf-8")
- await self.connection.broker_manager.publish(
- "$memphis_pm_acks", msgToAck
- )
- except Exception as er:
- raise MemphisConnectError(str(er)) from er
- else:
- raise MemphisConnectError(str(e)) from e
- return
+ raise e
+
- def get_data(self):
- """Receive the message."""
+ def unset_cached_consumer_station(self, station_name):
try:
- return bytearray(self.message.data)
- except:
- return
-
- def get_headers(self):
- """Receive the headers."""
- try:
- return self.message.headers
- except:
- return
-
- def get_sequence_number(self):
- """Get message sequence number."""
- try:
- return self.message.metadata.sequence.stream
- except:
- return
-
-
-def random_bytes(amount: int) -> str:
- lst = [random.choice("0123456789abcdef") for n in range(amount)]
- s = "".join(lst)
- return s
-
-
-class MemphisError(Exception):
- def __init__(self, message):
- message = message.replace("nats", "memphis")
- message = message.replace("NATS", "memphis")
- message = message.replace("Nats", "memphis")
- message = message.replace("NatsError", "MemphisError")
- self.message = message
- if message.startswith("memphis:"):
- super().__init__(self.message)
- else:
- super().__init__("memphis: " + self.message)
-
-
-class MemphisConnectError(MemphisError):
- pass
-
-
-class MemphisSchemaError(MemphisError):
- pass
-
+ internal_station_name = get_internal_name(station_name)
+ for key in list(self.consumers_map):
+ consumer = self.consumers_map[key]
+ consumer_station_name_internal = get_internal_name(consumer.station_name)
+ if consumer_station_name_internal == internal_station_name:
+ del self.consumers_map[key]
+ except Exception as e:
+ raise e
-class MemphisHeaderError(MemphisError):
- pass
diff --git a/memphis/message.py b/memphis/message.py
new file mode 100644
index 0000000..3c8f30d
--- /dev/null
+++ b/memphis/message.py
@@ -0,0 +1,57 @@
+from __future__ import annotations
+
+import json
+
+from memphis.exceptions import MemphisConnectError
+
+
+class Message:
+ def __init__(self, message, connection, cg_name):
+ self.message = message
+ self.connection = connection
+ self.cg_name = cg_name
+
+ async def ack(self):
+ """Ack a message is done processing."""
+ try:
+ await self.message.ack()
+ except Exception as e:
+ if (
+ "$memphis_pm_id" in self.message.headers
+ and "$memphis_pm_sequence" in self.message.headers
+ ):
+ try:
+ msg = {
+ "id": self.message.headers["$memphis_pm_id"],
+ "sequence": self.message.headers["$memphis_pm_sequence"],
+ }
+ msgToAck = json.dumps(msg).encode("utf-8")
+ await self.connection.broker_manager.publish(
+ "$memphis_pm_acks", msgToAck
+ )
+ except Exception as er:
+ raise MemphisConnectError(str(er)) from er
+ else:
+ raise MemphisConnectError(str(e)) from e
+ return
+
+ def get_data(self):
+ """Receive the message."""
+ try:
+ return bytearray(self.message.data)
+ except:
+ return
+
+ def get_headers(self):
+ """Receive the headers."""
+ try:
+ return self.message.headers
+ except:
+ return
+
+ def get_sequence_number(self):
+ """Get message sequence number."""
+ try:
+ return self.message.metadata.sequence.stream
+ except:
+ return
diff --git a/memphis/producer.py b/memphis/producer.py
new file mode 100644
index 0000000..5c2e4e1
--- /dev/null
+++ b/memphis/producer.py
@@ -0,0 +1,300 @@
+from __future__ import annotations
+
+import asyncio
+import json
+import time
+from typing import Union
+
+import graphql
+from graphql import parse as parse_graphql
+from graphql import validate as validate_graphql
+from jsonschema import validate
+from memphis.exceptions import MemphisError, MemphisSchemaError
+from memphis.headers import Headers
+from memphis.utils import get_internal_name
+
+schemaVFailAlertType = "schema_validation_fail_alert"
+
+
+class Producer:
+ def __init__(
+ self, connection, producer_name: str, station_name: str, real_name: str
+ ):
+ self.connection = connection
+ self.producer_name = producer_name.lower()
+ self.station_name = station_name
+ self.internal_station_name = get_internal_name(self.station_name)
+ self.loop = asyncio.get_running_loop()
+ self.real_name = real_name
+
+ async def validate_msg(self, message):
+ if self.connection.schema_updates_data[self.internal_station_name] != {}:
+ schema_type = self.connection.schema_updates_data[
+ self.internal_station_name
+ ]["type"]
+ if schema_type == "protobuf":
+ message = self.validate_protobuf(message)
+ return message
+ elif schema_type == "json":
+ message = self.validate_json_schema(message)
+ return message
+ elif schema_type == "graphql":
+ message = self.validate_graphql(message)
+ return message
+ elif not isinstance(message, bytearray) and not isinstance(message, dict):
+ raise MemphisSchemaError("Unsupported message type")
+ else:
+ if isinstance(message, dict):
+ message = bytearray(json.dumps(message).encode("utf-8"))
+ return message
+
+ def validate_protobuf(self, message):
+ proto_msg = self.connection.proto_msgs[self.internal_station_name]
+ msgToSend = ""
+ try:
+ if isinstance(message, bytearray):
+ msgToSend = bytes(message)
+ try:
+ proto_msg.ParseFromString(msgToSend)
+ proto_msg.SerializeToString()
+ msgToSend = msgToSend.decode("utf-8")
+ except Exception as e:
+ if "parsing message" in str(e):
+ e = "Invalid message format, expecting protobuf"
+ raise MemphisSchemaError(str(e))
+ return message
+ elif hasattr(message, "SerializeToString"):
+ msgToSend = message.SerializeToString()
+ proto_msg.ParseFromString(msgToSend)
+ proto_msg.SerializeToString()
+ return msgToSend
+
+ else:
+ raise MemphisSchemaError("Unsupported message type")
+
+ except Exception as e:
+ raise MemphisSchemaError("Schema validation has failed: " + str(e))
+
+ def validate_json_schema(self, message):
+ try:
+ if isinstance(message, bytearray):
+ try:
+ message_obj = json.loads(message)
+ except Exception as e:
+ raise Exception("Expecting Json format: " + str(e))
+ elif isinstance(message, dict):
+ message_obj = message
+ message = bytearray(json.dumps(message_obj).encode("utf-8"))
+ else:
+ raise Exception("Unsupported message type")
+
+ validate(
+ instance=message_obj,
+ schema=self.connection.json_schemas[self.internal_station_name],
+ )
+ return message
+ except Exception as e:
+ raise MemphisSchemaError("Schema validation has failed: " + str(e))
+
+ def validate_graphql(self, message):
+ try:
+ if isinstance(message, bytearray):
+ msg = message.decode("utf-8")
+ msg = parse_graphql(msg)
+ elif isinstance(message, str):
+ msg = parse_graphql(message)
+ message = message.encode("utf-8")
+ elif isinstance(message, graphql.language.ast.DocumentNode):
+ msg = message
+ message = str(msg.loc.source.body)
+ message = message.encode("utf-8")
+ else:
+ raise MemphisError("Unsupported message type")
+ validate_res = validate_graphql(
+ schema=self.connection.graphql_schemas[self.internal_station_name],
+ document_ast=msg,
+ )
+ if len(validate_res) > 0:
+ raise Exception("Schema validation has failed: " + str(validate_res))
+ return message
+ except Exception as e:
+ if "Syntax Error" in str(e):
+ e = "Invalid message format, expected GraphQL"
+ raise Exception("Schema validation has failed: " + str(e))
+
+ def get_dls_msg_id(self, station_name: str, producer_name: str, unix_time: str):
+ return station_name + "~" + producer_name + "~0~" + unix_time
+
+ async def produce(
+ self,
+ message,
+ ack_wait_sec: int = 15,
+ headers: Union[Headers, None] = None,
+ async_produce: bool = False,
+ msg_id: Union[str, None] = None,
+ ):
+ """Produces a message into a station.
+ Args:
+ message (bytearray/dict): message to send into the station - bytearray/protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema)
+ ack_wait_sec (int, optional): max time in seconds to wait for an ack from memphis. Defaults to 15.
+ headers (dict, optional): Message headers, defaults to {}.
+ async_produce (boolean, optional): produce operation won't wait for broker acknowledgement
+ msg_id (string, optional): Attach msg-id header to the message in order to achieve idempotency
+ Raises:
+ Exception: _description_
+ """
+ try:
+ message = await self.validate_msg(message)
+
+ memphis_headers = {
+ "$memphis_producedBy": self.producer_name,
+ "$memphis_connectionId": self.connection.connection_id,
+ }
+
+ if msg_id is not None:
+ memphis_headers["msg-id"] = msg_id
+
+ if headers is not None:
+ headers = headers.headers
+ headers.update(memphis_headers)
+ else:
+ headers = memphis_headers
+
+ if async_produce:
+ try:
+ self.loop.create_task(
+ self.connection.broker_connection.publish(
+ self.internal_station_name + ".final",
+ message,
+ timeout=ack_wait_sec,
+ headers=headers,
+ )
+ )
+ await asyncio.sleep(1) # TODO - check why we need sleep in here
+ except Exception as e:
+ raise MemphisError(e)
+ else:
+ await self.connection.broker_connection.publish(
+ self.internal_station_name + ".final",
+ message,
+ timeout=ack_wait_sec,
+ headers=headers,
+ )
+ except Exception as e:
+ if hasattr(e, "status_code") and e.status_code == "503":
+ raise MemphisError(
+ "Produce operation has failed, please check whether Station/Producer are still exist"
+ )
+ else:
+ if "Schema validation has failed" in str(
+ e
+ ) or "Unsupported message type" in str(e):
+ msgToSend = ""
+ if isinstance(message, bytearray):
+ msgToSend = str(message, "utf-8")
+ elif hasattr(message, "SerializeToString"):
+ msgToSend = message.SerializeToString().decode("utf-8")
+ if self.connection.station_schemaverse_to_dls[
+ self.internal_station_name
+ ]:
+ unix_time = int(time.time())
+ id = self.get_dls_msg_id(
+ self.internal_station_name,
+ self.producer_name,
+ str(unix_time),
+ )
+
+ memphis_headers = {
+ "$memphis_producedBy": self.producer_name,
+ "$memphis_connectionId": self.connection.connection_id,
+ }
+
+ if headers != {}:
+ headers = headers.headers
+ headers.update(memphis_headers)
+ else:
+ headers = memphis_headers
+
+ msgToSendEncoded = msgToSend.encode("utf-8")
+ msgHex = msgToSendEncoded.hex()
+ buf = {
+ "_id": id,
+ "station_name": self.internal_station_name,
+ "producer": {
+ "name": self.producer_name,
+ "connection_id": self.connection.connection_id,
+ },
+ "creation_unix": unix_time,
+ "message": {
+ "data": msgHex,
+ "headers": headers,
+ },
+ }
+ buf = json.dumps(buf).encode("utf-8")
+ await self.connection.broker_connection.publish(
+ "$memphis-"
+ + self.internal_station_name
+ + "-dls.schema."
+ + id,
+ buf,
+ )
+ if self.connection.cluster_configurations.get(
+ "send_notification"
+ ):
+ await self.connection.send_notification(
+ "Schema validation has failed",
+ "Station: "
+ + self.station_name
+ + "\nProducer: "
+ + self.producer_name
+ + "\nError:"
+ + str(e),
+ msgToSend,
+ schemaVFailAlertType,
+ )
+ raise MemphisError(str(e)) from e
+
+ async def destroy(self):
+ """Destroy the producer."""
+ try:
+ destroyProducerReq = {
+ "name": self.producer_name,
+ "station_name": self.station_name,
+ "username": self.connection.username,
+ }
+
+ producer_name = json.dumps(destroyProducerReq).encode("utf-8")
+ res = await self.connection.broker_manager.request(
+ "$memphis_producer_destructions", producer_name, timeout=5
+ )
+ error = res.data.decode("utf-8")
+ if error != "" and not "not exist" in error:
+ raise Exception(error)
+
+ internal_station_name = get_internal_name(self.station_name)
+ producer_number = (
+ self.connection.producers_per_station.get(internal_station_name) - 1
+ )
+ self.connection.producers_per_station[
+ internal_station_name
+ ] = producer_number
+
+ if producer_number == 0:
+ sub = self.connection.schema_updates_subs.get(internal_station_name)
+ task = self.connection.schema_tasks.get(internal_station_name)
+ if internal_station_name in self.connection.schema_updates_data:
+ del self.connection.schema_updates_data[internal_station_name]
+ if internal_station_name in self.connection.schema_updates_subs:
+ del self.connection.schema_updates_subs[internal_station_name]
+ if internal_station_name in self.connection.schema_tasks:
+ del self.connection.schema_tasks[internal_station_name]
+ if task is not None:
+ task.cancel()
+ if sub is not None:
+ await sub.unsubscribe()
+
+ map_key = internal_station_name + "_" + self.real_name
+ del self.connection.producers_map[map_key]
+
+ except Exception as e:
+ raise Exception(e)
diff --git a/memphis/station.py b/memphis/station.py
new file mode 100644
index 0000000..05aac6d
--- /dev/null
+++ b/memphis/station.py
@@ -0,0 +1,53 @@
+import json
+
+from memphis.exceptions import MemphisError
+from memphis.utils import get_internal_name
+
+
+class Station:
+ def __init__(self, connection, name: str):
+ self.connection = connection
+ self.name = name.lower()
+
+ async def destroy(self):
+ """Destroy the station."""
+ try:
+ nameReq = {"station_name": self.name, "username": self.connection.username}
+ station_name = json.dumps(nameReq, indent=2).encode("utf-8")
+ res = await self.connection.broker_manager.request(
+ "$memphis_station_destructions", station_name, timeout=5
+ )
+ error = res.data.decode("utf-8")
+ if error != "" and not "not exist" in error:
+ raise MemphisError(error)
+
+ internal_station_name = get_internal_name(self.name)
+ sub = self.connection.schema_updates_subs.get(internal_station_name)
+ task = self.connection.schema_tasks.get(internal_station_name)
+ if internal_station_name in self.connection.schema_updates_data:
+ del self.connection.schema_updates_data[internal_station_name]
+ if internal_station_name in self.connection.schema_updates_subs:
+ del self.connection.schema_updates_subs[internal_station_name]
+ if internal_station_name in self.connection.producers_per_station:
+ del self.connection.producers_per_station[internal_station_name]
+ if internal_station_name in self.connection.schema_tasks:
+ del self.connection.schema_tasks[internal_station_name]
+ if task is not None:
+ task.cancel()
+ if sub is not None:
+ await sub.unsubscribe()
+
+ self.connection.producers_map = {
+ k: v
+ for k, v in self.connection.producers_map.items()
+ if self.name not in k
+ }
+
+ self.connection.consumers_map = {
+ k: v
+ for k, v in self.connection.consumers_map.items()
+ if self.name not in k
+ }
+
+ except Exception as e:
+ raise MemphisError(str(e)) from e
diff --git a/memphis/storage_types.py b/memphis/storage_types.py
deleted file mode 100644
index 8c1d570..0000000
--- a/memphis/storage_types.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# Credit for The NATS.IO Authors
-# Copyright 2021-2022 The Memphis Authors
-# Licensed under the Apache License, Version 2.0 (the “License”);
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http:#www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an “AS IS” BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-DISK = "file"
-MEMORY = "memory"
diff --git a/memphis/retention_types.py b/memphis/types.py
similarity index 75%
rename from memphis/retention_types.py
rename to memphis/types.py
index 6fb09ae..cfd6a70 100644
--- a/memphis/retention_types.py
+++ b/memphis/types.py
@@ -12,6 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-MAX_MESSAGE_AGE_SECONDS = "message_age_sec"
-MESSAGES = "messages"
-BYTES = "bytes"
+from enum import Enum
+
+
+class Retention(Enum):
+ MAX_MESSAGE_AGE_SECONDS = "message_age_sec"
+ MESSAGES = "messages"
+ BYTES = "bytes"
+
+
+class Storage(Enum):
+ DISK = "file"
+ MEMORY = "memory"
diff --git a/memphis/utils.py b/memphis/utils.py
new file mode 100644
index 0000000..10512e9
--- /dev/null
+++ b/memphis/utils.py
@@ -0,0 +1,32 @@
+import random
+from threading import Timer
+from typing import Callable
+
+
+class set_interval:
+ def __init__(self, func: Callable, sec: int):
+ def func_wrapper():
+ self.t = Timer(sec, func_wrapper)
+ self.t.start()
+ func()
+
+ self.t = Timer(sec, func_wrapper)
+ self.t.start()
+
+ def cancel(self):
+ self.t.cancel()
+
+
+def default_error_handler(e):
+ print("ping exception raised", e)
+
+
+def get_internal_name(name: str) -> str:
+ name = name.lower()
+ return name.replace(".", "#")
+
+
+def random_bytes(amount: int) -> str:
+ lst = [random.choice("0123456789abcdef") for n in range(amount)]
+ s = "".join(lst)
+ return s
diff --git a/setup.py b/setup.py
index d53182d..5b15795 100644
--- a/setup.py
+++ b/setup.py
@@ -8,7 +8,7 @@
setup(
name="memphis-py",
packages=["memphis"],
- version="0.3.2",
+ version="0.3.3",
license="Apache-2.0",
description="A powerful messaging platform for modern developers",
long_description=long_description,
@@ -17,7 +17,7 @@
author="Memphis.dev",
author_email="team@memphis.dev",
url="https://github.com/memphisdev/memphis.py",
- download_url="https://github.com/memphisdev/memphis.py/archive/refs/tags/v0.3.2.tar.gz",
+ download_url="https://github.com/memphisdev/memphis.py/archive/refs/tags/v0.3.3.tar.gz",
keywords=["message broker", "devtool", "streaming", "data"],
install_requires=["asyncio", "nats-py", "protobuf", "jsonschema", "graphql-core"],
classifiers=[
diff --git a/version.conf b/version.conf
index 9fc80f9..87a0871 100644
--- a/version.conf
+++ b/version.conf
@@ -1 +1 @@
-0.3.2
\ No newline at end of file
+0.3.3
\ No newline at end of file