Skip to content

Commit

Permalink
Merge b835817 into ec47cfb
Browse files Browse the repository at this point in the history
  • Loading branch information
youngmit committed Jan 16, 2020
2 parents ec47cfb + b835817 commit 1bb8c77
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 25 deletions.
74 changes: 62 additions & 12 deletions armi/bookkeeping/db/database3.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@
from armi.utils.textProcessors import resolveMarkupInclusions

ORDER = interfaces.STACK_ORDER.BOOKKEEPING
DB_VERSION = "3.1"
DB_VERSION = "3.2"

ATTR_LINK = re.compile("^@(.*)$")


def getH5GroupName(cycle, timeNode, statePointName=None):
Expand Down Expand Up @@ -384,8 +386,6 @@ class Database3(database.Database):
`doc/user/outputs/database` for more details.
"""

version = DB_VERSION

timeNodeGroupPattern = re.compile(r"^c(\d\d)n(\d\d)$")

def __init__(self, fileName: str, permission: str):
Expand Down Expand Up @@ -413,6 +413,31 @@ def __init__(self, fileName: str, permission: str):
# closed.
self._openCount: int = 0

if permission == "w":
self.version = DB_VERSION
else:
# will be set upon read
self._version = None
self._versionMajor = None
self._versionMinor = None

@property
def version(self):
return self._version

@version.setter
def version(self, value):
self._version = value
self._versionMajor, self._versionMinor = (int(v) for v in value.split("."))

@property
def versionMajor(self):
return self._versionMajor

@property
def versionMinor(self):
return self._versionMinor

def __repr__(self):
return "<{} {}>".format(
self.__class__.__name__, repr(self.h5db).replace("<", "").replace(">", "")
Expand All @@ -430,6 +455,7 @@ def open(self):
if self._permission in {"r", "a"}:
self._fullPath = os.path.abspath(filePath)
self.h5db = h5py.File(filePath, self._permission)
self.version = self.h5db.attrs["databaseVersion"]
return

if self._permission == "w":
Expand All @@ -447,7 +473,7 @@ def open(self):
self.h5db = h5py.File(filePath, self._permission)
self.h5db.attrs["successfulCompletion"] = False
self.h5db.attrs["version"] = armi.__version__
self.h5db.attrs["databaseVersion"] = DB_VERSION
self.h5db.attrs["databaseVersion"] = self.version
self.h5db.attrs["user"] = armi.USER
self.h5db.attrs["python"] = sys.version
self.h5db.attrs["armiLocation"] = os.path.dirname(armi.ROOT)
Expand Down Expand Up @@ -637,6 +663,20 @@ def mergeHistory(self, inputDB, startCycle, startNode):
return
self.h5db.copy(h5ts, h5ts.name)

if inputDB.versionMinor < 2:
# The source database may have object references in some attributes.
# make sure to link those up using our manual path strategy.
references = []
def findReferences(name, obj):
for key, attr in obj.attrs.items():
if isinstance(attr, h5py.h5r.Reference):
references.append((name, key, inputDB.h5db[attr].name))
h5ts.visititems(findReferences)

for key, attr, path in references:
destTs = self.h5db[h5ts.name]
destTs[key].attrs[attr] = "@{}".format(path)

def __enter__(self):
"""Context management support"""
if self._openCount == 0:
Expand Down Expand Up @@ -1976,11 +2016,10 @@ def _writeAttrs(obj, group, attrs):
In such cases, this will store the attribute data in a proper dataset and
place a reference to that dataset in the attribute instead.
In practice, this takes ``linkedDims`` attrs from a particular
component type (like ``c00/n00/Circle/id``) and stores them
in new datasets (like ``c00n00/attrs/1_linkedDims``,
``c00n00/attrs/2_linkedDims``) and
then sets the object's attrs to links to those datasets.
In practice, this takes ``linkedDims`` attrs from a particular component type (like
``c00n00/Circle/id``) and stores them in new datasets (like
``c00n00/attrs/1_linkedDims``, ``c00n00/attrs/2_linkedDims``) and then sets the
object's attrs to links to those datasets.
"""
for key, value in attrs.items():
try:
Expand All @@ -2001,7 +2040,10 @@ def _writeAttrs(obj, group, attrs):
dataName = str(len(attrGroup)) + "_" + key
attrGroup[dataName] = value

