diff --git a/armi/bookkeeping/db/database3.py b/armi/bookkeeping/db/database3.py index 7d69c334d2..31f263b619 100644 --- a/armi/bookkeeping/db/database3.py +++ b/armi/bookkeeping/db/database3.py @@ -77,7 +77,9 @@ from armi.utils.textProcessors import resolveMarkupInclusions ORDER = interfaces.STACK_ORDER.BOOKKEEPING -DB_VERSION = "3.1" +DB_VERSION = "3.2" + +ATTR_LINK = re.compile("^@(.*)$") def getH5GroupName(cycle, timeNode, statePointName=None): @@ -384,8 +386,6 @@ class Database3(database.Database): `doc/user/outputs/database` for more details. """ - version = DB_VERSION - timeNodeGroupPattern = re.compile(r"^c(\d\d)n(\d\d)$") def __init__(self, fileName: str, permission: str): @@ -413,6 +413,31 @@ def __init__(self, fileName: str, permission: str): # closed. self._openCount: int = 0 + if permission == "w": + self.version = DB_VERSION + else: + # will be set upon read + self._version = None + self._versionMajor = None + self._versionMinor = None + + @property + def version(self): + return self._version + + @version.setter + def version(self, value): + self._version = value + self._versionMajor, self._versionMinor = (int(v) for v in value.split(".")) + + @property + def versionMajor(self): + return self._versionMajor + + @property + def versionMinor(self): + return self._versionMinor + def __repr__(self): return "<{} {}>".format( self.__class__.__name__, repr(self.h5db).replace("<", "").replace(">", "") @@ -430,6 +455,7 @@ def open(self): if self._permission in {"r", "a"}: self._fullPath = os.path.abspath(filePath) self.h5db = h5py.File(filePath, self._permission) + self.version = self.h5db.attrs["databaseVersion"] return if self._permission == "w": @@ -447,7 +473,7 @@ def open(self): self.h5db = h5py.File(filePath, self._permission) self.h5db.attrs["successfulCompletion"] = False self.h5db.attrs["version"] = armi.__version__ - self.h5db.attrs["databaseVersion"] = DB_VERSION + self.h5db.attrs["databaseVersion"] = self.version self.h5db.attrs["user"] = armi.USER self.h5db.attrs["python"] = sys.version self.h5db.attrs["armiLocation"] = os.path.dirname(armi.ROOT) @@ -637,6 +663,20 @@ def mergeHistory(self, inputDB, startCycle, startNode): return self.h5db.copy(h5ts, h5ts.name) + if inputDB.versionMinor < 2: + # The source database may have object references in some attributes. + # make sure to link those up using our manual path strategy. + references = [] + def findReferences(name, obj): + for key, attr in obj.attrs.items(): + if isinstance(attr, h5py.h5r.Reference): + references.append((name, key, inputDB.h5db[attr].name)) + h5ts.visititems(findReferences) + + for key, attr, path in references: + destTs = self.h5db[h5ts.name] + destTs[key].attrs[attr] = "@{}".format(path) + def __enter__(self): """Context management support""" if self._openCount == 0: @@ -1976,11 +2016,10 @@ def _writeAttrs(obj, group, attrs): In such cases, this will store the attribute data in a proper dataset and place a reference to that dataset in the attribute instead. - In practice, this takes ``linkedDims`` attrs from a particular - component type (like ``c00/n00/Circle/id``) and stores them - in new datasets (like ``c00n00/attrs/1_linkedDims``, - ``c00n00/attrs/2_linkedDims``) and - then sets the object's attrs to links to those datasets. + In practice, this takes ``linkedDims`` attrs from a particular component type (like + ``c00n00/Circle/id``) and stores them in new datasets (like + ``c00n00/attrs/1_linkedDims``, ``c00n00/attrs/2_linkedDims``) and then sets the + object's attrs to links to those datasets. """ for key, value in attrs.items(): try: @@ -2001,7 +2040,10 @@ def _writeAttrs(obj, group, attrs): dataName = str(len(attrGroup)) + "_" + key attrGroup[dataName] = value - obj.attrs[key] = attrGroup[dataName].ref + # using a soft link here allows us to cheaply copy time nodes without + # needing to crawl through and update object references. + linkName = attrGroup[dataName].name + obj.attrs[key] = "@{}".format(linkName) def _resolveAttrs(attrs, group): @@ -2015,9 +2057,17 @@ def _resolveAttrs(attrs, group): for key, val in attrs.items(): try: if isinstance(val, h5py.h5r.Reference): - # dereference the .ref to get the data - # out of the dataset. + # Old style object reference. If this cannot be dereferenced, it is + # likely because mergeHistory was used to get the current database, + # which does not preserve references. resolved[key] = group[val] + elif isinstance(val, str): + m = ATTR_LINK.match(val) + if m: + # dereference the path to get the data out of the dataset. + resolved[key] = group[m.group(1)][()] + else: + resolved[key] = val else: resolved[key] = val except ValueError: diff --git a/armi/bookkeeping/db/tests/test_database3.py b/armi/bookkeeping/db/tests/test_database3.py index 94be71138e..8f1b24ba48 100644 --- a/armi/bookkeeping/db/tests/test_database3.py +++ b/armi/bookkeeping/db/tests/test_database3.py @@ -26,20 +26,29 @@ class TestDatabase3(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.o, cls.r = test_reactors.loadTestReactor(TEST_ROOT) - def setUp(self): + self.o, self.r = test_reactors.loadTestReactor(TEST_ROOT) self.db = database.Database3(self._testMethodName + ".h5", "w") self.db.open() - print(self.db._fullPath) self.stateRetainer = self.r.retainState().__enter__() def tearDown(self): self.db.close() self.stateRetainer.__exit__() + def makeHistory(self): + """ + Walk the reactor through a few time steps and write them to the db. + """ + for cycle, node in ((cycle, node) for cycle in range(3) for node in range(3)): + self.r.p.cycle = cycle + self.r.p.timeNode = node + # something that splitDatabase won't change, so that we can make sure that + # the right data went to the right new groups/cycles + self.r.p.cycleLength = cycle + + self.db.writeToDB(self.r) + def _compareArrays(self, ref, src): """ Compare two numpy arrays. @@ -94,15 +103,36 @@ def test_replaceNones(self): self._compareRoundTrip(dataJagNones) self._compareRoundTrip(dataDict) - def test_splitDatabase(self): - for cycle, node in ((cycle, node) for cycle in range(3) for node in range(3)): - self.r.p.cycle = cycle - self.r.p.timeNode = node - # something that splitDatabase won't change, so that we can make sure that - # the right data went to the right new groups/cycles - self.r.p.cycleLength = cycle + def test_mergeHistory(self): + self.makeHistory() + + # put some big data in an HDF5 attribute. This will exercise the code that pulls + # such attributes into a formal dataset and a reference. + self.r.p.cycle = 1 + self.r.p.timeNode = 0 + tnGroup = self.db.getH5Group(self.r) + database._writeAttrs( + tnGroup["layout/serialNum"], tnGroup, {"fakeBigData": numpy.eye(6400), + "someString": "this isn't a reference to another dataset"} + ) - self.db.writeToDB(self.r) + db2 = database.Database3("restartDB.h5", "w") + with db2: + db2.mergeHistory(self.db, 2, 2) + self.r.p.cycle = 1 + self.r.p.timeNode = 0 + tnGroup = db2.getH5Group(self.r) + + # this test is a little bit implementation-specific, but nice to be explicit + self.assertEqual(tnGroup["layout/serialNum"].attrs["fakeBigData"], + "@/c01n00/attrs/0_fakeBigData") + + # actually exercise the _resolveAttrs function + attrs = database._resolveAttrs(tnGroup["layout/serialNum"].attrs, tnGroup) + self.assertTrue(numpy.array_equal(attrs["fakeBigData"], numpy.eye(6400))) + + def test_splitDatabase(self): + self.makeHistory() self.db.splitDatabase( [(c, n) for c in (1, 2) for n in range(3)], "-all-iterations"