Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add invoice_payload filtering #4005

Merged
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 38 additions & 3 deletions telegram/ext/_precheckoutqueryhandler.py
Expand Up @@ -19,9 +19,16 @@
"""This module contains the PreCheckoutQueryHandler class."""


import re
from typing import Optional, Pattern, TypeVar, Union

from telegram import Update
from telegram._utils.defaultvalue import DEFAULT_TRUE
from telegram._utils.types import DVType
from telegram.ext._basehandler import BaseHandler
from telegram.ext._utils.types import CCT
from telegram.ext._utils.types import CCT, HandlerCallback

RT = TypeVar("RT")


class PreCheckoutQueryHandler(BaseHandler[Update, CCT]):
Expand All @@ -43,6 +50,10 @@ async def callback(update: Update, context: CallbackContext)

The return value of the callback is usually ignored except for the special case of
:class:`telegram.ext.ConversationHandler`.
pattern (:obj:`str` | :func:`re.Pattern <re.compile>`, optional): Optional. Regex pattern
to test :attr:`telegram.PreCheckoutQuery.invoice_payload` against.

.. versionadded:: NEXT.VERSION
block (:obj:`bool`, optional): Determines whether the return value of the callback should
be awaited before processing the next handler in
:meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`.
Expand All @@ -51,11 +62,28 @@ async def callback(update: Update, context: CallbackContext)

Attributes:
callback (:term:`coroutine function`): The callback function for this handler.
pattern (:obj:`str` | :func:`re.Pattern <re.compile>`, optional): Optional. Regex pattern
to test :attr:`telegram.PreCheckoutQuery.invoice_payload` against.

.. versionadded:: NEXT.VERSION
block (:obj:`bool`): Determines whether the callback will run in a blocking way..

"""

__slots__ = ()
__slots__ = ("pattern",)

def __init__(
self,
callback: HandlerCallback[Update, CCT, RT],
pattern: Optional[Union[str, Pattern[str]]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please move the argument to the end of the list. Otherwise this would be a breaking change as PCQH(callback, False) would no longer work as expected. Please also move the docstring entries for argument and attribute above.

block: DVType[bool] = DEFAULT_TRUE,
):
super().__init__(callback, block=block)

if isinstance(pattern, str):
pattern = re.compile(pattern)

self.pattern: Optional[Union[str, Pattern[str]]] = pattern
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(pattern, str):
pattern = re.compile(pattern)
self.pattern: Optional[Union[str, Pattern[str]]] = pattern
self.pattern: Optional[Pattern[str]] = re.compile(pattern) if pattern is not None else None

Just a bit shorter and also gives the attribute a more narrow type.
This way we can also use self.pattern.match(…) in check_update below :)


def check_update(self, update: object) -> bool:
"""Determines whether an update should be passed to this handler's :attr:`callback`.
Expand All @@ -67,4 +95,11 @@ def check_update(self, update: object) -> bool:
:obj:`bool`

"""
return isinstance(update, Update) and bool(update.pre_checkout_query)
if isinstance(update, Update) and update.pre_checkout_query:
invoice_payload = update.pre_checkout_query.invoice_payload
if self.pattern:
if re.match(self.pattern, invoice_payload):
return True
else:
return True
return False
40 changes: 36 additions & 4 deletions telegram/ext/filters.py
Expand Up @@ -75,6 +75,7 @@
"Sticker",
"STORY",
"SUCCESSFUL_PAYMENT",
"SuccessfulPayment",
"SenderChat",
"StatusUpdate",
"TEXT",
Expand Down Expand Up @@ -2265,14 +2266,45 @@ def filter(self, message: Message) -> bool:
"""


class _SuccessfulPayment(MessageFilter):
__slots__ = ()
class SuccessfulPayment(MessageFilter):
"""Successful Payment Messages. If a list of invoice payloads is passed, it filters
messages to only allow those whose `invoice_payload` is appearing in the given list.

Examples:
`MessageHandler(filters.SuccessfulPayment(['Custom-Payload']), callback_method)`

.. seealso::
:attr:`telegram.ext.filters.SUCCESSFUL_PAYMENT`

Args:
invoice_payloads (List[:obj:`str`] | Tuple[:obj:`str`], optional): Which
invoice payloads to allow. Only exact matches are allowed. If not
specified, will allow any invoice payload.

.. versionadded:: NEXT.VERSION
"""

__slots__ = ("invoice_payloads",)

def __init__(self, invoice_payloads: Optional[Union[List[str], Tuple[str, ...]]] = None):
self.invoice_payloads: Optional[Sequence[str]] = invoice_payloads
super().__init__(
name=f"filters.SuccessfulPayment({invoice_payloads})"
if invoice_payloads
else "filters.SUCCESSFUL_PAYMENT"
)

def filter(self, message: Message) -> bool:
return bool(message.successful_payment)
if self.invoice_payloads is None:
return bool(message.successful_payment)
return (
payment.invoice_payload in self.invoice_payloads
if (payment := message.successful_payment)
else False
)


SUCCESSFUL_PAYMENT = _SuccessfulPayment(name="filters.SUCCESSFUL_PAYMENT")
SUCCESSFUL_PAYMENT = SuccessfulPayment()
"""Messages that contain :attr:`telegram.Message.successful_payment`."""


Expand Down
8 changes: 8 additions & 0 deletions tests/ext/test_filters.py
Expand Up @@ -31,6 +31,7 @@
Message,
MessageEntity,
Sticker,
SuccessfulPayment,
Update,
User,
)
Expand Down Expand Up @@ -1877,6 +1878,13 @@ def test_filters_successful_payment(self, update):
update.message.successful_payment = "test"
assert filters.SUCCESSFUL_PAYMENT.check_update(update)

def test_filters_successful_payment_payloads(self, update):
update.message.successful_payment = SuccessfulPayment(
"USD", 100, "custom-payload", "123", "123"
)
assert filters.SuccessfulPayment(("custom-payload",)).check_update(update)
assert not filters.SuccessfulPayment(["test1"]).check_update(update)
Bibo-Joshi marked this conversation as resolved.
Show resolved Hide resolved

def test_filters_passport_data(self, update):
assert not filters.PASSPORT_DATA.check_update(update)
update.message.passport_data = "test"
Expand Down
13 changes: 12 additions & 1 deletion tests/ext/test_precheckoutqueryhandler.py
Expand Up @@ -69,12 +69,15 @@ def false_update(request):

@pytest.fixture(scope="class")
def pre_checkout_query():
return Update(
update = Update(
1,
pre_checkout_query=PreCheckoutQuery(
"id", User(1, "test user", False), "EUR", 223, "invoice_payload"
),
)
update._unfreeze()
update.pre_checkout_query._unfreeze()
return update


class TestPreCheckoutQueryHandler:
Expand Down Expand Up @@ -103,6 +106,14 @@ async def callback(self, update, context):
and isinstance(update.pre_checkout_query, PreCheckoutQuery)
)

def test_with_pattern(self, pre_checkout_query):
handler = PreCheckoutQueryHandler(self.callback, pattern=".*voice.*")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please also test passing a compiled pattern :)


assert handler.check_update(pre_checkout_query)

pre_checkout_query.pre_checkout_query.invoice_payload = "nothing here"
assert not handler.check_update(pre_checkout_query)

def test_other_update_types(self, false_update):
handler = PreCheckoutQueryHandler(self.callback)
assert not handler.check_update(false_update)
Expand Down