obj.attrs[key] = attrGroup[dataName].ref
# using a soft link here allows us to cheaply copy time nodes without
# needing to crawl through and update object references.
linkName = attrGroup[dataName].name
obj.attrs[key] = "@{}".format(linkName)


def _resolveAttrs(attrs, group):
Expand All @@ -2015,9 +2057,17 @@ def _resolveAttrs(attrs, group):
for key, val in attrs.items():
try:
if isinstance(val, h5py.h5r.Reference):
# dereference the .ref to get the data
# out of the dataset.
# Old style object reference. If this cannot be dereferenced, it is
# likely because mergeHistory was used to get the current database,
# which does not preserve references.
resolved[key] = group[val]
elif isinstance(val, str):
m = ATTR_LINK.match(val)
if m:
# dereference the path to get the data out of the dataset.
resolved[key] = group[m.group(1)][()]
else:
resolved[key] = val
else:
resolved[key] = val
except ValueError:
Expand Down
56 changes: 43 additions & 13 deletions armi/bookkeeping/db/tests/test_database3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,29 @@


class TestDatabase3(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.o, cls.r = test_reactors.loadTestReactor(TEST_ROOT)

def setUp(self):
self.o, self.r = test_reactors.loadTestReactor(TEST_ROOT)
self.db = database.Database3(self._testMethodName + ".h5", "w")
self.db.open()
print(self.db._fullPath)
self.stateRetainer = self.r.retainState().__enter__()

def tearDown(self):
self.db.close()
self.stateRetainer.__exit__()

def makeHistory(self):
"""
Walk the reactor through a few time steps and write them to the db.
"""
for cycle, node in ((cycle, node) for cycle in range(3) for node in range(3)):
self.r.p.cycle = cycle
self.r.p.timeNode = node
# something that splitDatabase won't change, so that we can make sure that
# the right data went to the right new groups/cycles
self.r.p.cycleLength = cycle

self.db.writeToDB(self.r)

def _compareArrays(self, ref, src):
"""
Compare two numpy arrays.
Expand Down Expand Up @@ -94,15 +103,36 @@ def test_replaceNones(self):
self._compareRoundTrip(dataJagNones)
self._compareRoundTrip(dataDict)

def test_splitDatabase(self):
for cycle, node in ((cycle, node) for cycle in range(3) for node in range(3)):
self.r.p.cycle = cycle
self.r.p.timeNode = node
# something that splitDatabase won't change, so that we can make sure that
# the right data went to the right new groups/cycles
self.r.p.cycleLength = cycle
def test_mergeHistory(self):
self.makeHistory()

# put some big data in an HDF5 attribute. This will exercise the code that pulls
# such attributes into a formal dataset and a reference.
self.r.p.cycle = 1
self.r.p.timeNode = 0
tnGroup = self.db.getH5Group(self.r)
database._writeAttrs(
tnGroup["layout/serialNum"], tnGroup, {"fakeBigData": numpy.eye(6400),
"someString": "this isn't a reference to another dataset"}
)

self.db.writeToDB(self.r)
db2 = database.Database3("restartDB.h5", "w")
with db2:
db2.mergeHistory(self.db, 2, 2)
self.r.p.cycle = 1
self.r.p.timeNode = 0
tnGroup = db2.getH5Group(self.r)

# this test is a little bit implementation-specific, but nice to be explicit
self.assertEqual(tnGroup["layout/serialNum"].attrs["fakeBigData"],
"@/c01n00/attrs/0_fakeBigData")

# actually exercise the _resolveAttrs function
attrs = database._resolveAttrs(tnGroup["layout/serialNum"].attrs, tnGroup)
self.assertTrue(numpy.array_equal(attrs["fakeBigData"], numpy.eye(6400)))

def test_splitDatabase(self):
self.makeHistory()

self.db.splitDatabase(
[(c, n) for c in (1, 2) for n in range(3)], "-all-iterations"
Expand Down

0 comments on commit 1bb8c77

Please sign in to comment.