Skip to content

Commit

Permalink
Add custom typing to config and ctx (#2785)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahopkins committed Jul 12, 2023
1 parent 929d270 commit dc3c4d1
Show file tree
Hide file tree
Showing 21 changed files with 433 additions and 20 deletions.
19 changes: 19 additions & 0 deletions sanic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from types import SimpleNamespace

from typing_extensions import TypeAlias

from sanic.__version__ import __version__
from sanic.app import Sanic
from sanic.blueprints import Blueprint
from sanic.config import Config
from sanic.constants import HTTPMethod
from sanic.exceptions import (
BadRequest,
Expand Down Expand Up @@ -32,15 +37,29 @@
from sanic.server.websockets.impl import WebsocketImplProtocol as Websocket


DefaultSanic: TypeAlias = "Sanic[Config, SimpleNamespace]"
"""
A type alias for a Sanic app with a default config and namespace.
"""

DefaultRequest: TypeAlias = Request[DefaultSanic, SimpleNamespace]
"""
A type alias for a request with a default Sanic app and namespace.
"""

__all__ = (
"__version__",
# Common objects
"Sanic",
"Config",
"Blueprint",
"HTTPMethod",
"HTTPResponse",
"Request",
"Websocket",
# Common types
"DefaultSanic",
"DefaultRequest",
# Common exceptions
"BadRequest",
"ExpectationFailed",
Expand Down
114 changes: 108 additions & 6 deletions sanic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Coroutine,
Deque,
Dict,
Generic,
Iterable,
Iterator,
List,
Expand All @@ -42,6 +43,8 @@
Type,
TypeVar,
Union,
cast,
overload,
)
from urllib.parse import urlencode, urlunparse

Expand Down Expand Up @@ -103,8 +106,17 @@
if OS_IS_WINDOWS: # no cov
enable_windows_color_support()

ctx_type = TypeVar("ctx_type")
config_type = TypeVar("config_type", bound=Config)

class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta):

class Sanic(
Generic[config_type, ctx_type],
StaticHandleMixin,
BaseSanic,
StartupMixin,
metaclass=TouchUpMeta,
):
"""
The main application instance
"""
Expand Down Expand Up @@ -162,11 +174,99 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta):
_app_registry: ClassVar[Dict[str, "Sanic"]] = {}
test_mode: ClassVar[bool] = False

@overload
def __init__(
self: Sanic[Config, SimpleNamespace],
name: str,
config: None = None,
ctx: None = None,
router: Optional[Router] = None,
signal_router: Optional[SignalRouter] = None,
error_handler: Optional[ErrorHandler] = None,
env_prefix: Optional[str] = SANIC_PREFIX,
request_class: Optional[Type[Request]] = None,
strict_slashes: bool = False,
log_config: Optional[Dict[str, Any]] = None,
configure_logging: bool = True,
dumps: Optional[Callable[..., AnyStr]] = None,
loads: Optional[Callable[..., Any]] = None,
inspector: bool = False,
inspector_class: Optional[Type[Inspector]] = None,
certloader_class: Optional[Type[CertLoader]] = None,
) -> None:
...

@overload
def __init__(
self: Sanic[config_type, SimpleNamespace],
name: str,
config: Optional[config_type] = None,
ctx: None = None,
router: Optional[Router] = None,
signal_router: Optional[SignalRouter] = None,
error_handler: Optional[ErrorHandler] = None,
env_prefix: Optional[str] = SANIC_PREFIX,
request_class: Optional[Type[Request]] = None,
strict_slashes: bool = False,
log_config: Optional[Dict[str, Any]] = None,
configure_logging: bool = True,
dumps: Optional[Callable[..., AnyStr]] = None,
loads: Optional[Callable[..., Any]] = None,
inspector: bool = False,
inspector_class: Optional[Type[Inspector]] = None,
certloader_class: Optional[Type[CertLoader]] = None,
) -> None:
...

