Skip to content

Commit

Permalink
Add Parameter media_write_timeout to HTTPXRequest and Method `App…
Browse files Browse the repository at this point in the history
…licationBuilder.media_write_timeout` (#4120)
  • Loading branch information
Bibo-Joshi committed Feb 26, 2024
1 parent 277031c commit 9c263fb
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 8 deletions.
33 changes: 32 additions & 1 deletion telegram/ext/_applicationbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
("connect_timeout", "connect_timeout"),
("read_timeout", "read_timeout"),
("write_timeout", "write_timeout"),
("media_write_timeout", "media_write_timeout"),
("http_version", "http_version"),
("get_updates_connection_pool_size", "get_updates_connection_pool_size"),
("get_updates_proxy", "get_updates_proxy"),
Expand Down Expand Up @@ -152,6 +153,7 @@ class ApplicationBuilder(Generic[BT, CCT, UD, CD, BD, JQ]):
"_http_version",
"_job_queue",
"_local_mode",
"_media_write_timeout",
"_persistence",
"_pool_timeout",
"_post_init",
Expand Down Expand Up @@ -181,6 +183,7 @@ def __init__(self: "InitApplicationBuilder"):
self._connect_timeout: ODVInput[float] = DEFAULT_NONE
self._read_timeout: ODVInput[float] = DEFAULT_NONE
self._write_timeout: ODVInput[float] = DEFAULT_NONE
self._media_write_timeout: ODVInput[float] = DEFAULT_NONE
self._pool_timeout: ODVInput[float] = DEFAULT_NONE
self._request: DVInput[BaseRequest] = DEFAULT_NONE
self._get_updates_connection_pool_size: DVInput[int] = DEFAULT_NONE
Expand Down Expand Up @@ -243,6 +246,10 @@ def _build_request(self, get_updates: bool) -> BaseRequest:
"write_timeout": getattr(self, f"{prefix}write_timeout"),
"pool_timeout": getattr(self, f"{prefix}pool_timeout"),
}

if not get_updates:
timeouts["media_write_timeout"] = self._media_write_timeout

