Skip to content

Commit

Permalink
Drastically simplify check_param_type
Browse files Browse the repository at this point in the history
This was achieved by resolving ForwardRefs and changing the method of how type annotations are compared with the official API
  • Loading branch information
harshil21 committed Feb 5, 2024
1 parent 37ac865 commit f5704da
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 172 deletions.
227 changes: 96 additions & 131 deletions tests/test_official/arg_type_checker.py
Expand Up @@ -25,17 +25,25 @@
import re
from datetime import datetime
from types import FunctionType
from typing import Any, ForwardRef, Sequence, get_args, get_origin
from typing import Any, Sequence

import telegram
from telegram._utils.defaultvalue import DefaultValue
from telegram._utils.types import FileInput, ODVInput
from telegram.ext import Defaults
from tests.test_official.exceptions import ParamTypeCheckingExceptions as PTCE
from tests.test_official.exceptions import ignored_param_requirements
from tests.test_official.helpers import _extract_words, _get_params_base, _unionizer
from tests.test_official.helpers import (
_extract_words,
_get_params_base,
_unionizer,
cached_type_hints,
resolve_forward_refs_in_type,
wrap_with_none,
)
from tests.test_official.scraper import TelegramParameter

ARRAY_OF_PATTERN = r"Array of(?: Array of)? ([\w\,\s]*)"

# In order to evaluate the type annotation, we need to first have a mapping of the types
# specified in the official API to our types. The keys are types in the column of official API.
TYPE_MAPPING: dict[str, set[Any]] = {
Expand All @@ -45,12 +53,22 @@
r"Boolean|True": {bool},
r"Float(?: number)?": {float},
# Distinguishing 1D and 2D Sequences and finding the inner type is done later.
r"Array of (?:Array of )?[\w\,\s]*": {Sequence},
r"InputFile(?: or String)?": {FileInput},
ARRAY_OF_PATTERN: {Sequence},
r"InputFile(?: or String)?": {resolve_forward_refs_in_type(FileInput)},
}

ALL_DEFAULTS = inspect.getmembers(Defaults, lambda x: isinstance(x, property))

DATETIME_REGEX = re.compile(
r"""([_]+|\b) # check for word boundary or underscore
date # check for "date"
[^\w]*\b # optionally check for a word after 'date'
""",
re.VERBOSE,
)

log = logging.debug


def check_required_param(
tg_param: TelegramParameter, param: inspect.Parameter, method_or_obj_name: str
Expand All @@ -73,8 +91,10 @@ def check_defaults_type(ptb_param: inspect.Parameter) -> bool:


def check_param_type(
ptb_param: inspect.Parameter, tg_parameter: TelegramParameter, obj: FunctionType | type
) -> bool:
ptb_param: inspect.Parameter,
tg_parameter: TelegramParameter,
obj: FunctionType | type,
) -> tuple[bool, type]:
"""This function checks whether the type annotation of the parameter is the same as the one
specified in the official API. It also checks for some special cases where we accept more types
Expand All @@ -84,177 +104,122 @@ def check_param_type(
obj: The object (method/class) that we are checking.
Returns:
:obj:`bool`: The boolean returned represents whether our parameter's type annotation is the
same as Telegram's or not.
:obj:`tuple`: A tuple containing:
* :obj:`bool`: The boolean returned represents whether our parameter's type annotation
is the same as Telegram's or not.
* :obj:`type`: The expected type annotation of the parameter.
"""
# PRE-PROCESSING:
# In order to evaluate the type annotation, we need to first have a mapping of the types
# (see TYPE_MAPPING comment defined above)
tg_param_type: str = tg_parameter.param_type
is_class = inspect.isclass(obj)
ptb_annotation = cached_type_hints(obj, is_class).get(ptb_param.name)

# Let's check for a match:
# In order to evaluate the type annotation, we need to first have a mapping of the types
# (see TYPE_MAPPING comment defined at the top level of this module)
mapped: set[type] = _get_params_base(tg_param_type, TYPE_MAPPING)

# We should have a maximum of one match.
assert len(mapped) <= 1, f"More than one match found for {tg_param_type}"

