diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index d0ba34803c21e6..f8e51a489bd29b 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -1121,6 +1121,13 @@ def test_formatting(self): ('z>z6.1f', '-0.', 'zzz0.0'), ('x>z6.1f', '-0.', 'xxx0.0'), ('🖤>z6.1f', '-0.', '🖤🖤🖤0.0'), # multi-byte fill char + ('\x00>z6.1f', '-0.', '\x00\x00\x000.0'), # null fill char + + # issue 114563 ('z' format on F type in cdecimal) + ('z3,.10F', '-6.24E-323', '0.0000000000'), + + # issue 91060 ('#' format in cdecimal) + ('#', '0', '0.'), # issue 6850 ('a=-7.0', '0.12345', 'aaaa0.1'), @@ -5712,6 +5719,21 @@ def test_c_signaldict_segfault(self): with self.assertRaisesRegex(ValueError, err_msg): sd.copy() + def test_format_fallback_capitals(self): + # Fallback to _pydecimal formatting (triggered by `#` format which + # is unsupported by mpdecimal) should honor the current context. + x = C.Decimal('6.09e+23') + self.assertEqual(format(x, '#'), '6.09E+23') + with C.localcontext(capitals=0): + self.assertEqual(format(x, '#'), '6.09e+23') + + def test_format_fallback_rounding(self): + y = C.Decimal('6.09') + self.assertEqual(format(y, '#.1f'), '6.1') + with C.localcontext(rounding=C.ROUND_DOWN): + self.assertEqual(format(y, '#.1f'), '6.0') + + @requires_docstrings @requires_cdecimal class SignatureTest(unittest.TestCase): diff --git a/Misc/NEWS.d/next/Library/2024-02-11-20-23-36.gh-issue-114563.RzxNYT.rst b/Misc/NEWS.d/next/Library/2024-02-11-20-23-36.gh-issue-114563.RzxNYT.rst new file mode 100644 index 00000000000000..013b6db8e6dbd7 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2024-02-11-20-23-36.gh-issue-114563.RzxNYT.rst @@ -0,0 +1,4 @@ +Fix several :func:`format()` bugs when using the C implementation of :class:`~decimal.Decimal`: +* memory leak in some rare cases when using the ``z`` format option (coerce negative 0) +* incorrect output when applying the ``z`` format option to type ``F`` (fixed-point with capital ``NAN`` / ``INF``) +* incorrect output when applying the ``#`` format option (alternate form) diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c index a97490d48c90ac..5205abcfd65a53 100644 --- a/Modules/_decimal/_decimal.c +++ b/Modules/_decimal/_decimal.c @@ -144,6 +144,8 @@ static PyObject *default_context_template = NULL; static PyObject *basic_context_template = NULL; static PyObject *extended_context_template = NULL; +/* Invariant: NULL or pointer to _pydecimal.Decimal */ +static PyObject *PyDecimal = NULL; /* Error codes for functions that return signals or conditions */ #define DEC_INVALID_SIGNALS (MPD_Max_status+1U) @@ -3245,56 +3247,6 @@ dotsep_as_utf8(const char *s) return utf8; } -/* copy of libmpdec _mpd_round() */ -static void -_mpd_round(mpd_t *result, const mpd_t *a, mpd_ssize_t prec, - const mpd_context_t *ctx, uint32_t *status) -{ - mpd_ssize_t exp = a->exp + a->digits - prec; - - if (prec <= 0) { - mpd_seterror(result, MPD_Invalid_operation, status); - return; - } - if (mpd_isspecial(a) || mpd_iszero(a)) { - mpd_qcopy(result, a, status); - return; - } - - mpd_qrescale_fmt(result, a, exp, ctx, status); - if (result->digits > prec) { - mpd_qrescale_fmt(result, result, exp+1, ctx, status); - } -} - -/* Locate negative zero "z" option within a UTF-8 format spec string. - * Returns pointer to "z", else NULL. - * The portion of the spec we're working with is [[fill]align][sign][z] */ -static const char * -format_spec_z_search(char const *fmt, Py_ssize_t size) { - char const *pos = fmt; - char const *fmt_end = fmt + size; - /* skip over [[fill]align] (fill may be multi-byte character) */ - pos += 1; - while (pos < fmt_end && *pos & 0x80) { - pos += 1; - } - if (pos < fmt_end && strchr("<>=^", *pos) != NULL) { - pos += 1; - } else { - /* fill not present-- skip over [align] */ - pos = fmt; - if (pos < fmt_end && strchr("<>=^", *pos) != NULL) { - pos += 1; - } - } - /* skip over [sign] */ - if (pos < fmt_end && strchr("+- ", *pos) != NULL) { - pos += 1; - } - return pos < fmt_end && *pos == 'z' ? pos : NULL; -} - static int dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const char **valuestr) { @@ -3320,6 +3272,48 @@ dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const return 0; } +/* + * Fallback _pydecimal formatting for new format specifiers that mpdecimal does + * not yet support. As documented, libmpdec follows the PEP-3101 format language: + * https://www.bytereef.org/mpdecimal/doc/libmpdec/assign-convert.html#to-string + */ +static PyObject * +pydec_format(PyObject *dec, PyObject *context, PyObject *fmt) +{ + PyObject *result; + PyObject *pydec; + PyObject *u; + + if (PyDecimal == NULL) { + PyDecimal = _PyImport_GetModuleAttrString("_pydecimal", "Decimal"); + if (PyDecimal == NULL) { + return NULL; + } + } + + u = dec_str(dec); + if (u == NULL) { + return NULL; + } + + pydec = PyObject_CallOneArg(PyDecimal, u); + Py_DECREF(u); + if (pydec == NULL) { + return NULL; + } + + result = PyObject_CallMethod(pydec, "__format__", "(OO)", fmt, context); + Py_DECREF(pydec); + + if (result == NULL && PyErr_ExceptionMatches(PyExc_ValueError)) { + /* Do not confuse users with the _pydecimal exception */ + PyErr_Clear(); + PyErr_SetString(PyExc_ValueError, "invalid format string"); + } + + return result; +} + /* Formatted representation of a PyDecObject. */ static PyObject * dec_format(PyObject *dec, PyObject *args) @@ -3332,16 +3326,11 @@ dec_format(PyObject *dec, PyObject *args) PyObject *fmtarg; PyObject *context; mpd_spec_t spec; - char const *fmt; - char *fmt_copy = NULL; + char *fmt; char *decstring = NULL; uint32_t status = 0; int replace_fillchar = 0; - int no_neg_0 = 0; Py_ssize_t size; - mpd_t *mpd = MPD(dec); - mpd_uint_t dt[MPD_MINALLOC_MAX]; - mpd_t tmp = {MPD_STATIC|MPD_STATIC_DATA,0,0,0,MPD_MINALLOC_MAX,dt}; CURRENT_CONTEXT(context); @@ -3350,39 +3339,20 @@ dec_format(PyObject *dec, PyObject *args) } if (PyUnicode_Check(fmtarg)) { - fmt = PyUnicode_AsUTF8AndSize(fmtarg, &size); + fmt = (char *)PyUnicode_AsUTF8AndSize(fmtarg, &size); if (fmt == NULL) { return NULL; } - /* NOTE: If https://github.com/python/cpython/pull/29438 lands, the - * format string manipulation below can be eliminated by enhancing - * the forked mpd_parse_fmt_str(). */ + if (size > 0 && fmt[0] == '\0') { /* NUL fill character: must be replaced with a valid UTF-8 char before calling mpd_parse_fmt_str(). */ replace_fillchar = 1; - fmt = fmt_copy = dec_strdup(fmt, size); - if (fmt_copy == NULL) { + fmt = dec_strdup(fmt, size); + if (fmt == NULL) { return NULL; } - fmt_copy[0] = '_'; - } - /* Strip 'z' option, which isn't understood by mpd_parse_fmt_str(). - * NOTE: fmt is always null terminated by PyUnicode_AsUTF8AndSize() */ - char const *z_position = format_spec_z_search(fmt, size); - if (z_position != NULL) { - no_neg_0 = 1; - size_t z_index = z_position - fmt; - if (fmt_copy == NULL) { - fmt = fmt_copy = dec_strdup(fmt, size); - if (fmt_copy == NULL) { - return NULL; - } - } - /* Shift characters (including null terminator) left, - overwriting the 'z' option. */ - memmove(fmt_copy + z_index, fmt_copy + z_index + 1, size - z_index); - size -= 1; + fmt[0] = '_'; } } else { @@ -3392,10 +3362,13 @@ dec_format(PyObject *dec, PyObject *args) } if (!mpd_parse_fmt_str(&spec, fmt, CtxCaps(context))) { - PyErr_SetString(PyExc_ValueError, - "invalid format string"); - goto finish; + if (replace_fillchar) { + PyMem_Free(fmt); + } + + return pydec_format(dec, context, fmtarg); } + if (replace_fillchar) { /* In order to avoid clobbering parts of UTF-8 thousands separators or decimal points when the substitution is reversed later, the actual @@ -3448,45 +3421,8 @@ dec_format(PyObject *dec, PyObject *args) } } - if (no_neg_0 && mpd_isnegative(mpd) && !mpd_isspecial(mpd)) { - /* Round into a temporary (carefully mirroring the rounding - of mpd_qformat_spec()), and check if the result is negative zero. - If so, clear the sign and format the resulting positive zero. */ - mpd_ssize_t prec; - mpd_qcopy(&tmp, mpd, &status); - if (spec.prec >= 0) { - switch (spec.type) { - case 'f': - mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status); - break; - case '%': - tmp.exp += 2; - mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status); - break; - case 'g': - prec = (spec.prec == 0) ? 1 : spec.prec; - if (tmp.digits > prec) { - _mpd_round(&tmp, &tmp, prec, CTX(context), &status); - } - break; - case 'e': - if (!mpd_iszero(&tmp)) { - _mpd_round(&tmp, &tmp, spec.prec+1, CTX(context), &status); - } - break; - } - } - if (status & MPD_Errors) { - PyErr_SetString(PyExc_ValueError, "unexpected error when rounding"); - goto finish; - } - if (mpd_iszero(&tmp)) { - mpd_set_positive(&tmp); - mpd = &tmp; - } - } - decstring = mpd_qformat_spec(mpd, &spec, CTX(context), &status); + decstring = mpd_qformat_spec(MPD(dec), &spec, CTX(context), &status); if (decstring == NULL) { if (status & MPD_Malloc_error) { PyErr_NoMemory(); @@ -3509,7 +3445,7 @@ dec_format(PyObject *dec, PyObject *args) Py_XDECREF(grouping); Py_XDECREF(sep); Py_XDECREF(dot); - if (fmt_copy) PyMem_Free(fmt_copy); + if (replace_fillchar) PyMem_Free(fmt); if (decstring) mpd_free(decstring); return result; } @@ -5944,6 +5880,8 @@ PyInit__decimal(void) /* Create the module */ ASSIGN_PTR(m, PyModule_Create(&_decimal_module)); + /* For format specifiers not yet supported by libmpdec */ + PyDecimal = NULL; /* Add types to the module */ CHECK_INT(PyModule_AddObjectRef(m, "Decimal", (PyObject *)&PyDec_Type)); diff --git a/Tools/c-analyzer/cpython/globals-to-fix.tsv b/Tools/c-analyzer/cpython/globals-to-fix.tsv index 8fe861576228b3..9d9cc7acc68417 100644 --- a/Tools/c-analyzer/cpython/globals-to-fix.tsv +++ b/Tools/c-analyzer/cpython/globals-to-fix.tsv @@ -1273,6 +1273,7 @@ Modules/_decimal/_decimal.c - basic_context_template - Modules/_decimal/_decimal.c - current_context_var - Modules/_decimal/_decimal.c - default_context_template - Modules/_decimal/_decimal.c - extended_context_template - +Modules/_decimal/_decimal.c - PyDecimal - Modules/_decimal/_decimal.c - round_map - Modules/_decimal/_decimal.c - Rational - Modules/_decimal/_decimal.c - SignalTuple -