Permalink
Browse files

Merge dns-py3-6057

Author: exarkun
Reviewer: itamarst
Fixes: #6057

Port `twisted.names.dns` to Python 3.  As part of this, add some missing test coverage
and fix minor comparison issues.


git-svn-id: svn://svn.twistedmatrix.com/svn/Twisted/trunk@36014 bbbe8e31-12d6-0310-92fd-ac37d47ddeeb
  • Loading branch information...
1 parent ea5a005 commit a327460722a0823333b20dadea949eb7927d77e4 @exarkun exarkun committed Oct 9, 2012
View
4 admin/_twistedpython3.py
@@ -34,6 +34,9 @@
"twisted.internet.test.modulehelpers",
"twisted.internet.test.reactormixins",
"twisted.internet._utilspy3",
+ "twisted.names",
+ "twisted.names.dns",
+ "twisted.names.test",
"twisted.protocols",
"twisted.protocols.basic",
"twisted.protocols.test",
@@ -93,6 +96,7 @@
"twisted.internet.test.test_udp",
"twisted.internet.test.test_udp_internals",
"twisted.internet.test.test_utilspy3",
+ "twisted.names.test.test_dns",
"twisted.protocols.test.test_basic",
"twisted.python.test.test_components",
"twisted.python.test.test_deprecate",
View
10 doc/core/howto/python3.xhtml
@@ -52,5 +52,15 @@
string literals rather than "native" string literals (which are text on
Python 3).</p>
+ <p><code>twisted.names.dns</code> deals with strings with a wide range of
+ meanings, often several for each DNS record type. Most of these strings
+ have remained as byte strings, which will probably require application
+ updates (for the reason given in the <code>FilePath</code> section above).
+ Some strings have changed to text strings, though. Any string representing
+ a human readable address (for
+ example, <code>Record_A</code>'s <code>address</code> parameter) is now a
+ text string. Additionally, time-to-live (ttl) values given as strings must
+ now be given as text strings.</p>
+
</body>
</html>
View
291 twisted/names/dns.py
@@ -7,11 +7,10 @@
Future Plans:
- Get rid of some toplevels, maybe.
-
-@author: Moshe Zadka
-@author: Jean-Paul Calderone
"""
+from __future__ import division, absolute_import
+
__all__ = [
'IEncodable', 'IRecord',
@@ -47,20 +46,61 @@
import warnings
import struct, random, types, socket
+from itertools import chain
-import cStringIO as StringIO
+from io import BytesIO
AF_INET6 = socket.AF_INET6
-from zope.interface import implements, Interface, Attribute
+from zope.interface import implementer, Interface, Attribute
# Twisted imports
from twisted.internet import protocol, defer
from twisted.internet.error import CannotListenError
from twisted.python import log, failure
-from twisted.python import util as tputil
+from twisted.python import _utilpy3 as tputil
from twisted.python import randbytes
+from twisted.python.compat import _PY3, unicode, comparable, cmp, nativeString
+
+
+if _PY3:
+ def _ord2bytes(ordinal):
+ """
+ Construct a bytes object representing a single byte with the given
+ ordinal value.
+
+ @type ordinal: C{int}
+ @rtype: C{bytes}
+ """
+ return bytes([ordinal])
+
+
+ def _nicebytes(bytes):
+ """
+ Represent a mostly textful bytes object in a way suitable for presentation
+ to an end user.
+
+ @param bytes: The bytes to represent.
+ @rtype: C{str}
+ """
+ return repr(bytes)[1:]
+
+
+ def _nicebyteslist(list):
+ """
+ Represent a list of mostly textful bytes objects in a way suitable for
+ presentation to an end user.
+
+ @param list: The list of bytes to represent.
+ @rtype: C{str}
+ """
+ return '[%s]' % (
+ ', '.join([_nicebytes(b) for b in list]),)
+else:
+ _ord2bytes = chr
+ _nicebytes = _nicebyteslist = repr
+
def randomSource():
@@ -123,7 +163,7 @@ def randomSource():
}
REV_TYPES = dict([
- (v, k) for (k, v) in QUERY_TYPES.items() + EXT_QUERIES.items()
+ (v, k) for (k, v) in chain(QUERY_TYPES.items(), EXT_QUERIES.items())
])
IN, CS, CH, HS = range(1, 5)
@@ -165,19 +205,34 @@ class IRecord(Interface):
def str2time(s):
+ """
+ Parse a string description of an interval into an integer number of seconds.
+
+ @param s: An interval definition constructed as an interval duration
+ followed by an interval unit. An interval duration is a base ten
+ representation of an integer. An interval unit is one of the following
+ letters: S (seconds), M (minutes), H (hours), D (days), W (weeks), or Y
+ (years). For example: C{"3S"} indicates an interval of three seconds;
+ C{"5D"} indicates an interval of five days. Alternatively, C{s} may be
+ any non-string and it will be returned unmodified.
+ @type s: text string (C{str}) for parsing; anything else for passthrough.
+
+ @return: an C{int} giving the interval represented by the string C{s}, or
+ whatever C{s} is if it is not a string.
+ """
suffixes = (
('S', 1), ('M', 60), ('H', 60 * 60), ('D', 60 * 60 * 24),
('W', 60 * 60 * 24 * 7), ('Y', 60 * 60 * 24 * 365)
)
- if isinstance(s, types.StringType):
+ if isinstance(s, str):
s = s.upper().strip()
for (suff, mult) in suffixes:
if s.endswith(suff):
return int(float(s[:-1]) * mult)
try:
s = int(s)
except ValueError:
- raise ValueError, "Invalid time interval specifier: " + s
+ raise ValueError("Invalid time interval specifier: " + s)
return s
@@ -225,12 +280,12 @@ def decode(strio, length = None):
+@implementer(IEncodable)
class Charstr(object):
- implements(IEncodable)
- def __init__(self, string=''):
- if not isinstance(string, str):
- raise ValueError("%r is not a string" % (string,))
+ def __init__(self, string=b''):
+ if not isinstance(string, bytes):
+ raise ValueError("%r is not a byte string" % (string,))
self.string = string
@@ -244,13 +299,13 @@ def encode(self, strio, compDict=None):
"""
string = self.string
ind = len(string)
- strio.write(chr(ind))
+ strio.write(_ord2bytes(ind))
strio.write(string)
def decode(self, strio, length=None):
"""
- Decode a byte string into this Name.
+ Decode a byte string into this Charstr.
@type strio: file
@param strio: Bytes will be read from this file until the full string
@@ -259,33 +314,50 @@ def decode(self, strio, length=None):
@raise EOFError: Raised when there are not enough bytes available from
C{strio}.
"""
- self.string = ''
+ self.string = b''
l = ord(readPrecisely(strio, 1))
self.string = readPrecisely(strio, l)
def __eq__(self, other):
if isinstance(other, Charstr):
return self.string == other.string
- return False
+ return NotImplemented
+
+
+ def __ne__(self, other):
+ if isinstance(other, Charstr):
+ return self.string != other.string
+ return NotImplemented
def __hash__(self):
return hash(self.string)
def __str__(self):
- return self.string
+ """
+ Represent this L{Charstr} instance by its string value.
+ """
+ return nativeString(self.string)
+@implementer(IEncodable)
class Name:
- implements(IEncodable)
+ """
+ A name in the domain name system, made up of multiple labels. For example,
+ I{twistedmatrix.com}.
- def __init__(self, name=''):
- assert isinstance(name, types.StringTypes), "%r is not a string" % (name,)
+ @ivar name: A byte string giving the name.
+ @type name: C{bytes}
+ """
+ def __init__(self, name=b''):
+ if not isinstance(name, bytes):
+ raise TypeError("%r is not a byte string" % (name,))
self.name = name
+
def encode(self, strio, compDict=None):
"""
Encode this Name into the appropriate byte format.
@@ -308,15 +380,17 @@ def encode(self, strio, compDict=None):
return
else:
compDict[name] = strio.tell() + Message.headerSize
- ind = name.find('.')
+ ind = name.find(b'.')
if ind > 0:
label, name = name[:ind], name[ind + 1:]
else:
- label, name = name, ''
+ # This is the last label, end the loop after handling it.
+ label = name
+ name = None
ind = len(label)
- strio.write(chr(ind))
+ strio.write(_ord2bytes(ind))
strio.write(label)
- strio.write(chr(0))
+ strio.write(b'\x00')
def decode(self, strio, length=None):
@@ -334,7 +408,7 @@ def decode(self, strio, length=None):
because it contains a loop).
"""
visited = set()
- self.name = ''
+ self.name = b''
off = 0
while 1:
l = ord(readPrecisely(strio, 1))
@@ -353,24 +427,37 @@ def decode(self, strio, length=None):
strio.seek(new_off)
continue
label = readPrecisely(strio, l)
- if self.name == '':
+ if self.name == b'':
self.name = label
else:
- self.name = self.name + '.' + label
+ self.name = self.name + b'.' + label
def __eq__(self, other):
if isinstance(other, Name):
- return str(self) == str(other)
- return 0
+ return self.name == other.name
+ return NotImplemented
+
+
+ def __ne__(self, other):
+ if isinstance(other, Name):
+ return self.name != other.name
+ return NotImplemented
def __hash__(self):
- return hash(str(self))
+ return hash(self.name)
def __str__(self):
- return self.name
+ """
+ Represent this L{Name} instance by its string name.
+ """
+ return nativeString(self.name)
+
+
+@comparable
+@implementer(IEncodable)
class Query:
"""
Represent a single DNS query.
@@ -379,16 +466,13 @@ class Query:
@ivar type: The query type.
@ivar cls: The query class.
"""
-
- implements(IEncodable)
-
name = None
type = None
cls = None
- def __init__(self, name='', type=A, cls=IN):
+ def __init__(self, name=b'', type=A, cls=IN):
"""
- @type name: C{str}
+ @type name: C{bytes}
@param name: The name about which to request information.
@type type: C{int}
@@ -418,10 +502,11 @@ def __hash__(self):
def __cmp__(self, other):
- return isinstance(other, Query) and cmp(
- (str(self.name).lower(), self.type, self.cls),
- (str(other.name).lower(), other.type, other.cls)
- ) or cmp(self.__class__, other.__class__)
+ if isinstance(other, Query):
+ return cmp(
+ (str(self.name).lower(), self.type, self.cls),
+ (str(other.name).lower(), other.type, other.cls))
+ return NotImplemented
def __str__(self):
@@ -434,6 +519,8 @@ def __repr__(self):
return 'Query(%r, %r, %r)' % (str(self.name), self.type, self.cls)
+
+@implementer(IEncodable)
class RRHeader(tputil.FancyEqMixin):
"""
A resource record header.
@@ -449,9 +536,6 @@ class RRHeader(tputil.FancyEqMixin):
@ivar auth: A C{bool} indicating whether this C{RRHeader} was parsed from an
authoritative message.
"""
-
- implements(IEncodable)
-
compareAttributes = ('name', 'type', 'cls', 'ttl', 'payload', 'auth')
fmt = "!HHIH"
@@ -465,9 +549,9 @@ class RRHeader(tputil.FancyEqMixin):
cachedResponse = None
- def __init__(self, name='', type=A, cls=IN, ttl=0, payload=None, auth=False):
+ def __init__(self, name=b'', type=A, cls=IN, ttl=0, payload=None, auth=False):
"""
- @type name: C{str}
+ @type name: C{bytes}
@param name: The name about which this reply contains information.
@type type: C{int}
@@ -531,6 +615,7 @@ def __str__(self):
+@implementer(IEncodable, IRecord)
class SimpleRecord(tputil.FancyStrMixin, tputil.FancyEqMixin):
"""
A Resource Record which consists of a single RFC 1035 domain-name.
@@ -542,15 +627,13 @@ class SimpleRecord(tputil.FancyStrMixin, tputil.FancyEqMixin):
@ivar ttl: The maximum number of seconds which this record should be
cached.
"""
- implements(IEncodable, IRecord)
-
showAttributes = (('name', 'name', '%s'), 'ttl')
compareAttributes = ('name', 'ttl')
TYPE = None
name = None
- def __init__(self, name='', ttl=None):
+ def __init__(self, name=b'', ttl=None):
self.name = Name(name)
self.ttl = str2time(ttl)
@@ -671,6 +754,7 @@ class Record_DNAME(SimpleRecord):
+@implementer(IEncodable, IRecord)
class Record_A(tputil.FancyEqMixin):
"""
An IPv4 host address.
@@ -683,8 +767,6 @@ class Record_A(tputil.FancyEqMixin):
@ivar ttl: The maximum number of seconds which this record should be
cached.
"""
- implements(IEncodable, IRecord)
-
compareAttributes = ('address', 'ttl')
TYPE = A
@@ -718,6 +800,7 @@ def dottedQuad(self):
+@implementer(IEncodable, IRecord)
class Record_SOA(tputil.FancyEqMixin, tputil.FancyStrMixin):
"""
Marks the start of a zone of authority.
@@ -757,15 +840,14 @@ class Record_SOA(tputil.FancyEqMixin, tputil.FancyStrMixin):
@type ttl: C{int}
@ivar ttl: The default TTL to use for records served from this zone.
"""
- implements(IEncodable, IRecord)
-
fancybasename = 'SOA'
compareAttributes = ('serial', 'mname', 'rname', 'refresh', 'expire', 'retry', 'minimum', 'ttl')
showAttributes = (('mname', 'mname', '%s'), ('rname', 'rname', '%s'), 'serial', 'refresh', 'retry', 'expire', 'minimum', 'ttl')
TYPE = SOA
- def __init__(self, mname='', rname='', serial=0, refresh=0, retry=0, expire=0, minimum=0, ttl=None):
+ def __init__(self, mname=b'', rname=b'', serial=0, refresh=0, retry=0,
+ expire=0, minimum=0, ttl=None):
self.mname, self.rname = Name(mname), Name(rname)
self.serial, self.refresh = str2time(serial), str2time(refresh)
self.minimum, self.expire = str2time(minimum), str2time(expire)
@@ -801,6 +883,7 @@ def __hash__(self):
+@implementer(IEncodable, IRecord)
class Record_NULL(tputil.FancyStrMixin, tputil.FancyEqMixin):
"""
A null record.
@@ -811,10 +894,9 @@ class Record_NULL(tputil.FancyStrMixin, tputil.FancyEqMixin):
@ivar ttl: The maximum number of seconds which this record should be
cached.
"""
- implements(IEncodable, IRecord)
-
fancybasename = 'NULL'
- showAttributes = compareAttributes = ('payload', 'ttl')
+ showAttributes = (('payload', _nicebytes), 'ttl')
+ compareAttributes = ('payload', 'ttl')
TYPE = NULL
@@ -836,6 +918,7 @@ def __hash__(self):
+@implementer(IEncodable, IRecord)
class Record_WKS(tputil.FancyEqMixin, tputil.FancyStrMixin):
"""
A well known service description.
@@ -858,8 +941,6 @@ class Record_WKS(tputil.FancyEqMixin, tputil.FancyStrMixin):
@ivar ttl: The maximum number of seconds which this record should be
cached.
"""
- implements(IEncodable, IRecord)
-
fancybasename = "WKS"
compareAttributes = ('address', 'protocol', 'map', 'ttl')
showAttributes = [('_address', 'address', '%s'), 'protocol', 'ttl']
@@ -891,6 +972,7 @@ def __hash__(self):
+@implementer(IEncodable, IRecord)
class Record_AAAA(tputil.FancyEqMixin, tputil.FancyStrMixin):
"""
An IPv6 host address.
@@ -905,7 +987,6 @@ class Record_AAAA(tputil.FancyEqMixin, tputil.FancyStrMixin):
@see: U{http://www.faqs.org/rfcs/rfc1886.html}
"""
- implements(IEncodable, IRecord)
TYPE = AAAA
fancybasename = 'AAAA'
@@ -914,7 +995,7 @@ class Record_AAAA(tputil.FancyEqMixin, tputil.FancyStrMixin):
_address = property(lambda self: socket.inet_ntop(AF_INET6, self.address))
- def __init__(self, address = '::', ttl=None):
+ def __init__(self, address='::', ttl=None):
self.address = socket.inet_pton(AF_INET6, address)
self.ttl = str2time(ttl)
@@ -932,6 +1013,7 @@ def __hash__(self):
+@implementer(IEncodable, IRecord)
class Record_A6(tputil.FancyStrMixin, tputil.FancyEqMixin):
"""
An IPv6 address.
@@ -959,7 +1041,6 @@ class Record_A6(tputil.FancyStrMixin, tputil.FancyEqMixin):
@see: U{http://www.faqs.org/rfcs/rfc3363.html}
@see: U{http://www.faqs.org/rfcs/rfc3364.html}
"""
- implements(IEncodable, IRecord)
TYPE = A6
fancybasename = 'A6'
@@ -968,7 +1049,7 @@ class Record_A6(tputil.FancyStrMixin, tputil.FancyEqMixin):
_suffix = property(lambda self: socket.inet_ntop(AF_INET6, self.suffix))
- def __init__(self, prefixLen=0, suffix='::', prefix='', ttl=None):
+ def __init__(self, prefixLen=0, suffix='::', prefix=b'', ttl=None):
self.prefixLen = prefixLen
self.suffix = socket.inet_pton(AF_INET6, suffix)
self.prefix = Name(prefix)
@@ -989,7 +1070,7 @@ def decode(self, strio, length = None):
self.prefixLen = struct.unpack('!B', readPrecisely(strio, 1))[0]
self.bytes = int((128 - self.prefixLen) / 8.0)
if self.bytes:
- self.suffix = '\x00' * (16 - self.bytes) + readPrecisely(strio, self.bytes)
+ self.suffix = b'\x00' * (16 - self.bytes) + readPrecisely(strio, self.bytes)
if self.prefixLen:
self.prefix.decode(strio)
@@ -1016,6 +1097,7 @@ def __str__(self):
+@implementer(IEncodable, IRecord)
class Record_SRV(tputil.FancyEqMixin, tputil.FancyStrMixin):
"""
The location of the server(s) for a specific protocol and domain.
@@ -1050,14 +1132,13 @@ class Record_SRV(tputil.FancyEqMixin, tputil.FancyStrMixin):
@see: U{http://www.faqs.org/rfcs/rfc2782.html}
"""
- implements(IEncodable, IRecord)
TYPE = SRV
fancybasename = 'SRV'
compareAttributes = ('priority', 'weight', 'target', 'port', 'ttl')
showAttributes = ('priority', 'weight', ('target', 'target', '%s'), 'port', 'ttl')
- def __init__(self, priority=0, weight=0, port=0, target='', ttl=None):
+ def __init__(self, priority=0, weight=0, port=0, target=b'', ttl=None):
self.priority = int(priority)
self.weight = int(weight)
self.port = int(port)
@@ -1083,6 +1164,7 @@ def __hash__(self):
+@implementer(IEncodable, IRecord)
class Record_NAPTR(tputil.FancyEqMixin, tputil.FancyStrMixin):
"""
The location of the server(s) for a specific protocol and domain.
@@ -1100,7 +1182,7 @@ class Record_NAPTR(tputil.FancyEqMixin, tputil.FancyStrMixin):
@type flag: L{Charstr}
@ivar flag: A <character-string> containing flags to control aspects of the
rewriting and interpretation of the fields in the record. Flags
- aresingle characters from the set [A-Z0-9]. The case of the alphabetic
+ are single characters from the set [A-Z0-9]. The case of the alphabetic
characters is not significant.
At this time only four flags, "S", "A", "U", and "P", are defined.
@@ -1127,18 +1209,18 @@ class Record_NAPTR(tputil.FancyEqMixin, tputil.FancyStrMixin):
@see: U{http://www.faqs.org/rfcs/rfc2915.html}
"""
- implements(IEncodable, IRecord)
TYPE = NAPTR
compareAttributes = ('order', 'preference', 'flags', 'service', 'regexp',
'replacement')
fancybasename = 'NAPTR'
+
showAttributes = ('order', 'preference', ('flags', 'flags', '%s'),
('service', 'service', '%s'), ('regexp', 'regexp', '%s'),
('replacement', 'replacement', '%s'), 'ttl')
- def __init__(self, order=0, preference=0, flags='', service='', regexp='',
- replacement='', ttl=None):
+ def __init__(self, order=0, preference=0, flags=b'', service=b'', regexp=b'',
+ replacement=b'', ttl=None):
self.order = int(order)
self.preference = int(preference)
self.flags = Charstr(flags)
@@ -1177,6 +1259,7 @@ def __hash__(self):
+@implementer(IEncodable, IRecord)
class Record_AFSDB(tputil.FancyStrMixin, tputil.FancyEqMixin):
"""
Map from a domain name to the name of an AFS cell database server.
@@ -1197,14 +1280,13 @@ class Record_AFSDB(tputil.FancyStrMixin, tputil.FancyEqMixin):
@see: U{http://www.faqs.org/rfcs/rfc1183.html}
"""
- implements(IEncodable, IRecord)
TYPE = AFSDB
fancybasename = 'AFSDB'
compareAttributes = ('subtype', 'hostname', 'ttl')
showAttributes = ('subtype', ('hostname', 'hostname', '%s'), 'ttl')
- def __init__(self, subtype=0, hostname='', ttl=None):
+ def __init__(self, subtype=0, hostname=b'', ttl=None):
self.subtype = int(subtype)
self.hostname = Name(hostname)
self.ttl = str2time(ttl)
@@ -1226,6 +1308,7 @@ def __hash__(self):
+@implementer(IEncodable, IRecord)
class Record_RP(tputil.FancyEqMixin, tputil.FancyStrMixin):
"""
The responsible person for a domain.
@@ -1244,14 +1327,13 @@ class Record_RP(tputil.FancyEqMixin, tputil.FancyStrMixin):
@see: U{http://www.faqs.org/rfcs/rfc1183.html}
"""
- implements(IEncodable, IRecord)
TYPE = RP
fancybasename = 'RP'
compareAttributes = ('mbox', 'txt', 'ttl')
showAttributes = (('mbox', 'mbox', '%s'), ('txt', 'txt', '%s'), 'ttl')
- def __init__(self, mbox='', txt='', ttl=None):
+ def __init__(self, mbox=b'', txt=b'', ttl=None):
self.mbox = Name(mbox)
self.txt = Name(txt)
self.ttl = str2time(ttl)
@@ -1274,6 +1356,7 @@ def __hash__(self):
+@implementer(IEncodable, IRecord)
class Record_HINFO(tputil.FancyStrMixin, tputil.FancyEqMixin):
"""
Host information.
@@ -1288,11 +1371,11 @@ class Record_HINFO(tputil.FancyStrMixin, tputil.FancyEqMixin):
@ivar ttl: The maximum number of seconds which this record should be
cached.
"""
- implements(IEncodable, IRecord)
TYPE = HINFO
fancybasename = 'HINFO'
- showAttributes = compareAttributes = ('cpu', 'os', 'ttl')
+ showAttributes = (('cpu', _nicebytes), ('os', _nicebytes), 'ttl')
+ compareAttributes = ('cpu', 'os', 'ttl')
def __init__(self, cpu='', os='', ttl=None):
self.cpu, self.os = cpu, os
@@ -1324,6 +1407,7 @@ def __hash__(self):
+@implementer(IEncodable, IRecord)
class Record_MINFO(tputil.FancyEqMixin, tputil.FancyStrMixin):
"""
Mailbox or mail list information.
@@ -1345,7 +1429,6 @@ class Record_MINFO(tputil.FancyEqMixin, tputil.FancyStrMixin):
@ivar ttl: The maximum number of seconds which this record should be
cached.
"""
- implements(IEncodable, IRecord)
TYPE = MINFO
rmailbx = None
@@ -1357,7 +1440,7 @@ class Record_MINFO(tputil.FancyEqMixin, tputil.FancyStrMixin):
('emailbx', 'errors', '%s'),
'ttl')
- def __init__(self, rmailbx='', emailbx='', ttl=None):
+ def __init__(self, rmailbx=b'', emailbx=b'', ttl=None):
self.rmailbx, self.emailbx = Name(rmailbx), Name(emailbx)
self.ttl = str2time(ttl)
@@ -1378,6 +1461,7 @@ def __hash__(self):
+@implementer(IEncodable, IRecord)
class Record_MX(tputil.FancyStrMixin, tputil.FancyEqMixin):
"""
Mail exchange.
@@ -1394,14 +1478,13 @@ class Record_MX(tputil.FancyStrMixin, tputil.FancyEqMixin):
@ivar ttl: The maximum number of seconds which this record should be
cached.
"""
- implements(IEncodable, IRecord)
TYPE = MX
fancybasename = 'MX'
compareAttributes = ('preference', 'name', 'ttl')
showAttributes = ('preference', ('name', 'name', '%s'), 'ttl')
- def __init__(self, preference=0, name='', ttl=None, **kwargs):
+ def __init__(self, preference=0, name=b'', ttl=None, **kwargs):
self.preference, self.name = int(preference), Name(kwargs.get('exchange', name))
self.ttl = str2time(ttl)
@@ -1420,36 +1503,35 @@ def __hash__(self):
-# Oh god, Record_TXT how I hate thee.
+@implementer(IEncodable, IRecord)
class Record_TXT(tputil.FancyEqMixin, tputil.FancyStrMixin):
"""
Freeform text.
- @type data: C{list} of C{str}
+ @type data: C{list} of C{bytes}
@ivar data: Freeform text which makes up this record.
@type ttl: C{int}
@ivar ttl: The maximum number of seconds which this record should be cached.
"""
- implements(IEncodable, IRecord)
-
TYPE = TXT
fancybasename = 'TXT'
- showAttributes = compareAttributes = ('data', 'ttl')
+ showAttributes = (('data', _nicebyteslist), 'ttl')
+ compareAttributes = ('data', 'ttl')
def __init__(self, *data, **kw):
self.data = list(data)
# arg man python sucks so bad
self.ttl = str2time(kw.get('ttl', None))
- def encode(self, strio, compDict = None):
+ def encode(self, strio, compDict=None):
for d in self.data:
strio.write(struct.pack('!B', len(d)) + d)
- def decode(self, strio, length = None):
+ def decode(self, strio, length=None):
soFar = 0
self.data = []
while soFar < length:
@@ -1468,28 +1550,25 @@ def __hash__(self):
return hash(tuple(self.data))
-
-# This is a fallback record
+@implementer(IEncodable, IRecord)
class UnknownRecord(tputil.FancyEqMixin, tputil.FancyStrMixin, object):
"""
- Encapsulate the wire data for unkown record types so that they can
+ Encapsulate the wire data for unknown record types so that they can
pass through the system unchanged.
- @type data: C{str}
+ @type data: C{bytes}
@ivar data: Wire data which makes up this record.
@type ttl: C{int}
@ivar ttl: The maximum number of seconds which this record should be cached.
@since: 11.1
"""
- implements(IEncodable, IRecord)
-
fancybasename = 'UNKNOWN'
compareAttributes = ('data', 'ttl')
- showAttributes = ('data', 'ttl')
+ showAttributes = (('data', _nicebytes), 'ttl')
- def __init__(self, data='', ttl=None):
+ def __init__(self, data=b'', ttl=None):
self.data = data
self.ttl = str2time(ttl)
@@ -1565,7 +1644,7 @@ def addQuery(self, name, type=ALL_RECORDS, cls=IN):
"""
Add another query to this Message.
- @type name: C{str}
+ @type name: C{bytes}
@param name: The name to query.
@type type: C{int}
@@ -1579,7 +1658,7 @@ def addQuery(self, name, type=ALL_RECORDS, cls=IN):
def encode(self, strio):
compDict = {}
- body_tmp = StringIO.StringIO()
+ body_tmp = BytesIO()
for q in self.queries:
q.encode(body_tmp, compDict)
for q in self.answers:
@@ -1681,13 +1760,25 @@ def lookupRecordType(self, type):
def toStr(self):
- strio = StringIO.StringIO()
+ """
+ Encode this L{Message} into a byte string in the format described by RFC
+ 1035.
+
+ @rtype: C{bytes}
+ """
+ strio = BytesIO()
self.encode(strio)
return strio.getvalue()
def fromStr(self, str):
- strio = StringIO.StringIO(str)
+ """
+ Decode a byte string in the format described by RFC 1035 into this
+ L{Message}.
+
+ @param str: L{bytes}
+ """
+ strio = BytesIO(str)
self.decode(strio)
@@ -1882,7 +1973,7 @@ class DNSProtocol(DNSMixin, protocol.Protocol):
DNS protocol over TCP.
"""
length = None
- buffer = ''
+ buffer = b''
def writeMessage(self, message):
"""
View
825 twisted/names/test/test_dns.py
@@ -6,7 +6,9 @@
Tests for twisted.names.dns.
"""
-from cStringIO import StringIO
+from __future__ import division, absolute_import
+
+from io import BytesIO
import struct
@@ -27,30 +29,124 @@
dns.Record_AAAA, dns.Record_A6, dns.Record_NAPTR, dns.UnknownRecord,
]
+
+class Ord2ByteTests(unittest.TestCase):
+ """
+ Tests for L{dns._ord2bytes}.
+ """
+ def test_ord2byte(self):
+ """
+ L{dns._ord2byte} accepts an integer and returns a byte string of length
+ one with an ordinal value equal to the given integer.
+ """
+ self.assertEqual(b'\x10', dns._ord2bytes(0x10))
+
+
+
+class Str2TimeTests(unittest.TestCase):
+ """
+ Tests for L{dns.str2name}.
+ """
+ def test_nonString(self):
+ """
+ When passed a non-string object, L{dns.str2name} returns it unmodified.
+ """
+ time = object()
+ self.assertIdentical(time, dns.str2time(time))
+
+
+ def test_seconds(self):
+ """
+ Passed a string giving a number of seconds, L{dns.str2time} returns the
+ number of seconds represented. For example, C{"10S"} represents C{10}
+ seconds.
+ """
+ self.assertEqual(10, dns.str2time("10S"))
+
+
+ def test_minutes(self):
+ """
+ Like C{test_seconds}, but for the C{"M"} suffix which multiplies the
+ time value by C{60} (the number of seconds in a minute!).
+ """
+ self.assertEqual(2 * 60, dns.str2time("2M"))
+
+
+ def test_hours(self):
+ """
+ Like C{test_seconds}, but for the C{"H"} suffix which multiplies the
+ time value by C{3600}, the number of seconds in an hour.
+ """
+ self.assertEqual(3 * 3600, dns.str2time("3H"))
+
+
+ def test_days(self):
+ """
+ Like L{test_seconds}, but for the C{"D"} suffix which multiplies the
+ time value by C{86400}, the number of seconds in a day.
+ """
+ self.assertEqual(4 * 86400, dns.str2time("4D"))
+
+
+ def test_weeks(self):
+ """
+ Like L{test_seconds}, but for the C{"W"} suffix which multiplies the
+ time value by C{604800}, the number of seconds in a week.
+ """
+ self.assertEqual(5 * 604800, dns.str2time("5W"))
+
+
+ def test_years(self):
+ """
+ Like L{test_seconds}, but for the C{"Y"} suffix which multiplies the
+ time value by C{31536000}, the number of seconds in a year.
+ """
+ self.assertEqual(6 * 31536000, dns.str2time("6Y"))
+
+
+ def test_invalidPrefix(self):
+ """
+ If a non-integer prefix is given, L{dns.str2time} raises L{ValueError}.
+ """
+ self.assertRaises(ValueError, dns.str2time, "fooS")
+
+
+
class NameTests(unittest.TestCase):
"""
Tests for L{Name}, the representation of a single domain name with support
for encoding into and decoding from DNS message format.
"""
+ def test_nonStringName(self):
+ """
+ When constructed with a name which is neither C{bytes} nor C{str},
+ L{Name} raises L{TypeError}.
+ """
+ self.assertRaises(TypeError, dns.Name, 123)
+ self.assertRaises(TypeError, dns.Name, object())
+ self.assertRaises(TypeError, dns.Name, [])
+ self.assertRaises(TypeError, dns.Name, u"text")
+
+
def test_decode(self):
"""
L{Name.decode} populates the L{Name} instance with name information read
from the file-like object passed to it.
"""
n = dns.Name()
- n.decode(StringIO("\x07example\x03com\x00"))
- self.assertEqual(n.name, "example.com")
+ n.decode(BytesIO(b"\x07example\x03com\x00"))
+ self.assertEqual(n.name, b"example.com")
def test_encode(self):
"""
L{Name.encode} encodes its name information and writes it to the
file-like object passed to it.
"""
- name = dns.Name("foo.example.com")
- stream = StringIO()
+ name = dns.Name(b"foo.example.com")
+ stream = BytesIO()
name.encode(stream)
- self.assertEqual(stream.getvalue(), "\x03foo\x07example\x03com\x00")
+ self.assertEqual(stream.getvalue(), b"\x03foo\x07example\x03com\x00")
def test_encodeWithCompression(self):
@@ -61,23 +157,23 @@ def test_encodeWithCompression(self):
output. It also updates the compression dictionary with the location of
the name it writes to the stream.
"""
- name = dns.Name("foo.example.com")
- compression = {"example.com": 0x17}
+ name = dns.Name(b"foo.example.com")
+ compression = {b"example.com": 0x17}
# Some bytes already encoded into the stream for this message
- previous = "some prefix to change .tell()"
- stream = StringIO()
+ previous = b"some prefix to change .tell()"
+ stream = BytesIO()
stream.write(previous)
# The position at which the encoded form of this new name will appear in
# the stream.
expected = len(previous) + dns.Message.headerSize
name.encode(stream, compression)
self.assertEqual(
- "\x03foo\xc0\x17",
+ b"\x03foo\xc0\x17",
stream.getvalue()[len(previous):])
self.assertEqual(
- {"example.com": 0x17, "foo.example.com": expected},
+ {b"example.com": 0x17, b"foo.example.com": expected},
compression)
@@ -89,50 +185,49 @@ def test_unknown(self):
was parsed from.
"""
wire = (
- '\x01\x00' # Message ID
- '\x00' # answer bit, opCode nibble, auth bit, trunc bit, recursive
- # bit
- '\x00' # recursion bit, empty bit, empty bit, empty bit, response
- # code nibble
- '\x00\x01' # number of queries
- '\x00\x01' # number of answers
- '\x00\x00' # number of authorities
- '\x00\x01' # number of additionals
+ b'\x01\x00' # Message ID
+ b'\x00' # answer bit, opCode nibble, auth bit, trunc bit, recursive
+ # bit
+ b'\x00' # recursion bit, empty bit, empty bit, empty bit, response
+ # code nibble
+ b'\x00\x01' # number of queries
+ b'\x00\x01' # number of answers
+ b'\x00\x00' # number of authorities
+ b'\x00\x01' # number of additionals
# query
- '\x03foo\x03bar\x00' # foo.bar
- '\xde\xad' # type=0xdead
- '\xbe\xef' # cls=0xbeef
+ b'\x03foo\x03bar\x00' # foo.bar
+ b'\xde\xad' # type=0xdead
+ b'\xbe\xef' # cls=0xbeef
# 1st answer
- '\xc0\x0c' # foo.bar - compressed
- '\xde\xad' # type=0xdead
- '\xbe\xef' # cls=0xbeef
- '\x00\x00\x01\x01' # ttl=257
- '\x00\x08somedata' # some payload data
+ b'\xc0\x0c' # foo.bar - compressed
+ b'\xde\xad' # type=0xdead
+ b'\xbe\xef' # cls=0xbeef
+ b'\x00\x00\x01\x01' # ttl=257
+ b'\x00\x08somedata' # some payload data
# 1st additional
- '\x03baz\x03ban\x00' # baz.ban
- '\x00\x01' # type=A
- '\x00\x01' # cls=IN
- '\x00\x00\x01\x01' # ttl=257
- '\x00\x04' # len=4
- '\x01\x02\x03\x04' # 1.2.3.4
-
+ b'\x03baz\x03ban\x00' # baz.ban
+ b'\x00\x01' # type=A
+ b'\x00\x01' # cls=IN
+ b'\x00\x00\x01\x01' # ttl=257
+ b'\x00\x04' # len=4
+ b'\x01\x02\x03\x04' # 1.2.3.4
)
msg = dns.Message()
msg.fromStr(wire)
self.assertEqual(msg.queries, [
- dns.Query('foo.bar', type=0xdead, cls=0xbeef),
+ dns.Query(b'foo.bar', type=0xdead, cls=0xbeef),
])
self.assertEqual(msg.answers, [
- dns.RRHeader('foo.bar', type=0xdead, cls=0xbeef, ttl=257,
- payload=dns.UnknownRecord('somedata', ttl=257)),
+ dns.RRHeader(b'foo.bar', type=0xdead, cls=0xbeef, ttl=257,
+ payload=dns.UnknownRecord(b'somedata', ttl=257)),
])
self.assertEqual(msg.additional, [
- dns.RRHeader('baz.ban', type=dns.A, cls=dns.IN, ttl=257,
+ dns.RRHeader(b'baz.ban', type=dns.A, cls=dns.IN, ttl=257,
payload=dns.Record_A('1.2.3.4', ttl=257)),
])
@@ -149,27 +244,27 @@ def test_decodeWithCompression(self):
included in the name being decoded.
"""
# Slightly modified version of the example from RFC 1035, section 4.1.4.
- stream = StringIO(
- "x" * 20 +
- "\x01f\x03isi\x04arpa\x00"
- "\x03foo\xc0\x14"
- "\x03bar\xc0\x20")
+ stream = BytesIO(
+ b"x" * 20 +
+ b"\x01f\x03isi\x04arpa\x00"
+ b"\x03foo\xc0\x14"
+ b"\x03bar\xc0\x20")
stream.seek(20)
name = dns.Name()
name.decode(stream)
# Verify we found the first name in the stream and that the stream
# position is left at the first byte after the decoded name.
- self.assertEqual("f.isi.arpa", name.name)
+ self.assertEqual(b"f.isi.arpa", name.name)
self.assertEqual(32, stream.tell())
# Get the second name from the stream and make the same assertions.
name.decode(stream)
- self.assertEqual(name.name, "foo.f.isi.arpa")
+ self.assertEqual(name.name, b"foo.f.isi.arpa")
self.assertEqual(38, stream.tell())
# Get the third and final name
name.decode(stream)
- self.assertEqual(name.name, "bar.foo.f.isi.arpa")
+ self.assertEqual(name.name, b"bar.foo.f.isi.arpa")
self.assertEqual(44, stream.tell())
@@ -180,20 +275,22 @@ def test_rejectCompressionLoop(self):
undecodable.
"""
name = dns.Name()
- stream = StringIO("\xc0\x00")
+ stream = BytesIO(b"\xc0\x00")
self.assertRaises(ValueError, name.decode, stream)
class RoundtripDNSTestCase(unittest.TestCase):
- """Encoding and then decoding various objects."""
+ """
+ Encoding and then decoding various objects.
+ """
- names = ["example.org", "go-away.fish.tv", "23strikesback.net"]
+ names = [b"example.org", b"go-away.fish.tv", b"23strikesback.net"]
def testName(self):
for n in self.names:
# encode the name
- f = StringIO()
+ f = BytesIO()
dns.Name(n).encode(f)
# decode the name
@@ -202,12 +299,17 @@ def testName(self):
result.decode(f)
self.assertEqual(result.name, n)
- def testQuery(self):
+ def test_query(self):
+ """
+ L{dns.Query.encode} returns a byte string representing the fields of the
+ query which can be decoded into a new L{dns.Query} instance using
+ L{dns.Query.decode}.
+ """
for n in self.names:
for dnstype in range(1, 17):
for dnscls in range(1, 5):
# encode the query
- f = StringIO()
+ f = BytesIO()
dns.Query(n, dnstype, dnscls).encode(f)
# decode the result
@@ -218,36 +320,49 @@ def testQuery(self):
self.assertEqual(result.type, dnstype)
self.assertEqual(result.cls, dnscls)
- def testRR(self):
+ def test_resourceRecordHeader(self):
+ """
+ L{dns.RRHeader.encode} encodes the record header's information and
+ writes it to the file-like object passed to it and
+ L{dns.RRHeader.decode} reads from a file-like object to re-construct a
+ L{dns.RRHeader} instance.
+ """
# encode the RR
- f = StringIO()
- dns.RRHeader("test.org", 3, 4, 17).encode(f)
+ f = BytesIO()
+ dns.RRHeader(b"test.org", 3, 4, 17).encode(f)
# decode the result
f.seek(0, 0)
result = dns.RRHeader()
result.decode(f)
- self.assertEqual(str(result.name), "test.org")
+ self.assertEqual(result.name, dns.Name(b"test.org"))
self.assertEqual(result.type, 3)
self.assertEqual(result.cls, 4)
self.assertEqual(result.ttl, 17)
- def testResources(self):
+ def test_resources(self):
+ """
+ L{dns.SimpleRecord.encode} encodes the record's name information and
+ writes it to the file-like object passed to it and
+ L{dns.SimpleRecord.decode} reads from a file-like object to re-construct
+ a L{dns.SimpleRecord} instance.
+ """
names = (
- "this.are.test.name",
- "will.compress.will.this.will.name.will.hopefully",
- "test.CASE.preSErVatIOn.YeAH",
- "a.s.h.o.r.t.c.a.s.e.t.o.t.e.s.t",
- "singleton"
+ b"this.are.test.name",
+ b"will.compress.will.this.will.name.will.hopefully",
+ b"test.CASE.preSErVatIOn.YeAH",
+ b"a.s.h.o.r.t.c.a.s.e.t.o.t.e.s.t",
+ b"singleton"
)
for s in names:
- f = StringIO()
+ f = BytesIO()
dns.SimpleRecord(s).encode(f)
f.seek(0, 0)
result = dns.SimpleRecord()
result.decode(f)
- self.assertEqual(str(result.name), s)
+ self.assertEqual(result.name, dns.Name(s))
+
def test_hashable(self):
"""
@@ -266,7 +381,7 @@ def test_Charstr(self):
"""
for n in self.names:
# encode the name
- f = StringIO()
+ f = BytesIO()
dns.Charstr(n).encode(f)
# decode the name
@@ -276,21 +391,104 @@ def test_Charstr(self):
self.assertEqual(result.string, n)
+ def _recordRoundtripTest(self, record):
+ """
+ Assert that encoding C{record} and then decoding the resulting bytes
+ creates a record which compares equal to C{record}.
+ """
+ stream = BytesIO()
+ record.encode(stream)
+
+ length = stream.tell()
+ stream.seek(0, 0)
+ replica = record.__class__()
+ replica.decode(stream, length)
+ self.assertEqual(record, replica)
+
+
+ def test_SOA(self):
+ """
+ The byte stream written by L{dns.Record_SOA.encode} can be used by
+ L{dns.Record_SOA.decode} to reconstruct the state of the original
+ L{dns.Record_SOA} instance.
+ """
+ self._recordRoundtripTest(
+ dns.Record_SOA(
+ mname=b'foo', rname=b'bar', serial=12, refresh=34,
+ retry=56, expire=78, minimum=90))
+
+
+ def test_A(self):
+ """
+ The byte stream written by L{dns.Record_A.encode} can be used by
+ L{dns.Record_A.decode} to reconstruct the state of the original
+ L{dns.Record_A} instance.
+ """
+ self._recordRoundtripTest(dns.Record_A('1.2.3.4'))
+
+
+ def test_NULL(self):
+ """
+ The byte stream written by L{dns.Record_NULL.encode} can be used by
+ L{dns.Record_NULL.decode} to reconstruct the state of the original
+ L{dns.Record_NULL} instance.
+ """
+ self._recordRoundtripTest(dns.Record_NULL(b'foo bar'))
+
+
+ def test_WKS(self):
+ """
+ The byte stream written by L{dns.Record_WKS.encode} can be used by
+ L{dns.Record_WKS.decode} to reconstruct the state of the original
+ L{dns.Record_WKS} instance.
+ """
+ self._recordRoundtripTest(dns.Record_WKS('1.2.3.4', 3, b'xyz'))
+
+
+ def test_AAAA(self):
+ """
+ The byte stream written by L{dns.Record_AAAA.encode} can be used by
+ L{dns.Record_AAAA.decode} to reconstruct the state of the original
+ L{dns.Record_AAAA} instance.
+ """
+ self._recordRoundtripTest(dns.Record_AAAA('::1'))
+
+
+ def test_A6(self):
+ """
+ The byte stream written by L{dns.Record_A6.encode} can be used by
+ L{dns.Record_A6.decode} to reconstruct the state of the original
+ L{dns.Record_A6} instance.
+ """
+ self._recordRoundtripTest(dns.Record_A6(8, '::1:2', b'foo'))
+
+
+ def test_SRV(self):
+ """
+ The byte stream written by L{dns.Record_SRV.encode} can be used by
+ L{dns.Record_SRV.decode} to reconstruct the state of the original
+ L{dns.Record_SRV} instance.
+ """
+ self._recordRoundtripTest(dns.Record_SRV(
+ priority=1, weight=2, port=3, target=b'example.com'))
+
+
def test_NAPTR(self):
"""
Test L{dns.Record_NAPTR} encode and decode.
"""
- naptrs = [(100, 10, "u", "sip+E2U",
- "!^.*$!sip:information@domain.tld!", ""),
- (100, 50, "s", "http+I2L+I2C+I2R", "",
- "_http._tcp.gatech.edu")]
+ naptrs = [
+ (100, 10, b"u", b"sip+E2U",
+ b"!^.*$!sip:information@domain.tld!", b""),
+ (100, 50, b"s", b"http+I2L+I2C+I2R",
+ b"", b"_http._tcp.gatech.edu")]
for (order, preference, flags, service, regexp, replacement) in naptrs:
rin = dns.Record_NAPTR(order, preference, flags, service, regexp,
replacement)
- e = StringIO()
+ e = BytesIO()
rin.encode(e)
- e.seek(0,0)
+ e.seek(0, 0)
rout = dns.Record_NAPTR()
rout.decode(e)
self.assertEqual(rin.order, rout.order)
@@ -302,8 +500,66 @@ def test_NAPTR(self):
self.assertEqual(rin.ttl, rout.ttl)
+ def test_AFSDB(self):
+ """
+ The byte stream written by L{dns.Record_AFSDB.encode} can be used by
+ L{dns.Record_AFSDB.decode} to reconstruct the state of the original
+ L{dns.Record_AFSDB} instance.
+ """
+ self._recordRoundtripTest(dns.Record_AFSDB(
+ subtype=3, hostname=b'example.com'))
+
+
+ def test_RP(self):
+ """
+ The byte stream written by L{dns.Record_RP.encode} can be used by
+ L{dns.Record_RP.decode} to reconstruct the state of the original
+ L{dns.Record_RP} instance.
+ """
+ self._recordRoundtripTest(dns.Record_RP(
+ mbox=b'alice.example.com', txt=b'example.com'))
+
-class MessageTestCase(unittest.TestCase):
+ def test_HINFO(self):
+ """
+ The byte stream written by L{dns.Record_HINFO.encode} can be used by
+ L{dns.Record_HINFO.decode} to reconstruct the state of the original
+ L{dns.Record_HINFO} instance.
+ """
+ self._recordRoundtripTest(dns.Record_HINFO(cpu=b'fast', os=b'great'))
+
+
+ def test_MINFO(self):
+ """
+ The byte stream written by L{dns.Record_MINFO.encode} can be used by
+ L{dns.Record_MINFO.decode} to reconstruct the state of the original
+ L{dns.Record_MINFO} instance.
+ """
+ self._recordRoundtripTest(dns.Record_MINFO(
+ rmailbx=b'foo', emailbx=b'bar'))
+
+
+ def test_MX(self):
+ """
+ The byte stream written by L{dns.Record_MX.encode} can be used by
+ L{dns.Record_MX.decode} to reconstruct the state of the original
+ L{dns.Record_MX} instance.
+ """
+ self._recordRoundtripTest(dns.Record_MX(
+ preference=1, name=b'example.com'))
+
+
+ def test_TXT(self):
+ """
+ The byte stream written by L{dns.Record_TXT.encode} can be used by
+ L{dns.Record_TXT.decode} to reconstruct the state of the original
+ L{dns.Record_TXT} instance.
+ """
+ self._recordRoundtripTest(dns.Record_TXT(b'foo', b'bar'))
+
+
+
+class MessageTestCase(unittest.SynchronousTestCase):
"""
Tests for L{twisted.names.dns.Message}.
"""
@@ -314,23 +570,23 @@ def testEmptyMessage(self):
be raised when it is parsed.
"""
msg = dns.Message()
- self.assertRaises(EOFError, msg.fromStr, '')
+ self.assertRaises(EOFError, msg.fromStr, b'')
- def testEmptyQuery(self):
+ def test_emptyQuery(self):
"""
Test that bytes representing an empty query message can be decoded
as such.
"""
msg = dns.Message()
msg.fromStr(
- '\x01\x00' # Message ID
- '\x00' # answer bit, opCode nibble, auth bit, trunc bit, recursive bit
- '\x00' # recursion bit, empty bit, empty bit, empty bit, response code nibble
- '\x00\x00' # number of queries
- '\x00\x00' # number of answers
- '\x00\x00' # number of authorities
- '\x00\x00' # number of additionals
+ b'\x01\x00' # Message ID
+ b'\x00' # answer bit, opCode nibble, auth bit, trunc bit, recursive bit
+ b'\x00' # recursion bit, empty bit, empty bit, empty bit, response code nibble
+ b'\x00\x00' # number of queries
+ b'\x00\x00' # number of answers
+ b'\x00\x00' # number of authorities
+ b'\x00\x00' # number of additionals
)
self.assertEqual(msg.id, 256)
self.failIf(msg.answer, "Message was not supposed to be an answer.")
@@ -343,13 +599,17 @@ def testEmptyQuery(self):
self.assertEqual(msg.additional, [])
- def testNULL(self):
- bytes = ''.join([chr(i) for i in range(256)])
+ def test_NULL(self):
+ """
+ A I{NULL} record with an arbitrary payload can be encoded and decoded as
+ part of a L{dns.Message}.
+ """
+ bytes = b''.join([dns._ord2bytes(i) for i in range(256)])
rec = dns.Record_NULL(bytes)
- rr = dns.RRHeader('testname', dns.NULL, payload=rec)
+ rr = dns.RRHeader(b'testname', dns.NULL, payload=rec)
msg1 = dns.Message()
msg1.answers.append(rr)
- s = StringIO()
+ s = BytesIO()
msg1.encode(s)
s.seek(0, 0)
msg2 = dns.Message()
@@ -377,21 +637,21 @@ def test_nonAuthoritativeMessage(self):
The L{RRHeader} instances created by L{Message} from a non-authoritative
message are marked as not authoritative.
"""
- buf = StringIO()
+ buf = BytesIO()
answer = dns.RRHeader(payload=dns.Record_A('1.2.3.4', ttl=0))
answer.encode(buf)
message = dns.Message()
message.fromStr(
- '\x01\x00' # Message ID
+ b'\x01\x00' # Message ID
# answer bit, opCode nibble, auth bit, trunc bit, recursive bit
- '\x00'
+ b'\x00'
# recursion bit, empty bit, empty bit, empty bit, response code
# nibble
- '\x00'
- '\x00\x00' # number of queries
- '\x00\x01' # number of answers
- '\x00\x00' # number of authorities
- '\x00\x00' # number of additionals
+ b'\x00'
+ b'\x00\x00' # number of queries
+ b'\x00\x01' # number of answers
+ b'\x00\x00' # number of authorities
+ b'\x00\x00' # number of additionals
+ buf.getvalue()
)
self.assertEqual(message.answers, [answer])
@@ -403,21 +663,21 @@ def test_authoritativeMessage(self):
The L{RRHeader} instances created by L{Message} from an authoritative
message are marked as authoritative.
"""
- buf = StringIO()
+ buf = BytesIO()
answer = dns.RRHeader(payload=dns.Record_A('1.2.3.4', ttl=0))
answer.encode(buf)
message = dns.Message()
message.fromStr(
- '\x01\x00' # Message ID
+ b'\x01\x00' # Message ID
# answer bit, opCode nibble, auth bit, trunc bit, recursive bit
- '\x04'
+ b'\x04'
# recursion bit, empty bit, empty bit, empty bit, response code
# nibble
- '\x00'
- '\x00\x00' # number of queries
- '\x00\x01' # number of answers
- '\x00\x00' # number of authorities
- '\x00\x00' # number of additionals
+ b'\x00'
+ b'\x00\x00' # number of queries
+ b'\x00\x01' # number of answers
+ b'\x00\x00' # number of authorities
+ b'\x00\x00' # number of additionals
+ buf.getvalue()
)
answer.auth = True
@@ -471,19 +731,19 @@ def test_truncatedPacket(self):
Test that when a short datagram is received, datagramReceived does
not raise an exception while processing it.
"""
- self.proto.datagramReceived('',
- address.IPv4Address('UDP', '127.0.0.1', 12345))
+ self.proto.datagramReceived(
+ b'', address.IPv4Address('UDP', '127.0.0.1', 12345))
self.assertEqual(self.controller.messages, [])
def test_simpleQuery(self):
"""
Test content received after a query.
"""
- d = self.proto.query(('127.0.0.1', 21345), [dns.Query('foo')])
+ d = self.proto.query(('127.0.0.1', 21345), [dns.Query(b'foo')])
self.assertEqual(len(self.proto.liveMessages.keys()), 1)
m = dns.Message()
- m.id = self.proto.liveMessages.items()[0][0]
+ m.id = next(iter(self.proto.liveMessages.keys()))
m.answers = [dns.RRHeader(payload=dns.Record_A(address='1.2.3.4'))]
def cb(result):
self.assertEqual(result.answers[0].payload.dottedQuad(), '1.2.3.4')
@@ -496,7 +756,7 @@ def test_queryTimeout(self):
"""
Test that query timeouts after some seconds.
"""
- d = self.proto.query(('127.0.0.1', 21345), [dns.Query('foo')])
+ d = self.proto.query(('127.0.0.1', 21345), [dns.Query(b'foo')])
self.assertEqual(len(self.proto.liveMessages), 1)
self.clock.advance(10)
self.assertFailure(d, dns.DNSQueryTimeoutError)
@@ -514,7 +774,7 @@ def writeError(message, addr):
raise RuntimeError("bar")
self.proto.transport.write = writeError
- d = self.proto.query(('127.0.0.1', 21345), [dns.Query('foo')])
+ d = self.proto.query(('127.0.0.1', 21345), [dns.Query(b'foo')])
return self.assertFailure(d, RuntimeError)
@@ -530,7 +790,7 @@ def startListeningError():
# Clean up transport so that the protocol calls startListening again
self.proto.transport = None
- d = self.proto.query(('127.0.0.1', 21345), [dns.Query('foo')])
+ d = self.proto.query(('127.0.0.1', 21345), [dns.Query(b'foo')])
return self.assertFailure(d, CannotListenError)
@@ -589,7 +849,7 @@ def test_queryTimeout(self):
"""
Test that query timeouts after some seconds.
"""
- d = self.proto.query([dns.Query('foo')])
+ d = self.proto.query([dns.Query(b'foo')])
self.assertEqual(len(self.proto.liveMessages), 1)
self.clock.advance(60)
self.assertFailure(d, dns.DNSQueryTimeoutError)
@@ -601,10 +861,10 @@ def test_simpleQuery(self):
"""
Test content received after a query.
"""
- d = self.proto.query([dns.Query('foo')])
+ d = self.proto.query([dns.Query(b'foo')])
self.assertEqual(len(self.proto.liveMessages.keys()), 1)
m = dns.Message()
- m.id = self.proto.liveMessages.items()[0][0]
+ m.id = next(iter(self.proto.liveMessages.keys()))
m.answers = [dns.RRHeader(payload=dns.Record_A(address='1.2.3.4'))]
def cb(result):
self.assertEqual(result.answers[0].payload.dottedQuad(), '1.2.3.4')
@@ -625,7 +885,7 @@ def writeError(message):
raise RuntimeError("bar")
self.proto.transport.write = writeError
- d = self.proto.query([dns.Query('foo')])
+ d = self.proto.query([dns.Query(b'foo')])
return self.assertFailure(d, RuntimeError)
@@ -640,7 +900,7 @@ def test_ns(self):
nameserver and the TTL of the record.
"""
self.assertEqual(
- repr(dns.Record_NS('example.com', 4321)),
+ repr(dns.Record_NS(b'example.com', 4321)),
"<NS name=example.com ttl=4321>")
@@ -650,7 +910,7 @@ def test_md(self):
mail destination and the TTL of the record.
"""
self.assertEqual(
- repr(dns.Record_MD('example.com', 4321)),
+ repr(dns.Record_MD(b'example.com', 4321)),
"<MD name=example.com ttl=4321>")
@@ -660,7 +920,7 @@ def test_mf(self):
mail forwarder and the TTL of the record.
"""
self.assertEqual(
- repr(dns.Record_MF('example.com', 4321)),
+ repr(dns.Record_MF(b'example.com', 4321)),
"<MF name=example.com ttl=4321>")
@@ -670,7 +930,7 @@ def test_cname(self):
mail forwarder and the TTL of the record.
"""
self.assertEqual(
- repr(dns.Record_CNAME('example.com', 4321)),
+ repr(dns.Record_CNAME(b'example.com', 4321)),
"<CNAME name=example.com ttl=4321>")
@@ -680,7 +940,7 @@ def test_mb(self):
mailbox and the TTL of the record.
"""
self.assertEqual(
- repr(dns.Record_MB('example.com', 4321)),
+ repr(dns.Record_MB(b'example.com', 4321)),
"<MB name=example.com ttl=4321>")
@@ -690,7 +950,7 @@ def test_mg(self):
mail group memeber and the TTL of the record.
"""
self.assertEqual(
- repr(dns.Record_MG('example.com', 4321)),
+ repr(dns.Record_MG(b'example.com', 4321)),
"<MG name=example.com ttl=4321>")
@@ -700,7 +960,7 @@ def test_mr(self):
mail rename domain and the TTL of the record.
"""
self.assertEqual(
- repr(dns.Record_MR('example.com', 4321)),
+ repr(dns.Record_MR(b'example.com', 4321)),
"<MR name=example.com ttl=4321>")
@@ -710,7 +970,7 @@ def test_ptr(self):
pointer and the TTL of the record.
"""
self.assertEqual(
- repr(dns.Record_PTR('example.com', 4321)),
+ repr(dns.Record_PTR(b'example.com', 4321)),
"<PTR name=example.com ttl=4321>")
@@ -720,7 +980,7 @@ def test_dname(self):
non-terminal DNS name redirection and the TTL of the record.
"""
self.assertEqual(
- repr(dns.Record_DNAME('example.com', 4321)),
+ repr(dns.Record_DNAME(b'example.com', 4321)),
"<DNAME name=example.com ttl=4321>")
@@ -741,7 +1001,7 @@ def test_soa(self):
authority fields.
"""
self.assertEqual(
- repr(dns.Record_SOA(mname='mName', rname='rName', serial=123,
+ repr(dns.Record_SOA(mname=b'mName', rname=b'rName', serial=123,
refresh=456, retry=789, expire=10,
minimum=11, ttl=12)),
"<SOA mname=mName rname=rName serial=123 refresh=456 "
@@ -754,7 +1014,7 @@ def test_null(self):
payload and the TTL of the record.
"""
self.assertEqual(
- repr(dns.Record_NULL('abcd', 123)),
+ repr(dns.Record_NULL(b'abcd', 123)),
"<NULL payload='abcd' ttl=123>")
@@ -787,7 +1047,7 @@ def test_a6(self):
record.
"""
self.assertEqual(
- repr(dns.Record_A6(0, '1234::5678', 'foo.bar', ttl=10)),
+ repr(dns.Record_A6(0, '1234::5678', b'foo.bar', ttl=10)),
"<A6 suffix=1234::5678 prefix=foo.bar ttl=10>")
@@ -797,7 +1057,7 @@ def test_srv(self):
the target and the priority, weight, and TTL of the record.
"""
self.assertEqual(
- repr(dns.Record_SRV(1, 2, 3, 'example.org', 4)),
+ repr(dns.Record_SRV(1, 2, 3, b'example.org', 4)),
"<SRV priority=1 weight=2 target=example.org port=3 ttl=4>")
@@ -807,8 +1067,10 @@ def test_naptr(self):
preference, flags, service, regular expression, replacement, and TTL of
the record.
"""
+ record = dns.Record_NAPTR(
+ 5, 9, b"S", b"http", b"/foo/bar/i", b"baz", 3)
self.assertEqual(
- repr(dns.Record_NAPTR(5, 9, "S", "http", "/foo/bar/i", "baz", 3)),
+ repr(record),
"<NAPTR order=5 preference=9 flags=S service=http "
"regexp=/foo/bar/i replacement=baz ttl=3>")
@@ -819,7 +1081,7 @@ def test_afsdb(self):
hostname, and TTL of the record.
"""
self.assertEqual(
- repr(dns.Record_AFSDB(3, 'example.org', 5)),
+ repr(dns.Record_AFSDB(3, b'example.org', 5)),
"<AFSDB subtype=3 hostname=example.org ttl=5>")
@@ -829,7 +1091,7 @@ def test_rp(self):
fields of the record.
"""
self.assertEqual(
- repr(dns.Record_RP('alice.example.com', 'admin.example.com', 3)),
+ repr(dns.Record_RP(b'alice.example.com', b'admin.example.com', 3)),
"<RP mbox=alice.example.com txt=admin.example.com ttl=3>")
@@ -839,7 +1101,7 @@ def test_hinfo(self):
TTL fields of the record.
"""
self.assertEqual(
- repr(dns.Record_HINFO('sparc', 'minix', 12)),
+ repr(dns.Record_HINFO(b'sparc', b'minix', 12)),
"<HINFO cpu='sparc' os='minix' ttl=12>")
@@ -848,8 +1110,10 @@ def test_minfo(self):
The repr of a L{dns.Record_MINFO} instance includes the rmailbx,
emailbx, and TTL fields of the record.
"""
+ record = dns.Record_MINFO(
+ b'alice.example.com', b'bob.example.com', 15)
self.assertEqual(
- repr(dns.Record_MINFO('alice.example.com', 'bob.example.com', 15)),
+ repr(record),
"<MINFO responsibility=alice.example.com "
"errors=bob.example.com ttl=15>")
@@ -860,7 +1124,7 @@ def test_mx(self):
and TTL fields of the record.
"""
self.assertEqual(
- repr(dns.Record_MX(13, 'mx.example.com', 2)),
+ repr(dns.Record_MX(13, b'mx.example.com', 2)),
"<MX preference=13 name=mx.example.com ttl=2>")
@@ -870,7 +1134,7 @@ def test_txt(self):
fields of the record.
"""
self.assertEqual(
- repr(dns.Record_TXT("foo", "bar", ttl=15)),
+ repr(dns.Record_TXT(b"foo", b"bar", ttl=15)),
"<TXT data=['foo', 'bar'] ttl=15>")
@@ -880,7 +1144,7 @@ def test_spf(self):
fields of the record.
"""
self.assertEqual(
- repr(dns.Record_SPF("foo", "bar", ttl=15)),
+ repr(dns.Record_SPF(b"foo", b"bar", ttl=15)),
"<SPF data=['foo', 'bar'] ttl=15>")
@@ -890,7 +1154,7 @@ def test_unknown(self):
fields of the record.
"""
self.assertEqual(
- repr(dns.UnknownRecord("foo\x1fbar", 12)),
+ repr(dns.UnknownRecord(b"foo\x1fbar", 12)),
"<UNKNOWN data='foo\\x1fbar' ttl=12>")
@@ -949,17 +1213,42 @@ def _equalityTest(self, firstValueOne, secondValueOne, valueTwo):
self.assertTrue(firstValueOne != _NotEqual())
+ def test_charstr(self):
+ """
+ Two L{dns.Charstr} instances compare equal if and only if they have the
+ same string value.
+ """
+ self._equalityTest(
+ dns.Charstr(b'abc'), dns.Charstr(b'abc'), dns.Charstr(b'def'))
+
+
+ def test_name(self):
+ """
+ Two L{dns.Name} instances compare equal if and only if they have the
+ same name value.
+ """
+ self._equalityTest(
+ dns.Name(b'abc'), dns.Name(b'abc'), dns.Name(b'def'))
+
+
def _simpleEqualityTest(self, cls):
+ """
+ Assert that instances of C{cls} with the same attributes compare equal
+ to each other and instances with different attributes compare as not
+ equal.
+
+ @param cls: A L{dns.SimpleRecord} subclass.
+ """
# Vary the TTL
self._equalityTest(
- cls('example.com', 123),
- cls('example.com', 123),
- cls('example.com', 321))
+ cls(b'example.com', 123),
+ cls(b'example.com', 123),
+ cls(b'example.com', 321))
# Vary the name
self._equalityTest(
- cls('example.com', 123),
- cls('example.com', 123),
- cls('example.org', 123))
+ cls(b'example.com', 123),
+ cls(b'example.com', 123),
+ cls(b'example.org', 123))
def test_rrheader(self):
@@ -970,40 +1259,40 @@ def test_rrheader(self):
"""
# Vary the name
self._equalityTest(
- dns.RRHeader('example.com', payload=dns.Record_A('1.2.3.4')),
- dns.RRHeader('example.com', payload=dns.Record_A('1.2.3.4')),
- dns.RRHeader('example.org', payload=dns.Record_A('1.2.3.4')))
+ dns.RRHeader(b'example.com', payload=dns.Record_A('1.2.3.4')),
+ dns.RRHeader(b'example.com', payload=dns.Record_A('1.2.3.4')),
+ dns.RRHeader(b'example.org', payload=dns.Record_A('1.2.3.4')))
# Vary the payload
self._equalityTest(
- dns.RRHeader('example.com', payload=dns.Record_A('1.2.3.4')),
- dns.RRHeader('example.com', payload=dns.Record_A('1.2.3.4')),
- dns.RRHeader('example.com', payload=dns.Record_A('1.2.3.5')))
+ dns.RRHeader(b'example.com', payload=dns.Record_A('1.2.3.4')),
+ dns.RRHeader(b'example.com', payload=dns.Record_A('1.2.3.4')),
+ dns.RRHeader(b'example.com', payload=dns.Record_A('1.2.3.5')))
# Vary the type. Leave the payload as None so that we don't have to
# provide non-equal values.
self._equalityTest(
- dns.RRHeader('example.com', dns.A),
- dns.RRHeader('example.com', dns.A),
- dns.RRHeader('example.com', dns.MX))
+ dns.RRHeader(b'example.com', dns.A),
+ dns.RRHeader(b'example.com', dns.A),
+ dns.RRHeader(b'example.com', dns.MX))
# Probably not likely to come up. Most people use the internet.
self._equalityTest(
- dns.RRHeader('example.com', cls=dns.IN, payload=dns.Record_A('1.2.3.4')),
- dns.RRHeader('example.com', cls=dns.IN, payload=dns.Record_A('1.2.3.4')),
- dns.RRHeader('example.com', cls=dns.CS, payload=dns.Record_A('1.2.3.4')))
+ dns.RRHeader(b'example.com', cls=dns.IN, payload=dns.Record_A('1.2.3.4')),
+ dns.RRHeader(b'example.com', cls=dns.IN, payloa