diff --git a/src/ComputedAttribute/tests.py b/src/ComputedAttribute/tests.py index 2763449..2ac5fab 100644 --- a/src/ComputedAttribute/tests.py +++ b/src/ComputedAttribute/tests.py @@ -73,6 +73,31 @@ def test_wrapper_support(): import unittest from doctest import DocTestSuite +from ExtensionClass import Base +from ComputedAttribute import ComputedAttribute + + +class TestComputedAttribute(unittest.TestCase): + def _construct_class(self, level): + class X(Base): + def _get_a(self): + return 1 + + a = ComputedAttribute(_get_a, level) + + return X + + def test_computed_attribute_on_class_level0(self): + x = self._construct_class(0)() + self.assertEqual(x.a, 1) + + def test_computed_attribute_on_class_level1(self): + x = self._construct_class(1)() + self.assertIsInstance(x.a, ComputedAttribute) + def test_suite(): - return unittest.TestSuite((DocTestSuite(),)) + suite = unittest.TestSuite() + suite.addTest(DocTestSuite()) + suite.addTest(unittest.makeSuite(TestComputedAttribute)) + return suite diff --git a/src/ExtensionClass/_ExtensionClass.c b/src/ExtensionClass/_ExtensionClass.c index 43d03eb..0750b2e 100644 --- a/src/ExtensionClass/_ExtensionClass.c +++ b/src/ExtensionClass/_ExtensionClass.c @@ -44,36 +44,112 @@ of_get(PyObject *self, PyObject *inst, PyObject *cls) PyObject * Base_getattro(PyObject *obj, PyObject *name) { - int name_is_parent = 0; - PyObject* res = NULL; - PyObject* desc_res = NULL; + PyTypeObject *tp = Py_TYPE(obj); + PyObject *descr = NULL; + PyObject *res = NULL; + descrgetfunc f; + PyObject **dictptr; + + if (!NATIVE_CHECK(name)) { +#ifndef PY3K +#ifdef Py_USING_UNICODE + /* The Unicode to string conversion is done here because the + existing tp_setattro slots expect a string object as name + and we wouldn't want to break those. */ + if (PyUnicode_Check(name)) { + name = PyUnicode_AsEncodedString(name, NULL, NULL); + if (name == NULL) + return NULL; + } + else +#endif +#endif + { + PyErr_Format(PyExc_TypeError, + "attribute name must be string, not '%.200s'", + Py_TYPE(name)->tp_name); + return NULL; + } + } + else + Py_INCREF(name); - res = PyObject_GenericGetAttr(obj, name); - if (res == NULL) { - return NULL; + if (tp->tp_dict == NULL) { + if (PyType_Ready(tp) < 0) + goto done; } - name_is_parent = PyObject_RichCompareBool(name, str__parent__, Py_EQ); - if (name_is_parent == -1) { - Py_DECREF(res); - return NULL; + descr = _PyType_Lookup(tp, name); + Py_XINCREF(descr); + + f = NULL; + if (descr != NULL && HAS_TP_DESCR_GET(descr)) { + f = descr->ob_type->tp_descr_get; + if (f != NULL && PyDescr_IsData(descr)) { + res = f(descr, obj, (PyObject *)obj->ob_type); + Py_DECREF(descr); + goto done; + } } - if (name_is_parent == 1) { - return res; + dictptr = _PyObject_GetDictPtr(obj); + + if (dictptr && *dictptr) { + Py_INCREF(*dictptr); + res = PyDict_GetItem(*dictptr, name); + if (res != NULL) { + Py_INCREF(res); + Py_XDECREF(descr); + Py_DECREF(*dictptr); + + /* CHANGED! If the tp_descr_get of res is of_get, then call it. */ + if (PyObject_TypeCheck(Py_TYPE(res), &ExtensionClassType)) { + if (Py_TYPE(res)->tp_descr_get) { + int name_is_parent = PyObject_RichCompareBool(name, str__parent__, Py_EQ); + + if (name_is_parent == 0) { + PyObject *tres = Py_TYPE(res)->tp_descr_get(res, obj, (PyObject*)Py_TYPE(obj)); + Py_DECREF(res); + res = tres; + } + else if (name_is_parent == -1) { + PyErr_Clear(); + } + } + } + /* End of change. */ + + goto done; + } + Py_DECREF(*dictptr); } - if (!PyObject_TypeCheck(Py_TYPE(res), &ExtensionClassType)) { - return res; + if (f != NULL) { + res = f(descr, obj, (PyObject *)Py_TYPE(obj)); + Py_DECREF(descr); + goto done; } - if (!Py_TYPE(res)->tp_descr_get) { - return res; + if (descr != NULL) { + res = descr; + /* descr was already increfed above */ + goto done; } - desc_res = Py_TYPE(res)->tp_descr_get(res, obj, (PyObject*)Py_TYPE(obj)); - Py_DECREF(res); - return desc_res; +#ifdef PY3K + PyErr_Format(PyExc_AttributeError, + "'%.50s' object has no attribute '%U'", + tp->tp_name, name); +#else + PyErr_Format(PyExc_AttributeError, + "'%.50s' object has no attribute '%.400s'", + tp->tp_name, PyString_AS_STRING(name)); +#endif + + done: + Py_DECREF(name); + return res; + } #include "pickle/pickle.c" diff --git a/src/ExtensionClass/__init__.py b/src/ExtensionClass/__init__.py index d7247bc..f6e302e 100644 --- a/src/ExtensionClass/__init__.py +++ b/src/ExtensionClass/__init__.py @@ -98,6 +98,7 @@ class init called 1 """ +import inspect import os import sys @@ -219,13 +220,38 @@ def __setattr__(self, name, value): # to care or worry about using super(): it's always object. def Base_getattro(self, name): - res = object.__getattribute__(self, name) - # If it's a descriptor for something besides __parent__, call it. - if name != '__parent__' and isinstance(res, Base): - descr_get = getattr(res, '__get__', None) - if descr_get is not None: - res = descr_get(self, type(self)) - return res + descr = None + + for base in type(self).__mro__: + if name in base.__dict__: + descr = base.__dict__[name] + break + + if descr is not None and inspect.isdatadescriptor(base): + return descr.__get__(self, type(self)) + + try: + # Don't do self.__dict__ otherwise you get recursion. + inst_dict = object.__getattribute__(self, '__dict__') + except AttributeError: + pass + else: + if name in inst_dict: + descr = inst_dict[name] + # If the tp_descr_get of res is of_get, then call it. + if name == '__parent__' or not isinstance(descr, Base): + return descr + + if descr is not None: + descr_get = getattr(descr, '__get__', None) + if descr_get is None: + return descr + + return descr_get(self, type(self)) + + raise AttributeError( + "'%.50s' object has not attribute '%s'", + type(self).__name__, name) def _slotnames(self): diff --git a/src/ExtensionClass/_compat.h b/src/ExtensionClass/_compat.h index 9f4cd0b..9249333 100644 --- a/src/ExtensionClass/_compat.h +++ b/src/ExtensionClass/_compat.h @@ -33,6 +33,8 @@ #define INT_CHECK(x) PyLong_Check(x) #define INT_AS_LONG(x) PyLong_AS_LONG(x) +#define HAS_TP_DESCR_GET(ob) 1 + #else #define INTERN PyString_InternFromString #define INTERN_INPLACE PyString_InternInPlace @@ -44,6 +46,8 @@ #define INT_FROM_LONG(x) PyInt_FromLong(x) #define INT_CHECK(x) PyInt_Check(x) #define INT_AS_LONG(x) PyInt_AS_LONG(x) + +#define HAS_TP_DESCR_GET(ob) PyType_HasFeature(Py_TYPE(ob), Py_TPFLAGS_HAVE_CLASS) #endif #endif