Skip to content

Commit

Permalink
Add ApplicationBuilder.(get_updates_)socket_options (#3943)
Browse files Browse the repository at this point in the history
  • Loading branch information
Bibo-Joshi committed Oct 31, 2023
1 parent c71612f commit 616b0b5
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 9 deletions.
6 changes: 6 additions & 0 deletions telegram/_utils/types.py
Expand Up @@ -95,3 +95,9 @@
CorrectOptionID = Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

MarkdownVersion = Literal[1, 2]

SocketOpt = Union[
Tuple[int, int, int],
Tuple[int, int, Union[bytes, bytearray]],
Tuple[int, int, None, int],
]
54 changes: 53 additions & 1 deletion telegram/ext/_applicationbuilder.py
Expand Up @@ -23,6 +23,7 @@
TYPE_CHECKING,
Any,
Callable,
Collection,
Coroutine,
Dict,
Generic,
Expand All @@ -36,7 +37,7 @@

from telegram._bot import Bot
from telegram._utils.defaultvalue import DEFAULT_FALSE, DEFAULT_NONE, DefaultValue
from telegram._utils.types import DVInput, DVType, FilePathInput, HTTPVersion, ODVInput
from telegram._utils.types import DVInput, DVType, FilePathInput, HTTPVersion, ODVInput, SocketOpt
from telegram._utils.warnings import warn
from telegram.ext._application import Application
from telegram.ext._baseupdateprocessor import BaseUpdateProcessor, SimpleUpdateProcessor
Expand Down Expand Up @@ -71,13 +72,15 @@
("get_updates_request", "get_updates_request instance"),
("connection_pool_size", "connection_pool_size"),
("proxy", "proxy"),
("socket_options", "socket_options"),
("pool_timeout", "pool_timeout"),
("connect_timeout", "connect_timeout"),
("read_timeout", "read_timeout"),
("write_timeout", "write_timeout"),
("http_version", "http_version"),
("get_updates_connection_pool_size", "get_updates_connection_pool_size"),
("get_updates_proxy", "get_updates_proxy"),
("get_updates_socket_options", "get_updates_socket_options"),
("get_updates_pool_timeout", "get_updates_pool_timeout"),
("get_updates_connect_timeout", "get_updates_connect_timeout"),
("get_updates_read_timeout", "get_updates_read_timeout"),
Expand Down Expand Up @@ -143,6 +146,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
"_get_updates_proxy",
"_get_updates_read_timeout",
"_get_updates_request",
"_get_updates_socket_options",
"_get_updates_write_timeout",
"_get_updates_http_version",
"_job_queue",
Expand All @@ -157,6 +161,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
"_rate_limiter",
"_read_timeout",
"_request",
"_socket_options",
"_token",
"_update_queue",
"_updater",
Expand All @@ -171,13 +176,15 @@ def __init__(self: "InitApplicationBuilder"):
self._base_file_url: DVType[str] = DefaultValue("https://api.telegram.org/file/bot")
self._connection_pool_size: DVInput[int] = DEFAULT_NONE
self._proxy: DVInput[Union[str, httpx.Proxy, httpx.URL]] = DEFAULT_NONE
self._socket_options: DVInput[Collection[SocketOpt]] = DEFAULT_NONE
self._connect_timeout: ODVInput[float] = DEFAULT_NONE
self._read_timeout: ODVInput[float] = DEFAULT_NONE
self._write_timeout: ODVInput[float] = DEFAULT_NONE
self._pool_timeout: ODVInput[float] = DEFAULT_NONE
self._request: DVInput[BaseRequest] = DEFAULT_NONE
self._get_updates_connection_pool_size: DVInput[int] = DEFAULT_NONE
self._get_updates_proxy: DVInput[Union[str, httpx.Proxy, httpx.URL]] = DEFAULT_NONE
self._get_updates_socket_options: DVInput[Collection[SocketOpt]] = DEFAULT_NONE
self._get_updates_connect_timeout: ODVInput[float] = DEFAULT_NONE
self._get_updates_read_timeout: ODVInput[float] = DEFAULT_NONE
self._get_updates_write_timeout: ODVInput[float] = DEFAULT_NONE
Expand Down Expand Up @@ -219,6 +226,7 @@ def _build_request(self, get_updates: bool) -> BaseRequest:
return getattr(self, f"{prefix}request")

proxy = DefaultValue.get_value(getattr(self, f"{prefix}proxy"))
socket_options = DefaultValue.get_value(getattr(self, f"{prefix}socket_options"))
if get_updates:
connection_pool_size = (
DefaultValue.get_value(getattr(self, f"{prefix}connection_pool_size")) or 1
Expand All @@ -245,6 +253,7 @@ def _build_request(self, get_updates: bool) -> BaseRequest:
connection_pool_size=connection_pool_size,
proxy=proxy,
http_version=http_version, # type: ignore[arg-type]
socket_options=socket_options,
**effective_timeouts,
)

Expand Down Expand Up @@ -426,6 +435,9 @@ def _request_check(self, get_updates: bool) -> None:
if not isinstance(getattr(self, f"_{prefix}proxy"), DefaultValue):
raise RuntimeError(_TWO_ARGS_REQ.format(name, "proxy"))

if not isinstance(getattr(self, f"_{prefix}socket_options"), DefaultValue):
raise RuntimeError(_TWO_ARGS_REQ.format(name, "socket_options"))

if not isinstance(getattr(self, f"_{prefix}http_version"), DefaultValue):
raise RuntimeError(_TWO_ARGS_REQ.format(name, "http_version"))

Expand Down Expand Up @@ -531,6 +543,25 @@ def proxy(self: BuilderType, proxy: Union[str, httpx.Proxy, httpx.URL]) -> Build
self._proxy = proxy
return self

def socket_options(self: BuilderType, socket_options: Collection[SocketOpt]) -> BuilderType:
"""Sets the options for the :paramref:`~telegram.request.HTTPXRequest.socket_options`
parameter of :attr:`telegram.Bot.request`. Defaults to :obj:`None`.
.. seealso:: :meth:`get_updates_socket_options`
.. versionadded:: NEXT.VERSION
Args:
socket_options (Collection[:obj:`tuple`], optional): Socket options. See
:paramref:`telegram.request.HTTPXRequest.socket_options` for more information.
Returns:
:class:`ApplicationBuilder`: The same builder with the updated argument.
"""
self._request_param_check(name="socket_options", get_updates=False)
self._socket_options = socket_options
return self

def connect_timeout(self: BuilderType, connect_timeout: Optional[float]) -> BuilderType:
"""Sets the connection attempt timeout for the
:paramref:`~telegram.request.HTTPXRequest.connect_timeout` parameter of
Expand Down Expand Up @@ -726,6 +757,27 @@ def get_updates_proxy(
self._get_updates_proxy = get_updates_proxy
return self

def get_updates_socket_options(
self: BuilderType, get_updates_socket_options: Collection[SocketOpt]
) -> BuilderType:
"""Sets the options for the :paramref:`~telegram.request.HTTPXRequest.socket_options`
parameter of :paramref:`telegram.Bot.get_updates_request`. Defaults to :obj:`None`.
.. seealso:: :meth:`socket_options`
.. versionadded:: NEXT.VERSION
Args:
get_updates_socket_options (Collection[:obj:`tuple`], optional): Socket options. See
:paramref:`telegram.request.HTTPXRequest.socket_options` for more information.
Returns:
:class:`ApplicationBuilder`: The same builder with the updated argument.
"""
self._request_param_check(name="socket_options", get_updates=True)
self._get_updates_socket_options = get_updates_socket_options
return self

def get_updates_connect_timeout(
self: BuilderType, get_updates_connect_timeout: Optional[float]
) -> BuilderType:
Expand Down
10 changes: 2 additions & 8 deletions telegram/request/_httpxrequest.py
Expand Up @@ -23,7 +23,7 @@

from telegram._utils.defaultvalue import DefaultValue
from telegram._utils.logging import get_logger
from telegram._utils.types import HTTPVersion, ODVInput
from telegram._utils.types import HTTPVersion, ODVInput, SocketOpt
from telegram._utils.warnings import warn
from telegram.error import NetworkError, TimedOut
from telegram.request._baserequest import BaseRequest
Expand All @@ -37,12 +37,6 @@

_LOGGER = get_logger(__name__, "HTTPXRequest")

_SocketOpt = Union[
Tuple[int, int, int],
Tuple[int, int, Union[bytes, bytearray]],
Tuple[int, int, None, int],
]


class HTTPXRequest(BaseRequest):
"""Implementation of :class:`~telegram.request.BaseRequest` using the library
Expand Down Expand Up @@ -132,7 +126,7 @@ def __init__(
connect_timeout: Optional[float] = 5.0,
pool_timeout: Optional[float] = 1.0,
http_version: HTTPVersion = "1.1",
socket_options: Optional[Collection[_SocketOpt]] = None,
socket_options: Optional[Collection[SocketOpt]] = None,
proxy: Optional[Union[str, httpx.Proxy, httpx.URL]] = None,
):
if proxy_url is not None and proxy is not None:
Expand Down
67 changes: 67 additions & 0 deletions tests/ext/test_applicationbuilder.py
Expand Up @@ -17,11 +17,13 @@
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
import asyncio
import inspect
from dataclasses import dataclass

import httpx
import pytest

from telegram import Bot
from telegram.ext import (
AIORateLimiter,
Application,
Expand Down Expand Up @@ -65,6 +67,34 @@ def test_slot_behaviour(self, builder):
assert getattr(builder, attr, "err") != "err", f"got extra slot '{attr}'"
assert len(mro_slots(builder)) == len(set(mro_slots(builder))), "duplicate slot"

@pytest.mark.parametrize("get_updates", [True, False])
def test_all_methods_request(self, builder, get_updates):
arguments = inspect.signature(HTTPXRequest.__init__).parameters.keys()
prefix = "get_updates_" if get_updates else ""
for argument in arguments:
if argument == "self":
continue
assert hasattr(builder, prefix + argument), f"missing method {prefix}{argument}"

@pytest.mark.parametrize("bot_class", [Bot, ExtBot])
def test_all_methods_bot(self, builder, bot_class):
arguments = inspect.signature(bot_class.__init__).parameters.keys()
for argument in arguments:
if argument == "self":
continue
if argument == "private_key_password":
argument = "private_key" # noqa: PLW2901
assert hasattr(builder, argument), f"missing method {argument}"

def test_all_methods_application(self, builder):
arguments = inspect.signature(Application.__init__).parameters.keys()
for argument in arguments:
if argument == "self":
continue
if argument == "update_processor":
argument = "concurrent_updates" # noqa: PLW2901
assert hasattr(builder, argument), f"missing method {argument}"

def test_job_queue_init_exception(self, monkeypatch):
def init_raises_runtime_error(*args, **kwargs):
raise RuntimeError("RuntimeError")
Expand Down Expand Up @@ -172,6 +202,7 @@ def test_mutually_exclusive_for_bot(self, builder, method, description):
"write_timeout",
"proxy",
"proxy_url",
"socket_options",
"bot",
"updater",
"http_version",
Expand Down Expand Up @@ -201,6 +232,7 @@ def test_mutually_exclusive_for_request(self, builder, method):
"get_updates_write_timeout",
"get_updates_proxy",
"get_updates_proxy_url",
"get_updates_socket_options",
"get_updates_http_version",
"bot",
"updater",
Expand Down Expand Up @@ -231,6 +263,7 @@ def test_mutually_exclusive_for_get_updates_request(self, builder, method):
"get_updates_write_timeout",
"get_updates_proxy_url",
"get_updates_proxy",
"get_updates_socket_options",
"get_updates_http_version",
"connection_pool_size",
"connect_timeout",
Expand All @@ -239,6 +272,7 @@ def test_mutually_exclusive_for_get_updates_request(self, builder, method):
"write_timeout",
"proxy",
"proxy_url",
"socket_options",
"http_version",
"bot",
"update_queue",
Expand Down Expand Up @@ -273,6 +307,7 @@ def test_mutually_exclusive_for_updater(self, builder, method):
"get_updates_write_timeout",
"get_updates_proxy",
"get_updates_proxy_url",
"get_updates_socket_options",
"get_updates_http_version",
"connection_pool_size",
"connect_timeout",
Expand All @@ -281,6 +316,7 @@ def test_mutually_exclusive_for_updater(self, builder, method):
"write_timeout",
"proxy",
"proxy_url",
"socket_options",
"bot",
"http_version",
]
Expand All @@ -306,6 +342,7 @@ def test_mutually_non_exclusive_for_updater(self, builder, method):
def test_all_bot_args_custom(
self, builder, bot, monkeypatch, proxy_method, get_updates_proxy_method
):
# Only socket_options is tested in a standalone test, since that's easier
defaults = Defaults()
request = HTTPXRequest()
get_updates_request = HTTPXRequest()
Expand Down Expand Up @@ -379,6 +416,36 @@ class Client:
assert client.http1 is True
assert client.http2 is False

def test_custom_socket_options(self, builder, monkeypatch, bot):
httpx_request_kwargs = []
httpx_request_init = HTTPXRequest.__init__

def init_transport(*args, **kwargs):
nonlocal httpx_request_kwargs
# This is called once for request and once for get_updates_request, so we make
# it a list
httpx_request_kwargs.append(kwargs.copy())
httpx_request_init(*args, **kwargs)

monkeypatch.setattr(HTTPXRequest, "__init__", init_transport)

builder.token(bot.token).build()
assert httpx_request_kwargs[0].get("socket_options") is None
assert httpx_request_kwargs[1].get("socket_options") is None

httpx_request_kwargs = []
ApplicationBuilder().token(bot.token).socket_options(((1, 2, 3),)).connection_pool_size(
"request"
).get_updates_socket_options(((4, 5, 6),)).get_updates_connection_pool_size(
"get_updates"
).build()

for kwargs in httpx_request_kwargs:
if kwargs.get("connection_pool_size") == "request":
assert kwargs.get("socket_options") == ((1, 2, 3),)
else:
assert kwargs.get("socket_options") == ((4, 5, 6),)

def test_custom_application_class(self, bot, builder):
class CustomApplication(Application):
def __init__(self, arg, **kwargs):
Expand Down

0 comments on commit 616b0b5

Please sign in to comment.