Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions Lib/annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,17 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
# possibly constants if the annotate function uses them directly). We then
# convert each of those into a string to get an approximation of the
# original source.

# Attempt to call with VALUE_WITH_FAKE_GLOBALS to check if it is implemented
# See: https://github.com/python/cpython/issues/138764
# Only fail on NotImplementedError
try:
annotate(Format.VALUE_WITH_FAKE_GLOBALS)
except NotImplementedError:
raise
except Exception:
pass

globals = _StringifierDict({}, format=format)
is_class = isinstance(owner, type)
closure = _build_closure(
Expand Down Expand Up @@ -722,6 +733,10 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
)
try:
result = func(Format.VALUE_WITH_FAKE_GLOBALS)
except NotImplementedError:
# If NotImplementedError is raised, don't try to call again with
# no globals.
raise
except Exception:
pass
else:
Expand Down
80 changes: 80 additions & 0 deletions Lib/test/test_annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import itertools
import pickle
from string.templatelib import Template
import types
import typing
import unittest
import unittest.mock
from annotationlib import (
Format,
ForwardRef,
Expand Down Expand Up @@ -1206,6 +1208,84 @@ def evaluate(format, exc=NotImplementedError):
)


class TestCallAnnotateFunction(unittest.TestCase):
def _annotate_mock(self):
def annotate(format, /):
if format == Format.VALUE:
return {"x": str}
else:
raise NotImplementedError(format)

annotate_mock = unittest.mock.MagicMock(
wraps=annotate
)

# Add missing magic attributes needed
required_magic = [
"__builtins__",
"__closure__",
"__code__",
"__defaults__",
"__globals__",
"__kwdefaults__",
]

for attrib in required_magic:
setattr(annotate_mock, attrib, getattr(annotate, attrib))

return annotate_mock

def test_user_annotate_value(self):
annotate = self._annotate_mock()

annotations = annotationlib.call_annotate_function(
annotate,
Format.VALUE,
)

self.assertEqual(annotations, {"x": str})
annotate.assert_called_once_with(Format.VALUE)

def test_user_annotate_forwardref(self):
annotate = self._annotate_mock()

new_annotate = None
functype = types.FunctionType

def functiontype_mock(*args, **kwargs):
nonlocal new_annotate
new_func = unittest.mock.MagicMock(wraps=functype(*args, **kwargs))
new_annotate = new_func
return new_func

with unittest.mock.patch("types.FunctionType", new=functiontype_mock):
with self.assertRaises(NotImplementedError):
annotations = annotationlib.call_annotate_function(
annotate,
Format.FORWARDREF,
)

# Test the direct call
annotate.assert_called_once_with(Format.FORWARDREF)

# Test the call on the function with fake globals
new_annotate.assert_called_once_with(Format.VALUE_WITH_FAKE_GLOBALS)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a fan of these tests as they're asserting details of the implementation, not the observable behavior. Is there a way to write a test that's closer to your use case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this could instead test multiple annotate functions that provide only those formats?

Let me know if you prefer this or #138803 and I'll modify the tests for the relevant PR.


def test_user_annotate_string(self):
annotate = self._annotate_mock()

with self.assertRaises(NotImplementedError):
annotations = annotationlib.call_annotate_function(
annotate,
Format.STRING,
)

annotate.assert_has_calls([
unittest.mock.call(Format.STRING),
unittest.mock.call(Format.VALUE_WITH_FAKE_GLOBALS),
])


class MetaclassTests(unittest.TestCase):
def test_annotated_meta(self):
class Meta(type):
Expand Down
Loading