From 3659b286dff74055008f69095158592658a116cf Mon Sep 17 00:00:00 2001 From: shay23b Date: Mon, 3 Apr 2023 15:13:44 +0300 Subject: [PATCH] fix send to dls and allow produce more types --- memphis/producer.py | 115 +++++++++++++++++++++++++------------------- 1 file changed, 65 insertions(+), 50 deletions(-) diff --git a/memphis/producer.py b/memphis/producer.py index 2420410..eabbd81 100644 --- a/memphis/producer.py +++ b/memphis/producer.py @@ -41,6 +41,17 @@ async def validate_msg(self, message): elif schema_type == "graphql": message = self.validate_graphql(message) return message + elif hasattr(message, "SerializeToString"): + msgToSend = message.SerializeToString() + return msgToSend + elif isinstance(message, str): + message = message.encode("utf-8") + return message + elif isinstance(message, graphql.language.ast.DocumentNode): + msg = message + message = str(msg.loc.source.body) + message = message.encode("utf-8") + return message elif not isinstance(message, bytearray) and not isinstance(message, dict): raise MemphisSchemaError("Unsupported message type") else: @@ -181,63 +192,67 @@ async def produce( except Exception as e: if hasattr(e, "status_code") and e.status_code == "503": raise MemphisError( - "Produce operation has failed, please check whether Station/Producer are still exist" + "Produce operation has failed, please check whether Station/Producer still exist" ) else: if "Schema validation has failed" in str( e ) or "Unsupported message type" in str(e): - msgToSend = "" - if hasattr(message, "SerializeToString"): - msgToSend = message.SerializeToString().decode("utf-8") - elif isinstance(message, bytearray): - msgToSend = str(message, "utf-8") - else: - msgToSend = str(message) - if self.connection.station_schemaverse_to_dls[ - self.internal_station_name - ]: - memphis_headers = { - "$memphis_producedBy": self.producer_name, - "$memphis_connectionId": self.connection.connection_id, - } - - if headers != {}: - headers = headers.headers - headers.update(memphis_headers) + if self.connection.schema_updates_data[self.internal_station_name] != {}: + msgToSend = "" + if hasattr(message, "SerializeToString"): + msgToSend = message.SerializeToString().decode("utf-8") + elif isinstance(message, bytearray): + msgToSend = str(message, "utf-8") else: - headers = memphis_headers + msgToSend = str(message) + if self.connection.station_schemaverse_to_dls[ + self.internal_station_name + ]: + unix_time = int(time.time()) + + memphis_headers = { + "$memphis_producedBy": self.producer_name, + "$memphis_connectionId": self.connection.connection_id, + } + + if headers != {}: + headers = headers.headers + headers.update(memphis_headers) + else: + headers = memphis_headers - msgToSendEncoded = msgToSend.encode("utf-8") - msgHex = msgToSendEncoded.hex() - buf = { - "station_name": self.internal_station_name, - "producer": { - "name": self.producer_name, - "connection_id": self.connection.connection_id, - }, - "message": { - "data": msgHex, - "headers": headers, - }, - "validation_error": str(e), - } - buf = json.dumps(buf).encode("utf-8") - await self.connection.broker_manager.publish("$memphis_schemaverse_dls", buf) - if self.connection.cluster_configurations.get( - "send_notification" - ): - await self.connection.send_notification( - "Schema validation has failed", - "Station: " - + self.station_name - + "\nProducer: " - + self.producer_name - + "\nError:" - + str(e), - msgToSend, - schemaVFailAlertType, - ) + msgToSendEncoded = msgToSend.encode("utf-8") + msgHex = msgToSendEncoded.hex() + buf = { + "station_name": self.internal_station_name, + "producer": { + "name": self.producer_name, + "connection_id": self.connection.connection_id, + }, + "creation_unix": unix_time, + "message": { + "data": msgHex, + "headers": headers, + }, + "validation_error": str(e), + } + buf = json.dumps(buf).encode("utf-8") + await self.connection.broker_manager.publish("$memphis_schemaverse_dls", buf) + if self.connection.cluster_configurations.get( + "send_notification" + ): + await self.connection.send_notification( + "Schema validation has failed", + "Station: " + + self.station_name + + "\nProducer: " + + self.producer_name + + "\nError:" + + str(e), + msgToSend, + schemaVFailAlertType, + ) raise MemphisError(str(e)) from e async def destroy(self):