diff --git a/docs/development.rst b/docs/development.rst index 4cf88220b1..767e741c37 100644 --- a/docs/development.rst +++ b/docs/development.rst @@ -197,7 +197,9 @@ To include the changes that the hooks made, ``git add`` any files that were modified and run ``git commit`` (or, use ``git commit -a`` to commit all changed files.) -If you would like to run the checks without committing, use ``pre-commit run``. +If you would like to run the checks without committing, use ``pre-commit run`` +(but, note that this will *only* check changes that have been *staged*; +do ``pre-commit run --all`` to check unstaged changes as well). To bypass the checks (to save or get feedback on work-in-progress) use ``git commit --no-verify`` diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index d93e422bc8..25afb55ef1 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -857,6 +857,57 @@ def test_packset_location(self): self.assertEqual(list(t[0].location), [0]) self.assertEqual(list(t[1].location), [1, 2, 3]) + def test_missing_time_equal_to_self(self): + t = tskit.TableCollection(sequence_length=10) + t.sites.add_row(position=1, ancestral_state="0") + t.mutations.add_row(site=0, node=0, derived_state="1", time=tskit.UNKNOWN_TIME) + self.assertEqual(t.mutations[0], t.mutations[0]) + + def test_various_not_equals(self): + args = { + "site": 0, + "node": 0, + "derived_state": "a", + "parent": 0, + "metadata": b"abc", + "time": 0, + } + a = tskit.MutationTableRow(**args) + self.assertNotEqual(a, []) + self.assertNotEqual(a, 12) + self.assertNotEqual(a, None) + b = tskit.MutationTableRow(**args) + self.assertEqual(a, b) + args["site"] = 2 + b = tskit.MutationTableRow(**args) + self.assertNotEqual(a, b) + args["site"] = 0 + args["node"] = 2 + b = tskit.MutationTableRow(**args) + self.assertNotEqual(a, b) + args["node"] = 0 + args["derived_state"] = "b" + b = tskit.MutationTableRow(**args) + self.assertNotEqual(a, b) + args["derived_state"] = "a" + args["parent"] = 2 + b = tskit.MutationTableRow(**args) + self.assertNotEqual(a, b) + args["parent"] = 0 + args["metadata"] = b"" + b = tskit.MutationTableRow(**args) + self.assertNotEqual(a, b) + args["metadata"] = b"abc" + args["time"] = 1 + b = tskit.MutationTableRow(**args) + self.assertNotEqual(a, b) + args["time"] = 0 + args["time"] = tskit.UNKNOWN_TIME + b = tskit.MutationTableRow(**args) + self.assertNotEqual(a, b) + a = tskit.MutationTableRow(**args) + self.assertEqual(a, b) + class TestNodeTable(unittest.TestCase, CommonTestsMixin, MetadataTestsMixin): @@ -1061,9 +1112,17 @@ def test_simple_example(self): t = tskit.MutationTable() t.add_row(site=0, node=1, derived_state="2", parent=3, metadata=b"4", time=5) t.add_row(1, 2, "3", 4, b"\xf0", 6) + t.add_row( + site=0, + node=1, + derived_state="2", + parent=3, + metadata=b"4", + time=tskit.UNKNOWN_TIME, + ) s = str(t) self.assertGreater(len(s), 0) - self.assertEqual(len(t), 2) + self.assertEqual(len(t), 3) self.assertEqual(attr.astuple(t[0]), (0, 1, "2", 3, b"4", 5)) self.assertEqual(attr.astuple(t[1]), (1, 2, "3", 4, b"\xf0", 6)) self.assertEqual(t[0].site, 0) @@ -1072,9 +1131,10 @@ def test_simple_example(self): self.assertEqual(t[0].parent, 3) self.assertEqual(t[0].metadata, b"4") self.assertEqual(t[0].time, 5) - self.assertEqual(t[0], t[-2]) - self.assertEqual(t[1], t[-1]) - self.assertRaises(IndexError, t.__getitem__, -3) + self.assertEqual(t[0], t[-3]) + self.assertEqual(t[1], t[-2]) + self.assertEqual(t[2], t[-1]) + self.assertRaises(IndexError, t.__getitem__, -4) def test_add_row_bad_data(self): t = tskit.MutationTable() diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 79317bb966..263fba4ae7 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -103,7 +103,7 @@ class SiteTableRow: metadata: bytes -@attr.s(**attr_options) +@attr.s(eq=False, **attr_options) class MutationTableRow: site: int node: int @@ -112,6 +112,22 @@ class MutationTableRow: metadata: bytes time: float + def __eq__(self, other): + return ( + isinstance(other, MutationTableRow) + and self.site == other.site + and self.node == other.node + and self.derived_state == other.derived_state + and self.parent == other.parent + and self.metadata == other.metadata + and ( + self.time == other.time + or ( + util.is_unknown_time(self.time) and util.is_unknown_time(other.time) + ) + ) + ) + @attr.s(**attr_options) class PopulationTableRow: