Skip to content

Commit

Permalink
feat: speed up processing incoming data (#1167)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed May 1, 2023
1 parent 1431517 commit fbaaf7b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 20 deletions.
7 changes: 1 addition & 6 deletions src/zeroconf/_protocol/incoming.pxd
Expand Up @@ -54,7 +54,7 @@ cdef class DNSIncoming:
cdef unsigned int _data_len
cdef public cython.dict name_cache
cdef public cython.list questions
cdef object _answers
cdef cython.list _answers
cdef public object id
cdef public cython.uint num_questions
cdef public cython.uint num_answers
Expand All @@ -78,8 +78,6 @@ cdef class DNSIncoming:

cdef _initial_parse(self)

cdef _unpack(self, object unpacker, object length)

@cython.locals(
end=cython.uint,
length=cython.uint
Expand All @@ -88,9 +86,6 @@ cdef class DNSIncoming:

cdef _read_questions(self)

@cython.locals(
length=cython.uint
)
cdef bytes _read_character_string(self)

cdef _read_string(self, unsigned int length)
Expand Down
32 changes: 18 additions & 14 deletions src/zeroconf/_protocol/incoming.py
Expand Up @@ -22,7 +22,7 @@

import struct
import sys
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from .._dns import (
DNSAddress,
Expand Down Expand Up @@ -194,10 +194,6 @@ def __repr__(self) -> str:
]
)

def _unpack(self, unpacker: Callable[[bytes, int], tuple], length: int) -> tuple:
self.offset += length
return unpacker(self.data, self.offset - length)

def _read_header(self) -> None:
"""Reads header portion of packet"""
(
Expand All @@ -207,7 +203,8 @@ def _read_header(self) -> None:
self.num_answers,
self.num_authorities,
self.num_additionals,
) = self._unpack(UNPACK_6H, 12)
) = UNPACK_6H(self.data)
self.offset += 12

def _read_questions(self) -> None:
"""Reads questions section of packet"""
Expand Down Expand Up @@ -264,18 +261,24 @@ def _read_record(
) -> Optional[DNSRecord]:
"""Read known records types and skip unknown ones."""
if type_ == _TYPE_A:
return DNSAddress(domain, type_, class_, ttl, self._read_string(4), created=self.now)
dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(4))
dns_address.created = self.now
return dns_address
if type_ in (_TYPE_CNAME, _TYPE_PTR):
return DNSPointer(domain, type_, class_, ttl, self._read_name(), self.now)
if type_ == _TYPE_TXT:
return DNSText(domain, type_, class_, ttl, self._read_string(length), self.now)
if type_ == _TYPE_SRV:
priority, weight, port = UNPACK_3H(self.data, self.offset)
self.offset += 6
return DNSService(
domain,
type_,
class_,
ttl,
*cast(Tuple[int, int, int], self._unpack(UNPACK_3H, 6)),
priority,
weight,
port,
self._read_name(),
self.now,
)
Expand All @@ -285,14 +288,15 @@ def _read_record(
type_,
class_,
ttl,
self._read_character_string().decode('utf-8'),
self._read_character_string().decode('utf-8'),
self._read_character_string().decode('utf-8', 'replace'),
self._read_character_string().decode('utf-8', 'replace'),
self.now,
)
if type_ == _TYPE_AAAA:
return DNSAddress(
domain, type_, class_, ttl, self._read_string(16), created=self.now, scope_id=self.scope_id
)
dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(16))
dns_address.created = self.now
dns_address.scope_id = self.scope_id
return dns_address
if type_ == _TYPE_NSEC:
name_start = self.offset
return DNSNsec(
Expand Down Expand Up @@ -384,4 +388,4 @@ def _decode_labels_at_offset(self, off: int, labels: List[str], seen_pointers: S
)
return off + DNS_COMPRESSION_POINTER_LEN

raise IncomingDecodeError("Corrupt packet received while decoding name from {self.source}")
raise IncomingDecodeError(f"Corrupt packet received while decoding name from {self.source}")

0 comments on commit fbaaf7b

Please sign in to comment.