Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3.11] gh-114563: C decimal falls back to pydecimal for unsupported format strings (GH-114879) #115384

Merged
merged 1 commit into from Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 22 additions & 0 deletions Lib/test/test_decimal.py
Expand Up @@ -1121,6 +1121,13 @@ def test_formatting(self):
('z>z6.1f', '-0.', 'zzz0.0'),
('x>z6.1f', '-0.', 'xxx0.0'),
('🖤>z6.1f', '-0.', '🖤🖤🖤0.0'), # multi-byte fill char
('\x00>z6.1f', '-0.', '\x00\x00\x000.0'), # null fill char

# issue 114563 ('z' format on F type in cdecimal)
('z3,.10F', '-6.24E-323', '0.0000000000'),

# issue 91060 ('#' format in cdecimal)
('#', '0', '0.'),

# issue 6850
('a=-7.0', '0.12345', 'aaaa0.1'),
Expand Down Expand Up @@ -5712,6 +5719,21 @@ def test_c_signaldict_segfault(self):
with self.assertRaisesRegex(ValueError, err_msg):
sd.copy()

def test_format_fallback_capitals(self):
# Fallback to _pydecimal formatting (triggered by `#` format which
# is unsupported by mpdecimal) should honor the current context.
x = C.Decimal('6.09e+23')
self.assertEqual(format(x, '#'), '6.09E+23')
with C.localcontext(capitals=0):
self.assertEqual(format(x, '#'), '6.09e+23')

def test_format_fallback_rounding(self):
y = C.Decimal('6.09')
self.assertEqual(format(y, '#.1f'), '6.1')
with C.localcontext(rounding=C.ROUND_DOWN):
self.assertEqual(format(y, '#.1f'), '6.0')


@requires_docstrings
@requires_cdecimal
class SignatureTest(unittest.TestCase):
Expand Down
@@ -0,0 +1,4 @@
Fix several :func:`format()` bugs when using the C implementation of :class:`~decimal.Decimal`:
* memory leak in some rare cases when using the ``z`` format option (coerce negative 0)
* incorrect output when applying the ``z`` format option to type ``F`` (fixed-point with capital ``NAN`` / ``INF``)
* incorrect output when applying the ``#`` format option (alternate form)
182 changes: 60 additions & 122 deletions Modules/_decimal/_decimal.c
Expand Up @@ -144,6 +144,8 @@ static PyObject *default_context_template = NULL;
static PyObject *basic_context_template = NULL;
static PyObject *extended_context_template = NULL;

/* Invariant: NULL or pointer to _pydecimal.Decimal */
static PyObject *PyDecimal = NULL;

/* Error codes for functions that return signals or conditions */
#define DEC_INVALID_SIGNALS (MPD_Max_status+1U)
Expand Down Expand Up @@ -3245,56 +3247,6 @@ dotsep_as_utf8(const char *s)
return utf8;
}

/* copy of libmpdec _mpd_round() */
static void
_mpd_round(mpd_t *result, const mpd_t *a, mpd_ssize_t prec,
const mpd_context_t *ctx, uint32_t *status)
{
mpd_ssize_t exp = a->exp + a->digits - prec;

if (prec <= 0) {
mpd_seterror(result, MPD_Invalid_operation, status);
return;
}
if (mpd_isspecial(a) || mpd_iszero(a)) {
mpd_qcopy(result, a, status);
return;
}

mpd_qrescale_fmt(result, a, exp, ctx, status);
if (result->digits > prec) {
mpd_qrescale_fmt(result, result, exp+1, ctx, status);
}
}

