Skip to content

Commit

Permalink
Replace object references with soft links
Browse files Browse the repository at this point in the history
This replaces the use of object references to handle HDF5 attributes
that are too large to store natrually. This is because object references
break when copying an object that contains them to a new HDF5 file, as
is done in mergeHistory(). Instead, we use a simple string to refer to
the targed dataset like "@[path/to/dataset]".
  • Loading branch information
youngmit committed Jan 16, 2020
1 parent ec47cfb commit b835817
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 25 deletions.
74 changes: 62 additions & 12 deletions armi/bookkeeping/db/database3.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@
from armi.utils.textProcessors import resolveMarkupInclusions

ORDER = interfaces.STACK_ORDER.BOOKKEEPING
DB_VERSION = "3.1"
DB_VERSION = "3.2"

ATTR_LINK = re.compile("^@(.*)$")


def getH5GroupName(cycle, timeNode, statePointName=None):
Expand Down Expand Up @@ -384,8 +386,6 @@ class Database3(database.Database):
`doc/user/outputs/database` for more details.
"""

version = DB_VERSION

timeNodeGroupPattern = re.compile(r"^c(\d\d)n(\d\d)$")

def __init__(self, fileName: str, permission: str):
Expand Down Expand Up @@ -413,6 +413,31 @@ def __init__(self, fileName: str, permission: str):
# closed.
self._openCount: int = 0

if permission == "w":
self.version = DB_VERSION
else:
# will be set upon read
self._version = None
self._versionMajor = None
self._versionMinor = None

@property
def version(self):
return self._version

@version.setter
def version(self, value):
self._version = value
self._versionMajor, self._versionMinor = (int(v) for v in value.split("."))

@property
def versionMajor(self):
return self._versionMajor

@property
def versionMinor(self):
return self._versionMinor

def __repr__(self):
return "<{} {}>".format(
self.__class__.__name__, repr(self.h5db).replace("<", "").replace(">", "")
Expand All @@ -430,6 +455,7 @@ def open(self):
if self._permission in {"r", "a"}:
self._fullPath = os.path.abspath(filePath)
self.h5db = h5py.File(filePath, self._permission)
self.version = self.h5db.attrs["databaseVersion"]
return

if self._permission == "w":
Expand All @@ -447,7 +473,7 @@ def open(self):
self.h5db = h5py.File(filePath, self._permission)
self.h5db.attrs["successfulCompletion"] = False
self.h5db.attrs["version"] = armi.__version__
self.h5db.attrs["databaseVersion"] = DB_VERSION
self.h5db.attrs["databaseVersion"] = self.version
self.h5db.attrs["user"] = armi.USER
self.h5db.attrs["python"] = sys.version
self.h5db.attrs["armiLocation"] = os.path.dirname(armi.ROOT)
Expand Down Expand Up @@ -637,6 +663,20 @@ def mergeHistory(self, inputDB, startCycle, startNode):
return
self.h5db.copy(h5ts, h5ts.name)

if inputDB.versionMinor < 2:
# The source database may have object references in some attributes.
# make sure to link those up using our manual path strategy.
references = []
def findReferences(name, obj):
for key, attr in obj.attrs.items():
if isinstance(attr, h5py.h5r.Reference):
references.append((name, key, inputDB.h5db[attr].name))
h5ts.visititems(findReferences)

for key, attr, path in references:
destTs = self.h5db[h5ts.name]
destTs[key].attrs[attr] = "@{}".format(path)

def __enter__(self):
"""Context management support"""
if self._openCount == 0:
Expand Down Expand Up @@ -1976,11 +2016,10 @@ def _writeAttrs(obj, group, attrs):
In such cases, this will store the attribute data in a proper dataset and
place a reference to that dataset in the attribute instead.
In practice, this takes ``linkedDims`` attrs from a particular
component type (like ``c00/n00/Circle/id``) and stores them
in new datasets (like ``c00n00/attrs/1_linkedDims``,
``c00n00/attrs/2_linkedDims``) and
then sets the object's attrs to links to those datasets.
In practice, this takes ``linkedDims`` attrs from a particular component type (like
``c00n00/Circle/id``) and stores them in new datasets (like
``c00n00/attrs/1_linkedDims``, ``c00n00/attrs/2_linkedDims``) and then sets the
object's attrs to links to those datasets.
"""
for key, value in attrs.items():
try:
Expand All @@ -2001,7 +2040,10 @@ def _writeAttrs(obj, group, attrs):
dataName = str(len(attrGroup)) + "_" + key
attrGroup[dataName] = value

