From f5704da25d3b998ea2a1e44cacad064480d03c0e Mon Sep 17 00:00:00 2001 From: Harshil <37377066+harshil21@users.noreply.github.com> Date: Mon, 5 Feb 2024 05:26:41 -0500 Subject: [PATCH] Drastically simplify `check_param_type` This was achieved by resolving ForwardRefs and changing the method of how type annotations are compared with the official API --- tests/test_official/arg_type_checker.py | 227 ++++++++++-------------- tests/test_official/exceptions.py | 21 ++- tests/test_official/helpers.py | 46 ++++- tests/test_official/scraper.py | 8 +- tests/test_official/test_official.py | 42 +++-- 5 files changed, 172 insertions(+), 172 deletions(-) diff --git a/tests/test_official/arg_type_checker.py b/tests/test_official/arg_type_checker.py index f1dac7401b3..2ccd7808cb5 100644 --- a/tests/test_official/arg_type_checker.py +++ b/tests/test_official/arg_type_checker.py @@ -25,17 +25,25 @@ import re from datetime import datetime from types import FunctionType -from typing import Any, ForwardRef, Sequence, get_args, get_origin +from typing import Any, Sequence -import telegram from telegram._utils.defaultvalue import DefaultValue from telegram._utils.types import FileInput, ODVInput from telegram.ext import Defaults from tests.test_official.exceptions import ParamTypeCheckingExceptions as PTCE from tests.test_official.exceptions import ignored_param_requirements -from tests.test_official.helpers import _extract_words, _get_params_base, _unionizer +from tests.test_official.helpers import ( + _extract_words, + _get_params_base, + _unionizer, + cached_type_hints, + resolve_forward_refs_in_type, + wrap_with_none, +) from tests.test_official.scraper import TelegramParameter +ARRAY_OF_PATTERN = r"Array of(?: Array of)? ([\w\,\s]*)" + # In order to evaluate the type annotation, we need to first have a mapping of the types # specified in the official API to our types. The keys are types in the column of official API. TYPE_MAPPING: dict[str, set[Any]] = { @@ -45,12 +53,22 @@ r"Boolean|True": {bool}, r"Float(?: number)?": {float}, # Distinguishing 1D and 2D Sequences and finding the inner type is done later. - r"Array of (?:Array of )?[\w\,\s]*": {Sequence}, - r"InputFile(?: or String)?": {FileInput}, + ARRAY_OF_PATTERN: {Sequence}, + r"InputFile(?: or String)?": {resolve_forward_refs_in_type(FileInput)}, } ALL_DEFAULTS = inspect.getmembers(Defaults, lambda x: isinstance(x, property)) +DATETIME_REGEX = re.compile( + r"""([_]+|\b) # check for word boundary or underscore + date # check for "date" + [^\w]*\b # optionally check for a word after 'date' + """, + re.VERBOSE, +) + +log = logging.debug + def check_required_param( tg_param: TelegramParameter, param: inspect.Parameter, method_or_obj_name: str @@ -73,8 +91,10 @@ def check_defaults_type(ptb_param: inspect.Parameter) -> bool: def check_param_type( - ptb_param: inspect.Parameter, tg_parameter: TelegramParameter, obj: FunctionType | type -) -> bool: + ptb_param: inspect.Parameter, + tg_parameter: TelegramParameter, + obj: FunctionType | type, +) -> tuple[bool, type]: """This function checks whether the type annotation of the parameter is the same as the one specified in the official API. It also checks for some special cases where we accept more types @@ -84,177 +104,122 @@ def check_param_type( obj: The object (method/class) that we are checking. Returns: - :obj:`bool`: The boolean returned represents whether our parameter's type annotation is the - same as Telegram's or not. + :obj:`tuple`: A tuple containing: + * :obj:`bool`: The boolean returned represents whether our parameter's type annotation + is the same as Telegram's or not. + * :obj:`type`: The expected type annotation of the parameter. """ # PRE-PROCESSING: - # In order to evaluate the type annotation, we need to first have a mapping of the types - # (see TYPE_MAPPING comment defined above) tg_param_type: str = tg_parameter.param_type is_class = inspect.isclass(obj) + ptb_annotation = cached_type_hints(obj, is_class).get(ptb_param.name) + # Let's check for a match: + # In order to evaluate the type annotation, we need to first have a mapping of the types + # (see TYPE_MAPPING comment defined at the top level of this module) mapped: set[type] = _get_params_base(tg_param_type, TYPE_MAPPING) # We should have a maximum of one match. assert len(mapped) <= 1, f"More than one match found for {tg_param_type}" - if not mapped: # no match found, it's from telegram module - # it could be a list of objects, so let's check that: - objs = _extract_words(tg_param_type) - # We want to store both string version of class and the class obj itself. e.g. "InputMedia" - # and InputMedia because some annotations might be ForwardRefs. - if len(objs) >= 2: # We have to unionize the objects - mapped_type: tuple[Any, ...] = (_unionizer(objs, False), _unionizer(objs, True)) - else: - mapped_type = ( - getattr(telegram, tg_param_type), # This will fail if it's not from telegram mod - ForwardRef(tg_param_type), - tg_param_type, # for some reason, some annotations are just a string. - ) - elif len(mapped) == 1: - mapped_type = mapped.pop() - - # Resolve nested annotations to get inner types. - if (ptb_annotation := list(get_args(ptb_param.annotation))) == []: - ptb_annotation = ptb_param.annotation # if it's not nested, just use the annotation - - if isinstance(ptb_annotation, list): - # Some cleaning: - # Remove 'Optional[...]' from the annotation if it's present. We do it this way since: 1) - # we already check if argument should be optional or not + type checkers will complain. - # 2) we want to check if our `obj` is same as API's `obj`, and since python evaluates - # `Optional[obj] != obj` we have to remove the Optional, so that we can compare the two. - if type(None) in ptb_annotation: - ptb_annotation.remove(type(None)) - - # Cleaning done... now let's put it back together. - # Join all the annotations back (i.e. Union) - ptb_annotation = _unionizer(ptb_annotation, False) + # it may be a list of objects, so let's extract them using _extract_words: + mapped_type = _unionizer(_extract_words(tg_param_type)) if not mapped else mapped.pop() + # If the parameter is not required by TG, `None` should be added to `mapped_type` + mapped_type = wrap_with_none(tg_parameter, mapped_type, obj) - # Last step, we need to use get_origin to get the original type, since using get_args - # above will strip that out. - wrapped = get_origin(ptb_param.annotation) - if wrapped is not None: - # collections.abc.Sequence -> typing.Sequence - if "collections.abc.Sequence" in str(wrapped): - wrapped = Sequence - ptb_annotation = wrapped[ptb_annotation] - # We have put back our annotation together after removing the NoneType! - - logging.debug( + log( "At the end of PRE-PROCESSING, the values of variables are:\n" "Parameter name: %s\n" "ptb_annotation= %s\n" "mapped_type= %s\n" - "tg_param_type= %s\n", + "tg_param_type= %s\n" + "tg_parameter.param_required= %s\n", ptb_param.name, ptb_annotation, mapped_type, tg_param_type, + tg_parameter.param_required, ) # CHECKING: - # Each branch may have exits in the form of return statements. If the annotation is found to be - # correct, the function will return True. If not, it will return False. + # Each branch manipulates the `mapped_type` (except for 4) ) to match the `ptb_annotation`. # 1) HANDLING ARRAY TYPES: # Now let's do the checking, starting with "Array of ..." types. if "Array of " in tg_param_type: - logging.debug("Array of type found in `%s`\n", tg_param_type) - assert mapped_type is Sequence # For exceptions just check if they contain the annotation if ptb_param.name in PTCE.ARRAY_OF_EXCEPTIONS: - return PTCE.ARRAY_OF_EXCEPTIONS[ptb_param.name] in str(ptb_annotation) + return PTCE.ARRAY_OF_EXCEPTIONS[ptb_param.name] in str(ptb_annotation), Sequence - pattern = r"Array of(?: Array of)? ([\w\,\s]*)" - obj_match: re.Match | None = re.search(pattern, tg_param_type) # extract obj from string + obj_match: re.Match | None = re.search(ARRAY_OF_PATTERN, tg_param_type) if obj_match is None: raise AssertionError(f"Array of {tg_param_type} not found in {ptb_param.name}") obj_str: str = obj_match.group(1) # is obj a regular type like str? - array_of_mapped: set[type] = _get_params_base(obj_str, TYPE_MAPPING) + array_map: set[type] = _get_params_base(obj_str, TYPE_MAPPING) - if len(array_of_mapped) == 0: # no match found, it's from telegram module - # it could be a list of objects, so let's check that: - objs = _extract_words(obj_str) - # let's unionize all the objects, with and without ForwardRefs. - unionized_objs: list[type] = [_unionizer(objs, True), _unionizer(objs, False)] - else: - unionized_objs = [array_of_mapped.pop()] + mapped_type = _unionizer(_extract_words(obj_str)) if not array_map else array_map.pop() - # This means it is Array of Array of [obj] if "Array of Array of" in tg_param_type: - return any(Sequence[Sequence[o]] == ptb_annotation for o in unionized_objs) - - # This means it is Array of [obj] - return any(mapped_type[o] == ptb_annotation for o in unionized_objs) - - # 2) HANDLING DEFAULTS PARAMETERS: - # Classes whose parameters are all ODVInput should be converted and checked. - if obj.__name__ in PTCE.IGNORED_DEFAULTS_CLASSES: - logging.debug("Checking that `%s`'s param is ODVInput:\n", obj.__name__) - parsed = ODVInput[mapped_type] - return (ptb_annotation | None) == parsed # We have to add back None in our annotation - if not ( - # Defaults checking should not be done for: - # 1. Parameters that have name conflict with `Defaults.name` - is_class - and obj.__name__ in ("ReplyParameters", "Message", "ExternalReplyInfo") - and ptb_param.name in PTCE.IGNORED_DEFAULTS_PARAM_NAMES - ): - # Now let's check if the parameter is a Defaults parameter, it should be - for name, _ in ALL_DEFAULTS: - if name == ptb_param.name or "parse_mode" in ptb_param.name: - logging.debug("Checking that `%s` is a Defaults parameter!\n", ptb_param.name) - # mapped_type should not be a tuple since we need to check for equality: - # This can happen when the Defaults parameter is a class, e.g. LinkPreviewOptions - if isinstance(mapped_type, tuple): - mapped_type = mapped_type[1] # We select the ForwardRef - # Assert if it's ODVInput by checking equality: - parsed = ODVInput[mapped_type] - if (ptb_annotation | None) == parsed: # We have to add back None in our annotation - return True - return False + log("Array of Array of type found in `%s`\n", tg_param_type) + mapped_type = Sequence[Sequence[mapped_type]] + else: + log("Array of type found in `%s`\n", tg_param_type) + mapped_type = Sequence[mapped_type] - # 3) HANDLING OTHER TYPES: + # 2) HANDLING OTHER TYPES: # Special case for send_* methods where we accept more types than the official API: - if ( - ptb_param.name in PTCE.ADDITIONAL_TYPES - and not isinstance(mapped_type, tuple) - and obj.__name__.startswith("send") - ): - logging.debug("Checking that `%s` has an additional argument!\n", ptb_param.name) + elif ptb_param.name in PTCE.ADDITIONAL_TYPES and obj.__name__.startswith("send"): + log("Checking that `%s` has an additional argument!\n", ptb_param.name) mapped_type = mapped_type | PTCE.ADDITIONAL_TYPES[ptb_param.name] - # 4) HANDLING DATETIMES: - if ( + # 3) HANDLING DATETIMES: + elif ( re.search( - r"""([_]+|\b) # check for word boundary or underscore - date # check for "date" - [^\w]*\b # optionally check for a word after 'date' - """, + DATETIME_REGEX, ptb_param.name, - re.VERBOSE, ) or "Unix time" in tg_parameter.param_description ): - logging.debug("Checking that `%s` is a datetime!\n", ptb_param.name) + log("Checking that `%s` is a datetime!\n", ptb_param.name) if ptb_param.name in PTCE.DATETIME_EXCEPTIONS: - return True + return True, mapped_type # If it's a class, we only accept datetime as the parameter mapped_type = datetime if is_class else mapped_type | datetime - # RESULTS: ALL OTHER BASIC TYPES- - # Some types are too complicated, so we replace them with a simpler type: - for (param_name, expected_class), exception_type in PTCE.COMPLEX_TYPES.items(): - if ptb_param.name == param_name and is_class is expected_class: - logging.debug("Converting `%s` to a simpler type!\n", ptb_param.name) - ptb_annotation = exception_type + # 4) COMPLEX TYPES: + # Some types are too complicated, so we replace our annotation with a simpler type: + elif any(ptb_param.name in key for key in PTCE.COMPLEX_TYPES): + log("Converting `%s` to a simpler type!\n", ptb_param.name) + for (param_name, is_expected_class), exception_type in PTCE.COMPLEX_TYPES.items(): + if ptb_param.name == param_name and is_class is is_expected_class: + ptb_annotation = wrap_with_none(tg_parameter, exception_type, obj) - # Final check, if the annotation is a tuple, we need to check if any of the types in the tuple - # match the mapped type. - if isinstance(mapped_type, tuple) and any(ptb_annotation == t for t in mapped_type): - return True - - # If the annotation is not a tuple, we can just check if it's equal to the mapped type. - return mapped_type == ptb_annotation + # 5) HANDLING DEFAULTS PARAMETERS: + # Classes whose parameters are all ODVInput should be converted and checked. + elif obj.__name__ in PTCE.IGNORED_DEFAULTS_CLASSES: + log("Checking that `%s`'s param is ODVInput:\n", obj.__name__) + mapped_type = ODVInput[mapped_type] + elif not ( + # Defaults checking should not be done for: + # 1. Parameters that have name conflict with `Defaults.name` + is_class + and obj.__name__ in ("ReplyParameters", "Message", "ExternalReplyInfo") + and ptb_param.name in PTCE.IGNORED_DEFAULTS_PARAM_NAMES + ): + # Now let's check if the parameter is a Defaults parameter, it should be + for name, _ in ALL_DEFAULTS: + if name == ptb_param.name or "parse_mode" in ptb_param.name: + log("Checking that `%s` is a Defaults parameter!\n", ptb_param.name) + mapped_type = ODVInput[mapped_type] + break + + # RESULTS:- + mapped_type = wrap_with_none(tg_parameter, mapped_type, obj) + mapped_type = resolve_forward_refs_in_type(mapped_type) + log( + "At RESULTS, we are comparing:\nptb_annotation= %s\nmapped_type= %s\n", + ptb_annotation, + mapped_type, + ) + return mapped_type == ptb_annotation, mapped_type diff --git a/tests/test_official/exceptions.py b/tests/test_official/exceptions.py index abf713f3bce..e554b0888ad 100644 --- a/tests/test_official/exceptions.py +++ b/tests/test_official/exceptions.py @@ -19,8 +19,7 @@ """This module contains exceptions to our API compared to the official API.""" -from typing import ForwardRef - +from telegram import Animation, Audio, Document, PhotoSize, Sticker, Video, VideoNote, Voice from tests.test_official.helpers import _get_params_base IGNORED_OBJECTS = ("ResponseParameters",) @@ -38,14 +37,14 @@ class ParamTypeCheckingExceptions: # Types for certain parameters accepted by PTB but not in the official API ADDITIONAL_TYPES = { - "photo": ForwardRef("PhotoSize"), - "video": ForwardRef("Video"), - "video_note": ForwardRef("VideoNote"), - "audio": ForwardRef("Audio"), - "document": ForwardRef("Document"), - "animation": ForwardRef("Animation"), - "voice": ForwardRef("Voice"), - "sticker": ForwardRef("Sticker"), + "photo": PhotoSize, + "video": Video, + "video_note": VideoNote, + "audio": Audio, + "document": Document, + "animation": Animation, + "voice": Voice, + "sticker": Sticker, } # Exceptions to the "Array of" types, where we accept more types than the official API @@ -56,7 +55,7 @@ class ParamTypeCheckingExceptions: "keyboard": "KeyboardButton", # + sequence[sequence[str]] "reaction": "ReactionType", # + str # TODO: Deprecated and will be corrected (and removed) in next major PTB version: - "file_hashes": "list[str]", + "file_hashes": "List[str]", } # Special cases for other parameters that accept more types than the official API, and are diff --git a/tests/test_official/helpers.py b/tests/test_official/helpers.py index 2620e6c2e0a..6851bf85fa2 100644 --- a/tests/test_official/helpers.py +++ b/tests/test_official/helpers.py @@ -18,12 +18,23 @@ # along with this program. If not, see [http://www.gnu.org/licenses/]. """This module contains helper functions for the official API tests used in the other modules.""" +import functools import re -from typing import Any, ForwardRef, Sequence +from typing import TYPE_CHECKING, Any, Sequence, _eval_type, get_type_hints from bs4 import PageElement, Tag import telegram +import telegram._utils.defaultvalue +import telegram._utils.types + +if TYPE_CHECKING: + from tests.test_official.scraper import TelegramParameter + + +tg_objects = vars(telegram) +tg_objects.update(vars(telegram._utils.types)) +tg_objects.update(vars(telegram._utils.defaultvalue)) def _get_params_base(object_name: str, search_dict: dict[str, set[Any]]) -> set[Any]: @@ -48,14 +59,11 @@ def _extract_words(text: str) -> set[str]: return set(re.sub(r"[^\w\s]", "", text).split()) - {"and", "or"} -def _unionizer(annotation: Sequence[Any] | set[Any], forward_ref: bool) -> Any: - """Returns a union of all the types in the annotation. If forward_ref is True, it wraps the - annotation in a ForwardRef and then unionizes.""" +def _unionizer(annotation: Sequence[Any] | set[Any]) -> Any: + """Returns a union of all the types in the annotation. Also imports objects from lib.""" union = None for t in annotation: - if forward_ref: - t = ForwardRef(t) # noqa: PLW2901 - elif not forward_ref and isinstance(t, str): # we have to import objects from lib + if isinstance(t, str): # we have to import objects from lib t = getattr(telegram, t) # noqa: PLW2901 union = t if union is None else union | t return union @@ -71,7 +79,7 @@ def find_next_sibling_until(tag: Tag, name: str, until: Tag) -> PageElement | No def is_pascal_case(s): - # Check if the string starts with a capital letter and contains only alphanumeric characters + "PascalCase. Starts with a capital letter and has no spaces. Useful for identifying classes." return bool(re.match(r"^[A-Z][a-zA-Z\d]*$", s)) @@ -79,3 +87,25 @@ def is_parameter_required_by_tg(field: str) -> bool: if field in {"Required", "Yes"}: return True return field.split(".", 1)[0] != "Optional" # splits the sentence and extracts first word + + +def wrap_with_none(tg_parameter: "TelegramParameter", mapped_type: Any, obj: object) -> type: + """Adds `None` to type annotation if the parameter isn't required. Respects ignored params.""" + # have to import here to avoid circular imports + from tests.test_official.exceptions import ignored_param_requirements + + if tg_parameter.param_name in ignored_param_requirements(obj.__name__): + return mapped_type | type(None) + return mapped_type | type(None) if not tg_parameter.param_required else mapped_type + + +@functools.cache +def cached_type_hints(obj: Any, is_class: bool) -> dict[str, Any]: + """Returns type hints of a class, method, or function, with forward refs evaluated.""" + return get_type_hints(obj.__init__ if is_class else obj, localns=tg_objects) + + +@functools.cache +def resolve_forward_refs_in_type(obj: type) -> type: + """Resolves forward references in a type hint.""" + return _eval_type(obj, localns=tg_objects, globalns=None) diff --git a/tests/test_official/scraper.py b/tests/test_official/scraper.py index 13ffa029618..1da83a87a90 100644 --- a/tests/test_official/scraper.py +++ b/tests/test_official/scraper.py @@ -75,12 +75,12 @@ async def make_request(self) -> None: self.soup = BeautifulSoup(self.request.text, "html.parser") @overload - def parse_docs(self, doc_type: Literal["method"]) -> tuple[list[TelegramMethod], list[str]]: - ... + def parse_docs( + self, doc_type: Literal["method"] + ) -> tuple[list[TelegramMethod], list[str]]: ... @overload - def parse_docs(self, doc_type: Literal["class"]) -> tuple[list[TelegramClass], list[str]]: - ... + def parse_docs(self, doc_type: Literal["class"]) -> tuple[list[TelegramClass], list[str]]: ... def parse_docs(self, doc_type): argvalues = [] diff --git a/tests/test_official/test_official.py b/tests/test_official/test_official.py index a4e36f2523c..5ad1d8b5686 100644 --- a/tests/test_official/test_official.py +++ b/tests/test_official/test_official.py @@ -56,9 +56,9 @@ def test_check_method(tg_method: TelegramMethod) -> None: - Method existence - Parameter existence + - Parameter requirement correctness - Parameter type annotation existence - Parameter type annotation correctness - - Parameter requirement correctness - Parameter default value correctness - No unexpected parameters - Extra parameters should be keyword only @@ -77,20 +77,25 @@ def test_check_method(tg_method: TelegramMethod) -> None: ptb_param is not None ), f"Parameter {tg_parameter.param_name} not found in {ptb_method.__name__}" + # Now check if the parameter is required or not + assert check_required_param( + tg_parameter, ptb_param, ptb_method.__name__ + ), f"Param {ptb_param.name!r} of {ptb_method.__name__!r} requirement mismatch" + # Check if type annotation is present assert ( ptb_param.annotation is not inspect.Parameter.empty ), f"Param {ptb_param.name!r} of {ptb_method.__name__!r} should have a type annotation!" # Check if type annotation is correct - assert check_param_type(ptb_param, tg_parameter, ptb_method), ( - f"Param {ptb_param.name!r} of {ptb_method.__name__!r} should be " - f"{tg_parameter.param_type} or something else!" + correct_type_hint, expected_type_hint = check_param_type( + ptb_param, + tg_parameter, + ptb_method, + ) + assert correct_type_hint, ( + f"Type hint of param {ptb_param.name!r} of {ptb_method.__name__!r} should be " + f"{expected_type_hint!r} or something else!" ) - - # Now check if the parameter is required or not - assert check_required_param( - tg_parameter, ptb_param, ptb_method.__name__ - ), f"Param {ptb_param.name!r} of method {ptb_method.__name__!r} requirement mismatch!" # Now we will check that we don't pass default values if the parameter is not required. if ptb_param.default is not inspect.Parameter.empty: # If there is a default argument... @@ -126,9 +131,9 @@ def test_check_object(tg_class: TelegramClass) -> None: - Class existence - Parameter existence + - Parameter requirement correctness - Parameter type annotation existence - Parameter type annotation correctness - - Parameter requirement correctness - Parameter default value correctness - No unexpected parameters """ @@ -152,22 +157,23 @@ def test_check_object(tg_class: TelegramClass) -> None: ptb_param = sig.parameters.get(field) assert ptb_param is not None, f"Attribute {field} not found in {obj.__name__}" + # Now check if the parameter is required or not + assert check_required_param( + tg_parameter, ptb_param, obj.__name__ + ), f"Param {ptb_param.name!r} of {obj.__name__!r} requirement mismatch" + # Check if type annotation is present assert ( ptb_param.annotation is not inspect.Parameter.empty ), f"Param {ptb_param.name!r} of {obj.__name__!r} should have a type annotation" # Check if type annotation is correct - assert check_param_type(ptb_param, tg_parameter, obj), ( - f"Param {ptb_param.name!r} of {obj.__name__!r} should be {tg_parameter.param_type} or " - "something else" + correct_type_hint, expected_type_hint = check_param_type(ptb_param, tg_parameter, obj) + assert correct_type_hint, ( + f"Type hint of param {ptb_param.name!r} of {obj.__name__!r} should be " + f"{expected_type_hint!r} or something else!" ) - # Now check if the parameter is required or not - assert check_required_param( - tg_parameter, ptb_param, obj.__name__ - ), f"Param {ptb_param.name!r} of {obj.__name__!r} requirement mismatch" - # Now we will check that we don't pass default values if the parameter is not required. if ptb_param.default is not inspect.Parameter.empty: # If there is a default argument... default_arg_none = check_defaults_type(ptb_param) # check if its None