Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions Lib/_pydecimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,34 +433,32 @@ class FloatOperation(DecimalException, TypeError):
# The getcontext() and setcontext() function manage access to a thread-local
# current context.

import threading
import contextvars

local = threading.local()
if hasattr(local, '__decimal_context__'):
del local.__decimal_context__
_current_context_var = contextvars.ContextVar('decimal_context')

def getcontext(_local=local):
def getcontext():
"""Returns this thread's context.

If this thread does not yet have a context, returns
a new context and sets this thread's context.
New contexts are copies of DefaultContext.
"""
try:
return _local.__decimal_context__
except AttributeError:
return _current_context_var.get()
except LookupError:
context = Context()
_local.__decimal_context__ = context
_current_context_var.set(context)
return context

def setcontext(context, _local=local):
def setcontext(context):
"""Set this thread's context to context."""
if context in (DefaultContext, BasicContext, ExtendedContext):
context = context.copy()
context.clear_flags()
_local.__decimal_context__ = context
_current_context_var.set(context)

del threading, local # Don't contaminate the namespace
del contextvars # Don't contaminate the namespace

def localcontext(ctx=None):
"""Return a context manager for a copy of the supplied context
Expand Down
29 changes: 29 additions & 0 deletions Lib/test/test_asyncio/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import asyncio
import decimal
import unittest


class DecimalContextTest(unittest.TestCase):

def test_asyncio_task_decimal_context(self):
async def fractions(t, precision, x, y):
with decimal.localcontext() as ctx:
ctx.prec = precision
a = decimal.Decimal(x) / decimal.Decimal(y)
await asyncio.sleep(t)
b = decimal.Decimal(x) / decimal.Decimal(y ** 2)
return a, b

async def main():
r1, r2 = await asyncio.gather(
fractions(0.1, 3, 1, 3), fractions(0.2, 6, 1, 3))

return r1, r2

r1, r2 = asyncio.run(main())

self.assertEqual(str(r1[0]), '0.333')
self.assertEqual(str(r1[1]), '0.111')

self.assertEqual(str(r2[0]), '0.333333')
self.assertEqual(str(r2[1]), '0.111111')
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor decimal module to use contextvars to store decimal context.
120 changes: 31 additions & 89 deletions Modules/_decimal/_decimal.c
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,7 @@ incr_false(void)
}


/* Key for thread state dictionary */
static PyObject *tls_context_key = NULL;
/* Invariant: NULL or the most recently accessed thread local context */
static PyDecContextObject *cached_context = NULL;
static PyContextVar *current_context_var;

/* Template for creating new thread contexts, calling Context() without
* arguments and initializing the module_context on first access. */
Expand Down Expand Up @@ -1220,10 +1217,6 @@ context_new(PyTypeObject *type, PyObject *args UNUSED, PyObject *kwds UNUSED)
static void
context_dealloc(PyDecContextObject *self)
{
if (self == cached_context) {
cached_context = NULL;
}

Py_XDECREF(self->traps);
Py_XDECREF(self->flags);
Py_TYPE(self)->tp_free(self);
Expand Down Expand Up @@ -1498,117 +1491,61 @@ static PyGetSetDef context_getsets [] =
* operation.
*/

/* Get the context from the thread state dictionary. */
static PyObject *
current_context_from_dict(void)
init_current_context(void)
{
PyObject *dict;
PyObject *tl_context;
PyThreadState *tstate;

dict = PyThreadState_GetDict();
if (dict == NULL) {
PyErr_SetString(PyExc_RuntimeError,
"cannot get thread state");
PyObject *tl_context = context_copy(default_context_template, NULL);
if (tl_context == NULL) {
return NULL;
}
CTX(tl_context)->status = 0;

tl_context = PyDict_GetItemWithError(dict, tls_context_key);
if (tl_context != NULL) {
/* We already have a thread local context. */
CONTEXT_CHECK(tl_context);
}
else {
if (PyErr_Occurred()) {
return NULL;
}

/* Set up a new thread local context. */
tl_context = context_copy(default_context_template, NULL);
if (tl_context == NULL) {
return NULL;
}
CTX(tl_context)->status = 0;

if (PyDict_SetItem(dict, tls_context_key, tl_context) < 0) {
Py_DECREF(tl_context);
return NULL;
}
PyContextToken *tok = PyContextVar_Set(current_context_var, tl_context);
if (tok == NULL) {
Py_DECREF(tl_context);
return NULL;
}
Py_DECREF(tok);

/* Cache the context of the current thread, assuming that it
* will be accessed several times before a thread switch. */
tstate = PyThreadState_GET();
if (tstate) {
cached_context = (PyDecContextObject *)tl_context;
cached_context->tstate = tstate;
}

/* Borrowed reference with refcount==1 */
return tl_context;
}

/* Return borrowed reference to thread local context. */
static PyObject *
static inline PyObject *
current_context(void)
{
PyThreadState *tstate;
PyObject *tl_context;
if (PyContextVar_Get(current_context_var, NULL, &tl_context) < 0) {
return NULL;
}

tstate = PyThreadState_GET();
if (cached_context && cached_context->tstate == tstate) {
return (PyObject *)cached_context;
if (tl_context != NULL) {
return tl_context;
}

return current_context_from_dict();
return init_current_context();
}

/* ctxobj := borrowed reference to the current context */
#define CURRENT_CONTEXT(ctxobj) \
ctxobj = current_context(); \
if (ctxobj == NULL) { \
return NULL; \
}

/* ctx := pointer to the mpd_context_t struct of the current context */
#define CURRENT_CONTEXT_ADDR(ctx) { \
PyObject *_c_t_x_o_b_j = current_context(); \
if (_c_t_x_o_b_j == NULL) { \
return NULL; \
} \
ctx = CTX(_c_t_x_o_b_j); \
}
} \
Py_DECREF(ctxobj);

/* Return a new reference to the current context */
static PyObject *
PyDec_GetCurrentContext(PyObject *self UNUSED, PyObject *args UNUSED)
{
PyObject *context;

context = current_context();
if (context == NULL) {
return NULL;
}

Py_INCREF(context);
return context;
return current_context();
}

/* Set the thread local context to a new context, decrement old reference */
static PyObject *
PyDec_SetCurrentContext(PyObject *self UNUSED, PyObject *v)
{
PyObject *dict;

CONTEXT_CHECK(v);

dict = PyThreadState_GetDict();
if (dict == NULL) {
PyErr_SetString(PyExc_RuntimeError,
"cannot get thread state");
return NULL;
}

/* If the new context is one of the templates, make a copy.
* This is the current behavior of decimal.py. */
if (v == default_context_template ||
Expand All @@ -1624,13 +1561,13 @@ PyDec_SetCurrentContext(PyObject *self UNUSED, PyObject *v)
Py_INCREF(v);
}

cached_context = NULL;
if (PyDict_SetItem(dict, tls_context_key, v) < 0) {
Py_DECREF(v);
PyContextToken *tok = PyContextVar_Set(current_context_var, v);
Py_DECREF(v);
if (tok == NULL) {
return NULL;
}
Py_DECREF(tok);

Py_DECREF(v);
Py_RETURN_NONE;
}

Expand Down Expand Up @@ -4458,6 +4395,7 @@ _dec_hash(PyDecObject *v)
if (context == NULL) {
return -1;
}
Py_DECREF(context);

if (mpd_isspecial(MPD(v))) {
if (mpd_issnan(MPD(v))) {
Expand Down Expand Up @@ -5599,6 +5537,11 @@ PyInit__decimal(void)
mpd_free = PyMem_Free;
mpd_setminalloc(_Py_DEC_MINALLOC);

/* Init context variable */
current_context_var = PyContextVar_New("decimal_context", NULL);
if (current_context_var == NULL) {
goto error;
}

/* Init external C-API functions */
_py_long_multiply = PyLong_Type.tp_as_number->nb_multiply;
Expand Down Expand Up @@ -5768,7 +5711,6 @@ PyInit__decimal(void)
CHECK_INT(PyModule_AddObject(m, "DefaultContext",
default_context_template));

ASSIGN_PTR(tls_context_key, PyUnicode_FromString("___DECIMAL_CTX__"));
Py_INCREF(Py_True);
CHECK_INT(PyModule_AddObject(m, "HAVE_THREADS", Py_True));

Expand Down Expand Up @@ -5827,9 +5769,9 @@ PyInit__decimal(void)
Py_CLEAR(SignalTuple); /* GCOV_NOT_REACHED */
Py_CLEAR(DecimalTuple); /* GCOV_NOT_REACHED */
Py_CLEAR(default_context_template); /* GCOV_NOT_REACHED */
Py_CLEAR(tls_context_key); /* GCOV_NOT_REACHED */
Py_CLEAR(basic_context_template); /* GCOV_NOT_REACHED */
Py_CLEAR(extended_context_template); /* GCOV_NOT_REACHED */
Py_CLEAR(current_context_var); /* GCOV_NOT_REACHED */
Py_CLEAR(m); /* GCOV_NOT_REACHED */

return NULL; /* GCOV_NOT_REACHED */
Expand Down