# Get timeouts that were actually set-
effective_timeouts = {
key: value for key, value in timeouts.items() if not isinstance(value, DefaultValue)
Expand Down Expand Up @@ -424,9 +431,13 @@ def _request_check(self, get_updates: bool) -> None:
prefix = "get_updates_" if get_updates else ""
name = prefix + "request"

timeouts = ["connect_timeout", "read_timeout", "write_timeout", "pool_timeout"]
if not get_updates:
timeouts.append("media_write_timeout")

# Code below tests if it's okay to set a Request object. Only okay if no other request args
# or instances containing a Request were set previously
for attr in ("connect_timeout", "read_timeout", "write_timeout", "pool_timeout"):
for attr in timeouts:
if not isinstance(getattr(self, f"_{prefix}{attr}"), DefaultValue):
raise RuntimeError(_TWO_ARGS_REQ.format(name, attr))

Expand Down Expand Up @@ -617,6 +628,26 @@ def write_timeout(self: BuilderType, write_timeout: Optional[float]) -> BuilderT
self._write_timeout = write_timeout
return self

def media_write_timeout(
self: BuilderType, media_write_timeout: Optional[float]
) -> BuilderType:
"""Sets the media write operation timeout for the
:paramref:`~telegram.request.HTTPXRequest.media_write_timeout` parameter of
:attr:`telegram.Bot.request`. Defaults to ``20``.
.. versionadded:: NEXT.VERSION
Args:
media_write_timeout (:obj:`float`): See
:paramref:`telegram.request.HTTPXRequest.media_write_timeout` for more information.
Returns:
:class:`ApplicationBuilder`: The same builder with the updated argument.
"""
self._request_param_check(name="media_write_timeout", get_updates=False)
self._media_write_timeout = media_write_timeout
return self

def pool_timeout(self: BuilderType, pool_timeout: Optional[float]) -> BuilderType:
"""Sets the connection pool's connection freeing timeout for the
:paramref:`~telegram.request.HTTPXRequest.pool_timeout` parameter of
Expand Down
20 changes: 14 additions & 6 deletions telegram/request/_httpxrequest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class HTTPXRequest(BaseRequest):
a network socket; i.e. POSTing a request or uploading a file).
This value is used unless a different value is passed to :meth:`do_request`.
Defaults to ``5``.
Hint:
This timeout is used for all requests except for those that upload media/files.
For the latter, :paramref:`media_write_timeout` is used.
connect_timeout (:obj:`float` | :obj:`None`, optional): If passed, specifies the
maximum amount of time (in seconds) to wait for a connection attempt to a server
to succeed. This value is used unless a different value is passed to
Expand Down Expand Up @@ -112,10 +116,16 @@ class HTTPXRequest(BaseRequest):
.. _the docs of httpx: https://www.python-httpx.org/environment_variables/#proxies
.. versionadded:: 20.7
media_write_timeout (:obj:`float` | :obj:`None`, optional): Like :paramref:`write_timeout`,
but used only for requests that upload media/files. This value is used unless a
different value is passed to :paramref:`do_request.write_timeout` of
:meth:`do_request`. Defaults to ``20`` seconds.
.. versionadded:: NEXT.VERSION
"""

__slots__ = ("_client", "_client_kwargs", "_http_version")
__slots__ = ("_client", "_client_kwargs", "_http_version", "_media_write_timeout")

def __init__(
self,
Expand All @@ -128,6 +138,7 @@ def __init__(
http_version: HTTPVersion = "1.1",
socket_options: Optional[Collection[SocketOpt]] = None,
proxy: Optional[Union[str, httpx.Proxy, httpx.URL]] = None,
media_write_timeout: Optional[float] = 20.0,
):
if proxy_url is not None and proxy is not None:
raise ValueError("The parameters `proxy_url` and `proxy` are mutually exclusive.")
Expand All @@ -142,6 +153,7 @@ def __init__(
)

self._http_version = http_version
self._media_write_timeout = media_write_timeout
timeout = httpx.Timeout(
connect=connect_timeout,
read=read_timeout,
Expand Down Expand Up @@ -251,11 +263,7 @@ async def do_request(
pool_timeout = self._client.timeout.pool

if isinstance(write_timeout, DefaultValue):
# Making the networking backend decide on the proper timeout values instead of doing
# it via the default values of the Bot methods was introduced in version 20.7.
# We hard-code the value here for now until we add additional parameters to this
# class to control the media_write_timeout separately.
write_timeout = self._client.timeout.write if not files else 20
write_timeout = self._client.timeout.write if not files else self._media_write_timeout

timeout = httpx.Timeout(
connect=connect_timeout,
Expand Down
19 changes: 18 additions & 1 deletion tests/ext/test_applicationbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def test_all_methods_request(self, builder, get_updates):
for argument in arguments:
if argument == "self":
continue
if argument == "media_write_timeout" and get_updates:
# get_updates never makes media requests
continue
assert hasattr(builder, prefix + argument), f"missing method {prefix}{argument}"

@pytest.mark.parametrize("bot_class", [Bot, ExtBot])
Expand Down Expand Up @@ -202,6 +205,7 @@ def test_mutually_exclusive_for_bot(self, builder, method, description):
"pool_timeout",
"read_timeout",
"write_timeout",
"media_write_timeout",
"proxy",
"proxy_url",
"socket_options",
Expand Down Expand Up @@ -272,6 +276,7 @@ def test_mutually_exclusive_for_get_updates_request(self, builder, method):
"pool_timeout",
"read_timeout",
"write_timeout",
"media_write_timeout",
"proxy",
"proxy_url",
"socket_options",
Expand Down Expand Up @@ -316,6 +321,7 @@ def test_mutually_exclusive_for_updater(self, builder, method):
"pool_timeout",
"read_timeout",
"write_timeout",
"media_write_timeout",
"proxy",
"proxy_url",
"socket_options",
Expand Down Expand Up @@ -384,12 +390,20 @@ class Client:
http2: object
transport: object = None

original_init = HTTPXRequest.__init__
media_write_timeout = []

def init_httpx_request(self_, *args, **kwargs):
media_write_timeout.append(kwargs.get("media_write_timeout"))
original_init(self_, *args, **kwargs)

monkeypatch.setattr(httpx, "AsyncClient", Client)
monkeypatch.setattr(HTTPXRequest, "__init__", init_httpx_request)

builder = ApplicationBuilder().token(bot.token)
builder.connection_pool_size(1).connect_timeout(2).pool_timeout(3).read_timeout(
4
).write_timeout(5).http_version("1.1")
).write_timeout(5).media_write_timeout(6).http_version("1.1")
getattr(builder, proxy_method)("proxy")
app = builder.build()
client = app.bot.request._client
Expand All @@ -399,7 +413,9 @@ class Client:
assert client.proxy == "proxy"
assert client.http1 is True
assert client.http2 is False
assert media_write_timeout == [6, None]

media_write_timeout.clear()
builder = ApplicationBuilder().token(bot.token)
builder.get_updates_connection_pool_size(1).get_updates_connect_timeout(
2
Expand All @@ -417,6 +433,7 @@ class Client:
assert client.proxy == "get_updates_proxy"
assert client.http1 is True
assert client.http2 is False
assert media_write_timeout == [None, None]

def test_custom_socket_options(self, builder, monkeypatch, bot):
httpx_request_kwargs = []
Expand Down
33 changes: 33 additions & 0 deletions tests/request/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,39 @@ async def request(_, **kwargs):
# other than HTTPXRequest
assert len(recwarn) == 0

@pytest.mark.parametrize("init", [True, False])
async def test_setting_media_write_timeout(
self, monkeypatch, init, input_media_photo, recwarn # noqa: F811
):
httpx_request = HTTPXRequest(media_write_timeout=42) if init else HTTPXRequest()

async def request(_, **kwargs):
self.test_flag = kwargs["timeout"].write
return httpx.Response(HTTPStatus.OK, content=b'{"ok": "True", "result": {}}')

monkeypatch.setattr(httpx.AsyncClient, "request", request)

data = {"string": "string", "int": 1, "float": 1.0, "media": input_media_photo}
request_data = RequestData(
parameters=[RequestParameter.from_input(key, value) for key, value in data.items()],
)

# First make sure that custom timeouts are always respected
await httpx_request.post(
"url",
request_data,
write_timeout=43,
)
assert self.test_flag == 43

# Now also ensure that the init value is respected
await httpx_request.post("url", request_data)
assert self.test_flag == 42 if init else 20

# Just for double-checking, since warnings are issued for implementations of BaseRequest
# other than HTTPXRequest
assert len(recwarn) == 0

async def test_socket_opts(self, monkeypatch):
transport_kwargs = {}
transport_init = AsyncHTTPTransport.__init__
Expand Down

0 comments on commit 9c263fb

Please sign in to comment.