Skip to content

Commit

Permalink
Merge pull request #4 from terrapower/packSpecialFixes
Browse files Browse the repository at this point in the history
Fix bugs in database3 pack/unpackSpecialData
  • Loading branch information
ntouran committed Nov 5, 2019
2 parents 67688e4 + 4db697b commit 148095d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 24 deletions.
22 changes: 18 additions & 4 deletions armi/bookkeeping/db/database3.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Reactor Model to the flat representation in the database.
"""
import collections
import copy
import io
import itertools
import os
Expand Down Expand Up @@ -1611,8 +1612,16 @@ def packSpecialData(
"""

# Check to make sure that we even need to do this. If the numpy data type is
# not "O", chances are we have nice, clean data.
if data.dtype != "O":
return data, {}

attrs: Dict[str, Any] = {"specialFormatting": True}

# make a copy of the data, so that the original is unchanged
data = copy.copy(data)

# find locations of Nones. The below works for ndarrays, whereas `data == None`
# gives a single True/False value
nones = numpy.where([d is None for d in data])[0]
Expand All @@ -1621,10 +1630,10 @@ def packSpecialData(
# Everything is None, so why bother?
return None, attrs

if nones.any():
if len(nones) > 0:
attrs["nones"] = True

# XXX: this whole if/iften/elif/else can be optimized by looping once and then
# XXX: this whole if/then/elif/else can be optimized by looping once and then
# determining the correct action
# A robust solution would need
# to do this on a case-by-case basis, and re-do it any time we want to
Expand Down Expand Up @@ -1745,8 +1754,13 @@ def unpackSpecialData(data: numpy.ndarray, attrs, paramName: str) -> numpy.ndarr
--------
packSpecialData
"""
if not attrs.get("specialFormatting", False):
# The data were not subjected to any special formatting; short circuit.
assert data.dtype != "O"
return data

unpackedData: List[Any]
if attrs.get("nones", False):
if attrs.get("nones", False) and not attrs.get("jagged", False):
data = replaceNonsenseWithNones(data, paramName)
return data
if attrs.get("jagged", False):
Expand Down Expand Up @@ -1812,7 +1826,7 @@ def replaceNonsenseWithNones(data: numpy.ndarray, paramName: str) -> numpy.ndarr
isNone = data == "<!None!>"
else:
raise TypeError(
"Unable to resolve values that should be None for {}".format(paramName)
"Unable to resolve values that should be None for `{}`".format(paramName)
)

if data.ndim > 1:
Expand Down
64 changes: 44 additions & 20 deletions armi/bookkeeping/db/tests/test_database3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import unittest
import numpy
import numpy.testing
import h5py
import os

Expand All @@ -38,6 +39,36 @@ def tearDown(self):
self.db.close()
self.stateRetainer.__exit__()

def _compareArrays(self, ref, src):
"""
Compare two numpy arrays.
Comparing numpy arrays that may have unsavory data (NaNs, Nones, jagged
data, etc.) is really difficult. For now, convert to a list and compare
element-by-element.
"""
self.assertEqual(type(ref), type(src))
if isinstance(ref, numpy.ndarray):
ref = ref.tolist()
src = src.tolist()

for v1, v2 in zip(ref, src):
# Entries may be None
if isinstance(v1, numpy.ndarray):
v1 = v1.tolist()
if isinstance(v2, numpy.ndarray):
v2 = v2.tolist()
self.assertEqual(v1, v2)

def _compareRoundTrip(self, data):
"""
Make sure that data is unchanged by packing/unpacking.
"""
packed, attrs = database.packSpecialData(data, "testing")
roundTrip = database.unpackSpecialData(packed, attrs, "testing")
self._compareArrays(data, roundTrip)


def test_replaceNones(self):
"""
This definitely needs some work.
Expand All @@ -47,28 +78,21 @@ def test_replaceNones(self):
data1iNones = numpy.array([1, 2, None, 5, 6])
data1fNones = numpy.array([None, 2.0, None, 5.0, 6.0])
data2fNones = numpy.array([None, [[1.0, 2.0, 6.0], [2.0, 3.0, 4.0]]])
data_jag = numpy.array([[[1, 2], [3, 4]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]])
data_dict = numpy.array(
dataJag = numpy.array([[[1, 2], [3, 4]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]])
dataJagNones = numpy.array(
[[[1, 2], [3, 4]], [[1],[1]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
)
dataDict = numpy.array(
[{"bar": 2, "baz": 3}, {"foo": 4, "baz": 6}, {"foo": 7, "bar": 8}]
)
# nones = numpy.where([d is None for d in data1])[0]
# conv_d1 = database.replaceNonesWithNonsense(data1, None, nones)
print("data3: ", database.packSpecialData(data3, ""))
print("data_jag", database.packSpecialData(data_jag, ""))
# print("data1", database.packSpecialData(data1, ""))
print("data1iNones", database.packSpecialData(data1iNones, ""))
print("data1fNones", database.packSpecialData(data1fNones, ""))
print("data2fNones", database.packSpecialData(data2fNones, ""))
print("dataDict", database.packSpecialData(data_dict, ""))

packedData, attrs = database.packSpecialData(data_jag, "")
roundTrip = database.unpackSpecialData(packedData, attrs, "")
print("round-tripped jagged:", roundTrip)
print("round-tripped dtype:", roundTrip.dtype)

packedData, attrs = database.packSpecialData(data_dict, "")
roundTrip = database.unpackSpecialData(packedData, attrs, "")
print("round-tripped dict:", roundTrip)
self._compareRoundTrip(data3)
self._compareRoundTrip(data1)
self._compareRoundTrip(data1iNones)
self._compareRoundTrip(data1fNones)
self._compareRoundTrip(data2fNones)
self._compareRoundTrip(dataJag)
self._compareRoundTrip(dataJagNones)
self._compareRoundTrip(dataDict)

def test_splitDatabase(self):
for cycle, node in ((cycle, node) for cycle in range(3) for node in range(3)):
Expand Down

0 comments on commit 148095d

Please sign in to comment.