diff --git a/tests/tests.py b/tests/tests.py index a305416..7255bc1 100755 --- a/tests/tests.py +++ b/tests/tests.py @@ -95,6 +95,21 @@ class F(Form): assert not form.validate() self.assertEqual(form.a.errors, ['Not a valid choice']) + form = F(a=sess.query(self.Test).filter_by(name='banana').first()) + another_sess = self.Session() + form.a.query = another_sess.query(self.Test) + #this is the problem. not same pointer. + self.assertEqual(form.a(), [('1', 'apple', False), ('2', 'banana', False)]) + + #So, use coerce and get_pk for test + class F2(Form): + a = QuerySelectField(get_label='name', widget=LazySelect(), get_pk=lambda x: x.id, coerce=int) + form = F2(a=sess.query(self.Test).filter_by(name='banana').first()) + another_sess = self.Session() + form.a.query = another_sess.query(self.Test) + self.assertEqual(form.a(), [('1', 'apple', False), ('2', 'banana', True)]) + + def test_with_query_factory(self): sess = self.Session() self._fill(sess) diff --git a/wtforms_sqlalchemy/fields.py b/wtforms_sqlalchemy/fields.py index 286dada..df3c8be 100644 --- a/wtforms_sqlalchemy/fields.py +++ b/wtforms_sqlalchemy/fields.py @@ -51,14 +51,19 @@ class QuerySelectField(SelectFieldBase): top of the list. Selecting this choice will result in the `data` property being `None`. The label for this blank choice can be set by specifying the `blank_text` parameter. + + If `coerce` and `get_pk` callables are specified, these two callables are + used for equality test. `coerce` should get pk and return value casted that + same type returned by `get_pk`. And selected is True when test passed. """ widget = widgets.Select() def __init__(self, label=None, validators=None, query_factory=None, get_pk=None, get_label=None, allow_blank=False, - blank_text='', **kwargs): + blank_text='', coerce=None, **kwargs): super(QuerySelectField, self).__init__(label, validators, **kwargs) self.query_factory = query_factory + self.coerce = coerce if get_pk is None: if not has_identity_key: @@ -105,7 +110,9 @@ def iter_choices(self): yield ('__None', self.blank_text, self.data is None) for pk, obj in self._get_object_list(): - yield (pk, self.get_label(obj), obj == self.data) + yield (pk, self.get_label(obj), obj == self.data or + (self.coerce is not None and self.get_pk is not None + and self.coerce(pk) == self.get_pk(self.data))) def process_formdata(self, valuelist): if valuelist: