Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ApplicationBuilder.(get_updates_)socket_options #3943

Merged
merged 7 commits into from Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions telegram/_utils/types.py
Expand Up @@ -95,3 +95,9 @@
CorrectOptionID = Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

MarkdownVersion = Literal[1, 2]

SocketOpt = Union[
Tuple[int, int, int],
Tuple[int, int, Union[bytes, bytearray]],
Tuple[int, int, None, int],
]
54 changes: 53 additions & 1 deletion telegram/ext/_applicationbuilder.py
Expand Up @@ -23,6 +23,7 @@
TYPE_CHECKING,
Any,
Callable,
Collection,
Coroutine,
Dict,
Generic,
Expand All @@ -36,7 +37,7 @@

from telegram._bot import Bot
from telegram._utils.defaultvalue import DEFAULT_FALSE, DEFAULT_NONE, DefaultValue
from telegram._utils.types import DVInput, DVType, FilePathInput, HTTPVersion, ODVInput
from telegram._utils.types import DVInput, DVType, FilePathInput, HTTPVersion, ODVInput, SocketOpt
from telegram._utils.warnings import warn
from telegram.ext._application import Application
from telegram.ext._baseupdateprocessor import BaseUpdateProcessor, SimpleUpdateProcessor
Expand Down Expand Up @@ -71,13 +72,15 @@
("get_updates_request", "get_updates_request instance"),
("connection_pool_size", "connection_pool_size"),
("proxy", "proxy"),
("socket_options", "socket_options"),
("pool_timeout", "pool_timeout"),
("connect_timeout", "connect_timeout"),
("read_timeout", "read_timeout"),
("write_timeout", "write_timeout"),
("http_version", "http_version"),
("get_updates_connection_pool_size", "get_updates_connection_pool_size"),
("get_updates_proxy", "get_updates_proxy"),
("get_updates_socket_options", "get_updates_socket_options"),
("get_updates_pool_timeout", "get_updates_pool_timeout"),
("get_updates_connect_timeout", "get_updates_connect_timeout"),
("get_updates_read_timeout", "get_updates_read_timeout"),
Expand Down Expand Up @@ -143,6 +146,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
"_get_updates_proxy",
"_get_updates_read_timeout",
"_get_updates_request",
"_get_updates_socket_options",
"_get_updates_write_timeout",
"_get_updates_http_version",
"_job_queue",
Expand All @@ -157,6 +161,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
"_rate_limiter",
"_read_timeout",
"_request",
"_socket_options",
"_token",
"_update_queue",
"_updater",
Expand All @@ -171,13 +176,15 @@ def __init__(self: "InitApplicationBuilder"):
self._base_file_url: DVType[str] = DefaultValue("https://api.telegram.org/file/bot")
self._connection_pool_size: DVInput[int] = DEFAULT_NONE
self._proxy: DVInput[Union[str, httpx.Proxy, httpx.URL]] = DEFAULT_NONE
self._socket_options: DVInput[Collection[SocketOpt]] = DEFAULT_NONE
self._connect_timeout: ODVInput[float] = DEFAULT_NONE
self._read_timeout: ODVInput[float] = DEFAULT_NONE
self._write_timeout: ODVInput[float] = DEFAULT_NONE
self._pool_timeout: ODVInput[float] = DEFAULT_NONE
self._request: DVInput[BaseRequest] = DEFAULT_NONE
self._get_updates_connection_pool_size: DVInput[int] = DEFAULT_NONE
self._get_updates_proxy: DVInput[Union[str, httpx.Proxy, httpx.URL]] = DEFAULT_NONE
self._get_updates_socket_options: DVInput[Collection[SocketOpt]] = DEFAULT_NONE
self._get_updates_connect_timeout: ODVInput[float] = DEFAULT_NONE
self._get_updates_read_timeout: ODVInput[float] = DEFAULT_NONE
self._get_updates_write_timeout: ODVInput[float] = DEFAULT_NONE
Expand Down Expand Up @@ -219,6 +226,7 @@ def _build_request(self, get_updates: bool) -> BaseRequest:
return getattr(self, f"{prefix}request")

proxy = DefaultValue.get_value(getattr(self, f"{prefix}proxy"))
socket_options = DefaultValue.get_value(getattr(self, f"{prefix}socket_options"))
if get_updates:
connection_pool_size = (
DefaultValue.get_value(getattr(self, f"{prefix}connection_pool_size")) or 1
Expand All @@ -245,6 +253,7 @@ def _build_request(self, get_updates: bool) -> BaseRequest:
connection_pool_size=connection_pool_size,
proxy=proxy,
http_version=http_version, # type: ignore[arg-type]
socket_options=socket_options,
**effective_timeouts,
)

