Skip to content

Commit

Permalink
Backport recent change to NamedTuple classes regarding `__set_name_…
Browse files Browse the repository at this point in the history
…_` (#303)
  • Loading branch information
AlexWaygood committed Nov 29, 2023
1 parent 7af82f9 commit 4f91502
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
argument to the `msg` parameter. Patch by Alex Waygood.
- Exclude `__match_args__` from `Protocol` members,
this is a backport of https://github.com/python/cpython/pull/110683
- When creating a `typing_extensions.NamedTuple` class, ensure `__set_name__`
is called on all objects that define `__set_name__` and exist in the values
of the `NamedTuple` class's class dictionary. Patch by Alex Waygood,
backporting https://github.com/python/cpython/pull/111876.

# Release 4.8.0 (September 17, 2023)

Expand Down
135 changes: 135 additions & 0 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@
# versions, but not all
HAS_FORWARD_MODULE = "module" in inspect.signature(typing._type_check).parameters

skip_if_early_py313_alpha = skipIf(
sys.version_info[:4] == (3, 13, 0, 'alpha') and sys.version_info.serial < 3,
"Bugfixes will be released in 3.13.0a3"
)

ANN_MODULE_SOURCE = '''\
from typing import Optional
from functools import wraps
Expand Down Expand Up @@ -5548,6 +5553,136 @@ class GenericNamedTuple(NamedTuple, Generic[T]):

self.assertEqual(CallNamedTuple.__orig_bases__, (NamedTuple,))

@skip_if_early_py313_alpha
def test_setname_called_on_values_in_class_dictionary(self):
class Vanilla:
def __set_name__(self, owner, name):
self.name = name

class Foo(NamedTuple):
attr = Vanilla()

foo = Foo()
self.assertEqual(len(foo), 0)
self.assertNotIn('attr', Foo._fields)
self.assertIsInstance(foo.attr, Vanilla)
self.assertEqual(foo.attr.name, "attr")

class Bar(NamedTuple):
attr: Vanilla = Vanilla()

bar = Bar()
self.assertEqual(len(bar), 1)
self.assertIn('attr', Bar._fields)
self.assertIsInstance(bar.attr, Vanilla)
self.assertEqual(bar.attr.name, "attr")

@skipIf(
TYPING_3_12_0,
"__set_name__ behaviour changed on py312+ to use BaseException.add_note()"
)
def test_setname_raises_the_same_as_on_other_classes_py311_minus(self):
class CustomException(BaseException): pass

class Annoying:
def __set_name__(self, owner, name):
raise CustomException

annoying = Annoying()

with self.assertRaises(RuntimeError) as cm:
class NormalClass:
attr = annoying
normal_exception = cm.exception

with self.assertRaises(RuntimeError) as cm:
class NamedTupleClass(NamedTuple):
attr = annoying
namedtuple_exception = cm.exception

expected_note = (
"Error calling __set_name__ on 'Annoying' instance "
"'attr' in 'NamedTupleClass'"
)

self.assertIs(type(namedtuple_exception), RuntimeError)
self.assertIs(type(namedtuple_exception), type(normal_exception))
self.assertEqual(len(namedtuple_exception.args), len(normal_exception.args))
self.assertEqual(
namedtuple_exception.args[0],
normal_exception.args[0].replace("NormalClass", "NamedTupleClass")
)

self.assertIs(type(namedtuple_exception.__cause__), CustomException)
self.assertIs(
type(namedtuple_exception.__cause__), type(normal_exception.__cause__)
)
self.assertEqual(
namedtuple_exception.__cause__.args, normal_exception.__cause__.args
)

@skipUnless(
TYPING_3_12_0,
"__set_name__ behaviour changed on py312+ to use BaseException.add_note()"
)
@skip_if_early_py313_alpha
def test_setname_raises_the_same_as_on_other_classes_py312_plus(self):
class CustomException(BaseException): pass

class Annoying:
def __set_name__(self, owner, name):
raise CustomException

annoying = Annoying()

with self.assertRaises(CustomException) as cm:
class NormalClass:
attr = annoying
normal_exception = cm.exception

with self.assertRaises(CustomException) as cm:
class NamedTupleClass(NamedTuple):
attr = annoying
namedtuple_exception = cm.exception

expected_note = (
"Error calling __set_name__ on 'Annoying' instance "
"'attr' in 'NamedTupleClass'"
)

self.assertIs(type(namedtuple_exception), CustomException)
self.assertIs(type(namedtuple_exception), type(normal_exception))
self.assertEqual(namedtuple_exception.args, normal_exception.args)

self.assertEqual(len(namedtuple_exception.__notes__), 1)
self.assertEqual(
len(namedtuple_exception.__notes__), len(normal_exception.__notes__)
)

self.assertEqual(namedtuple_exception.__notes__[0], expected_note)
self.assertEqual(
namedtuple_exception.__notes__[0],
normal_exception.__notes__[0].replace("NormalClass", "NamedTupleClass")
)

@skip_if_early_py313_alpha
def test_strange_errors_when_accessing_set_name_itself(self):
class CustomException(Exception): pass

class Meta(type):
def __getattribute__(self, attr):
if attr == "__set_name__":
raise CustomException
return object.__getattribute__(self, attr)

class VeryAnnoying(metaclass=Meta): pass

very_annoying = VeryAnnoying()

with self.assertRaises(CustomException):
class Foo(NamedTuple):
attr = very_annoying


class TypeVarTests(BaseTestCase):
def test_basic_plain(self):
Expand Down
30 changes: 27 additions & 3 deletions src/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2467,11 +2467,35 @@ def __new__(cls, typename, bases, ns):
class_getitem = typing.Generic.__class_getitem__.__func__
nm_tpl.__class_getitem__ = classmethod(class_getitem)
# update from user namespace without overriding special namedtuple attributes
for key in ns:
for key, val in ns.items():
if key in _prohibited_namedtuple_fields:
raise AttributeError("Cannot overwrite NamedTuple attribute " + key)
elif key not in _special_namedtuple_fields and key not in nm_tpl._fields:
setattr(nm_tpl, key, ns[key])
elif key not in _special_namedtuple_fields:
if key not in nm_tpl._fields:
setattr(nm_tpl, key, ns[key])
try:
set_name = type(val).__set_name__
except AttributeError:
pass
else:
try:
set_name(val, nm_tpl, key)
except BaseException as e:
msg = (
f"Error calling __set_name__ on {type(val).__name__!r} "
f"instance {key!r} in {typename!r}"
)
# BaseException.add_note() existed on py311,
# but the __set_name__ machinery didn't start
# using add_note() until py312.
# Making sure exceptions are raised in the same way
# as in "normal" classes seems most important here.
if sys.version_info >= (3, 12):
e.add_note(msg)
raise
else:
raise RuntimeError(msg) from e

if typing.Generic in bases:
nm_tpl.__init_subclass__()
return nm_tpl
Expand Down

0 comments on commit 4f91502

Please sign in to comment.