Skip to content

Commit

Permalink
Merge pull request #478 from bwelling/query-use-context-managers
Browse files Browse the repository at this point in the history
Use context managers in the query methods.
  • Loading branch information
rthalley committed May 21, 2020
2 parents 198cf1a + 0fa0d19 commit e51e70c
Showing 1 changed file with 33 additions and 56 deletions.
89 changes: 33 additions & 56 deletions dns/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,10 +463,7 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
wire = q.to_wire()
(af, destination, source) = _destination_and_source(af, where, port,
source, source_port)
s = socket_factory(af, socket.SOCK_DGRAM, 0)
received_time = None
sent_time = None
try:
with socket_factory(af, socket.SOCK_DGRAM, 0) as s:
expiration = _compute_expiration(timeout)
s.setblocking(0)
if source is not None:
Expand All @@ -475,16 +472,10 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
(r, received_time) = receive_udp(s, destination, expiration,
ignore_unexpected, one_rr_per_rrset,
q.keyring, q.mac, ignore_trailing)
finally:
if sent_time is None or received_time is None:
response_time = 0
else:
response_time = received_time - sent_time
s.close()
r.time = response_time
if not q.is_response(r):
raise BadResponse
return r
r.time = received_time - sent_time
if not q.is_response(r):
raise BadResponse
return r


def _net_read(sock, count, expiration):
Expand Down Expand Up @@ -637,10 +628,7 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
wire = q.to_wire()
(af, destination, source) = _destination_and_source(af, where, port,
source, source_port)
s = socket_factory(af, socket.SOCK_STREAM, 0)
begin_time = None
received_time = None
try:
with socket_factory(af, socket.SOCK_STREAM, 0) as s:
expiration = _compute_expiration(timeout)
s.setblocking(0)
begin_time = time.time()
Expand All @@ -650,16 +638,21 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
send_tcp(s, wire, expiration)
(r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
q.keyring, q.mac, ignore_trailing)
finally:
if begin_time is None or received_time is None:
response_time = 0
else:
response_time = received_time - begin_time
s.close()
r.time = response_time
if not q.is_response(r):
raise BadResponse
return r
r.time = received_time - begin_time
if not q.is_response(r):
raise BadResponse
return r


def _tls_handshake(s, expiration):
while True:
try:
s.do_handshake()
return
except ssl.SSLWantReadError:
_wait_for_readable(s, expiration)
except ssl.SSLWantWriteError:
_wait_for_writable(s, expiration)


def tls(q, where, timeout=None, port=853, af=None, source=None, source_port=0,
Expand Down Expand Up @@ -708,43 +701,27 @@ def tls(q, where, timeout=None, port=853, af=None, source=None, source_port=0,
wire = q.to_wire()
(af, destination, source) = _destination_and_source(af, where, port,
source, source_port)
s = socket_factory(af, socket.SOCK_STREAM, 0)
begin_time = None
received_time = None
try:
if ssl_context is None:
ssl_context = ssl.create_default_context()
if server_hostname is None:
ssl_context.check_hostname = False
with ssl_context.wrap_socket(socket_factory(af, socket.SOCK_STREAM, 0),
do_handshake_on_connect=False,
server_hostname=server_hostname) as s:
expiration = _compute_expiration(timeout)
s.setblocking(0)
begin_time = time.time()
if source is not None:
s.bind(source)
_connect(s, destination, expiration)
if ssl_context is None:
ssl_context = ssl.create_default_context()
if server_hostname is None:
ssl_context.check_hostname = False
s = ssl_context.wrap_socket(s, do_handshake_on_connect=False,
server_hostname=server_hostname)
while True:
try:
s.do_handshake()
break
except ssl.SSLWantReadError:
_wait_for_readable(s, expiration)
except ssl.SSLWantWriteError:
_wait_for_writable(s, expiration)
_tls_handshake(s, expiration)
send_tcp(s, wire, expiration)
(r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
q.keyring, q.mac, ignore_trailing)
finally:
if begin_time is None or received_time is None:
response_time = 0
else:
response_time = received_time - begin_time
s.close()
r.time = response_time
if not q.is_response(r):
raise BadResponse
return r
r.time = received_time - begin_time
if not q.is_response(r):
raise BadResponse
return r


def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
Expand Down

0 comments on commit e51e70c

Please sign in to comment.