Skip to content

Commit

Permalink
Inlude url & endpoint in RequestData protect the connection pool with…
Browse files Browse the repository at this point in the history
… a semaphore
  • Loading branch information
Bibo-Joshi committed Jan 2, 2022
1 parent 41e7010 commit 4d19a16
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 41 deletions.
5 changes: 3 additions & 2 deletions telegram/_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,12 @@ async def _post(
# We don't do this earlier so that _insert_defaults (see above) has a chance to convert
# to the default timezone in case this is called by ExtBot
request_data = RequestData(
[RequestParameter.from_input(key, value) for key, value in data.items()]
base_url=self.base_url,
endpoint=endpoint,
parameters=[RequestParameter.from_input(key, value) for key, value in data.items()],
)

return await self.request.post(
f'{self.base_url}/{endpoint}',
request_data=request_data,
read_timeout=read_timeout,
write_timeout=write_timeout,
Expand Down
2 changes: 1 addition & 1 deletion telegram/ext/_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ async def dispatch_error(
await self.run_async(callback, args=(update, context), update=update)
else:
try:
await callback(update, context)
await run_non_blocking(func=callback, args=(update, context))
except DispatcherHandlerStop:
return True
except Exception as exc:
Expand Down
9 changes: 1 addition & 8 deletions telegram/request/_baserequest.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ async def shutdown(self) -> None:

async def post(
self,
url: str,
request_data: RequestData = None,
connect_timeout: float = None,
read_timeout: float = None,
Expand All @@ -130,7 +129,6 @@ async def post(
Args:
url (:obj:`str`): The web location we want to retrieve.
request_data (:class:`telegram.request.RequestData`, optional): An object describing
any parameters and files to upload for the request.
connect_timeout (:obj:`float`, optional): If passed, specifies the maximum amount of
Expand All @@ -153,7 +151,6 @@ async def post(
"""
result = await self._request_wrapper(
method='POST',
url=url,
request_data=request_data,
read_timeout=read_timeout,
write_timeout=write_timeout,
Expand Down Expand Up @@ -191,7 +188,7 @@ async def retrieve(
"""
return await self._request_wrapper(
method='GET',
url=url,
request_data=RequestData(base_url=url),
read_timeout=read_timeout,
write_timeout=write_timeout,
connect_timeout=connect_timeout,
Expand All @@ -201,7 +198,6 @@ async def retrieve(
async def _request_wrapper(
self,
method: str,
url: str,
request_data: RequestData = None,
read_timeout: float = None,
connect_timeout: float = None,
Expand Down Expand Up @@ -234,7 +230,6 @@ async def _request_wrapper(
try:
code, payload = await self.do_request(
method,
url,
request_data=request_data,
read_timeout=read_timeout,
write_timeout=write_timeout,
Expand Down Expand Up @@ -308,7 +303,6 @@ def _parse_json_response(json_payload: bytes) -> JSONDict:
async def do_request(
self,
method: str,
url: str,
request_data: RequestData = None,
connect_timeout: float = None,
read_timeout: float = None,
Expand All @@ -323,7 +317,6 @@ async def do_request(
Args:
method (:obj:`str`): HTTP method (i.e. ``'POST'``, ``'GET'``, etc.).
url (:obj:`str`): The request's URL.
request_data (:class:`telegram.request.RequestData`, optional): An object describing
any parameters and files to upload for the request.
read_timeout (:obj:`float`, optional): If this value is specified, use it as the read
Expand Down
79 changes: 54 additions & 25 deletions telegram/request/_httpxrequest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2015-2021
# Copyright (C) 2015-2022
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
Expand All @@ -15,25 +15,9 @@
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].

#
# A library that provides a Python interface to the Telegram Bot API
# Copyright (C) 2015-2021
# Leandro Toledo de Souza <devs@python-telegram-bot.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser Public License for more details.
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
"""This module contains methods to make POST and GET requests using the httpx library."""
import asyncio
import logging
from typing import Tuple, Optional

import httpx
Expand All @@ -48,14 +32,20 @@
# That also works with socks5. Just pass `--mode socks5` to mitmproxy and the `verify` argument to
# AsyncProxyTransport.from_url

_logger = logging.getLogger(__name__)


class HTTPXRequest(BaseRequest):
"""Implementation of :class`BaseRequest` using the library
`httpx <https://www.python-httpx.org>`_`.
Args:
connection_pool_size (:obj:`int`, optional): Number of connections to keep in the
connection pool. Default to :obj:`1`.
connection pool. Defaults to :obj:`1`.
Note:
Independent of the value, one additional connection will be reserved for
:meth:`telegram.Bot.get_updates`.
proxy_url (:obj:`str`, optional): The URL to the proxy server. For example
``'http://127.0.0.1:3128'`` or ``'socks5://127.0.0.1:3128'``. Defaults to :obj:`None`.
Expand Down Expand Up @@ -83,7 +73,7 @@ class HTTPXRequest(BaseRequest):
"""

__slots__ = ('_client', '_connection_pool_size')
__slots__ = ('_client', '_connection_pool_size', '__pool_semaphore')

def __init__(
self,
Expand All @@ -94,15 +84,18 @@ def __init__(
write_timeout: Optional[float] = 5.0,
pool_timeout: Optional[float] = 1.0,
):
self.__pool_semaphore = asyncio.BoundedSemaphore(connection_pool_size)

timeout = httpx.Timeout(
connect=connect_timeout,
read=read_timeout,
write=write_timeout,
pool=pool_timeout,
)
self._connection_pool_size = connection_pool_size
self._connection_pool_size = connection_pool_size + 1
limits = httpx.Limits(
max_connections=connection_pool_size, max_keepalive_connections=connection_pool_size
max_connections=self.connection_pool_size,
max_keepalive_connections=self.connection_pool_size,
)

# Handle socks5 proxies
Expand Down Expand Up @@ -146,14 +139,43 @@ async def shutdown(self) -> None:
async def do_request(
self,
method: str,
url: str,
request_data: RequestData = None,
connect_timeout: float = None,
read_timeout: float = None,
write_timeout: float = None,
pool_timeout: float = None,
) -> Tuple[int, bytes]:
"""See :meth:`BaseRequest.do_request`."""
if request_data.endpoint == 'getUpdates':
return await self._do_request(
method=method,
request_data=request_data,
connect_timeout=connect_timeout,
read_timeout=read_timeout,
write_timeout=write_timeout,
pool_timeout=pool_timeout,
)

async with self.__pool_semaphore:
out = await self._do_request(
method=method,
request_data=request_data,
connect_timeout=connect_timeout,
read_timeout=read_timeout,
write_timeout=write_timeout,
pool_timeout=pool_timeout,
)
return out

async def _do_request(
self,
method: str,
request_data: RequestData = None,
connect_timeout: float = None,
read_timeout: float = None,
write_timeout: float = None,
pool_timeout: float = None,
) -> Tuple[int, bytes]:
timeout = httpx.Timeout(
connect=self._client.timeout.connect,
read=self._client.timeout.read,
Expand All @@ -177,16 +199,23 @@ async def do_request(

files = request_data.multipart_data if request_data else None
data = request_data.json_parameters if request_data else None

try:
res = await self._client.request(
method=method,
url=url,
url=request_data.url,
headers={'User-Agent': self.USER_AGENT},
timeout=timeout,
files=files,
data=data,
)
except httpx.TimeoutException as err:
if isinstance(err, httpx.PoolTimeout):
_logger.critical(
'All connections in the connection pool are occupied. Request was *not* sent '
'to Telegram. Adjust connection pool size!',
# exc_info=err,
)
raise TimedOut() from err
except httpx.HTTPError as err:
# HTTPError must come last as its the base httpx exception class
Expand Down
21 changes: 16 additions & 5 deletions telegram/request/_requestdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@


class RequestData:
"""Instances of this class represent a collection of parameters and files to be sent along
with a request to the Bot API.
"""Instances of this class collect the data needed for one request to the Bot API, including
all parameters and files to be sent along with the request.
.. versionadded:: 14.0
Expand All @@ -43,17 +43,28 @@ class RequestData:
Attributes:
contains_files (:obj:`bool`): Whether this object contains files to be uploaded via
``multipart/form-data``.
endpoint (:obj:`str`): Optional. The endpoint of the Bot API that this request is directed
at. E.g. ``'sendMessage'``.
"""

__slots__ = ('_parameters', 'contains_files')
__slots__ = ('_parameters', 'contains_files', 'endpoint', '_base_url')

def __init__(
self,
base_url: str,
endpoint: str = None,
parameters: List[RequestParameter] = None,
):
self._base_url = base_url
self.endpoint = endpoint
self._parameters = parameters or []
self.contains_files = any(param.input_files for param in self._parameters)

@property
def url(self) -> str:
"""The URL for this request, including the endpoint, but excluding the parameters."""
return f'{self._base_url}/{self.endpoint}'

@property
def parameters(self) -> Dict[str, Union[str, int, List, Dict]]:
"""Gives the parameters as mapping of parameter name to the parameter value, which can be
Expand Down Expand Up @@ -81,7 +92,7 @@ def url_encoded_parameters(self, encode_kwargs: Dict[str, Any] = None) -> str:
return urlencode(self.json_parameters, **encode_kwargs)
return urlencode(self.json_parameters)

def build_parametrized_url(self, url: str, encode_kwargs: Dict[str, Any] = None) -> str:
def parametrized_url(self, url: str, encode_kwargs: Dict[str, Any] = None) -> str:
"""Shortcut for attaching the return value of :meth:`url_encoded_parameters` to the
:attr:`url`.
Expand All @@ -90,7 +101,7 @@ def build_parametrized_url(self, url: str, encode_kwargs: Dict[str, Any] = None)
along to :meth:`urllib.parse.urlencode`.
"""
url_parameters = self.url_encoded_parameters(encode_kwargs=encode_kwargs)
return f'{url}?{url_parameters}'
return f'{self.url}?{url_parameters}'

@property
def json_payload(self) -> bytes:
Expand Down

0 comments on commit 4d19a16

Please sign in to comment.