Skip to content

Commit

Permalink
Merge pull request #505 from grahamgower/numpy-str
Browse files Browse the repository at this point in the history
Fix use of numpy strings for deme names
  • Loading branch information
grahamgower committed Apr 3, 2023
2 parents d28db30 + 0eefa86 commit 8b76542
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
16 changes: 11 additions & 5 deletions demes/demes.py
Expand Up @@ -2369,21 +2369,27 @@ def filt(attrib, value):
or (not (hasattr(value, "__len__") and len(value) == 0))
) and attrib.name != "_deme_map"

def coerce_numbers(inst, attribute, value):
# Explicitly convert numeric types to int or float, so that they
def coerce_types(inst, attribute, value):
# Explicitly convert numeric and string types, so that they
# don't cause problems for the YAML and JSON serialisers.
# E.g. numpy int32/int64 are part of Python's numeric tower as
# Numpy int32/int64 are part of Python's numeric tower as
# subclasses of numbers.Integral, similarly numpy's float32/float64
# are subclasses of numbers.Real. There are yet other numeric types,
# such as the standard library's decimal.Decimal, which are not part
# of the numeric tower, but provide a __float__() method.
if isinstance(value, numbers.Integral):
# Likewise, string subclasses such as numpy.str_ aren't recognised
# by the YAML serialiser, so we explicitly convert them to str.
# We check for 'str' first, because numpy.str_ also has a
# __float__() method.
if isinstance(value, str):
value = str(value)
elif isinstance(value, numbers.Integral):
value = int(value)
elif isinstance(value, numbers.Real) or hasattr(value, "__float__"):
value = float(value)
return value

data = attr.asdict(self, filter=filt, value_serializer=coerce_numbers)
data = attr.asdict(self, filter=filt, value_serializer=coerce_types)
# translate to spec data model
for deme in data["demes"]:
for epoch in deme["epochs"]:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_load_dump.py
Expand Up @@ -575,6 +575,16 @@ def test_float_subclass(self):
)
self.check_dump_load_roundtrip(b.resolve())

def test_str_subclass(self):
# Check that numpy.str_ are round-trippable.
b = demes.Builder(defaults=dict(epoch=dict(start_size=1)))
names = np.array(["a", "b"])
b.add_deme(names[0])
b.add_deme(names[1], ancestors=[names[0]], start_time=50)
b.add_pulse(sources=[names[0]], dest=names[1], time=10, proportions=[0.1])
b.add_migration(source=names[0], dest=names[1], rate=1e-3)
self.check_dump_load_roundtrip(b.resolve())

def test_json_infinities_get_stringified(self):
b = demes.Builder()
b.add_deme("a", epochs=[dict(start_size=1)])
Expand Down

0 comments on commit 8b76542

Please sign in to comment.