Skip to content

Commit

Permalink
Add Support for Unix Sockets to Updater.start_webhook (#3986)
Browse files Browse the repository at this point in the history
Co-authored-by: Bibo-Joshi <22366557+Bibo-Joshi@users.noreply.github.com>
  • Loading branch information
Poolitzer and Bibo-Joshi committed Dec 14, 2023
1 parent cc45f49 commit 2345bfb
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 38 deletions.
16 changes: 16 additions & 0 deletions telegram/_utils/defaultvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,19 @@ def get_value(obj: Union[OT, "DefaultValue[OT]"]) -> OT:
.. versionadded:: 20.0
"""


DEFAULT_20: DefaultValue[int] = DefaultValue(20)
""":class:`DefaultValue`: Default :obj:`20`"""

DEFAULT_IP: DefaultValue[str] = DefaultValue("127.0.0.1")
""":class:`DefaultValue`: Default :obj:`127.0.0.1`
.. versionadded:: NEXT.VERSION
"""

DEFAULT_80: DefaultValue[int] = DefaultValue(80)
""":class:`DefaultValue`: Default :obj:`80`
.. versionadded:: NEXT.VERSION
"""
24 changes: 21 additions & 3 deletions telegram/ext/_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@
)

from telegram._update import Update
from telegram._utils.defaultvalue import DEFAULT_NONE, DEFAULT_TRUE, DefaultValue
from telegram._utils.defaultvalue import (
DEFAULT_80,
DEFAULT_IP,
DEFAULT_NONE,
DEFAULT_TRUE,
DefaultValue,
)
from telegram._utils.logging import get_logger
from telegram._utils.repr import build_repr_with_selected_attrs
from telegram._utils.types import SCT, DVType, ODVInput
Expand Down Expand Up @@ -834,8 +840,8 @@ def error_callback(exc: TelegramError) -> None:

def run_webhook(
self,
listen: str = "127.0.0.1",
port: int = 80,
listen: DVType[str] = DEFAULT_IP,
port: DVType[int] = DEFAULT_80,
url_path: str = "",
cert: Optional[Union[str, Path]] = None,
key: Optional[Union[str, Path]] = None,
Expand All @@ -848,6 +854,7 @@ def run_webhook(
close_loop: bool = True,
stop_signals: ODVInput[Sequence[int]] = DEFAULT_NONE,
secret_token: Optional[str] = None,
unix: Optional[Union[str, Path]] = None,
) -> None:
"""Convenience method that takes care of initializing and starting the app,
listening for updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and
Expand Down Expand Up @@ -940,6 +947,16 @@ def run_webhook(
header isn't set or it is set to a wrong token.
.. versionadded:: 20.0
unix (:class:`pathlib.Path` | :obj:`str`, optional): Path to the unix socket file. Path
does not need to exist, in which case the file will be created.
Caution:
This parameter is a replacement for the default TCP bind. Therefore, it is
mutually exclusive with :paramref:`listen` and :paramref:`port`. When using
this param, you must also run a reverse proxy to the unix socket and set the
appropriate :paramref:`webhook_url`.
.. versionadded:: NEXT.VERSION
"""
if not self.updater:
raise RuntimeError(
Expand All @@ -960,6 +977,7 @@ def run_webhook(
ip_address=ip_address,
max_connections=max_connections,
secret_token=secret_token,
unix=unix,
),
close_loop=close_loop,
stop_signals=stop_signals,
Expand Down
45 changes: 36 additions & 9 deletions telegram/ext/_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@
Union,
)

from telegram._utils.defaultvalue import DEFAULT_NONE
from telegram._utils.defaultvalue import DEFAULT_80, DEFAULT_IP, DEFAULT_NONE, DefaultValue
from telegram._utils.logging import get_logger
from telegram._utils.repr import build_repr_with_selected_attrs
from telegram._utils.types import ODVInput
from telegram._utils.types import DVType, ODVInput
from telegram.error import InvalidToken, RetryAfter, TelegramError, TimedOut

try:
Expand Down Expand Up @@ -456,8 +456,8 @@ async def _get_updates_cleanup() -> None:

async def start_webhook(
self,
listen: str = "127.0.0.1",
port: int = 80,
listen: DVType[str] = DEFAULT_IP,
port: DVType[int] = DEFAULT_80,
url_path: str = "",
cert: Optional[Union[str, Path]] = None,
key: Optional[Union[str, Path]] = None,
Expand All @@ -468,6 +468,7 @@ async def start_webhook(
ip_address: Optional[str] = None,
max_connections: int = 40,
secret_token: Optional[str] = None,
unix: Optional[Union[str, Path]] = None,
) -> "asyncio.Queue[object]":
"""
Starts a small http server to listen for updates via webhook. If :paramref:`cert`
Expand Down Expand Up @@ -536,6 +537,16 @@ async def start_webhook(
header isn't set or it is set to a wrong token.
.. versionadded:: 20.0
unix (:class:`pathlib.Path` | :obj:`str`, optional): Path to the unix socket file. Path
does not need to exist, in which case the file will be created.
Caution:
This parameter is a replacement for the default TCP bind. Therefore, it is
mutually exclusive with :paramref:`listen` and :paramref:`port`. When using
this param, you must also run a reverse proxy to the unix socket and set the
appropriate :paramref:`webhook_url`.
.. versionadded:: NEXT.VERSION
Returns:
:class:`queue.Queue`: The update queue that can be filled from the main thread.
Expand All @@ -547,6 +558,21 @@ async def start_webhook(
"To use `start_webhook`, PTB must be installed via `pip install "
'"python-telegram-bot[webhooks]"`.'
)
# unix has special requirements what must and mustn't be set when using it
if unix:
error_msg = (
"You can not pass unix and {0}, only use one. Unix if you want to "
"initialize a unix socket, or {0} for a standard TCP server."
)
if not isinstance(listen, DefaultValue):
raise RuntimeError(error_msg.format("listen"))
if not isinstance(port, DefaultValue):
raise RuntimeError(error_msg.format("port"))
if not webhook_url:
raise RuntimeError(
"Since you set unix, you also need to set the URL to the webhook "
"of the proxy you run in front of the unix socket."
)

async with self.__lock:
if self.running:
Expand All @@ -561,8 +587,8 @@ async def start_webhook(
webhook_ready = asyncio.Event()

await self._start_webhook(
listen=listen,
port=port,
listen=DefaultValue.get_value(listen),
port=DefaultValue.get_value(port),
url_path=url_path,
cert=cert,
key=key,
Expand All @@ -574,6 +600,7 @@ async def start_webhook(
ip_address=ip_address,
max_connections=max_connections,
secret_token=secret_token,
unix=unix,
)

_LOGGER.debug("Waiting for webhook server to start")
Expand Down Expand Up @@ -601,6 +628,7 @@ async def _start_webhook(
ip_address: Optional[str] = None,
max_connections: int = 40,
secret_token: Optional[str] = None,
unix: Optional[Union[str, Path]] = None,
) -> None:
_LOGGER.debug("Updater thread started (webhook)")

Expand All @@ -625,14 +653,13 @@ async def _start_webhook(
raise TelegramError("Invalid SSL Certificate") from exc
else:
ssl_ctx = None

# Create and start server
self._httpd = WebhookServer(listen, port, app, ssl_ctx)
self._httpd = WebhookServer(listen, port, app, ssl_ctx, unix)

if not webhook_url:
webhook_url = self._gen_webhook_url(
protocol="https" if ssl_ctx else "http",
listen=listen,
listen=DefaultValue.get_value(listen),
port=port,
url_path=url_path,
)
Expand Down
27 changes: 24 additions & 3 deletions telegram/ext/_utils/webhookhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,23 @@
import asyncio
import json
from http import HTTPStatus
from pathlib import Path
from ssl import SSLContext
from types import TracebackType
from typing import TYPE_CHECKING, Optional, Type
from typing import TYPE_CHECKING, Optional, Type, Union

# Instead of checking for ImportError here, we do that in `updater.py`, where we import from
# this module. Doing it here would be tricky, as the classes below subclass tornado classes
import tornado.web
from tornado.httpserver import HTTPServer

try:
from tornado.netutil import bind_unix_socket

UNIX_AVAILABLE = True
except ImportError:
UNIX_AVAILABLE = False

from telegram import Update
from telegram._utils.logging import get_logger
from telegram.ext._extbot import ExtBot
Expand All @@ -50,21 +58,34 @@ class WebhookServer:
"is_running",
"_server_lock",
"_shutdown_lock",
"unix",
)

def __init__(
self, listen: str, port: int, webhook_app: "WebhookAppClass", ssl_ctx: Optional[SSLContext]
self,
listen: str,
port: int,
webhook_app: "WebhookAppClass",
ssl_ctx: Optional[SSLContext],
unix: Optional[Union[str, Path]] = None,
):
if unix and not UNIX_AVAILABLE:
raise RuntimeError("This OS does not support binding unix sockets.")
self._http_server = HTTPServer(webhook_app, ssl_options=ssl_ctx)
self.listen = listen
self.port = port
self.is_running = False
self.unix = unix
self._server_lock = asyncio.Lock()
self._shutdown_lock = asyncio.Lock()

async def serve_forever(self, ready: Optional[asyncio.Event] = None) -> None:
async with self._server_lock:
self._http_server.listen(self.port, address=self.listen)
if self.unix:
socket = bind_unix_socket(str(self.unix))
self._http_server.add_socket(socket)
else:
self._http_server.listen(self.port, address=self.listen)

self.is_running = True
if ready is not None:
Expand Down
8 changes: 6 additions & 2 deletions tests/auxil/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
#
# You should have received a copy of the GNU Lesser Public License
# along with this program. If not, see [http://www.gnu.org/licenses/].
from pathlib import Path
from typing import Optional

import pytest
from httpx import AsyncClient, Response
from httpx import AsyncClient, AsyncHTTPTransport, Response

from telegram._utils.defaultvalue import DEFAULT_NONE
from telegram._utils.types import ODVInput
Expand Down Expand Up @@ -90,6 +91,7 @@ async def send_webhook_message(
content_type: str = "application/json",
get_method: Optional[str] = None,
secret_token: Optional[str] = None,
unix: Optional[Path] = None,
) -> Response:
headers = {
"content-type": content_type,
Expand All @@ -111,7 +113,9 @@ async def send_webhook_message(

url = f"http://{ip}:{port}/{url_path}"

async with AsyncClient() as client:
transport = AsyncHTTPTransport(uds=unix) if unix else None

async with AsyncClient(transport=transport) as client:
return await client.request(
url=url, method=get_method or "POST", data=payload, headers=headers
)

0 comments on commit 2345bfb

Please sign in to comment.