Skip to content

Commit

Permalink
Merge pull request #7 from pawl/clean_up_conversions
Browse files Browse the repository at this point in the history
clean up converter arguments
  • Loading branch information
pawl committed Mar 27, 2016
2 parents 05b51be + c2b4a45 commit 6e4e771
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 73 deletions.
127 changes: 96 additions & 31 deletions tests/tests.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import unicode_literals, absolute_import

from sqlalchemy import create_engine, ForeignKey
from sqlalchemy import create_engine, ForeignKey, types as sqla_types
from sqlalchemy.schema import MetaData, Table, Column, ColumnDefault
from sqlalchemy.types import String, Integer, Numeric, Date, Text, Enum, Boolean, DateTime
from sqlalchemy.orm import sessionmaker, relationship, backref
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.dialects.postgresql import INET, MACADDR, UUID
from sqlalchemy.dialects.mysql import YEAR
from sqlalchemy.dialects.mssql import BIT

from unittest import TestCase

Expand All @@ -27,7 +29,7 @@ def __init__(self, **kwargs):
setattr(self, k, v)


class AnotherInteger(Integer):
class AnotherInteger(sqla_types.Integer):
"""Use me to test if MRO works like we want"""


Expand All @@ -37,14 +39,14 @@ def _do_tables(self, mapper, engine):

test_table = Table(
'test', metadata,
Column('id', Integer, primary_key=True, nullable=False),
Column('name', String, nullable=False),
Column('id', sqla_types.Integer, primary_key=True, nullable=False),
Column('name', sqla_types.String, nullable=False),
)

pk_test_table = Table(
'pk_test', metadata,
Column('foobar', String, primary_key=True, nullable=False),
Column('baz', String, nullable=False),
Column('foobar', sqla_types.String, primary_key=True, nullable=False),
Column('baz', sqla_types.String, nullable=False),
)

Test = type(str('Test'), (Base, ), {})
Expand Down Expand Up @@ -188,33 +190,33 @@ def setUp(self):

student_course = Table(
'student_course', Model.metadata,
Column('student_id', Integer, ForeignKey('student.id')),
Column('course_id', Integer, ForeignKey('course.id'))
Column('student_id', sqla_types.Integer, ForeignKey('student.id')),
Column('course_id', sqla_types.Integer, ForeignKey('course.id'))
)

class Course(Model):
__tablename__ = "course"
id = Column(Integer, primary_key=True)
name = Column(String(255), nullable=False)
id = Column(sqla_types.Integer, primary_key=True)
name = Column(sqla_types.String(255), nullable=False)
# These are for better model form testing
cost = Column(Numeric(5, 2), nullable=False)
description = Column(Text, nullable=False)
level = Column(Enum('Primary', 'Secondary'))
has_prereqs = Column(Boolean, nullable=False)
started = Column(DateTime, nullable=False)
cost = Column(sqla_types.Numeric(5, 2), nullable=False)
description = Column(sqla_types.Text, nullable=False)
level = Column(sqla_types.Enum('Primary', 'Secondary'))
has_prereqs = Column(sqla_types.Boolean, nullable=False)
started = Column(sqla_types.DateTime, nullable=False)
grade = Column(AnotherInteger, nullable=False)

class School(Model):
__tablename__ = "school"
id = Column(Integer, primary_key=True)
name = Column(String(255), nullable=False)
id = Column(sqla_types.Integer, primary_key=True)
name = Column(sqla_types.String(255), nullable=False)

class Student(Model):
__tablename__ = "student"
id = Column(Integer, primary_key=True)
full_name = Column(String(255), nullable=False, unique=True)
dob = Column(Date(), nullable=True)
current_school_id = Column(Integer, ForeignKey(School.id), nullable=False)
id = Column(sqla_types.Integer, primary_key=True)
full_name = Column(sqla_types.String(255), nullable=False, unique=True)
dob = Column(sqla_types.Date(), nullable=True)
current_school_id = Column(sqla_types.Integer, ForeignKey(School.id), nullable=False)

