Skip to content

Commit

Permalink
finish testing of resolver business logic
Browse files Browse the repository at this point in the history
  • Loading branch information
rthalley committed May 20, 2020
1 parent 1573dd6 commit 410d7f5
Showing 1 changed file with 154 additions and 15 deletions.
169 changes: 154 additions & 15 deletions tests/test_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dns.message
import dns.name
import dns.rcode
import dns.rdataclass
import dns.rdatatype
import dns.resolver
Expand Down Expand Up @@ -45,14 +46,28 @@ def bad():
(request, answer) = self.resn.next_request()
self.assertRaises(dns.resolver.NXDOMAIN, bad)

def test_next_request_cache_hit(self):
self.resolver.cache = dns.resolver.Cache()
q = dns.message.make_query(self.qname, dns.rdatatype.A)
def make_address_response(self, q):
r = dns.message.make_response(q)
rrs = r.get_rrset(r.answer, self.qname, dns.rdataclass.IN,
dns.rdatatype.A, create=True)
rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A,
'10.0.0.1'), 300)
return r

def make_negative_response(self, q, nxdomain=False):
r = dns.message.make_response(q)
rrs = r.get_rrset(r.authority, self.qname, dns.rdataclass.IN,
dns.rdatatype.SOA, create=True)
rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA,
'. . 1 2 3 4 300'), 300)
if nxdomain:
r.set_rcode(dns.rcode.NXDOMAIN)
return r

