diff --git a/dns/versioned.py b/dns/versioned.py index 686a83b0..42f2c814 100644 --- a/dns/versioned.py +++ b/dns/versioned.py @@ -11,12 +11,9 @@ import dns.exception import dns.immutable import dns.name -import dns.node import dns.rdataclass import dns.rdatatype -import dns.rdata import dns.rdtypes.ANY.SOA -import dns.transaction import dns.zone @@ -24,142 +21,13 @@ class UseTransaction(dns.exception.DNSException): """To alter a versioned zone, use a transaction.""" -class Version: - def __init__(self, zone, id): - self.zone = zone - self.id = id - self.nodes = {} - - def _validate_name(self, name): - if name.is_absolute(): - if not name.is_subdomain(self.zone.origin): - raise KeyError("name is not a subdomain of the zone origin") - if self.zone.relativize: - name = name.relativize(self.origin) - return name - - def get_node(self, name): - name = self._validate_name(name) - return self.nodes.get(name) - - def get_rdataset(self, name, rdtype, covers): - node = self.get_node(name) - if node is None: - return None - return node.get_rdataset(self.zone.rdclass, rdtype, covers) - - def items(self): - return self.nodes.items() # pylint: disable=dict-items-not-iterating - - -class WritableVersion(Version): - def __init__(self, zone, replacement=False): - # The zone._versions_lock must be held by our caller. - if len(zone._versions) > 0: - id = zone._versions[-1].id + 1 - else: - id = 1 - super().__init__(zone, id) - if not replacement: - # We copy the map, because that gives us a simple and thread-safe - # way of doing versions, and we have a garbage collector to help - # us. We only make new node objects if we actually change the - # node. - self.nodes.update(zone.nodes) - # We have to copy the zone origin as it may be None in the first - # version, and we don't want to mutate the zone until we commit. - self.origin = zone.origin - self.changed = set() - - def _maybe_cow(self, name): - name = self._validate_name(name) - node = self.nodes.get(name) - if node is None or node.id != self.id: - new_node = self.zone.node_factory() - new_node.id = self.id - if node is not None: - # moo! copy on write! - new_node.rdatasets.extend(node.rdatasets) - self.nodes[name] = new_node - self.changed.add(name) - return new_node - else: - return node - - def delete_node(self, name): - name = self._validate_name(name) - if name in self.nodes: - del self.nodes[name] - self.changed.add(name) - - def put_rdataset(self, name, rdataset): - node = self._maybe_cow(name) - node.replace_rdataset(rdataset) - - def delete_rdataset(self, name, rdtype, covers): - node = self._maybe_cow(name) - node.delete_rdataset(self.zone.rdclass, rdtype, covers) - if len(node) == 0: - del self.nodes[name] - - -@dns.immutable.immutable -class ImmutableVersion(Version): - def __init__(self, version): - # We tell super() that it's a replacement as we don't want it - # to copy the nodes, as we're about to do that with an - # immutable Dict. - super().__init__(version.zone, True) - # set the right id! - self.id = version.id - # Make changed nodes immutable - for name in version.changed: - node = version.nodes.get(name) - # it might not exist if we deleted it in the version - if node: - version.nodes[name] = ImmutableNode(node) - self.nodes = dns.immutable.Dict(version.nodes, True) - - -# A node with a version id. - -class Node(dns.node.Node): - __slots__ = ['id'] - - def __init__(self): - super().__init__() - # A proper id will get set by the Version - self.id = 0 - - -@dns.immutable.immutable -class ImmutableNode(Node): - __slots__ = ['id'] - - def __init__(self, node): - super().__init__() - self.id = node.id - self.rdatasets = tuple( - [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets] - ) - - def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, - create=False): - if create: - raise TypeError("immutable") - return super().find_rdataset(rdclass, rdtype, covers, False) - - def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, - create=False): - if create: - raise TypeError("immutable") - return super().get_rdataset(rdclass, rdtype, covers, False) - - def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE): - raise TypeError("immutable") - - def replace_rdataset(self, replacement): - raise TypeError("immutable") +# Backwards compatibility +Node = dns.zone.VersionedNode +ImmutableNode = dns.zone.ImmutableVersionedNode +Version = dns.zone.Version +WritableVersion = dns.zone.WritableVersion +ImmutableVersion = dns.zone.ImmutableVersion +Transaction = dns.zone.Transaction class Zone(dns.zone.Zone): @@ -198,7 +66,9 @@ def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True, self._write_event = None self._write_waiters = collections.deque() self._readers = set() - self._commit_version_unlocked(None, WritableVersion(self), origin) + self._commit_version_unlocked(None, + WritableVersion(self, replacement=True), + origin) def reader(self, id=None, serial=None): # pylint: disable=arguments-differ if id is not None and serial is not None: @@ -247,7 +117,8 @@ def writer(self, replacement=False): # give up the lock, so that we hold the lock as # short a time as possible. This is why we call # _setup_version() below. - self._write_txn = Transaction(self, replacement) + self._write_txn = Transaction(self, replacement, + make_immutable=True) # give up our exclusive right to make a Transaction self._write_event = None break @@ -367,6 +238,13 @@ def _commit_version(self, txn, version, origin): with self._version_lock: self._commit_version_unlocked(txn, version, origin) + def _get_next_version_id(self): + if len(self._versions) > 0: + id = self._versions[-1].id + 1 + else: + id = 1 + return id + def find_node(self, name, create=False): if create: raise UseTransaction @@ -394,62 +272,3 @@ def delete_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE): def replace_rdataset(self, name, replacement): raise UseTransaction - - -class Transaction(dns.transaction.Transaction): - - def __init__(self, zone, replacement, version=None): - read_only = version is not None - super().__init__(zone, replacement, read_only) - self.version = version - - @property - def zone(self): - return self.manager - - def _setup_version(self): - assert self.version is None - self.version = WritableVersion(self.zone, self.replacement) - - def _get_rdataset(self, name, rdtype, covers): - return self.version.get_rdataset(name, rdtype, covers) - - def _put_rdataset(self, name, rdataset): - assert not self.read_only - self.version.put_rdataset(name, rdataset) - - def _delete_name(self, name): - assert not self.read_only - self.version.delete_node(name) - - def _delete_rdataset(self, name, rdtype, covers): - assert not self.read_only - self.version.delete_rdataset(name, rdtype, covers) - - def _name_exists(self, name): - return self.version.get_node(name) is not None - - def _changed(self): - if self.read_only: - return False - else: - return len(self.version.changed) > 0 - - def _end_transaction(self, commit): - if self.read_only: - self.zone._end_read(self) - elif commit and len(self.version.changed) > 0: - self.zone._commit_version(self, ImmutableVersion(self.version), - self.version.origin) - else: - # rollback - self.zone._end_write(self) - - def _set_origin(self, origin): - if self.version.origin is None: - self.version.origin = origin - - def _iterate_rdatasets(self): - for (name, node) in self.version.items(): - for rdataset in node: - yield (name, rdataset) diff --git a/dns/zone.py b/dns/zone.py index 2f99b1b7..510be2df 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -24,6 +24,7 @@ import struct import dns.exception +import dns.immutable import dns.name import dns.node import dns.rdataclass @@ -772,10 +773,13 @@ def verify_digest(self, zonemd=None): # TransactionManager methods def reader(self): - return Transaction(self, False, True) + return Transaction(self, False, + Version(self, 1, self.nodes, self.origin)) def writer(self, replacement=False): - return Transaction(self, replacement, False) + txn = Transaction(self, replacement) + txn._setup_version() + return txn def origin_information(self): if self.relativize: @@ -787,107 +791,238 @@ def origin_information(self): def get_class(self): return self.rdclass + # Transaction methods -class Transaction(dns.transaction.Transaction): + def _end_read(self, txn): + pass + + def _end_write(self, txn): + pass + + def _commit_version(self, txn, version, origin): + self.nodes = version.nodes + if self.origin is None: + self.origin = origin + + def _get_next_version_id(self): + # Versions are ephemeral and all have id 1 + return 1 + + +# These classes used to be in dns.versioned, but have moved here so we can use +# the copy-on-write transaction mechanism for both kinds of zones. In a +# regular zone, the version only exists during the transaction, and the nodes +# are regular dns.node.Nodes. + +# A node with a version id. + +class VersionedNode(dns.node.Node): + __slots__ = ['id'] + + def __init__(self): + super().__init__() + # A proper id will get set by the Version + self.id = 0 + + +@dns.immutable.immutable +class ImmutableVersionedNode(VersionedNode): + __slots__ = ['id'] + + def __init__(self, node): + super().__init__() + self.id = node.id + self.rdatasets = tuple( + [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets] + ) + + def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, + create=False): + if create: + raise TypeError("immutable") + return super().find_rdataset(rdclass, rdtype, covers, False) + + def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, + create=False): + if create: + raise TypeError("immutable") + return super().get_rdataset(rdclass, rdtype, covers, False) - _deleted_rdataset = dns.rdataset.Rdataset(dns.rdataclass.ANY, - dns.rdatatype.ANY) + def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE): + raise TypeError("immutable") + + def replace_rdataset(self, replacement): + raise TypeError("immutable") + + +class Version: + def __init__(self, zone, id, nodes=None, origin=None): + self.zone = zone + self.id = id + if nodes is not None: + self.nodes = nodes + else: + self.nodes = {} + self.origin = origin + + def _validate_name(self, name): + if name.is_absolute(): + if not name.is_subdomain(self.zone.origin): + raise KeyError("name is not a subdomain of the zone origin") + if self.zone.relativize: + # XXXRTH should it be an error if self.origin is still None? + name = name.relativize(self.origin) + return name + + def get_node(self, name): + name = self._validate_name(name) + return self.nodes.get(name) + + def get_rdataset(self, name, rdtype, covers): + node = self.get_node(name) + if node is None: + return None + return node.get_rdataset(self.zone.rdclass, rdtype, covers) - def __init__(self, zone, replacement, read_only): + def items(self): + return self.nodes.items() # pylint: disable=dict-items-not-iterating + + +class WritableVersion(Version): + def __init__(self, zone, replacement=False): + # The zone._versions_lock must be held by our caller in a versioned + # zone. + id = zone._get_next_version_id() + super().__init__(zone, id) + if not replacement: + # We copy the map, because that gives us a simple and thread-safe + # way of doing versions, and we have a garbage collector to help + # us. We only make new node objects if we actually change the + # node. + self.nodes.update(zone.nodes) + # We have to copy the zone origin as it may be None in the first + # version, and we don't want to mutate the zone until we commit. + self.origin = zone.origin + self.changed = set() + + def _maybe_cow(self, name): + name = self._validate_name(name) + node = self.nodes.get(name) + if node is None or name not in self.changed: + new_node = self.zone.node_factory() + if hasattr(new_node, 'id'): + # We keep doing this for backwards compatibility, as earlier + # code used new_node.id != self.id for the "do we need to CoW?" + # test. Now we use the changed set as this works with both + # regular zones and versioned zones. + new_node.id = self.id + if node is not None: + # moo! copy on write! + new_node.rdatasets.extend(node.rdatasets) + self.nodes[name] = new_node + self.changed.add(name) + return new_node + else: + return node + + def delete_node(self, name): + name = self._validate_name(name) + if name in self.nodes: + del self.nodes[name] + self.changed.add(name) + + def put_rdataset(self, name, rdataset): + node = self._maybe_cow(name) + node.replace_rdataset(rdataset) + + def delete_rdataset(self, name, rdtype, covers): + node = self._maybe_cow(name) + node.delete_rdataset(self.zone.rdclass, rdtype, covers) + if len(node) == 0: + del self.nodes[name] + + +@dns.immutable.immutable +class ImmutableVersion(Version): + def __init__(self, version): + # We tell super() that it's a replacement as we don't want it + # to copy the nodes, as we're about to do that with an + # immutable Dict. + super().__init__(version.zone, True) + # set the right id! + self.id = version.id + # keep the origin + self.origin = version.origin + # Make changed nodes immutable + for name in version.changed: + node = version.nodes.get(name) + # it might not exist if we deleted it in the version + if node: + version.nodes[name] = ImmutableVersionedNode(node) + self.nodes = dns.immutable.Dict(version.nodes, True) + + +class Transaction(dns.transaction.Transaction): + + def __init__(self, zone, replacement, version=None, make_immutable=False): + read_only = version is not None super().__init__(zone, replacement, read_only) - self.rdatasets = {} + self.version = version + self.make_immutable = make_immutable @property def zone(self): return self.manager + def _setup_version(self): + assert self.version is None + self.version = WritableVersion(self.zone, self.replacement) + def _get_rdataset(self, name, rdtype, covers): - rdataset = self.rdatasets.get((name, rdtype, covers)) - if rdataset is self._deleted_rdataset: - return None - elif rdataset is None and not self.replacement: - rdataset = self.zone.get_rdataset(name, rdtype, covers) - return rdataset + return self.version.get_rdataset(name, rdtype, covers) def _put_rdataset(self, name, rdataset): assert not self.read_only - self.zone._validate_name(name) - self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset + self.version.put_rdataset(name, rdataset) def _delete_name(self, name): assert not self.read_only - # First remove any changes involving the name - remove = [] - for key in self.rdatasets: - if key[0] == name: - remove.append(key) - if len(remove) > 0: - for key in remove: - del self.rdatasets[key] - # Next add deletion records for any rdatasets matching the - # name in the zone - node = self.zone.get_node(name) - if node is not None: - for rdataset in node.rdatasets: - self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = \ - self._deleted_rdataset + self.version.delete_node(name) def _delete_rdataset(self, name, rdtype, covers): assert not self.read_only - try: - del self.rdatasets[(name, rdtype, covers)] - except KeyError: - pass - rdataset = self.zone.get_rdataset(name, rdtype, covers) - if rdataset is not None: - self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = \ - self._deleted_rdataset + self.version.delete_rdataset(name, rdtype, covers) def _name_exists(self, name): - for key, rdataset in self.rdatasets.items(): - if key[0] == name: - if rdataset != self._deleted_rdataset: - return True - else: - return None - self.zone._validate_name(name) - if self.zone.get_node(name): - return True - return False + return self.version.get_node(name) is not None def _changed(self): if self.read_only: return False else: - return len(self.rdatasets) > 0 + return len(self.version.changed) > 0 def _end_transaction(self, commit): - if commit and self._changed(): - if self.replacement: - self.zone.nodes = {} - for (name, rdtype, covers), rdataset in \ - self.rdatasets.items(): - if rdataset is self._deleted_rdataset: - self.zone.delete_rdataset(name, rdtype, covers) - else: - self.zone.replace_rdataset(name, rdataset) + if self.read_only: + self.zone._end_read(self) + elif commit and len(self.version.changed) > 0: + if self.make_immutable: + version = ImmutableVersion(self.version) + else: + version = self.version + self.zone._commit_version(self, version, self.version.origin) + else: + # rollback + self.zone._end_write(self) def _set_origin(self, origin): - if self.zone.origin is None: - self.zone.origin = origin + if self.version.origin is None: + self.version.origin = origin def _iterate_rdatasets(self): - # Expensive but simple! Use a versioned zone for efficient txn - # iteration. - if self.replacement: - rdatasets = self.rdatasets - else: - rdatasets = {} - for (name, rdataset) in self.zone.iterate_rdatasets(): - rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset - rdatasets.update(self.rdatasets) - for (name, _, _), rdataset in rdatasets.items(): - yield (name, rdataset) + for (name, node) in self.version.items(): + for rdataset in node: + yield (name, rdataset) def from_text(text, origin=None, rdclass=dns.rdataclass.IN,