Skip to content

Commit

Permalink
Fix setting source for sync/trio quic queries.
Browse files Browse the repository at this point in the history
The sync code called connect() before bind(), which meant that any
attempt to specify a source resulted in an exception.  This switches the
order.

The trio code called a nonexistent method in the wrong place, so didn't
work at all.  This fixes the call and puts it in the right place.

The asyncio code worked, so no changes were needed.
  • Loading branch information
bwelling committed Dec 7, 2023
1 parent 860ba4d commit 186922d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
8 changes: 4 additions & 4 deletions dns/quic/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,6 @@ class SyncQuicConnection(BaseQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0)
self._socket.connect(self._peer)
(self._send_wakeup, self._receive_wakeup) = socket.socketpair()
self._receive_wakeup.setblocking(False)
self._socket.setblocking(False)
if self._source is not None:
try:
self._socket.bind(
Expand All @@ -94,6 +90,10 @@ def __init__(self, connection, address, port, source, source_port, manager):
except Exception:
self._socket.close()
raise
self._socket.connect(self._peer)
(self._send_wakeup, self._receive_wakeup) = socket.socketpair()
self._receive_wakeup.setblocking(False)
self._socket.setblocking(False)
self._handshake_complete = threading.Event()
self._worker_thread = None
self._lock = threading.Lock()
Expand Down
6 changes: 4 additions & 2 deletions dns/quic/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,17 @@ class TrioQuicConnection(AsyncQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager=None):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0)
if self._source:
trio.socket.bind(dns.inet.low_level_address_tuple(self._source, self._af))
self._handshake_complete = trio.Event()
self._run_done = trio.Event()
self._worker_scope = None
self._send_pending = False

async def _worker(self):
try:
if self._source:
await self._socket.bind(
dns.inet.low_level_address_tuple(self._source, self._af)
)
await self._socket.connect(self._peer)
while not self._done:
(expiration, interval) = self._get_timer_values(False)
Expand Down

0 comments on commit 186922d

Please sign in to comment.