Skip to content

Commit

Permalink
feat: reduce overhead to send responses (#1135)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Apr 1, 2023
1 parent d45c2f9 commit c4077dd
Showing 1 changed file with 64 additions and 24 deletions.
88 changes: 64 additions & 24 deletions src/zeroconf/_core.py
Expand Up @@ -28,7 +28,7 @@
import sys
import threading
from types import TracebackType # noqa # used in type hints
from typing import Awaitable, Dict, List, Optional, Tuple, Type, Union, cast
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Type, Union, cast

from ._cache import DNSCache
from ._dns import DNSQuestion, DNSQuestionType
Expand Down Expand Up @@ -105,6 +105,48 @@
_REGISTER_BROADCASTS = 3


class _WrappedTransport:
"""A wrapper for transports."""

__slots__ = (
'transport',
'is_ipv6',
'sock',
'fileno',
'sock_name',
)

def __init__(
self,
transport: asyncio.DatagramTransport,
is_ipv6: bool,
sock: socket.socket,
fileno: int,
sock_name: Any,
) -> None:
"""Initialize the wrapped transport.
These attributes are used when sending packets.
"""
self.transport = transport
self.is_ipv6 = is_ipv6
self.sock = sock
self.fileno = fileno
self.sock_name = sock_name


def _make_wrapped_transport(transport: asyncio.DatagramTransport) -> _WrappedTransport:
"""Make a wrapped transport."""
sock: socket.socket = transport.get_extra_info('socket')
return _WrappedTransport(
transport=transport,
is_ipv6=sock.family == socket.AF_INET6,
sock=sock,
fileno=sock.fileno(),
sock_name=sock.getsockname(),
)


class AsyncEngine:
"""An engine wraps sockets in the event loop."""

Expand All @@ -117,8 +159,8 @@ def __init__(
self.loop: Optional[asyncio.AbstractEventLoop] = None
self.zc = zeroconf
self.protocols: List[AsyncListener] = []
self.readers: List[asyncio.DatagramTransport] = []
self.senders: List[asyncio.DatagramTransport] = []
self.readers: List[_WrappedTransport] = []
self.senders: List[_WrappedTransport] = []
self.running_event: Optional[asyncio.Event] = None
self._listen_socket = listen_socket
self._respond_sockets = respond_sockets
Expand Down Expand Up @@ -158,9 +200,9 @@ async def _async_create_endpoints(self) -> None:
for s in reader_sockets:
transport, protocol = await loop.create_datagram_endpoint(lambda: AsyncListener(self.zc), sock=s)
self.protocols.append(cast(AsyncListener, protocol))
self.readers.append(cast(asyncio.DatagramTransport, transport))
self.readers.append(_make_wrapped_transport(cast(asyncio.DatagramTransport, transport)))
if s in sender_sockets:
self.senders.append(cast(asyncio.DatagramTransport, transport))
self.senders.append(_make_wrapped_transport(cast(asyncio.DatagramTransport, transport)))

def _async_cache_cleanup(self) -> None:
"""Periodic cache cleanup."""
Expand All @@ -186,8 +228,8 @@ def _async_shutdown(self) -> None:
"""Shutdown transports and sockets."""
assert self.running_event is not None
self.running_event.clear()
for transport in itertools.chain(self.senders, self.readers):
transport.close()
for wrapped_transport in itertools.chain(self.senders, self.readers):
wrapped_transport.transport.close()

def close(self) -> None:
"""Close from sync context.
Expand Down Expand Up @@ -221,7 +263,7 @@ def __init__(self, zc: 'Zeroconf') -> None:
self.zc = zc
self.data: Optional[bytes] = None
self.last_time: float = 0
self.transport: Optional[asyncio.DatagramTransport] = None
self.transport: Optional[_WrappedTransport] = None
self.sock_description: Optional[str] = None
self._deferred: Dict[str, List[DNSIncoming]] = {}
self._timers: Dict[str, asyncio.TimerHandle] = {}
Expand Down Expand Up @@ -309,7 +351,7 @@ def handle_query_or_defer(
msg: DNSIncoming,
addr: str,
port: int,
transport: asyncio.DatagramTransport,
transport: _WrappedTransport,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
) -> None:
"""Deal with incoming query packets. Provides a response if
Expand Down Expand Up @@ -341,7 +383,7 @@ def _respond_query(
msg: Optional[DNSIncoming],
addr: str,
port: int,
transport: asyncio.DatagramTransport,
transport: _WrappedTransport,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
) -> None:
"""Respond to a query and reassemble any truncated deferred packets."""
Expand All @@ -362,27 +404,25 @@ def error_received(self, exc: Exception) -> None:
self.log_exception_once(exc, msg_str, exc)

def connection_made(self, transport: asyncio.BaseTransport) -> None:
self.transport = cast(asyncio.DatagramTransport, transport)
sock_name = self.transport.get_extra_info('sockname')
sock_fileno = self.transport.get_extra_info('socket').fileno()
self.sock_description = f"{sock_fileno} ({sock_name})"
wrapped_transport = _make_wrapped_transport(cast(asyncio.DatagramTransport, transport))
self.transport = wrapped_transport
self.sock_description = f"{wrapped_transport.fileno} ({wrapped_transport.sock_name})"

def connection_lost(self, exc: Optional[Exception]) -> None:
"""Handle connection lost."""


def async_send_with_transport(
log_debug: bool,
transport: asyncio.DatagramTransport,
transport: _WrappedTransport,
packet: bytes,
packet_num: int,
out: DNSOutgoing,
addr: Optional[str],
port: int,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
) -> None:
s = transport.get_extra_info('socket')
ipv6_socket = s.family == socket.AF_INET6
ipv6_socket = transport.is_ipv6
if addr is None:
real_addr = _MDNS_ADDR6 if ipv6_socket else _MDNS_ADDR
else:
Expand All @@ -394,8 +434,8 @@ def async_send_with_transport(
'Sending to (%s, %d) via [socket %s (%s)] (%d bytes #%d) %r as %r...',
real_addr,
port or _MDNS_PORT,
s.fileno(),
transport.get_extra_info('sockname'),
transport.fileno,
transport.sock_name,
len(packet),
packet_num + 1,
out,
Expand All @@ -404,9 +444,9 @@ def async_send_with_transport(
# Get flowinfo and scopeid for the IPV6 socket to create a complete IPv6
# address tuple: https://docs.python.org/3.6/library/socket.html#socket-families
if ipv6_socket and not v6_flow_scope:
_, _, sock_flowinfo, sock_scopeid = s.getsockname()
_, _, sock_flowinfo, sock_scopeid = transport.sock_name
v6_flow_scope = (sock_flowinfo, sock_scopeid)
transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope))
transport.transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope))


class Zeroconf(QuietLogger):
Expand Down Expand Up @@ -832,7 +872,7 @@ def handle_assembled_query(
packets: List[DNSIncoming],
addr: str,
port: int,
transport: asyncio.DatagramTransport,
transport: _WrappedTransport,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
) -> None:
"""Respond to a (re)assembled query.
Expand Down Expand Up @@ -870,7 +910,7 @@ def send(
addr: Optional[str] = None,
port: int = _MDNS_PORT,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
transport: Optional[asyncio.DatagramTransport] = None,
transport: Optional[_WrappedTransport] = None,
) -> None:
"""Sends an outgoing packet threadsafe."""
assert self.loop is not None
Expand All @@ -882,7 +922,7 @@ def async_send(
addr: Optional[str] = None,
port: int = _MDNS_PORT,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
transport: Optional[asyncio.DatagramTransport] = None,
transport: Optional[_WrappedTransport] = None,
) -> None:
"""Sends an outgoing packet."""
if self.done:
Expand Down

0 comments on commit c4077dd

Please sign in to comment.