From c2b4a457d4f87433ea93a84f531d2f71fa8e793b Mon Sep 17 00:00:00 2001 From: Paul Brown Date: Sat, 26 Mar 2016 21:06:18 -0500 Subject: [PATCH] removed unnecessary converter arguments, add tests --- tests/tests.py | 127 ++++++++++++++++++++++++++++---------- wtforms_sqlalchemy/orm.py | 91 ++++++++++++++------------- 2 files changed, 145 insertions(+), 73 deletions(-) diff --git a/tests/tests.py b/tests/tests.py index 3c9c3ec..1e63a7f 100755 --- a/tests/tests.py +++ b/tests/tests.py @@ -1,10 +1,12 @@ from __future__ import unicode_literals, absolute_import -from sqlalchemy import create_engine, ForeignKey +from sqlalchemy import create_engine, ForeignKey, types as sqla_types from sqlalchemy.schema import MetaData, Table, Column, ColumnDefault -from sqlalchemy.types import String, Integer, Numeric, Date, Text, Enum, Boolean, DateTime from sqlalchemy.orm import sessionmaker, relationship, backref from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.dialects.postgresql import INET, MACADDR, UUID +from sqlalchemy.dialects.mysql import YEAR +from sqlalchemy.dialects.mssql import BIT from unittest import TestCase @@ -27,7 +29,7 @@ def __init__(self, **kwargs): setattr(self, k, v) -class AnotherInteger(Integer): +class AnotherInteger(sqla_types.Integer): """Use me to test if MRO works like we want""" @@ -37,14 +39,14 @@ def _do_tables(self, mapper, engine): test_table = Table( 'test', metadata, - Column('id', Integer, primary_key=True, nullable=False), - Column('name', String, nullable=False), + Column('id', sqla_types.Integer, primary_key=True, nullable=False), + Column('name', sqla_types.String, nullable=False), ) pk_test_table = Table( 'pk_test', metadata, - Column('foobar', String, primary_key=True, nullable=False), - Column('baz', String, nullable=False), + Column('foobar', sqla_types.String, primary_key=True, nullable=False), + Column('baz', sqla_types.String, nullable=False), ) Test = type(str('Test'), (Base, ), {}) @@ -188,33 +190,33 @@ def setUp(self): student_course = Table( 'student_course', Model.metadata, - Column('student_id', Integer, ForeignKey('student.id')), - Column('course_id', Integer, ForeignKey('course.id')) + Column('student_id', sqla_types.Integer, ForeignKey('student.id')), + Column('course_id', sqla_types.Integer, ForeignKey('course.id')) ) class Course(Model): __tablename__ = "course" - id = Column(Integer, primary_key=True) - name = Column(String(255), nullable=False) + id = Column(sqla_types.Integer, primary_key=True) + name = Column(sqla_types.String(255), nullable=False) # These are for better model form testing - cost = Column(Numeric(5, 2), nullable=False) - description = Column(Text, nullable=False) - level = Column(Enum('Primary', 'Secondary')) - has_prereqs = Column(Boolean, nullable=False) - started = Column(DateTime, nullable=False) + cost = Column(sqla_types.Numeric(5, 2), nullable=False) + description = Column(sqla_types.Text, nullable=False) + level = Column(sqla_types.Enum('Primary', 'Secondary')) + has_prereqs = Column(sqla_types.Boolean, nullable=False) + started = Column(sqla_types.DateTime, nullable=False) grade = Column(AnotherInteger, nullable=False) class School(Model): __tablename__ = "school" - id = Column(Integer, primary_key=True) - name = Column(String(255), nullable=False) + id = Column(sqla_types.Integer, primary_key=True) + name = Column(sqla_types.String(255), nullable=False) class Student(Model): __tablename__ = "student" - id = Column(Integer, primary_key=True) - full_name = Column(String(255), nullable=False, unique=True) - dob = Column(Date(), nullable=True) - current_school_id = Column(Integer, ForeignKey(School.id), nullable=False) + id = Column(sqla_types.Integer, primary_key=True) + full_name = Column(sqla_types.String(255), nullable=False, unique=True) + dob = Column(sqla_types.Date(), nullable=True) + current_school_id = Column(sqla_types.Integer, ForeignKey(School.id), nullable=False) current_school = relationship(School, backref=backref('students')) courses = relationship( @@ -282,9 +284,6 @@ def test_convert_basic(self): form_class = model_form(self.Course, exclude=['students']) form = form_class() self.assertEqual(len(list(form)), 7) - assert isinstance(form.cost, fields.DecimalField) - assert isinstance(form.has_prereqs, fields.BooleanField) - assert isinstance(form.started, fields.DateTimeField) def test_only(self): desired_fields = ['id', 'cost', 'description'] @@ -311,16 +310,16 @@ def default_score(): class StudentDefaultScoreCallable(Model): __tablename__ = "course" - id = Column(Integer, primary_key=True) - name = Column(String(255), nullable=False) - score = Column(Integer, default=default_score, nullable=False) + id = Column(sqla_types.Integer, primary_key=True) + name = Column(sqla_types.String(255), nullable=False) + score = Column(sqla_types.Integer, default=default_score, nullable=False) class StudentDefaultScoreScalar(Model): __tablename__ = "school" - id = Column(Integer, primary_key=True) - name = Column(String(255), nullable=False) + id = Column(sqla_types.Integer, primary_key=True) + name = Column(sqla_types.String(255), nullable=False) # Default scalar value - score = Column(Integer, default=10, nullable=False) + score = Column(sqla_types.Integer, default=10, nullable=False) self.StudentDefaultScoreCallable = StudentDefaultScoreCallable self.StudentDefaultScoreScalar = StudentDefaultScoreScalar @@ -339,3 +338,69 @@ def test_column_default_scalar(self): student_form = model_form(self.StudentDefaultScoreScalar, self.sess)() assert not isinstance(student_form._fields['score'].default, ColumnDefault) self.assertEqual(student_form._fields['score'].default, 10) + + +class ModelFormTest(TestCase): + def setUp(self): + Model = declarative_base() + + class AllTypesModel(Model): + __tablename__ = "course" + id = Column(sqla_types.Integer, primary_key=True) + string = Column(sqla_types.String) + unicode = Column(sqla_types.Unicode) + varchar = Column(sqla_types.VARCHAR) + integer = Column(sqla_types.Integer) + biginteger = Column(sqla_types.BigInteger) + smallinteger = Column(sqla_types.SmallInteger) + numeric = Column(sqla_types.Numeric) + float = Column(sqla_types.Float) + text = Column(sqla_types.Text) + binary = Column(sqla_types.Binary) + largebinary = Column(sqla_types.LargeBinary) + unicodetext = Column(sqla_types.UnicodeText) + enum = Column(sqla_types.Enum('Primary', 'Secondary')) + boolean = Column(sqla_types.Boolean) + datetime = Column(sqla_types.DateTime) + timestamp = Column(sqla_types.TIMESTAMP) + date = Column(sqla_types.Date) + postgres_inet = Column(INET) + postgres_macaddr = Column(MACADDR) + postgres_uuid = Column(UUID) + mysql_year = Column(YEAR) + mssql_bit = Column(BIT) + + self.AllTypesModel = AllTypesModel + + def test_convert_types(self): + form = model_form(self.AllTypesModel)() + + assert isinstance(form.string, fields.StringField) + assert isinstance(form.unicode, fields.StringField) + assert isinstance(form.varchar, fields.StringField) + assert isinstance(form.postgres_inet, fields.StringField) + assert isinstance(form.postgres_macaddr, fields.StringField) + assert isinstance(form.postgres_uuid, fields.StringField) + assert isinstance(form.mysql_year, fields.StringField) + + assert isinstance(form.integer, fields.IntegerField) + assert isinstance(form.biginteger, fields.IntegerField) + assert isinstance(form.smallinteger, fields.IntegerField) + + assert isinstance(form.numeric, fields.DecimalField) + assert isinstance(form.float, fields.DecimalField) + + assert isinstance(form.text, fields.TextAreaField) + assert isinstance(form.binary, fields.TextAreaField) + assert isinstance(form.largebinary, fields.TextAreaField) + assert isinstance(form.unicodetext, fields.TextAreaField) + + assert isinstance(form.enum, fields.SelectField) + + assert isinstance(form.boolean, fields.BooleanField) + assert isinstance(form.mssql_bit, fields.BooleanField) + + assert isinstance(form.datetime, fields.DateTimeField) + assert isinstance(form.timestamp, fields.DateTimeField) + + assert isinstance(form.date, fields.DateField) diff --git a/wtforms_sqlalchemy/orm.py b/wtforms_sqlalchemy/orm.py index 5acb454..2d6b9d3 100644 --- a/wtforms_sqlalchemy/orm.py +++ b/wtforms_sqlalchemy/orm.py @@ -5,8 +5,7 @@ import inspect -from wtforms import fields as f -from wtforms import validators +from wtforms import validators, fields as wtforms_fields from wtforms.form import Form from .fields import QuerySelectField, QuerySelectMultipleField @@ -42,6 +41,34 @@ def __init__(self, converters, use_mro=True): self.converters = converters + def get_converter(self, column): + """ + Searches `self.converters` for a converter method with an argument that + matches the column's type. + """ + if self.use_mro: + types = inspect.getmro(type(column.type)) + else: + types = [type(column.type)] + + # Search by module + name + for col_type in types: + type_string = '%s.%s' % (col_type.__module__, col_type.__name__) + + # remove the 'sqlalchemy.' prefix for sqlalchemy <0.7 compatibility + if type_string.startswith('sqlalchemy'): + type_string = type_string[11:] + + if type_string in self.converters: + return self.converters[type_string] + + # Search by name + for col_type in types: + if col_type.__name__ in self.converters: + return self.converters[col_type.__name__] + + raise ModelConversionError('Could not find field converter for %s (%r).' % (prop.key, types[0])) + def convert(self, model, mapper, prop, field_args, db_session=None): if not hasattr(prop, 'columns') and not hasattr(prop, 'direction'): return @@ -65,7 +92,6 @@ def convert(self, model, mapper, prop, field_args, db_session=None): converter = None column = None - types = None if not hasattr(prop, 'direction'): column = prop.columns[0] @@ -91,26 +117,7 @@ def convert(self, model, mapper, prop, field_args, db_session=None): else: kwargs['validators'].append(validators.Required()) - if self.use_mro: - types = inspect.getmro(type(column.type)) - else: - types = [type(column.type)] - - for col_type in types: - type_string = '%s.%s' % (col_type.__module__, col_type.__name__) - if type_string.startswith('sqlalchemy'): - type_string = type_string[11:] - - if type_string in self.converters: - converter = self.converters[type_string] - break - else: - for col_type in types: - if col_type.__name__ in self.converters: - converter = self.converters[col_type.__name__] - break - else: - raise ModelConversionError('Could not find field converter for %s (%r).' % (prop.key, types[0])) + converter = self.get_converter(column) else: # We have a property with a direction. if not db_session: @@ -148,68 +155,68 @@ def _string_common(cls, column, field_args, **extra): if isinstance(column.type.length, int) and column.type.length: field_args['validators'].append(validators.Length(max=column.type.length)) - @converts('String', 'Unicode') + @converts('String') # includes Unicode def conv_String(self, field_args, **extra): self._string_common(field_args=field_args, **extra) - return f.TextField(**field_args) + return wtforms_fields.StringField(**field_args) - @converts('types.Text', 'UnicodeText', 'types.LargeBinary', 'types.Binary', 'sql.sqltypes.Text') + @converts('Text', 'LargeBinary', 'Binary') # includes UnicodeText def conv_Text(self, field_args, **extra): self._string_common(field_args=field_args, **extra) - return f.TextAreaField(**field_args) + return wtforms_fields.TextAreaField(**field_args) - @converts('Boolean') + @converts('Boolean', 'dialects.mssql.base.BIT') def conv_Boolean(self, field_args, **extra): - return f.BooleanField(**field_args) + return wtforms_fields.BooleanField(**field_args) @converts('Date') def conv_Date(self, field_args, **extra): - return f.DateField(**field_args) + return wtforms_fields.DateField(**field_args) @converts('DateTime') def conv_DateTime(self, field_args, **extra): - return f.DateTimeField(**field_args) + return wtforms_fields.DateTimeField(**field_args) @converts('Enum') def conv_Enum(self, column, field_args, **extra): field_args['choices'] = [(e, e) for e in column.type.enums] - return f.SelectField(**field_args) + return wtforms_fields.SelectField(**field_args) - @converts('Integer', 'SmallInteger') + @converts('Integer') # includes BigInteger and SmallInteger def handle_integer_types(self, column, field_args, **extra): unsigned = getattr(column.type, 'unsigned', False) if unsigned: field_args['validators'].append(validators.NumberRange(min=0)) - return f.IntegerField(**field_args) + return wtforms_fields.IntegerField(**field_args) - @converts('Numeric', 'Float') + @converts('Numeric') # includes DECIMAL, Float/FLOAT, REAL, and DOUBLE def handle_decimal_types(self, column, field_args, **extra): # override default decimal places limit, use database defaults instead field_args.setdefault('places', None) - return f.DecimalField(**field_args) + return wtforms_fields.DecimalField(**field_args) - @converts('databases.mysql.MSYear', 'dialects.mysql.base.YEAR') + @converts('dialects.mysql.base.YEAR') def conv_MSYear(self, field_args, **extra): field_args['validators'].append(validators.NumberRange(min=1901, max=2155)) - return f.TextField(**field_args) + return wtforms_fields.StringField(**field_args) - @converts('databases.postgres.PGInet', 'dialects.postgresql.base.INET') + @converts('dialects.postgresql.base.INET') def conv_PGInet(self, field_args, **extra): field_args.setdefault('label', 'IP Address') field_args['validators'].append(validators.IPAddress()) - return f.TextField(**field_args) + return wtforms_fields.StringField(**field_args) @converts('dialects.postgresql.base.MACADDR') def conv_PGMacaddr(self, field_args, **extra): field_args.setdefault('label', 'MAC Address') field_args['validators'].append(validators.MacAddress()) - return f.TextField(**field_args) + return wtforms_fields.StringField(**field_args) @converts('dialects.postgresql.base.UUID') def conv_PGUuid(self, field_args, **extra): field_args.setdefault('label', 'UUID') field_args['validators'].append(validators.UUID()) - return f.TextField(**field_args) + return wtforms_fields.StringField(**field_args) @converts('MANYTOONE') def conv_ManyToOne(self, field_args, **extra):