Skip to content

Commit

Permalink
Merge branch 'feature/issue-12-array-is-a-complex-type' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
smn committed Feb 13, 2015
2 parents 3fbb0a8 + e35e781 commit 11abb9b
Show file tree
Hide file tree
Showing 16 changed files with 265 additions and 135 deletions.
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ install:
- pip install -r requirements-dev.txt
- pip install -e .
- pip install coveralls
- pip install flake8
script:
- flake8 elasticgit
- py.test elasticgit -s --cov ./elasticgit
after_success:
- coveralls
Expand Down
13 changes: 0 additions & 13 deletions elasticgit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
import pkg_resources
import sys

from elasticgit.workspace import EG, F, Q

__all__ = ['EG', 'F', 'Q']
__version__ = pkg_resources.require('elastic-git')[0].version

version_info = {
'language': 'python',
'language_version_string': sys.version,
'language_version': '%d.%d.%d' % (
sys.version_info.major,
sys.version_info.minor,
sys.version_info.micro,
),
'package': 'elastic-git',
'package_version': __version__
}
95 changes: 52 additions & 43 deletions elasticgit/commands/avro.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import absolute_import

from jinja2 import Environment, PackageLoader
from functools import partial
import argparse
Expand All @@ -7,21 +9,23 @@

from datetime import datetime

from elasticgit import version_info
import avro.schema

from elasticgit.models import (
Model, IntegerField, TextField, ModelVersionField, FloatField,
BooleanField, ListField, DictField, UUIDField)
Model, IntegerField, TextField, FloatField,
BooleanField, ListField, DictField, UUIDField,
version_info)

from elasticgit.commands.base import (
ToolCommand, ToolCommandError, CommandArgument)
from elasticgit.utils import load_class


def deserialize(schema, field_mapping={}, module_name=None):
def deserialize(data, field_mapping={}, module_name=None):
"""
Deserialize an Avro schema and define it within a module (if specified)
:param dict schema:
:param dict data:
The Avro schema
:param dict field_mapping:
Optional mapping to override the default mapping.
Expand All @@ -47,6 +51,7 @@ def deserialize(schema, field_mapping={}, module_name=None):
"""
schema_loader = SchemaLoader()
schema = avro.schema.make_avsc_object(data, avro.schema.Names()).to_json()
model_code = schema_loader.generate_model(schema)
model_name = schema['name']

Expand Down Expand Up @@ -165,13 +170,11 @@ class SchemaLoader(ToolCommand):
action='append', type=RenameType),
)

mapping = {
core_mapping = {
'int': IntegerField,
'string': TextField,
'float': FloatField,
'boolean': BooleanField,
'array': ListField,
'record': DictField,
}

def run(self, schema_files, field_mappings=None, model_renames=None):
Expand Down Expand Up @@ -212,15 +215,20 @@ def field_class_for(self, field, field_mapping):

if isinstance(field_type, dict):
return self.field_class_for_complex_type(field)
return self.mapping[field_type].__name__
return self.core_mapping[field_type].__name__

def field_class_for_complex_type(self, field):
field_type = field['type']
if (field_type['name'] == 'ModelVersionField' and
field_type['namespace'] == 'elasticgit.models'):
return ModelVersionField.__name__
handler = getattr(
self, 'field_class_for_complex_%(type)s_type' % field_type)
return handler(field)

def field_class_for_complex_record_type(self, field):
return DictField.__name__

def field_class_for_complex_array_type(self, field):
return ListField.__name__

def default_value(self, field):
return pprint.pformat(field['default'], indent=8)

Expand Down Expand Up @@ -274,6 +282,9 @@ def generate_model(self, schema, field_mapping={}, model_renames={},
env.globals['field_class_for'] = partial(
self.field_class_for, field_mapping=field_mapping)
env.globals['default_value'] = self.default_value
env.globals['is_complex'] = (
lambda field: isinstance(field['type'], dict))
env.globals['core_mapping'] = self.core_mapping

template = env.get_template('model_generator.py.txt')
return template.render(
Expand Down Expand Up @@ -301,41 +312,13 @@ class SchemaDumper(ToolCommand):
CommandArgument('class_path', help='python path to Class.'),
)

mapping = {
# How model fields map to types
core_field_mappings = {
IntegerField: 'int',
TextField: 'string',
FloatField: 'float',
BooleanField: 'boolean',
ListField: 'array',
DictField: 'record',
UUIDField: 'string',
ModelVersionField: {
'type': 'record',
'name': 'ModelVersionField',
'namespace': 'elasticgit.models',
'fields': [
{
'name': 'language',
'type': 'string',
},
{
'name': 'language_version_string',
'type': 'string',
},
{
'name': 'language_version',
'type': 'string',
},
{
'name': 'package',
'type': 'string',
},
{
'name': 'package_version',
'type': 'string',
}
]
}
}

def run(self, class_path):
Expand Down Expand Up @@ -370,6 +353,32 @@ def dump_schema(self, model_class):
for name, field in model_class._fields.items()],
}, indent=2)

