Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,20 @@ async def msg_handler(msgs, error, context):
print(error)
consumer.consume(msg_handler)
```
#### Processing schema deserialized messages
To get messages deserialized, use `msg.get_data_deserialized()`.

```python
async def msg_handler(msgs, error, context):
for msg in msgs:
print("message: ", await msg.get_data_deserialized())
await msg.ack()
if error:
print(error)
consumer.consume(msg_handler)
```

if you have ingested data into station in one format, afterwards you apply a schema on the station, the consumer won't deserialize the previously ingested data. For example, you have ingested string into the station and attached a protobuf schema on the station. In this case, consumer won't deserialize the string.

### Consume using a partition key
The key will be used to consume from a specific partition
Expand Down
10 changes: 8 additions & 2 deletions memphis/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

import json

from memphis.exceptions import MemphisConnectError, MemphisError
from memphis.exceptions import MemphisConnectError, MemphisError, MemphisSchemaError
from memphis.station import Station

class Message:
def __init__(self, message, connection, cg_name, internal_station_name):
self.message = message
self.connection = connection
self.cg_name = cg_name
self.internal_station_name = internal_station_name
self.station = Station(connection, internal_station_name)

async def ack(self):
"""Ack a message is done processing."""
Expand Down Expand Up @@ -42,13 +44,17 @@ def get_data(self):
except Exception:
return

def get_data_deserialized(self):
async def get_data_deserialized(self):
"""Receive the message."""
try:
if self.connection.schema_updates_data and self.connection.schema_updates_data[self.internal_station_name] != {}:
schema_type = self.connection.schema_updates_data[
self.internal_station_name
]["type"]
try:
await self.station.validate_msg(bytearray(self.message.data))
except Exception as e:
raise MemphisSchemaError("Deserialization has been failed since the message format does not align with the currently attached schema: " + str(e))
if schema_type == "protobuf":
proto_msg = self.connection.proto_msgs[self.internal_station_name]
proto_msg.ParseFromString(self.message.data)
Expand Down
158 changes: 3 additions & 155 deletions memphis/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,12 @@
from typing import Union
import warnings

import graphql
from graphql import parse as parse_graphql
from graphql import validate as validate_graphql
from jsonschema import validate
import google.protobuf.json_format as protobuf_json_format
import fastavro
import mmh3
from memphis.exceptions import MemphisError, MemphisSchemaError
from memphis.headers import Headers
from memphis.utils import get_internal_name
from memphis.partition_generator import PartitionGenerator
from memphis.station import Station

schemaverse_fail_alert_type = "schema_validation_fail_alert"

Expand All @@ -28,161 +23,14 @@ def __init__(
self.connection = connection
self.producer_name = producer_name.lower()
self.station_name = station_name
self.station = Station(connection, station_name)
self.internal_station_name = get_internal_name(self.station_name)
self.loop = asyncio.get_running_loop()
self.real_name = real_name
self.background_tasks = set()
if self.internal_station_name in connection.partition_producers_updates_data:
self.partition_generator = PartitionGenerator(connection.partition_producers_updates_data[self.internal_station_name]["partitions_list"])

async def validate_msg(self, message):
if self.connection.schema_updates_data[self.internal_station_name] != {}:
schema_type = self.connection.schema_updates_data[
self.internal_station_name
]["type"]
if schema_type == "protobuf":
message = self.validate_protobuf(message)
return message
if schema_type == "json":
message = self.validate_json_schema(message)
return message
if schema_type == "graphql":
message = self.validate_graphql(message)
return message
if schema_type == "avro":
message = self.validate_avro_schema(message)
return message
if hasattr(message, "SerializeToString"):
msg_to_send = message.SerializeToString()
return msg_to_send
elif isinstance(message, str):
message = message.encode("utf-8")
return message
elif isinstance(message, graphql.language.ast.DocumentNode):
msg = message
message = str(msg.loc.source.body)
message = message.encode("utf-8")
return message
elif hasattr(message, "SerializeToString"):
msg_to_send = message.SerializeToString()
return msg_to_send
elif not isinstance(message, bytearray) and not isinstance(message, dict):
raise MemphisSchemaError("Unsupported message type")
else:
if isinstance(message, dict):
message = bytearray(json.dumps(message).encode("utf-8"))
return message

def validate_protobuf(self, message):
proto_msg = self.connection.proto_msgs[self.internal_station_name]
msg_to_send = ""
try:
if isinstance(message, bytearray):
msg_to_send = bytes(message)
try:
proto_msg.ParseFromString(msg_to_send)
proto_msg.SerializeToString()
msg_to_send = msg_to_send.decode("utf-8")
except Exception as e:
if "parsing message" in str(e):
e = "Invalid message format, expecting protobuf"
raise MemphisSchemaError(str(e))
return message
if hasattr(message, "SerializeToString"):
msg_to_send = message.SerializeToString()
proto_msg.ParseFromString(msg_to_send)
proto_msg.SerializeToString()
try:
proto_msg.ParseFromString(msg_to_send)
proto_msg.SerializeToString()
except Exception as e:
if "parsing message" in str(e):
e = "Error parsing protobuf message"
raise MemphisSchemaError(str(e))
return msg_to_send
elif isinstance(message, dict):
try:
protobuf_json_format.ParseDict(message, proto_msg)
msg_to_send = proto_msg.SerializeToString()
return msg_to_send
except Exception as e:
raise MemphisSchemaError(str(e))
else:
raise MemphisSchemaError("Unsupported message type")

except Exception as e:
raise MemphisSchemaError("Schema validation has failed: " + str(e))

def validate_json_schema(self, message):
try:
if isinstance(message, bytearray):
try:
message_obj = json.loads(message)
except Exception as e:
raise Exception("Expecting Json format: " + str(e))
elif isinstance(message, dict):
message_obj = message
message = bytearray(json.dumps(message_obj).encode("utf-8"))
else:
raise Exception("Unsupported message type")

validate(
instance=message_obj,
schema=self.connection.json_schemas[self.internal_station_name],
)
return message
except Exception as e:
raise MemphisSchemaError("Schema validation has failed: " + str(e))

def validate_graphql(self, message):
try:
if isinstance(message, bytearray):
msg = message.decode("utf-8")
msg = parse_graphql(msg)
elif isinstance(message, str):
msg = parse_graphql(message)
message = message.encode("utf-8")
elif isinstance(message, graphql.language.ast.DocumentNode):
msg = message
message = str(msg.loc.source.body)
message = message.encode("utf-8")
else:
raise Exception("Unsupported message type")
validate_res = validate_graphql(
schema=self.connection.graphql_schemas[self.internal_station_name],
document_ast=msg,
)
if len(validate_res) > 0:
raise Exception(
"Schema validation has failed: " + str(validate_res))
return message
except Exception as e:
if "Syntax Error" in str(e):
e = "Invalid message format, expected GraphQL"
raise MemphisSchemaError("Schema validation has failed: " + str(e))


def validate_avro_schema(self, message):
try:
if isinstance(message, bytearray):
try:
message_obj = json.loads(message)
except Exception as e:
raise Exception("Expecting Avro format: " + str(e))
elif isinstance(message, dict):
message_obj = message
message = bytearray(json.dumps(message_obj).encode("utf-8"))
else:
raise Exception("Unsupported message type")

fastavro.validate(
message_obj,
self.connection.avro_schemas[self.internal_station_name],
)
return message
except fastavro.validation.ValidationError as e:
raise MemphisSchemaError("Schema validation has failed: " + str(e))

# pylint: disable=R0913
async def produce(
self,
Expand Down Expand Up @@ -224,7 +72,7 @@ async def produce(
Exception: _description_
"""
try:
message = await self.validate_msg(message)
message = await self.station.validate_msg(message)

memphis_headers = {
"$memphis_producedBy": self.producer_name,
Expand Down
Loading