obj.attrs[key] = attrGroup[dataName].ref
# using a soft link here allows us to cheaply copy time nodes without
# needing to crawl through and update object references.
linkName = attrGroup[dataName].name
obj.attrs[key] = "@{}".format(linkName)


def _resolveAttrs(attrs, group):
Expand All @@ -2015,9 +2057,17 @@ def _resolveAttrs(attrs, group):
for key, val in attrs.items():
try:
if isinstance(val, h5py.h5r.Reference):
# dereference the .ref to get the data
# out of the dataset.
# Old style object reference. If this cannot be dereferenced, it is
# likely because mergeHistory was used to get the current database,
# which does not preserve references.
resolved[key] = group[val]
elif isinstance(val, str):
m = ATTR_LINK.match(val)
if m:
# dereference the path to get the data out of the dataset.
resolved[key] = group[m.group(1)][()]
else:
resolved[key] = val
else:
resolved[key] = val
except ValueError:
Expand Down
56 changes: 43 additions & 13 deletions armi/bookkeeping/db/tests/test_database3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,29 @@


class TestDatabase3(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.o, cls.r = test_reactors.loadTestReactor(TEST_ROOT)

def setUp(self):
self.o, self.r = test_reactors.loadTestReactor(TEST_ROOT)
self.db = database.Database3(self._testMethodName + ".h5", "w")
self.db.open()
print(self.db._fullPath)
self.stateRetainer = self.r.retainState().__enter__()

def tearDown(self):
self.db.close()
self.stateRetainer.__exit__()

def makeHistory(self):
"""
Walk the reactor through a few time steps and write them to the db.
"""
for cycle, node in ((cycle, node) for cycle in range(3) for node in range(3)):
self.r.p.cycle = cycle
self.r.p.timeNode = node
# something that splitDatabase won't change, so that we can make sure that
# the right data went to the right new groups/cycles
self.r.p.cycleLength = cycle

self.db.writeToDB(self.r)

def _compareArrays(self, ref, src):
"""
Compare two numpy arrays.
Expand Down Expand Up @@ -94,15 +103,36 @@ def test_replaceNones(self):
self._compareRoundTrip(dataJagNones)
self._compareRoundTrip(dataDict)

def test_splitDatabase(self):
for cycle, node in ((cycle, node) for cycle in range(3) for node in range(3)):
self.r.p.cycle = cycle
self.r.p.timeNode = node
# something that splitDatabase won't change, so that we can make sure that
# the right data went to the right new groups/cycles
self.r.p.cycleLength = cycle
def test_mergeHistory(self):
self.makeHistory()

# put some big data in an HDF5 attribute. This will exercise the code that pulls
# such attributes into a formal dataset and a reference.
self.r.p.cycle = 1
self.r.p.timeNode = 0
tnGroup = self.db.getH5Group(self.r)
database._writeAttrs(
tnGroup["layout/serialNum"], tnGroup, {"fakeBigData": numpy.eye(6400),
"someString": "this isn't a reference to another dataset"}
)

self.db.writeToDB(self.r)
db2 = database.Database3("restartDB.h5", "w")
with db2:
db2.mergeHistory(self.db, 2, 2)
self.r.p.cycle = 1
self.r.p.timeNode = 0
tnGroup = db2.getH5Group(self.r)

# this test is a little bit implementation-specific, but nice to be explicit
self.assertEqual(tnGroup["layout/serialNum"].attrs["fakeBigData"],
"@/c01n00/attrs/0_fakeBigData")

# actually exercise the _resolveAttrs function
attrs = database._resolveAttrs(tnGroup["layout/serialNum"].attrs, tnGroup)
self.assertTrue(numpy.array_equal(attrs["fakeBigData"], numpy.eye(6400)))

def test_splitDatabase(self):
self.makeHistory()

self.db.splitDatabase(
[(c, n) for c in (1, 2) for n in range(3)], "-all-iterations"
Expand Down

0 comments on commit b835817

Please sign in to comment.