Skip to content

Commit

Permalink
Feat: Introduce unix mode for webhook
Browse files Browse the repository at this point in the history
  • Loading branch information
Poolitzer committed Nov 26, 2023
1 parent dd9af64 commit 52febca
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 22 deletions.
11 changes: 9 additions & 2 deletions telegram/ext/_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
from telegram.ext._basepersistence import BasePersistence
from telegram.ext._contexttypes import ContextTypes
from telegram.ext._extbot import ExtBot
from telegram.ext._updater import Updater
from telegram.ext._updater import Updater, _DefaultIP
from telegram.ext._utils.stack import was_called_by
from telegram.ext._utils.trackingdict import TrackingDict
from telegram.ext._utils.types import BD, BT, CCT, CD, JQ, RT, UD, ConversationKey, HandlerCallback
Expand Down Expand Up @@ -805,7 +805,7 @@ def error_callback(exc: TelegramError) -> None:

def run_webhook(
self,
listen: str = "127.0.0.1",
listen: DVType[str] = _DefaultIP,
port: int = 80,
url_path: str = "",
cert: Optional[Union[str, Path]] = None,
Expand All @@ -819,6 +819,7 @@ def run_webhook(
close_loop: bool = True,
stop_signals: ODVInput[Sequence[int]] = DEFAULT_NONE,
secret_token: Optional[str] = None,
unix: Optional[Union[str, Path]] = None,
) -> None:
"""Convenience method that takes care of initializing and starting the app,
listening for updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and
Expand Down Expand Up @@ -911,6 +912,11 @@ def run_webhook(
header isn't set or it is set to a wrong token.
.. versionadded:: 20.0
unix (:class:`pathlib.Path` | :obj:`str`, optional): Path to the unix socket file. Path
can be empty in which case the file will be created. When using this param, you
need to set the :paramref:`webhook_url`!
.. versionadded:: NEXT.VERSION
"""
if not self.updater:
raise RuntimeError(
Expand All @@ -931,6 +937,7 @@ def run_webhook(
ip_address=ip_address,
max_connections=max_connections,
secret_token=secret_token,
unix=unix,
),
close_loop=close_loop,
stop_signals=stop_signals,
Expand Down
28 changes: 21 additions & 7 deletions telegram/ext/_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@
Union,
)

from telegram._utils.defaultvalue import DEFAULT_NONE
from telegram._utils.defaultvalue import DEFAULT_NONE, DefaultValue
from telegram._utils.logging import get_logger
from telegram._utils.repr import build_repr_with_selected_attrs
from telegram._utils.types import ODVInput
from telegram._utils.types import DVType, ODVInput
from telegram.error import InvalidToken, RetryAfter, TelegramError, TimedOut

try:
Expand All @@ -53,6 +53,7 @@


_UpdaterType = TypeVar("_UpdaterType", bound="Updater") # pylint: disable=invalid-name
_DefaultIP = DefaultValue("127.0.0.1")
_LOGGER = get_logger(__name__)


Expand Down Expand Up @@ -419,7 +420,7 @@ async def _get_updates_cleanup() -> None:

async def start_webhook(
self,
listen: str = "127.0.0.1",
listen: DVType[str] = _DefaultIP,
port: int = 80,
url_path: str = "",
cert: Optional[Union[str, Path]] = None,
Expand All @@ -431,6 +432,7 @@ async def start_webhook(
ip_address: Optional[str] = None,
max_connections: int = 40,
secret_token: Optional[str] = None,
unix: Optional[Union[str, Path]] = None,
) -> "asyncio.Queue[object]":
"""
Starts a small http server to listen for updates via webhook. If :paramref:`cert`
Expand Down Expand Up @@ -499,6 +501,11 @@ async def start_webhook(
header isn't set or it is set to a wrong token.
.. versionadded:: 20.0
unix (:class:`pathlib.Path` | :obj:`str`, optional): Path to the unix socket file. Path
can be empty in which case the file will be created. When using this param, you
need to set the :paramref:`webhook_url`!
.. versionadded:: NEXT.VERSION
Returns:
:class:`queue.Queue`: The update queue that can be filled from the main thread.
Expand Down Expand Up @@ -537,6 +544,7 @@ async def start_webhook(
ip_address=ip_address,
max_connections=max_connections,
secret_token=secret_token,
unix=unix,
)

_LOGGER.debug("Waiting for webhook server to start")
Expand All @@ -551,7 +559,7 @@ async def start_webhook(

async def _start_webhook(
self,
listen: str,
listen: DVType[str],
port: int,
url_path: str,
bootstrap_retries: int,
Expand All @@ -564,6 +572,7 @@ async def _start_webhook(
ip_address: Optional[str] = None,
max_connections: int = 40,
secret_token: Optional[str] = None,
unix: Optional[Union[str, Path]] = None,
) -> None:
_LOGGER.debug("Updater thread started (webhook)")

Expand All @@ -588,14 +597,19 @@ async def _start_webhook(
raise TelegramError("Invalid SSL Certificate") from exc
else:
ssl_ctx = None

# If unix is used, the webhook_url can't be generated and thus must be set by the user
if unix and not webhook_url:
raise RuntimeError(
"Since you set unix, you also need to set the URL to the webhook "
"of the proxy you run in front of the unix socket."
)
# Create and start server
self._httpd = WebhookServer(listen, port, app, ssl_ctx)
self._httpd = WebhookServer(listen, port, app, ssl_ctx, unix)

if not webhook_url:
webhook_url = self._gen_webhook_url(
protocol="https" if ssl_ctx else "http",
listen=listen,
listen=DefaultValue.get_value(listen),
port=port,
url_path=url_path,
)
Expand Down
26 changes: 23 additions & 3 deletions telegram/ext/_utils/webhookhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,21 @@
import asyncio
import json
from http import HTTPStatus
from pathlib import Path
from ssl import SSLContext
from types import TracebackType
from typing import TYPE_CHECKING, Optional, Type
from typing import TYPE_CHECKING, Optional, Type, Union

# Instead of checking for ImportError here, we do that in `updater.py`, where we import from
# this module. Doing it here would be tricky, as the classes below subclass tornado classes
import tornado.web
from tornado.httpserver import HTTPServer
from tornado.netutil import bind_unix_socket

from telegram import Update
from telegram._utils.defaultvalue import DefaultValue
from telegram._utils.logging import get_logger
from telegram._utils.types import DVType
from telegram.ext._extbot import ExtBot

if TYPE_CHECKING:
Expand All @@ -50,21 +54,37 @@ class WebhookServer:
"is_running",
"_server_lock",
"_shutdown_lock",
"unix",
)

def __init__(
self, listen: str, port: int, webhook_app: "WebhookAppClass", ssl_ctx: Optional[SSLContext]
self,
listen: DVType[str],
port: int,
webhook_app: "WebhookAppClass",
ssl_ctx: Optional[SSLContext],
unix: Optional[Union[str, Path]] = None,
):
if not isinstance(listen, DefaultValue) and unix:
raise RuntimeError(
"You can not pass unix and listen, only use one. Unix if you want to initialize a"
" unix socket, or listen for a standard TCP server"
)
self._http_server = HTTPServer(webhook_app, ssl_options=ssl_ctx)
self.listen = listen
self.port = port
self.is_running = False
self.unix = unix
self._server_lock = asyncio.Lock()
self._shutdown_lock = asyncio.Lock()

async def serve_forever(self, ready: Optional[asyncio.Event] = None) -> None:
async with self._server_lock:
self._http_server.listen(self.port, address=self.listen)
if self.unix:
socket = bind_unix_socket(str(self.unix))
self._http_server.add_socket(socket)
else:
self._http_server.listen(self.port, address=DefaultValue.get_value(self.listen))

self.is_running = True
if ready is not None:
Expand Down
9 changes: 7 additions & 2 deletions tests/auxil/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Optional

import pytest
from httpx import AsyncClient, Response
from httpx import AsyncClient, AsyncHTTPTransport, Response

from telegram._utils.defaultvalue import DEFAULT_NONE
from telegram._utils.types import ODVInput
Expand Down Expand Up @@ -90,6 +90,7 @@ async def send_webhook_message(
content_type: str = "application/json",
get_method: Optional[str] = None,
secret_token: Optional[str] = None,
unix: Optional[str] = None,
) -> Response:
headers = {
"content-type": content_type,
Expand All @@ -111,7 +112,11 @@ async def send_webhook_message(

url = f"http://{ip}:{port}/{url_path}"

async with AsyncClient() as client:
transport = None
if unix:
transport = AsyncHTTPTransport(uds=unix)

async with AsyncClient(transport=transport) as client:
return await client.request(
url=url, method=get_method or "POST", data=payload, headers=headers
)
67 changes: 59 additions & 8 deletions tests/ext/test_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from telegram.request import HTTPXRequest
from tests.auxil.build_messages import make_message, make_message_update
from tests.auxil.envvars import TEST_WITH_OPT_DEPS
from tests.auxil.files import data_file
from tests.auxil.files import TEST_DATA_PATH, data_file
from tests.auxil.networking import send_webhook_message
from tests.auxil.pytest_classes import PytestBot, make_bot
from tests.auxil.slots import mro_slots
Expand Down Expand Up @@ -73,6 +73,12 @@ def _reset(self):
self.cb_handler_called = None
self.test_flag = False

@pytest.fixture()
def file_path(self) -> str:
path = TEST_DATA_PATH / "test.sock"
yield str(path)
path.unlink(missing_ok=True)

def error_callback(self, error):
self.received = error
self.err_handler_called.set()
Expand Down Expand Up @@ -646,8 +652,9 @@ async def delete_webhook(*args, **kwargs):
@pytest.mark.parametrize("ext_bot", [True, False])
@pytest.mark.parametrize("drop_pending_updates", [True, False])
@pytest.mark.parametrize("secret_token", ["SecretToken", None])
@pytest.mark.parametrize("unix", [None, True])
async def test_webhook_basic(
self, monkeypatch, updater, drop_pending_updates, ext_bot, secret_token
self, monkeypatch, updater, drop_pending_updates, ext_bot, secret_token, unix, file_path
):
# Testing with both ExtBot and Bot to make sure any logic in WebhookHandler
# that depends on this distinction works
Expand Down Expand Up @@ -678,37 +685,55 @@ async def set_webhook(*args, **kwargs):
port=port,
url_path="TOKEN",
secret_token=secret_token,
unix=file_path if unix else None,
webhook_url="string",
)
assert return_value is updater.update_queue
assert updater.running

# Now, we send an update to the server
update = make_message_update("Webhook")
await send_webhook_message(
ip, port, update.to_json(), "TOKEN", secret_token=secret_token
ip,
port,
update.to_json(),
"TOKEN",
secret_token=secret_token,
unix=file_path if unix else None,
)
assert (await updater.update_queue.get()).to_dict() == update.to_dict()

# Returns Not Found if path is incorrect
response = await send_webhook_message(ip, port, "123456", "webhook_handler.py")
response = await send_webhook_message(
ip, port, "123456", "webhook_handler.py", unix=file_path if unix else None
)
assert response.status_code == HTTPStatus.NOT_FOUND

# Returns METHOD_NOT_ALLOWED if method is not allowed
response = await send_webhook_message(ip, port, None, "TOKEN", get_method="HEAD")
response = await send_webhook_message(
ip, port, None, "TOKEN", get_method="HEAD", unix=file_path if unix else None
)
assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED

if secret_token:
# Returns Forbidden if no secret token is set
response_text = "<html><title>403: {0}</title><body>403: {0}</body></html>"
response = await send_webhook_message(ip, port, update.to_json(), "TOKEN")
response = await send_webhook_message(
ip, port, update.to_json(), "TOKEN", unix=file_path if unix else None
)
assert response.status_code == HTTPStatus.FORBIDDEN
assert response.text == response_text.format(
"Request did not include the secret token"
)

# Returns Forbidden if the secret token is wrong
response = await send_webhook_message(
ip, port, update.to_json(), "TOKEN", secret_token="NotTheSecretToken"
ip,
port,
update.to_json(),
"TOKEN",
secret_token="NotTheSecretToken",
unix=file_path if unix else None,
)
assert response.status_code == HTTPStatus.FORBIDDEN
assert response.text == response_text.format("Request had the wrong secret token")
Expand All @@ -727,14 +752,40 @@ async def set_webhook(*args, **kwargs):
ip_address=ip,
port=port,
url_path="TOKEN",
unix=file_path if unix else None,
webhook_url="string",
)
assert updater.running
update = make_message_update("Webhook")
await send_webhook_message(ip, port, update.to_json(), "TOKEN")
await send_webhook_message(
ip, port, update.to_json(), "TOKEN", unix=file_path if unix else None
)
assert (await updater.update_queue.get()).to_dict() == update.to_dict()
await updater.stop()
assert not updater.running

async def test_unix_webhook_raises(self, updater):
async with updater:
with pytest.raises(RuntimeError, match="URL"):
await updater.start_webhook(listen="127.0.0.1", unix="DoesntMatter")
with pytest.raises(RuntimeError, match="unix"):
await updater.start_webhook(
listen="127.0.0.1", unix="DoesntMatter", webhook_url="string"
)

async def test_unix_webhook_path(self, updater, monkeypatch, file_path):
async def set_webhook(*args, **kwargs):
return True

monkeypatch.setattr(updater.bot, "set_webhook", set_webhook)

async with updater:
await updater.start_webhook(unix=file_path, url_path="TOKEN", webhook_url="string")

update = make_message_update("Webhook")
await send_webhook_message("ip", 123, update.to_json(), "TOKEN", unix=file_path)
await updater.stop()

async def test_start_webhook_already_running(self, updater, monkeypatch):
async def return_true(*args, **kwargs):
return True
Expand Down

0 comments on commit 52febca

Please sign in to comment.