Skip to content

Commit

Permalink
DNS-over-HTTP3 (#1048)
Browse files Browse the repository at this point in the history
* Implement DNS-over-HTTP3 using aioquic directly.

* Add h3 support for DoHNameserver.
  • Loading branch information
rthalley committed Feb 24, 2024
1 parent 6695d80 commit e9d58f2
Show file tree
Hide file tree
Showing 10 changed files with 585 additions and 65 deletions.
120 changes: 110 additions & 10 deletions dns/asyncquery.py
Expand Up @@ -19,9 +19,11 @@

import base64
import contextlib
import random
import socket
import struct
import time
import urllib.parse
from typing import Any, Dict, Optional, Tuple, Union

import dns.asyncbackend
Expand All @@ -40,6 +42,7 @@
NoDOH,
NoDOQ,
UDPMode,
_check_status,
_compute_times,
_make_dot_ssl_context,
_matches_destination,
Expand Down Expand Up @@ -500,6 +503,20 @@ async def tls(
return response


def _maybe_get_resolver(
resolver: Optional["dns.asyncresolver.Resolver"],
) -> "dns.asyncresolver.Resolver":
# We need a separate method for this to avoid overriding the global
# variable "dns" with the as-yet undefined local variable "dns"
# in https().
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver

resolver = dns.asyncresolver.Resolver()
return resolver


async def https(
q: dns.message.Message,
where: str,
Expand All @@ -515,7 +532,8 @@ async def https(
verify: Union[bool, str] = True,
bootstrap_address: Optional[str] = None,
resolver: Optional["dns.asyncresolver.Resolver"] = None,
family: Optional[int] = socket.AF_UNSPEC,
family: int = socket.AF_UNSPEC,
h3: bool = False,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS.
Expand All @@ -529,18 +547,10 @@ async def https(
parameters, exceptions, and return type of this method.
"""

if not have_doh:
raise NoDOH # pragma: no cover
if client and not isinstance(client, httpx.AsyncClient):
raise ValueError("session parameter must be an httpx.AsyncClient")

wire = q.to_wire()
try:
af = dns.inet.af_for_address(where)
except ValueError:
af = None
transport = None
headers = {"accept": "application/dns-message"}
if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET:
url = "https://{}:{}{}".format(where, port, path)
Expand All @@ -549,6 +559,39 @@ async def https(
else:
url = where

if h3:
if bootstrap_address is None:
parsed = urllib.parse.urlparse(url)
resolver = _maybe_get_resolver(resolver)
if parsed.hostname is None:
raise ValueError("no hostname in URL")
answers = await resolver.resolve_name(parsed.hostname, family)
bootstrap_address = random.choice(list(answers.addresses()))
if parsed.port is not None:
port = parsed.port
return await _http3(
q,
bootstrap_address,
url,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
verify=verify,
post=post,
)

if not have_doh:
raise NoDOH # pragma: no cover
if client and not isinstance(client, httpx.AsyncClient):
raise ValueError("session parameter must be an httpx.AsyncClient")

wire = q.to_wire()
transport = None
headers = {"accept": "application/dns-message"}

backend = dns.asyncbackend.get_default_backend()

if source is None:
Expand Down Expand Up @@ -617,6 +660,57 @@ async def https(
return r


async def _http3(
q: dns.message.Message,
where: str,
url: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
verify: Union[bool, str] = True,
backend: Optional[dns.asyncbackend.Backend] = None,
hostname: Optional[str] = None,
post: bool = True,
) -> dns.message.Message:
if not dns.quic.have_quic:
raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover

url_parts = urllib.parse.urlparse(url)
hostname = url_parts.hostname

q.id = 0
wire = q.to_wire()
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)

async with cfactory() as context:
async with mfactory(
context, verify_mode=verify, server_name=hostname, h3=True
) as the_manager:
the_connection = the_manager.connect(where, port, source, source_port)
(start, expiration) = _compute_times(timeout)
stream = await the_connection.make_stream(timeout)
async with stream:
# note that send_h3() does not need await
stream.send_h3(url, wire, post)
wire = await stream.receive(_remaining(expiration))
_check_status(stream.headers(), where, wire)
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r


async def inbound_xfr(
where: str,
txn_manager: dns.transaction.TransactionManager,
Expand Down Expand Up @@ -730,6 +824,7 @@ async def quic(
connection: Optional[dns.quic.AsyncQuicConnection] = None,
verify: Union[bool, str] = True,
backend: Optional[dns.asyncbackend.Backend] = None,
hostname: Optional[str] = None,
server_hostname: Optional[str] = None,
) -> dns.message.Message:
"""Return the response obtained after sending an asynchronous query via
Expand All @@ -745,6 +840,9 @@ async def quic(
if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover

if server_hostname is not None and hostname is None:
hostname = server_hostname

q.id = 0
wire = q.to_wire()
the_connection: dns.quic.AsyncQuicConnection
Expand All @@ -757,7 +855,9 @@ async def quic(

async with cfactory() as context:
async with mfactory(
context, verify_mode=verify, server_name=server_hostname
context,
verify_mode=verify,
server_name=server_hostname,
) as the_manager:
if not connection:
the_connection = the_manager.connect(where, port, source, source_port)
Expand Down
4 changes: 4 additions & 0 deletions dns/nameserver.py
Expand Up @@ -168,12 +168,14 @@ def __init__(
bootstrap_address: Optional[str] = None,
verify: Union[bool, str] = True,
want_get: bool = False,
h3: bool = False,
):
super().__init__()
self.url = url
self.bootstrap_address = bootstrap_address
self.verify = verify
self.want_get = want_get
self.h3 = h3

def kind(self):
return "DoH"
Expand Down Expand Up @@ -214,6 +216,7 @@ def query(
ignore_trailing=ignore_trailing,
verify=self.verify,
post=(not self.want_get),
h3=self.h3,
)

async def async_query(
Expand All @@ -238,6 +241,7 @@ async def async_query(
ignore_trailing=ignore_trailing,
verify=self.verify,
post=(not self.want_get),
h3=self.h3,
)


Expand Down

0 comments on commit e9d58f2

Please sign in to comment.