diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py index 2166dbff0ee70c..a2496f64ddb63f 100644 --- a/Lib/annotationlib.py +++ b/Lib/annotationlib.py @@ -728,13 +728,13 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): annotate, owner, is_class, globals, allow_evaluation=False ) func = types.FunctionType( - annotate.__code__, + _get_annotate_attr(annotate, "__code__"), globals, closure=closure, - argdefs=annotate.__defaults__, - kwdefaults=annotate.__kwdefaults__, + argdefs=_get_annotate_attr(annotate, "__defaults__", None), + kwdefaults=_get_annotate_attr(annotate, "__kwdefaults__", None), ) - annos = func(Format.VALUE_WITH_FAKE_GLOBALS) + annos = _direct_call_annotate(func, annotate, Format.VALUE_WITH_FAKE_GLOBALS) if _is_evaluate: return _stringify_single(annos) return { @@ -759,11 +759,21 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): # reconstruct the source. But in the dictionary that we eventually return, we # want to return objects with more user-friendly behavior, such as an __eq__ # that returns a bool and an defined set of attributes. - namespace = {**annotate.__builtins__, **annotate.__globals__} + + # Grab and store all the annotate function attributes that we might need to access + # multiple times as variables, as this could be a bit expensive for non-functions. + annotate_globals = _get_annotate_attr(annotate, "__globals__", {}) + annotate_code = _get_annotate_attr(annotate, "__code__") + annotate_defaults = _get_annotate_attr(annotate, "__defaults__", None) + annotate_kwdefaults = _get_annotate_attr(annotate, "__kwdefaults__", None) + namespace = { + **_get_annotate_attr(annotate, "__builtins__", {}), + **annotate_globals + } is_class = isinstance(owner, type) globals = _StringifierDict( namespace, - globals=annotate.__globals__, + globals=annotate_globals, owner=owner, is_class=is_class, format=format, @@ -772,14 +782,14 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): annotate, owner, is_class, globals, allow_evaluation=True ) func = types.FunctionType( - annotate.__code__, + annotate_code, globals, closure=closure, - argdefs=annotate.__defaults__, - kwdefaults=annotate.__kwdefaults__, + argdefs=annotate_defaults, + kwdefaults=annotate_kwdefaults, ) try: - result = func(Format.VALUE_WITH_FAKE_GLOBALS) + result = _direct_call_annotate(func, annotate, Format.VALUE_WITH_FAKE_GLOBALS) except NotImplementedError: # FORWARDREF and VALUE_WITH_FAKE_GLOBALS not supported, fall back to VALUE return annotate(Format.VALUE) @@ -793,7 +803,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): # a value in certain cases where an exception gets raised during evaluation. globals = _StringifierDict( {}, - globals=annotate.__globals__, + globals=annotate_globals, owner=owner, is_class=is_class, format=format, @@ -802,13 +812,13 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): annotate, owner, is_class, globals, allow_evaluation=False ) func = types.FunctionType( - annotate.__code__, + annotate_code, globals, closure=closure, - argdefs=annotate.__defaults__, - kwdefaults=annotate.__kwdefaults__, + argdefs=annotate_defaults, + kwdefaults=annotate_kwdefaults, ) - result = func(Format.VALUE_WITH_FAKE_GLOBALS) + result = _direct_call_annotate(func, annotate, Format.VALUE_WITH_FAKE_GLOBALS) globals.transmogrify(cell_dict) if _is_evaluate: if isinstance(result, ForwardRef): @@ -833,12 +843,13 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): def _build_closure(annotate, owner, is_class, stringifier_dict, *, allow_evaluation): - if not annotate.__closure__: + closure = _get_annotate_attr(annotate, "__closure__", None) + if not closure: return None, None - freevars = annotate.__code__.co_freevars + freevars = _get_annotate_attr(annotate, "__code__", None).co_freevars new_closure = [] cell_dict = {} - for i, cell in enumerate(annotate.__closure__): + for i, cell in enumerate(closure): if i < len(freevars): name = freevars[i] else: @@ -857,7 +868,7 @@ def _build_closure(annotate, owner, is_class, stringifier_dict, *, allow_evaluat name, cell=cell, owner=owner, - globals=annotate.__globals__, + globals=_get_annotate_attr(annotate, "__globals__", {}), is_class=is_class, stringifier_dict=stringifier_dict, ) @@ -879,6 +890,141 @@ def _stringify_single(anno): return repr(anno) +def _get_annotate_attr(annotate, attr, default=_sentinel): + # Try to get the attr on the annotate function. If it doesn't exist, we might + # need to look in other places on the object. If all of those fail, we can + # return the default at the end. + if hasattr(annotate, attr): + return getattr(annotate, attr) + + # Redirect method attribute access to the underlying function. The C code + # verifies that the __func__ attribute is some kind of callable, so we need + # to look for attributes recursively. + if isinstance(annotate, types.MethodType): + return _get_annotate_attr(annotate.__func__, attr, default) + + # Python generics are callable. Usually, the __init__ method sets attributes. + # However, typing._BaseGenericAlias overrides the __init__ method, so we need + # to use the original class method for fake globals and the like. + # _BaseGenericAlias also override __call__, so let's handle this earlier than + # other class construction. + if ( + (typing := sys.modules.get("typing", None)) + and isinstance(annotate, typing._BaseGenericAlias) + ): + return _get_annotate_attr(annotate.__origin__.__init__, attr, default) + + # If annotate is a class instance, its __call__ is the relevant function. + # However, __call__ Could be a method, a function descriptor, or any other callable. + # Normal functions have a __call__ property which is a useless method wrapper, + # ignore these. + if ( + (call := getattr(annotate, "__call__", None)) and + not isinstance(call, types.MethodWrapperType) + ): + return _get_annotate_attr(annotate.__call__, attr, default) + + # Classes and generics are callable. Usually the __init__ method sets attributes, + # so let's access this method for fake globals and the like. + # Technically __init__ can be any callable object, so we recurse. + if isinstance(annotate, type) or isinstance(annotate, types.GenericAlias): + return _get_annotate_attr(annotate.__init__, attr, default) + + # Most 'wrapped' functions, including functools.cache and staticmethod, need us + # to manually, recursively unwrap. For partial.update_wrapper functions, the + # attribute is accessible on the function itself, so we never get this far. + if hasattr(annotate, "__wrapped__"): + return _get_annotate_attr(annotate.__wrapped__, attr, default) + + # Partial functions and methods both store their underlying function as a + # func attribute. They can wrap any callable, so we need to recursively unwrap. + if ( + (functools := sys.modules.get("functools", None)) + and isinstance(annotate, functools.partial) + ): + return _get_annotate_attr(annotate.func, attr, default) + + if default is _sentinel: + raise TypeError(f"annotate function missing {attr!r} attribute") + return default + +def _direct_call_annotate(func, annotate, *args): + # If annotate is a method, we need to pass self as the first param. + if ( + hasattr(annotate, "__func__") and + (self := getattr(annotate, "__self__", None)) + ): + # We don't know what type of callable will be in the __func__ attribute, + # so let's try again with knowledge of that type, including self as the first + # argument. + return _direct_call_annotate(func, annotate.__func__, self, *args) + + # Python generics (typing._BaseGenericAlias) override __call__, so let's handle + # them earlier than other class construction. + if ( + (typing := sys.modules.get("typing", None)) + and isinstance(annotate, typing._BaseGenericAlias) + ): + inst = annotate.__new__(annotate.__origin__) + func(inst, *args) + # Try to set the original class on the instance, if possible. + # This is the same logic used in typing for custom generics. + try: + inst.__orig_class__ = annotate + except Exception: + pass + return inst + + # If annotate is a class instance, its __call__ is the function. + # __call__ Could be a method, a function descriptor, or any other callable. + # Normal functions have a __call__ property which is a useless method wrapper, + # ignore these. + if ( + (call := getattr(annotate, "__call__", None)) and + not isinstance(call, types.MethodWrapperType) + ): + return _direct_call_annotate(func, annotate.__call__, *args) + + # If annotate is a class, `func` is the __init__ method, so we still need to call + # __new__() to create the instance + if isinstance(annotate, type): + inst = annotate.__new__(annotate) + # func might refer to some non-function object. + _direct_call_annotate(func, annotate.__init__, inst, *args) + return inst + + # Generic instantiation is slightly different. Since we want to give + # __call__ priority, the custom logic for builtin generics is here. + if isinstance(annotate, types.GenericAlias): + inst = annotate.__new__(annotate.__origin__) + # func might refer to some non-function object. + _direct_call_annotate(func, annotate.__init__, inst, *args) + # Try to set the original class on the instance, if possible. + # This is the same logic used in typing for custom generics. + try: + inst.__orig_class__ = annotate + except Exception: + pass + return inst + + if functools := sys.modules.get("functools", None): + # If annotate is a partial function, re-create it with the new function object. + # We could call the function directly, but then we'd have to handle placeholders, + # and this way should be more robust for future changes. + if isinstance(annotate, functools.partial): + # Partial methods + if self := getattr(annotate, "__self__", None): + return functools.partial(func, self, *annotate.args, **annotate.keywords)(*args) + return functools.partial(func, *annotate.args, **annotate.keywords)(*args) + + # If annotate is a cached function, we've now updated the function data, so + # let's not use the old cache. Furthermore, we're about to call the function + # and never use it again, so let's not bother trying to cache it. + + # Or, if it's a normal function or unsupported callable, we should just call it. + return func(*args) + + def get_annotate_from_class_namespace(obj): """Retrieve the annotate function from a class namespace dictionary. diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index 9f3275d5071484..9d09374222350f 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -8,6 +8,8 @@ import itertools import pickle from string.templatelib import Template, Interpolation +import random +import types import typing import sys import unittest @@ -1132,6 +1134,26 @@ def __annotate__(self): {"x": "int"}, ) + def test_non_function_annotate(self): + class AnnotateCallable: + def __call__(self, format, /): + if format > 2: + raise NotImplementedError + return {"x": int} + + class OnlyAnnotate: + @property + def __annotate__(self): + return AnnotateCallable() + + oa = OnlyAnnotate() + self.assertEqual(get_annotations(oa, format=Format.VALUE), {"x": int}) + self.assertEqual(get_annotations(oa, format=Format.FORWARDREF), {"x": int}) + self.assertEqual( + get_annotations(oa, format=Format.STRING), + {"x": "int"}, + ) + def test_non_dict_annotate(self): class WeirdAnnotate: def __annotate__(self, *args, **kwargs): @@ -1492,6 +1514,114 @@ def annotate(format, /, __Format=Format, __NotImplementedError=NotImplementedErr self.assertEqual(annotations, {"x": int}) + def test_callable_class_annotate_forwardref_fakeglobals(self): + # Calling the class will construct a new instance and call its __init__ function + # as an annotate function, returning the instance. This is fine as long as + # the class inherits from dict. + class Annotate(dict): + def __init__(self, format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + super().__init__({'x': str}) + elif format == __Format.VALUE_WITH_FAKE_GLOBALS: + super().__init__({'x': int}) + else: + raise __NotImplementedError(format) + + annotations = annotationlib.call_annotate_function( + Annotate, + Format.FORWARDREF + ) + + self.assertEqual(annotations, {"x": int}) + + def test_callable_class_custom_init_annotate_forwardref_fakeglobals(self): + # Calling the class will construct a new instance and call its __init__ function + # as an annotate function, except this __init__ is not a method, + # but a partial function. + def custom_init(self, second, format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + super(type(self), self).__init__({"x": str}) + elif format == __Format.VALUE_WITH_FAKE_GLOBALS: + super(type(self), self).__init__({"x": second}) + else: + raise __NotImplementedError(format) + + class Annotate(dict): + pass + + Annotate.__init__ = functools.partial(custom_init, functools.Placeholder, int) + + annotations = annotationlib.call_annotate_function( + Annotate, + Format.FORWARDREF + ) + + self.assertEqual(annotations, {"x": int}) + + def test_callable_generic_class_annotate_forwardref_fakeglobals(self): + # Subscripted generic classes are types.GenericAlias instances + # for dict subclasses. Check that they are still + # callable as annotate functions, just like regular classes. + class Annotate[K, V](dict[K, V]): + def __init__(self, format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + super().__init__({'x': str}) + elif format == __Format.VALUE_WITH_FAKE_GLOBALS: + super().__init__({'x': int}) + else: + raise __NotImplementedError(format) + + annotations = annotationlib.call_annotate_function( + Annotate[str, type], + Format.FORWARDREF + ) + + self.assertEqual(annotations, {"x": int}) + + # We manually set the __orig_class__ for this special-case, check this too. + self.assertEqual(annotations.__orig_class__, Annotate[str, type]) + + def test_callable_typing_generic_class_annotate_forwardref_fakeglobals(self): + # Normally, generics are 'typing._GenericAlias' objects. These are implemented + # in Python with a __call__ method (in _typing.BaseGenericAlias), but this + # needs to be bypassed so we can inject fake globals into the origin class' + # __init__ method. + class Annotate[T]: + def __init__(self, format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + self.data = {'x': str} + elif format == __Format.VALUE_WITH_FAKE_GLOBALS: + self.data = {"x": int} + else: + raise __NotImplementedError(format) + def __getitem__(self, item): + return self.data[item] + def __iter__(self): + return iter(self.data) + def __len__(self): + return len(self.data) + def __getattr__(self, attr): + val = getattr(collections.abc.Mapping, attr) + if isinstance(val, types.FunctionType): + return types.MethodType(val, self) + return val + def __eq__(self, other): + return dict(self.items()) == dict(other.items()) + + annotations = annotationlib.call_annotate_function( + Annotate[int], + Format.FORWARDREF, + ) + + self.assertEqual(annotations, {"x": int}) + + # We manually set the __orig_class__ for this special-case, check this too. + self.assertEqual(annotations.__orig_class__, Annotate[int]) + def test_user_annotate_forwardref_value_fallback(self): # If Format.FORWARDREF and Format.VALUE_WITH_FAKE_GLOBALS are not supported # use Format.VALUE @@ -1545,6 +1675,50 @@ def annotate(format, /, __Format=Format, __NotImplementedError=NotImplementedErr self.assertEqual(annotations, {"x": "int"}) + def test_callable_generic_class_annotate_string_fakeglobals(self): + # If a generic class uses slots, we may not be able to set + # its __orig_class__ attr. + class Annotate[T]: + __slots__ = "data", + + # If Format.STRING is not supported but Format.VALUE_WITH_FAKE_GLOBALS is, + # prefer that over Format.VALUE + def __init__(self, format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + self.data = {"x": str} + elif format == __Format.VALUE_WITH_FAKE_GLOBALS: + self.data = {"x": int} + else: + raise __NotImplementedError(format) + def __getitem__(self, item): + return self.data[item] + def __iter__(self): + return iter(self.data) + def __len__(self): + return len(self.data) + def __getattr__(self, attr): + val = getattr(collections.abc.Mapping, attr) + if isinstance(val, types.FunctionType): + return types.MethodType(val, self) + return val + def __eq__(self, other): + return dict(self.items()) == dict(other.items()) + + # Subscripting a user-created class will usually return a typing._GenericAlias. + # We want to check that types.GenericAlias objects are still interpreted properly, + # so manually create it with the documented constructor. + annotations = annotationlib.call_annotate_function( + types.GenericAlias(Annotate, (int,)), + Format.STRING, + ) + + self.assertEqual(annotations, {"x": "int"}) + + # A __slots__ class can't have __orig_class__ set unless already specified. + # Ensure that the error passes silently, as is the case in typing. + self.assertNotHasAttr(annotations, "__orig_class__") + def test_user_annotate_string_value_fallback(self): # If Format.STRING and Format.VALUE_WITH_FAKE_GLOBALS are not # supported fall back to Format.VALUE and convert to strings @@ -1561,6 +1735,328 @@ def annotate(format, /, __Format=Format, __NotImplementedError=NotImplementedErr self.assertEqual(annotations, {"x": "str"}) + def test_callable_object_annotate(self): + class Annotate: + def __call__(self, format, /): + return {"x": str} + + # Check that all formats work with a standard callable object as an + # annotate function. + for fmt in [Format.VALUE, Format.FORWARDREF, Format.STRING]: + self.assertEqual( + annotationlib.call_annotate_function(Annotate(), format=fmt), + {"x": str} + ) + + def test_callable_method_annotate_forwardref_value_fallback(self): + # Calling a method requires call_annotate_function() to add the self param. + class Annotate: + def format(self, format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + return {"x": str} + else: + raise __NotImplementedError(format) + + annotations = annotationlib.call_annotate_function( + Annotate().format, + Format.FORWARDREF, + ) + + self.assertEqual(annotations, {"x": str}) + + def test_callable_object_annotate_forwardref_value_fallback(self): + # Calling an object is special-cased in call_annotate_function() + # to call its __call__ method. + class Annotate: + def __call__(self, format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + return {"x": str} + else: + raise __NotImplementedError(format) + + annotations = annotationlib.call_annotate_function( + Annotate(), + Format.FORWARDREF, + ) + + self.assertEqual(annotations, {"x": str}) + + def test_callable_custom_method_annotate_forwardref_value_fallback(self): + class Annotate(dict): + def __init__(inst, self, format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + super().__init__({"x": str}) + else: + raise __NotImplementedError(format) + + # This wouldn't happen on a normal class, but it's technically legal. + # Ensure that methods (which are special-cased) can wrap class construction + # (which is also special-cased). + custom_method = types.MethodType(Annotate, Annotate) + + annotations = annotationlib.call_annotate_function( + custom_method, + Format.FORWARDREF, + ) + + self.assertEqual(annotations, {"x": str}) + + def test_callable_classmethod_annotate_forwardref_value_fallback(self): + # @classmethod returns a descriptor to a method. + # Ensure that the class itself is correctly bound to the cls param. + class Annotate: + @classmethod + def format(cls, format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE and cls is Annotate: + return {"x": str} + else: + raise __NotImplementedError(format) + + annotations = annotationlib.call_annotate_function( + Annotate.format, + Format.FORWARDREF, + ) + + self.assertEqual(annotations, {"x": str}) + + def test_callable_staticmethod_annotate_forwardref_value_fallback(self): + # @staticmethod returns a descriptor which means that Annotate.format + # should be a normal function object. + class Annotate: + @staticmethod + def format(format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + return {"x": str} + else: + raise __NotImplementedError(format) + + annotations = annotationlib.call_annotate_function( + Annotate.format, + Format.FORWARDREF, + ) + self.assertEqual(annotations, {"x": str}) + + # But if we access via __dict__, the underlying staticmethod object is returned. + # Ensure that call_annotate_function() can handle this special case. + annotations = annotationlib.call_annotate_function( + Annotate.__dict__["format"], + Format.FORWARDREF, + ) + self.assertEqual(annotations, {"x": str}) + + + def test_callable_object_custom_call_annotate_forwardref_value_fallback(self): + class AnnotateClass(dict): + def __init__(self, format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + super().__init__({"x": int}) + else: + raise __NotImplementedError(format) + + class Annotate: + # In this case, calling the instance returns a callable class, instead of + # the usual method. + __call__ = AnnotateClass + + annotations = annotationlib.call_annotate_function( + Annotate(), + Format.FORWARDREF, + ) + + self.assertEqual(annotations, {"x": int}) + + def test_callable_partial_annotate_forwardref_value_fallback(self): + # functools.partial is implemented in C. Ensure that the annotate function + # is extracted and called correctly, particularly with Placeholder args. + def format(format, second, /, *, third, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + return {"x": format * second * third} + else: + raise __NotImplementedError(format) + + annotations = annotationlib.call_annotate_function( + functools.partial(format, functools.Placeholder, 5, third=6), + Format.FORWARDREF, + ) + + self.assertEqual(annotations, {"x": Format.VALUE * 5 * 6}) + + def test_callable_partialmethod_annotate_forwardref_value_fallback(self): + # partialmethod is a Python wrapper around functools.partial, + # ensure that self is passed in and partial works as usual. + class Annotate: + def _internal_format(self, format, second, /, *, third, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + return {"x": format * second * third} + else: + raise __NotImplementedError(format) + + format = functools.partialmethod( + _internal_format, + functools.Placeholder, + 5, + third=6 + ) + + annotations = annotationlib.call_annotate_function( + Annotate().format, + Format.FORWARDREF, + ) + + self.assertEqual(annotations, {"x": Format.VALUE * 5 * 6}) + + def test_callable_cache_annotate_forwardref_value_fallback(self): + # lru cache is a C wrapper around functions, ensure that the underlying + # function is accessed correctly. + @functools.cache + def format(format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + return {"x": random.random(), "y": str} + else: + raise __NotImplementedError(format) + + annotations = annotationlib.call_annotate_function( + format, + Format.FORWARDREF, + ) + + self.assertIsInstance(annotations, dict) + self.assertIn("x", annotations) + self.assertIsInstance(annotations["x"], float) + self.assertIs(annotations["y"], str) + + # Check annotations again to ensure that the result is still cached. + new_anns = annotationlib.call_annotate_function(format, Format.FORWARDREF) + self.assertEqual(annotations, new_anns) + + def test_callable_double_wrapped_annotate_forwardref_value_fallback(self): + # The raw staticmethod object returns a 'wrapped' function, and so does + # @functools.cache. Here we test that functions unwrap recursively, + # allowing annotate functions which wrap already wrapped functions. + class Annotate: + @staticmethod + @functools.cache + def format(format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + return {"x": random.random(), "y": str} + else: + raise __NotImplementedError(format) + + # Access the raw staticmethod object which wraps the cached function. + annotations = annotationlib.call_annotate_function( + Annotate.__dict__["format"], + Format.FORWARDREF, + ) + + self.assertIsInstance(annotations, dict) + self.assertIn("x", annotations) + self.assertIsInstance(annotations["x"], float) + self.assertIs(annotations["y"], str) + + # Check annotations again to ensure that the result is still cached. + new_anns = annotationlib.call_annotate_function(Annotate.format, Format.FORWARDREF) + self.assertEqual(annotations, new_anns) + + def test_callable_wrapped_annotate_forwardref_value_fallback(self): + # Test unwrapping of @functools.wraps functions, similar to @functools.cache. + def multiple_format(fn): + inputs = {"x": int} + @functools.wraps(fn) + def format(format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + return {**inputs, **fn()} + else: + raise __NotImplementedError(format) + + return format + + annotations = annotationlib.call_annotate_function( + multiple_format(lambda: {"y": str}), + Format.FORWARDREF, + ) + + self.assertEqual(annotations, {"x": int, "y": str}) + + def test_callable_singledispatch_annotate_forwardref_value_fallback(self): + # Ensure that the correct singledispatch function is used when calling + # a singledispatch annotate function. + @functools.singledispatch + def format(format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + return {"x": str} + else: + raise __NotImplementedError(format) + + @format.register(float) + def _(format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + return {"x": float} + else: + raise __NotImplementedError(format) + + @format.register(int) + def _(format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + return {"x": int} + else: + raise __NotImplementedError(format) + + annotations = annotationlib.call_annotate_function( + format, + Format.FORWARDREF, + ) + + self.assertEqual(annotations, {"x": int}) + + def test_callable_singledispatchmethod_annotate_forwardref_value_fallback(self): + # Ensure that the correct singledispatch method is used, along with the self + # parameter when calling a singledispatchmethod annotate function. + class Annotate: + @functools.singledispatchmethod + def format(self, format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + return {"x": str} + else: + raise __NotImplementedError(format) + + @format.register(float) + def _(self, format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + return {"x": float} + else: + raise __NotImplementedError(format) + + @format.register(int) + def _(self, format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + return {"x": int} + else: + raise __NotImplementedError(format) + + annotations = annotationlib.call_annotate_function( + Annotate().format, + Format.FORWARDREF, + ) + + self.assertEqual(annotations, {"x": int}) + def test_condition_not_stringified(self): # Make sure the first condition isn't evaluated as True by being converted # to a _Stringifier @@ -1606,6 +2102,48 @@ def annotate(format, /): with self.assertRaises(DemoException): annotationlib.call_annotate_function(annotate, format=fmt) + def test_callable_object_error_from_value_raised(self): + # Test that the error from format.VALUE is raised + # if all formats fail + + class DemoException(Exception): ... + + class Annotate: + def __call__(self, format, /): + if format == Format.VALUE: + raise DemoException() + else: + raise NotImplementedError(format) + + for fmt in [Format.VALUE, Format.FORWARDREF, Format.STRING]: + with self.assertRaises(DemoException): + annotationlib.call_annotate_function(Annotate(), format=fmt) + + def test_unsupported_callable_object_fakeglobals_error(self): + # Test that a readable error is raised when an unsupported callable + # type is used as an annotate function with fake globals. + + def annotate(format, /, __Format=Format, + __NotImplementedError=NotImplementedError): + if format == __Format.VALUE: + return {"x": int} + elif format == __Format.VALUE_WITH_FAKE_GLOBALS: + return {"x": str} + else: + raise __NotImplementedError(format) + + annotations = annotationlib.call_annotate_function( + annotate.__call__, + Format.VALUE + ) + self.assertEqual(annotations, {"x": int}) + + for fmt in (Format.FORWARDREF, Format.STRING): + with self.assertRaisesRegex( + TypeError, "annotate function missing '__code__' attribute" + ): + annotationlib.call_annotate_function(annotate.__call__, fmt) + class MetaclassTests(unittest.TestCase): def test_annotated_meta(self): diff --git a/Misc/NEWS.d/next/Library/2025-11-18-16-46-58.gh-issue-141388.V5UBkb.rst b/Misc/NEWS.d/next/Library/2025-11-18-16-46-58.gh-issue-141388.V5UBkb.rst new file mode 100644 index 00000000000000..8f895f1facbe37 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-11-18-16-46-58.gh-issue-141388.V5UBkb.rst @@ -0,0 +1 @@ +Support arbitrary callables as annotate functions.