current_school = relationship(School, backref=backref('students'))
courses = relationship(
Expand Down Expand Up @@ -282,9 +284,6 @@ def test_convert_basic(self):
form_class = model_form(self.Course, exclude=['students'])
form = form_class()
self.assertEqual(len(list(form)), 7)
assert isinstance(form.cost, fields.DecimalField)
assert isinstance(form.has_prereqs, fields.BooleanField)
assert isinstance(form.started, fields.DateTimeField)

def test_only(self):
desired_fields = ['id', 'cost', 'description']
Expand All @@ -311,16 +310,16 @@ def default_score():

class StudentDefaultScoreCallable(Model):
__tablename__ = "course"
id = Column(Integer, primary_key=True)
name = Column(String(255), nullable=False)
score = Column(Integer, default=default_score, nullable=False)
id = Column(sqla_types.Integer, primary_key=True)
name = Column(sqla_types.String(255), nullable=False)
score = Column(sqla_types.Integer, default=default_score, nullable=False)

class StudentDefaultScoreScalar(Model):
__tablename__ = "school"
id = Column(Integer, primary_key=True)
name = Column(String(255), nullable=False)
id = Column(sqla_types.Integer, primary_key=True)
name = Column(sqla_types.String(255), nullable=False)
# Default scalar value
score = Column(Integer, default=10, nullable=False)
score = Column(sqla_types.Integer, default=10, nullable=False)

self.StudentDefaultScoreCallable = StudentDefaultScoreCallable
self.StudentDefaultScoreScalar = StudentDefaultScoreScalar
Expand All @@ -339,3 +338,69 @@ def test_column_default_scalar(self):
student_form = model_form(self.StudentDefaultScoreScalar, self.sess)()
assert not isinstance(student_form._fields['score'].default, ColumnDefault)
self.assertEqual(student_form._fields['score'].default, 10)


class ModelFormTest(TestCase):
def setUp(self):
Model = declarative_base()

class AllTypesModel(Model):
__tablename__ = "course"
id = Column(sqla_types.Integer, primary_key=True)
string = Column(sqla_types.String)
unicode = Column(sqla_types.Unicode)
varchar = Column(sqla_types.VARCHAR)
integer = Column(sqla_types.Integer)
biginteger = Column(sqla_types.BigInteger)
smallinteger = Column(sqla_types.SmallInteger)
numeric = Column(sqla_types.Numeric)
float = Column(sqla_types.Float)
text = Column(sqla_types.Text)
binary = Column(sqla_types.Binary)
largebinary = Column(sqla_types.LargeBinary)
unicodetext = Column(sqla_types.UnicodeText)
enum = Column(sqla_types.Enum('Primary', 'Secondary'))
boolean = Column(sqla_types.Boolean)
datetime = Column(sqla_types.DateTime)
timestamp = Column(sqla_types.TIMESTAMP)
date = Column(sqla_types.Date)
postgres_inet = Column(INET)
postgres_macaddr = Column(MACADDR)
postgres_uuid = Column(UUID)
mysql_year = Column(YEAR)
mssql_bit = Column(BIT)

self.AllTypesModel = AllTypesModel

def test_convert_types(self):
form = model_form(self.AllTypesModel)()

assert isinstance(form.string, fields.StringField)
assert isinstance(form.unicode, fields.StringField)
assert isinstance(form.varchar, fields.StringField)
assert isinstance(form.postgres_inet, fields.StringField)
assert isinstance(form.postgres_macaddr, fields.StringField)
assert isinstance(form.postgres_uuid, fields.StringField)
assert isinstance(form.mysql_year, fields.StringField)

assert isinstance(form.integer, fields.IntegerField)
assert isinstance(form.biginteger, fields.IntegerField)
assert isinstance(form.smallinteger, fields.IntegerField)

assert isinstance(form.numeric, fields.DecimalField)
assert isinstance(form.float, fields.DecimalField)

assert isinstance(form.text, fields.TextAreaField)
assert isinstance(form.binary, fields.TextAreaField)
assert isinstance(form.largebinary, fields.TextAreaField)
assert isinstance(form.unicodetext, fields.TextAreaField)

assert isinstance(form.enum, fields.SelectField)

assert isinstance(form.boolean, fields.BooleanField)
assert isinstance(form.mssql_bit, fields.BooleanField)

assert isinstance(form.datetime, fields.DateTimeField)
assert isinstance(form.timestamp, fields.DateTimeField)

assert isinstance(form.date, fields.DateField)
91 changes: 49 additions & 42 deletions wtforms_sqlalchemy/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

import inspect

from wtforms import fields as f
from wtforms import validators
from wtforms import validators, fields as wtforms_fields
from wtforms.form import Form
from .fields import QuerySelectField, QuerySelectMultipleField

Expand Down Expand Up @@ -42,6 +41,34 @@ def __init__(self, converters, use_mro=True):

self.converters = converters

def get_converter(self, column):
"""
Searches `self.converters` for a converter method with an argument that
matches the column's type.
"""
if self.use_mro:
types = inspect.getmro(type(column.type))
else:
types = [type(column.type)]

# Search by module + name
for col_type in types:
type_string = '%s.%s' % (col_type.__module__, col_type.__name__)

# remove the 'sqlalchemy.' prefix for sqlalchemy <0.7 compatibility
if type_string.startswith('sqlalchemy'):
type_string = type_string[11:]

if type_string in self.converters:
return self.converters[type_string]

# Search by name
for col_type in types:
if col_type.__name__ in self.converters:
return self.converters[col_type.__name__]

raise ModelConversionError('Could not find field converter for %s (%r).' % (prop.key, types[0]))

def convert(self, model, mapper, prop, field_args, db_session=None):
if not hasattr(prop, 'columns') and not hasattr(prop, 'direction'):
return
Expand All @@ -65,7 +92,6 @@ def convert(self, model, mapper, prop, field_args, db_session=None):

converter = None
column = None
types = None

if not hasattr(prop, 'direction'):
column = prop.columns[0]
Expand All @@ -91,26 +117,7 @@ def convert(self, model, mapper, prop, field_args, db_session=None):
else:
kwargs['validators'].append(validators.Required())

if self.use_mro:
types = inspect.getmro(type(column.type))
else:
types = [type(column.type)]

for col_type in types:
type_string = '%s.%s' % (col_type.__module__, col_type.__name__)
if type_string.startswith('sqlalchemy'):
type_string = type_string[11:]

if type_string in self.converters:
converter = self.converters[type_string]
break
else:
for col_type in types:
if col_type.__name__ in self.converters:
converter = self.converters[col_type.__name__]
break
else:
raise ModelConversionError('Could not find field converter for %s (%r).' % (prop.key, types[0]))
converter = self.get_converter(column)
else:
# We have a property with a direction.
if not db_session:
Expand Down Expand Up @@ -148,68 +155,68 @@ def _string_common(cls, column, field_args, **extra):
if isinstance(column.type.length, int) and column.type.length:
field_args['validators'].append(validators.Length(max=column.type.length))

@converts('String', 'Unicode')
@converts('String') # includes Unicode
def conv_String(self, field_args, **extra):
self._string_common(field_args=field_args, **extra)
return f.TextField(**field_args)
return wtforms_fields.StringField(**field_args)

@converts('types.Text', 'UnicodeText', 'types.LargeBinary', 'types.Binary', 'sql.sqltypes.Text')
@converts('Text', 'LargeBinary', 'Binary') # includes UnicodeText
def conv_Text(self, field_args, **extra):
self._string_common(field_args=field_args, **extra)
return f.TextAreaField(**field_args)
return wtforms_fields.TextAreaField(**field_args)

@converts('Boolean')
@converts('Boolean', 'dialects.mssql.base.BIT')
def conv_Boolean(self, field_args, **extra):
return f.BooleanField(**field_args)
return wtforms_fields.BooleanField(**field_args)

@converts('Date')
def conv_Date(self, field_args, **extra):
return f.DateField(**field_args)
return wtforms_fields.DateField(**field_args)

@converts('DateTime')
def conv_DateTime(self, field_args, **extra):
return f.DateTimeField(**field_args)
return wtforms_fields.DateTimeField(**field_args)

@converts('Enum')
def conv_Enum(self, column, field_args, **extra):
field_args['choices'] = [(e, e) for e in column.type.enums]
return f.SelectField(**field_args)
return wtforms_fields.SelectField(**field_args)

@converts('Integer', 'SmallInteger')
@converts('Integer') # includes BigInteger and SmallInteger
def handle_integer_types(self, column, field_args, **extra):
unsigned = getattr(column.type, 'unsigned', False)
if unsigned:
field_args['validators'].append(validators.NumberRange(min=0))
return f.IntegerField(**field_args)
return wtforms_fields.IntegerField(**field_args)

@converts('Numeric', 'Float')
@converts('Numeric') # includes DECIMAL, Float/FLOAT, REAL, and DOUBLE
def handle_decimal_types(self, column, field_args, **extra):
# override default decimal places limit, use database defaults instead
field_args.setdefault('places', None)
return f.DecimalField(**field_args)
return wtforms_fields.DecimalField(**field_args)

@converts('databases.mysql.MSYear', 'dialects.mysql.base.YEAR')
@converts('dialects.mysql.base.YEAR')
def conv_MSYear(self, field_args, **extra):
field_args['validators'].append(validators.NumberRange(min=1901, max=2155))
return f.TextField(**field_args)
return wtforms_fields.StringField(**field_args)

@converts('databases.postgres.PGInet', 'dialects.postgresql.base.INET')
@converts('dialects.postgresql.base.INET')
def conv_PGInet(self, field_args, **extra):
field_args.setdefault('label', 'IP Address')
field_args['validators'].append(validators.IPAddress())
return f.TextField(**field_args)
return wtforms_fields.StringField(**field_args)

@converts('dialects.postgresql.base.MACADDR')
def conv_PGMacaddr(self, field_args, **extra):
field_args.setdefault('label', 'MAC Address')
field_args['validators'].append(validators.MacAddress())
return f.TextField(**field_args)
return wtforms_fields.StringField(**field_args)

@converts('dialects.postgresql.base.UUID')
def conv_PGUuid(self, field_args, **extra):
field_args.setdefault('label', 'UUID')
field_args['validators'].append(validators.UUID())
return f.TextField(**field_args)
return wtforms_fields.StringField(**field_args)

@converts('MANYTOONE')
def conv_ManyToOne(self, field_args, **extra):
Expand Down

0 comments on commit 6e4e771

Please sign in to comment.