/* Locate negative zero "z" option within a UTF-8 format spec string.
* Returns pointer to "z", else NULL.
* The portion of the spec we're working with is [[fill]align][sign][z] */
static const char *
format_spec_z_search(char const *fmt, Py_ssize_t size) {
char const *pos = fmt;
char const *fmt_end = fmt + size;
/* skip over [[fill]align] (fill may be multi-byte character) */
pos += 1;
while (pos < fmt_end && *pos & 0x80) {
pos += 1;
}
if (pos < fmt_end && strchr("<>=^", *pos) != NULL) {
pos += 1;
} else {
/* fill not present-- skip over [align] */
pos = fmt;
if (pos < fmt_end && strchr("<>=^", *pos) != NULL) {
pos += 1;
}
}
/* skip over [sign] */
if (pos < fmt_end && strchr("+- ", *pos) != NULL) {
pos += 1;
}
return pos < fmt_end && *pos == 'z' ? pos : NULL;
}

static int
dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const char **valuestr)
{
Expand All @@ -3320,6 +3272,48 @@ dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const
return 0;
}

/*
* Fallback _pydecimal formatting for new format specifiers that mpdecimal does
* not yet support. As documented, libmpdec follows the PEP-3101 format language:
* https://www.bytereef.org/mpdecimal/doc/libmpdec/assign-convert.html#to-string
*/
static PyObject *
pydec_format(PyObject *dec, PyObject *context, PyObject *fmt)
{
PyObject *result;
PyObject *pydec;
PyObject *u;

if (PyDecimal == NULL) {
PyDecimal = _PyImport_GetModuleAttrString("_pydecimal", "Decimal");
if (PyDecimal == NULL) {
return NULL;
}
}

u = dec_str(dec);
if (u == NULL) {
return NULL;
}

pydec = PyObject_CallOneArg(PyDecimal, u);
Py_DECREF(u);
if (pydec == NULL) {
return NULL;
}

result = PyObject_CallMethod(pydec, "__format__", "(OO)", fmt, context);
Py_DECREF(pydec);

if (result == NULL && PyErr_ExceptionMatches(PyExc_ValueError)) {
/* Do not confuse users with the _pydecimal exception */
PyErr_Clear();
PyErr_SetString(PyExc_ValueError, "invalid format string");
}

return result;
}

/* Formatted representation of a PyDecObject. */
static PyObject *
dec_format(PyObject *dec, PyObject *args)
Expand All @@ -3332,16 +3326,11 @@ dec_format(PyObject *dec, PyObject *args)
PyObject *fmtarg;
PyObject *context;
mpd_spec_t spec;
char const *fmt;
char *fmt_copy = NULL;
char *fmt;
char *decstring = NULL;
uint32_t status = 0;
int replace_fillchar = 0;
int no_neg_0 = 0;
Py_ssize_t size;
mpd_t *mpd = MPD(dec);
mpd_uint_t dt[MPD_MINALLOC_MAX];
mpd_t tmp = {MPD_STATIC|MPD_STATIC_DATA,0,0,0,MPD_MINALLOC_MAX,dt};


CURRENT_CONTEXT(context);
Expand All @@ -3350,39 +3339,20 @@ dec_format(PyObject *dec, PyObject *args)
}

