diff --git a/docs/data-model.rst b/docs/data-model.rst index 72a5129534..c445ae6f5e 100644 --- a/docs/data-model.rst +++ b/docs/data-model.rst @@ -634,12 +634,12 @@ requirements for a valid set of mutations are: - ``site`` must refer to a valid site ID; - ``node`` must refer to a valid node ID; -- ``time`` must either be UNKNOWN_TIME (a NAN value which indicates +- ``time`` must either be ``UNKNOWN_TIME`` (a NAN value which indicates the time is unknown) or be a finite value which is greater or equal to the mutation ``node``'s ``time``, less than the ``node`` above the mutation's ``time`` and equal to or less than the ``time`` of the ``parent`` mutation - if this mutation has one. If one mutation on a site has UNKNOWN_TIME then all mutations - at that site must, a mixture of known and unknown is not valid. + if this mutation has one. If one mutation on a site has ``UNKNOWN_TIME`` then + all mutations at that site must, as a mixture of known and unknown is not valid. - ``parent`` must either be the null ID (-1) or a valid mutation ID within the current table @@ -668,6 +668,9 @@ mutation does not result in any change of state. This error is raised at run-time when we reconstruct sample genotypes, for example in the :meth:`TreeSequence.variants` iterator. +.. note:: As ``tskit.UNKNOWN_TIME`` is implemented as a ``NaN`` value, tests for + equality will always fail. Use ``tskit.is_unknown_time`` to detect unknown + values. .. _sec_migration_requirements: diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 644be33d84..99981b7b02 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -28,6 +28,8 @@ - New ``tree.is_isolated(u)`` method (:user:`hyanwong`, :pr:`443`). +- ``tskit.is_unknown_time`` can now check arrays. (:user:`benjeffery`, :pr:`857`). + -------------------- [0.3.1] - 2020-09-04 -------------------- diff --git a/python/tests/test_util.py b/python/tests/test_util.py index 657ed293cc..151365d145 100644 --- a/python/tests/test_util.py +++ b/python/tests/test_util.py @@ -29,6 +29,7 @@ import unittest import numpy as np +from numpy.testing import assert_array_equal import tests.tsutil as tsutil import tskit.util as util @@ -64,7 +65,15 @@ def test_canonical_json(self): class TestUnknownTime(unittest.TestCase): - def test_unknown_time(self): + def test_unknown_time_bad_types(self): + with self.assertRaises(ValueError): + util.is_unknown_time("bad") + with self.assertRaises(ValueError): + util.is_unknown_time(np.array(["bad"])) + with self.assertRaises(ValueError): + util.is_unknown_time(["bad"]) + + def test_unknown_time_scalar(self): self.assertTrue(math.isnan(UNKNOWN_TIME)) self.assertTrue(util.is_unknown_time(UNKNOWN_TIME)) self.assertFalse(util.is_unknown_time(math.nan)) @@ -72,6 +81,28 @@ def test_unknown_time(self): self.assertFalse(util.is_unknown_time(0)) self.assertFalse(util.is_unknown_time(math.inf)) self.assertFalse(util.is_unknown_time(1)) + self.assertFalse(util.is_unknown_time(None)) + self.assertFalse(util.is_unknown_time([None])) + + def test_unknown_time_array(self): + test_arrays = ( + [], + [True], + [False], + [True, False] * 5, + [[True], [False]], + [[[True, False], [True, False]], [[False, True], [True, False]]], + ) + for spec in test_arrays: + spec = np.asarray(spec, dtype=bool) + array = np.zeros(shape=spec.shape) + array[spec] = UNKNOWN_TIME + assert_array_equal(spec, util.is_unknown_time(array)) + + weird_array = [0, UNKNOWN_TIME, np.nan, 1, math.inf] + assert_array_equal( + [False, True, False, False, False], util.is_unknown_time(weird_array) + ) class TestNumpyArrayCasting(unittest.TestCase): diff --git a/python/tskit/util.py b/python/tskit/util.py index 1b145d4b72..3b0e5cbcb3 100644 --- a/python/tskit/util.py +++ b/python/tskit/util.py @@ -23,7 +23,6 @@ Module responsible for various utility functions used in other modules. """ import json -import struct import numpy as np @@ -45,9 +44,16 @@ def canonical_json(obj): def is_unknown_time(time): """ As the default unknown mutation time is NAN equality always fails. This - method compares the bitfield. + method compares the bitfield such that unknown times can be detected. + Either single floats can be passed or lists/arrays. + + :param float or array-like time: Value or array to check. + :return: A single boolean or array of booleans the same shape as ``time``. + :rtype: bool or np.array(dtype=bool) """ - return struct.pack(">d", UNKNOWN_TIME) == struct.pack(">d", time) + return np.asarray(time, dtype=np.float64).view(np.uint64) == np.float64( + UNKNOWN_TIME + ).view(np.uint64) def safe_np_int_cast(int_array, dtype, copy=False):