Skip to content

Commit

Permalink
Add typing and enforce checking via tox/CI
Browse files Browse the repository at this point in the history
This uses the same mypy settings as wsproto.

The Sentinel values are problematic, but I've found no good solution
that also has the property that type(Sentinel) is Sentinel - so this
should suffice for now.

Whilst I've been lazy with the tests, I've mostly avoided type ignores
in the main code. This should ensure that mypyc can be used if
desired.
  • Loading branch information
pgjones committed Nov 7, 2021
1 parent f643a8e commit eedd7bd
Show file tree
Hide file tree
Showing 21 changed files with 650 additions and 319 deletions.
59 changes: 50 additions & 9 deletions h11/__init__.py
Expand Up @@ -6,16 +6,57 @@
# semantics to check that what you're asking to write to the wire is sensible,
# but at least it gets you out of dealing with the wire itself.

from ._connection import *
from ._events import *
from ._state import *
from ._util import LocalProtocolError, ProtocolError, RemoteProtocolError
from ._version import __version__
from h11._connection import Connection, NEED_DATA, PAUSED
from h11._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from h11._state import (
CLIENT,
CLOSED,
DONE,
ERROR,
IDLE,
MIGHT_SWITCH_PROTOCOL,
MUST_CLOSE,
SEND_BODY,
SEND_RESPONSE,
SERVER,
SWITCHED_PROTOCOL,
)
from h11._util import LocalProtocolError, ProtocolError, RemoteProtocolError
from h11._version import __version__

PRODUCT_ID = "python-h11/" + __version__


__all__ = ["ProtocolError", "LocalProtocolError", "RemoteProtocolError"]
__all__ += _events.__all__
__all__ += _connection.__all__
__all__ += _state.__all__
__all__ = (
"Connection",
"NEED_DATA",
"PAUSED",
"ConnectionClosed",
"Data",
"EndOfMessage",
"Event",
"InformationalResponse",
"Request",
"Response",
"CLIENT",
"CLOSED",
"DONE",
"ERROR",
"IDLE",
"MUST_CLOSE",
"SEND_BODY",
"SEND_RESPONSE",
"SERVER",
"SWITCHED_PROTOCOL",
"ProtocolError",
"LocalProtocolError",
"RemoteProtocolError",
)
118 changes: 76 additions & 42 deletions h11/_connection.py
@@ -1,18 +1,38 @@
# This contains the main Connection class. Everything in h11 revolves around
# this.

from ._events import * # Import all event types
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union

from ._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from ._headers import get_comma_header, has_expect_100_continue, set_comma_header
from ._readers import READERS
from ._readers import READERS, ReadersType
from ._receivebuffer import ReceiveBuffer
from ._state import * # Import all state sentinels
from ._state import _SWITCH_CONNECT, _SWITCH_UPGRADE, ConnectionState
from ._state import (
_SWITCH_CONNECT,
_SWITCH_UPGRADE,
CLIENT,
ConnectionState,
DONE,
ERROR,
MIGHT_SWITCH_PROTOCOL,
SEND_BODY,
SERVER,
SWITCHED_PROTOCOL,
)
from ._util import ( # Import the internal things we need
LocalProtocolError,
make_sentinel,
RemoteProtocolError,
Sentinel,
)
from ._writers import WRITERS
from ._writers import WRITERS, WritersType

# Everything in __all__ gets re-exported as part of the h11 public API.
__all__ = ["Connection", "NEED_DATA", "PAUSED"]
Expand Down Expand Up @@ -44,7 +64,7 @@
# our rule is:
# - If someone says Connection: close, we will close
# - If someone uses HTTP/1.0, we will close.
def _keep_alive(event):
def _keep_alive(event: Union[Request, Response]) -> bool:
connection = get_comma_header(event.headers, b"connection")
if b"close" in connection:
return False
Expand All @@ -53,7 +73,9 @@ def _keep_alive(event):
return True