Expand Down Expand Up @@ -426,6 +435,9 @@ def _request_check(self, get_updates: bool) -> None:
if not isinstance(getattr(self, f"_{prefix}proxy"), DefaultValue):
raise RuntimeError(_TWO_ARGS_REQ.format(name, "proxy"))

if not isinstance(getattr(self, f"_{prefix}socket_options"), DefaultValue):
raise RuntimeError(_TWO_ARGS_REQ.format(name, "socket_options"))

if not isinstance(getattr(self, f"_{prefix}http_version"), DefaultValue):
raise RuntimeError(_TWO_ARGS_REQ.format(name, "http_version"))

Expand Down Expand Up @@ -531,6 +543,25 @@ def proxy(self: BuilderType, proxy: Union[str, httpx.Proxy, httpx.URL]) -> Build
self._proxy = proxy
return self

def socket_options(self: BuilderType, socket_options: Collection[SocketOpt]) -> BuilderType:
"""Sets the options for the :paramref:`~telegram.request.HTTPXRequest.socket_options`
parameter of :attr:`telegram.Bot.request`. Defaults to :obj:`None`.

.. seealso:: :meth:`get_updates_socket_options`

.. versionadded:: NEXT.VERSION

Args:
socket_options (Collection[:obj:`tuple`], optional): Socket options. See
:paramref:`telegram.request.HTTPXRequest.socket_options` for more information.

Returns:
:class:`ApplicationBuilder`: The same builder with the updated argument.
"""
self._request_param_check(name="socket_options", get_updates=False)
self._socket_options = socket_options
return self

def connect_timeout(self: BuilderType, connect_timeout: Optional[float]) -> BuilderType:
"""Sets the connection attempt timeout for the
:paramref:`~telegram.request.HTTPXRequest.connect_timeout` parameter of
Expand Down Expand Up @@ -726,6 +757,27 @@ def get_updates_proxy(
self._get_updates_proxy = get_updates_proxy
return self

def get_updates_socket_options(
self: BuilderType, get_updates_socket_options: Collection[SocketOpt]
) -> BuilderType:
"""Sets the options for the :paramref:`~telegram.request.HTTPXRequest.socket_options`
parameter of :paramref:`telegram.Bot.get_updates_request`. Defaults to :obj:`None`.

.. seealso:: :meth:`socket_options`

.. versionadded:: NEXT.VERSION

Args:
get_updates_socket_options (Collection[:obj:`tuple`], optional): Socket options. See
:paramref:`telegram.request.HTTPXRequest.socket_options` for more information.

Returns:
:class:`ApplicationBuilder`: The same builder with the updated argument.
"""
self._request_param_check(name="socket_options", get_updates=True)
self._get_updates_socket_options = get_updates_socket_options
return self

def get_updates_connect_timeout(
self: BuilderType, get_updates_connect_timeout: Optional[float]
) -> BuilderType:
Expand Down
10 changes: 2 additions & 8 deletions telegram/request/_httpxrequest.py
Expand Up @@ -23,7 +23,7 @@

from telegram._utils.defaultvalue import DefaultValue
from telegram._utils.logging import get_logger
from telegram._utils.types import HTTPVersion, ODVInput
from telegram._utils.types import HTTPVersion, ODVInput, SocketOpt
from telegram._utils.warnings import warn
from telegram.error import NetworkError, TimedOut
from telegram.request._baserequest import BaseRequest
Expand All @@ -37,12 +37,6 @@

_LOGGER = get_logger(__name__, "HTTPXRequest")

_SocketOpt = Union[
Tuple[int, int, int],
Tuple[int, int, Union[bytes, bytearray]],
Tuple[int, int, None, int],
]


class HTTPXRequest(BaseRequest):
"""Implementation of :class:`~telegram.request.BaseRequest` using the library
Expand Down Expand Up @@ -132,7 +126,7 @@ def __init__(
connect_timeout: Optional[float] = 5.0,
pool_timeout: Optional[float] = 1.0,
http_version: HTTPVersion = "1.1",
socket_options: Optional[Collection[_SocketOpt]] = None,
socket_options: Optional[Collection[SocketOpt]] = None,
proxy: Optional[Union[str, httpx.Proxy, httpx.URL]] = None,
):
if proxy_url is not None and proxy is not None:
Expand Down
67 changes: 67 additions & 0 deletions tests/ext/test_applicationbuilder.py
Expand Up @@ -17,11 +17,13 @@
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
import asyncio
import inspect
from dataclasses import dataclass

import httpx
import pytest

