From 06a4e5d51f5c8db59162b28a88b74faa9e19d6e6 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 21 Oct 2021 12:30:42 +0100 Subject: [PATCH] Add time_units enumeration --- python/_tskitmodule.c | 4 ++++ python/lwt_interface/dict_encoding_testlib.py | 4 ++-- python/tests/test_highlevel.py | 4 +++- python/tests/test_lowlevel.py | 9 +++++++-- python/tests/test_tables.py | 2 +- python/tests/test_tree_stats.py | 4 ++-- python/tskit/__init__.py | 6 ++++++ python/tskit/drawing.py | 3 ++- python/tskit/trees.py | 3 ++- 9 files changed, 29 insertions(+), 10 deletions(-) diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index e56187dff3..d8f59a19dc 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -11915,5 +11915,9 @@ PyInit__tskit(void) PyModule_AddIntConstant(module, "FORWARD", TSK_DIR_FORWARD); PyModule_AddIntConstant(module, "REVERSE", TSK_DIR_REVERSE); + PyModule_AddStringConstant(module, "DEFAULT_TIME_UNITS", TSK_DEFAULT_TIME_UNITS); + PyModule_AddStringConstant( + module, "TIME_UNITS_UNCALIBRATED", TSK_TIME_UNITS_UNCALIBRATED); + return module; } diff --git a/python/lwt_interface/dict_encoding_testlib.py b/python/lwt_interface/dict_encoding_testlib.py index 3f6d2d72b4..c74fc8bb61 100644 --- a/python/lwt_interface/dict_encoding_testlib.py +++ b/python/lwt_interface/dict_encoding_testlib.py @@ -211,7 +211,7 @@ def test_missing_time_units(self, tables): lwt = lwt_module.LightweightTableCollection() lwt.fromdict(d) tables = tskit.TableCollection.fromdict(lwt.asdict()) - assert tables.time_units == "unknown" + assert tables.time_units == tskit.DEFAULT_TIME_UNITS def test_missing_metadata(self, tables): assert tables.metadata != b"" @@ -603,7 +603,7 @@ def test_top_level_time_units(self, tables): lwt.fromdict(d) out = lwt.asdict() tables = tskit.TableCollection.fromdict(out) - assert tables.time_units == "unknown" + assert tables.time_units == tskit.DEFAULT_TIME_UNITS # Missing is tested in TestMissingData above d = tables.asdict() # None should give default value diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 09adff6d57..5daa3e4cdd 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -2397,7 +2397,7 @@ def test_tree_sequence_metadata(self): def test_tree_sequence_time_units(self): tc = tskit.TableCollection(1) ts = tc.tree_sequence() - assert ts.time_units == "unknown" + assert ts.time_units == tskit.DEFAULT_TIME_UNITS tc.time_units = "something else" ts = tc.tree_sequence() assert ts.time_units == "something else" @@ -2405,6 +2405,8 @@ def test_tree_sequence_time_units(self): del ts.time_units with pytest.raises(AttributeError): ts.time_units = "readonly" + assert tskit.DEFAULT_TIME_UNITS == "unknown" + assert tskit.TIME_UNITS_UNCALIBRATED == "uncalibrated" def test_table_metadata_schemas(self): ts = msprime.simulate(5) diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index be9374929f..b6a93d0460 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -234,7 +234,7 @@ def test_set_time_units_errors(self): def test_set_time_units(self): tables = _tskit.TableCollection(1) - assert tables.time_units == "unknown" + assert tables.time_units == tskit.DEFAULT_TIME_UNITS for value in ["foo", "", "💩", "null char \0 in string"]: tables.time_units = value assert tables.time_units == value @@ -1309,7 +1309,7 @@ def test_time_units(self): tables.build_index() ts = _tskit.TreeSequence() ts.load_tables(tables) - assert ts.get_time_units() == "unknown" + assert ts.get_time_units() == tskit.DEFAULT_TIME_UNITS for value in ["foo", "", "💩", "null char \0 in string"]: tables.time_units = value ts = _tskit.TreeSequence() @@ -3254,3 +3254,8 @@ def test_uninitialised(): method = getattr(uninitialised, method_name) with pytest.raises((SystemError, ValueError)): method() + + +def test_constants(): + assert _tskit.DEFAULT_TIME_UNITS == "unknown" + assert _tskit.TIME_UNITS_UNCALIBRATED == "uncalibrated" diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 91c861b8fd..960b624725 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -4078,7 +4078,7 @@ def test_set_metadata(self): def test_set_time_units(self): tc = tskit.TableCollection(1) - assert tc.time_units == "unknown" + assert tc.time_units == tskit.DEFAULT_TIME_UNITS ex1 = "years" ex2 = "generations" diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py index 273c73d3af..74ec56741d 100644 --- a/python/tests/test_tree_stats.py +++ b/python/tests/test_tree_stats.py @@ -6084,7 +6084,7 @@ class TestTimeUncalibratedErrors: def test_uncalibrated_time_allele_frequency_spectrum(self, ts_fixture): ts_fixture.allele_frequency_spectrum(mode="branch") tables = ts_fixture.dump_tables() - tables.time_units = "uncalibrated" + tables.time_units = tskit.TIME_UNITS_UNCALIBRATED ts_uncalibrated = tables.tree_sequence() ts_uncalibrated.allele_frequency_spectrum(mode="site") with pytest.raises( @@ -6100,7 +6100,7 @@ def test_uncalibrated_time_general_stat(self, ts_fixture): W, lambda x: x * (x < ts_fixture.num_samples), W.shape[1], mode="branch" ) tables = ts_fixture.dump_tables() - tables.time_units = "uncalibrated" + tables.time_units = tskit.TIME_UNITS_UNCALIBRATED ts_uncalibrated = tables.tree_sequence() ts_uncalibrated.general_stat( W, lambda x: x * (x < ts_uncalibrated.num_samples), W.shape[1], mode="site" diff --git a/python/tskit/__init__.py b/python/tskit/__init__.py index e868e17ded..cfcb61263f 100644 --- a/python/tskit/__init__.py +++ b/python/tskit/__init__.py @@ -50,6 +50,12 @@ #: NAN value, you cannot use `==` to test for it. Use :func:`is_unknown_time` instead. UNKNOWN_TIME = _tskit.UNKNOWN_TIME +#: Default value of ts.time_units +DEFAULT_TIME_UNITS = _tskit.DEFAULT_TIME_UNITS + +#: ts.time_units value when dimension is uncalibrated +TIME_UNITS_UNCALIBRATED = _tskit.TIME_UNITS_UNCALIBRATED + #: Options for printing to strings and HTML, modify with tskit.set_print_options. _print_options = {"max_lines": 40} diff --git a/python/tskit/drawing.py b/python/tskit/drawing.py index c66c476384..469513c978 100644 --- a/python/tskit/drawing.py +++ b/python/tskit/drawing.py @@ -38,6 +38,7 @@ import numpy as np import svgwrite +import tskit import tskit.util as util from _tskit import NODE_IS_SAMPLE from _tskit import NULL @@ -551,7 +552,7 @@ def __init__( y_label = "Node time" else: y_label = "Time" - if ts.time_units != "unknown": + if ts.time_units != tskit.DEFAULT_TIME_UNITS: y_label += f" ({ts.time_units})" self.x_label = x_label self.y_label = y_label diff --git a/python/tskit/trees.py b/python/tskit/trees.py index dda5524bdc..74028739ea 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -41,6 +41,7 @@ import numpy as np import _tskit +import tskit import tskit.combinatorics as combinatorics import tskit.drawing as drawing import tskit.exceptions as exceptions @@ -3153,7 +3154,7 @@ def parse_mutations( if len(tokens) >= 3: site = int(tokens[site_index]) node = int(tokens[node_index]) - if time_index is None or tokens[time_index] == "unknown": + if time_index is None or tokens[time_index] == tskit.DEFAULT_TIME_UNITS: time = UNKNOWN_TIME else: time = float(tokens[time_index])