Skip to content

Commit

Permalink
Backport schema mixin.
Browse files Browse the repository at this point in the history
  • Loading branch information
saabeilin committed Oct 15, 2018
1 parent d5fad87 commit b9eb7ef
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 4 deletions.
74 changes: 74 additions & 0 deletions kafkian/serde/avroserdebase.py
@@ -0,0 +1,74 @@
import struct

import avro.io
from confluent_kafka.avro import MessageSerializer as ConfluentMessageSerializer
from confluent_kafka.avro.serializer import SerializerError
from confluent_kafka.avro.serializer.message_serializer import ContextStringIO

MAGIC_BYTE = 0


class HasSchemaMixin:
"""
A mixing for decoded Avro record to make able to add schema attribute
"""
def schema(self):
"""
:return: Avro schema for used to decode this entity
:rtype: avro.schema.Schema
"""
return self._schema


def _wrap(value, schema):
"""
Wraps a value into subclass with HasSchemaMixin
:param value: a decoded value
:param schema: corresponding Avro schema used to decode value
:return: An instance of a dynamically created class with schema fullname
"""
if hasattr(schema, 'fullname'):
name = schema.fullname
elif hasattr(schema, 'namespace'):
name = "{namespace}.{name}".format(namespace=schema.namespace,
name=schema.name)
elif hasattr(schema, 'name'):
name = schema.name
else:
name = schema.type

new_class = type(str(name), (value.__class__, HasSchemaMixin), {})

wrapped = new_class(value)
wrapped._schema = schema
return wrapped


class AvroSerDeBase(ConfluentMessageSerializer):
"""
A subclass of MessageSerializer from Confluent's kafka-python,
adding schema to deserialized Avro messages.
"""

def decode_message(self, message):
"""
Decode a message from kafka that has been encoded for use with
the schema registry.
@:param: message
"""

if message is None:
return None

if len(message) <= 5:
raise SerializerError("message is too small to decode")

with ContextStringIO(message) as payload:
magic, schema_id = struct.unpack('>bI', payload.read(5))
if magic != MAGIC_BYTE:
raise SerializerError("message does not start with magic byte")
decoder_func = self._get_decoder_func(schema_id, payload)
return _wrap(
decoder_func(payload),
self.registry_client.get_by_id(schema_id)
)
6 changes: 4 additions & 2 deletions kafkian/serde/deserialization.py
@@ -1,4 +1,6 @@
from confluent_kafka.avro import CachedSchemaRegistryClient, MessageSerializer
from confluent_kafka.avro import CachedSchemaRegistryClient

from .avroserdebase import AvroSerDeBase


class Deserializer:
Expand All @@ -17,7 +19,7 @@ class AvroDeserializer(Deserializer):
def __init__(self, schema_registry_url: str, **kwargs) -> None:
super().__init__(**kwargs)
self.schema_registry = CachedSchemaRegistryClient(schema_registry_url)
self._serializer_impl = MessageSerializer(self.schema_registry)
self._serializer_impl = AvroSerDeBase(self.schema_registry)

def deserialize(self, value, **kwargs):
return self._serializer_impl.decode_message(value)
6 changes: 4 additions & 2 deletions kafkian/serde/serialization.py
@@ -1,7 +1,9 @@
from enum import Enum

from confluent_kafka import avro
from confluent_kafka.avro import CachedSchemaRegistryClient, MessageSerializer
from confluent_kafka.avro import CachedSchemaRegistryClient

from .avroserdebase import AvroSerDeBase


class SubjectNameStrategy(Enum):
Expand Down Expand Up @@ -33,7 +35,7 @@ def __init__(self,
self.schema_registry = CachedSchemaRegistryClient(schema_registry_url)
self.auto_register_schemas = auto_register_schemas
self.subject_name_strategy = subject_name_strategy
self._serializer_impl = MessageSerializer(self.schema_registry)
self._serializer_impl = AvroSerDeBase(self.schema_registry)

def _get_subject(self, topic, schema, is_key=False):
if self.subject_name_strategy == SubjectNameStrategy.TopicNameStrategy:
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/test_avro_serde_base.py
@@ -0,0 +1,41 @@
from confluent_kafka import avro

from kafkian.serde.avroserdebase import HasSchemaMixin, _wrap


SCHEMA = avro.loads("""
{
"name": "basic",
"type": "record",
"doc": "basic schema for tests",
"namespace": "python.test.basic",
"fields": [
{
"name": "number",
"doc": "age",
"type": [
"long",
"null"
]
},
{
"name": "name",
"doc": "a name",
"type": [
"string"
]
}
]
}
""")


def test_schema_mixin_wrapper():
for base_class in (int, float, dict, list):
val = base_class()
wrapped = _wrap(val, SCHEMA)
assert val == wrapped
assert isinstance(wrapped, base_class)
assert isinstance(wrapped, HasSchemaMixin)
assert wrapped.schema() is SCHEMA
assert wrapped.__class__.__name__ == 'python.test.basic.basic'

0 comments on commit b9eb7ef

Please sign in to comment.