Skip to content

Commit

Permalink
Merge d516c2d into a8e7328
Browse files Browse the repository at this point in the history
  • Loading branch information
ntouran committed Apr 9, 2020
2 parents a8e7328 + d516c2d commit 43dbbaa
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 47 deletions.
155 changes: 130 additions & 25 deletions armi/bookkeeping/db/database3.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
for reading previous versions, and for performing a ``mergeHistory()`` and converting
to the new reference strategy, but the old version cannot be written.
- 3.3: Compress the way locations are stored in the database and allow MultiIndex locations
to be read and written.
"""
import collections
import copy
Expand Down Expand Up @@ -89,13 +92,20 @@
from armi.utils.textProcessors import resolveMarkupInclusions

ORDER = interfaces.STACK_ORDER.BOOKKEEPING
DB_VERSION = "3.2"
DB_MAJOR = 3
DB_MINOR = 3
DB_VERSION = f"{DB_MAJOR}.{DB_MINOR}"

ATTR_LINK = re.compile("^@(.*)$")

_SERIALIZER_NAME = "serializerName"
_SERIALIZER_VERSION = "serializerVersion"

LOC_NONE = "N"
LOC_COORD = "C"
LOC_INDEX = "I"
LOC_MULTI = "M:"


def getH5GroupName(cycle, timeNode, statePointName=None):
return "c{:0>2}n{:0>2}{}".format(cycle, timeNode, statePointName or "")
Expand Down Expand Up @@ -809,7 +819,7 @@ def writeToDB(self, reactor, statePointName=None):
# _createLayout is recursive
h5group = self.getH5Group(reactor, statePointName)
runLog.info("Writing to database for statepoint: {}".format(h5group.name))
layout = Layout(comp=reactor)
layout = Layout(comp=reactor, version=(self.versionMajor, self.versionMinor))
layout.writeToDB(h5group)
groupedComps = layout.groupedComps

Expand Down Expand Up @@ -850,7 +860,7 @@ def load(self, cycle, node, cs=None, bp=None, statePointName=None):

h5group = self.h5db[getH5GroupName(cycle, node, statePointName)]

layout = Layout(h5group=h5group)
layout = Layout(h5group=h5group, version=(self.versionMajor, self.versionMinor))
comps, groupedComps = layout._initComps(cs, bp)

# populate data onto initialized components
Expand Down Expand Up @@ -1191,7 +1201,9 @@ def getHistories(

cycle = h5TimeNodeGroup.attrs["cycle"]
timeNode = h5TimeNodeGroup.attrs["timeNode"]
layout = Layout(h5group=h5TimeNodeGroup)
layout = Layout(
h5group=h5TimeNodeGroup, version=(self.versionMajor, self.versionMinor)
)

for compType, compsBySerialNum in compsByTypeThenSerialNum.items():
compTypeName = compType.__name__
Expand Down Expand Up @@ -1284,6 +1296,87 @@ def getHistories(
return histData


def _packLocation(
loc: grids.LocationBase, version: Tuple[int, int] = (DB_MAJOR, DB_MINOR)
) -> Tuple[str, List]:
"""
Extract information from a location needed to write it to this DB.
Each locator has one locationType and up to N location-defining datums,
where N is the number of entries in a possible multiindex, or just 1
for everything else.
Shrink grid locator names for storage efficiency.
Notes
-----
Contains some conditionals to still load databases made before
db version 3.3 which can be removed once no users care about
those DBs anymore.
"""
oldStyle = version[0] == 3 and version[1] < 3
if oldStyle:
locationType = loc.__class__.__name__
if loc is None:
if oldStyle:
locationType = "None"
else:
locationType = LOC_NONE
locData = [(0.0, 0.0, 0.0)]
elif loc.__class__ is grids.CoordinateLocation:
if not oldStyle:
locationType = LOC_COORD
locData = [loc.indices]
elif loc.__class__ is grids.IndexLocation:
if not oldStyle:
locationType = LOC_INDEX
locData = [loc.indices]
elif loc.__class__ is grids.MultiIndexLocation:
# encode number of sub-locations to allow in-line unpacking.
locationType = LOC_MULTI + f"{len(loc)}"
locData = [subloc.indices for subloc in loc]
else:
raise ValueError(f"Invalid location type: {loc}")

return locationType, locData


def _unpackLocation(
locationTypes, locData, version: Tuple[int, int] = (DB_MAJOR, DB_MINOR)
):
"""
Convert location data as read from DB back into data structure for building reactor model.
location and locationType will only have different lengths
when multiindex locations are used.
"""
oldStyle = version[0] == 3 and version[1] < 3
locsIter = iter(locData)
unpackedLocs = []
for lt in locationTypes:
if (oldStyle and lt == "None") or lt == LOC_NONE:
loc = next(locsIter)
unpackedLocs.append(None)
elif (oldStyle and lt == "IndexLocation") or lt == LOC_INDEX:
loc = next(locsIter)
# the data is stored as float, so cast back to int
unpackedLocs.append(tuple(int(i) for i in loc))
elif oldStyle or lt == LOC_COORD:
loc = next(locsIter)
unpackedLocs.append(tuple(loc))
elif lt.startswith(LOC_MULTI):
# extract number of sublocations from e.g. "M:345" string.
numSubLocs = int(lt.split(":")[1])
for _ in range(numSubLocs):
subLoc = next(locsIter)
# All multiindexes sublocs are index locs
unpackedLocs.append(tuple(int(i) for i in subLoc))
else:
raise ValueError(f"Read unknown location type {lt}. Invalid DB.")

return unpackedLocs


class Layout(object):
"""
The Layout class describes the hierarchical layout of the composite structure in a flat representation.
Expand All @@ -1300,7 +1393,7 @@ class Layout(object):
Layout.
"""

def __init__(self, h5group=None, comp=None):
def __init__(self, h5group=None, comp=None, version=None):
self.type = []
self.name = []
self.serialNum = []
Expand All @@ -1316,6 +1409,7 @@ def __init__(self, h5group=None, comp=None):
self._seenGridParams = dict()
# actual list of grid parameters, with stable order for safe indexing
self.gridParams = []
self.version = version

self.groupedComps = collections.defaultdict(list)

Expand All @@ -1341,8 +1435,19 @@ def __getitem__(self, sn):
)

def _createLayout(self, comp):
"""Recursive function to populate a hierarchical representation and group the
items by type."""
"""
Populate a hierarchical representation and group the reactor model items by type.
This is used when writing a reactor model to the database.
Notes
-----
This is recursive.
See Also
--------
_readLayout : does the opposite
"""
compList = self.groupedComps[type(comp)]
compList.append(comp)

Expand All @@ -1361,12 +1466,7 @@ def _createLayout(self, comp):
else:
self.gridIndex.append(None)

if comp.spatialLocator is None:
self.locationType.append("None")
self.location.append((0.0, 0.0, 0.0))
else:
self.locationType.append(comp.spatialLocator.__class__.__name__)
self.location.append(comp.spatialLocator.indices)
self._addSpatialLocatorData(comp.spatialLocator)

try:
self.temperatures.append((comp.inputTemperatureInC, comp.temperatureInC))
Expand All @@ -1376,17 +1476,31 @@ def _createLayout(self, comp):
self.material.append("")

try:
comps = sorted([c for c in comp])
comps = sorted(list(comp))
except ValueError:
runLog.error(
"Failed to sort some collection of ArmiObjects for database output: {} "
"value {}".format(type(comp), [c for c in comp])
"value {}".format(type(comp), list(comp))
)
raise

for c in comps:
self._createLayout(c)

def _addSpatialLocatorData(self, locator):
"""
Extend ``locationType`` and ``location`` attributes with location info.
Notes
-----
There are several types of locators and they must be encoded properly.
Most complicated are the MultiIndexLocations from grids, which
have multiple indices per component.
"""
locationType, locations = _packLocation(locator, self.version)
self.locationType.append(locationType)
self.location.extend(locations)

def _readLayout(self, h5group):
try:
# location is either an index, or a point
Expand All @@ -1395,16 +1509,7 @@ def _readLayout(self, h5group):
self.locationType = numpy.char.decode(
h5group["layout/locationType"][:]
).tolist()
self.location = locs = []
for l, lt in zip(locations, self.locationType):
if lt == "None":
locs.append(None)
elif lt == "IndexLocation":
# the data is stored as float, so cast back to int
locs.append(tuple(int(i) for i in l))
else:
locs.append(tuple(l))

self.location = _unpackLocation(self.locationType, locations, self.version)
self.type = numpy.char.decode(h5group["layout/type"][:])
self.name = numpy.char.decode(h5group["layout/name"][:])
self.serialNum = h5group["layout/serialNum"][:]
Expand Down
46 changes: 38 additions & 8 deletions armi/bookkeeping/db/tests/test_database3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
# limitations under the License.

import unittest
import os

import numpy
import numpy.testing
import h5py

from armi.bookkeeping.db import database3 as database
from armi.reactor import grids
from armi.reactor.tests import test_reactors
from armi.tests import TEST_ROOT, ARMI_RUN_PATH

from armi.bookkeeping.db import database3 as database
from armi.tests import TEST_ROOT


class TestDatabase3(unittest.TestCase):
Expand Down Expand Up @@ -112,8 +111,12 @@ def test_mergeHistory(self):
self.r.p.timeNode = 0
tnGroup = self.db.getH5Group(self.r)
database._writeAttrs(
tnGroup["layout/serialNum"], tnGroup, {"fakeBigData": numpy.eye(6400),
"someString": "this isn't a reference to another dataset"}
tnGroup["layout/serialNum"],
tnGroup,
{
"fakeBigData": numpy.eye(6400),
"someString": "this isn't a reference to another dataset",
},
)

db2 = database.Database3("restartDB.h5", "w")
Expand All @@ -124,8 +127,10 @@ def test_mergeHistory(self):
tnGroup = db2.getH5Group(self.r)

# this test is a little bit implementation-specific, but nice to be explicit
self.assertEqual(tnGroup["layout/serialNum"].attrs["fakeBigData"],
"@/c01n00/attrs/0_fakeBigData")
self.assertEqual(
tnGroup["layout/serialNum"].attrs["fakeBigData"],
"@/c01n00/attrs/0_fakeBigData",
)

# actually exercise the _resolveAttrs function
attrs = database._resolveAttrs(tnGroup["layout/serialNum"].attrs, tnGroup)
Expand All @@ -148,6 +153,31 @@ def test_splitDatabase(self):
self.assertTrue(newDb.attrs["databaseVersion"] == database.DB_VERSION)


class Test_LocationPacking(unittest.TestCase):
def test_locationPacking(self):
# pylint: disable=protected-access
loc1 = grids.IndexLocation(1, 2, 3, None)
loc2 = grids.CoordinateLocation(4.0, 5.0, 6.0, None)
loc3 = grids.MultiIndexLocation(None)
loc3.append(grids.IndexLocation(7, 8, 9, None))
loc3.append(grids.IndexLocation(10, 11, 12, None))

tp, data = database._packLocation(loc1)
self.assertEqual(tp, database.LOC_INDEX)
unpacked = database._unpackLocation([tp], data)
self.assertEqual(unpacked[0], (1, 2, 3))

tp, data = database._packLocation(loc2)
self.assertEqual(tp, database.LOC_COORD)
unpacked = database._unpackLocation([tp], data)
self.assertEqual(unpacked[0], (4.0, 5.0, 6.0))

tp, data = database._packLocation(loc3)
self.assertEqual(tp, database.LOC_MULTI + "2")
unpacked = database._unpackLocation([tp], data)
self.assertEqual(unpacked[0], (7, 8, 9))


if __name__ == "__main__":
import sys

Expand Down
8 changes: 5 additions & 3 deletions armi/bookkeeping/mainInterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import os
import subprocess
import re

from six.moves import zip_longest
import shutil
import itertools

from armi import interfaces
from armi import runLog
Expand Down Expand Up @@ -122,7 +122,7 @@ def _moveFiles(self):
# if any files to copy, then use the first as the default, i.e. len() == 1,
# otherwise assume '.'
default = copyFilesTo[0] if any(copyFilesTo) else "."
for filename, dest in zip_longest(
for filename, dest in itertools.zip_longest(
copyFilesFrom, copyFilesTo, fillvalue=default
):
pathTools.copyOrWarn("copyFilesFrom", filename, dest)
Expand Down Expand Up @@ -165,6 +165,8 @@ def interactEOL(self):

def updateClusterProgress(self):
"""Updates the status window on the Windows HPC client."""
if not shutil.which("job"):
return
totalSteps = max(
(self.cs["burnSteps"] + 1) * self.cs["nCycles"] - 1, 1
) # 0 through 5 if 2 cycles
Expand Down
21 changes: 11 additions & 10 deletions armi/physics/neutronics/fissionProductModel/fissionProductModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,17 @@
Examples
--------
from armi.physics.neutronics.fissionProductModel import fissionProductModel
fpInterface = fissionProductModel.FissionProductModel()
lfp = fpInterface.getGlobalLumpedFissionProducts()
lfp['LFP35']
lfp35 = lfp['LFP35']
lfp35.printDensities(0.05)
lfp35.values()
allFPs = [(fpY, fpNuc) for (fpNuc,fpY) in lfp35.items()]
allFPs.sort()
lfp35.keys()
from armi.physics.neutronics.fissionProductModel import fissionProductModel
fpInterface = fissionProductModel.FissionProductModel()
lfp = fpInterface.getGlobalLumpedFissionProducts()
lfp['LFP35']
lfp35 = lfp['LFP35']
lfp35.printDensities(0.05)
lfp35.values()
allFPs = [(fpY, fpNuc) for (fpNuc,fpY) in lfp35.items()]
allFPs.sort()
lfp35.keys()
"""

from armi import runLog
Expand Down
Loading

0 comments on commit 43dbbaa

Please sign in to comment.