def map_field_to_type(self, field):
if field.__class__ in self.core_field_mappings:
return self.core_field_mappings[field.__class__]

handler = getattr(self, 'map_%s_type' % (field.__class__.__name__,))
return handler(field)

def map_ListField_type(self, field):
return {
'type': 'array',
'name': field.name,
'namespace': field.__class__.__module__,
'items': [self.map_field_to_type(fld) for fld in field.fields],
}

def map_DictField_type(self, field):
return {
'type': 'record',
'name': field.name,
'namespace': field.__class__.__module__,
'fields': [{
'name': fld.name,
'type': self.map_field_to_type(fld),
} for fld in field.fields],
}

def get_field_info(self, name, field):
"""
Return the Avro field object for an
Expand All @@ -383,7 +392,7 @@ def get_field_info(self, name, field):
"""
return {
'name': name,
'type': self.mapping[field.__class__],
'type': self.map_field_to_type(field),
'doc': field.doc,
'default': field.default,
'aliases': [fallback.field_name for fallback in field.fallbacks]
Expand Down
5 changes: 4 additions & 1 deletion elasticgit/commands/gitmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def guess_type(self, value):
float: 'float',
str: 'string',
unicode: 'string',
list: 'array',
list: {
'type': 'array',
'items': ['string'],
},
None: 'null',
}[None if value is None else type(value)]
67 changes: 49 additions & 18 deletions elasticgit/commands/tests/test_avro.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from elasticgit import models
from elasticgit.tests.base import ToolBaseTest

import elasticgit


class TestDumpSchemaTool(ToolBaseTest):

Expand Down Expand Up @@ -47,6 +45,18 @@ class TestModel(models.Model):
age = self.get_field(schema, 'age')
self.assertEqual(age['aliases'], ['length'])

def test_dump_array(self):
class TestModel(models.Model):
tags = models.ListField('The tags',
fields=(models.IntegerField('doc'),))

schema_dumper = self.mk_schema_dumper()
schema = json.loads(schema_dumper.dump_schema(TestModel))
tags = self.get_field(schema, 'tags')
field_type = tags['type']
self.assertEqual(field_type['type'], 'array')
self.assertEqual(field_type['items'], ['int'])


class TestLoadSchemaTool(ToolBaseTest):

Expand Down Expand Up @@ -124,15 +134,25 @@ def test_boolean_field(self):
def test_array_field(self):
self.assertFieldCreation({
'name': 'array',
'type': 'array',
'type': {
'type': 'array',
'items': ['string'],
},
'doc': 'The Array',
'default': ['foo', 'bar', 'baz']
}, models.ListField)

