diff --git a/README.md b/README.md index 798f6bc..845366d 100644 --- a/README.md +++ b/README.md @@ -368,6 +368,20 @@ async def msg_handler(msgs, error, context): print(error) consumer.consume(msg_handler) ``` +#### Processing schema deserialized messages +To get messages deserialized, use `msg.get_data_deserialized()`. + +```python +async def msg_handler(msgs, error, context): + for msg in msgs: + print("message: ", await msg.get_data_deserialized()) + await msg.ack() + if error: + print(error) +consumer.consume(msg_handler) +``` + +if you have ingested data into station in one format, afterwards you apply a schema on the station, the consumer won't deserialize the previously ingested data. For example, you have ingested string into the station and attached a protobuf schema on the station. In this case, consumer won't deserialize the string. ### Consume using a partition key The key will be used to consume from a specific partition diff --git a/memphis/message.py b/memphis/message.py index 374a581..54441cb 100644 --- a/memphis/message.py +++ b/memphis/message.py @@ -2,7 +2,8 @@ import json -from memphis.exceptions import MemphisConnectError, MemphisError +from memphis.exceptions import MemphisConnectError, MemphisError, MemphisSchemaError +from memphis.station import Station class Message: def __init__(self, message, connection, cg_name, internal_station_name): @@ -10,6 +11,7 @@ def __init__(self, message, connection, cg_name, internal_station_name): self.connection = connection self.cg_name = cg_name self.internal_station_name = internal_station_name + self.station = Station(connection, internal_station_name) async def ack(self): """Ack a message is done processing.""" @@ -42,13 +44,17 @@ def get_data(self): except Exception: return - def get_data_deserialized(self): + async def get_data_deserialized(self): """Receive the message.""" try: if self.connection.schema_updates_data and self.connection.schema_updates_data[self.internal_station_name] != {}: schema_type = self.connection.schema_updates_data[ self.internal_station_name ]["type"] + try: + await self.station.validate_msg(bytearray(self.message.data)) + except Exception as e: + raise MemphisSchemaError("Deserialization has been failed since the message format does not align with the currently attached schema: " + str(e)) if schema_type == "protobuf": proto_msg = self.connection.proto_msgs[self.internal_station_name] proto_msg.ParseFromString(self.message.data) diff --git a/memphis/producer.py b/memphis/producer.py index 464dd42..7d970fe 100644 --- a/memphis/producer.py +++ b/memphis/producer.py @@ -6,17 +6,12 @@ from typing import Union import warnings -import graphql -from graphql import parse as parse_graphql -from graphql import validate as validate_graphql -from jsonschema import validate -import google.protobuf.json_format as protobuf_json_format -import fastavro import mmh3 from memphis.exceptions import MemphisError, MemphisSchemaError from memphis.headers import Headers from memphis.utils import get_internal_name from memphis.partition_generator import PartitionGenerator +from memphis.station import Station schemaverse_fail_alert_type = "schema_validation_fail_alert" @@ -28,6 +23,7 @@ def __init__( self.connection = connection self.producer_name = producer_name.lower() self.station_name = station_name + self.station = Station(connection, station_name) self.internal_station_name = get_internal_name(self.station_name) self.loop = asyncio.get_running_loop() self.real_name = real_name @@ -35,154 +31,6 @@ def __init__( if self.internal_station_name in connection.partition_producers_updates_data: self.partition_generator = PartitionGenerator(connection.partition_producers_updates_data[self.internal_station_name]["partitions_list"]) - async def validate_msg(self, message): - if self.connection.schema_updates_data[self.internal_station_name] != {}: - schema_type = self.connection.schema_updates_data[ - self.internal_station_name - ]["type"] - if schema_type == "protobuf": - message = self.validate_protobuf(message) - return message - if schema_type == "json": - message = self.validate_json_schema(message) - return message - if schema_type == "graphql": - message = self.validate_graphql(message) - return message - if schema_type == "avro": - message = self.validate_avro_schema(message) - return message - if hasattr(message, "SerializeToString"): - msg_to_send = message.SerializeToString() - return msg_to_send - elif isinstance(message, str): - message = message.encode("utf-8") - return message - elif isinstance(message, graphql.language.ast.DocumentNode): - msg = message - message = str(msg.loc.source.body) - message = message.encode("utf-8") - return message - elif hasattr(message, "SerializeToString"): - msg_to_send = message.SerializeToString() - return msg_to_send - elif not isinstance(message, bytearray) and not isinstance(message, dict): - raise MemphisSchemaError("Unsupported message type") - else: - if isinstance(message, dict): - message = bytearray(json.dumps(message).encode("utf-8")) - return message - - def validate_protobuf(self, message): - proto_msg = self.connection.proto_msgs[self.internal_station_name] - msg_to_send = "" - try: - if isinstance(message, bytearray): - msg_to_send = bytes(message) - try: - proto_msg.ParseFromString(msg_to_send) - proto_msg.SerializeToString() - msg_to_send = msg_to_send.decode("utf-8") - except Exception as e: - if "parsing message" in str(e): - e = "Invalid message format, expecting protobuf" - raise MemphisSchemaError(str(e)) - return message - if hasattr(message, "SerializeToString"): - msg_to_send = message.SerializeToString() - proto_msg.ParseFromString(msg_to_send) - proto_msg.SerializeToString() - try: - proto_msg.ParseFromString(msg_to_send) - proto_msg.SerializeToString() - except Exception as e: - if "parsing message" in str(e): - e = "Error parsing protobuf message" - raise MemphisSchemaError(str(e)) - return msg_to_send - elif isinstance(message, dict): - try: - protobuf_json_format.ParseDict(message, proto_msg) - msg_to_send = proto_msg.SerializeToString() - return msg_to_send - except Exception as e: - raise MemphisSchemaError(str(e)) - else: - raise MemphisSchemaError("Unsupported message type") - - except Exception as e: - raise MemphisSchemaError("Schema validation has failed: " + str(e)) - - def validate_json_schema(self, message): - try: - if isinstance(message, bytearray): - try: - message_obj = json.loads(message) - except Exception as e: - raise Exception("Expecting Json format: " + str(e)) - elif isinstance(message, dict): - message_obj = message - message = bytearray(json.dumps(message_obj).encode("utf-8")) - else: - raise Exception("Unsupported message type") - - validate( - instance=message_obj, - schema=self.connection.json_schemas[self.internal_station_name], - ) - return message - except Exception as e: - raise MemphisSchemaError("Schema validation has failed: " + str(e)) - - def validate_graphql(self, message): - try: - if isinstance(message, bytearray): - msg = message.decode("utf-8") - msg = parse_graphql(msg) - elif isinstance(message, str): - msg = parse_graphql(message) - message = message.encode("utf-8") - elif isinstance(message, graphql.language.ast.DocumentNode): - msg = message - message = str(msg.loc.source.body) - message = message.encode("utf-8") - else: - raise Exception("Unsupported message type") - validate_res = validate_graphql( - schema=self.connection.graphql_schemas[self.internal_station_name], - document_ast=msg, - ) - if len(validate_res) > 0: - raise Exception( - "Schema validation has failed: " + str(validate_res)) - return message - except Exception as e: - if "Syntax Error" in str(e): - e = "Invalid message format, expected GraphQL" - raise MemphisSchemaError("Schema validation has failed: " + str(e)) - - - def validate_avro_schema(self, message): - try: - if isinstance(message, bytearray): - try: - message_obj = json.loads(message) - except Exception as e: - raise Exception("Expecting Avro format: " + str(e)) - elif isinstance(message, dict): - message_obj = message - message = bytearray(json.dumps(message_obj).encode("utf-8")) - else: - raise Exception("Unsupported message type") - - fastavro.validate( - message_obj, - self.connection.avro_schemas[self.internal_station_name], - ) - return message - except fastavro.validation.ValidationError as e: - raise MemphisSchemaError("Schema validation has failed: " + str(e)) - # pylint: disable=R0913 async def produce( self, @@ -224,7 +72,7 @@ async def produce( Exception: _description_ """ try: - message = await self.validate_msg(message) + message = await self.station.validate_msg(message) memphis_headers = { "$memphis_producedBy": self.producer_name, diff --git a/memphis/station.py b/memphis/station.py index dbbff86..26dec82 100644 --- a/memphis/station.py +++ b/memphis/station.py @@ -1,6 +1,12 @@ import json +import graphql +from graphql import parse as parse_graphql +from graphql import validate as validate_graphql +from jsonschema import validate +import google.protobuf.json_format as protobuf_json_format +import fastavro -from memphis.exceptions import MemphisError +from memphis.exceptions import MemphisError, MemphisSchemaError from memphis.utils import get_internal_name @@ -8,6 +14,154 @@ class Station: def __init__(self, connection, name: str): self.connection = connection self.name = name.lower() + self.internal_station_name = get_internal_name(self.name) + + async def validate_msg(self, message): + if self.connection.schema_updates_data[self.internal_station_name] != {}: + schema_type = self.connection.schema_updates_data[ + self.internal_station_name + ]["type"] + if schema_type == "protobuf": + message = self.validate_protobuf(message) + return message + if schema_type == "json": + message = self.validate_json_schema(message) + return message + if schema_type == "graphql": + message = self.validate_graphql(message) + return message + if schema_type == "avro": + message = self.validate_avro_schema(message) + return message + if hasattr(message, "SerializeToString"): + msg_to_send = message.SerializeToString() + return msg_to_send + elif isinstance(message, str): + message = message.encode("utf-8") + return message + elif isinstance(message, graphql.language.ast.DocumentNode): + msg = message + message = str(msg.loc.source.body) + message = message.encode("utf-8") + return message + elif hasattr(message, "SerializeToString"): + msg_to_send = message.SerializeToString() + return msg_to_send + elif not isinstance(message, bytearray) and not isinstance(message, dict): + raise MemphisSchemaError("Unsupported message type") + else: + if isinstance(message, dict): + message = bytearray(json.dumps(message).encode("utf-8")) + return message + + def validate_protobuf(self, message): + proto_msg = self.connection.proto_msgs[self.internal_station_name] + msg_to_send = "" + try: + if isinstance(message, bytearray): + msg_to_send = bytes(message) + try: + proto_msg.ParseFromString(msg_to_send) + proto_msg.SerializeToString() + msg_to_send = msg_to_send.decode("utf-8") + except Exception as e: + if "parsing message" in str(e): + e = "Invalid message format, expecting protobuf" + raise MemphisSchemaError(str(e)) + return message + if hasattr(message, "SerializeToString"): + msg_to_send = message.SerializeToString() + proto_msg.ParseFromString(msg_to_send) + proto_msg.SerializeToString() + try: + proto_msg.ParseFromString(msg_to_send) + proto_msg.SerializeToString() + except Exception as e: + if "parsing message" in str(e): + e = "Error parsing protobuf message" + raise MemphisSchemaError(str(e)) + return msg_to_send + elif isinstance(message, dict): + try: + protobuf_json_format.ParseDict(message, proto_msg) + msg_to_send = proto_msg.SerializeToString() + return msg_to_send + except Exception as e: + raise MemphisSchemaError(str(e)) + else: + raise MemphisSchemaError("Unsupported message type") + + except Exception as e: + raise MemphisSchemaError("Schema validation has failed: " + str(e)) + + def validate_json_schema(self, message): + try: + if isinstance(message, bytearray): + try: + message_obj = json.loads(message) + except Exception as e: + raise Exception("Expecting Json format: " + str(e)) + elif isinstance(message, dict): + message_obj = message + message = bytearray(json.dumps(message_obj).encode("utf-8")) + else: + raise Exception("Unsupported message type") + + validate( + instance=message_obj, + schema=self.connection.json_schemas[self.internal_station_name], + ) + return message + except Exception as e: + raise MemphisSchemaError("Schema validation has failed: " + str(e)) + + def validate_graphql(self, message): + try: + if isinstance(message, bytearray): + msg = message.decode("utf-8") + msg = parse_graphql(msg) + elif isinstance(message, str): + msg = parse_graphql(message) + message = message.encode("utf-8") + elif isinstance(message, graphql.language.ast.DocumentNode): + msg = message + message = str(msg.loc.source.body) + message = message.encode("utf-8") + else: + raise Exception("Unsupported message type") + validate_res = validate_graphql( + schema=self.connection.graphql_schemas[self.internal_station_name], + document_ast=msg, + ) + if len(validate_res) > 0: + raise Exception( + "Schema validation has failed: " + str(validate_res)) + return message + except Exception as e: + if "Syntax Error" in str(e): + e = "Invalid message format, expected GraphQL" + raise MemphisSchemaError("Schema validation has failed: " + str(e)) + + def validate_avro_schema(self, message): + try: + if isinstance(message, bytearray): + try: + message_obj = json.loads(message) + except Exception as e: + raise Exception("Expecting Avro format: " + str(e)) + elif isinstance(message, dict): + message_obj = message + message = bytearray(json.dumps(message_obj).encode("utf-8")) + else: + raise Exception("Unsupported message type") + + fastavro.validate( + message_obj, + self.connection.avro_schemas[self.internal_station_name], + ) + return message + except fastavro.validation.ValidationError as e: + raise MemphisSchemaError("Schema validation has failed: " + str(e)) async def destroy(self): """Destroy the station."""