@overload
def __init__(
self: Sanic[Config, ctx_type],
name: str,
config: None = None,
ctx: Optional[ctx_type] = None,
router: Optional[Router] = None,
signal_router: Optional[SignalRouter] = None,
error_handler: Optional[ErrorHandler] = None,
env_prefix: Optional[str] = SANIC_PREFIX,
request_class: Optional[Type[Request]] = None,
strict_slashes: bool = False,
log_config: Optional[Dict[str, Any]] = None,
configure_logging: bool = True,
dumps: Optional[Callable[..., AnyStr]] = None,
loads: Optional[Callable[..., Any]] = None,
inspector: bool = False,
inspector_class: Optional[Type[Inspector]] = None,
certloader_class: Optional[Type[CertLoader]] = None,
) -> None:
...

@overload
def __init__(
self: Sanic[config_type, ctx_type],
name: str,
config: Optional[config_type] = None,
ctx: Optional[ctx_type] = None,
router: Optional[Router] = None,
signal_router: Optional[SignalRouter] = None,
error_handler: Optional[ErrorHandler] = None,
env_prefix: Optional[str] = SANIC_PREFIX,
request_class: Optional[Type[Request]] = None,
strict_slashes: bool = False,
log_config: Optional[Dict[str, Any]] = None,
configure_logging: bool = True,
dumps: Optional[Callable[..., AnyStr]] = None,
loads: Optional[Callable[..., Any]] = None,
inspector: bool = False,
inspector_class: Optional[Type[Inspector]] = None,
certloader_class: Optional[Type[CertLoader]] = None,
) -> None:
...

def __init__(
self,
name: Optional[str] = None,
config: Optional[Config] = None,
ctx: Optional[Any] = None,
name: str,
config: Optional[config_type] = None,
ctx: Optional[ctx_type] = None,
router: Optional[Router] = None,
signal_router: Optional[SignalRouter] = None,
error_handler: Optional[ErrorHandler] = None,
Expand Down Expand Up @@ -194,7 +294,9 @@ def __init__(
)

# First setup config
self.config: Config = config or Config(env_prefix=env_prefix)
self.config: config_type = cast(
config_type, config or Config(env_prefix=env_prefix)
)
if inspector:
self.config.INSPECTOR = inspector

Expand All @@ -218,7 +320,7 @@ def __init__(
certloader_class or CertLoader
)
self.configure_logging: bool = configure_logging
self.ctx: Any = ctx or SimpleNamespace()
self.ctx: ctx_type = cast(ctx_type, ctx or SimpleNamespace())
self.error_handler: ErrorHandler = error_handler or ErrorHandler()
self.inspector_class: Type[Inspector] = inspector_class or Inspector
self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list)
Expand Down
2 changes: 1 addition & 1 deletion sanic/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class Blueprint(BaseSanic):

