Skip to content

Commit

Permalink
Fix Bug With Parameter message_thread_id of Message.reply_* (#4207)
Browse files Browse the repository at this point in the history
  • Loading branch information
Bibo-Joshi committed Apr 15, 2024
1 parent 42b68f1 commit fed8d88
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 74 deletions.
65 changes: 38 additions & 27 deletions telegram/_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,7 +1462,7 @@ def build_reply_arguments(
quote_index: Optional[int] = None,
target_chat_id: Optional[Union[int, str]] = None,
allow_sending_without_reply: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
) -> _ReplyKwargs:
"""
Builds a dictionary with the keys ``chat_id`` and ``reply_parameters``. This dictionary can
Expand Down Expand Up @@ -1587,11 +1587,22 @@ async def _parse_quote_arguments(
def _parse_message_thread_id(
self,
chat_id: Union[str, int],
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
) -> Optional[int]:
return message_thread_id or (
self.message_thread_id if chat_id in {self.chat_id, self.chat.username} else None
)
# values set by user have the highest priority
if not isinstance(message_thread_id, DefaultValue):
return message_thread_id

# self.message_thread_id can be used for send_*.param.message_thread_id only if the
# thread is a forum topic. It does not work if the thread is a chain of replies to a
# message in a normal group. In that case, self.message_thread_id is just the message_id
# of the first message in the chain.
if not self.is_topic_message:
return None

# Setting message_thread_id=self.message_thread_id only makes sense if we're replying in
# the same chat.
return self.message_thread_id if chat_id in {self.chat_id, self.chat.username} else None

async def reply_text(
self,
Expand All @@ -1601,7 +1612,7 @@ async def reply_text(
reply_markup: Optional[ReplyMarkup] = None,
entities: Optional[Sequence["MessageEntity"]] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
link_preview_options: ODVInput["LinkPreviewOptions"] = DEFAULT_NONE,
reply_parameters: Optional["ReplyParameters"] = None,
*,
Expand Down Expand Up @@ -1677,7 +1688,7 @@ async def reply_markdown(
reply_markup: Optional[ReplyMarkup] = None,
entities: Optional[Sequence["MessageEntity"]] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
link_preview_options: ODVInput["LinkPreviewOptions"] = DEFAULT_NONE,
reply_parameters: Optional["ReplyParameters"] = None,
*,
Expand Down Expand Up @@ -1759,7 +1770,7 @@ async def reply_markdown_v2(
reply_markup: Optional[ReplyMarkup] = None,
entities: Optional[Sequence["MessageEntity"]] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
link_preview_options: ODVInput["LinkPreviewOptions"] = DEFAULT_NONE,
reply_parameters: Optional["ReplyParameters"] = None,
*,
Expand Down Expand Up @@ -1837,7 +1848,7 @@ async def reply_html(
reply_markup: Optional[ReplyMarkup] = None,
entities: Optional[Sequence["MessageEntity"]] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
link_preview_options: ODVInput["LinkPreviewOptions"] = DEFAULT_NONE,
reply_parameters: Optional["ReplyParameters"] = None,
*,
Expand Down Expand Up @@ -1915,7 +1926,7 @@ async def reply_media_group(
],
disable_notification: ODVInput[bool] = DEFAULT_NONE,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
reply_parameters: Optional["ReplyParameters"] = None,
*,
reply_to_message_id: Optional[int] = None,
Expand Down Expand Up @@ -1994,7 +2005,7 @@ async def reply_photo(
parse_mode: ODVInput[str] = DEFAULT_NONE,
caption_entities: Optional[Sequence["MessageEntity"]] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
has_spoiler: Optional[bool] = None,
reply_parameters: Optional["ReplyParameters"] = None,
*,
Expand Down Expand Up @@ -2076,7 +2087,7 @@ async def reply_audio(
parse_mode: ODVInput[str] = DEFAULT_NONE,
caption_entities: Optional[Sequence["MessageEntity"]] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
thumbnail: Optional[FileInput] = None,
reply_parameters: Optional["ReplyParameters"] = None,
*,
Expand Down Expand Up @@ -2159,7 +2170,7 @@ async def reply_document(
disable_content_type_detection: Optional[bool] = None,
caption_entities: Optional[Sequence["MessageEntity"]] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
thumbnail: Optional[FileInput] = None,
reply_parameters: Optional["ReplyParameters"] = None,
*,
Expand Down Expand Up @@ -2242,7 +2253,7 @@ async def reply_animation(
reply_markup: Optional[ReplyMarkup] = None,
caption_entities: Optional[Sequence["MessageEntity"]] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
has_spoiler: Optional[bool] = None,
thumbnail: Optional[FileInput] = None,
reply_parameters: Optional["ReplyParameters"] = None,
Expand Down Expand Up @@ -2323,7 +2334,7 @@ async def reply_sticker(
disable_notification: ODVInput[bool] = DEFAULT_NONE,
reply_markup: Optional[ReplyMarkup] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
emoji: Optional[str] = None,
reply_parameters: Optional["ReplyParameters"] = None,
*,
Expand Down Expand Up @@ -2401,7 +2412,7 @@ async def reply_video(
supports_streaming: Optional[bool] = None,
caption_entities: Optional[Sequence["MessageEntity"]] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
has_spoiler: Optional[bool] = None,
thumbnail: Optional[FileInput] = None,
reply_parameters: Optional["ReplyParameters"] = None,
Expand Down Expand Up @@ -2485,7 +2496,7 @@ async def reply_video_note(
disable_notification: ODVInput[bool] = DEFAULT_NONE,
reply_markup: Optional[ReplyMarkup] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
thumbnail: Optional[FileInput] = None,
reply_parameters: Optional["ReplyParameters"] = None,
*,
Expand Down Expand Up @@ -2564,7 +2575,7 @@ async def reply_voice(
parse_mode: ODVInput[str] = DEFAULT_NONE,
caption_entities: Optional[Sequence["MessageEntity"]] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
reply_parameters: Optional["ReplyParameters"] = None,
*,
reply_to_message_id: Optional[int] = None,
Expand Down Expand Up @@ -2644,7 +2655,7 @@ async def reply_location(
heading: Optional[int] = None,
proximity_alert_radius: Optional[int] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
reply_parameters: Optional["ReplyParameters"] = None,
*,
reply_to_message_id: Optional[int] = None,
Expand Down Expand Up @@ -2727,7 +2738,7 @@ async def reply_venue(
google_place_id: Optional[str] = None,
google_place_type: Optional[str] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
reply_parameters: Optional["ReplyParameters"] = None,
*,
reply_to_message_id: Optional[int] = None,
Expand Down Expand Up @@ -2808,7 +2819,7 @@ async def reply_contact(
reply_markup: Optional[ReplyMarkup] = None,
vcard: Optional[str] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
reply_parameters: Optional["ReplyParameters"] = None,
*,
reply_to_message_id: Optional[int] = None,
Expand Down Expand Up @@ -2893,7 +2904,7 @@ async def reply_poll(
close_date: Optional[Union[int, datetime.datetime]] = None,
explanation_entities: Optional[Sequence["MessageEntity"]] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
reply_parameters: Optional["ReplyParameters"] = None,
*,
reply_to_message_id: Optional[int] = None,
Expand Down Expand Up @@ -2973,7 +2984,7 @@ async def reply_dice(
reply_markup: Optional[ReplyMarkup] = None,
emoji: Optional[str] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
reply_parameters: Optional["ReplyParameters"] = None,
*,
reply_to_message_id: Optional[int] = None,
Expand Down Expand Up @@ -3039,7 +3050,7 @@ async def reply_dice(
async def reply_chat_action(
self,
action: str,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
*,
read_timeout: ODVInput[float] = DEFAULT_NONE,
write_timeout: ODVInput[float] = DEFAULT_NONE,
Expand Down Expand Up @@ -3086,7 +3097,7 @@ async def reply_game(
disable_notification: ODVInput[bool] = DEFAULT_NONE,
reply_markup: Optional["InlineKeyboardMarkup"] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
reply_parameters: Optional["ReplyParameters"] = None,
*,
reply_to_message_id: Optional[int] = None,
Expand Down Expand Up @@ -3177,7 +3188,7 @@ async def reply_invoice(
max_tip_amount: Optional[int] = None,
suggested_tip_amounts: Optional[Sequence[int]] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
reply_parameters: Optional["ReplyParameters"] = None,
*,
reply_to_message_id: Optional[int] = None,
Expand Down Expand Up @@ -3387,7 +3398,7 @@ async def reply_copy(
disable_notification: ODVInput[bool] = DEFAULT_NONE,
reply_markup: Optional[ReplyMarkup] = None,
protect_content: ODVInput[bool] = DEFAULT_NONE,
message_thread_id: Optional[int] = None,
message_thread_id: ODVInput[int] = DEFAULT_NONE,
reply_parameters: Optional["ReplyParameters"] = None,
*,
reply_to_message_id: Optional[int] = None,
Expand Down
29 changes: 27 additions & 2 deletions tests/auxil/bot_method_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import functools
import inspect
import re
from typing import Any, Callable, Dict, Iterable, List, Optional
from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Tuple

import pytest

Expand Down Expand Up @@ -59,6 +59,7 @@ def check_shortcut_signature(
bot_method: Callable,
shortcut_kwargs: List[str],
additional_kwargs: List[str],
annotation_overrides: Optional[Dict[str, Tuple[Any, Any]]] = None,
) -> bool:
"""
Checks that the signature of a shortcut matches the signature of the underlying bot method.
Expand All @@ -69,10 +70,14 @@ def check_shortcut_signature(
shortcut_kwargs: The kwargs passed by the shortcut directly, e.g. ``chat_id``
additional_kwargs: Additional kwargs of the shortcut that the bot method doesn't have, e.g.
``quote``.
annotation_overrides: A dictionary of exceptions for the annotation comparison. The key is
the name of the argument, the value is a tuple of the expected annotation and
the default value. E.g. ``{'parse_mode': (str, 'None')}``.
Returns:
:obj:`bool`: Whether or not the signature matches.
"""
annotation_overrides = annotation_overrides or {}

def resolve_class(class_name: str) -> Optional[type]:
"""Attempts to resolve a PTB class (telegram module only) from a ForwardRef.
Expand Down Expand Up @@ -117,6 +122,14 @@ def resolve_class(class_name: str) -> Optional[type]:
if shortcut_sig.parameters[kwarg].kind != expected_kind:
raise Exception(f"Argument {kwarg} must be of kind {expected_kind}.")

if kwarg in annotation_overrides:
if shortcut_sig.parameters[kwarg].annotation != annotation_overrides[kwarg][0]:
raise Exception(
f"For argument {kwarg} I expected {annotation_overrides[kwarg]}, "
f"but got {shortcut_sig.parameters[kwarg].annotation}"
)
continue

if bot_sig.parameters[kwarg].annotation != shortcut_sig.parameters[kwarg].annotation:
if FORWARD_REF_PATTERN.search(str(shortcut_sig.parameters[kwarg])):
# If a shortcut signature contains a ForwardRef, the simple comparison of
Expand Down Expand Up @@ -155,6 +168,13 @@ def resolve_class(class_name: str) -> Optional[type]:
bot_method_sig = inspect.signature(bot_method)
shortcut_sig = inspect.signature(shortcut)
for arg in expected_args:
if arg in annotation_overrides:
if shortcut_sig.parameters[arg].default == annotation_overrides[arg][1]:
continue
raise Exception(
f"For argument {arg} I expected default {annotation_overrides[arg][1]}, "
f"but got {shortcut_sig.parameters[arg].default}"
)
if not shortcut_sig.parameters[arg].default == bot_method_sig.parameters[arg].default:
raise Exception(
f"Default for argument {arg} does not match the default of the Bot method."
Expand Down Expand Up @@ -525,6 +545,7 @@ async def check_defaults_handling(
method: Callable,
bot: Bot,
return_value=None,
no_default_kwargs: Collection[str] = frozenset(),
) -> bool:
"""
Checks that tg.ext.Defaults are handled correctly.
Expand All @@ -536,6 +557,8 @@ async def check_defaults_handling(
return_value: Optional. The return value of Bot._post that the method expects. Defaults to
None. get_file is automatically handled. If this is a `TelegramObject`, Bot._post will
return the `to_dict` representation of it.
no_default_kwargs: Optional. A collection of keyword arguments that should not have default
values. Defaults to an empty frozenset.
"""
raw_bot = not isinstance(bot, ExtBot)
Expand All @@ -545,7 +568,9 @@ async def check_defaults_handling(
kwargs_need_default = {
kwarg
for kwarg, value in shortcut_signature.parameters.items()
if isinstance(value.default, DefaultValue) and not kwarg.endswith("_timeout")
if isinstance(value.default, DefaultValue)
and not kwarg.endswith("_timeout")
and kwarg not in no_default_kwargs
}
# We tested this for a long time, but Bot API 7.0 deprecated it in favor of
# reply_parameters. In the transition phase, both exist in a mutually exclusive
Expand Down
Loading

0 comments on commit fed8d88

Please sign in to comment.