def test_next_request_cache_hit(self):
self.resolver.cache = dns.resolver.Cache()
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_address_response(q)
cache_answer = dns.resolver.Answer(self.qname, dns.rdatatype.A,
dns.rdataclass.IN, r)
self.resolver.cache.put((self.qname, dns.rdatatype.A,
Expand All @@ -65,12 +80,9 @@ def test_next_request_no_answer(self):
# In default mode, we should raise on a no-answer hit
self.resolver.cache = dns.resolver.Cache()
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = dns.message.make_response(q)
# We need an SOA so the cache doesn't expire the answer immediately.
rrs = r.get_rrset(r.authority, self.qname, dns.rdataclass.IN,
dns.rdatatype.SOA, create=True)
rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA,
'. . 1 2 3 4 300'), 300)
# Note we need an SOA so the cache doesn't expire the answer
# immediately, but our negative response code does that.
r = self.make_negative_response(q)
cache_answer = dns.resolver.Answer(self.qname, dns.rdatatype.A,
dns.rdataclass.IN, r, False)
self.resolver.cache.put((self.qname, dns.rdatatype.A,
Expand All @@ -87,15 +99,14 @@ def bad():
self.assertTrue(answer is cache_answer)

def test_next_nameserver_udp(self):
nameservers = {'10.0.0.1', '10.0.0.2'}
(request, answer) = self.resn.next_request()
(nameserver1, port, tcp, backoff) = self.resn.next_nameserver()
self.assertTrue(nameserver1 in nameservers)
self.assertTrue(nameserver1 in self.resolver.nameservers)
self.assertEqual(port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.0)
(nameserver2, port, tcp, backoff) = self.resn.next_nameserver()
self.assertTrue(nameserver2 in nameservers)
self.assertTrue(nameserver2 in self.resolver.nameservers)
self.assertTrue(nameserver2 != nameserver1)
self.assertEqual(port, 53)
self.assertFalse(tcp)
Expand All @@ -117,10 +128,9 @@ def test_next_nameserver_udp(self):
self.assertEqual(backoff, 0.2)

def test_next_nameserver_retry_with_tcp(self):
nameservers = {'10.0.0.1', '10.0.0.2'}
(request, answer) = self.resn.next_request()
(nameserver1, port, tcp, backoff) = self.resn.next_nameserver()
self.assertTrue(nameserver1 in nameservers)
self.assertTrue(nameserver1 in self.resolver.nameservers)
self.assertEqual(port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.0)
Expand All @@ -131,7 +141,7 @@ def test_next_nameserver_retry_with_tcp(self):
self.assertTrue(tcp)
self.assertEqual(backoff, 0.0)
(nameserver3, port, tcp, backoff) = self.resn.next_nameserver()
self.assertTrue(nameserver3 in nameservers)
self.assertTrue(nameserver3 in self.resolver.nameservers)
self.assertTrue(nameserver3 != nameserver1)
self.assertEqual(port, 53)
self.assertFalse(tcp)
Expand All @@ -146,3 +156,132 @@ def test_next_nameserver_no_nameservers(self):
def bad():
(nameserver, _, _, _) = self.resn.next_nameserver()
self.assertRaises(dns.resolver.NoNameservers, bad)

def test_query_result_nameserver_removing_exceptions(self):
# add some nameservers so we have enough to remove :)
self.resolver.nameservers.extend(['10.0.0.3', '10.0.0.4'])
(request, _) = self.resn.next_request()
exceptions = [dns.exception.FormError(), EOFError(),
NotImplementedError(), dns.message.Truncated()]
for i in range(4):
(nameserver, _, _, _) = self.resn.next_nameserver()
if i == 3:
# Truncated is only bad if we're doing TCP, make it look
# like that's the case
self.resn.tcp_attempt = True
self.assertTrue(nameserver in self.resn.nameservers)
(answer, done) = self.resn.query_result(None, exceptions[i])
self.assertTrue(answer is None)
self.assertFalse(done)
self.assertFalse(nameserver in self.resn.nameservers)
self.assertEqual(len(self.resn.nameservers), 0)

def test_query_result_nameserver_continuing_exception(self):
# except for the exceptions tested in
# test_query_result_nameserver_removing_exceptions(), we should
# not remove any nameservers and just continue resolving.
(_, _) = self.resn.next_request()
(_, _, _, _) = self.resn.next_nameserver()
nameservers = self.resn.nameservers[:]
(answer, done) = self.resn.query_result(None, dns.exception.Timeout())
self.assertTrue(answer is None)
self.assertFalse(done)
self.assertEqual(nameservers, self.resn.nameservers)

def test_query_result_retry_with_tcp(self):
(request, _) = self.resn.next_request()
(nameserver, _, tcp, _) = self.resn.next_nameserver()
self.assertFalse(tcp)
(answer, done) = self.resn.query_result(None, dns.message.Truncated())
self.assertTrue(answer is None)
self.assertFalse(done)
self.assertTrue(self.resn.retry_with_tcp)
# The rest of TCP retry logic was tested above in
# test_next_nameserver_retry_with_tcp(), so we do not repeat
# it.

def test_query_result_no_error_with_data(self):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_address_response(q)
(_, _) = self.resn.next_request()
(_, _, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertFalse(answer is None)
self.assertTrue(done)
self.assertEqual(answer.qname, self.qname)
self.assertEqual(answer.rdtype, dns.rdatatype.A)

def test_query_result_no_error_with_data_cached(self):
self.resolver.cache = dns.resolver.Cache()
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_address_response(q)
(_, _) = self.resn.next_request()
(_, _, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertFalse(answer is None)
cache_answer = self.resolver.cache.get((self.qname, dns.rdatatype.A,
dns.rdataclass.IN))
self.assertTrue(answer is cache_answer)

def test_query_result_no_error_no_data(self):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_negative_response(q)
(_, _) = self.resn.next_request()
(_, _, _, _) = self.resn.next_nameserver()
def bad():
(answer, done) = self.resn.query_result(r, None)
self.assertRaises(dns.resolver.NoAnswer, bad)

def test_query_result_nxdomain(self):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_negative_response(q, True)
(_, _) = self.resn.next_request()
(_, _, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertTrue(answer is None)
self.assertTrue(done)

def test_query_result_yxdomain(self):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_address_response(q)
r.set_rcode(dns.rcode.YXDOMAIN)
(_, _) = self.resn.next_request()
(_, _, _, _) = self.resn.next_nameserver()
def bad():
(answer, done) = self.resn.query_result(r, None)
self.assertRaises(dns.resolver.YXDOMAIN, bad)

def test_query_result_servfail_no_retry(self):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_address_response(q)
r.set_rcode(dns.rcode.SERVFAIL)
(_, _) = self.resn.next_request()
(nameserver, _, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertTrue(answer is None)
self.assertFalse(done)
self.assertTrue(nameserver not in self.resn.nameservers)

def test_query_result_servfail_with_retry(self):
self.resolver.retry_servfail = True
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_address_response(q)
r.set_rcode(dns.rcode.SERVFAIL)
(_, _) = self.resn.next_request()
(_, _, _, _) = self.resn.next_nameserver()
nameservers = self.resn.nameservers[:]
(answer, done) = self.resn.query_result(r, None)
self.assertTrue(answer is None)
self.assertFalse(done)
self.assertEqual(nameservers, self.resn.nameservers)

def test_query_result_other_unhappy_rcode(self):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_address_response(q)
r.set_rcode(dns.rcode.REFUSED)
(_, _) = self.resn.next_request()
(nameserver, _, _, _) = self.resn.next_nameserver()
(answer, done) = self.resn.query_result(r, None)
self.assertTrue(answer is None)
self.assertFalse(done)
self.assertTrue(nameserver not in self.resn.nameservers)

0 comments on commit 410d7f5

Please sign in to comment.