Skip to content

Commit

Permalink
Ensure asyncio datagram sockets on windows have had a bind() before
Browse files Browse the repository at this point in the history
recvfrom().

The fix for [#637] erroneously concluded that that windows asyncio
needed connected datagram sockets, but subsequent further
investation showed that the actual problem was that windows wants
an unconnected datagram socket to be bound before recvfrom is called.
Linux autobinds in this case to the wildcard address and port, so
that's why we didn't see any problems there.  We now ensure that
the source is bound.
  • Loading branch information
rthalley committed Dec 15, 2023
1 parent 143c264 commit adfc942
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 27 deletions.
13 changes: 6 additions & 7 deletions dns/_asyncio_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import dns._asyncbackend
import dns.exception
import dns.inet

_is_win32 = sys.platform == "win32"

Expand Down Expand Up @@ -224,14 +225,12 @@ async def make_socket(
ssl_context=None,
server_hostname=None,
):
if destination is None and socktype == socket.SOCK_DGRAM and _is_win32:
raise NotImplementedError(
"destinationless datagram sockets "
"are not supported by asyncio "
"on Windows"
)
loop = _get_running_loop()
if socktype == socket.SOCK_DGRAM:
if _is_win32 and source is None:
# Win32 wants explicit binding before recvfrom(). This is the
# proper fix for [#637].
source = (dns.inet.any_for_af(af), 0)
transport, protocol = await loop.create_datagram_endpoint(
_DatagramProtocol,
source,
Expand Down Expand Up @@ -266,7 +265,7 @@ async def sleep(self, interval):
await asyncio.sleep(interval)

def datagram_connection_required(self):
return _is_win32
return False

def get_transport_class(self):
return _HTTPTransport
Expand Down
25 changes: 5 additions & 20 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,6 @@ def test_source_tuple(self):

@unittest.skipIf(not tests.util.is_internet_reachable(), "Internet not reachable")
class AsyncTests(unittest.TestCase):
connect_udp = sys.platform == "win32"

def setUp(self):
self.backend = dns.asyncbackend.set_default_backend("asyncio")

Expand Down Expand Up @@ -327,12 +325,12 @@ def testQueryUDPWithSocket(self):
qname = dns.name.from_text("dns.google.")

async def run():
if self.connect_udp:
dtuple = (address, 53)
else:
dtuple = None
async with await self.backend.make_socket(
dns.inet.af_for_address(address), socket.SOCK_DGRAM, 0, None, dtuple
dns.inet.af_for_address(address),
socket.SOCK_DGRAM,
0,
None,
None,
) as s:
q = dns.message.make_query(qname, dns.rdatatype.A)
return await dns.asyncquery.udp(q, address, sock=s, timeout=2)
Expand Down Expand Up @@ -485,9 +483,6 @@ async def run():
self.assertFalse(tcp)

def testUDPReceiveQuery(self):
if self.connect_udp:
self.skipTest("test needs connectionless sockets")

async def run():
async with await self.backend.make_socket(
socket.AF_INET, socket.SOCK_DGRAM, source=("127.0.0.1", 0)
Expand All @@ -509,9 +504,6 @@ async def run():
self.assertEqual(sender_address, recv_address)

def testUDPReceiveTimeout(self):
if self.connect_udp:
self.skipTest("test needs connectionless sockets")

async def arun():
async with await self.backend.make_socket(
socket.AF_INET, socket.SOCK_DGRAM, 0, ("127.0.0.1", 0)
Expand Down Expand Up @@ -616,18 +608,13 @@ async def run():

@unittest.skipIf(not tests.util.is_internet_reachable(), "Internet not reachable")
class AsyncioOnlyTests(unittest.TestCase):
connect_udp = sys.platform == "win32"

def setUp(self):
self.backend = dns.asyncbackend.set_default_backend("asyncio")

def async_run(self, afunc):
return asyncio.run(afunc())

def testUseAfterTimeout(self):
if self.connect_udp:
self.skipTest("test needs connectionless sockets")

# Test #843 fix.
async def run():
qname = dns.name.from_text("dns.google")
Expand Down Expand Up @@ -678,8 +665,6 @@ def async_run(self, afunc):
return trio.run(afunc)

class TrioAsyncTests(AsyncTests):
connect_udp = False

def setUp(self):
self.backend = dns.asyncbackend.set_default_backend("trio")

Expand Down

0 comments on commit adfc942

Please sign in to comment.