From b835817724486749c16a2c93dda2ad9a12bee288 Mon Sep 17 00:00:00 2001 From: Mitchell Young Date: Mon, 13 Jan 2020 10:37:48 -0800 Subject: [PATCH] Replace object references with soft links This replaces the use of object references to handle HDF5 attributes that are too large to store natrually. This is because object references break when copying an object that contains them to a new HDF5 file, as is done in mergeHistory(). Instead, we use a simple string to refer to the targed dataset like "@[path/to/dataset]". --- armi/bookkeeping/db/database3.py | 74 +++++++++++++++++---- armi/bookkeeping/db/tests/test_database3.py | 56 ++++++++++++---- 2 files changed, 105 insertions(+), 25 deletions(-) diff --git a/armi/bookkeeping/db/database3.py b/armi/bookkeeping/db/database3.py index 7d69c334d2..31f263b619 100644 --- a/armi/bookkeeping/db/database3.py +++ b/armi/bookkeeping/db/database3.py @@ -77,7 +77,9 @@ from armi.utils.textProcessors import resolveMarkupInclusions ORDER = interfaces.STACK_ORDER.BOOKKEEPING -DB_VERSION = "3.1" +DB_VERSION = "3.2" + +ATTR_LINK = re.compile("^@(.*)$") def getH5GroupName(cycle, timeNode, statePointName=None): @@ -384,8 +386,6 @@ class Database3(database.Database): `doc/user/outputs/database` for more details. """ - version = DB_VERSION - timeNodeGroupPattern = re.compile(r"^c(\d\d)n(\d\d)$") def __init__(self, fileName: str, permission: str): @@ -413,6 +413,31 @@ def __init__(self, fileName: str, permission: str): # closed. self._openCount: int = 0 + if permission == "w": + self.version = DB_VERSION + else: + # will be set upon read + self._version = None + self._versionMajor = None + self._versionMinor = None + + @property + def version(self): + return self._version + + @version.setter + def version(self, value): + self._version = value + self._versionMajor, self._versionMinor = (int(v) for v in value.split(".")) + + @property + def versionMajor(self): + return self._versionMajor + + @property + def versionMinor(self): + return self._versionMinor + def __repr__(self): return "<{} {}>".format( self.__class__.__name__, repr(self.h5db).replace("<", "").replace(">", "") @@ -430,6 +455,7 @@ def open(self): if self._permission in {"r", "a"}: self._fullPath = os.path.abspath(filePath) self.h5db = h5py.File(filePath, self._permission) + self.version = self.h5db.attrs["databaseVersion"] return if self._permission == "w": @@ -447,7 +473,7 @@ def open(self): self.h5db = h5py.File(filePath, self._permission) self.h5db.attrs["successfulCompletion"] = False self.h5db.attrs["version"] = armi.__version__ - self.h5db.attrs["databaseVersion"] = DB_VERSION + self.h5db.attrs["databaseVersion"] = self.version self.h5db.attrs["user"] = armi.USER self.h5db.attrs["python"] = sys.version self.h5db.attrs["armiLocation"] = os.path.dirname(armi.ROOT) @@ -637,6 +663,20 @@ def mergeHistory(self, inputDB, startCycle, startNode): return self.h5db.copy(h5ts, h5ts.name) + if inputDB.versionMinor < 2: + # The source database may have object references in some attributes. + # make sure to link those up using our manual path strategy. + references = [] + def findReferences(name, obj): + for key, attr in obj.attrs.items(): + if isinstance(attr, h5py.h5r.Reference): + references.append((name, key, inputDB.h5db[attr].name)) + h5ts.visititems(findReferences) + + for key, attr, path in references: + destTs = self.h5db[h5ts.name] + destTs[key].attrs[attr] = "@{}".format(path) + def __enter__(self): """Context management support""" if self._openCount == 0: @@ -1976,11 +2016,10 @@ def _writeAttrs(obj, group, attrs): In such cases, this will store the attribute data in a proper dataset and place a reference to that dataset in the attribute instead. - In practice, this takes ``linkedDims`` attrs from a particular - component type (like ``c00/n00/Circle/id``) and stores them - in new datasets (like ``c00n00/attrs/1_linkedDims``, - ``c00n00/attrs/2_linkedDims``) and - then sets the object's attrs to links to those datasets. + In practice, this takes ``linkedDims`` attrs from a particular component type (like + ``c00n00/Circle/id``) and stores them in new datasets (like + ``c00n00/attrs/1_linkedDims``, ``c00n00/attrs/2_linkedDims``) and then sets the + object's attrs to links to those datasets. """ for key, value in attrs.items(): try: @@ -2001,7 +2040,10 @@ def _writeAttrs(obj, group, attrs): dataName = str(len(attrGroup)) + "_" + key attrGroup[dataName] = value - obj.attrs[key] = attrGroup[dataName].ref + # using a soft link here allows us to cheaply copy time nodes without + # needing to crawl through and update object references. + linkName = attrGroup[dataName].name + obj.attrs[key] = "@{}".format(linkName) def _resolveAttrs(attrs, group): @@ -2015,9 +2057,17 @@ def _resolveAttrs(attrs, group): for key, val in attrs.items(): try: if isinstance(val, h5py.h5r.Reference): - # dereference the .ref to get the data - # out of the dataset. + # Old style object reference. If this cannot be dereferenced, it is + # likely because mergeHistory was used to get the current database, + # which does not preserve references. resolved[key] = group[val] + elif isinstance(val, str): + m = ATTR_LINK.match(val) + if m: + # dereference the path to get the data out of the dataset. + resolved[key] = group[m.group(1)][()] + else: + resolved[key] = val else: resolved[key] = val except ValueError: diff --git a/armi/bookkeeping/db/tests/test_database3.py b/armi/bookkeeping/db/tests/test_database3.py index 94be71138e..8f1b24ba48 100644 --- a/armi/bookkeeping/db/tests/test_database3.py +++ b/armi/bookkeeping/db/tests/test_database3.py @@ -26,20 +26,29 @@ class TestDatabase3(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.o, cls.r = test_reactors.loadTestReactor(TEST_ROOT) - def setUp(self): + self.o, self.r = test_reactors.loadTestReactor(TEST_ROOT) self.db = database.Database3(self._testMethodName + ".h5", "w") self.db.open() - print(self.db._fullPath) self.stateRetainer = self.r.retainState().__enter__() def tearDown(self): self.db.close() self.stateRetainer.__exit__() + def makeHistory(self): + """ + Walk the reactor through a few time steps and write them to the db. + """ + for cycle, node in ((cycle, node) for cycle in range(3) for node in range(3)): + self.r.p.cycle = cycle + self.r.p.timeNode = node + # something that splitDatabase won't change, so that we can make sure that + # the right data went to the right new groups/cycles + self.r.p.cycleLength = cycle + + self.db.writeToDB(self.r) + def _compareArrays(self, ref, src): """ Compare two numpy arrays. @@ -94,15 +103,36 @@ def test_replaceNones(self): self._compareRoundTrip(dataJagNones) self._compareRoundTrip(dataDict) - def test_splitDatabase(self): - for cycle, node in ((cycle, node) for cycle in range(3) for node in range(3)): - self.r.p.cycle = cycle - self.r.p.timeNode = node - # something that splitDatabase won't change, so that we can make sure that - # the right data went to the right new groups/cycles - self.r.p.cycleLength = cycle + def test_mergeHistory(self): + self.makeHistory() + + # put some big data in an HDF5 attribute. This will exercise the code that pulls + # such attributes into a formal dataset and a reference. + self.r.p.cycle = 1 + self.r.p.timeNode = 0 + tnGroup = self.db.getH5Group(self.r) + database._writeAttrs( + tnGroup["layout/serialNum"], tnGroup, {"fakeBigData": numpy.eye(6400), + "someString": "this isn't a reference to another dataset"} + ) - self.db.writeToDB(self.r) + db2 = database.Database3("restartDB.h5", "w") + with db2: + db2.mergeHistory(self.db, 2, 2) + self.r.p.cycle = 1 + self.r.p.timeNode = 0 + tnGroup = db2.getH5Group(self.r) + + # this test is a little bit implementation-specific, but nice to be explicit + self.assertEqual(tnGroup["layout/serialNum"].attrs["fakeBigData"], + "@/c01n00/attrs/0_fakeBigData") + + # actually exercise the _resolveAttrs function + attrs = database._resolveAttrs(tnGroup["layout/serialNum"].attrs, tnGroup) + self.assertTrue(numpy.array_equal(attrs["fakeBigData"], numpy.eye(6400))) + + def test_splitDatabase(self): + self.makeHistory() self.db.splitDatabase( [(c, n) for c in (1, 2) for n in range(3)], "-all-iterations"