def _body_framing(request_method, event):
def _body_framing(
request_method: bytes, event: Union[Request, Response]
) -> Tuple[str, Union[Tuple[()], Tuple[int]]]:
# Called when we enter SEND_BODY to figure out framing information for
# this body.
#
Expand Down Expand Up @@ -126,8 +148,10 @@ class Connection:
"""

def __init__(
self, our_role, max_incomplete_event_size=DEFAULT_MAX_INCOMPLETE_EVENT_SIZE
):
self,
our_role: Sentinel,
max_incomplete_event_size: int = DEFAULT_MAX_INCOMPLETE_EVENT_SIZE,
) -> None:
self._max_incomplete_event_size = max_incomplete_event_size
# State and role tracking
if our_role not in (CLIENT, SERVER):
Expand Down Expand Up @@ -155,14 +179,14 @@ def __init__(
# These two are only used to interpret framing headers for figuring
# out how to read/write response bodies. their_http_version is also
# made available as a convenient public API.
self.their_http_version = None
self._request_method = None
self.their_http_version: Optional[bytes] = None
self._request_method: Optional[bytes] = None
# This is pure flow-control and doesn't at all affect the set of legal
# transitions, so no need to bother ConnectionState with it:
self.client_is_waiting_for_100_continue = False

@property
def states(self):
def states(self) -> Dict[Sentinel, Sentinel]:
"""A dictionary like::
{CLIENT: <client state>, SERVER: <server state>}
Expand All @@ -173,24 +197,24 @@ def states(self):
return dict(self._cstate.states)

@property
def our_state(self):
def our_state(self) -> Sentinel:
"""The current state of whichever role we are playing. See
:ref:`state-machine` for details.
"""
return self._cstate.states[self.our_role]

@property
def their_state(self):
def their_state(self) -> Sentinel:
"""The current state of whichever role we are NOT playing. See
:ref:`state-machine` for details.
"""
return self._cstate.states[self.their_role]

@property
def they_are_waiting_for_100_continue(self):
def they_are_waiting_for_100_continue(self) -> bool:
return self.their_role is CLIENT and self.client_is_waiting_for_100_continue

def start_next_cycle(self):
def start_next_cycle(self) -> None:
"""Attempt to reset our connection state for a new request/response
cycle.
Expand All @@ -210,12 +234,12 @@ def start_next_cycle(self):
assert not self.client_is_waiting_for_100_continue
self._respond_to_state_changes(old_states)

def _process_error(self, role):
def _process_error(self, role: Sentinel) -> None:
old_states = dict(self._cstate.states)
self._cstate.process_error(role)
self._respond_to_state_changes(old_states)

def _server_switch_event(self, event):
def _server_switch_event(self, event: Event) -> Optional[Sentinel]:
if type(event) is InformationalResponse and event.status_code == 101:
return _SWITCH_UPGRADE
if type(event) is Response:
Expand All @@ -227,7 +251,7 @@ def _server_switch_event(self, event):
return None

# All events go through here
def _process_event(self, role, event):
def _process_event(self, role: Sentinel, event: Event) -> None:
# First, pass the event through the state machine to make sure it
# succeeds.
old_states = dict(self._cstate.states)
Expand All @@ -243,16 +267,15 @@ def _process_event(self, role, event):

# Then perform the updates triggered by it.

# self._request_method
if type(event) is Request:
self._request_method = event.method

# self.their_http_version
if role is self.their_role and type(event) in (
Request,
Response,
InformationalResponse,
):
event = cast(Union[Request, Response, InformationalResponse], event)
self.their_http_version = event.http_version

# Keep alive handling
Expand All @@ -261,7 +284,9 @@ def _process_event(self, role, event):
# shows up on a 1xx InformationalResponse. I think the idea is that
# this is not supposed to happen. In any case, if it does happen, we
# ignore it.
if type(event) in (Request, Response) and not _keep_alive(event):
if type(event) in (Request, Response) and not _keep_alive(
cast(Union[Request, Response], event)
):
self._cstate.process_keep_alive_disabled()

# 100-continue
Expand All @@ -274,30 +299,39 @@ def _process_event(self, role, event):

self._respond_to_state_changes(old_states, event)

def _get_io_object(self, role, event, io_dict):
def _get_io_object(
self,
role: Sentinel,
event: Optional[Event],
io_dict: Union[ReadersType, WritersType],
) -> Optional[Callable[..., Any]]:
# event may be None; it's only used when entering SEND_BODY
state = self._cstate.states[role]
if state is SEND_BODY:
# Special case: the io_dict has a dict of reader/writer factories
# that depend on the request/response framing.
framing_type, args = _body_framing(self._request_method, event)
return io_dict[SEND_BODY][framing_type](*args)
framing_type, args = _body_framing(
cast(bytes, self._request_method), cast(Union[Request, Response], event)
)
return io_dict[SEND_BODY][framing_type](*args) # type: ignore[index]
else:
# General case: the io_dict just has the appropriate reader/writer
# for this state
return io_dict.get((role, state))
return io_dict.get((role, state)) # type: ignore[return-value]

# This must be called after any action that might have caused
# self._cstate.states to change.
def _respond_to_state_changes(self, old_states, event=None):
def _respond_to_state_changes(
self, old_states: Dict[Sentinel, Sentinel], event: Optional[Event] = None
) -> None:
# Update reader/writer
if self.our_state != old_states[self.our_role]:
self._writer = self._get_io_object(self.our_role, event, WRITERS)
if self.their_state != old_states[self.their_role]:
self._reader = self._get_io_object(self.their_role, event, READERS)

@property
def trailing_data(self):
def trailing_data(self) -> Tuple[bytes, bool]:
"""Data that has been received, but not yet processed, represented as
a tuple with two elements, where the first is a byte-string containing
the unprocessed data itself, and the second is a bool that is True if
Expand All @@ -307,7 +341,7 @@ def trailing_data(self):
"""
return (bytes(self._receive_buffer), self._receive_buffer_closed)

def receive_data(self, data):
def receive_data(self, data: bytes) -> None:
"""Add data to our internal receive buffer.
This does not actually do any processing on the data, just stores
Expand Down Expand Up @@ -353,7 +387,7 @@ def receive_data(self, data):
else:
self._receive_buffer_closed = True

