Skip to content

Commit

Permalink
Introduce "session_closed" handler callback
Browse files Browse the repository at this point in the history
Introduce "session_closed" handler callback, and add support for
CLOSE_WEBTRANSPORT_SESSION capsule. There is no conforming
clients so it is not tested: I ran existing WPTs and confirmed
this didn't break them.

Also introduce WebTransportSession.dict_for_handlers to allow
handlers to put arbitrary data to the associated session. This is
different from Stash because Stash outlives sessions.
  • Loading branch information
yutakahirano committed Sep 24, 2021
1 parent 68c75f6 commit 92d8d96
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 16 deletions.
17 changes: 16 additions & 1 deletion tools/webtransport/h3/handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Optional, Tuple

from .webtransport_h3_server import WebTransportSession

Expand Down Expand Up @@ -46,3 +46,18 @@ def datagram_received(session: WebTransportSession, data: bytes) -> None:
:param data: The received data.
"""
pass


def session_closed(session: WebTransportSession,
close_info: Optional[Tuple[int, bytes]],
abruptly: bool) -> None:
"""
Called when a WebTransport session is closed.
:param session: A WebTransport session.
:param close_info: The code and reason attached to the
CLOSE_WEBTRANSPORT_SESSION capsule.
:param abruptly: True when the session is closed forcibly
(by a CLOSE_CONNECTION QUIC frame for example).
"""
pass
124 changes: 109 additions & 15 deletions tools/webtransport/h3/webtransport_h3_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import ssl
import threading
import traceback
from enum import IntEnum
from urllib.parse import urlparse
from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -14,7 +15,7 @@
from aioquic.h3.events import H3Event, HeadersReceived, WebTransportStreamDataReceived, DatagramReceived # type: ignore
from aioquic.quic.configuration import QuicConfiguration # type: ignore
from aioquic.quic.connection import stream_is_unidirectional # type: ignore
from aioquic.quic.events import QuicEvent, ProtocolNegotiated # type: ignore
from aioquic.quic.events import QuicEvent, ProtocolNegotiated, ConnectionTerminated # type: ignore
from aioquic.tls import SessionTicket # type: ignore
from aioquic.quic.packet import QuicErrorCode # type: ignore

Expand All @@ -35,11 +36,59 @@
_doc_root: str = ""


# TODO(yutakahirano): Use aioquic's H3Capsule when
# https://github.com/aiortc/aioquic/pull/229 is accepted.
class CapsuleType(IntEnum):
# Defined in
# https://www.ietf.org/archive/id/draft-ietf-masque-h3-datagram-03.html.
DATAGRAM = 0xff37a0
REGISTER_DATAGRAM_CONTEXT = 0xff37a1
REGISTER_DATAGRAM_NO_CONTEXT = 0xff37a2
CLOSE_DATAGRAM_CONTEXT = 0xff37a3
# Defined in
# https://www.ietf.org/archive/id/draft-ietf-webtrans-http3-01.html.
CLOSE_WEBTRANSPORT_SESSION = 0x2843


# TODO(yutakahirano): Use aioquic's H3Capsule when
# https://github.com/aiortc/aioquic/pull/229 is accepted.
class H3Capsule:
"""
Represents the Capsule concept defined in
https://ietf-wg-masque.github.io/draft-ietf-masque-h3-datagram/draft-ietf-masque-h3-datagram.html#name-capsules.
"""
def __init__(self, type: int, data: bytes) -> None:
self.type = type
self.data = data

@staticmethod
def decode(data: bytes) -> any:
"""
Returns an H3Capsule representing the given bytes.
"""
buffer = Buffer(data=data)
type = buffer.pull_uint_var()
length = buffer.pull_uint_var()
return H3Capsule(type, buffer.pull_bytes(length))

def encode(self) -> bytes:
"""
Encodes this H3Connection and return the bytes.
"""
buffer = Buffer(capacity=len(self.data) + 2 * UINT_VAR_MAX_SIZE)
buffer.push_uint_var(self.type)
buffer.push_uint_var(len(self.data))
buffer.push_bytes(self.data)
return buffer.data


class WebTransportH3Protocol(QuicConnectionProtocol):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._handler: Optional[Any] = None
self._http: Optional[H3Connection] = None
self._connect_stream_id: Optional[int] = None
self._connect_stream_data: bytes = b""

def quic_event_received(self, event: QuicEvent) -> None:
if isinstance(event, ProtocolNegotiated):
Expand All @@ -49,6 +98,9 @@ def quic_event_received(self, event: QuicEvent) -> None:
for http_event in self._http.handle_event(event):
self._h3_event_received(http_event)

if isinstance(event, ConnectionTerminated) and self._handler:
self._handler.session_closed(close_info=None, abruptly=True)

def _h3_event_received(self, event: H3Event) -> None:
if isinstance(event, HeadersReceived):
# Convert from List[Tuple[bytes, bytes]] to Dict[bytes, bytes].
Expand All @@ -64,12 +116,29 @@ def _h3_event_received(self, event: H3Event) -> None:
self._handshake_webtransport(event, headers)
else:
self._send_error_response(event.stream_id, 400)

if self._handler is not None:
self._session_stream_id = event.stream_id

if self._session_stream_id == event.stream_id and\
isinstance(event, WebTransportStreamDataReceived):
self._connect_stream_data += event.data
if self._handler is not None and event.stream_ended:
close_info = None
if len(self._connect_stream_data) > 0:
capsule = H3Capsule.decode(self._connect_stream_data)
close_info = (0, "")
if capsule.type == CapsuleType.CLOSE_WEBTRANSPORT_SESSION:
buffer = Buffer(capsule.data)
close_info[0] = buffer.pull_uint32()
reason = buffer.data()
# TODO(yutakahirano): Make sure `reason` is a
# UTF-8 text.
self._handler.session_closed(session, close_info, abruptly=False)
elif self._handler is not None:
if isinstance(event, WebTransportStreamDataReceived):
self._handler.stream_data_received(stream_id=event.stream_id,
data=event.data,
stream_ended=event.stream_ended)
self._handler.stream_data_received(
stream_id=event.stream_id,
data=event.data,
stream_ended=event.stream_ended)
elif isinstance(event, DatagramReceived):
self._handler.datagram_received(data=event.data)

Expand Down Expand Up @@ -146,6 +215,7 @@ def __init__(self, protocol: WebTransportH3Protocol, session_id: int,
# WebTransport sessions can access the same store easily.
self._stash_path = '/webtransport/handlers'
self._stash: Optional[stash.Stash] = None
self._dict_for_handlers: Dict[str, Any] = {}

@property
def stash(self) -> stash.Stash:
Expand All @@ -155,24 +225,39 @@ def stash(self) -> stash.Stash:
self._stash = stash.Stash(self._stash_path, address, authkey)
return self._stash

@property
def dict_for_handlers(self) -> Dict[str, Any]:
"""A dictionary that handlers can attach arbitrary data."""
return self._dict_for_handlers

def stream_is_unidirectional(self, stream_id: int) -> bool:
"""Return True if the stream is unidirectional."""
return stream_is_unidirectional(stream_id)

def close(self,
error_code: int = QuicErrorCode.NO_ERROR,
reason_phrase: str = "") -> None:
def close(self, close_info: Optional[Tuple[int, bytes]]) -> None:
"""
Close the session.
:param error_code: An error code indicating why the session is
being closed.
:param reason_phrase: A human readable explanation of why the
session is being closed.
:param close_info The close information to send.
"""
self._http._quic.close(error_code=error_code,
reason_phrase=reason_phrase)
assert self._session_stream_id is not None
if close_info is not None:
code = close_info.code
reason = close_info.reason
buffer = Buffer(capacity=len(reason) + 4)
buffer.push_uint32(code)
buffer.push_bytes(reason)
capsule = H3Capsule(CLOSE_WEBTRANSPORT_SESSION, buffer.data)
self.send_stream_data(self._session_stream_id, capsule.encode())

self.send_stream_data(self._session_stream_id, b'', end_stream=True)
self._protocol.transmit()
# TODO(yutakahirano): Reset all other streams.
# TODO(yutakahirano): Reject future stream open requests
# We need to wait for the stream data to arrive at the client, and then
# we need to close the connection. At this moment we're relying on the
# client's behavior.
# TODO(yutakahirano): Implement the above.

def create_unidirectional_stream(self) -> int:
"""
Expand Down Expand Up @@ -225,6 +310,7 @@ def __init__(self, session: WebTransportSession,
callbacks: Dict[str, Any]) -> None:
self._session = session
self._callbacks = callbacks
self._called_session_closed = False

def _run_callback(self, callback_name: str,
*args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -252,6 +338,14 @@ def stream_data_received(self, stream_id: int, data: bytes,
def datagram_received(self, data: bytes) -> None:
self._run_callback("datagram_received", self._session, data)

def session_closed(
self, close_info: Optional[Tuple[int, bytes]], abruptly: bool):
if _called_session_closed:
return
_called_session_closed = True
self._run_callback(
"session_closed", self._session, close_info, abruptly=abruptly)


class SessionTicketStore:
"""
Expand Down

0 comments on commit 92d8d96

Please sign in to comment.