diff --git a/dns/asyncquery.py b/dns/asyncquery.py index e3003b1f..f7d4df44 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -41,6 +41,7 @@ BadResponse, NoDOH, NoDOQ, + HTTPVersion, UDPMode, _check_status, _compute_times, @@ -533,7 +534,7 @@ async def https( bootstrap_address: Optional[str] = None, resolver: Optional["dns.asyncresolver.Resolver"] = None, family: int = socket.AF_UNSPEC, - h3: bool = False, + http_version: HTTPVersion = HTTPVersion.DEFAULT, ) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. @@ -559,7 +560,7 @@ async def https( else: url = where - if h3: + if http_version == HTTPVersion.H3 or (http_version == HTTPVersion.DEFAULT and not have_doh): if bootstrap_address is None: parsed = urllib.parse.urlparse(url) resolver = _maybe_get_resolver(resolver) @@ -595,6 +596,9 @@ async def https( transport = None headers = {"accept": "application/dns-message"} + h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT) + h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT) + backend = dns.asyncbackend.get_default_backend() if source is None: @@ -605,8 +609,8 @@ async def https( local_port = source_port transport = backend.get_transport_class()( local_address=local_address, - http1=True, - http2=True, + http1=h1, + http2=h2, verify=verify, local_port=local_port, bootstrap_address=bootstrap_address, @@ -618,7 +622,7 @@ async def https( cm: contextlib.AbstractAsyncContextManager = NullContext(client) else: cm = httpx.AsyncClient( - http1=True, http2=True, verify=verify, transport=transport + http1=h1, http2=h2, verify=verify, transport=transport ) async with cm as the_client: diff --git a/dns/nameserver.py b/dns/nameserver.py index e8068e7e..b02a239b 100644 --- a/dns/nameserver.py +++ b/dns/nameserver.py @@ -168,14 +168,14 @@ def __init__( bootstrap_address: Optional[str] = None, verify: Union[bool, str] = True, want_get: bool = False, - h3: bool = False, + http_version: dns.query.HTTPVersion = dns.query.HTTPVersion.DEFAULT, ): super().__init__() self.url = url self.bootstrap_address = bootstrap_address self.verify = verify self.want_get = want_get - self.h3 = h3 + self.http_version = http_version def kind(self): return "DoH" @@ -216,7 +216,7 @@ def query( ignore_trailing=ignore_trailing, verify=self.verify, post=(not self.want_get), - h3=self.h3, + http_version=self.http_version, ) async def async_query( @@ -241,7 +241,7 @@ async def async_query( ignore_trailing=ignore_trailing, verify=self.verify, post=(not self.want_get), - h3=self.h3, + http_version=self.http_version, ) diff --git a/dns/query.py b/dns/query.py index bfd6908c..f3907c6f 100644 --- a/dns/query.py +++ b/dns/query.py @@ -351,6 +351,22 @@ def _maybe_get_resolver( return resolver +class HTTPVersion(enum.IntEnum): + """Which version of HTTP should be used? + + DEFAULT will select the first version from the list [2, 1.1, 3] that + is available. + """ + + DEFAULT = 0 + HTTP_1 = 1 + H1 = 1 + HTTP_2 = 2 + H2 = 2 + HTTP_3 = 3 + H3 = 3 + + def https( q: dns.message.Message, where: str, @@ -367,7 +383,7 @@ def https( verify: Union[bool, str] = True, resolver: Optional["dns.resolver.Resolver"] = None, family: int = socket.AF_UNSPEC, - h3: bool = False, + http_version: HTTPVersion = HTTPVersion.DEFAULT, ) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. @@ -417,7 +433,7 @@ def https( *family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A and AAAA records will be retrieved. - *h3*, a ``bool``. If ``True``, use HTTP/3 otherwise use HTTP/2 or HTTP/1.1. + *http_version*, a ``dns.query.HTTPVersion``, indicating which HTTP version to use. Returns a ``dns.message.Message``. """ @@ -433,7 +449,7 @@ def https( else: url = where - if h3: + if http_version == HTTPVersion.H3 or (http_version == HTTPVersion.DEFAULT and not have_doh): if bootstrap_address is None: parsed = urllib.parse.urlparse(url) resolver = _maybe_get_resolver(resolver) @@ -469,6 +485,9 @@ def https( transport = None headers = {"accept": "application/dns-message"} + h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT) + h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT) + # set source port and source address if the_source is None: @@ -479,8 +498,8 @@ def https( local_port = the_source[1] transport = _HTTPTransport( local_address=local_address, - http1=True, - http2=True, + http1=h1, + http2=h2, verify=verify, local_port=local_port, bootstrap_address=bootstrap_address, @@ -491,7 +510,7 @@ def https( if session: cm: contextlib.AbstractContextManager = contextlib.nullcontext(session) else: - cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport) + cm = httpx.Client(http1=h1, http2=h2, verify=verify, transport=transport) with cm as session: # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # GET and POST examples diff --git a/tests/test_async.py b/tests/test_async.py index e1cb8610..f0c227de 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -570,7 +570,7 @@ async def run(): post=False, timeout=4, family=family, - h3=True, + http_version=dns.asyncquery.HTTPVersion.H3, ) self.assertTrue(q.is_response(r)) @@ -587,7 +587,7 @@ async def run(): post=True, timeout=4, family=family, - h3=True, + http_version=dns.asyncquery.HTTPVersion.H3, ) self.assertTrue(q.is_response(r)) @@ -603,7 +603,7 @@ async def run(): nameserver_ip, post=False, timeout=4, - h3=True, + http_version=dns.asyncquery.HTTPVersion.H3, ) self.assertTrue(q.is_response(r)) diff --git a/tests/test_doh.py b/tests/test_doh.py index 692b2d67..900a3fae 100644 --- a/tests/test_doh.py +++ b/tests/test_doh.py @@ -203,7 +203,7 @@ def testDoH3GetRequest(self): post=False, timeout=4, family=family, - h3=True, + http_version=dns.query.HTTPVersion.H3, ) self.assertTrue(q.is_response(r)) @@ -216,7 +216,7 @@ def testDoH3PostRequest(self): post=True, timeout=4, family=family, - h3=True, + http_version=dns.query.HTTPVersion.H3, ) self.assertTrue(q.is_response(r)) @@ -233,7 +233,7 @@ def test_build_url_from_ip(self): nameserver_ip, post=False, timeout=4, - h3=True, + http_version=dns.query.HTTPVersion.H3, ) self.assertTrue(q.is_response(r)) if resolver_v6_addresses: @@ -244,7 +244,7 @@ def test_build_url_from_ip(self): nameserver_ip, post=False, timeout=4, - h3=True, + http_version=dns.query.HTTPVersion.H3, ) self.assertTrue(q.is_response(r))