if not mapped: # no match found, it's from telegram module
# it could be a list of objects, so let's check that:
objs = _extract_words(tg_param_type)
# We want to store both string version of class and the class obj itself. e.g. "InputMedia"
# and InputMedia because some annotations might be ForwardRefs.
if len(objs) >= 2: # We have to unionize the objects
mapped_type: tuple[Any, ...] = (_unionizer(objs, False), _unionizer(objs, True))
else:
mapped_type = (
getattr(telegram, tg_param_type), # This will fail if it's not from telegram mod
ForwardRef(tg_param_type),
tg_param_type, # for some reason, some annotations are just a string.
)
elif len(mapped) == 1:
mapped_type = mapped.pop()

# Resolve nested annotations to get inner types.
if (ptb_annotation := list(get_args(ptb_param.annotation))) == []:
ptb_annotation = ptb_param.annotation # if it's not nested, just use the annotation

if isinstance(ptb_annotation, list):
# Some cleaning:
# Remove 'Optional[...]' from the annotation if it's present. We do it this way since: 1)
# we already check if argument should be optional or not + type checkers will complain.
# 2) we want to check if our `obj` is same as API's `obj`, and since python evaluates
# `Optional[obj] != obj` we have to remove the Optional, so that we can compare the two.
if type(None) in ptb_annotation:
ptb_annotation.remove(type(None))

# Cleaning done... now let's put it back together.
# Join all the annotations back (i.e. Union)
ptb_annotation = _unionizer(ptb_annotation, False)
# it may be a list of objects, so let's extract them using _extract_words:
mapped_type = _unionizer(_extract_words(tg_param_type)) if not mapped else mapped.pop()
# If the parameter is not required by TG, `None` should be added to `mapped_type`
mapped_type = wrap_with_none(tg_parameter, mapped_type, obj)

# Last step, we need to use get_origin to get the original type, since using get_args
# above will strip that out.
wrapped = get_origin(ptb_param.annotation)
if wrapped is not None:
# collections.abc.Sequence -> typing.Sequence
if "collections.abc.Sequence" in str(wrapped):
wrapped = Sequence
ptb_annotation = wrapped[ptb_annotation]
# We have put back our annotation together after removing the NoneType!

