Skip to content

Commit

Permalink
The Tudoor fix should not eat valid Truncated exceptions [#1053] (#1054)
Browse files Browse the repository at this point in the history
* The Tudoor fix should not eat valid Truncated exceptions [##1053]

* Make logic more readable

(cherry picked from commit 2ab3d16)
  • Loading branch information
rthalley committed Feb 18, 2024
1 parent f12d398 commit 0ea5ad0
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 2 deletions.
10 changes: 10 additions & 0 deletions dns/asyncquery.py
Expand Up @@ -151,6 +151,16 @@ async def receive_udp(
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation,
)
except dns.message.Truncated as e:
# See the comment in query.py for details.
if (
ignore_errors
and query is not None
and not query.is_response(e.message())
):
continue
else:
raise
except Exception:
if ignore_errors:
continue
Expand Down
14 changes: 14 additions & 0 deletions dns/query.py
Expand Up @@ -638,6 +638,20 @@ def receive_udp(
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation,
)
except dns.message.Truncated as e:
# If we got Truncated and not FORMERR, we at least got the header with TC
# set, and very likely the question section, so we'll re-raise if the
# message seems to be a response as we need to know when truncation happens.
# We need to check that it seems to be a response as we don't want a random
# injected message with TC set to cause us to bail out.
if (
ignore_errors
and query is not None
and not query.is_response(e.message())
):
continue
else:
raise
except Exception:
if ignore_errors:
continue
Expand Down
60 changes: 59 additions & 1 deletion tests/test_async.py
Expand Up @@ -705,17 +705,22 @@ async def mock_receive(
from2,
ignore_unexpected=True,
ignore_errors=True,
raise_on_truncation=False,
good_r=None,
):
if good_r is None:
good_r = self.good_r
s = MockSock(wire1, from1, wire2, from2)
(r, when, _) = await dns.asyncquery.receive_udp(
s,
("127.0.0.1", 53),
time.time() + 2,
ignore_unexpected=ignore_unexpected,
ignore_errors=ignore_errors,
raise_on_truncation=raise_on_truncation,
query=self.q,
)
self.assertEqual(r, self.good_r)
self.assertEqual(r, good_r)

def test_good_mock(self):
async def run():
Expand Down Expand Up @@ -802,6 +807,59 @@ async def run():

self.async_run(run)

def test_good_wire_with_truncation_flag_and_no_truncation_raise(self):
async def run():
tc_r = dns.message.make_response(self.q)
tc_r.flags |= dns.flags.TC
tc_r_wire = tc_r.to_wire()
await self.mock_receive(
tc_r_wire, ("127.0.0.1", 53), None, None, good_r=tc_r
)

self.async_run(run)

def test_good_wire_with_truncation_flag_and_truncation_raise(self):
async def agood():
tc_r = dns.message.make_response(self.q)
tc_r.flags |= dns.flags.TC
tc_r_wire = tc_r.to_wire()
await self.mock_receive(
tc_r_wire, ("127.0.0.1", 53), None, None, raise_on_truncation=True
)

def good():
self.async_run(agood)

self.assertRaises(dns.message.Truncated, good)

def test_wrong_id_wire_with_truncation_flag_and_no_truncation_raise(self):
async def run():
bad_r = dns.message.make_response(self.q)
bad_r.id += 1
bad_r.flags |= dns.flags.TC
bad_r_wire = bad_r.to_wire()
await self.mock_receive(
bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
)

self.async_run(run)

def test_wrong_id_wire_with_truncation_flag_and_truncation_raise(self):
async def run():
bad_r = dns.message.make_response(self.q)
bad_r.id += 1
bad_r.flags |= dns.flags.TC
bad_r_wire = bad_r.to_wire()
await self.mock_receive(
bad_r_wire,
("127.0.0.1", 53),
self.good_r_wire,
("127.0.0.1", 53),
raise_on_truncation=True,
)

self.async_run(run)

def test_bad_wire_not_ignored(self):
bad_r = dns.message.make_response(self.q)
bad_r.id += 1
Expand Down
44 changes: 43 additions & 1 deletion tests/test_query.py
Expand Up @@ -29,6 +29,7 @@
have_ssl = False

import dns.exception
import dns.flags
import dns.inet
import dns.message
import dns.name
Expand Down Expand Up @@ -706,7 +707,11 @@ def mock_receive(
from2,
ignore_unexpected=True,
ignore_errors=True,
raise_on_truncation=False,
good_r=None,
):
if good_r is None:
good_r = self.good_r
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
with mock_udp_recv(wire1, from1, wire2, from2):
Expand All @@ -716,9 +721,10 @@ def mock_receive(
time.time() + 2,
ignore_unexpected=ignore_unexpected,
ignore_errors=ignore_errors,
raise_on_truncation=raise_on_truncation,
query=self.q,
)
self.assertEqual(r, self.good_r)
self.assertEqual(r, good_r)
finally:
s.close()

Expand Down Expand Up @@ -787,6 +793,42 @@ def test_bad_wire(self):
bad_r_wire[:10], ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
)

def test_good_wire_with_truncation_flag_and_no_truncation_raise(self):
tc_r = dns.message.make_response(self.q)
tc_r.flags |= dns.flags.TC
tc_r_wire = tc_r.to_wire()
self.mock_receive(tc_r_wire, ("127.0.0.1", 53), None, None, good_r=tc_r)

def test_good_wire_with_truncation_flag_and_truncation_raise(self):
def good():
tc_r = dns.message.make_response(self.q)
tc_r.flags |= dns.flags.TC
tc_r_wire = tc_r.to_wire()
self.mock_receive(
tc_r_wire, ("127.0.0.1", 53), None, None, raise_on_truncation=True
)

self.assertRaises(dns.message.Truncated, good)

def test_wrong_id_wire_with_truncation_flag_and_no_truncation_raise(self):
bad_r = dns.message.make_response(self.q)
bad_r.id += 1
bad_r.flags |= dns.flags.TC
bad_r_wire = bad_r.to_wire()
self.mock_receive(
bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
)

def test_wrong_id_wire_with_truncation_flag_and_truncation_raise(self):
bad_r = dns.message.make_response(self.q)
bad_r.id += 1
bad_r.flags |= dns.flags.TC
bad_r_wire = bad_r.to_wire()
self.mock_receive(
bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53),
raise_on_truncation=True
)

def test_bad_wire_not_ignored(self):
bad_r = dns.message.make_response(self.q)
bad_r.id += 1
Expand Down

0 comments on commit 0ea5ad0

Please sign in to comment.