Skip to content

Commit

Permalink
Add convenience method for exception reporting (#2792)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahopkins committed Jul 17, 2023
1 parent 31d7ba8 commit 9cbe1fb
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 16 deletions.
52 changes: 41 additions & 11 deletions sanic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from asyncio.futures import Future
from collections import defaultdict, deque
from contextlib import contextmanager, suppress
from functools import partial
from functools import partial, wraps
from inspect import isawaitable
from os import environ
from socket import socket
Expand Down Expand Up @@ -87,7 +87,7 @@
from sanic.response import BaseHTTPResponse, HTTPResponse, ResponseStream
from sanic.router import Router
from sanic.server.websockets.impl import ConnectionClosed
from sanic.signals import Signal, SignalRouter
from sanic.signals import Event, Signal, SignalRouter
from sanic.touchup import TouchUp, TouchUpMeta
from sanic.types.shared_ctx import SharedContext
from sanic.worker.inspector import Inspector
Expand Down Expand Up @@ -605,6 +605,19 @@ async def event(
raise NotFound("Could not find signal %s" % event)
return await wait_for(signal.ctx.event.wait(), timeout=timeout)

def report_exception(
self, handler: Callable[[Sanic, Exception], Coroutine[Any, Any, None]]
):
@wraps(handler)
async def report(exception: Exception) -> None:
await handler(self, exception)

self.add_signal(
handler=report, event=Event.SERVER_EXCEPTION_REPORT.value
)

return report

def enable_websocket(self, enable=True):
"""Enable or disable the support for websocket.
Expand Down Expand Up @@ -876,10 +889,12 @@ async def handle_exception(
:raises ServerError: response 500
"""
response = None
await self.dispatch(
"server.lifecycle.exception",
context={"exception": exception},
)
if not getattr(exception, "__dispatched__", False):
... # DO NOT REMOVE THIS LINE. IT IS NEEDED FOR TOUCHUP.
await self.dispatch(
"server.exception.report",
context={"exception": exception},
)
await self.dispatch(
"http.lifecycle.exception",
inline=True,
Expand Down Expand Up @@ -1310,13 +1325,28 @@ def _prep_task(
app,
loop,
):
if callable(task):
async def do(task):
try:
task = task(app)
except TypeError:
task = task()
if callable(task):
try:
task = task(app)
except TypeError:
task = task()
if isawaitable(task):
await task
except CancelledError:
error_logger.warning(
f"Task {task} was cancelled before it completed."
)
raise
except Exception as e:
await app.dispatch(
"server.exception.report",
context={"exception": e},
)
raise

return task
return do(task)

@classmethod
def _loop_add_task(
Expand Down
9 changes: 5 additions & 4 deletions sanic/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@


class Event(Enum):
SERVER_EXCEPTION_REPORT = "server.exception.report"
SERVER_INIT_AFTER = "server.init.after"
SERVER_INIT_BEFORE = "server.init.before"
SERVER_SHUTDOWN_AFTER = "server.shutdown.after"
SERVER_SHUTDOWN_BEFORE = "server.shutdown.before"
SERVER_LIFECYCLE_EXCEPTION = "server.lifecycle.exception"
HTTP_LIFECYCLE_BEGIN = "http.lifecycle.begin"
HTTP_LIFECYCLE_COMPLETE = "http.lifecycle.complete"
HTTP_LIFECYCLE_EXCEPTION = "http.lifecycle.exception"
Expand All @@ -40,11 +40,11 @@ class Event(Enum):

RESERVED_NAMESPACES = {
"server": (
Event.SERVER_EXCEPTION_REPORT.value,
Event.SERVER_INIT_AFTER.value,
Event.SERVER_INIT_BEFORE.value,
Event.SERVER_SHUTDOWN_AFTER.value,
Event.SERVER_SHUTDOWN_BEFORE.value,
Event.SERVER_LIFECYCLE_EXCEPTION.value,
),
"http": (
Event.HTTP_LIFECYCLE_BEGIN.value,
Expand Down Expand Up @@ -174,11 +174,12 @@ async def _dispatch(
if self.ctx.app.debug and self.ctx.app.state.verbosity >= 1:
error_logger.exception(e)

if event != Event.SERVER_LIFECYCLE_EXCEPTION.value:
if event != Event.SERVER_EXCEPTION_REPORT.value:
await self.dispatch(
Event.SERVER_LIFECYCLE_EXCEPTION.value,
Event.SERVER_EXCEPTION_REPORT.value,
context={"exception": e},
)
setattr(e, "__dispatched__", True)
raise e
finally:
for signal_event in events:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_signal_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_signal_server_lifecycle_exception(app: Sanic):
async def hello_route(request):
return HTTPResponse()

@app.signal(Event.SERVER_LIFECYCLE_EXCEPTION)
@app.signal(Event.SERVER_EXCEPTION_REPORT)
async def test_signal(exception: Exception):
nonlocal trigger
trigger = exception
Expand Down
113 changes: 113 additions & 0 deletions tests/test_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from enum import Enum
from inspect import isawaitable
from itertools import count

import pytest

from sanic_routing.exceptions import NotFound

from sanic import Blueprint, Sanic, empty
from sanic.exceptions import InvalidSignal, SanicException
from sanic.signals import Event


def test_add_signal(app):
Expand Down Expand Up @@ -427,3 +429,114 @@ def test_signal_reservation(app, event, expected):
app.signal(event)(lambda: ...)
else:
app.signal(event)(lambda: ...)


@pytest.mark.asyncio
async def test_report_exception(app: Sanic):
@app.report_exception
async def catch_any_exception(app: Sanic, exception: Exception):
...

@app.route("/")
async def handler(request):
1 / 0

app.signal_router.finalize()

registered_signal_handlers = [
handler
for handler, *_ in app.signal_router.get(
Event.SERVER_EXCEPTION_REPORT.value
)
]

assert catch_any_exception in registered_signal_handlers


def test_report_exception_runs(app: Sanic):
event = asyncio.Event()

@app.report_exception
async def catch_any_exception(app: Sanic, exception: Exception):
event.set()

@app.route("/")
async def handler(request):
1 / 0

app.test_client.get("/")

assert event.is_set()


def test_report_exception_runs_once_inline(app: Sanic):
event = asyncio.Event()
c = count()

@app.report_exception
async def catch_any_exception(app: Sanic, exception: Exception):
event.set()
next(c)

@app.route("/")
async def handler(request):
...

@app.signal(Event.HTTP_ROUTING_AFTER.value)
async def after_routing(**_):
1 / 0

app.test_client.get("/")

assert event.is_set()
assert next(c) == 1


def test_report_exception_runs_once_custom(app: Sanic):
event = asyncio.Event()
c = count()

@app.report_exception
async def catch_any_exception(app: Sanic, exception: Exception):
event.set()
next(c)

@app.route("/")
async def handler(request):
await app.dispatch("one.two.three")
return empty()

@app.signal("one.two.three")
async def one_two_three(**_):
1 / 0

app.test_client.get("/")

assert event.is_set()
assert next(c) == 1


def test_report_exception_runs_task(app: Sanic):
c = count()

async def task_1():
next(c)

async def task_2(app):
next(c)

@app.report_exception
async def catch_any_exception(app: Sanic, exception: Exception):
next(c)

@app.route("/")
async def handler(request):
app.add_task(task_1)
app.add_task(task_1())
app.add_task(task_2)
app.add_task(task_2(app))
return empty()

app.test_client.get("/")

assert next(c) == 4

0 comments on commit 9cbe1fb

Please sign in to comment.