logging.debug(
log(
"At the end of PRE-PROCESSING, the values of variables are:\n"
"Parameter name: %s\n"
"ptb_annotation= %s\n"
"mapped_type= %s\n"
"tg_param_type= %s\n",
"tg_param_type= %s\n"
"tg_parameter.param_required= %s\n",
ptb_param.name,
ptb_annotation,
mapped_type,
tg_param_type,
tg_parameter.param_required,
)

# CHECKING:
# Each branch may have exits in the form of return statements. If the annotation is found to be
# correct, the function will return True. If not, it will return False.
# Each branch manipulates the `mapped_type` (except for 4) ) to match the `ptb_annotation`.

# 1) HANDLING ARRAY TYPES:
# Now let's do the checking, starting with "Array of ..." types.
if "Array of " in tg_param_type:
logging.debug("Array of type found in `%s`\n", tg_param_type)
assert mapped_type is Sequence
# For exceptions just check if they contain the annotation
if ptb_param.name in PTCE.ARRAY_OF_EXCEPTIONS:
return PTCE.ARRAY_OF_EXCEPTIONS[ptb_param.name] in str(ptb_annotation)
return PTCE.ARRAY_OF_EXCEPTIONS[ptb_param.name] in str(ptb_annotation), Sequence

pattern = r"Array of(?: Array of)? ([\w\,\s]*)"
obj_match: re.Match | None = re.search(pattern, tg_param_type) # extract obj from string
obj_match: re.Match | None = re.search(ARRAY_OF_PATTERN, tg_param_type)
if obj_match is None:
raise AssertionError(f"Array of {tg_param_type} not found in {ptb_param.name}")
obj_str: str = obj_match.group(1)
# is obj a regular type like str?
array_of_mapped: set[type] = _get_params_base(obj_str, TYPE_MAPPING)
array_map: set[type] = _get_params_base(obj_str, TYPE_MAPPING)

if len(array_of_mapped) == 0: # no match found, it's from telegram module
# it could be a list of objects, so let's check that:
objs = _extract_words(obj_str)
# let's unionize all the objects, with and without ForwardRefs.
unionized_objs: list[type] = [_unionizer(objs, True), _unionizer(objs, False)]
else:
unionized_objs = [array_of_mapped.pop()]
mapped_type = _unionizer(_extract_words(obj_str)) if not array_map else array_map.pop()

# This means it is Array of Array of [obj]
if "Array of Array of" in tg_param_type:
return any(Sequence[Sequence[o]] == ptb_annotation for o in unionized_objs)

# This means it is Array of [obj]
return any(mapped_type[o] == ptb_annotation for o in unionized_objs)

# 2) HANDLING DEFAULTS PARAMETERS:
# Classes whose parameters are all ODVInput should be converted and checked.
if obj.__name__ in PTCE.IGNORED_DEFAULTS_CLASSES:
logging.debug("Checking that `%s`'s param is ODVInput:\n", obj.__name__)
parsed = ODVInput[mapped_type]
return (ptb_annotation | None) == parsed # We have to add back None in our annotation
if not (
# Defaults checking should not be done for:
# 1. Parameters that have name conflict with `Defaults.name`
is_class
and obj.__name__ in ("ReplyParameters", "Message", "ExternalReplyInfo")
and ptb_param.name in PTCE.IGNORED_DEFAULTS_PARAM_NAMES
):
# Now let's check if the parameter is a Defaults parameter, it should be
for name, _ in ALL_DEFAULTS:
if name == ptb_param.name or "parse_mode" in ptb_param.name:
logging.debug("Checking that `%s` is a Defaults parameter!\n", ptb_param.name)
# mapped_type should not be a tuple since we need to check for equality:
# This can happen when the Defaults parameter is a class, e.g. LinkPreviewOptions
if isinstance(mapped_type, tuple):
mapped_type = mapped_type[1] # We select the ForwardRef
# Assert if it's ODVInput by checking equality:
parsed = ODVInput[mapped_type]
if (ptb_annotation | None) == parsed: # We have to add back None in our annotation
return True
return False
log("Array of Array of type found in `%s`\n", tg_param_type)
mapped_type = Sequence[Sequence[mapped_type]]
else:
log("Array of type found in `%s`\n", tg_param_type)
mapped_type = Sequence[mapped_type]

# 3) HANDLING OTHER TYPES:
# 2) HANDLING OTHER TYPES:
# Special case for send_* methods where we accept more types than the official API:
if (
ptb_param.name in PTCE.ADDITIONAL_TYPES
and not isinstance(mapped_type, tuple)
and obj.__name__.startswith("send")
):
logging.debug("Checking that `%s` has an additional argument!\n", ptb_param.name)
elif ptb_param.name in PTCE.ADDITIONAL_TYPES and obj.__name__.startswith("send"):
log("Checking that `%s` has an additional argument!\n", ptb_param.name)
mapped_type = mapped_type | PTCE.ADDITIONAL_TYPES[ptb_param.name]

# 4) HANDLING DATETIMES:
if (
# 3) HANDLING DATETIMES:
elif (
re.search(
r"""([_]+|\b) # check for word boundary or underscore
date # check for "date"
[^\w]*\b # optionally check for a word after 'date'
""",
DATETIME_REGEX,
ptb_param.name,
re.VERBOSE,
)
or "Unix time" in tg_parameter.param_description
):
logging.debug("Checking that `%s` is a datetime!\n", ptb_param.name)
log("Checking that `%s` is a datetime!\n", ptb_param.name)
if ptb_param.name in PTCE.DATETIME_EXCEPTIONS:
return True
return True, mapped_type
# If it's a class, we only accept datetime as the parameter
mapped_type = datetime if is_class else mapped_type | datetime

# RESULTS: ALL OTHER BASIC TYPES-
# Some types are too complicated, so we replace them with a simpler type:
for (param_name, expected_class), exception_type in PTCE.COMPLEX_TYPES.items():
if ptb_param.name == param_name and is_class is expected_class:
logging.debug("Converting `%s` to a simpler type!\n", ptb_param.name)
ptb_annotation = exception_type
# 4) COMPLEX TYPES:
# Some types are too complicated, so we replace our annotation with a simpler type:
elif any(ptb_param.name in key for key in PTCE.COMPLEX_TYPES):
log("Converting `%s` to a simpler type!\n", ptb_param.name)
for (param_name, is_expected_class), exception_type in PTCE.COMPLEX_TYPES.items():
if ptb_param.name == param_name and is_class is is_expected_class:
ptb_annotation = wrap_with_none(tg_parameter, exception_type, obj)

# Final check, if the annotation is a tuple, we need to check if any of the types in the tuple
# match the mapped type.
if isinstance(mapped_type, tuple) and any(ptb_annotation == t for t in mapped_type):
return True

# If the annotation is not a tuple, we can just check if it's equal to the mapped type.
return mapped_type == ptb_annotation
# 5) HANDLING DEFAULTS PARAMETERS:
# Classes whose parameters are all ODVInput should be converted and checked.
elif obj.__name__ in PTCE.IGNORED_DEFAULTS_CLASSES:
log("Checking that `%s`'s param is ODVInput:\n", obj.__name__)
mapped_type = ODVInput[mapped_type]
elif not (
# Defaults checking should not be done for:
# 1. Parameters that have name conflict with `Defaults.name`
is_class
and obj.__name__ in ("ReplyParameters", "Message", "ExternalReplyInfo")
and ptb_param.name in PTCE.IGNORED_DEFAULTS_PARAM_NAMES
):
# Now let's check if the parameter is a Defaults parameter, it should be
for name, _ in ALL_DEFAULTS:
if name == ptb_param.name or "parse_mode" in ptb_param.name:
log("Checking that `%s` is a Defaults parameter!\n", ptb_param.name)
mapped_type = ODVInput[mapped_type]
break

# RESULTS:-
mapped_type = wrap_with_none(tg_parameter, mapped_type, obj)
mapped_type = resolve_forward_refs_in_type(mapped_type)
log(
"At RESULTS, we are comparing:\nptb_annotation= %s\nmapped_type= %s\n",
ptb_annotation,
mapped_type,
)
return mapped_type == ptb_annotation, mapped_type
21 changes: 10 additions & 11 deletions tests/test_official/exceptions.py
Expand Up @@ -19,8 +19,7 @@
"""This module contains exceptions to our API compared to the official API."""


from typing import ForwardRef

from telegram import Animation, Audio, Document, PhotoSize, Sticker, Video, VideoNote, Voice
from tests.test_official.helpers import _get_params_base

IGNORED_OBJECTS = ("ResponseParameters",)
Expand All @@ -38,14 +37,14 @@
class ParamTypeCheckingExceptions:
# Types for certain parameters accepted by PTB but not in the official API
ADDITIONAL_TYPES = {
"photo": ForwardRef("PhotoSize"),
"video": ForwardRef("Video"),
"video_note": ForwardRef("VideoNote"),
"audio": ForwardRef("Audio"),
"document": ForwardRef("Document"),
"animation": ForwardRef("Animation"),
"voice": ForwardRef("Voice"),
"sticker": ForwardRef("Sticker"),
"photo": PhotoSize,
"video": Video,
"video_note": VideoNote,
"audio": Audio,
"document": Document,
"animation": Animation,
"voice": Voice,
"sticker": Sticker,
}

# Exceptions to the "Array of" types, where we accept more types than the official API
Expand All @@ -56,7 +55,7 @@ class ParamTypeCheckingExceptions:
"keyboard": "KeyboardButton", # + sequence[sequence[str]]
"reaction": "ReactionType", # + str
# TODO: Deprecated and will be corrected (and removed) in next major PTB version:
"file_hashes": "list[str]",
"file_hashes": "List[str]",
}

# Special cases for other parameters that accept more types than the official API, and are
Expand Down

0 comments on commit f5704da

Please sign in to comment.