Skip to content

Commit

Permalink
Add time_units enumeration
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed Oct 21, 2021
1 parent 5bd854d commit 06a4e5d
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 10 deletions.
4 changes: 4 additions & 0 deletions python/_tskitmodule.c
Expand Up @@ -11915,5 +11915,9 @@ PyInit__tskit(void)
PyModule_AddIntConstant(module, "FORWARD", TSK_DIR_FORWARD);
PyModule_AddIntConstant(module, "REVERSE", TSK_DIR_REVERSE);

PyModule_AddStringConstant(module, "DEFAULT_TIME_UNITS", TSK_DEFAULT_TIME_UNITS);
PyModule_AddStringConstant(
module, "TIME_UNITS_UNCALIBRATED", TSK_TIME_UNITS_UNCALIBRATED);

return module;
}
4 changes: 2 additions & 2 deletions python/lwt_interface/dict_encoding_testlib.py
Expand Up @@ -211,7 +211,7 @@ def test_missing_time_units(self, tables):
lwt = lwt_module.LightweightTableCollection()
lwt.fromdict(d)
tables = tskit.TableCollection.fromdict(lwt.asdict())
assert tables.time_units == "unknown"
assert tables.time_units == tskit.DEFAULT_TIME_UNITS

def test_missing_metadata(self, tables):
assert tables.metadata != b""
Expand Down Expand Up @@ -603,7 +603,7 @@ def test_top_level_time_units(self, tables):
lwt.fromdict(d)
out = lwt.asdict()
tables = tskit.TableCollection.fromdict(out)
assert tables.time_units == "unknown"
assert tables.time_units == tskit.DEFAULT_TIME_UNITS
# Missing is tested in TestMissingData above
d = tables.asdict()
# None should give default value
Expand Down
4 changes: 3 additions & 1 deletion python/tests/test_highlevel.py
Expand Up @@ -2397,14 +2397,16 @@ def test_tree_sequence_metadata(self):
def test_tree_sequence_time_units(self):
tc = tskit.TableCollection(1)
ts = tc.tree_sequence()
assert ts.time_units == "unknown"
assert ts.time_units == tskit.DEFAULT_TIME_UNITS
tc.time_units = "something else"
ts = tc.tree_sequence()
assert ts.time_units == "something else"
with pytest.raises(AttributeError):
del ts.time_units
with pytest.raises(AttributeError):
ts.time_units = "readonly"
assert tskit.DEFAULT_TIME_UNITS == "unknown"
assert tskit.TIME_UNITS_UNCALIBRATED == "uncalibrated"

def test_table_metadata_schemas(self):
ts = msprime.simulate(5)
Expand Down
9 changes: 7 additions & 2 deletions python/tests/test_lowlevel.py
Expand Up @@ -234,7 +234,7 @@ def test_set_time_units_errors(self):

def test_set_time_units(self):
tables = _tskit.TableCollection(1)
assert tables.time_units == "unknown"
assert tables.time_units == tskit.DEFAULT_TIME_UNITS
for value in ["foo", "", "💩", "null char \0 in string"]:
tables.time_units = value
assert tables.time_units == value
Expand Down Expand Up @@ -1309,7 +1309,7 @@ def test_time_units(self):
tables.build_index()
ts = _tskit.TreeSequence()
ts.load_tables(tables)
assert ts.get_time_units() == "unknown"
assert ts.get_time_units() == tskit.DEFAULT_TIME_UNITS
for value in ["foo", "", "💩", "null char \0 in string"]:
tables.time_units = value
ts = _tskit.TreeSequence()
Expand Down Expand Up @@ -3254,3 +3254,8 @@ def test_uninitialised():
method = getattr(uninitialised, method_name)
with pytest.raises((SystemError, ValueError)):
method()


def test_constants():
assert _tskit.DEFAULT_TIME_UNITS == "unknown"
assert _tskit.TIME_UNITS_UNCALIBRATED == "uncalibrated"
2 changes: 1 addition & 1 deletion python/tests/test_tables.py
Expand Up @@ -4078,7 +4078,7 @@ def test_set_metadata(self):

def test_set_time_units(self):
tc = tskit.TableCollection(1)
assert tc.time_units == "unknown"
assert tc.time_units == tskit.DEFAULT_TIME_UNITS

ex1 = "years"
ex2 = "generations"
Expand Down
4 changes: 2 additions & 2 deletions python/tests/test_tree_stats.py
Expand Up @@ -6084,7 +6084,7 @@ class TestTimeUncalibratedErrors:
def test_uncalibrated_time_allele_frequency_spectrum(self, ts_fixture):
ts_fixture.allele_frequency_spectrum(mode="branch")
tables = ts_fixture.dump_tables()
tables.time_units = "uncalibrated"
tables.time_units = tskit.TIME_UNITS_UNCALIBRATED
ts_uncalibrated = tables.tree_sequence()
ts_uncalibrated.allele_frequency_spectrum(mode="site")
with pytest.raises(
Expand All @@ -6100,7 +6100,7 @@ def test_uncalibrated_time_general_stat(self, ts_fixture):
W, lambda x: x * (x < ts_fixture.num_samples), W.shape[1], mode="branch"
)
tables = ts_fixture.dump_tables()
tables.time_units = "uncalibrated"
tables.time_units = tskit.TIME_UNITS_UNCALIBRATED
ts_uncalibrated = tables.tree_sequence()
ts_uncalibrated.general_stat(
W, lambda x: x * (x < ts_uncalibrated.num_samples), W.shape[1], mode="site"
Expand Down
6 changes: 6 additions & 0 deletions python/tskit/__init__.py
Expand Up @@ -50,6 +50,12 @@
#: NAN value, you cannot use `==` to test for it. Use :func:`is_unknown_time` instead.
UNKNOWN_TIME = _tskit.UNKNOWN_TIME

#: Default value of ts.time_units
DEFAULT_TIME_UNITS = _tskit.DEFAULT_TIME_UNITS

#: ts.time_units value when dimension is uncalibrated
TIME_UNITS_UNCALIBRATED = _tskit.TIME_UNITS_UNCALIBRATED

#: Options for printing to strings and HTML, modify with tskit.set_print_options.
_print_options = {"max_lines": 40}

Expand Down
3 changes: 2 additions & 1 deletion python/tskit/drawing.py
Expand Up @@ -38,6 +38,7 @@
import numpy as np
import svgwrite

import tskit
import tskit.util as util
from _tskit import NODE_IS_SAMPLE
from _tskit import NULL
Expand Down Expand Up @@ -551,7 +552,7 @@ def __init__(
y_label = "Node time"
else:
y_label = "Time"
if ts.time_units != "unknown":
if ts.time_units != tskit.DEFAULT_TIME_UNITS:
y_label += f" ({ts.time_units})"
self.x_label = x_label
self.y_label = y_label
Expand Down
3 changes: 2 additions & 1 deletion python/tskit/trees.py
Expand Up @@ -41,6 +41,7 @@
import numpy as np

import _tskit
import tskit
import tskit.combinatorics as combinatorics
import tskit.drawing as drawing
import tskit.exceptions as exceptions
Expand Down Expand Up @@ -3153,7 +3154,7 @@ def parse_mutations(
if len(tokens) >= 3:
site = int(tokens[site_index])
node = int(tokens[node_index])
if time_index is None or tokens[time_index] == "unknown":
if time_index is None or tokens[time_index] == tskit.DEFAULT_TIME_UNITS:
time = UNKNOWN_TIME
else:
time = float(tokens[time_index])
Expand Down

0 comments on commit 06a4e5d

Please sign in to comment.