Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions Lib/test/test_xml_etree_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import struct
from test import support
from test.support.import_helper import import_fresh_module
from test.support import threading_helper
import types
import unittest
import threading

cET = import_fresh_module('xml.etree.ElementTree',
fresh=['_elementtree'])
Expand Down Expand Up @@ -256,6 +258,45 @@ def test_element_with_children(self):
self.check_sizeof(e, self.elementsize + self.extra +
struct.calcsize('8P'))

@unittest.skipUnless(cET, 'requires _elementtree')
@threading_helper.requires_working_threading()
class TestElementTreeFreeThreading(unittest.TestCase):
def test_element_extra_race(self):
#Race len(), .attrib, and .clear() to verify fix for gh-149861.
root = cET.Element('root')
children = [cET.Element(f'child-{i}') for i in range(5)]

stop_event = threading.Event()

def reader_task():
while not stop_event.is_set():
len(root)
try:
_ = root.attrib
except AttributeError:
# In a race where clear() just ran, this is expected
# because of the PyErr_SetString is added in C.
pass

def writer_task():
while not stop_event.is_set():
# Test element_add_subelement / extend
root.extend(children)
# Test clear_extra
root.clear()

threads = []
for _ in range(4):
threads.append(threading.Thread(target=reader_task))
for _ in range(2):
threads.append(threading.Thread(target=writer_task))

with threading_helper.start_threads(threads):
import time
time.sleep(1.0)
stop_event.set()



def install_tests():
# Test classes should have __module__ referring to this module.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Fixes race conditions in :mod:`xml.etree.ElementTree` where the ``extra``
pointer dereference and ``attrib`` access was unsynchronized . This prevents
potential crashes (core dumps) in the free-threading build caused by
concurrent modification of an element's internal structure.
89 changes: 65 additions & 24 deletions Modules/_elementtree.c
Original file line number Diff line number Diff line change
Expand Up @@ -271,19 +271,31 @@ typedef struct {
LOCAL(int)
create_extra(ElementObject* self, PyObject* attrib)
{
self->extra = PyMem_Malloc(sizeof(ElementObjectExtra));
if (!self->extra) {
int res;
Py_BEGIN_CRITICAL_SECTION(self);

if (self->extra != NULL) {
res = 0;
goto end;
}
ElementObjectExtra* extra = PyMem_Malloc(sizeof(ElementObjectExtra));
if (!extra) {
PyErr_NoMemory();
return -1;
res = -1;
goto end;
}

self->extra->attrib = Py_XNewRef(attrib);
extra->attrib = Py_XNewRef(attrib);

self->extra->length = 0;
self->extra->allocated = STATIC_CHILDREN;
self->extra->children = self->extra->_children;
extra->length = 0;
extra->allocated = STATIC_CHILDREN;
extra->children = extra->_children;
self->extra = extra;
res = 0;

return 0;
end: ;
Py_END_CRITICAL_SECTION();
return res;
}

LOCAL(void)
Expand All @@ -297,7 +309,7 @@ dealloc_extra(ElementObjectExtra *extra)
Py_XDECREF(extra->attrib);

for (i = 0; i < extra->length; i++)
Py_DECREF(extra->children[i]);
Py_XDECREF(extra->children[i]);

if (extra->children != extra->_children) {
PyMem_Free(extra->children);
Expand All @@ -311,15 +323,16 @@ clear_extra(ElementObject* self)
{
ElementObjectExtra *myextra;

if (!self->extra)
return;
Py_BEGIN_CRITICAL_SECTION(self);

/* Avoid DECREFs calling into this code again (cycles, etc.)
*/
myextra = self->extra;
self->extra = NULL;

dealloc_extra(myextra);
Py_END_CRITICAL_SECTION();
if (myextra) {
dealloc_extra(myextra);
}
}

/* Convenience internal function to create new Element objects with the given
Expand Down Expand Up @@ -544,29 +557,42 @@ element_add_subelement(elementtreestate *st, ElementObject *self,
return -1;
}

Py_BEGIN_CRITICAL_SECTION(self);
if (element_resize(self, 1) < 0)
return -1;

self->extra->children[self->extra->length] = Py_NewRef(element);

self->extra->length++;

Py_END_CRITICAL_SECTION();
return 0;
}

LOCAL(PyObject*)
element_get_attrib(ElementObject* self)
{
/* return borrowed reference to attrib dictionary */
/* return new reference to attrib dictionary */
/* note: this function assumes that the extra section exists */

PyObject* res = self->extra->attrib;
PyObject *res = NULL;
Py_BEGIN_CRITICAL_SECTION(self);
if (self->extra == NULL) {
PyErr_SetString(PyExc_AttributeError, "extra section does not exist");
goto end;
}
res = self->extra->attrib;

if (!res) {
/* create missing dictionary */
res = self->extra->attrib = PyDict_New();
res = PyDict_New();
if (res) {
self->extra->attrib = res;
}
}

Py_XINCREF(res);
end: ;
Py_END_CRITICAL_SECTION();
return res;
}

Expand Down Expand Up @@ -667,20 +693,23 @@ element_gc_traverse(PyObject *op, visitproc visit, void *arg)
Py_VISIT(JOIN_OBJ(self->text));
Py_VISIT(JOIN_OBJ(self->tail));

Py_BEGIN_CRITICAL_SECTION(self);
if (self->extra) {
Py_ssize_t i;
Py_VISIT(self->extra->attrib);

for (i = 0; i < self->extra->length; ++i)
Py_VISIT(self->extra->children[i]);
}
Py_END_CRITICAL_SECTION();
return 0;
}

static int
element_gc_clear(PyObject *op)
{
ElementObject *self = _Element_CAST(op);
Py_BEGIN_CRITICAL_SECTION(self);
Py_CLEAR(self->tag);
_clear_joined_ptr(&self->text);
_clear_joined_ptr(&self->tail);
Expand All @@ -689,6 +718,7 @@ element_gc_clear(PyObject *op)
* so fully deallocate it.
*/
clear_extra(self);
Py_END_CRITICAL_SECTION();
return 0;
}

Expand Down Expand Up @@ -1625,10 +1655,14 @@ static Py_ssize_t
element_length(PyObject *op)
{
ElementObject *self = _Element_CAST(op);
if (!self->extra)
return 0;
Py_ssize_t res = 0;
Py_BEGIN_CRITICAL_SECTION(self);
if (self->extra) {
res = self->extra->length;
}
Py_END_CRITICAL_SECTION();

return self->extra->length;
return res;
}

/*[clinic input]
Expand Down Expand Up @@ -1766,9 +1800,12 @@ _elementtree_Element_set_impl(ElementObject *self, PyObject *key,
if (!attrib)
return NULL;

if (PyDict_SetItem(attrib, key, value) < 0)
if (PyDict_SetItem(attrib, key, value) < 0) {
Py_DECREF(attrib);
return NULL;
}

Py_DECREF(attrib);
Py_RETURN_NONE;
}

Expand Down Expand Up @@ -2077,14 +2114,18 @@ element_tail_getter(PyObject *op, void *closure)
static PyObject*
element_attrib_getter(PyObject *op, void *closure)
{
PyObject *res;
PyObject *res = NULL;
ElementObject *self = _Element_CAST(op);
Py_BEGIN_CRITICAL_SECTION(self);
if (!self->extra) {
if (create_extra(self, NULL) < 0)
return NULL;
goto end;
}
res = element_get_attrib(self);
return Py_XNewRef(res);

end: ;
Py_END_CRITICAL_SECTION();
return res;
}

/* macro for setter validation */
Expand Down
Loading