Skip to content

Commit

Permalink
Test for recent-enough versions of optional packages. (#1041)
Browse files Browse the repository at this point in the history
  • Loading branch information
rthalley committed Feb 9, 2024
1 parent 1d35451 commit 3c6a797
Show file tree
Hide file tree
Showing 14 changed files with 243 additions and 80 deletions.
5 changes: 3 additions & 2 deletions dns/_asyncio_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys

import dns._asyncbackend
import dns._features
import dns.exception
import dns.inet

Expand Down Expand Up @@ -122,7 +123,7 @@ async def getpeercert(self, timeout):
return self.writer.get_extra_info("peercert")


try:
if dns._features.have("doh"):
import anyio
import httpcore
import httpcore._backends.anyio
Expand Down Expand Up @@ -206,7 +207,7 @@ def __init__(
resolver, local_port, bootstrap_address, family
)

except ImportError:
else:
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore


Expand Down
92 changes: 92 additions & 0 deletions dns/_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license

import importlib.metadata
import itertools
import string
from typing import Dict, List, Tuple


def _tuple_from_text(version: str) -> Tuple:
text_parts = version.split(".")
int_parts = []
for text_part in text_parts:
digit_prefix = "".join(
itertools.takewhile(lambda x: x in string.digits, text_part)
)
try:
int_parts.append(int(digit_prefix))
except Exception:
break
return tuple(int_parts)


def _version_check(
requirement: str,
) -> bool:
"""Is the requirement fulfilled?
The requirement must be of the form
package>=version
"""
package, minimum = requirement.split(">=")
try:
version = importlib.metadata.version(package)
except Exception:
return False
t_version = _tuple_from_text(version)
t_minimum = _tuple_from_text(minimum)
if t_version < t_minimum:
return False
return True


_cache: Dict[str, bool] = {}


def have(feature: str) -> bool:
"""Is *feature* available?
This tests if all optional packages needed for the
feature are available and recent enough.
Returns ``True`` if the feature is available,
and ``False`` if it is not or if metadata is
missing.
"""
value = _cache.get(feature)
if value is not None:
return value
requirements = _requirements.get(feature)
if requirements is None:
# we make a cache entry here for consistency not performance
_cache[feature] = False
return False
ok = True
for requirement in requirements:
if not _version_check(requirement):
ok = False
break
_cache[feature] = ok
return ok


def force(feature: str, enabled: bool) -> None:
"""Force the status of *feature* to be *enabled*.
This method is provided as a workaround for any cases
where importlib.metadata is ineffective, or for testing.
"""
_cache[feature] = enabled


_requirements: Dict[str, List[str]] = {
### BEGIN generated requirements
"dnssec": ["cryptography>=42"],
"doh": ["httpcore>=1.0.0", "httpx>=0.26.0", "h2>=4.1.0"],
"doq": ["aioquic>=0.9.25"],
"idna": ["idna>=3.6"],
"trio": ["trio>=0.23"],
"wmi": ["wmi>=1.5.1"],
### END generated requirements
}
8 changes: 6 additions & 2 deletions dns/_trio_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
import trio.socket # type: ignore

import dns._asyncbackend
import dns._features
import dns.exception
import dns.inet

if not dns._features.have("trio"):
raise ImportError("trio not found or too old")


def _maybe_timeout(timeout):
if timeout is not None:
Expand Down Expand Up @@ -95,7 +99,7 @@ async def getpeercert(self, timeout):
raise NotImplementedError


try:
if dns._features.have("doh"):
import httpcore
import httpcore._backends.trio
import httpx
Expand Down Expand Up @@ -177,7 +181,7 @@ def __init__(
resolver, local_port, bootstrap_address, family
)

except ImportError:
else:
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore


Expand Down
5 changes: 2 additions & 3 deletions dns/asyncquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
NoDOQ,
UDPMode,
_compute_times,
_have_http2,
_make_dot_ssl_context,
_matches_destination,
_remaining,
Expand Down Expand Up @@ -534,7 +533,7 @@ async def https(
transport = backend.get_transport_class()(
local_address=local_address,
http1=True,
http2=_have_http2,
http2=True,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
Expand All @@ -546,7 +545,7 @@ async def https(
cm: contextlib.AbstractAsyncContextManager = NullContext(client)
else:
cm = httpx.AsyncClient(
http1=True, http2=_have_http2, verify=verify, transport=transport
http1=True, http2=True, verify=verify, transport=transport
)

async with cm as the_client:
Expand Down
19 changes: 10 additions & 9 deletions dns/dnssec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from datetime import datetime
from typing import Callable, Dict, List, Optional, Set, Tuple, Union, cast

import dns._features
import dns.exception
import dns.name
import dns.node
Expand Down Expand Up @@ -1169,7 +1170,7 @@ def _need_pyca(*args, **kwargs):
) # pragma: no cover


try:
if dns._features.have("dnssec"):
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives.asymmetric import dsa # pylint: disable=W0611
from cryptography.hazmat.primitives.asymmetric import ec # pylint: disable=W0611
Expand All @@ -1184,20 +1185,20 @@ def _need_pyca(*args, **kwargs):
get_algorithm_cls_from_dnskey,
)
from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
except ImportError: # pragma: no cover
validate = _need_pyca
validate_rrsig = _need_pyca
sign = _need_pyca
make_dnskey = _need_pyca
make_cdnskey = _need_pyca
_have_pyca = False
else:

validate = _validate # type: ignore
validate_rrsig = _validate_rrsig # type: ignore
sign = _sign
make_dnskey = _make_dnskey
make_cdnskey = _make_cdnskey
_have_pyca = True
else: # pragma: no cover
validate = _need_pyca
validate_rrsig = _need_pyca
sign = _need_pyca
make_dnskey = _need_pyca
make_cdnskey = _need_pyca
_have_pyca = False

### BEGIN generated Algorithm constants

Expand Down
4 changes: 2 additions & 2 deletions dns/dnssecalgs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import dns.name

try:
if dns._features.have("dnssec"):
from dns.dnssecalgs.base import GenericPrivateKey
from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1
from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384
Expand All @@ -16,7 +16,7 @@
)

_have_cryptography = True
except ImportError:
else:
_have_cryptography = False

from dns.dnssectypes import Algorithm
Expand Down
7 changes: 4 additions & 3 deletions dns/name.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
import struct
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

try:
import dns._features

if dns._features.have("idna"):
import idna # type: ignore

have_idna_2008 = True
except ImportError: # pragma: no cover
else: # pragma: no cover
have_idna_2008 = False

import dns.enum
Expand Down Expand Up @@ -355,7 +357,6 @@ def _maybe_convert_to_binary(label: Union[bytes, str]) -> bytes:

@dns.immutable.immutable
class Name:

"""A DNS name.
The dns.name.Name class represents a DNS name as a tuple of
Expand Down
23 changes: 6 additions & 17 deletions dns/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import time
from typing import Any, Dict, Optional, Tuple, Union

import dns._features
import dns.exception
import dns.inet
import dns.message
Expand Down Expand Up @@ -58,24 +59,14 @@ def _expiration_for_this_attempt(timeout, expiration):
return min(time.time() + timeout, expiration)


_have_httpx = False
_have_http2 = False
try:
import httpcore
_have_httpx = dns._features.have("doh")
if _have_httpx:
import httpcore._backends.sync
import httpx

_CoreNetworkBackend = httpcore.NetworkBackend
_CoreSyncStream = httpcore._backends.sync.SyncStream

_have_httpx = True
try:
# See if http2 support is available.
with httpx.Client(http2=True):
_have_http2 = True
except Exception:
pass

class _NetworkBackend(_CoreNetworkBackend):
def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__()
Expand Down Expand Up @@ -148,7 +139,7 @@ def __init__(
resolver, local_port, bootstrap_address, family
)

except ImportError: # pragma: no cover
else:

class _HTTPTransport: # type: ignore
def connect_tcp(self, host, port, timeout, local_address):
Expand Down Expand Up @@ -462,7 +453,7 @@ def https(
transport = _HTTPTransport(
local_address=local_address,
http1=True,
http2=_have_http2,
http2=True,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
Expand All @@ -473,9 +464,7 @@ def https(
if session:
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
else:
cm = httpx.Client(
http1=True, http2=_have_http2, verify=verify, transport=transport
)
cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport)
with cm as session:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples
Expand Down
10 changes: 5 additions & 5 deletions dns/quic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license

try:
import dns._features

if dns._features.have("doq"):
import aioquic.quic.configuration # type: ignore

import dns.asyncbackend
Expand Down Expand Up @@ -31,7 +33,7 @@ def _asyncio_manager_factory(

_async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)}

try:
if dns._features.have("trio"):
import trio

from dns.quic._trio import ( # pylint: disable=ungrouped-imports
Expand All @@ -47,15 +49,13 @@ def _trio_manager_factory(context, *args, **kwargs):
return TrioQuicManager(context, *args, **kwargs)

_async_factories["trio"] = (_trio_context_factory, _trio_manager_factory)
except ImportError:
pass

def factories_for_backend(backend=None):
if backend is None:
backend = dns.asyncbackend.get_default_backend()
return _async_factories[backend.name()]

except ImportError:
else: # pragma: no cover
have_quic = False

from typing import Any
Expand Down
6 changes: 4 additions & 2 deletions dns/win32util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import sys

import dns._features

if sys.platform == "win32":
from typing import Any

Expand All @@ -15,14 +17,14 @@
except KeyError:
WindowsError = Exception

try:
if dns._features.have("wmi"):
import threading

import pythoncom # pylint: disable=import-error
import wmi # pylint: disable=import-error

_have_wmi = True
except Exception:
else:
_have_wmi = False

def _config_domain(domain):
Expand Down

0 comments on commit 3c6a797

Please sign in to comment.