def test_dict_field(self):
self.assertFieldCreation({
'name': 'obj',
'type': 'record',
'type': {
'type': 'record',
'items': ['string'],
'fields': [{
'name': 'hello',
'type': 'string',
}]
},
'doc': 'The Object',
'default': {'hello': 'world'},
}, models.DictField)
Expand All @@ -143,7 +163,12 @@ def test_complex_field(self):
'type': {
'namespace': 'foo.bar',
'name': 'ItIsComplicated',
'type': 'record'
'type': 'record',
'items': ['string'],
'fields': [{
'name': 'foo',
'type': 'string',
}]
},
'doc': 'Super Complex',
'default': {},
Expand All @@ -154,12 +179,13 @@ def test_version_field(self):
'name': 'version',
'type': {
'namespace': 'elasticgit.models',
'name': 'ModelVersionField',
'name': 'version',
'type': 'record',
'items': ['string'],
},
'doc': 'The Model Version',
'default': elasticgit.version_info,
}, models.ModelVersionField)
'default': models.version_info,
}, models.DictField)

def test_mapping_hints(self):
self.assertFieldCreation({
Expand All @@ -177,8 +203,12 @@ class DumpAndLoadModel(models.Model):
integer = models.IntegerField('the integer')
float_ = models.FloatField('the float')
boolean = models.BooleanField('the boolean')
list_ = models.ListField('the list')
dict_ = models.DictField('the dict')
list_ = models.ListField('the list', fields=(
models.IntegerField('the int'),
))
dict_ = models.DictField('the dict', fields=(
models.TextField('hello', name='hello'),
))


class TestDumpAndLoad(ToolBaseTest):
Expand All @@ -189,6 +219,7 @@ def test_two_way(self):
schema_loader = self.mk_schema_loader()

schema = schema_dumper.dump_schema(DumpAndLoadModel)

generated_code = schema_loader.generate_model(json.loads(schema))

GeneratedModel = self.load_class(generated_code, 'DumpAndLoadModel')
Expand All @@ -198,7 +229,7 @@ def test_two_way(self):
'integer': 1,
'float': 1.1,
'boolean': False,
'list': ['1', '2', '3'],
'list_': [1, 2, 3],
'dict_': {'hello': 'world'}
}
record1 = DumpAndLoadModel(data)
Expand All @@ -220,8 +251,8 @@ def test_two_way_dict_ints(self):
'integer': 1,
'float': 1.1,
'boolean': False,
'list': ['1', '2', '3'],
'dict_': {'hello': 1}
'list_': [1, 2, 3],
'dict_': {'hello': '1'}
}
record1 = DumpAndLoadModel(data)
record2 = GeneratedModel(data)
Expand All @@ -242,7 +273,7 @@ def test_two_way_list_ints(self):
'integer': 1,
'float': 1.1,
'boolean': False,
'list': [1, 2, 3],
'list_': [1, 2, 3],
'dict_': {'hello': '1'}
}
record1 = DumpAndLoadModel(data)
Expand All @@ -264,7 +295,7 @@ def test_two_way_list_unicode(self):
'integer': 1,
'float': 1.1,
'boolean': False,
'list': [1, 2, 3],
'list_': [1, 2, 3],
'dict_': {'hello': '1'}
}
record1 = DumpAndLoadModel(data)
Expand Down Expand Up @@ -326,7 +357,7 @@ def test_load_older_version(self):
class Foo(models.Model):
pass

old_version_info = elasticgit.version_info.copy()
old_version_info = models.version_info.copy()
old_version_info['package_version'] = '0.0.1'

f = Foo({
Expand All @@ -341,9 +372,9 @@ class Foo(models.Model):
pass

major, minor, micro = map(
int, elasticgit.version_info['package_version'].split('.'))
int, models.version_info['package_version'].split('.'))

new_version = elasticgit.version_info.copy()
new_version = models.version_info.copy()
new_version['package_version'] = '%d.%d.%d' % (
major + 1,
minor,
Expand Down

0 comments on commit 11abb9b

Please sign in to comment.