def _extract_next_receive_event(self):
def _extract_next_receive_event(self) -> Union[Event, Sentinel]:
state = self.their_state
# We don't pause immediately when they enter DONE, because even in
# DONE state we can still process a ConnectionClosed() event. But
Expand All @@ -372,14 +406,14 @@ def _extract_next_receive_event(self):
# return that event, and then the state will change and we'll
# get called again to generate the actual ConnectionClosed().
if hasattr(self._reader, "read_eof"):
event = self._reader.read_eof()
event = self._reader.read_eof() # type: ignore[attr-defined]
else:
event = ConnectionClosed()
if event is None:
event = NEED_DATA
return event
return event # type: ignore[no-any-return]

def next_event(self):
def next_event(self) -> Union[Event, Sentinel]:
"""Parse the next event out of our receive buffer, update our internal
state, and return it.
Expand Down Expand Up @@ -424,7 +458,7 @@ def next_event(self):
try:
event = self._extract_next_receive_event()
if event not in [NEED_DATA, PAUSED]:
self._process_event(self.their_role, event)
self._process_event(self.their_role, cast(Event, event))
if event is NEED_DATA:
if len(self._receive_buffer) > self._max_incomplete_event_size:
# 431 is "Request header fields too large" which is pretty
Expand All @@ -444,7 +478,7 @@ def next_event(self):
else:
raise

def send(self, event):
def send(self, event: Event) -> Optional[bytes]:
"""Convert a high-level event into bytes that can be sent to the peer,
while updating our internal state machine.
Expand All @@ -471,7 +505,7 @@ def send(self, event):
else:
return b"".join(data_list)

def send_with_data_passthrough(self, event):
def send_with_data_passthrough(self, event: Event) -> Optional[List[bytes]]:
"""Identical to :meth:`send`, except that in situations where
:meth:`send` returns a single :term:`bytes-like object`, this instead
returns a list of them -- and when sending a :class:`Data` event, this
Expand All @@ -497,14 +531,14 @@ def send_with_data_passthrough(self, event):
# In any situation where writer is None, process_event should
# have raised ProtocolError
assert writer is not None
data_list = []
data_list: List[bytes] = []
writer(event, data_list.append)
return data_list
except:
self._process_error(self.our_role)
raise

def send_failed(self):
def send_failed(self) -> None:
"""Notify the state machine that we failed to send the data it gave
us.
Expand All @@ -529,7 +563,7 @@ def send_failed(self):
# This function's *only* responsibility is making sure headers are set up
# right -- everything downstream just looks at the headers. There are no
# side channels.
def _clean_up_response_headers_for_sending(self, response):
def _clean_up_response_headers_for_sending(self, response: Response) -> Response:
assert type(response) is Response

headers = response.headers
Expand All @@ -542,7 +576,7 @@ def _clean_up_response_headers_for_sending(self, response):
# we're allowed to leave out the framing headers -- see
# https://tools.ietf.org/html/rfc7231#section-4.3.2 . But it's just as
# easy to get them right.)
method_for_choosing_headers = self._request_method
method_for_choosing_headers = cast(bytes, self._request_method)
if method_for_choosing_headers == b"HEAD":
method_for_choosing_headers = b"GET"
framing_type, _ = _body_framing(method_for_choosing_headers, response)
Expand Down Expand Up @@ -572,7 +606,7 @@ def _clean_up_response_headers_for_sending(self, response):
if self._request_method != b"HEAD":
need_close = True
else:
headers = set_comma_header(headers, b"transfer-encoding", ["chunked"])
headers = set_comma_header(headers, b"transfer-encoding", [b"chunked"])

if not self._cstate.keep_alive or need_close:
# Make sure Connection: close is set
Expand Down

0 comments on commit eedd7bd

Please sign in to comment.