diff --git a/docs/source/telegram.ext.baseupdateprocessor.rst b/docs/source/telegram.ext.baseupdateprocessor.rst new file mode 100644 index 00000000000..7155adf7191 --- /dev/null +++ b/docs/source/telegram.ext.baseupdateprocessor.rst @@ -0,0 +1,6 @@ +BaseUpdateProcessor +=================== + +.. autoclass:: telegram.ext.BaseUpdateProcessor + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/telegram.ext.rst b/docs/source/telegram.ext.rst index e776128ecb8..ab9efc1b353 100644 --- a/docs/source/telegram.ext.rst +++ b/docs/source/telegram.ext.rst @@ -9,12 +9,14 @@ telegram.ext package telegram.ext.application telegram.ext.applicationbuilder telegram.ext.applicationhandlerstop + telegram.ext.baseupdateprocessor telegram.ext.callbackcontext telegram.ext.contexttypes telegram.ext.defaults telegram.ext.extbot telegram.ext.job telegram.ext.jobqueue + telegram.ext.simpleupdateprocessor telegram.ext.updater telegram.ext.handlers-tree.rst telegram.ext.persistence-tree.rst diff --git a/docs/source/telegram.ext.simpleupdateprocessor.rst b/docs/source/telegram.ext.simpleupdateprocessor.rst new file mode 100644 index 00000000000..1e30c27566c --- /dev/null +++ b/docs/source/telegram.ext.simpleupdateprocessor.rst @@ -0,0 +1,6 @@ +SimpleUpdateProcessor +===================== + +.. autoclass:: telegram.ext.SimpleUpdateProcessor + :members: + :show-inheritance: \ No newline at end of file diff --git a/telegram/ext/__init__.py b/telegram/ext/__init__.py index b37d0fc70ca..a6abdb974e9 100644 --- a/telegram/ext/__init__.py +++ b/telegram/ext/__init__.py @@ -26,6 +26,7 @@ "BaseHandler", "BasePersistence", "BaseRateLimiter", + "BaseUpdateProcessor", "CallbackContext", "CallbackDataCache", "CallbackQueryHandler", @@ -51,6 +52,7 @@ "PreCheckoutQueryHandler", "PrefixHandler", "ShippingQueryHandler", + "SimpleUpdateProcessor", "StringCommandHandler", "StringRegexHandler", "TypeHandler", @@ -63,6 +65,7 @@ from ._applicationbuilder import ApplicationBuilder from ._basepersistence import BasePersistence, PersistenceInput from ._baseratelimiter import BaseRateLimiter +from ._baseupdateprocessor import BaseUpdateProcessor, SimpleUpdateProcessor from ._callbackcontext import CallbackContext from ._callbackdatacache import CallbackDataCache, InvalidCallbackData from ._callbackqueryhandler import CallbackQueryHandler diff --git a/telegram/ext/_application.py b/telegram/ext/_application.py index 7452e4c9c52..15178d9176f 100644 --- a/telegram/ext/_application.py +++ b/telegram/ext/_application.py @@ -57,6 +57,7 @@ from telegram._utils.warnings import warn from telegram.error import TelegramError from telegram.ext._basepersistence import BasePersistence +from telegram.ext._baseupdateprocessor import BaseUpdateProcessor from telegram.ext._contexttypes import ContextTypes from telegram.ext._extbot import ExtBot from telegram.ext._handler import BaseHandler @@ -228,12 +229,11 @@ class Application(Generic[BT, CCT, UD, CD, BD, JQ], AsyncContextManager["Applica "_chat_data", "_chat_ids_to_be_deleted_in_persistence", "_chat_ids_to_be_updated_in_persistence", - "_concurrent_updates", - "_concurrent_updates_sem", "_conversation_handler_conversations", "_initialized", "_job_queue", "_running", + "_update_processor", "_user_data", "_user_ids_to_be_deleted_in_persistence", "_user_ids_to_be_updated_in_persistence", @@ -259,7 +259,7 @@ def __init__( update_queue: "asyncio.Queue[object]", updater: Optional[Updater], job_queue: JQ, - concurrent_updates: Union[bool, int], + update_processor: "BaseUpdateProcessor", persistence: Optional[BasePersistence[UD, CD, BD]], context_types: ContextTypes[CCT, UD, CD, BD], post_init: Optional[ @@ -297,14 +297,7 @@ def __init__( self.post_stop: Optional[ Callable[["Application[BT, CCT, UD, CD, BD, JQ]"], Coroutine[Any, Any, None]] ] = post_stop - - if isinstance(concurrent_updates, int) and concurrent_updates < 0: - raise ValueError("`concurrent_updates` must be a non-negative integer!") - if concurrent_updates is True: - concurrent_updates = 256 - self._concurrent_updates_sem = asyncio.BoundedSemaphore(concurrent_updates or 1) - self._concurrent_updates: int = concurrent_updates or 0 - + self._update_processor = update_processor self.bot_data: BD = self.context_types.bot_data() self._user_data: DefaultDict[int, UD] = defaultdict(self.context_types.user_data) self._chat_data: DefaultDict[int, CD] = defaultdict(self.context_types.chat_data) @@ -359,9 +352,13 @@ def concurrent_updates(self) -> int: """:obj:`int`: The number of concurrent updates that will be processed in parallel. A value of ``0`` indicates updates are *not* being processed concurrently. + .. versionchanged:: NEXT.VERSION + This is now just a shortcut to :attr:`update_processor.max_concurrent_updates + `. + .. seealso:: :wiki:`Concurrency` """ - return self._concurrent_updates + return self._update_processor.max_concurrent_updates @property def job_queue(self) -> Optional["JobQueue[CCT]"]: @@ -379,12 +376,25 @@ def job_queue(self) -> Optional["JobQueue[CCT]"]: ) return self._job_queue + @property + def update_processor(self) -> "BaseUpdateProcessor": + """:class:`telegram.ext.BaseUpdateProcessor`: The update processor used by this + application. + + .. seealso:: :wiki:`Concurrency` + + .. versionadded:: NEXT.VERSION + """ + return self._update_processor + async def initialize(self) -> None: """Initializes the Application by initializing: * The :attr:`bot`, by calling :meth:`telegram.Bot.initialize`. * The :attr:`updater`, by calling :meth:`telegram.ext.Updater.initialize`. * The :attr:`persistence`, by loading persistent conversations and data. + * The :attr:`update_processor` by calling + :meth:`telegram.ext.BaseUpdateProcessor.initialize`. Does *not* call :attr:`post_init` - that is only done by :meth:`run_polling` and :meth:`run_webhook`. @@ -397,6 +407,8 @@ async def initialize(self) -> None: return await self.bot.initialize() + await self._update_processor.initialize() + if self.updater: await self.updater.initialize() @@ -429,6 +441,7 @@ async def shutdown(self) -> None: * :attr:`updater` by calling :meth:`telegram.ext.Updater.shutdown` * :attr:`persistence` by calling :meth:`update_persistence` and :meth:`BasePersistence.flush` + * :attr:`update_processor` by calling :meth:`telegram.ext.BaseUpdateProcessor.shutdown` Does *not* call :attr:`post_shutdown` - that is only done by :meth:`run_polling` and :meth:`run_webhook`. @@ -447,6 +460,8 @@ async def shutdown(self) -> None: return await self.bot.shutdown() + await self._update_processor.shutdown() + if self.updater: await self.updater.shutdown() @@ -1060,11 +1075,15 @@ async def _update_fetcher(self) -> None: _LOGGER.debug("Processing update %s", update) - if self._concurrent_updates: + if self._update_processor.max_concurrent_updates > 1: # We don't await the below because it has to be run concurrently - self.create_task(self.__process_update_wrapper(update), update=update) + self.create_task( + self.__process_update_wrapper(update), + update=update, + ) else: await self.__process_update_wrapper(update) + except asyncio.CancelledError: # This may happen if the application is manually run via application.start() and # then a KeyboardInterrupt is sent. We must prevent this loop to die since @@ -1075,9 +1094,8 @@ async def _update_fetcher(self) -> None: ) async def __process_update_wrapper(self, update: object) -> None: - async with self._concurrent_updates_sem: - await self.process_update(update) - self.update_queue.task_done() + await self._update_processor.process_update(update, self.process_update(update)) + self.update_queue.task_done() async def process_update(self, update: object) -> None: """Processes a single update and marks the update to be updated by the persistence later. diff --git a/telegram/ext/_applicationbuilder.py b/telegram/ext/_applicationbuilder.py index 8188105b7f1..cd3b7ad35c0 100644 --- a/telegram/ext/_applicationbuilder.py +++ b/telegram/ext/_applicationbuilder.py @@ -36,6 +36,7 @@ from telegram._utils.defaultvalue import DEFAULT_FALSE, DEFAULT_NONE, DefaultValue from telegram._utils.types import DVInput, DVType, FilePathInput, ODVInput from telegram.ext._application import Application +from telegram.ext._baseupdateprocessor import BaseUpdateProcessor, SimpleUpdateProcessor from telegram.ext._contexttypes import ContextTypes from telegram.ext._extbot import ExtBot from telegram.ext._jobqueue import JobQueue @@ -127,7 +128,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]): "_base_file_url", "_base_url", "_bot", - "_concurrent_updates", + "_update_processor", "_connect_timeout", "_connection_pool_size", "_context_types", @@ -198,7 +199,9 @@ def __init__(self: "InitApplicationBuilder"): self._context_types: DVType[ContextTypes] = DefaultValue(ContextTypes()) self._application_class: DVType[Type[Application]] = DefaultValue(Application) self._application_kwargs: Dict[str, object] = {} - self._concurrent_updates: Union[int, DefaultValue[bool]] = DEFAULT_FALSE + self._update_processor: "BaseUpdateProcessor" = SimpleUpdateProcessor( + max_concurrent_updates=1 + ) self._updater: ODVInput[Updater] = DEFAULT_NONE self._post_init: Optional[Callable[[Application], Coroutine[Any, Any, None]]] = None self._post_shutdown: Optional[Callable[[Application], Coroutine[Any, Any, None]]] = None @@ -306,7 +309,7 @@ def build( bot=bot, update_queue=update_queue, updater=updater, - concurrent_updates=DefaultValue.get_value(self._concurrent_updates), + update_processor=self._update_processor, job_queue=job_queue, persistence=persistence, context_types=DefaultValue.get_value(self._context_types), @@ -902,7 +905,9 @@ def update_queue(self: BuilderType, update_queue: "Queue[object]") -> BuilderTyp self._update_queue = update_queue return self - def concurrent_updates(self: BuilderType, concurrent_updates: Union[bool, int]) -> BuilderType: + def concurrent_updates( + self: BuilderType, concurrent_updates: Union[bool, int, "BaseUpdateProcessor"] + ) -> BuilderType: """Specifies if and how many updates may be processed concurrently instead of one by one. If not called, updates will be processed one by one. @@ -917,14 +922,34 @@ def concurrent_updates(self: BuilderType, concurrent_updates: Union[bool, int]) .. seealso:: :attr:`telegram.ext.Application.concurrent_updates` Args: - concurrent_updates (:obj:`bool` | :obj:`int`): Passing :obj:`True` will allow for - ``256`` updates to be processed concurrently. Pass an integer to specify a - different number of updates that may be processed concurrently. + concurrent_updates (:obj:`bool` | :obj:`int` | :class:`BaseUpdateProcessor`): Passing + :obj:`True` will allow for ``256`` updates to be processed concurrently using + :class:`telegram.ext.SimpleUpdateProcessor`. Pass an integer to specify a different + number of updates that may be processed concurrently. Pass an instance of + :class:`telegram.ext.BaseUpdateProcessor` to use that instance for handling updates + concurrently. + + .. versionchanged:: NEXT.VERSION + Now accepts :class:`BaseUpdateProcessor` instances. Returns: :class:`ApplicationBuilder`: The same builder with the updated argument. """ - self._concurrent_updates = concurrent_updates + # Check if concurrent updates is bool and convert to integer + if concurrent_updates is True: + concurrent_updates = 256 + elif concurrent_updates is False: + concurrent_updates = 1 + + # If `concurrent_updates` is an integer, create a `SimpleUpdateProcessor` + # instance with that integer value; otherwise, raise an error if the value + # is negative + if isinstance(concurrent_updates, int): + concurrent_updates = SimpleUpdateProcessor(concurrent_updates) + + # Assign default value of concurrent_updates if it is instance of + # `BaseUpdateProcessor` + self._update_processor: BaseUpdateProcessor = concurrent_updates # type: ignore[no-redef] return self def job_queue( diff --git a/telegram/ext/_baseupdateprocessor.py b/telegram/ext/_baseupdateprocessor.py new file mode 100644 index 00000000000..a3b59d4fb92 --- /dev/null +++ b/telegram/ext/_baseupdateprocessor.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2023 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +"""This module contains the BaseProcessor class.""" +from abc import ABC, abstractmethod +from asyncio import BoundedSemaphore +from types import TracebackType +from typing import Any, Awaitable, Optional, Type + + +class BaseUpdateProcessor(ABC): + """An abstract base class for update processors. You can use this class to implement + your own update processor. + + .. seealso:: :wiki:`Concurrency` + + .. versionadded:: NEXT.VERSION + + Args: + max_concurrent_updates (:obj:`int`): The maximum number of updates to be processed + concurrently. If this number is exceeded, new updates will be queued until the number + of currently processed updates decreases. + + Raises: + :exc:`ValueError`: If :paramref:`max_concurrent_updates` is a non-positive integer. + """ + + __slots__ = ("_max_concurrent_updates", "_semaphore") + + def __init__(self, max_concurrent_updates: int): + self._max_concurrent_updates = max_concurrent_updates + if self.max_concurrent_updates < 1: + raise ValueError("`max_concurrent_updates` must be a positive integer!") + self._semaphore = BoundedSemaphore(self.max_concurrent_updates) + + @property + def max_concurrent_updates(self) -> int: + """:obj:`int`: The maximum number of updates that can be processed concurrently.""" + return self._max_concurrent_updates + + @abstractmethod + async def do_process_update( + self, + update: object, + coroutine: "Awaitable[Any]", + ) -> None: + """Custom implementation of how to process an update. Must be implemented by a subclass. + + Warning: + This method will be called by :meth:`process_update`. It should *not* be called + manually. + + Args: + update (:obj:`object`): The update to be processed. + coroutine (:term:`Awaitable`): The coroutine that will be awaited to process the + update. + """ + + @abstractmethod + async def initialize(self) -> None: + """Initializes the processor so resources can be allocated. Must be implemented by a + subclass. + + .. seealso:: + :meth:`shutdown` + """ + + @abstractmethod + async def shutdown(self) -> None: + """Shutdown the processor so resources can be freed. Must be implemented by a subclass. + + .. seealso:: + :meth:`initialize` + """ + + async def process_update( + self, + update: object, + coroutine: "Awaitable[Any]", + ) -> None: + """Calls :meth:`do_process_update` with a semaphore to limit the number of concurrent + updates. + + Args: + update (:obj:`object`): The update to be processed. + coroutine (:term:`Awaitable`): The coroutine that will be awaited to process the + update. + """ + async with self._semaphore: + await self.do_process_update(update, coroutine) + + async def __aenter__(self) -> "BaseUpdateProcessor": + """Simple context manager which initializes the Processor.""" + try: + await self.initialize() + return self + except Exception as exc: + await self.shutdown() + raise exc + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """Shutdown the Processor from the context manager.""" + await self.shutdown() + + +class SimpleUpdateProcessor(BaseUpdateProcessor): + """Instance of :class:`telegram.ext.BaseUpdateProcessor` that immediately awaits the + coroutine, i.e. does not apply any additional processing. This is used by default when + :attr:`telegram.ext.ApplicationBuilder.concurrent_updates` is :obj:`int`. + + .. versionadded:: NEXT.VERSION + """ + + __slots__ = () + + async def do_process_update( + self, + update: object, + coroutine: "Awaitable[Any]", + ) -> None: + """Immediately awaits the coroutine, i.e. does not apply any additional processing. + + Args: + update (:obj:`object`): The update to be processed. + coroutine (:term:`Awaitable`): The coroutine that will be awaited to process the + update. + """ + await coroutine + + async def initialize(self) -> None: + """Does nothing.""" + + async def shutdown(self) -> None: + """Does nothing.""" diff --git a/tests/ext/test_application.py b/tests/ext/test_application.py index db52aab01a7..00001524853 100644 --- a/tests/ext/test_application.py +++ b/tests/ext/test_application.py @@ -49,6 +49,7 @@ JobQueue, MessageHandler, PicklePersistence, + SimpleUpdateProcessor, TypeHandler, Updater, filters, @@ -134,7 +135,7 @@ def test_manual_init_warning(self, recwarn, updater): persistence=None, context_types=ContextTypes(), updater=updater, - concurrent_updates=False, + update_processor=False, post_init=None, post_shutdown=None, post_stop=None, @@ -147,15 +148,13 @@ def test_manual_init_warning(self, recwarn, updater): assert recwarn[0].category is PTBUserWarning assert recwarn[0].filename == __file__, "stacklevel is incorrect!" - @pytest.mark.parametrize( - ("concurrent_updates", "expected"), [(0, 0), (4, 4), (False, 0), (True, 256)] - ) @pytest.mark.filterwarnings("ignore: `Application` instances should") - def test_init(self, one_time_bot, concurrent_updates, expected): + def test_init(self, one_time_bot): update_queue = asyncio.Queue() job_queue = JobQueue() persistence = PicklePersistence("file_path") context_types = ContextTypes() + update_processor = SimpleUpdateProcessor(1) updater = Updater(bot=one_time_bot, update_queue=update_queue) async def post_init(application: Application) -> None: @@ -174,7 +173,7 @@ async def post_stop(application: Application) -> None: persistence=persistence, context_types=context_types, updater=updater, - concurrent_updates=concurrent_updates, + update_processor=update_processor, post_init=post_init, post_shutdown=post_shutdown, post_stop=post_stop, @@ -187,7 +186,7 @@ async def post_stop(application: Application) -> None: assert app.updater is updater assert app.update_queue is updater.update_queue assert app.bot is updater.bot - assert app.concurrent_updates == expected + assert app.update_processor is update_processor assert app.post_init is post_init assert app.post_shutdown is post_shutdown assert app.post_stop is post_stop @@ -201,20 +200,6 @@ async def post_stop(application: Application) -> None: assert isinstance(app.chat_data[1], dict) assert isinstance(app.user_data[1], dict) - with pytest.raises(ValueError, match="must be a non-negative"): - Application( - bot=one_time_bot, - update_queue=update_queue, - job_queue=job_queue, - persistence=persistence, - context_types=context_types, - updater=updater, - concurrent_updates=-1, - post_init=None, - post_shutdown=None, - post_stop=None, - ) - def test_job_queue(self, one_time_bot, app, recwarn): expected_warning = ( "No `JobQueue` set up. To use `JobQueue`, you must install PTB via " @@ -250,23 +235,39 @@ async def test_initialize(self, one_time_bot, monkeypatch, updater): async def after_initialize_bot(*args, **kwargs): self.test_flag.add("bot") + async def after_initialize_update_processor(*args, **kwargs): + self.test_flag.add("update_processor") + async def after_initialize_updater(*args, **kwargs): self.test_flag.add("updater") + update_processor = SimpleUpdateProcessor(1) monkeypatch.setattr(Bot, "initialize", call_after(Bot.initialize, after_initialize_bot)) + monkeypatch.setattr( + SimpleUpdateProcessor, + "initialize", + call_after(SimpleUpdateProcessor.initialize, after_initialize_update_processor), + ) monkeypatch.setattr( Updater, "initialize", call_after(Updater.initialize, after_initialize_updater) ) - if updater: - app = ApplicationBuilder().bot(one_time_bot).build() + app = ( + ApplicationBuilder().bot(one_time_bot).concurrent_updates(update_processor).build() + ) await app.initialize() - assert self.test_flag == {"bot", "updater"} + assert self.test_flag == {"bot", "update_processor", "updater"} await app.shutdown() else: - app = ApplicationBuilder().bot(one_time_bot).updater(None).build() + app = ( + ApplicationBuilder() + .bot(one_time_bot) + .updater(None) + .concurrent_updates(update_processor) + .build() + ) await app.initialize() - assert self.test_flag == {"bot"} + assert self.test_flag == {"bot", "update_processor"} await app.shutdown() @pytest.mark.parametrize("updater", [True, False]) @@ -277,22 +278,35 @@ async def test_shutdown(self, one_time_bot, monkeypatch, updater): def after_bot_shutdown(*args, **kwargs): self.test_flag.add("bot") + def after_shutdown_update_processor(*args, **kwargs): + self.test_flag.add("update_processor") + def after_updater_shutdown(*args, **kwargs): self.test_flag.add("updater") + update_processor = SimpleUpdateProcessor(1) monkeypatch.setattr(Bot, "shutdown", call_after(Bot.shutdown, after_bot_shutdown)) + monkeypatch.setattr( + SimpleUpdateProcessor, + "shutdown", + call_after(SimpleUpdateProcessor.shutdown, after_shutdown_update_processor), + ) monkeypatch.setattr( Updater, "shutdown", call_after(Updater.shutdown, after_updater_shutdown) ) if updater: - async with ApplicationBuilder().bot(one_time_bot).build(): + async with ApplicationBuilder().bot(one_time_bot).concurrent_updates( + update_processor + ).build(): pass - assert self.test_flag == {"bot", "updater"} + assert self.test_flag == {"bot", "update_processor", "updater"} else: - async with ApplicationBuilder().bot(one_time_bot).updater(None).build(): + async with ApplicationBuilder().bot(one_time_bot).updater(None).concurrent_updates( + update_processor + ).build(): pass - assert self.test_flag == {"bot"} + assert self.test_flag == {"bot", "update_processor"} async def test_multiple_inits_and_shutdowns(self, app, monkeypatch): self.received = defaultdict(int) @@ -1309,7 +1323,7 @@ def gen(): await app.create_task(gen()) assert event.is_set() - async def test_no_concurrent_updates(self, app): + async def test_no_update_processor(self, app): queue = asyncio.Queue() event_1 = asyncio.Event() event_2 = asyncio.Event() @@ -1337,14 +1351,14 @@ async def callback(u, c): await app.stop() - @pytest.mark.parametrize("concurrent_updates", [15, 50, 100]) - async def test_concurrent_updates(self, one_time_bot, concurrent_updates): + @pytest.mark.parametrize("update_processor", [15, 50, 100]) + async def test_update_processor(self, one_time_bot, update_processor): # We don't test with `True` since the large number of parallel coroutines quickly leads # to test instabilities - app = ( - Application.builder().bot(one_time_bot).concurrent_updates(concurrent_updates).build() - ) - events = {i: asyncio.Event() for i in range(app.concurrent_updates + 10)} + app = Application.builder().bot(one_time_bot).concurrent_updates(update_processor).build() + events = { + i: asyncio.Event() for i in range(app.update_processor.max_concurrent_updates + 10) + } queue = asyncio.Queue() for event in events.values(): await queue.put(event) @@ -1356,25 +1370,28 @@ async def callback(u, c): app.add_handler(TypeHandler(object, callback)) async with app: await app.start() - for i in range(app.concurrent_updates + 10): + for i in range(app.update_processor.max_concurrent_updates + 10): await app.update_queue.put(i) - for i in range(app.concurrent_updates + 10): + for i in range(app.update_processor.max_concurrent_updates + 10): assert not events[i].is_set() await asyncio.sleep(0.9) - for i in range(app.concurrent_updates): + for i in range(app.update_processor.max_concurrent_updates): assert events[i].is_set() - for i in range(app.concurrent_updates, app.concurrent_updates + 10): + for i in range( + app.update_processor.max_concurrent_updates, + app.update_processor.max_concurrent_updates + 10, + ): assert not events[i].is_set() await asyncio.sleep(0.5) - for i in range(app.concurrent_updates + 10): + for i in range(app.update_processor.max_concurrent_updates + 10): assert events[i].is_set() await app.stop() - async def test_concurrent_updates_done_on_shutdown(self, one_time_bot): + async def test_update_processor_done_on_shutdown(self, one_time_bot): app = Application.builder().bot(one_time_bot).concurrent_updates(True).build() event = asyncio.Event() diff --git a/tests/ext/test_applicationbuilder.py b/tests/ext/test_applicationbuilder.py index b17d30451fb..0f9eb29ad7f 100644 --- a/tests/ext/test_applicationbuilder.py +++ b/tests/ext/test_applicationbuilder.py @@ -35,6 +35,7 @@ Updater, ) from telegram.ext._applicationbuilder import _BOT_CHECKS +from telegram.ext._baseupdateprocessor import SimpleUpdateProcessor from telegram.request import HTTPXRequest from tests.auxil.constants import PRIVATE_KEY from tests.auxil.envvars import TEST_WITH_OPT_DEPS @@ -96,7 +97,8 @@ class Client: app = builder.token(bot.token).build() assert isinstance(app, Application) - assert app.concurrent_updates == 0 + assert isinstance(app.update_processor, SimpleUpdateProcessor) + assert app.update_processor.max_concurrent_updates == 1 assert isinstance(app.bot, ExtBot) assert isinstance(app.bot.request, HTTPXRequest) @@ -367,12 +369,21 @@ def __init__(self, arg, **kwargs): assert isinstance(app, CustomApplication) assert app.arg == 2 - def test_all_application_args_custom(self, builder, bot, monkeypatch): + @pytest.mark.parametrize( + ("concurrent_updates", "expected"), + [ + (4, SimpleUpdateProcessor(4)), + (False, SimpleUpdateProcessor(1)), + (True, SimpleUpdateProcessor(256)), + ], + ) + def test_all_application_args_custom( + self, builder, bot, monkeypatch, concurrent_updates, expected + ): job_queue = JobQueue() persistence = PicklePersistence("file_path") update_queue = asyncio.Queue() context_types = ContextTypes() - concurrent_updates = 123 async def post_init(app: Application) -> None: pass @@ -395,6 +406,7 @@ async def post_stop(app: Application) -> None: .post_stop(post_stop) .arbitrary_callback_data(True) ).build() + assert app.job_queue is job_queue assert app.job_queue.application is app assert app.persistence is persistence @@ -403,7 +415,9 @@ async def post_stop(app: Application) -> None: assert app.updater.update_queue is update_queue assert app.updater.bot is app.bot assert app.context_types is context_types - assert app.concurrent_updates == concurrent_updates + assert isinstance(app.update_processor, SimpleUpdateProcessor) + assert app.update_processor.max_concurrent_updates == expected.max_concurrent_updates + assert app.concurrent_updates == app.update_processor.max_concurrent_updates assert app.post_init is post_init assert app.post_shutdown is post_shutdown assert app.post_stop is post_stop @@ -414,6 +428,19 @@ async def post_stop(app: Application) -> None: assert app.updater is updater assert app.bot is updater.bot assert app.update_queue is updater.update_queue + app = ( + builder.token(bot.token) + .job_queue(job_queue) + .persistence(persistence) + .update_queue(update_queue) + .context_types(context_types) + .concurrent_updates(expected) + .post_init(post_init) + .post_shutdown(post_shutdown) + .post_stop(post_stop) + .arbitrary_callback_data(True) + ).build() + assert app.update_processor is expected @pytest.mark.parametrize("input_type", ["bytes", "str", "Path"]) def test_all_private_key_input_types(self, builder, bot, input_type): diff --git a/tests/ext/test_baseupdateprocessor.py b/tests/ext/test_baseupdateprocessor.py new file mode 100644 index 00000000000..3ae10d2dd16 --- /dev/null +++ b/tests/ext/test_baseupdateprocessor.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python +# +# A library that provides a Python interface to the Telegram Bot API +# Copyright (C) 2015-2023 +# Leandro Toledo de Souza +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser Public License for more details. +# +# You should have received a copy of the GNU Lesser Public License +# along with this program. If not, see [http://www.gnu.org/licenses/]. +"""Here we run tests directly with SimpleUpdateProcessor because that's easier than providing dummy +implementations for SimpleUpdateProcessor and we want to test SimpleUpdateProcessor anyway.""" +import asyncio + +import pytest + +from telegram import Update +from telegram.ext import SimpleUpdateProcessor +from tests.auxil.asyncio_helpers import call_after +from tests.auxil.slots import mro_slots + + +@pytest.fixture() +def mock_processor(): + class MockProcessor(SimpleUpdateProcessor): + test_flag = False + + async def do_process_update(self, update, coroutine): + await coroutine + self.test_flag = True + + return MockProcessor(5) + + +class TestSimpleUpdateProcessor: + def test_slot_behaviour(self): + inst = SimpleUpdateProcessor(1) + for attr in inst.__slots__: + assert getattr(inst, attr, "err") != "err", f"got extra slot '{attr}'" + assert len(mro_slots(inst)) == len(set(mro_slots(inst))), "duplicate slot" + + @pytest.mark.parametrize("concurrent_updates", [0, -1]) + def test_init(self, concurrent_updates): + processor = SimpleUpdateProcessor(3) + assert processor.max_concurrent_updates == 3 + with pytest.raises(ValueError, match="must be a positive integer"): + SimpleUpdateProcessor(concurrent_updates) + + async def test_process_update(self, mock_processor): + """Test that process_update calls do_process_update.""" + update = Update(1) + + async def coroutine(): + pass + + await mock_processor.process_update(update, coroutine()) + # This flag is set in the mock processor in do_process_update, telling us that + # do_process_update was called. + assert mock_processor.test_flag + + async def test_do_process_update(self): + """Test that do_process_update calls the coroutine.""" + processor = SimpleUpdateProcessor(1) + update = Update(1) + test_flag = False + + async def coroutine(): + nonlocal test_flag + test_flag = True + + await processor.do_process_update(update, coroutine()) + assert test_flag + + async def test_max_concurrent_updates_enforcement(self, mock_processor): + """Test that max_concurrent_updates is enforced, i.e. that the processor will run + at most max_concurrent_updates coroutines at the same time.""" + count = 2 * mock_processor.max_concurrent_updates + events = {i: asyncio.Event() for i in range(count)} + queue = asyncio.Queue() + for event in events.values(): + await queue.put(event) + + async def callback(): + await asyncio.sleep(0.5) + (await queue.get()).set() + + # We start several calls to `process_update` at the same time, each of them taking + # 0.5 seconds to complete. We know that they are completed when the corresponding + # event is set. + tasks = [ + asyncio.create_task(mock_processor.process_update(update=_, coroutine=callback())) + for _ in range(count) + ] + + # Right now we expect no event to be set + for i in range(count): + assert not events[i].is_set() + + # After 0.5 seconds (+ some buffer), we expect that exactly max_concurrent_updates + # events are set. + await asyncio.sleep(0.75) + for i in range(mock_processor.max_concurrent_updates): + assert events[i].is_set() + for i in range( + mock_processor.max_concurrent_updates, + count, + ): + assert not events[i].is_set() + + # After wating another 0.5 seconds, we expect that the next max_concurrent_updates + # events are set. + await asyncio.sleep(0.5) + for i in range(count): + assert events[i].is_set() + + # Sanity check: we expect that all tasks are completed. + await asyncio.gather(*tasks) + + async def test_context_manager(self, monkeypatch, mock_processor): + self.test_flag = set() + + async def after_initialize(*args, **kwargs): + self.test_flag.add("initialize") + + async def after_shutdown(*args, **kwargs): + self.test_flag.add("stop") + + monkeypatch.setattr( + SimpleUpdateProcessor, + "initialize", + call_after(SimpleUpdateProcessor.initialize, after_initialize), + ) + monkeypatch.setattr( + SimpleUpdateProcessor, + "shutdown", + call_after(SimpleUpdateProcessor.shutdown, after_shutdown), + ) + + async with mock_processor: + pass + + assert self.test_flag == {"initialize", "stop"} + + async def test_context_manager_exception_on_init(self, monkeypatch, mock_processor): + async def initialize(*args, **kwargs): + raise RuntimeError("initialize") + + async def shutdown(*args, **kwargs): + self.test_flag = "shutdown" + + monkeypatch.setattr(SimpleUpdateProcessor, "initialize", initialize) + monkeypatch.setattr(SimpleUpdateProcessor, "shutdown", shutdown) + + with pytest.raises(RuntimeError, match="initialize"): + async with mock_processor: + pass + + assert self.test_flag == "shutdown"