Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove pin on SQLAlchemy version < 2.0 #45

Merged
merged 5 commits into from Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Expand Up @@ -35,7 +35,7 @@ include_package_data = true
python_requires = >= 3.8
install_requires =
WTForms>=3.1
SQLAlchemy>=0.7.10,<2
SQLAlchemy>=1.4

[flake8]
# B = bugbear
Expand Down
16 changes: 8 additions & 8 deletions tests/test_main.py
Expand Up @@ -8,13 +8,13 @@
from sqlalchemy.dialects.postgresql import INET
from sqlalchemy.dialects.postgresql import MACADDR
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import backref
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import registry
from sqlalchemy.orm import relationship
from sqlalchemy.orm import sessionmaker
from sqlalchemy.schema import Column
from sqlalchemy.schema import ColumnDefault
from sqlalchemy.schema import MetaData
from sqlalchemy.schema import Table
from wtforms import fields
from wtforms import Form
Expand Down Expand Up @@ -51,18 +51,18 @@ class AnotherInteger(sqla_types.Integer):

class TestBase(TestCase):
def _do_tables(self, mapper, engine):
metadata = MetaData()
mapper_registry = registry()

test_table = Table(
"test",
metadata,
mapper_registry.metadata,
Column("id", sqla_types.Integer, primary_key=True, nullable=False),
Column("name", sqla_types.String, nullable=False),
)

pk_test_table = Table(
"pk_test",
metadata,
mapper_registry.metadata,
Column("foobar", sqla_types.String, primary_key=True, nullable=False),
Column("baz", sqla_types.String, nullable=False),
)
Expand All @@ -74,12 +74,12 @@ def _do_tables(self, mapper, engine):
{"__unicode__": lambda x: x.baz, "__str__": lambda x: x.baz},
)

mapper(Test, test_table)
mapper(PKTest, pk_test_table)
mapper_registry.map_imperatively(Test, test_table)
mapper_registry.map_imperatively(PKTest, pk_test_table)
self.Test = Test
self.PKTest = PKTest

metadata.create_all(bind=engine)
mapper_registry.metadata.create_all(bind=engine)

def _fill(self, sess):
for i, n in [(1, "apple"), (2, "banana")]:
Expand Down
15 changes: 10 additions & 5 deletions wtforms_sqlalchemy/orm.py
Expand Up @@ -3,6 +3,7 @@
"""
import inspect

from sqlalchemy import inspect as sainspect
from wtforms import fields as wtforms_fields
from wtforms import validators
from wtforms.form import Form
Expand Down Expand Up @@ -211,19 +212,23 @@ def conv_MSYear(self, field_args, **extra):
field_args["validators"].append(validators.NumberRange(min=1901, max=2155))
return wtforms_fields.StringField(**field_args)

@converts("dialects.postgresql.base.INET")
@converts("dialects.postgresql.types.INET", "dialects.postgresql.base.INET")
def conv_PGInet(self, field_args, **extra):
field_args.setdefault("label", "IP Address")
field_args["validators"].append(validators.IPAddress())
return wtforms_fields.StringField(**field_args)

@converts("dialects.postgresql.base.MACADDR")
@converts("dialects.postgresql.types.MACADDR", "dialects.postgresql.base.MACADDR")
def conv_PGMacaddr(self, field_args, **extra):
field_args.setdefault("label", "MAC Address")
field_args["validators"].append(validators.MacAddress())
return wtforms_fields.StringField(**field_args)

@converts("dialects.postgresql.base.UUID")
@converts(
"sql.sqltypes.UUID",
"dialects.postgresql.types.UUID",
"dialects.postgresql.base.UUID",
)
def conv_PGUuid(self, field_args, **extra):
field_args.setdefault("label", "UUID")
field_args["validators"].append(validators.UUID())
Expand Down Expand Up @@ -253,12 +258,12 @@ def model_fields(

See `model_form` docstring for description of parameters.
"""
mapper = model._sa_class_manager.mapper
mapper = sainspect(model)
converter = converter or ModelConverter()
field_args = field_args or {}
properties = []

for prop in mapper.iterate_properties:
for prop in mapper.attrs.values():
if getattr(prop, "columns", None):
if exclude_fk and prop.columns[0].foreign_keys:
continue
Expand Down