Skip to content

Commit

Permalink
Refactor zone transactions to always use versioned CoW code.
Browse files Browse the repository at this point in the history
  • Loading branch information
rthalley committed Dec 1, 2021
1 parent c706e26 commit 9a16076
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 271 deletions.
219 changes: 19 additions & 200 deletions dns/versioned.py
Expand Up @@ -11,155 +11,23 @@
import dns.exception
import dns.immutable
import dns.name
import dns.node
import dns.rdataclass
import dns.rdatatype
import dns.rdata
import dns.rdtypes.ANY.SOA
import dns.transaction
import dns.zone


class UseTransaction(dns.exception.DNSException):
"""To alter a versioned zone, use a transaction."""


class Version:
def __init__(self, zone, id):
self.zone = zone
self.id = id
self.nodes = {}

def _validate_name(self, name):
if name.is_absolute():
if not name.is_subdomain(self.zone.origin):
raise KeyError("name is not a subdomain of the zone origin")
if self.zone.relativize:
name = name.relativize(self.origin)
return name

def get_node(self, name):
name = self._validate_name(name)
return self.nodes.get(name)

def get_rdataset(self, name, rdtype, covers):
node = self.get_node(name)
if node is None:
return None
return node.get_rdataset(self.zone.rdclass, rdtype, covers)

def items(self):
return self.nodes.items() # pylint: disable=dict-items-not-iterating


class WritableVersion(Version):
def __init__(self, zone, replacement=False):
# The zone._versions_lock must be held by our caller.
if len(zone._versions) > 0:
id = zone._versions[-1].id + 1
else:
id = 1
super().__init__(zone, id)
if not replacement:
# We copy the map, because that gives us a simple and thread-safe
# way of doing versions, and we have a garbage collector to help
# us. We only make new node objects if we actually change the
# node.
self.nodes.update(zone.nodes)
# We have to copy the zone origin as it may be None in the first
# version, and we don't want to mutate the zone until we commit.
self.origin = zone.origin
self.changed = set()

def _maybe_cow(self, name):
name = self._validate_name(name)
node = self.nodes.get(name)
if node is None or node.id != self.id:
new_node = self.zone.node_factory()
new_node.id = self.id
if node is not None:
# moo! copy on write!
new_node.rdatasets.extend(node.rdatasets)
self.nodes[name] = new_node
self.changed.add(name)
return new_node
else:
return node

def delete_node(self, name):
name = self._validate_name(name)
if name in self.nodes:
del self.nodes[name]
self.changed.add(name)

def put_rdataset(self, name, rdataset):
node = self._maybe_cow(name)
node.replace_rdataset(rdataset)

def delete_rdataset(self, name, rdtype, covers):
node = self._maybe_cow(name)
node.delete_rdataset(self.zone.rdclass, rdtype, covers)
if len(node) == 0:
del self.nodes[name]


@dns.immutable.immutable
class ImmutableVersion(Version):
def __init__(self, version):
# We tell super() that it's a replacement as we don't want it
# to copy the nodes, as we're about to do that with an
# immutable Dict.
super().__init__(version.zone, True)
# set the right id!
self.id = version.id
# Make changed nodes immutable
for name in version.changed:
node = version.nodes.get(name)
# it might not exist if we deleted it in the version
if node:
version.nodes[name] = ImmutableNode(node)
self.nodes = dns.immutable.Dict(version.nodes, True)


# A node with a version id.

class Node(dns.node.Node):
__slots__ = ['id']

def __init__(self):
super().__init__()
# A proper id will get set by the Version
self.id = 0


@dns.immutable.immutable
class ImmutableNode(Node):
__slots__ = ['id']

def __init__(self, node):
super().__init__()
self.id = node.id
self.rdatasets = tuple(
[dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
)

def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
create=False):
if create:
raise TypeError("immutable")
return super().find_rdataset(rdclass, rdtype, covers, False)

def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
create=False):
if create:
raise TypeError("immutable")
return super().get_rdataset(rdclass, rdtype, covers, False)

def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
raise TypeError("immutable")

def replace_rdataset(self, replacement):
raise TypeError("immutable")
# Backwards compatibility
Node = dns.zone.VersionedNode
ImmutableNode = dns.zone.ImmutableVersionedNode
Version = dns.zone.Version
WritableVersion = dns.zone.WritableVersion
ImmutableVersion = dns.zone.ImmutableVersion
Transaction = dns.zone.Transaction


class Zone(dns.zone.Zone):
Expand Down Expand Up @@ -198,7 +66,9 @@ def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True,
self._write_event = None
self._write_waiters = collections.deque()
self._readers = set()
self._commit_version_unlocked(None, WritableVersion(self), origin)
self._commit_version_unlocked(None,
WritableVersion(self, replacement=True),
origin)

def reader(self, id=None, serial=None): # pylint: disable=arguments-differ
if id is not None and serial is not None:
Expand Down Expand Up @@ -247,7 +117,8 @@ def writer(self, replacement=False):
# give up the lock, so that we hold the lock as
# short a time as possible. This is why we call
# _setup_version() below.
self._write_txn = Transaction(self, replacement)
self._write_txn = Transaction(self, replacement,
make_immutable=True)
# give up our exclusive right to make a Transaction
self._write_event = None
break
Expand Down Expand Up @@ -367,6 +238,13 @@ def _commit_version(self, txn, version, origin):
with self._version_lock:
self._commit_version_unlocked(txn, version, origin)

def _get_next_version_id(self):
if len(self._versions) > 0:
id = self._versions[-1].id + 1
else:
id = 1
return id

def find_node(self, name, create=False):
if create:
raise UseTransaction
Expand Down Expand Up @@ -394,62 +272,3 @@ def delete_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE):

def replace_rdataset(self, name, replacement):
raise UseTransaction


class Transaction(dns.transaction.Transaction):

def __init__(self, zone, replacement, version=None):
read_only = version is not None
super().__init__(zone, replacement, read_only)
self.version = version

@property
def zone(self):
return self.manager

def _setup_version(self):
assert self.version is None
self.version = WritableVersion(self.zone, self.replacement)

def _get_rdataset(self, name, rdtype, covers):
return self.version.get_rdataset(name, rdtype, covers)

def _put_rdataset(self, name, rdataset):
assert not self.read_only
self.version.put_rdataset(name, rdataset)

def _delete_name(self, name):
assert not self.read_only
self.version.delete_node(name)

def _delete_rdataset(self, name, rdtype, covers):
assert not self.read_only
self.version.delete_rdataset(name, rdtype, covers)

def _name_exists(self, name):
return self.version.get_node(name) is not None

def _changed(self):
if self.read_only:
return False
else:
return len(self.version.changed) > 0

def _end_transaction(self, commit):
if self.read_only:
self.zone._end_read(self)
elif commit and len(self.version.changed) > 0:
self.zone._commit_version(self, ImmutableVersion(self.version),
self.version.origin)
else:
# rollback
self.zone._end_write(self)

def _set_origin(self, origin):
if self.version.origin is None:
self.version.origin = origin

def _iterate_rdatasets(self):
for (name, node) in self.version.items():
for rdataset in node:
yield (name, rdataset)

0 comments on commit 9a16076

Please sign in to comment.