-
Notifications
You must be signed in to change notification settings - Fork 93
/
quic.py
125 lines (110 loc) · 4.65 KB
/
quic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from functools import partial
from typing import Awaitable, Callable, Dict, Optional, Tuple
from aioquic.buffer import Buffer
from aioquic.h3.connection import H3_ALPN
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection
from aioquic.quic.events import (
ConnectionIdIssued,
ConnectionIdRetired,
ConnectionTerminated,
ProtocolNegotiated,
)
from aioquic.quic.packet import (
encode_quic_version_negotiation,
PACKET_TYPE_INITIAL,
pull_quic_header,
)
from .h3 import H3Protocol
from ..config import Config
from ..events import Closed, Event, RawData
from ..typing import ASGIFramework, Context
class QuicProtocol:
def __init__(
self,
app: ASGIFramework,
config: Config,
context: Context,
server: Optional[Tuple[str, int]],
send: Callable[[Event], Awaitable[None]],
) -> None:
self.app = app
self.config = config
self.context = context
self.connections: Dict[bytes, QuicConnection] = {}
self.http_connections: Dict[QuicConnection, H3Protocol] = {}
self.send = send
self.server = server
self.quic_config = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=False)
self.quic_config.load_cert_chain(certfile=config.certfile, keyfile=config.keyfile)
async def handle(self, event: Event) -> None:
if isinstance(event, RawData):
try:
header = pull_quic_header(Buffer(data=event.data), host_cid_length=8)
except ValueError:
return
if (
header.version is not None
and header.version not in self.quic_config.supported_versions
):
data = encode_quic_version_negotiation(
source_cid=header.destination_cid,
destination_cid=header.source_cid,
supported_versions=self.quic_config.supported_versions,
)
await self.send(RawData(data=data, address=event.address))
return
connection = self.connections.get(header.destination_cid)
if (
connection is None
and len(event.data) >= 1200
and header.packet_type == PACKET_TYPE_INITIAL
):
connection = QuicConnection(
configuration=self.quic_config,
original_destination_connection_id=header.destination_cid,
)
self.connections[header.destination_cid] = connection
self.connections[connection.host_cid] = connection
if connection is not None:
connection.receive_datagram(event.data, event.address, now=self.context.time())
await self._handle_events(connection, event.address)
elif isinstance(event, Closed):
pass
async def send_all(self, connection: QuicConnection) -> None:
for data, address in connection.datagrams_to_send(now=self.context.time()):
await self.send(RawData(data=data, address=address))
async def _handle_events(
self, connection: QuicConnection, client: Optional[Tuple[str, int]] = None
) -> None:
event = connection.next_event()
while event is not None:
if isinstance(event, ConnectionTerminated):
pass
elif isinstance(event, ProtocolNegotiated):
self.http_connections[connection] = H3Protocol(
self.app,
self.config,
self.context,
client,
self.server,
connection,
partial(self.send_all, connection),
)
elif isinstance(event, ConnectionIdIssued):
self.connections[event.connection_id] = connection
elif isinstance(event, ConnectionIdRetired):
del self.connections[event.connection_id]
if connection in self.http_connections:
await self.http_connections[connection].handle(event)
event = connection.next_event()
await self.send_all(connection)
timer = connection.get_timer()
if timer is not None:
self.context.spawn(self._handle_timer, timer, connection)
async def _handle_timer(self, timer: float, connection: QuicConnection) -> None:
wait = max(0, timer - self.context.time())
await self.context.sleep(wait)
if connection._close_at is not None:
connection.handle_timer(now=self.context.time())
await self._handle_events(connection, None)