From df80be07263ef76ce06522938138f1e86bf92043 Mon Sep 17 00:00:00 2001 From: big-vi Date: Sun, 3 Sep 2023 09:00:39 +0530 Subject: [PATCH 1/5] consumer deserialize data --- memphis/consumer.py | 9 ++++--- memphis/memphis.py | 66 ++++++++++++++++++++++++++------------------- memphis/message.py | 30 ++++++++++++++++++--- 3 files changed, 70 insertions(+), 35 deletions(-) diff --git a/memphis/consumer.py b/memphis/consumer.py index c6dbdc5..21b7ec2 100644 --- a/memphis/consumer.py +++ b/memphis/consumer.py @@ -29,6 +29,7 @@ def __init__( ): self.connection = connection self.station_name = station_name.lower() + self.internal_station_name = get_internal_name(self.station_name) self.consumer_name = consumer_name.lower() self.consumer_group = consumer_group.lower() self.pull_interval_ms = pull_interval_ms @@ -122,7 +123,7 @@ async def __consume(self, callback): for msg in msgs: memphis_messages.append( - Message(msg, self.connection, self.consumer_group) + Message(msg, self.connection, self.consumer_group, self.internal_station_name) ) await callback(memphis_messages, None, self.context) await asyncio.sleep(self.pull_interval_ms / 1000) @@ -153,12 +154,12 @@ async def __consume_dls(self): index_to_insert %= 10000 self.dls_messages.insert( index_to_insert, Message( - msg, self.connection, self.consumer_group) + msg, self.connection, self.consumer_group, self.internal_station_name) ) self.dls_current_index += 1 if self.dls_callback_func != None: await self.dls_callback_func( - [Message(msg, self.connection, self.consumer_group)], + [Message(msg, self.connection, self.consumer_group, self.internal_station_name)], None, self.context, ) @@ -237,7 +238,7 @@ async def main(host, username, password, station): msgs = await self.psub.fetch(batch_size) for msg in msgs: messages.append( - Message(msg, self.connection, self.consumer_group)) + Message(msg, self.connection, self.consumer_group, self.internal_station_name)) return messages except Exception as e: if "timeout" not in str(e).lower(): diff --git a/memphis/memphis.py b/memphis/memphis.py index 6afe30f..4d31e6c 100644 --- a/memphis/memphis.py +++ b/memphis/memphis.py @@ -457,34 +457,8 @@ async def producer( internal_station_name, create_res["schema_update"] ) - if self.schema_updates_data[internal_station_name] != {}: - if ( - self.schema_updates_data[internal_station_name]["type"] - == "protobuf" - ): - self.parse_descriptor(internal_station_name) - if self.schema_updates_data[internal_station_name]["type"] == "json": - schema = self.schema_updates_data[internal_station_name][ - "active_version" - ]["schema_content"] - self.json_schemas[internal_station_name] = json.loads( - schema) - elif ( - self.schema_updates_data[internal_station_name]["type"] == "graphql" - ): - self.graphql_schemas[internal_station_name] = build_graphql_schema( - self.schema_updates_data[internal_station_name][ - "active_version" - ]["schema_content"] - ) - elif ( - self.schema_updates_data[internal_station_name]["type"] == "avro" - ): - schema = self.schema_updates_data[internal_station_name][ - "active_version" - ]["schema_content"] - self.avro_schemas[internal_station_name] = json.loads( - schema) + self.update_schema_data(station_name) + producer = Producer(self, producer_name, station_name, real_name) map_key = internal_station_name + "_" + real_name self.producers_map[map_key] = producer @@ -493,6 +467,37 @@ async def producer( except Exception as e: raise MemphisError(str(e)) from e + def update_schema_data(self, station_name): + internal_station_name = get_internal_name(station_name) + if self.schema_updates_data[internal_station_name] != {}: + if ( + self.schema_updates_data[internal_station_name]["type"] + == "protobuf" + ): + self.parse_descriptor(internal_station_name) + if self.schema_updates_data[internal_station_name]["type"] == "json": + schema = self.schema_updates_data[internal_station_name][ + "active_version" + ]["schema_content"] + self.json_schemas[internal_station_name] = json.loads( + schema) + elif ( + self.schema_updates_data[internal_station_name]["type"] == "graphql" + ): + self.graphql_schemas[internal_station_name] = build_graphql_schema( + self.schema_updates_data[internal_station_name][ + "active_version" + ]["schema_content"] + ) + elif ( + self.schema_updates_data[internal_station_name]["type"] == "avro" + ): + schema = self.schema_updates_data[internal_station_name][ + "active_version" + ]["schema_content"] + self.avro_schemas[internal_station_name] = json.loads( + schema) + async def get_msg_schema_updates(self, internal_station_name, iterable): async for msg in iterable: message = msg.data.decode("utf-8") @@ -556,6 +561,8 @@ async def start_listen_for_schema_updates(self, station_name, schema_update_data ) self.schema_tasks[station_name] = task + self.update_schema_data(station_name) + async def consumer( self, station_name: str, @@ -647,6 +654,9 @@ async def consumer( internal_station_name = get_internal_name(station_name) map_key = internal_station_name + "_" + real_name + await self.start_listen_for_schema_updates( + internal_station_name, creation_res["schema_update"] + ) consumer = Consumer( self, station_name, diff --git a/memphis/message.py b/memphis/message.py index 7d08db4..7336a0a 100644 --- a/memphis/message.py +++ b/memphis/message.py @@ -4,12 +4,12 @@ from memphis.exceptions import MemphisConnectError, MemphisError - class Message: - def __init__(self, message, connection, cg_name): + def __init__(self, message, connection, cg_name, station_name): self.message = message self.connection = connection self.cg_name = cg_name + self.internal_station_name = station_name async def ack(self): """Ack a message is done processing.""" @@ -38,7 +38,31 @@ async def ack(self): def get_data(self): """Receive the message.""" try: - return bytearray(self.message.data) + return json.loads(bytearray(self.message.data)) + except Exception: + return + + def get_data_deserialized(self): + """Receive the message.""" + try: + if self.connection.schema_updates_data[self.internal_station_name] != {}: + schema_type = self.connection.schema_updates_data[ + self.internal_station_name + ]["type"] + if schema_type == "protobuf": + proto_msg = self.connection.proto_msgs[self.internal_station_name] + proto_msg.ParseFromString(self.message.data) + return proto_msg + if schema_type == "avro": + return json.loads(bytearray(self.message.data)) + if schema_type == "json": + return json.loads(bytearray(self.message.data)) + if schema_type == "graphql": + message = self.message.data + decoded_str = message.decode("utf-8") + return decoded_str + else: + return self.message.data except Exception: return From c3485363d0047a90adb1b900b49773ff52aa4e09 Mon Sep 17 00:00:00 2001 From: big-vi Date: Wed, 20 Sep 2023 18:28:04 +0530 Subject: [PATCH 2/5] producer counter fix --- memphis/consumer.py | 23 +++++++++++++++++++++++ memphis/memphis.py | 10 +++++----- memphis/message.py | 6 +++--- memphis/producer.py | 4 ++-- memphis/station.py | 4 ++-- 5 files changed, 35 insertions(+), 12 deletions(-) diff --git a/memphis/consumer.py b/memphis/consumer.py index 21b7ec2..b8de3ad 100644 --- a/memphis/consumer.py +++ b/memphis/consumer.py @@ -295,6 +295,29 @@ async def destroy(self): raise MemphisError(error) self.dls_messages.clear() internal_station_name = get_internal_name(self.station_name) + producer_number = ( + self.connection.clients_per_station.get( + internal_station_name) - 1 + ) + self.connection.clients_per_station[ + internal_station_name + ] = producer_number + + if producer_number == 0: + sub = self.connection.schema_updates_subs.get( + internal_station_name) + task = self.connection.schema_tasks.get(internal_station_name) + if internal_station_name in self.connection.schema_updates_data: + del self.connection.schema_updates_data[internal_station_name] + if internal_station_name in self.connection.schema_updates_subs: + del self.connection.schema_updates_subs[internal_station_name] + if internal_station_name in self.connection.schema_tasks: + del self.connection.schema_tasks[internal_station_name] + if task is not None: + task.cancel() + if sub is not None: + await sub.unsubscribe() + map_key = internal_station_name + "_" + self.consumer_name.lower() del self.connection.consumers_map[map_key] except Exception as e: diff --git a/memphis/memphis.py b/memphis/memphis.py index 4d31e6c..e21ef6f 100644 --- a/memphis/memphis.py +++ b/memphis/memphis.py @@ -47,7 +47,7 @@ def __init__(self): self.partition_producers_updates_data = {} self.partition_consumers_updates_data = {} self.schema_updates_subs = {} - self.producers_per_station = {} + self.clients_per_station = {} self.schema_tasks = {} self.proto_msgs = {} self.graphql_schemas = {} @@ -364,8 +364,8 @@ async def close(self): del self.schema_updates_data[key] if key in self.schema_updates_subs: del self.schema_updates_subs[key] - if key in self.producers_per_station: - del self.producers_per_station[key] + if key in self.clients_per_station: + del self.clients_per_station[key] if key in self.schema_tasks: del self.schema_tasks[key] if task is not None: @@ -546,10 +546,10 @@ async def start_listen_for_schema_updates(self, station_name, schema_update_data schema_exists = self.schema_updates_subs.get(station_name) if schema_exists: - self.producers_per_station[station_name] += 1 + self.clients_per_station[station_name] += 1 else: sub = await self.broker_manager.subscribe(schema_updates_subject) - self.producers_per_station[station_name] = 1 + self.clients_per_station[station_name] = 1 self.schema_updates_subs[station_name] = sub task_exists = self.schema_tasks.get(station_name) if not task_exists: diff --git a/memphis/message.py b/memphis/message.py index 7336a0a..3966cdb 100644 --- a/memphis/message.py +++ b/memphis/message.py @@ -5,11 +5,11 @@ from memphis.exceptions import MemphisConnectError, MemphisError class Message: - def __init__(self, message, connection, cg_name, station_name): + def __init__(self, message, connection, cg_name, internal_station_name): self.message = message self.connection = connection self.cg_name = cg_name - self.internal_station_name = station_name + self.internal_station_name = internal_station_name async def ack(self): """Ack a message is done processing.""" @@ -38,7 +38,7 @@ async def ack(self): def get_data(self): """Receive the message.""" try: - return json.loads(bytearray(self.message.data)) + return bytearray(self.message.data) except Exception: return diff --git a/memphis/producer.py b/memphis/producer.py index 6310793..d933c80 100644 --- a/memphis/producer.py +++ b/memphis/producer.py @@ -368,10 +368,10 @@ async def destroy(self): internal_station_name = get_internal_name(self.station_name) producer_number = ( - self.connection.producers_per_station.get( + self.connection.clients_per_station.get( internal_station_name) - 1 ) - self.connection.producers_per_station[ + self.connection.clients_per_station[ internal_station_name ] = producer_number diff --git a/memphis/station.py b/memphis/station.py index c921f30..dbbff86 100644 --- a/memphis/station.py +++ b/memphis/station.py @@ -28,8 +28,8 @@ async def destroy(self): del self.connection.schema_updates_data[internal_station_name] if internal_station_name in self.connection.schema_updates_subs: del self.connection.schema_updates_subs[internal_station_name] - if internal_station_name in self.connection.producers_per_station: - del self.connection.producers_per_station[internal_station_name] + if internal_station_name in self.connection.clients_per_station: + del self.connection.clients_per_station[internal_station_name] if internal_station_name in self.connection.schema_tasks: del self.connection.schema_tasks[internal_station_name] if task is not None: From e5ceb708bf61616a6d3ca903b274f1e2f22cec72 Mon Sep 17 00:00:00 2001 From: big-vi Date: Fri, 22 Sep 2023 18:45:23 +0530 Subject: [PATCH 3/5] use case where broker is not upgraded and sdk is --- memphis/consumer.py | 47 +++++++++++++++++++++++---------------------- memphis/memphis.py | 10 +++++----- memphis/message.py | 4 ++-- memphis/producer.py | 6 +++--- 4 files changed, 34 insertions(+), 33 deletions(-) diff --git a/memphis/consumer.py b/memphis/consumer.py index b8de3ad..3064b29 100644 --- a/memphis/consumer.py +++ b/memphis/consumer.py @@ -294,29 +294,30 @@ async def destroy(self): if error != "" and not "not exist" in error: raise MemphisError(error) self.dls_messages.clear() - internal_station_name = get_internal_name(self.station_name) - producer_number = ( - self.connection.clients_per_station.get( - internal_station_name) - 1 - ) - self.connection.clients_per_station[ - internal_station_name - ] = producer_number - - if producer_number == 0: - sub = self.connection.schema_updates_subs.get( - internal_station_name) - task = self.connection.schema_tasks.get(internal_station_name) - if internal_station_name in self.connection.schema_updates_data: - del self.connection.schema_updates_data[internal_station_name] - if internal_station_name in self.connection.schema_updates_subs: - del self.connection.schema_updates_subs[internal_station_name] - if internal_station_name in self.connection.schema_tasks: - del self.connection.schema_tasks[internal_station_name] - if task is not None: - task.cancel() - if sub is not None: - await sub.unsubscribe() + if self.connection.schema_updates_data != {}: + internal_station_name = get_internal_name(self.station_name) + clients_number = ( + self.connection.clients_per_station.get( + internal_station_name) - 1 + ) + self.connection.clients_per_station[ + internal_station_name + ] = clients_number + + if clients_number == 0: + sub = self.connection.schema_updates_subs.get( + internal_station_name) + task = self.connection.schema_tasks.get(internal_station_name) + if internal_station_name in self.connection.schema_updates_data: + del self.connection.schema_updates_data[internal_station_name] + if internal_station_name in self.connection.schema_updates_subs: + del self.connection.schema_updates_subs[internal_station_name] + if internal_station_name in self.connection.schema_tasks: + del self.connection.schema_tasks[internal_station_name] + if task is not None: + task.cancel() + if sub is not None: + await sub.unsubscribe() map_key = internal_station_name + "_" + self.consumer_name.lower() del self.connection.consumers_map[map_key] diff --git a/memphis/memphis.py b/memphis/memphis.py index e21ef6f..820bc50 100644 --- a/memphis/memphis.py +++ b/memphis/memphis.py @@ -561,8 +561,6 @@ async def start_listen_for_schema_updates(self, station_name, schema_update_data ) self.schema_tasks[station_name] = task - self.update_schema_data(station_name) - async def consumer( self, station_name: str, @@ -654,9 +652,11 @@ async def consumer( internal_station_name = get_internal_name(station_name) map_key = internal_station_name + "_" + real_name - await self.start_listen_for_schema_updates( - internal_station_name, creation_res["schema_update"] - ) + if "schema_update" in creation_res: + await self.start_listen_for_schema_updates( + internal_station_name, creation_res["schema_update"] + ) + self.update_schema_data(station_name) consumer = Consumer( self, station_name, diff --git a/memphis/message.py b/memphis/message.py index 3966cdb..374a581 100644 --- a/memphis/message.py +++ b/memphis/message.py @@ -45,7 +45,7 @@ def get_data(self): def get_data_deserialized(self): """Receive the message.""" try: - if self.connection.schema_updates_data[self.internal_station_name] != {}: + if self.connection.schema_updates_data and self.connection.schema_updates_data[self.internal_station_name] != {}: schema_type = self.connection.schema_updates_data[ self.internal_station_name ]["type"] @@ -62,7 +62,7 @@ def get_data_deserialized(self): decoded_str = message.decode("utf-8") return decoded_str else: - return self.message.data + return bytearray(self.message.data) except Exception: return diff --git a/memphis/producer.py b/memphis/producer.py index d933c80..36780d2 100644 --- a/memphis/producer.py +++ b/memphis/producer.py @@ -367,15 +367,15 @@ async def destroy(self): raise Exception(error) internal_station_name = get_internal_name(self.station_name) - producer_number = ( + clients_number = ( self.connection.clients_per_station.get( internal_station_name) - 1 ) self.connection.clients_per_station[ internal_station_name - ] = producer_number + ] = clients_number - if producer_number == 0: + if clients_number == 0: sub = self.connection.schema_updates_subs.get( internal_station_name) task = self.connection.schema_tasks.get(internal_station_name) From 20edf38b0a3d7cdd8e5ce172dae6f21c87c6852c Mon Sep 17 00:00:00 2001 From: big-vi Date: Fri, 22 Sep 2023 18:49:35 +0530 Subject: [PATCH 4/5] use case where broker is not upgraded and sdk is --- memphis/consumer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/memphis/consumer.py b/memphis/consumer.py index 3064b29..0c9c79e 100644 --- a/memphis/consumer.py +++ b/memphis/consumer.py @@ -294,8 +294,8 @@ async def destroy(self): if error != "" and not "not exist" in error: raise MemphisError(error) self.dls_messages.clear() + internal_station_name = get_internal_name(self.station_name) if self.connection.schema_updates_data != {}: - internal_station_name = get_internal_name(self.station_name) clients_number = ( self.connection.clients_per_station.get( internal_station_name) - 1 From 68d2ff0e4a820cf4f32caf842c546ff03372434a Mon Sep 17 00:00:00 2001 From: big-vi Date: Sat, 7 Oct 2023 11:54:42 +0530 Subject: [PATCH 5/5] code refactoring and readme update --- README.md | 14 ++++ memphis/message.py | 10 ++- memphis/producer.py | 158 +------------------------------------------- memphis/station.py | 156 ++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 180 insertions(+), 158 deletions(-) diff --git a/README.md b/README.md index b4f66a5..f393c93 100644 --- a/README.md +++ b/README.md @@ -356,6 +356,20 @@ async def msg_handler(msgs, error, context): print(error) consumer.consume(msg_handler) ``` +#### Processing schema deserialized messages +To get messages deserialized, use `msg.get_data_deserialized()`. + +```python +async def msg_handler(msgs, error, context): + for msg in msgs: + print("message: ", await msg.get_data_deserialized()) + await msg.ack() + if error: + print(error) +consumer.consume(msg_handler) +``` + +if you have ingested data into station in one format, afterwards you apply a schema on the station, the consumer won't deserialize the previously ingested data. For example, you have ingested string into the station and attached a protobuf schema on the station. In this case, consumer won't deserialize the string. ### Consume using a partition key The key will be used to consume from a specific partition diff --git a/memphis/message.py b/memphis/message.py index 374a581..54441cb 100644 --- a/memphis/message.py +++ b/memphis/message.py @@ -2,7 +2,8 @@ import json -from memphis.exceptions import MemphisConnectError, MemphisError +from memphis.exceptions import MemphisConnectError, MemphisError, MemphisSchemaError +from memphis.station import Station class Message: def __init__(self, message, connection, cg_name, internal_station_name): @@ -10,6 +11,7 @@ def __init__(self, message, connection, cg_name, internal_station_name): self.connection = connection self.cg_name = cg_name self.internal_station_name = internal_station_name + self.station = Station(connection, internal_station_name) async def ack(self): """Ack a message is done processing.""" @@ -42,13 +44,17 @@ def get_data(self): except Exception: return - def get_data_deserialized(self): + async def get_data_deserialized(self): """Receive the message.""" try: if self.connection.schema_updates_data and self.connection.schema_updates_data[self.internal_station_name] != {}: schema_type = self.connection.schema_updates_data[ self.internal_station_name ]["type"] + try: + await self.station.validate_msg(bytearray(self.message.data)) + except Exception as e: + raise MemphisSchemaError("Deserialization has been failed since the message format does not align with the currently attached schema: " + str(e)) if schema_type == "protobuf": proto_msg = self.connection.proto_msgs[self.internal_station_name] proto_msg.ParseFromString(self.message.data) diff --git a/memphis/producer.py b/memphis/producer.py index eba9c52..5beced7 100644 --- a/memphis/producer.py +++ b/memphis/producer.py @@ -6,17 +6,12 @@ from typing import Union import warnings -import graphql -from graphql import parse as parse_graphql -from graphql import validate as validate_graphql -from jsonschema import validate -import google.protobuf.json_format as protobuf_json_format -import fastavro import mmh3 from memphis.exceptions import MemphisError, MemphisSchemaError from memphis.headers import Headers from memphis.utils import get_internal_name from memphis.partition_generator import PartitionGenerator +from memphis.station import Station schemaverse_fail_alert_type = "schema_validation_fail_alert" @@ -28,6 +23,7 @@ def __init__( self.connection = connection self.producer_name = producer_name.lower() self.station_name = station_name + self.station = Station(connection, station_name) self.internal_station_name = get_internal_name(self.station_name) self.loop = asyncio.get_running_loop() self.real_name = real_name @@ -35,154 +31,6 @@ def __init__( if self.internal_station_name in connection.partition_producers_updates_data: self.partition_generator = PartitionGenerator(connection.partition_producers_updates_data[self.internal_station_name]["partitions_list"]) - async def validate_msg(self, message): - if self.connection.schema_updates_data[self.internal_station_name] != {}: - schema_type = self.connection.schema_updates_data[ - self.internal_station_name - ]["type"] - if schema_type == "protobuf": - message = self.validate_protobuf(message) - return message - if schema_type == "json": - message = self.validate_json_schema(message) - return message - if schema_type == "graphql": - message = self.validate_graphql(message) - return message - if schema_type == "avro": - message = self.validate_avro_schema(message) - return message - if hasattr(message, "SerializeToString"): - msg_to_send = message.SerializeToString() - return msg_to_send - elif isinstance(message, str): - message = message.encode("utf-8") - return message - elif isinstance(message, graphql.language.ast.DocumentNode): - msg = message - message = str(msg.loc.source.body) - message = message.encode("utf-8") - return message - elif hasattr(message, "SerializeToString"): - msg_to_send = message.SerializeToString() - return msg_to_send - elif not isinstance(message, bytearray) and not isinstance(message, dict): - raise MemphisSchemaError("Unsupported message type") - else: - if isinstance(message, dict): - message = bytearray(json.dumps(message).encode("utf-8")) - return message - - def validate_protobuf(self, message): - proto_msg = self.connection.proto_msgs[self.internal_station_name] - msg_to_send = "" - try: - if isinstance(message, bytearray): - msg_to_send = bytes(message) - try: - proto_msg.ParseFromString(msg_to_send) - proto_msg.SerializeToString() - msg_to_send = msg_to_send.decode("utf-8") - except Exception as e: - if "parsing message" in str(e): - e = "Invalid message format, expecting protobuf" - raise MemphisSchemaError(str(e)) - return message - if hasattr(message, "SerializeToString"): - msg_to_send = message.SerializeToString() - proto_msg.ParseFromString(msg_to_send) - proto_msg.SerializeToString() - try: - proto_msg.ParseFromString(msg_to_send) - proto_msg.SerializeToString() - except Exception as e: - if "parsing message" in str(e): - e = "Error parsing protobuf message" - raise MemphisSchemaError(str(e)) - return msg_to_send - elif isinstance(message, dict): - try: - protobuf_json_format.ParseDict(message, proto_msg) - msg_to_send = proto_msg.SerializeToString() - return msg_to_send - except Exception as e: - raise MemphisSchemaError(str(e)) - else: - raise MemphisSchemaError("Unsupported message type") - - except Exception as e: - raise MemphisSchemaError("Schema validation has failed: " + str(e)) - - def validate_json_schema(self, message): - try: - if isinstance(message, bytearray): - try: - message_obj = json.loads(message) - except Exception as e: - raise Exception("Expecting Json format: " + str(e)) - elif isinstance(message, dict): - message_obj = message - message = bytearray(json.dumps(message_obj).encode("utf-8")) - else: - raise Exception("Unsupported message type") - - validate( - instance=message_obj, - schema=self.connection.json_schemas[self.internal_station_name], - ) - return message - except Exception as e: - raise MemphisSchemaError("Schema validation has failed: " + str(e)) - - def validate_graphql(self, message): - try: - if isinstance(message, bytearray): - msg = message.decode("utf-8") - msg = parse_graphql(msg) - elif isinstance(message, str): - msg = parse_graphql(message) - message = message.encode("utf-8") - elif isinstance(message, graphql.language.ast.DocumentNode): - msg = message - message = str(msg.loc.source.body) - message = message.encode("utf-8") - else: - raise Exception("Unsupported message type") - validate_res = validate_graphql( - schema=self.connection.graphql_schemas[self.internal_station_name], - document_ast=msg, - ) - if len(validate_res) > 0: - raise Exception( - "Schema validation has failed: " + str(validate_res)) - return message - except Exception as e: - if "Syntax Error" in str(e): - e = "Invalid message format, expected GraphQL" - raise MemphisSchemaError("Schema validation has failed: " + str(e)) - - - def validate_avro_schema(self, message): - try: - if isinstance(message, bytearray): - try: - message_obj = json.loads(message) - except Exception as e: - raise Exception("Expecting Avro format: " + str(e)) - elif isinstance(message, dict): - message_obj = message - message = bytearray(json.dumps(message_obj).encode("utf-8")) - else: - raise Exception("Unsupported message type") - - fastavro.validate( - message_obj, - self.connection.avro_schemas[self.internal_station_name], - ) - return message - except fastavro.validation.ValidationError as e: - raise MemphisSchemaError("Schema validation has failed: " + str(e)) - # pylint: disable=R0913 async def produce( self, @@ -222,7 +70,7 @@ async def produce( Exception: _description_ """ try: - message = await self.validate_msg(message) + message = await self.station.validate_msg(message) memphis_headers = { "$memphis_producedBy": self.producer_name, diff --git a/memphis/station.py b/memphis/station.py index dbbff86..26dec82 100644 --- a/memphis/station.py +++ b/memphis/station.py @@ -1,6 +1,12 @@ import json +import graphql +from graphql import parse as parse_graphql +from graphql import validate as validate_graphql +from jsonschema import validate +import google.protobuf.json_format as protobuf_json_format +import fastavro -from memphis.exceptions import MemphisError +from memphis.exceptions import MemphisError, MemphisSchemaError from memphis.utils import get_internal_name @@ -8,6 +14,154 @@ class Station: def __init__(self, connection, name: str): self.connection = connection self.name = name.lower() + self.internal_station_name = get_internal_name(self.name) + + async def validate_msg(self, message): + if self.connection.schema_updates_data[self.internal_station_name] != {}: + schema_type = self.connection.schema_updates_data[ + self.internal_station_name + ]["type"] + if schema_type == "protobuf": + message = self.validate_protobuf(message) + return message + if schema_type == "json": + message = self.validate_json_schema(message) + return message + if schema_type == "graphql": + message = self.validate_graphql(message) + return message + if schema_type == "avro": + message = self.validate_avro_schema(message) + return message + if hasattr(message, "SerializeToString"): + msg_to_send = message.SerializeToString() + return msg_to_send + elif isinstance(message, str): + message = message.encode("utf-8") + return message + elif isinstance(message, graphql.language.ast.DocumentNode): + msg = message + message = str(msg.loc.source.body) + message = message.encode("utf-8") + return message + elif hasattr(message, "SerializeToString"): + msg_to_send = message.SerializeToString() + return msg_to_send + elif not isinstance(message, bytearray) and not isinstance(message, dict): + raise MemphisSchemaError("Unsupported message type") + else: + if isinstance(message, dict): + message = bytearray(json.dumps(message).encode("utf-8")) + return message + + def validate_protobuf(self, message): + proto_msg = self.connection.proto_msgs[self.internal_station_name] + msg_to_send = "" + try: + if isinstance(message, bytearray): + msg_to_send = bytes(message) + try: + proto_msg.ParseFromString(msg_to_send) + proto_msg.SerializeToString() + msg_to_send = msg_to_send.decode("utf-8") + except Exception as e: + if "parsing message" in str(e): + e = "Invalid message format, expecting protobuf" + raise MemphisSchemaError(str(e)) + return message + if hasattr(message, "SerializeToString"): + msg_to_send = message.SerializeToString() + proto_msg.ParseFromString(msg_to_send) + proto_msg.SerializeToString() + try: + proto_msg.ParseFromString(msg_to_send) + proto_msg.SerializeToString() + except Exception as e: + if "parsing message" in str(e): + e = "Error parsing protobuf message" + raise MemphisSchemaError(str(e)) + return msg_to_send + elif isinstance(message, dict): + try: + protobuf_json_format.ParseDict(message, proto_msg) + msg_to_send = proto_msg.SerializeToString() + return msg_to_send + except Exception as e: + raise MemphisSchemaError(str(e)) + else: + raise MemphisSchemaError("Unsupported message type") + + except Exception as e: + raise MemphisSchemaError("Schema validation has failed: " + str(e)) + + def validate_json_schema(self, message): + try: + if isinstance(message, bytearray): + try: + message_obj = json.loads(message) + except Exception as e: + raise Exception("Expecting Json format: " + str(e)) + elif isinstance(message, dict): + message_obj = message + message = bytearray(json.dumps(message_obj).encode("utf-8")) + else: + raise Exception("Unsupported message type") + + validate( + instance=message_obj, + schema=self.connection.json_schemas[self.internal_station_name], + ) + return message + except Exception as e: + raise MemphisSchemaError("Schema validation has failed: " + str(e)) + + def validate_graphql(self, message): + try: + if isinstance(message, bytearray): + msg = message.decode("utf-8") + msg = parse_graphql(msg) + elif isinstance(message, str): + msg = parse_graphql(message) + message = message.encode("utf-8") + elif isinstance(message, graphql.language.ast.DocumentNode): + msg = message + message = str(msg.loc.source.body) + message = message.encode("utf-8") + else: + raise Exception("Unsupported message type") + validate_res = validate_graphql( + schema=self.connection.graphql_schemas[self.internal_station_name], + document_ast=msg, + ) + if len(validate_res) > 0: + raise Exception( + "Schema validation has failed: " + str(validate_res)) + return message + except Exception as e: + if "Syntax Error" in str(e): + e = "Invalid message format, expected GraphQL" + raise MemphisSchemaError("Schema validation has failed: " + str(e)) + + def validate_avro_schema(self, message): + try: + if isinstance(message, bytearray): + try: + message_obj = json.loads(message) + except Exception as e: + raise Exception("Expecting Avro format: " + str(e)) + elif isinstance(message, dict): + message_obj = message + message = bytearray(json.dumps(message_obj).encode("utf-8")) + else: + raise Exception("Unsupported message type") + + fastavro.validate( + message_obj, + self.connection.avro_schemas[self.internal_station_name], + ) + return message + except fastavro.validation.ValidationError as e: + raise MemphisSchemaError("Schema validation has failed: " + str(e)) async def destroy(self): """Destroy the station."""