if (PyUnicode_Check(fmtarg)) {
fmt = PyUnicode_AsUTF8AndSize(fmtarg, &size);
fmt = (char *)PyUnicode_AsUTF8AndSize(fmtarg, &size);
if (fmt == NULL) {
return NULL;
}
/* NOTE: If https://github.com/python/cpython/pull/29438 lands, the
* format string manipulation below can be eliminated by enhancing
* the forked mpd_parse_fmt_str(). */

if (size > 0 && fmt[0] == '\0') {
/* NUL fill character: must be replaced with a valid UTF-8 char
before calling mpd_parse_fmt_str(). */
replace_fillchar = 1;
fmt = fmt_copy = dec_strdup(fmt, size);
if (fmt_copy == NULL) {
fmt = dec_strdup(fmt, size);
if (fmt == NULL) {
return NULL;
}
fmt_copy[0] = '_';
}
/* Strip 'z' option, which isn't understood by mpd_parse_fmt_str().
* NOTE: fmt is always null terminated by PyUnicode_AsUTF8AndSize() */
char const *z_position = format_spec_z_search(fmt, size);
if (z_position != NULL) {
no_neg_0 = 1;
size_t z_index = z_position - fmt;
if (fmt_copy == NULL) {
fmt = fmt_copy = dec_strdup(fmt, size);
if (fmt_copy == NULL) {
return NULL;
}
}
/* Shift characters (including null terminator) left,
overwriting the 'z' option. */
memmove(fmt_copy + z_index, fmt_copy + z_index + 1, size - z_index);
size -= 1;
fmt[0] = '_';
}
}
else {
Expand All @@ -3392,10 +3362,13 @@ dec_format(PyObject *dec, PyObject *args)
}

if (!mpd_parse_fmt_str(&spec, fmt, CtxCaps(context))) {
PyErr_SetString(PyExc_ValueError,
"invalid format string");
goto finish;
if (replace_fillchar) {
PyMem_Free(fmt);
}

return pydec_format(dec, context, fmtarg);
}

if (replace_fillchar) {
/* In order to avoid clobbering parts of UTF-8 thousands separators or
decimal points when the substitution is reversed later, the actual
Expand Down Expand Up @@ -3448,45 +3421,8 @@ dec_format(PyObject *dec, PyObject *args)
}
}

if (no_neg_0 && mpd_isnegative(mpd) && !mpd_isspecial(mpd)) {
/* Round into a temporary (carefully mirroring the rounding
of mpd_qformat_spec()), and check if the result is negative zero.
If so, clear the sign and format the resulting positive zero. */
mpd_ssize_t prec;
mpd_qcopy(&tmp, mpd, &status);
if (spec.prec >= 0) {
switch (spec.type) {
case 'f':
mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status);
break;
case '%':
tmp.exp += 2;
mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status);
break;
case 'g':
prec = (spec.prec == 0) ? 1 : spec.prec;
if (tmp.digits > prec) {
_mpd_round(&tmp, &tmp, prec, CTX(context), &status);
}
break;
case 'e':
if (!mpd_iszero(&tmp)) {
_mpd_round(&tmp, &tmp, spec.prec+1, CTX(context), &status);
}
break;
}
}
if (status & MPD_Errors) {
PyErr_SetString(PyExc_ValueError, "unexpected error when rounding");
goto finish;
}
if (mpd_iszero(&tmp)) {
mpd_set_positive(&tmp);
mpd = &tmp;
}
}

decstring = mpd_qformat_spec(mpd, &spec, CTX(context), &status);
decstring = mpd_qformat_spec(MPD(dec), &spec, CTX(context), &status);
if (decstring == NULL) {
if (status & MPD_Malloc_error) {
PyErr_NoMemory();
Expand All @@ -3509,7 +3445,7 @@ dec_format(PyObject *dec, PyObject *args)
Py_XDECREF(grouping);
Py_XDECREF(sep);
Py_XDECREF(dot);
if (fmt_copy) PyMem_Free(fmt_copy);
if (replace_fillchar) PyMem_Free(fmt);
if (decstring) mpd_free(decstring);
return result;
}
Expand Down Expand Up @@ -5944,6 +5880,8 @@ PyInit__decimal(void)
/* Create the module */
ASSIGN_PTR(m, PyModule_Create(&_decimal_module));

/* For format specifiers not yet supported by libmpdec */
PyDecimal = NULL;

/* Add types to the module */
CHECK_INT(PyModule_AddObjectRef(m, "Decimal", (PyObject *)&PyDec_Type));
Expand Down
1 change: 1 addition & 0 deletions Tools/c-analyzer/cpython/globals-to-fix.tsv
Expand Up @@ -1273,6 +1273,7 @@ Modules/_decimal/_decimal.c - basic_context_template -
Modules/_decimal/_decimal.c - current_context_var -
Modules/_decimal/_decimal.c - default_context_template -
Modules/_decimal/_decimal.c - extended_context_template -
Modules/_decimal/_decimal.c - PyDecimal -
Modules/_decimal/_decimal.c - round_map -
Modules/_decimal/_decimal.c - Rational -
Modules/_decimal/_decimal.c - SignalTuple -
Expand Down