Skip to content

Commit

Permalink
Add Parameter pattern to PreCheckoutQueryHandler and `filters.Suc…
Browse files Browse the repository at this point in the history
…cessfulPayment` (#4005)
  • Loading branch information
aelkheir committed Jan 2, 2024
1 parent 7fcfad4 commit f3479cd
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 8 deletions.
38 changes: 35 additions & 3 deletions telegram/ext/_precheckoutqueryhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@
"""This module contains the PreCheckoutQueryHandler class."""


import re
from typing import Optional, Pattern, TypeVar, Union

from telegram import Update
from telegram._utils.defaultvalue import DEFAULT_TRUE
from telegram._utils.types import DVType
from telegram.ext._basehandler import BaseHandler
from telegram.ext._utils.types import CCT
from telegram.ext._utils.types import CCT, HandlerCallback

RT = TypeVar("RT")


class PreCheckoutQueryHandler(BaseHandler[Update, CCT]):
Expand All @@ -48,14 +55,32 @@ async def callback(update: Update, context: CallbackContext)
:meth:`telegram.ext.Application.process_update`. Defaults to :obj:`True`.
.. seealso:: :wiki:`Concurrency`
pattern (:obj:`str` | :func:`re.Pattern <re.compile>`, optional): Optional. Regex pattern
to test :attr:`telegram.PreCheckoutQuery.invoice_payload` against.
.. versionadded:: NEXT.VERSION
Attributes:
callback (:term:`coroutine function`): The callback function for this handler.
block (:obj:`bool`): Determines whether the callback will run in a blocking way..
pattern (:obj:`str` | :func:`re.Pattern <re.compile>`, optional): Optional. Regex pattern
to test :attr:`telegram.PreCheckoutQuery.invoice_payload` against.
.. versionadded:: NEXT.VERSION
"""

__slots__ = ()
__slots__ = ("pattern",)

def __init__(
self,
callback: HandlerCallback[Update, CCT, RT],
block: DVType[bool] = DEFAULT_TRUE,
pattern: Optional[Union[str, Pattern[str]]] = None,
):
super().__init__(callback, block=block)

self.pattern: Optional[Pattern[str]] = re.compile(pattern) if pattern is not None else None

def check_update(self, update: object) -> bool:
"""Determines whether an update should be passed to this handler's :attr:`callback`.
Expand All @@ -67,4 +92,11 @@ def check_update(self, update: object) -> bool:
:obj:`bool`
"""
return isinstance(update, Update) and bool(update.pre_checkout_query)
if isinstance(update, Update) and update.pre_checkout_query:
invoice_payload = update.pre_checkout_query.invoice_payload
if self.pattern:
if self.pattern.match(invoice_payload):
return True
else:
return True
return False
40 changes: 36 additions & 4 deletions telegram/ext/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
"Sticker",
"STORY",
"SUCCESSFUL_PAYMENT",
"SuccessfulPayment",
"SenderChat",
"StatusUpdate",
"TEXT",
Expand Down Expand Up @@ -2265,14 +2266,45 @@ def filter(self, message: Message) -> bool:
"""


class _SuccessfulPayment(MessageFilter):
__slots__ = ()
class SuccessfulPayment(MessageFilter):
"""Successful Payment Messages. If a list of invoice payloads is passed, it filters
messages to only allow those whose `invoice_payload` is appearing in the given list.
Examples:
`MessageHandler(filters.SuccessfulPayment(['Custom-Payload']), callback_method)`
.. seealso::
:attr:`telegram.ext.filters.SUCCESSFUL_PAYMENT`
Args:
invoice_payloads (List[:obj:`str`] | Tuple[:obj:`str`], optional): Which
invoice payloads to allow. Only exact matches are allowed. If not
specified, will allow any invoice payload.
.. versionadded:: NEXT.VERSION
"""

__slots__ = ("invoice_payloads",)

def __init__(self, invoice_payloads: Optional[Union[List[str], Tuple[str, ...]]] = None):
self.invoice_payloads: Optional[Sequence[str]] = invoice_payloads
super().__init__(
name=f"filters.SuccessfulPayment({invoice_payloads})"
if invoice_payloads
else "filters.SUCCESSFUL_PAYMENT"
)

def filter(self, message: Message) -> bool:
return bool(message.successful_payment)
if self.invoice_payloads is None:
return bool(message.successful_payment)
return (
payment.invoice_payload in self.invoice_payloads
if (payment := message.successful_payment)
else False
)


SUCCESSFUL_PAYMENT = _SuccessfulPayment(name="filters.SUCCESSFUL_PAYMENT")
SUCCESSFUL_PAYMENT = SuccessfulPayment()
"""Messages that contain :attr:`telegram.Message.successful_payment`."""


Expand Down
19 changes: 19 additions & 0 deletions tests/ext/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Message,
MessageEntity,
Sticker,
SuccessfulPayment,
Update,
User,
)
Expand Down Expand Up @@ -1877,6 +1878,24 @@ def test_filters_successful_payment(self, update):
update.message.successful_payment = "test"
assert filters.SUCCESSFUL_PAYMENT.check_update(update)

