From 1349f72c93d3bddff281491eb3c44f53ecee6097 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 27 Dec 2017 15:55:15 -0500 Subject: [PATCH 1/4] bpo-32630: Use contextvars in the decimal module --- Lib/_pydecimal.py | 20 ++- .../2018-01-23-01-57-36.bpo-32630.6KRHBs.rst | 1 + Modules/_decimal/_decimal.c | 121 +++++------------- 3 files changed, 44 insertions(+), 98 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2018-01-23-01-57-36.bpo-32630.6KRHBs.rst diff --git a/Lib/_pydecimal.py b/Lib/_pydecimal.py index a1662bbd671360..359690003fe160 100644 --- a/Lib/_pydecimal.py +++ b/Lib/_pydecimal.py @@ -433,13 +433,11 @@ class FloatOperation(DecimalException, TypeError): # The getcontext() and setcontext() function manage access to a thread-local # current context. -import threading +import contextvars -local = threading.local() -if hasattr(local, '__decimal_context__'): - del local.__decimal_context__ +_current_context_var = contextvars.ContextVar('decimal_context') -def getcontext(_local=local): +def getcontext(): """Returns this thread's context. If this thread does not yet have a context, returns @@ -447,20 +445,20 @@ def getcontext(_local=local): New contexts are copies of DefaultContext. """ try: - return _local.__decimal_context__ - except AttributeError: + return _current_context_var.get() + except LookupError: context = Context() - _local.__decimal_context__ = context + _current_context_var.set(context) return context -def setcontext(context, _local=local): +def setcontext(context): """Set this thread's context to context.""" if context in (DefaultContext, BasicContext, ExtendedContext): context = context.copy() context.clear_flags() - _local.__decimal_context__ = context + _current_context_var.set(context) -del threading, local # Don't contaminate the namespace +del contextvars # Don't contaminate the namespace def localcontext(ctx=None): """Return a context manager for a copy of the supplied context diff --git a/Misc/NEWS.d/next/Library/2018-01-23-01-57-36.bpo-32630.6KRHBs.rst b/Misc/NEWS.d/next/Library/2018-01-23-01-57-36.bpo-32630.6KRHBs.rst new file mode 100644 index 00000000000000..1bbcbb173eb9ed --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-01-23-01-57-36.bpo-32630.6KRHBs.rst @@ -0,0 +1 @@ +Refactor decimal module to use contextvars to store decimal context. diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c index 18fa2e4fa5e475..afe12de3ff551f 100644 --- a/Modules/_decimal/_decimal.c +++ b/Modules/_decimal/_decimal.c @@ -122,10 +122,8 @@ incr_false(void) } -/* Key for thread state dictionary */ -static PyObject *tls_context_key = NULL; /* Invariant: NULL or the most recently accessed thread local context */ -static PyDecContextObject *cached_context = NULL; +static PyContextVar *current_context_var; /* Template for creating new thread contexts, calling Context() without * arguments and initializing the module_context on first access. */ @@ -1220,10 +1218,6 @@ context_new(PyTypeObject *type, PyObject *args UNUSED, PyObject *kwds UNUSED) static void context_dealloc(PyDecContextObject *self) { - if (self == cached_context) { - cached_context = NULL; - } - Py_XDECREF(self->traps); Py_XDECREF(self->flags); Py_TYPE(self)->tp_free(self); @@ -1498,69 +1492,42 @@ static PyGetSetDef context_getsets [] = * operation. */ -/* Get the context from the thread state dictionary. */ static PyObject * -current_context_from_dict(void) +init_current_context(void) { - PyObject *dict; - PyObject *tl_context; - PyThreadState *tstate; - - dict = PyThreadState_GetDict(); - if (dict == NULL) { - PyErr_SetString(PyExc_RuntimeError, - "cannot get thread state"); + PyObject *tl_context = context_copy(default_context_template, NULL); + if (tl_context == NULL) { return NULL; } + CTX(tl_context)->status = 0; - tl_context = PyDict_GetItemWithError(dict, tls_context_key); - if (tl_context != NULL) { - /* We already have a thread local context. */ - CONTEXT_CHECK(tl_context); - } - else { - if (PyErr_Occurred()) { - return NULL; - } - - /* Set up a new thread local context. */ - tl_context = context_copy(default_context_template, NULL); - if (tl_context == NULL) { - return NULL; - } - CTX(tl_context)->status = 0; - - if (PyDict_SetItem(dict, tls_context_key, tl_context) < 0) { - Py_DECREF(tl_context); - return NULL; - } + PyContextToken *tok = PyContextVar_Set(current_context_var, tl_context); + if (tok == NULL) { Py_DECREF(tl_context); + return NULL; } - - /* Cache the context of the current thread, assuming that it - * will be accessed several times before a thread switch. */ - tstate = PyThreadState_GET(); - if (tstate) { - cached_context = (PyDecContextObject *)tl_context; - cached_context->tstate = tstate; - } + Py_DECREF(tok); /* Borrowed reference with refcount==1 */ return tl_context; } -/* Return borrowed reference to thread local context. */ -static PyObject * + +/* Get the context from the thread state dictionary. */ +static inline PyObject * current_context(void) { - PyThreadState *tstate; + PyObject *tl_context; + if (PyContextVar_Get(current_context_var, NULL, &tl_context) < 0) { + return NULL; + } - tstate = PyThreadState_GET(); - if (cached_context && cached_context->tstate == tstate) { - return (PyObject *)cached_context; + if (tl_context != NULL) { + /* We already have a thread local context. */ + return tl_context; } - return current_context_from_dict(); + return init_current_context(); } /* ctxobj := borrowed reference to the current context */ @@ -1568,47 +1535,22 @@ current_context(void) ctxobj = current_context(); \ if (ctxobj == NULL) { \ return NULL; \ - } - -/* ctx := pointer to the mpd_context_t struct of the current context */ -#define CURRENT_CONTEXT_ADDR(ctx) { \ - PyObject *_c_t_x_o_b_j = current_context(); \ - if (_c_t_x_o_b_j == NULL) { \ - return NULL; \ - } \ - ctx = CTX(_c_t_x_o_b_j); \ -} + } \ + Py_DECREF(ctxobj); /* Return a new reference to the current context */ static PyObject * PyDec_GetCurrentContext(PyObject *self UNUSED, PyObject *args UNUSED) { - PyObject *context; - - context = current_context(); - if (context == NULL) { - return NULL; - } - - Py_INCREF(context); - return context; + return current_context(); } /* Set the thread local context to a new context, decrement old reference */ static PyObject * PyDec_SetCurrentContext(PyObject *self UNUSED, PyObject *v) { - PyObject *dict; - CONTEXT_CHECK(v); - dict = PyThreadState_GetDict(); - if (dict == NULL) { - PyErr_SetString(PyExc_RuntimeError, - "cannot get thread state"); - return NULL; - } - /* If the new context is one of the templates, make a copy. * This is the current behavior of decimal.py. */ if (v == default_context_template || @@ -1624,13 +1566,13 @@ PyDec_SetCurrentContext(PyObject *self UNUSED, PyObject *v) Py_INCREF(v); } - cached_context = NULL; - if (PyDict_SetItem(dict, tls_context_key, v) < 0) { - Py_DECREF(v); + PyContextToken *tok = PyContextVar_Set(current_context_var, v); + Py_DECREF(v); + if (tok == NULL) { return NULL; } + Py_DECREF(tok); - Py_DECREF(v); Py_RETURN_NONE; } @@ -4458,6 +4400,7 @@ _dec_hash(PyDecObject *v) if (context == NULL) { return -1; } + Py_DECREF(context); if (mpd_isspecial(MPD(v))) { if (mpd_issnan(MPD(v))) { @@ -5599,6 +5542,11 @@ PyInit__decimal(void) mpd_free = PyMem_Free; mpd_setminalloc(_Py_DEC_MINALLOC); + /* Init context variable */ + current_context_var = PyContextVar_New("decimal_context", NULL); + if (current_context_var == NULL) { + goto error; + } /* Init external C-API functions */ _py_long_multiply = PyLong_Type.tp_as_number->nb_multiply; @@ -5768,7 +5716,6 @@ PyInit__decimal(void) CHECK_INT(PyModule_AddObject(m, "DefaultContext", default_context_template)); - ASSIGN_PTR(tls_context_key, PyUnicode_FromString("___DECIMAL_CTX__")); Py_INCREF(Py_True); CHECK_INT(PyModule_AddObject(m, "HAVE_THREADS", Py_True)); @@ -5827,9 +5774,9 @@ PyInit__decimal(void) Py_CLEAR(SignalTuple); /* GCOV_NOT_REACHED */ Py_CLEAR(DecimalTuple); /* GCOV_NOT_REACHED */ Py_CLEAR(default_context_template); /* GCOV_NOT_REACHED */ - Py_CLEAR(tls_context_key); /* GCOV_NOT_REACHED */ Py_CLEAR(basic_context_template); /* GCOV_NOT_REACHED */ Py_CLEAR(extended_context_template); /* GCOV_NOT_REACHED */ + Py_CLEAR(current_context_var); /* GCOV_NOT_REACHED */ Py_CLEAR(m); /* GCOV_NOT_REACHED */ return NULL; /* GCOV_NOT_REACHED */ From bce636ba6be994157e3b7ab90362765c3635f45d Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Tue, 23 Jan 2018 11:37:48 -0500 Subject: [PATCH 2/4] Drop outdated comments --- Modules/_decimal/_decimal.c | 4 ---- 1 file changed, 4 deletions(-) diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c index afe12de3ff551f..c96d20f945b5df 100644 --- a/Modules/_decimal/_decimal.c +++ b/Modules/_decimal/_decimal.c @@ -122,7 +122,6 @@ incr_false(void) } -/* Invariant: NULL or the most recently accessed thread local context */ static PyContextVar *current_context_var; /* Template for creating new thread contexts, calling Context() without @@ -1512,8 +1511,6 @@ init_current_context(void) return tl_context; } - -/* Get the context from the thread state dictionary. */ static inline PyObject * current_context(void) { @@ -1523,7 +1520,6 @@ current_context(void) } if (tl_context != NULL) { - /* We already have a thread local context. */ return tl_context; } From cb668b7660f05e2c08a169eabf0eeeccc8731296 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Fri, 26 Jan 2018 17:46:46 -0500 Subject: [PATCH 3/4] Add an asyncio test to make sure decimal context works with tasks --- Lib/test/test_asyncio/test_context.py | 29 +++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 Lib/test/test_asyncio/test_context.py diff --git a/Lib/test/test_asyncio/test_context.py b/Lib/test/test_asyncio/test_context.py new file mode 100644 index 00000000000000..6abddd9f2515e1 --- /dev/null +++ b/Lib/test/test_asyncio/test_context.py @@ -0,0 +1,29 @@ +import asyncio +import decimal +import unittest + + +class DecimalContextTest(unittest.TestCase): + + def test_asyncio_task_decimal_context(self): + async def fractions(t, precision, x, y): + with decimal.localcontext() as ctx: + ctx.prec = precision + a = decimal.Decimal(x) / decimal.Decimal(y) + await asyncio.sleep(t) + b = decimal.Decimal(x) / decimal.Decimal(y ** 2) + return a, b + + async def main(): + r1, r2 = await asyncio.gather( + fractions(0.1, 3, 1, 3), fractions(0.2, 6, 1, 3)) + + return r1, r2 + + r1, r2 = asyncio.run(main()) + + self.assertEqual(str(r1[0]), '0.333') + self.assertEqual(str(r1[1]), '0.111') + + self.assertEqual(str(r2[0]), '0.333333') + self.assertEqual(str(r2[1]), '0.111111') From 74c8d8e336408cfb0f15e00c521d98891b1d41f3 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Sat, 27 Jan 2018 12:42:24 -0500 Subject: [PATCH 4/4] Remove outdated comment --- Modules/_decimal/_decimal.c | 1 - 1 file changed, 1 deletion(-) diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c index c96d20f945b5df..fddb39ef652abc 100644 --- a/Modules/_decimal/_decimal.c +++ b/Modules/_decimal/_decimal.c @@ -1507,7 +1507,6 @@ init_current_context(void) } Py_DECREF(tok); - /* Borrowed reference with refcount==1 */ return tl_context; }