Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test and Document Exception.__cause__ on NetworkError #3792

Merged
merged 9 commits into from Jul 18, 2023
25 changes: 22 additions & 3 deletions telegram/error.py
Expand Up @@ -118,11 +118,30 @@ def __init__(self, message: Optional[str] = None) -> None:
class NetworkError(TelegramError):
"""Base class for exceptions due to networking errors.

Args:
cause (:obj:`Exception`, optional): The original exception that caused this error as issued
by the library used for network communication.

.. versionadded:: NEXT.VERSION

Attributes:
cause (:obj:`Exception`): Optional. The original exception that caused this error as issued
by the library used for network communication.

.. versionadded:: NEXT.VERSION

Examples:
:any:`Raw API Bot <examples.rawapibot>`
"""

__slots__ = ()
__slots__ = ("cause",)

def __init__(self, message: str, cause: Optional[Exception] = None) -> None:
super().__init__(message)
self.cause: Optional[Exception] = cause

def __reduce__(self) -> Tuple[type, Tuple[str, Optional[Exception]]]: # type: ignore[override]
return self.__class__, (self.message, self.cause)


class BadRequest(NetworkError):
Expand All @@ -142,8 +161,8 @@ class TimedOut(NetworkError):

__slots__ = ()

def __init__(self, message: Optional[str] = None) -> None:
super().__init__(message or "Timed out")
def __init__(self, message: Optional[str] = None, cause: Optional[Exception] = None) -> None:
super().__init__(message=message or "Timed out", cause=cause)


class ChatMigrated(TelegramError):
Expand Down
4 changes: 3 additions & 1 deletion telegram/request/_baserequest.py
Expand Up @@ -287,7 +287,9 @@ async def _request_wrapper(
except TelegramError as exc:
raise exc
except Exception as exc:
raise NetworkError(f"Unknown error in HTTP implementation: {exc!r}") from exc
raise NetworkError(
f"Unknown error in HTTP implementation: {exc!r}", cause=exc
) from exc

if HTTPStatus.OK <= code <= 299:
# 200-299 range are HTTP success statuses
Expand Down
7 changes: 4 additions & 3 deletions telegram/request/_httpxrequest.py
Expand Up @@ -228,15 +228,16 @@ async def do_request(
"Pool timeout: All connections in the connection pool are occupied. "
"Request was *not* sent to Telegram. Consider adjusting the connection "
"pool size or the pool timeout."
)
),
cause=err,
) from err
raise TimedOut from err
raise TimedOut(cause=err) from err
except httpx.HTTPError as err:
# HTTPError must come last as its the base httpx exception class
# TODO p4: do something smart here; for now just raise NetworkError

# We include the class name for easier debugging. Especially useful if the error
# message of `err` is empty.
raise NetworkError(f"httpx.{err.__class__.__name__}: {err}") from err
raise NetworkError(f"httpx.{err.__class__.__name__}: {err}", cause=err) from err

return res.status_code, res.content
29 changes: 19 additions & 10 deletions tests/request/test_request.py
Expand Up @@ -295,7 +295,7 @@ async def test_special_errors(
(TelegramError("TelegramError"), TelegramError, "TelegramError"),
(
RuntimeError("CustomError"),
Exception,
NetworkError,
r"HTTP implementation: RuntimeError\('CustomError'\)",
),
],
Expand All @@ -312,9 +312,12 @@ async def do_request(*args, **kwargs):
do_request,
)

with pytest.raises(catch_class, match=match):
with pytest.raises(catch_class, match=match) as exc_info:
await httpx_request.post(None, None, None)

if catch_class is NetworkError:
assert exc_info.value.cause is exception

async def test_retrieve(self, monkeypatch, httpx_request):
"""Here we just test that retrieve gives us the raw bytes instead of trying to parse them
as json
Expand Down Expand Up @@ -571,43 +574,49 @@ async def make_assertion(self, method, url, headers, timeout, files, data):
assert content == b"content"

@pytest.mark.parametrize(
("raised_class", "expected_class", "expected_message"),
("raised_exception", "expected_class", "expected_message"),
[
(httpx.TimeoutException, TimedOut, "Timed out"),
(httpx.ReadError, NetworkError, "httpx.ReadError: message"),
(httpx.TimeoutException("timeout"), TimedOut, "Timed out"),
(httpx.ReadError("read_error"), NetworkError, "httpx.ReadError: read_error"),
],
)
async def test_do_request_exceptions(
self, monkeypatch, httpx_request, raised_class, expected_class, expected_message
self, monkeypatch, httpx_request, raised_exception, expected_class, expected_message
):
async def make_assertion(self, method, url, headers, timeout, files, data):
raise raised_class("message")
raise raised_exception

monkeypatch.setattr(httpx.AsyncClient, "request", make_assertion)

with pytest.raises(expected_class, match=expected_message):
with pytest.raises(expected_class, match=expected_message) as exc_info:
await httpx_request.do_request(
"method",
"url",
)

assert exc_info.value.cause is raised_exception

async def test_do_request_pool_timeout(self, monkeypatch):
pool_timeout = httpx.PoolTimeout("pool timeout")

async def request(_, **kwargs):
if self.test_flag is None:
self.test_flag = True
else:
raise httpx.PoolTimeout("pool timeout")
raise pool_timeout
return httpx.Response(HTTPStatus.OK)

monkeypatch.setattr(httpx.AsyncClient, "request", request)

async with HTTPXRequest(pool_timeout=0.02) as httpx_request:
with pytest.raises(TimedOut, match="Pool timeout"):
with pytest.raises(TimedOut, match="Pool timeout") as exc_info:
await asyncio.gather(
httpx_request.do_request(method="GET", url="URL"),
httpx_request.do_request(method="GET", url="URL"),
)

assert exc_info.value.cause is pool_timeout


@pytest.mark.skipif(not TEST_WITH_OPT_DEPS, reason="No need to run this twice")
class TestHTTPXRequestWithRequest:
Expand Down
10 changes: 7 additions & 3 deletions tests/test_error.py
Expand Up @@ -63,6 +63,10 @@ def test_invalid_token(self):
raise InvalidToken

def test_network_error(self):
cause = Exception("test cause")
error = NetworkError("test message", cause)
assert error.cause is cause

with pytest.raises(NetworkError, match="test message"):
raise NetworkError("test message")
with pytest.raises(NetworkError, match="^Test message$"):
Expand Down Expand Up @@ -105,9 +109,9 @@ def test_conflict(self):
(TelegramError("test message"), ["message"]),
(Forbidden("test message"), ["message"]),
(InvalidToken(), ["message"]),
(NetworkError("test message"), ["message"]),
(NetworkError("test message"), ["message", "cause"]),
(BadRequest("test message"), ["message"]),
(TimedOut(), ["message"]),
(TimedOut(), ["message", "cause"]),
(ChatMigrated(1234), ["message", "new_chat_id"]),
(RetryAfter(12), ["message", "retry_after"]),
(Conflict("test message"), ["message"]),
Expand All @@ -130,7 +134,7 @@ def test_errors_pickling(self, exception, attributes):
(TelegramError("test message")),
(Forbidden("test message")),
(InvalidToken()),
(NetworkError("test message")),
(NetworkError("test message", Exception("test cause"))),
(BadRequest("test message")),
(TimedOut()),
(ChatMigrated(1234)),
Expand Down