From e0409d26d37455096c6d2be4ff5e42f317f6c64c Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Mon, 10 Jan 2022 13:57:35 +0000 Subject: [PATCH] Add ref seq equals --- python/CHANGELOG.rst | 5 +- python/tests/test_reference_sequence.py | 66 +++++++++++++++++++------ python/tskit/tables.py | 10 ++++ 3 files changed, 66 insertions(+), 15 deletions(-) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 5fc6946170..8e3fad6fd5 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -11,7 +11,10 @@ **Fixes** - ``TreeSequence.dump_text`` now prints decoded metadata if there is a schema. - (:user:`bejeffery`, :issue:`1860`, :issue:`1527`, + (:user:`benjeffery`, :issue:`1860`, :issue:`1527`) + +- Add missing ``ReferenceSequence.__eq__`` method. + (:user:`benjeffery`, :issue:`2063`, :pr:`2085`) ---------------------- diff --git a/python/tests/test_reference_sequence.py b/python/tests/test_reference_sequence.py index c4054dcea7..fad82cbae7 100644 --- a/python/tests/test_reference_sequence.py +++ b/python/tests/test_reference_sequence.py @@ -148,42 +148,80 @@ def test_repr(self): assert repr(refseq).startswith("ReferenceSequence") -class TestAssertEquals: - def test_success_self(self, ts_fixture): +class TestEquals: + def test_equal_self(self, ts_fixture): ts_fixture.reference_sequence.assert_equals(ts_fixture.reference_sequence) + assert ts_fixture.reference_sequence == ts_fixture.reference_sequence + assert not ts_fixture.reference_sequence != ts_fixture.reference_sequence + assert ts_fixture.reference_sequence.equals(ts_fixture.reference_sequence) - def test_success_empty(self): + def test_equal_empty(self): tables = tskit.TableCollection(1) tables.reference_sequence.assert_equals(tables.reference_sequence) + assert tables.reference_sequence == tables.reference_sequence + assert tables.reference_sequence.equals(tables.reference_sequence) @pytest.mark.parametrize("attr", ["url", "data"]) - def test_fails_attr_missing(self, ts_fixture, attr): + def test_unequal_attr_missing(self, ts_fixture, attr): t1 = ts_fixture.tables d = t1.asdict() del d["reference_sequence"][attr] t2 = tskit.TableCollection.fromdict(d) with pytest.raises(AssertionError, match=attr): t1.reference_sequence.assert_equals(t2.reference_sequence) + assert t1.reference_sequence != t2.reference_sequence + assert not t1.reference_sequence.equals(t2.reference_sequence) with pytest.raises(AssertionError, match=attr): t2.reference_sequence.assert_equals(t1.reference_sequence) - - def test_fails_metadata_different(self, ts_fixture): + assert t2.reference_sequence != t1.reference_sequence + assert not t2.reference_sequence.equals(t1.reference_sequence) + + @pytest.mark.parametrize( + ("attr", "val"), + [ + ("url", "foo"), + ("data", "bar"), + ("metadata", {"json": "runs the world"}), + ("metadata_schema", tskit.MetadataSchema(None)), + ], + ) + def test_different_not_equal(self, ts_fixture, attr, val): t1 = ts_fixture.dump_tables() t2 = t1.copy() - t1.reference_sequence.metadata = {"different": "metadata"} - with pytest.raises(AssertionError, match="metadata"): + setattr(t1.reference_sequence, attr, val) + + with pytest.raises(AssertionError): t1.reference_sequence.assert_equals(t2.reference_sequence) - with pytest.raises(AssertionError, match="metadata"): + assert t1.reference_sequence != t2.reference_sequence + assert not t1.reference_sequence.equals(t2.reference_sequence) + with pytest.raises(AssertionError): t2.reference_sequence.assert_equals(t1.reference_sequence) - - def test_fails_metadata_schema_different(self, ts_fixture): + assert t2.reference_sequence != t1.reference_sequence + assert not t2.reference_sequence.equals(t1.reference_sequence) + + @pytest.mark.parametrize( + ("attr", "val"), + [ + ("metadata", {"json": "runs the world"}), + ("metadata_schema", tskit.MetadataSchema(None)), + ], + ) + def test_different_but_ignore(self, ts_fixture, attr, val): t1 = ts_fixture.dump_tables() t2 = t1.copy() - t1.reference_sequence.metadata_schema = tskit.MetadataSchema(None) - with pytest.raises(AssertionError, match="schemas"): + setattr(t1.reference_sequence, attr, val) + + with pytest.raises(AssertionError): t1.reference_sequence.assert_equals(t2.reference_sequence) - with pytest.raises(AssertionError, match="schemas"): + assert t1.reference_sequence != t2.reference_sequence + assert not t1.reference_sequence.equals(t2.reference_sequence) + with pytest.raises(AssertionError): t2.reference_sequence.assert_equals(t1.reference_sequence) + assert t2.reference_sequence != t1.reference_sequence + assert not t2.reference_sequence.equals(t1.reference_sequence) + + t2.reference_sequence.assert_equals(t1.reference_sequence, ignore_metadata=True) + assert t2.reference_sequence.equals(t1.reference_sequence, ignore_metadata=True) class TestTreeSequenceProperties: diff --git a/python/tskit/tables.py b/python/tskit/tables.py index c6d33218b6..e6050cd451 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -2773,6 +2773,16 @@ def asdict(self) -> dict: "url": self.url, } + def __eq__(self, other): + return self.equals(other) + + def equals(self, other, ignore_metadata=False): + try: + self.assert_equals(other, ignore_metadata) + return True + except AssertionError: + return False + def assert_equals(self, other, ignore_metadata=False): if not ignore_metadata: super().assert_equals(other)