def test_filters_successful_payment_payloads(self, update):
assert not filters.SuccessfulPayment(("custom-payload",)).check_update(update)
assert not filters.SuccessfulPayment().check_update(update)

update.message.successful_payment = SuccessfulPayment(
"USD", 100, "custom-payload", "123", "123"
)
assert filters.SuccessfulPayment(("custom-payload",)).check_update(update)
assert filters.SuccessfulPayment().check_update(update)
assert not filters.SuccessfulPayment(["test1"]).check_update(update)

def test_filters_successful_payment_repr(self):
f = filters.SuccessfulPayment()
assert str(f) == "filters.SUCCESSFUL_PAYMENT"

f = filters.SuccessfulPayment(["payload1", "payload2"])
assert str(f) == "filters.SuccessfulPayment(['payload1', 'payload2'])"

def test_filters_passport_data(self, update):
assert not filters.PASSPORT_DATA.check_update(update)
update.message.passport_data = "test"
Expand Down
23 changes: 22 additions & 1 deletion tests/ext/test_precheckoutqueryhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
import asyncio
import re

import pytest

Expand Down Expand Up @@ -69,12 +70,15 @@ def false_update(request):

@pytest.fixture(scope="class")
def pre_checkout_query():
return Update(
update = Update(
1,
pre_checkout_query=PreCheckoutQuery(
"id", User(1, "test user", False), "EUR", 223, "invoice_payload"
),
)
update._unfreeze()
update.pre_checkout_query._unfreeze()
return update


class TestPreCheckoutQueryHandler:
Expand Down Expand Up @@ -103,6 +107,23 @@ async def callback(self, update, context):
and isinstance(update.pre_checkout_query, PreCheckoutQuery)
)

def test_with_pattern(self, pre_checkout_query):
handler = PreCheckoutQueryHandler(self.callback, pattern=".*voice.*")

assert handler.check_update(pre_checkout_query)

pre_checkout_query.pre_checkout_query.invoice_payload = "nothing here"
assert not handler.check_update(pre_checkout_query)

def test_with_compiled_pattern(self, pre_checkout_query):
handler = PreCheckoutQueryHandler(self.callback, pattern=re.compile(r".*payload"))

pre_checkout_query.pre_checkout_query.invoice_payload = "invoice_payload"
assert handler.check_update(pre_checkout_query)

pre_checkout_query.pre_checkout_query.invoice_payload = "nothing here"
assert not handler.check_update(pre_checkout_query)

def test_other_update_types(self, false_update):
handler = PreCheckoutQueryHandler(self.callback)
assert not handler.check_update(false_update)
Expand Down

0 comments on commit f3479cd

Please sign in to comment.