from telegram import Bot
from telegram.ext import (
AIORateLimiter,
Application,
Expand Down Expand Up @@ -65,6 +67,34 @@ def test_slot_behaviour(self, builder):
assert getattr(builder, attr, "err") != "err", f"got extra slot '{attr}'"
assert len(mro_slots(builder)) == len(set(mro_slots(builder))), "duplicate slot"

@pytest.mark.parametrize("get_updates", [True, False])
def test_all_methods_request(self, builder, get_updates):
arguments = inspect.signature(HTTPXRequest.__init__).parameters.keys()
prefix = "get_updates_" if get_updates else ""
for argument in arguments:
if argument == "self":
continue
assert hasattr(builder, prefix + argument), f"missing method {prefix}{argument}"

@pytest.mark.parametrize("bot_class", [Bot, ExtBot])
def test_all_methods_bot(self, builder, bot_class):
arguments = inspect.signature(bot_class.__init__).parameters.keys()
for argument in arguments:
if argument == "self":
continue
if argument == "private_key_password":
argument = "private_key" # noqa: PLW2901
assert hasattr(builder, argument), f"missing method {argument}"

def test_all_methods_application(self, builder):
arguments = inspect.signature(Application.__init__).parameters.keys()
for argument in arguments:
if argument == "self":
continue
if argument == "update_processor":
argument = "concurrent_updates" # noqa: PLW2901
assert hasattr(builder, argument), f"missing method {argument}"

def test_job_queue_init_exception(self, monkeypatch):
def init_raises_runtime_error(*args, **kwargs):
raise RuntimeError("RuntimeError")
Expand Down Expand Up @@ -172,6 +202,7 @@ def test_mutually_exclusive_for_bot(self, builder, method, description):
"write_timeout",
"proxy",
"proxy_url",
"socket_options",
"bot",
"updater",
"http_version",
Expand Down Expand Up @@ -201,6 +232,7 @@ def test_mutually_exclusive_for_request(self, builder, method):
"get_updates_write_timeout",
"get_updates_proxy",
"get_updates_proxy_url",
"get_updates_socket_options",
"get_updates_http_version",
"bot",
"updater",
Expand Down Expand Up @@ -231,6 +263,7 @@ def test_mutually_exclusive_for_get_updates_request(self, builder, method):
"get_updates_write_timeout",
"get_updates_proxy_url",
"get_updates_proxy",
"get_updates_socket_options",
"get_updates_http_version",
"connection_pool_size",
"connect_timeout",
Expand All @@ -239,6 +272,7 @@ def test_mutually_exclusive_for_get_updates_request(self, builder, method):
"write_timeout",
"proxy",
"proxy_url",
"socket_options",
"http_version",
"bot",
"update_queue",
Expand Down Expand Up @@ -273,6 +307,7 @@ def test_mutually_exclusive_for_updater(self, builder, method):
"get_updates_write_timeout",
"get_updates_proxy",
"get_updates_proxy_url",
"get_updates_socket_options",
"get_updates_http_version",
"connection_pool_size",
"connect_timeout",
Expand All @@ -281,6 +316,7 @@ def test_mutually_exclusive_for_updater(self, builder, method):
"write_timeout",
"proxy",
"proxy_url",
"socket_options",
"bot",
"http_version",
]
Expand All @@ -306,6 +342,7 @@ def test_mutually_non_exclusive_for_updater(self, builder, method):
def test_all_bot_args_custom(
self, builder, bot, monkeypatch, proxy_method, get_updates_proxy_method
):
# Only socket_options is tested in a standalone test, since that's easier
defaults = Defaults()
request = HTTPXRequest()
get_updates_request = HTTPXRequest()
Expand Down Expand Up @@ -379,6 +416,36 @@ class Client:
assert client.http1 is True
assert client.http2 is False

def test_custom_socket_options(self, builder, monkeypatch, bot):
httpx_request_kwargs = []
httpx_request_init = HTTPXRequest.__init__

def init_transport(*args, **kwargs):
nonlocal httpx_request_kwargs
# This is called once for request and once for get_updates_request, so we make
# it a list
httpx_request_kwargs.append(kwargs.copy())
httpx_request_init(*args, **kwargs)

monkeypatch.setattr(HTTPXRequest, "__init__", init_transport)

builder.token(bot.token).build()
assert httpx_request_kwargs[0].get("socket_options") is None
assert httpx_request_kwargs[1].get("socket_options") is None

httpx_request_kwargs = []
ApplicationBuilder().token(bot.token).socket_options(((1, 2, 3),)).connection_pool_size(
"request"
).get_updates_socket_options(((4, 5, 6),)).get_updates_connection_pool_size(
"get_updates"
).build()

for kwargs in httpx_request_kwargs:
if kwargs.get("connection_pool_size") == "request":
assert kwargs.get("socket_options") == ((1, 2, 3),)
else:
assert kwargs.get("socket_options") == ((4, 5, 6),)
harshil21 marked this conversation as resolved.
Show resolved Hide resolved

def test_custom_application_class(self, bot, builder):
class CustomApplication(Application):
def __init__(self, arg, **kwargs):
Expand Down