Skip to content

Commit

Permalink
Merge pull request #188 from simplejson/bpo-31505
Browse files Browse the repository at this point in the history
bpo-31505: Fix an assertion failure in json, in case _json.make_encoder() received a bad encoder() argument.
  • Loading branch information
etrepum committed Nov 20, 2017
2 parents 138a2ff + eb9665a commit 529268f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
21 changes: 18 additions & 3 deletions simplejson/_speedups.c
Original file line number Diff line number Diff line change
Expand Up @@ -2811,10 +2811,25 @@ static PyObject *
encoder_encode_string(PyEncoderObject *s, PyObject *obj)
{
/* Return the JSON representation of a string */
if (s->fast_encode)
PyObject *encoded;

if (s->fast_encode) {
return py_encode_basestring_ascii(NULL, obj);
else
return PyObject_CallFunctionObjArgs(s->encoder, obj, NULL);
}
encoded = PyObject_CallFunctionObjArgs(s->encoder, obj, NULL);
if (encoded != NULL &&
#if PY_MAJOR_VERSION < 3
!JSON_ASCII_Check(unicode) &&
#endif /* PY_MAJOR_VERSION < 3 */
!PyUnicode_Check(encoded))
{
PyErr_Format(PyExc_TypeError,
"encoder() must return a string, not %.80s",
Py_TYPE(encoded)->tp_name);
Py_DECREF(encoded);
return NULL;
}
return encoded;
}

static int
Expand Down
24 changes: 24 additions & 0 deletions simplejson/tests/test_speedups.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,30 @@ def test_make_encoder(self):
None
)

@skip_if_speedups_missing
def test_bad_str_encoder(self):
# Issue #31505: There shouldn't be an assertion failure in case
# c_make_encoder() receives a bad encoder() argument.
import decimal
def bad_encoder1(*args):
return None
enc = encoder.c_make_encoder(
None, lambda obj: str(obj),
bad_encoder1, None, ': ', ', ',
False, False, False, {}, False, False, False,
None, None, 'utf-8', False, False, decimal.Decimal, False)
self.assertRaises(TypeError, enc, 'spam', 4)
self.assertRaises(TypeError, enc, {'spam': 42}, 4)

def bad_encoder2(*args):
1/0
enc = encoder.c_make_encoder(
None, lambda obj: str(obj),
bad_encoder2, None, ': ', ', ',
False, False, False, {}, False, False, False,
None, None, 'utf-8', False, False, decimal.Decimal, False)
self.assertRaises(ZeroDivisionError, enc, 'spam', 4)

@skip_if_speedups_missing
def test_bad_bool_args(self):
def test(name):
Expand Down

0 comments on commit 529268f

Please sign in to comment.