diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000..8ace4a7 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# formatted using black & isort +b2692e213c7ef62882a1b9b7c95affff3246b036 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a7de200 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,9 @@ +repos: +- repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black +- repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..4b923d9 --- /dev/null +++ b/Makefile @@ -0,0 +1,11 @@ +PRE_COMMIT = pre-commit +PRE_COMMIT_RUN_ARGS = --all-files +PRE_COMMIT_INSTALL_ARGS = --install-hooks + +.PHONY: lint +lint: + $(PRE_COMMIT) run $(PRE_COMMIT_RUN_ARGS) + +.PHONY: pre-commit-install +pre-commit-install: + $(PRE_COMMIT) install $(PRE_COMMIT_INSTALL_ARGS) \ No newline at end of file diff --git a/examples/consumer.py b/examples/consumer.py index 0ce801f..1ddd034 100644 --- a/examples/consumer.py +++ b/examples/consumer.py @@ -1,5 +1,6 @@ import asyncio -from memphis import Memphis, MemphisError, MemphisConnectError, MemphisHeaderError + +from memphis import Memphis, MemphisConnectError, MemphisError, MemphisHeaderError async def main(): @@ -17,11 +18,18 @@ async def msg_handler(msgs, error, context): try: memphis = Memphis() - await memphis.connect(host="", username="", connection_token="") + await memphis.connect( + host="", + username="", + connection_token="", + ) consumer = await memphis.consumer( - station_name="", consumer_name="", consumer_group="") - + station_name="", + consumer_name="", + consumer_group="", + ) + consumer.set_context({"key": "value"}) consumer.consume(msg_handler) # Keep your main thread alive so the consumer will keep receiving data @@ -33,5 +41,6 @@ async def msg_handler(msgs, error, context): finally: await memphis.close() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/producer.py b/examples/producer.py index eebf91d..5f3d95a 100644 --- a/examples/producer.py +++ b/examples/producer.py @@ -1,24 +1,46 @@ import asyncio -from memphis import Memphis, Headers, MemphisError, MemphisConnectError, MemphisHeaderError, MemphisSchemaError + +from memphis import ( + Headers, + Memphis, + MemphisConnectError, + MemphisError, + MemphisHeaderError, + MemphisSchemaError, +) async def main(): try: memphis = Memphis() - await memphis.connect(host="", username="", connection_token="") + await memphis.connect( + host="", + username="", + connection_token="", + ) producer = await memphis.producer( - station_name="", producer_name="") + station_name="", producer_name="" + ) headers = Headers() - headers.add("key", "value") + headers.add("key", "value") for i in range(5): - await producer.produce(bytearray('Message #'+str(i)+': Hello world', 'utf-8'), headers=headers) # you can send the message parameter as dict as well + await producer.produce( + bytearray("Message #" + str(i) + ": Hello world", "utf-8"), + headers=headers, + ) # you can send the message parameter as dict as well - except (MemphisError, MemphisConnectError, MemphisHeaderError, MemphisSchemaError) as e: + except ( + MemphisError, + MemphisConnectError, + MemphisHeaderError, + MemphisSchemaError, + ) as e: print(e) finally: await memphis.close() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/memphis/__init__.py b/memphis/__init__.py index 475d518..ff8bd5f 100644 --- a/memphis/__init__.py +++ b/memphis/__init__.py @@ -12,6 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from memphis.memphis import Memphis, Headers, MemphisError, MemphisConnectError, MemphisSchemaError, MemphisHeaderError import memphis.retention_types import memphis.storage_types +from memphis.memphis import ( + Headers, + Memphis, + MemphisConnectError, + MemphisError, + MemphisHeaderError, + MemphisSchemaError, +) diff --git a/memphis/memphis.py b/memphis/memphis.py index 2c80742..f2e6c31 100644 --- a/memphis/memphis.py +++ b/memphis/memphis.py @@ -1,4 +1,24 @@ from __future__ import annotations + +import asyncio +import json +import random +import ssl +import time +from threading import Timer +from typing import Callable, Iterable, Union + +import graphql +import memphis.retention_types as retention_types +import memphis.storage_types as storage_types +import nats as broker +from google.protobuf import descriptor_pb2, descriptor_pool +from google.protobuf.message_factory import MessageFactory +from graphql import build_schema as build_graphql_schema +from graphql import parse as parse_graphql +from graphql import validate as validate_graphql +from jsonschema import validate + # Credit for The NATS.IO Authors # Copyright 2021-2022 The Memphis Authors # Licensed under the Apache License, Version 2.0 (the “License”); @@ -13,41 +33,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Callable, Union -import random -import json -import ssl -import time -from graphql import build_schema as build_graphql_schema, parse as parse_graphql, validate as validate_graphql -import graphql - -import nats as broker -from threading import Timer -import asyncio +schemaVFailAlertType = "schema_validation_fail_alert" -from jsonschema import validate -from google.protobuf import descriptor_pb2, descriptor_pool -from google.protobuf.message_factory import MessageFactory -import memphis.retention_types as retention_types -import memphis.storage_types as storage_types - -schemaVFailAlertType = 'schema_validation_fail_alert' - - -class set_interval(): +class set_interval: def __init__(self, func: Callable, sec: int): def func_wrapper(): self.t = Timer(sec, func_wrapper) self.t.start() func() + self.t = Timer(sec, func_wrapper) self.t.start() def cancel(self): self.t.cancel() + class Headers: def __init__(self): self.headers = {} @@ -63,8 +66,7 @@ def add(self, key, value): if not key.startswith("$memphis"): self.headers[key] = value else: - raise MemphisHeaderError( - "Keys in headers should not start with $memphis") + raise MemphisHeaderError("Keys in headers should not start with $memphis") class Memphis: @@ -88,26 +90,45 @@ async def get_msgs_update_configurations(self, iterable: Iterable): async for msg in iterable: message = msg.data.decode("utf-8") data = json.loads(message) - if data['type'] == 'send_notification': - self.cluster_configurations[data['type']] = data['update'] - elif data['type'] == 'schemaverse_to_dls': - self.station_schemaverse_to_dls[data['station_name'] - ] = data['update'] + if data["type"] == "send_notification": + self.cluster_configurations[data["type"]] = data["update"] + elif data["type"] == "schemaverse_to_dls": + self.station_schemaverse_to_dls[data["station_name"]] = data[ + "update" + ] except Exception as err: raise MemphisError(err) async def configurations_listener(self): try: - sub = await self.broker_manager.subscribe("$memphis_sdk_configurations_updates") + sub = await self.broker_manager.subscribe( + "$memphis_sdk_configurations_updates" + ) self.update_configurations_sub = sub loop = asyncio.get_event_loop() - task = loop.create_task(self.get_msgs_update_configurations( - self.update_configurations_sub.messages)) + task = loop.create_task( + self.get_msgs_update_configurations( + self.update_configurations_sub.messages + ) + ) self.configuration_tasks = task except Exception as err: raise MemphisError(err) - async def connect(self, host: str, username: str, connection_token: str, port: int = 6666, reconnect: bool = True, max_reconnect: int = 10, reconnect_interval_ms: int = 1500, timeout_ms: int = 15000, cert_file: str = "", key_file: str = "", ca_file: str = ""): + async def connect( + self, + host: str, + username: str, + connection_token: str, + port: int = 6666, + reconnect: bool = True, + max_reconnect: int = 10, + reconnect_interval_ms: int = 1500, + timeout_ms: int = 15000, + cert_file: str = "", + key_file: str = "", + ca_file: str = "", + ): """Creates connection with Memphis. Args: host (str): memphis host. @@ -132,34 +153,37 @@ async def connect(self, host: str, username: str, connection_token: str, port: i self.timeout_ms = timeout_ms self.connection_id = self.__generateConnectionID() try: - if (cert_file != '' or key_file != '' or ca_file != ''): - if cert_file == '': + if cert_file != "" or key_file != "" or ca_file != "": + if cert_file == "": raise MemphisConnectError("Must provide a TLS cert file") - if key_file == '': + if key_file == "": raise MemphisConnectError("Must provide a TLS key file") - if ca_file == '': + if ca_file == "": raise MemphisConnectError("Must provide a TLS ca file") - ssl_ctx = ssl.create_default_context( - purpose=ssl.Purpose.SERVER_AUTH) + ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) ssl_ctx.load_verify_locations(ca_file) ssl_ctx.load_cert_chain(certfile=cert_file, keyfile=key_file) - self.broker_manager = await broker.connect(servers=self.host+":"+str(self.port), - allow_reconnect=self.reconnect, - reconnect_time_wait=self.reconnect_interval_ms/1000, - connect_timeout=self.timeout_ms/1000, - max_reconnect_attempts=self.max_reconnect, - token=self.connection_token, - name=self.connection_id + "::" + self.username, - tls=ssl_ctx, - tls_hostname=self.host) + self.broker_manager = await broker.connect( + servers=self.host + ":" + str(self.port), + allow_reconnect=self.reconnect, + reconnect_time_wait=self.reconnect_interval_ms / 1000, + connect_timeout=self.timeout_ms / 1000, + max_reconnect_attempts=self.max_reconnect, + token=self.connection_token, + name=self.connection_id + "::" + self.username, + tls=ssl_ctx, + tls_hostname=self.host, + ) else: - self.broker_manager = await broker.connect(servers=self.host+":"+str(self.port), - allow_reconnect=self.reconnect, - reconnect_time_wait=self.reconnect_interval_ms/1000, - connect_timeout=self.timeout_ms/1000, - max_reconnect_attempts=self.max_reconnect, - token=self.connection_token, - name=self.connection_id + "::" + self.username) + self.broker_manager = await broker.connect( + servers=self.host + ":" + str(self.port), + allow_reconnect=self.reconnect, + reconnect_time_wait=self.reconnect_interval_ms / 1000, + connect_timeout=self.timeout_ms / 1000, + max_reconnect_attempts=self.max_reconnect, + token=self.connection_token, + name=self.connection_id + "::" + self.username, + ) await self.configurations_listener() self.broker_connection = self.broker_manager.jetstream() @@ -168,17 +192,22 @@ async def connect(self, host: str, username: str, connection_token: str, port: i raise MemphisConnectError(str(e)) from e async def send_notification(self, title, msg, failedMsg, type): - msg = { - "title": title, - "msg": msg, - "type": type, - "code": failedMsg - } - msgToSend = json.dumps(msg).encode('utf-8') + msg = {"title": title, "msg": msg, "type": type, "code": failedMsg} + msgToSend = json.dumps(msg).encode("utf-8") await self.broker_manager.publish("$memphis_notifications", msgToSend) - async def station(self, name: str, - retention_type: str = retention_types.MAX_MESSAGE_AGE_SECONDS, retention_value: int = 604800, storage_type: str = storage_types.DISK, replicas: int = 1, idempotency_window_ms: int = 120000, schema_name: str = "", send_poison_msg_to_dls: bool = True, send_schema_failed_msg_to_dls: bool = True,): + async def station( + self, + name: str, + retention_type: str = retention_types.MAX_MESSAGE_AGE_SECONDS, + retention_value: int = 604800, + storage_type: str = storage_types.DISK, + replicas: int = 1, + idempotency_window_ms: int = 120000, + schema_name: str = "", + send_poison_msg_to_dls: bool = True, + send_schema_failed_msg_to_dls: bool = True, + ): """Creates a station. Args: name (str): station name. @@ -205,13 +234,16 @@ async def station(self, name: str, "schema_name": schema_name, "dls_configuration": { "poison": send_poison_msg_to_dls, - "Schemaverse": send_schema_failed_msg_to_dls + "Schemaverse": send_schema_failed_msg_to_dls, }, - "username": self.username + "username": self.username, } - create_station_req_bytes = json.dumps( - createStationReq, indent=2).encode('utf-8') - err_msg = await self.broker_manager.request("$memphis_station_creations", create_station_req_bytes, timeout=5) + create_station_req_bytes = json.dumps(createStationReq, indent=2).encode( + "utf-8" + ) + err_msg = await self.broker_manager.request( + "$memphis_station_creations", create_station_req_bytes, timeout=5 + ) err_msg = err_msg.data.decode("utf-8") if err_msg != "": @@ -219,7 +251,7 @@ async def station(self, name: str, return Station(self, name) except Exception as e: - if str(e).find('already exist') != -1: + if str(e).find("already exist") != -1: return Station(self, name.lower()) else: raise MemphisError(str(e)) from e @@ -233,15 +265,13 @@ async def attach_schema(self, name, stationName): Exception: _description_ """ try: - if name == '' or stationName == '': + if name == "" or stationName == "": raise MemphisError("name and station name can not be empty") - msg = { - "name": name, - "station_name": stationName, - "username": self.username - } - msgToSend = json.dumps(msg).encode('utf-8') - err_msg = await self.broker_manager.request("$memphis_schema_attachments", msgToSend, timeout=5) + msg = {"name": name, "station_name": stationName, "username": self.username} + msgToSend = json.dumps(msg).encode("utf-8") + err_msg = await self.broker_manager.request( + "$memphis_schema_attachments", msgToSend, timeout=5 + ) err_msg = err_msg.data.decode("utf-8") if err_msg != "": @@ -257,14 +287,13 @@ async def detach_schema(self, stationName): Exception: _description_ """ try: - if stationName == '': + if stationName == "": raise MemphisError("station name is missing") - msg = { - "station_name": stationName, - "username": self.username - } - msgToSend = json.dumps(msg).encode('utf-8') - err_msg = await self.broker_manager.request("$memphis_schema_detachments", msgToSend, timeout=5) + msg = {"station_name": stationName, "username": self.username} + msgToSend = json.dumps(msg).encode("utf-8") + err_msg = await self.broker_manager.request( + "$memphis_schema_detachments", msgToSend, timeout=5 + ) err_msg = err_msg.data.decode("utf-8") if err_msg != "": @@ -273,8 +302,7 @@ async def detach_schema(self, stationName): raise MemphisError(str(e)) from e async def close(self): - """Close Memphis connection. - """ + """Close Memphis connection.""" try: if self.is_connection_active: await self.broker_manager.close() @@ -310,14 +338,19 @@ def __generateConnectionID(self): return random_bytes(24) def __normalize_host(self, host): - if (host.startswith("http://")): + if host.startswith("http://"): return host.split("http://")[1] - elif (host.startswith("https://")): + elif host.startswith("https://"): return host.split("https://")[1] else: return host - async def producer(self, station_name: str, producer_name: str, generate_random_suffix: bool =False): + async def producer( + self, + station_name: str, + producer_name: str, + generate_random_suffix: bool = False, + ): """Creates a producer. Args: station_name (str): station name to produce messages into. @@ -340,33 +373,51 @@ async def producer(self, station_name: str, producer_name: str, generate_random_ "connection_id": self.connection_id, "producer_type": "application", "req_version": 1, - "username": self.username + "username": self.username, } - create_producer_req_bytes = json.dumps( - createProducerReq, indent=2).encode('utf-8') - create_res = await self.broker_manager.request("$memphis_producer_creations", create_producer_req_bytes, timeout=5) + create_producer_req_bytes = json.dumps(createProducerReq, indent=2).encode( + "utf-8" + ) + create_res = await self.broker_manager.request( + "$memphis_producer_creations", create_producer_req_bytes, timeout=5 + ) create_res = create_res.data.decode("utf-8") create_res = json.loads(create_res) - if create_res['error'] != "": - raise MemphisError(create_res['error']) + if create_res["error"] != "": + raise MemphisError(create_res["error"]) station_name_internal = get_internal_name(station_name) - self.station_schemaverse_to_dls[station_name_internal] = create_res['schemaverse_to_dls'] - self.cluster_configurations['send_notification'] = create_res['send_notification'] - await self.start_listen_for_schema_updates(station_name_internal, create_res['schema_update']) + self.station_schemaverse_to_dls[station_name_internal] = create_res[ + "schemaverse_to_dls" + ] + self.cluster_configurations["send_notification"] = create_res[ + "send_notification" + ] + await self.start_listen_for_schema_updates( + station_name_internal, create_res["schema_update"] + ) if self.schema_updates_data[station_name_internal] != {}: - if self.schema_updates_data[station_name_internal]['type'] == "protobuf": + if ( + self.schema_updates_data[station_name_internal]["type"] + == "protobuf" + ): self.parse_descriptor(station_name_internal) - if self.schema_updates_data[station_name_internal]['type'] == "json": - schema = self.schema_updates_data[station_name_internal]['active_version']['schema_content'] - self.json_schemas[station_name_internal] = json.loads( - schema) - elif self.schema_updates_data[station_name_internal]['type'] == "graphql": + if self.schema_updates_data[station_name_internal]["type"] == "json": + schema = self.schema_updates_data[station_name_internal][ + "active_version" + ]["schema_content"] + self.json_schemas[station_name_internal] = json.loads(schema) + elif ( + self.schema_updates_data[station_name_internal]["type"] == "graphql" + ): self.graphql_schemas[station_name_internal] = build_graphql_schema( - self.schema_updates_data[station_name_internal]['active_version']['schema_content']) + self.schema_updates_data[station_name_internal][ + "active_version" + ]["schema_content"] + ) producer = Producer(self, producer_name, station_name, real_name) - map_key = station_name_internal+"_"+real_name + map_key = station_name_internal + "_" + real_name self.producers_map[map_key] = producer return producer @@ -377,17 +428,21 @@ async def get_msg_schema_updates(self, station_name_internal, iterable): async for msg in iterable: message = msg.data.decode("utf-8") message = json.loads(message) - if message['init']['schema_name'] == "": + if message["init"]["schema_name"] == "": data = {} else: - data = message['init'] + data = message["init"] self.schema_updates_data[station_name_internal] = data self.parse_descriptor(station_name_internal) def parse_descriptor(self, station_name): try: - descriptor = self.schema_updates_data[station_name]['active_version']['descriptor'] - msg_struct_name = self.schema_updates_data[station_name]['active_version']['message_struct_name'] + descriptor = self.schema_updates_data[station_name]["active_version"][ + "descriptor" + ] + msg_struct_name = self.schema_updates_data[station_name]["active_version"][ + "message_struct_name" + ] desc_set = descriptor_pb2.FileDescriptorSet() descriptor_bytes = str.encode(descriptor) desc_set.ParseFromString(descriptor_bytes) @@ -398,7 +453,8 @@ def parse_descriptor(self, station_name): if pkg_name != "": msg_name = desc_set.file[0].package + "." + msg_struct_name proto_msg = MessageFactory(pool).GetPrototype( - pool.FindMessageTypeByName(msg_name)) + pool.FindMessageTypeByName(msg_name) + ) proto = proto_msg() self.proto_msgs[station_name] = proto @@ -408,7 +464,7 @@ def parse_descriptor(self, station_name): async def start_listen_for_schema_updates(self, station_name, schema_update_data): schema_updates_subject = "$memphis_schema_updates_" + station_name - empty = schema_update_data['schema_name'] == '' + empty = schema_update_data["schema_name"] == "" if empty: self.schema_updates_data[station_name] = {} else: @@ -424,11 +480,27 @@ async def start_listen_for_schema_updates(self, station_name, schema_update_data task_exists = self.schema_tasks.get(station_name) if not task_exists: loop = asyncio.get_event_loop() - task = loop.create_task(self.get_msg_schema_updates( - station_name, self.schema_updates_subs[station_name].messages)) + task = loop.create_task( + self.get_msg_schema_updates( + station_name, self.schema_updates_subs[station_name].messages + ) + ) self.schema_tasks[station_name] = task - async def consumer(self, station_name: str, consumer_name: str, consumer_group: str ="", pull_interval_ms: int = 1000, batch_size: int = 10, batch_max_time_to_wait_ms: int =5000, max_ack_time_ms: int=30000, max_msg_deliveries: int=10, generate_random_suffix: bool=False, start_consume_from_sequence: int=1, last_messages: int=-1): + async def consumer( + self, + station_name: str, + consumer_name: str, + consumer_group: str = "", + pull_interval_ms: int = 1000, + batch_size: int = 10, + batch_max_time_to_wait_ms: int = 5000, + max_ack_time_ms: int = 30000, + max_msg_deliveries: int = 10, + generate_random_suffix: bool = False, + start_consume_from_sequence: int = 1, + last_messages: int = -1, + ): """Creates a consumer. Args:. station_name (str): station name to consume messages from. @@ -454,40 +526,69 @@ async def consumer(self, station_name: str, consumer_name: str, consumer_group: cg = consumer_name if not consumer_group else consumer_group if start_consume_from_sequence <= 0: - raise MemphisError("start_consume_from_sequence has to be a positive number") + raise MemphisError( + "start_consume_from_sequence has to be a positive number" + ) if last_messages < -1: raise MemphisError("min value for last_messages is -1") - if start_consume_from_sequence > 1 and last_messages > -1 : - raise MemphisError("Consumer creation options can't contain both start_consume_from_sequence and last_messages") + if start_consume_from_sequence > 1 and last_messages > -1: + raise MemphisError( + "Consumer creation options can't contain both start_consume_from_sequence and last_messages" + ) createConsumerReq = { - 'name': consumer_name, + "name": consumer_name, "station_name": station_name, "connection_id": self.connection_id, - "consumer_type": 'application', + "consumer_type": "application", "consumers_group": consumer_group, "max_ack_time_ms": max_ack_time_ms, "max_msg_deliveries": max_msg_deliveries, "start_consume_from_sequence": start_consume_from_sequence, "last_messages": last_messages, "req_version": 1, - "username": self.username + "username": self.username, } - create_consumer_req_bytes = json.dumps( - createConsumerReq, indent=2).encode('utf-8') - err_msg = await self.broker_manager.request("$memphis_consumer_creations", create_consumer_req_bytes, timeout=5) + create_consumer_req_bytes = json.dumps(createConsumerReq, indent=2).encode( + "utf-8" + ) + err_msg = await self.broker_manager.request( + "$memphis_consumer_creations", create_consumer_req_bytes, timeout=5 + ) err_msg = err_msg.data.decode("utf-8") if err_msg != "": raise MemphisError(err_msg) - return Consumer(self, station_name, consumer_name, cg, pull_interval_ms, batch_size, batch_max_time_to_wait_ms, max_ack_time_ms, max_msg_deliveries, start_consume_from_sequence=start_consume_from_sequence, last_messages=last_messages) + return Consumer( + self, + station_name, + consumer_name, + cg, + pull_interval_ms, + batch_size, + batch_max_time_to_wait_ms, + max_ack_time_ms, + max_msg_deliveries, + start_consume_from_sequence=start_consume_from_sequence, + last_messages=last_messages, + ) except Exception as e: raise MemphisError(str(e)) from e - async def produce(self, station_name: str, producer_name: str, message, generate_random_suffix: bool =False, ack_wait_sec: int = 15, headers: Union[Headers, None] = None, async_produce: bool=False, msg_id: Union[str, None]= None): + async def produce( + self, + station_name: str, + producer_name: str, + message, + generate_random_suffix: bool = False, + ack_wait_sec: int = 15, + headers: Union[Headers, None] = None, + async_produce: bool = False, + msg_id: Union[str, None] = None, + ): """Produces a message into a station without the need to create a producer. Args: station_name (str): station name to produce messages into. @@ -503,13 +604,23 @@ async def produce(self, station_name: str, producer_name: str, message, generate """ try: station_name_internal = get_internal_name(station_name) - map_key = station_name_internal+"_"+producer_name.lower() + map_key = station_name_internal + "_" + producer_name.lower() producer = None if map_key in self.producers_map: producer = self.producers_map[map_key] else: - producer = await self.producer(station_name=station_name, producer_name=producer_name, generate_random_suffix=generate_random_suffix) - await producer.produce(message=message, ack_wait_sec=ack_wait_sec, headers=headers, async_produce= async_produce, msg_id=msg_id) + producer = await self.producer( + station_name=station_name, + producer_name=producer_name, + generate_random_suffix=generate_random_suffix, + ) + await producer.produce( + message=message, + ack_wait_sec=ack_wait_sec, + headers=headers, + async_produce=async_produce, + msg_id=msg_id, + ) except Exception as e: raise MemphisError(str(e)) from e @@ -523,22 +634,19 @@ def __init__(self, connection, name: str): self.name = name.lower() async def destroy(self): - """Destroy the station. - """ + """Destroy the station.""" try: - nameReq = { - "station_name": self.name, - "username": self.connection.username - } - station_name = json.dumps(nameReq, indent=2).encode('utf-8') - res = await self.connection.broker_manager.request('$memphis_station_destructions', station_name, timeout=5) - error = res.data.decode('utf-8') + nameReq = {"station_name": self.name, "username": self.connection.username} + station_name = json.dumps(nameReq, indent=2).encode("utf-8") + res = await self.connection.broker_manager.request( + "$memphis_station_destructions", station_name, timeout=5 + ) + error = res.data.decode("utf-8") if error != "" and not "not exist" in error: raise MemphisError(error) station_name_internal = get_internal_name(self.name) - sub = self.connection.schema_updates_subs.get( - station_name_internal) + sub = self.connection.schema_updates_subs.get(station_name_internal) task = self.connection.schema_tasks.get(station_name_internal) if station_name_internal in self.connection.schema_updates_data: del self.connection.schema_updates_data[station_name_internal] @@ -553,7 +661,11 @@ async def destroy(self): if sub is not None: await sub.unsubscribe() - self.connection.producers_map = {k: v for k, v in self.connection.producers_map.items() if self.name not in k} + self.connection.producers_map = { + k: v + for k, v in self.connection.producers_map.items() + if self.name not in k + } except Exception as e: raise MemphisError(str(e)) from e @@ -565,7 +677,9 @@ def get_internal_name(name: str) -> str: class Producer: - def __init__(self, connection, producer_name: str, station_name: str, real_name: str): + def __init__( + self, connection, producer_name: str, station_name: str, real_name: str + ): self.connection = connection self.producer_name = producer_name.lower() self.station_name = station_name @@ -575,7 +689,9 @@ def __init__(self, connection, producer_name: str, station_name: str, real_name: async def validate_msg(self, message): if self.connection.schema_updates_data[self.internal_station_name] != {}: - schema_type = self.connection.schema_updates_data[self.internal_station_name]['type'] + schema_type = self.connection.schema_updates_data[ + self.internal_station_name + ]["type"] if schema_type == "protobuf": message = self.validate_protobuf(message) return message @@ -585,11 +701,11 @@ async def validate_msg(self, message): elif schema_type == "graphql": message = self.validate_graphql(message) return message - elif not isinstance(message, bytearray) and not isinstance(message, dict) : + elif not isinstance(message, bytearray) and not isinstance(message, dict): raise MemphisSchemaError("Unsupported message type") else: if isinstance(message, dict): - message = bytearray(json.dumps(message).encode('utf-8')) + message = bytearray(json.dumps(message).encode("utf-8")) return message def validate_protobuf(self, message): @@ -603,8 +719,8 @@ def validate_protobuf(self, message): proto_msg.SerializeToString() msgToSend = msgToSend.decode("utf-8") except Exception as e: - if 'parsing message' in str(e): - e = 'Invalid message format, expecting protobuf' + if "parsing message" in str(e): + e = "Invalid message format, expecting protobuf" raise MemphisSchemaError(str(e)) return message elif hasattr(message, "SerializeToString"): @@ -628,12 +744,14 @@ def validate_json_schema(self, message): raise Exception("Expecting Json format: " + str(e)) elif isinstance(message, dict): message_obj = message - message = bytearray(json.dumps(message_obj).encode('utf-8')) + message = bytearray(json.dumps(message_obj).encode("utf-8")) else: raise Exception("Unsupported message type") - validate(instance=message_obj, - schema=self.connection.json_schemas[self.internal_station_name]) + validate( + instance=message_obj, + schema=self.connection.json_schemas[self.internal_station_name], + ) return message except Exception as e: raise MemphisSchemaError("Schema validation has failed: " + str(e)) @@ -645,28 +763,36 @@ def validate_graphql(self, message): msg = parse_graphql(msg) elif isinstance(message, str): msg = parse_graphql(message) - message = message.encode('utf-8') + message = message.encode("utf-8") elif isinstance(message, graphql.language.ast.DocumentNode): msg = message message = str(msg.loc.source.body) - message = message.encode('utf-8') + message = message.encode("utf-8") else: raise MemphisError("Unsupported message type") validate_res = validate_graphql( - schema=self.connection.graphql_schemas[self.internal_station_name], document_ast=msg) + schema=self.connection.graphql_schemas[self.internal_station_name], + document_ast=msg, + ) if len(validate_res) > 0: - raise Exception( - "Schema validation has failed: " + str(validate_res)) + raise Exception("Schema validation has failed: " + str(validate_res)) return message except Exception as e: - if 'Syntax Error' in str(e): + if "Syntax Error" in str(e): e = "Invalid message format, expected GraphQL" raise Exception("Schema validation has failed: " + str(e)) def get_dls_msg_id(self, station_name: str, producer_name: str, unix_time: str): - return station_name + '~' + producer_name + '~0~' + unix_time - - async def produce(self, message, ack_wait_sec: int = 15, headers: Union[Headers, None] = None, async_produce: bool=False, msg_id: Union[str, None]= None): + return station_name + "~" + producer_name + "~0~" + unix_time + + async def produce( + self, + message, + ack_wait_sec: int = 15, + headers: Union[Headers, None] = None, + async_produce: bool = False, + msg_id: Union[str, None] = None, + ): """Produces a message into a station. Args: message (bytearray/dict): message to send into the station - bytearray/protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema) @@ -682,7 +808,8 @@ async def produce(self, message, ack_wait_sec: int = 15, headers: Union[Headers, memphis_headers = { "$memphis_producedBy": self.producer_name, - "$memphis_connectionId": self.connection.connection_id} + "$memphis_connectionId": self.connection.connection_id, + } if msg_id is not None: memphis_headers["msg-id"] = msg_id @@ -695,32 +822,52 @@ async def produce(self, message, ack_wait_sec: int = 15, headers: Union[Headers, if async_produce: try: - self.loop.create_task(self.connection.broker_connection.publish( - self.internal_station_name + ".final", message, timeout=ack_wait_sec, headers=headers)) + self.loop.create_task( + self.connection.broker_connection.publish( + self.internal_station_name + ".final", + message, + timeout=ack_wait_sec, + headers=headers, + ) + ) await asyncio.sleep(1) except Exception as e: raise MemphisError(e) else: - await self.connection.broker_connection.publish(self.internal_station_name + ".final", message, timeout=ack_wait_sec, headers=headers) + await self.connection.broker_connection.publish( + self.internal_station_name + ".final", + message, + timeout=ack_wait_sec, + headers=headers, + ) except Exception as e: - if hasattr(e, 'status_code') and e.status_code == '503': + if hasattr(e, "status_code") and e.status_code == "503": raise MemphisError( - "Produce operation has failed, please check whether Station/Producer are still exist") + "Produce operation has failed, please check whether Station/Producer are still exist" + ) else: - if ("Schema validation has failed" in str(e) or "Unsupported message type" in str(e)): + if "Schema validation has failed" in str( + e + ) or "Unsupported message type" in str(e): msgToSend = "" if isinstance(message, bytearray): - msgToSend = str(message, 'utf-8') + msgToSend = str(message, "utf-8") elif hasattr(message, "SerializeToString"): msgToSend = message.SerializeToString().decode("utf-8") - if self.connection.station_schemaverse_to_dls[self.internal_station_name]: + if self.connection.station_schemaverse_to_dls[ + self.internal_station_name + ]: unix_time = int(time.time()) id = self.get_dls_msg_id( - self.internal_station_name, self.producer_name, str(unix_time)) + self.internal_station_name, + self.producer_name, + str(unix_time), + ) memphis_headers = { "$memphis_producedBy": self.producer_name, - "$memphis_connectionId": self.connection.connection_id} + "$memphis_connectionId": self.connection.connection_id, + } if headers != {}: headers = headers.headers @@ -728,51 +875,72 @@ async def produce(self, message, ack_wait_sec: int = 15, headers: Union[Headers, else: headers = memphis_headers - msgToSendEncoded = msgToSend.encode('utf-8') + msgToSendEncoded = msgToSend.encode("utf-8") msgHex = msgToSendEncoded.hex() buf = { "_id": id, "station_name": self.internal_station_name, "producer": { "name": self.producer_name, - "connection_id": self.connection.connection_id + "connection_id": self.connection.connection_id, }, "creation_unix": unix_time, "message": { "data": msgHex, "headers": headers, - } + }, } - buf = json.dumps(buf).encode('utf-8') - await self.connection.broker_connection.publish('$memphis-' + self.internal_station_name + '-dls.schema.' + id, buf) - if self.connection.cluster_configurations.get('send_notification'): - await self.connection.send_notification('Schema validation has failed', 'Station: ' + self.station_name + '\nProducer: ' + self.producer_name + '\nError:' + str(e), msgToSend, schemaVFailAlertType) + buf = json.dumps(buf).encode("utf-8") + await self.connection.broker_connection.publish( + "$memphis-" + + self.internal_station_name + + "-dls.schema." + + id, + buf, + ) + if self.connection.cluster_configurations.get( + "send_notification" + ): + await self.connection.send_notification( + "Schema validation has failed", + "Station: " + + self.station_name + + "\nProducer: " + + self.producer_name + + "\nError:" + + str(e), + msgToSend, + schemaVFailAlertType, + ) raise MemphisError(str(e)) from e async def destroy(self): - """Destroy the producer. - """ + """Destroy the producer.""" try: destroyProducerReq = { "name": self.producer_name, "station_name": self.station_name, - "username": self.connection.username + "username": self.connection.username, } - producer_name = json.dumps(destroyProducerReq).encode('utf-8') - res = await self.connection.broker_manager.request('$memphis_producer_destructions', producer_name, timeout=5) - error = res.data.decode('utf-8') + producer_name = json.dumps(destroyProducerReq).encode("utf-8") + res = await self.connection.broker_manager.request( + "$memphis_producer_destructions", producer_name, timeout=5 + ) + error = res.data.decode("utf-8") if error != "" and not "not exist" in error: raise Exception(error) station_name_internal = get_internal_name(self.station_name) - producer_number = self.connection.producers_per_station.get( - station_name_internal) - 1 - self.connection.producers_per_station[station_name_internal] = producer_number + producer_number = ( + self.connection.producers_per_station.get(station_name_internal) - 1 + ) + self.connection.producers_per_station[ + station_name_internal + ] = producer_number if producer_number == 0: - sub = self.connection.schema_updates_subs.get( - station_name_internal) + sub = self.connection.schema_updates_subs.get(station_name_internal) task = self.connection.schema_tasks.get(station_name_internal) if station_name_internal in self.connection.schema_updates_data: del self.connection.schema_updates_data[station_name_internal] @@ -785,7 +953,7 @@ async def destroy(self): if sub is not None: await sub.unsubscribe() - map_key = station_name_internal+"_"+self.real_name + map_key = station_name_internal + "_" + self.real_name del self.connection.producers_map[map_key] except Exception as e: @@ -797,7 +965,21 @@ async def default_error_handler(e): class Consumer: - def __init__(self, connection, station_name: str, consumer_name, consumer_group, pull_interval_ms: int, batch_size: int, batch_max_time_to_wait_ms: int, max_ack_time_ms: int, max_msg_deliveries: int=10, error_callback=None, start_consume_from_sequence: int=1, last_messages: int=-1): + def __init__( + self, + connection, + station_name: str, + consumer_name, + consumer_group, + pull_interval_ms: int, + batch_size: int, + batch_max_time_to_wait_ms: int, + max_ack_time_ms: int, + max_msg_deliveries: int = 10, + error_callback=None, + start_consume_from_sequence: int = 1, + last_messages: int = -1, + ): self.connection = connection self.station_name = station_name.lower() self.consumer_name = consumer_name.lower() @@ -813,17 +995,15 @@ def __init__(self, connection, station_name: str, consumer_name, consumer_group, self.t_ping = asyncio.create_task(self.__ping_consumer(error_callback)) self.start_consume_from_sequence = start_consume_from_sequence - self.last_messages= last_messages + self.last_messages = last_messages self.context = {} def set_context(self, context): - """Set a context (dict) that will be passed to each message handler call. - """ + """Set a context (dict) that will be passed to each message handler call.""" self.context = context def consume(self, callback): - """Consume events. - """ + """Consume events.""" self.t_consume = asyncio.create_task(self.__consume(callback)) self.t_dls = asyncio.create_task(self.__consume_dls(callback)) @@ -831,7 +1011,8 @@ async def __consume(self, callback): subject = get_internal_name(self.station_name) consumer_group = get_internal_name(self.consumer_group) self.psub = await self.connection.broker_connection.pull_subscribe( - subject + ".final", durable=consumer_group) + subject + ".final", durable=consumer_group + ) while True: if self.connection.is_connection_active and self.pull_interval_ms: try: @@ -839,12 +1020,15 @@ async def __consume(self, callback): msgs = await self.psub.fetch(self.batch_size) for msg in msgs: memphis_messages.append( - Message(msg, self.connection, self.consumer_group)) + Message(msg, self.connection, self.consumer_group) + ) await callback(memphis_messages, None, self.context) - await asyncio.sleep(self.pull_interval_ms/1000) + await asyncio.sleep(self.pull_interval_ms / 1000) except asyncio.TimeoutError: - await callback([], MemphisError("Memphis: TimeoutError"), self.context) + await callback( + [], MemphisError("Memphis: TimeoutError"), self.context + ) continue except Exception as e: if self.connection.is_connection_active: @@ -858,10 +1042,16 @@ async def __consume_dls(self, callback): subject = get_internal_name(self.station_name) consumer_group = get_internal_name(self.consumer_group) try: - subscription_name = "$memphis_dls_"+subject+"_"+consumer_group - self.consumer_dls = await self.connection.broker_manager.subscribe(subscription_name, subscription_name) + subscription_name = "$memphis_dls_" + subject + "_" + consumer_group + self.consumer_dls = await self.connection.broker_manager.subscribe( + subscription_name, subscription_name + ) async for msg in self.consumer_dls.messages: - await callback([Message(msg, self.connection, self.consumer_group)], None, self.context) + await callback( + [Message(msg, self.connection, self.consumer_group)], + None, + self.context, + ) except Exception as e: print("dls", e) await callback([], MemphisError(str(e))) @@ -870,16 +1060,17 @@ async def __consume_dls(self, callback): async def __ping_consumer(self, callback): while True: try: - await asyncio.sleep(self.ping_consumer_invterval_ms/1000) + await asyncio.sleep(self.ping_consumer_invterval_ms / 1000) consumer_group = get_internal_name(self.consumer_group) - await self.connection.broker_connection.consumer_info(self.station_name, consumer_group, timeout=30) + await self.connection.broker_connection.consumer_info( + self.station_name, consumer_group, timeout=30 + ) except Exception as e: await callback(e) async def destroy(self): - """Destroy the consumer. - """ + """Destroy the consumer.""" if self.t_consume is not None: self.t_consume.cancel() if self.t_dls is not None: @@ -891,12 +1082,13 @@ async def destroy(self): destroyConsumerReq = { "name": self.consumer_name, "station_name": self.station_name, - "username": self.connection.username + "username": self.connection.username, } - consumer_name = json.dumps( - destroyConsumerReq, indent=2).encode('utf-8') - res = await self.connection.broker_manager.request('$memphis_consumer_destructions', consumer_name, timeout=5) - error = res.data.decode('utf-8') + consumer_name = json.dumps(destroyConsumerReq, indent=2).encode("utf-8") + res = await self.connection.broker_manager.request( + "$memphis_consumer_destructions", consumer_name, timeout=5 + ) + error = res.data.decode("utf-8") if error != "" and not "not exist" in error: raise MemphisError(error) except Exception as e: @@ -910,19 +1102,24 @@ def __init__(self, message, connection, cg_name): self.cg_name = cg_name async def ack(self): - """Ack a message is done processing. - """ + """Ack a message is done processing.""" try: await self.message.ack() except Exception as e: - if ("$memphis_pm_id" in self.message.headers & "$memphis_pm_sequence" in self.message.headers): + if ( + "$memphis_pm_id" + in self.message.headers & "$memphis_pm_sequence" + in self.message.headers + ): try: msg = { "id": self.message.headers["$memphis_pm_id"], "sequence": self.message.headers["$memphis_pm_sequence"], } - msgToAck = json.dumps(msg).encode('utf-8') - await self.connection.broker_manager.publish("$memphis_pm_acks", msgToAck) + msgToAck = json.dumps(msg).encode("utf-8") + await self.connection.broker_manager.publish( + "$memphis_pm_acks", msgToAck + ) except Exception as er: raise MemphisConnectError(str(er)) from er else: @@ -930,31 +1127,29 @@ async def ack(self): return def get_data(self): - """Receive the message. - """ + """Receive the message.""" try: return bytearray(self.message.data) except: return def get_headers(self): - """Receive the headers. - """ + """Receive the headers.""" try: return self.message.headers except: return def get_sequence_number(self): - """Get message sequence number. - """ + """Get message sequence number.""" try: return self.message.metadata.sequence.stream except: return + def random_bytes(amount: int) -> str: - lst = [random.choice('0123456789abcdef') for n in range(amount)] + lst = [random.choice("0123456789abcdef") for n in range(amount)] s = "".join(lst) return s diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..0d8e5f4 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[tool.isort] +profile = "black" +known_third_party = ["memphis"] \ No newline at end of file diff --git a/setup.py b/setup.py index 38f249e..2fd1b73 100644 --- a/setup.py +++ b/setup.py @@ -1,40 +1,35 @@ +from pathlib import Path from setuptools import setup -from pathlib import Path + this_directory = Path(__file__).parent long_description = (this_directory / "README.md").read_text() setup( - name='memphis-py', - packages=['memphis'], - version='0.2.9', - license='Apache-2.0', - description='A powerful messaging platform for modern developers', + name="memphis-py", + packages=["memphis"], + version="0.2.9", + license="Apache-2.0", + description="A powerful messaging platform for modern developers", long_description=long_description, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", readme="README.md", - author='Memphis.dev', - author_email='team@memphis.dev', - url='https://github.com/memphisdev/memphis.py', - download_url='https://github.com/memphisdev/memphis.py/archive/refs/tags/v0.2.9.tar.gz', - keywords=['message broker', 'devtool', 'streaming', 'data'], - install_requires=[ - 'asyncio', - 'nats-py', - 'protobuf', - 'jsonschema', - 'graphql-core' - ], + author="Memphis.dev", + author_email="team@memphis.dev", + url="https://github.com/memphisdev/memphis.py", + download_url="https://github.com/memphisdev/memphis.py/archive/refs/tags/v0.2.9.tar.gz", + keywords=["message broker", "devtool", "streaming", "data"], + install_requires=["asyncio", "nats-py", "protobuf", "jsonschema", "graphql-core"], classifiers=[ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'Topic :: Software Development', - 'License :: OSI Approved :: GNU General Public License (GPL)', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Topic :: Software Development", + "License :: OSI Approved :: GNU General Public License (GPL)", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ], )