def __init__(
self,
name: str = None,
name: str,
url_prefix: Optional[str] = None,
host: Optional[Union[List[str], str]] = None,
version: Optional[Union[int, str, float]] = None,
Expand Down
2 changes: 1 addition & 1 deletion sanic/errorpages.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def exception_response(
debug: bool,
fallback: str,
base: t.Type[BaseRenderer],
renderer: t.Type[t.Optional[BaseRenderer]] = None,
renderer: t.Optional[t.Type[BaseRenderer]] = None,
) -> HTTPResponse:
"""
Render a response for the default FALLBACK exception handler.
Expand Down
2 changes: 1 addition & 1 deletion sanic/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(

super().__init__(message)

self.status_code = status_code
self.status_code = status_code or self.status_code
self.quiet = quiet
self.headers = headers

Expand Down
2 changes: 1 addition & 1 deletion sanic/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def format_http1_response(status: int, headers: HeaderBytesIterable) -> bytes:

def parse_credentials(
header: Optional[str],
prefixes: Union[List, Tuple, Set] = None,
prefixes: Optional[Union[List, Tuple, Set]] = None,
) -> Tuple[Optional[str], Optional[str]]:
"""Parses any header with the aim to retrieve any credentials from it."""
if not prefixes or not isinstance(prefixes, (list, tuple, set)):
Expand Down
41 changes: 35 additions & 6 deletions sanic/request/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from contextvars import ContextVar
from inspect import isawaitable
from types import SimpleNamespace
from typing import (
TYPE_CHECKING,
Any,
DefaultDict,
Dict,
Generic,
List,
Optional,
Tuple,
Expand All @@ -15,6 +17,7 @@
)

from sanic_routing.route import Route
from typing_extensions import TypeVar

from sanic.http.constants import HTTP # type: ignore
from sanic.http.stream import Stream
Expand All @@ -23,13 +26,13 @@


if TYPE_CHECKING:
from sanic.server import ConnInfo
from sanic.app import Sanic
from sanic.config import Config
from sanic.server import ConnInfo

import uuid

from collections import defaultdict
from types import SimpleNamespace
from urllib.parse import parse_qs, parse_qsl, urlunparse

from httptools import parse_url
Expand Down Expand Up @@ -68,8 +71,21 @@
except ImportError:
from json import loads as json_loads # type: ignore

if TYPE_CHECKING:
# The default argument of TypeVar is proposed to be added in Python 3.13
# by PEP 696 (https://www.python.org/dev/peps/pep-0696/).
# Therefore, we use typing_extensions.TypeVar for compatibility.
# For more information, see:
# https://discuss.python.org/t/pep-696-type-defaults-for-typevarlikes
sanic_type = TypeVar(
"sanic_type", bound=Sanic, default=Sanic[Config, SimpleNamespace]
)
else:
sanic_type = TypeVar("sanic_type")
ctx_type = TypeVar("ctx_type")


class Request:
class Request(Generic[sanic_type, ctx_type]):
"""
Properties of an HTTP request such as URL, headers, etc.
"""
Expand All @@ -80,6 +96,7 @@ class Request:
__slots__ = (
"__weakref__",
"_cookies",
"_ctx",
"_id",
"_ip",
"_parsed_url",
Expand All @@ -96,7 +113,6 @@ class Request:
"app",
"body",
"conn_info",
"ctx",
"head",
"headers",
"method",
Expand Down Expand Up @@ -125,7 +141,7 @@ def __init__(
version: str,
method: str,
transport: TransportProtocol,
app: Sanic,
app: sanic_type,
head: bytes = b"",
stream_id: int = 0,
):
Expand All @@ -149,7 +165,7 @@ def __init__(
# Init but do not inhale
self.body = b""
self.conn_info: Optional[ConnInfo] = None
self.ctx = SimpleNamespace()
self._ctx: Optional[ctx_type] = None
self.parsed_accept: Optional[AcceptList] = None
self.parsed_args: DefaultDict[
Tuple[bool, bool, str, str], RequestParameters
Expand All @@ -176,6 +192,10 @@ def __repr__(self):
class_name = self.__class__.__name__
return f"<{class_name}: {self.method} {self.path}>"

@staticmethod
def make_context() -> ctx_type:
return cast(ctx_type, SimpleNamespace())

@classmethod
def get_current(cls) -> Request:
"""
Expand Down Expand Up @@ -205,6 +225,15 @@ def get_current(cls) -> Request:
def generate_id(*_):
return uuid.uuid4()

@property
def ctx(self) -> ctx_type:
"""
:return: The current request context
"""
if not self._ctx:
self._ctx = self.make_context()
return self._ctx

@property
def stream_id(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion sanic/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def add( # type: ignore
strict_slashes: bool = False,
stream: bool = False,
ignore_body: bool = False,
version: Union[str, float, int] = None,
version: Optional[Union[str, float, int]] = None,
name: Optional[str] = None,
unquote: bool = False,
static: bool = False,
Expand Down
1 change: 1 addition & 0 deletions sanic/server/websockets/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ async def get(self, timeout: Optional[float] = None) -> Optional[Data]:
If ``timeout`` is set and elapses before a complete message is
received, :meth:`get` returns ``None``.
"""
completed: bool
async with self.read_mutex:
if timeout is not None and timeout <= 0:
if not self.message_complete.is_set():
Expand Down
2 changes: 1 addition & 1 deletion sanic/server/websockets/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


try: # websockets < 11.0
from websockets.connection import Event, State
from websockets.connection import Event, State # type: ignore
from websockets.server import ServerConnection as ServerProtocol
except ImportError: # websockets >= 11.0
from websockets.protocol import Event, State # type: ignore
Expand Down

0 comments on commit dc3c4d1

Please sign in to comment.