diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 4bffc7c6f3..a23c050125 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -56,6 +56,12 @@ provenances, table schemas and tree-sequence level metadata and schema. (:user:`benjeffery`, :issue:`929`, :pr:`1001`) +**Bugfixes** + +- ``LightWeightTableCollection.asdict`` and ``TableCollection.asdict`` now return copies + of arrays. + (:user:`benjeffery`, :issue:`1025`, :pr:`1029`) + **Breaking changes** - The argument to ``ts.dump`` and ``tskit.load`` has been renamed `file` from `path`. diff --git a/python/lwt_interface/dict_encoding_testlib.py b/python/lwt_interface/dict_encoding_testlib.py index eb381ddb5b..23bca3e9d7 100644 --- a/python/lwt_interface/dict_encoding_testlib.py +++ b/python/lwt_interface/dict_encoding_testlib.py @@ -26,6 +26,9 @@ compiled module exporting the LightweightTableCollection class. See the test_example_c_module file for an example. """ +import json + +import kastore import msprime import numpy as np import pytest @@ -36,113 +39,76 @@ lwt_module = None -def get_example_tables(): +@pytest.fixture(scope="session") +def full_ts(): """ Return a tree sequence that has data in all fields. """ - pop_configs = [msprime.PopulationConfiguration(5) for _ in range(2)] - migration_matrix = [[0, 1], [1, 0]] + """ + A tree sequence with data in all fields - duplcated from tskit's conftest.py + as other test suites using this file will not have that fixture defined. + """ + n = 10 + t = 1 + population_configurations = [ + msprime.PopulationConfiguration(n // 2), + msprime.PopulationConfiguration(n // 2), + msprime.PopulationConfiguration(0), + ] + demographic_events = [ + msprime.MassMigration(time=t, source=0, destination=2), + msprime.MassMigration(time=t, source=1, destination=2), + ] ts = msprime.simulate( - population_configurations=pop_configs, - migration_matrix=migration_matrix, + population_configurations=population_configurations, + demographic_events=demographic_events, + random_seed=1, mutation_rate=1, record_migrations=True, - random_seed=1, ) - tables = ts.dump_tables() - for j in range(ts.num_samples): - tables.individuals.add_row(flags=j, location=np.arange(j), metadata=b"x" * j) - tables.nodes.clear() - for node in ts.nodes(): - tables.nodes.add_row( - flags=node.flags, - time=node.time, - population=node.population, - individual=node.id if node.id < ts.num_samples else -1, - metadata=b"y" * node.id, - ) - tables.edges.clear() - for edge in ts.edges(): - tables.edges.add_row( - left=edge.left, - right=edge.right, - child=edge.child, - parent=edge.parent, - metadata=b"y" * edge.id, - ) - tables.sites.clear() - for site in ts.sites(): - tables.sites.add_row( - position=site.position, - ancestral_state="A" * site.id, - metadata=b"q" * site.id, - ) - tables.mutations.clear() - for mutation in ts.mutations(): - mut_id = tables.mutations.add_row( - site=mutation.site, - node=mutation.node, - time=0, - parent=-1, - derived_state="C" * mutation.id, - metadata=b"x" * mutation.id, - ) - # Add another mutation on the same branch. - tables.mutations.add_row( - site=mutation.site, - node=mutation.node, - time=0, - parent=mut_id, - derived_state="G" * mutation.id, - metadata=b"y" * mutation.id, - ) - tables.migrations.clear() - for migration in ts.migrations(): - tables.migrations.add_row( - left=migration.left, - right=migration.right, - node=migration.node, - source=migration.source, - dest=migration.dest, - time=migration.time, - metadata=b"y" * migration.id, - ) + # TODO replace this with properly linked up individuals using sim_ancestry + # once 1.0 is released. + for j in range(n): + tables.individuals.add_row(flags=j, location=(j, j)) + + for name, table in tables.name_map.items(): + if name != "provenances": + table.metadata_schema = tskit.MetadataSchema({"codec": "json"}) + metadatas = [f"n_{name}_{u}" for u in range(len(table))] + metadata, metadata_offset = tskit.pack_strings(metadatas) + table.set_columns( + **{ + **table.asdict(), + "metadata": metadata, + "metadata_offset": metadata_offset, + } + ) + tables.metadata_schema = tskit.MetadataSchema({"codec": "json"}) + tables.metadata = "Test metadata" + + # Add some more provenance so we have enough rows for the offset deletion test. for j in range(10): - tables.populations.add_row(metadata=b"p" * j) tables.provenances.add_row(timestamp="x" * j, record="y" * j) - tables.metadata_schema = tskit.MetadataSchema( - { - "codec": "struct", - "type": "object", - "properties": { - "top-level": { - "type": "array", - "items": {"type": "integer", "binaryFormat": "B"}, - "noLengthEncodingExhaustBuffer": True, - } - }, - } - ) - tables.metadata = {"top-level": [1, 2, 3, 4]} - for table in [ - "individuals", - "nodes", - "edges", - "migrations", - "sites", - "mutations", - "populations", - ]: - t = getattr(tables, table) - t.metadata_schema = tskit.MetadataSchema( - { - "codec": "struct", - "type": "object", - "properties": {table: {"type": "string", "binaryFormat": "50p"}}, - } - ) - return tables + return tables.tree_sequence() + + +# The ts above is used for the whole test session, but our tests need fresh tables to +# modify +@pytest.fixture +def tables(full_ts): + return full_ts.dump_tables() + + +def test_check_ts_full(tmp_path, full_ts): + """ + Check that the example ts has data in all fields + """ + full_ts.dump(tmp_path / "tables") + store = kastore.load(tmp_path / "tables") + for v in store.values(): + # Check we really have data in every field + assert v.nbytes > 0 class TestEncodingVersion: @@ -196,8 +162,7 @@ def test_migration(self): ) self.verify(ts.tables) - def test_example(self): - tables = get_example_tables() + def test_example(self, tables): tables.metadata_schema = tskit.MetadataSchema( { "codec": "struct", @@ -233,16 +198,14 @@ class TestMissingData: Tests what happens when we have missing data in the encoded dict. """ - def test_missing_sequence_length(self): - tables = get_example_tables() + def test_missing_sequence_length(self, tables): d = tables.asdict() del d["sequence_length"] lwt = lwt_module.LightweightTableCollection() with pytest.raises(TypeError): lwt.fromdict(d) - def test_missing_metadata(self): - tables = get_example_tables() + def test_missing_metadata(self, tables): assert tables.metadata != b"" d = tables.asdict() del d["metadata"] @@ -250,10 +213,10 @@ def test_missing_metadata(self): lwt.fromdict(d) tables = tskit.TableCollection.fromdict(lwt.asdict()) # Empty byte field still gets interpreted by schema - assert tables.metadata == {"top-level": []} + with pytest.raises(json.decoder.JSONDecodeError): + tables.metadata - def test_missing_metadata_schema(self): - tables = get_example_tables() + def test_missing_metadata_schema(self, tables): assert str(tables.metadata_schema) != "" d = tables.asdict() del d["metadata_schema"] @@ -262,8 +225,7 @@ def test_missing_metadata_schema(self): tables = tskit.TableCollection.fromdict(lwt.asdict()) assert str(tables.metadata_schema) == "" - def test_missing_tables(self): - tables = get_example_tables() + def test_missing_tables(self, tables): d = tables.asdict() table_names = d.keys() - { "sequence_length", @@ -285,8 +247,7 @@ class TestBadTypes: Tests for setting each column to a type that can't be converted to 1D numpy array. """ - def verify_columns(self, value): - tables = get_example_tables() + def verify_columns(self, value, tables): d = tables.asdict() table_names = set(d.keys()) - { "sequence_length", @@ -306,14 +267,13 @@ def verify_columns(self, value): with pytest.raises(ValueError): lwt.fromdict(d) - def test_2d_array(self): - self.verify_columns([[1, 2], [3, 4]]) + def test_2d_array(self, tables): + self.verify_columns([[1, 2], [3, 4]], tables) - def test_str(self): - self.verify_columns("aserg") + def test_str(self, tables): + self.verify_columns("aserg", tables) - def test_bad_top_level_types(self): - tables = get_example_tables() + def test_bad_top_level_types(self, tables): d = tables.asdict() for key in set(d.keys()) - {"encoding_version", "indexes"}: bad_type_dict = tables.asdict() @@ -329,9 +289,7 @@ class TestBadLengths: Tests for setting each column to a length incompatible with the table. """ - def verify(self, num_rows): - - tables = get_example_tables() + def verify(self, num_rows, tables): d = tables.asdict() table_names = set(d.keys()) - { "sequence_length", @@ -351,14 +309,13 @@ def verify(self, num_rows): with pytest.raises(ValueError): lwt.fromdict(d) - def test_two_rows(self): - self.verify(2) + def test_two_rows(self, tables): + self.verify(2, tables) - def test_zero_rows(self): - self.verify(0) + def test_zero_rows(self, tables): + self.verify(0, tables) - def test_bad_index_length(self): - tables = get_example_tables() + def test_bad_index_length(self, tables): for col in ("insertion", "removal"): d = tables.asdict() d["indexes"][f"edge_{col}_order"] = d["indexes"][f"edge_{col}_order"][:-1] @@ -511,8 +468,7 @@ def verify_metadata_schema(self, tables, table_name): tables = tskit.TableCollection.fromdict(out) assert str(getattr(tables, table_name).metadata_schema) == "" - def test_individuals(self): - tables = get_example_tables() + def test_individuals(self, tables): self.verify_required_columns(tables, "individuals", ["flags"]) self.verify_offset_pair( tables, len(tables.individuals), "individuals", "location" @@ -522,24 +478,21 @@ def test_individuals(self): ) self.verify_metadata_schema(tables, "individuals") - def test_nodes(self): - tables = get_example_tables() + def test_nodes(self, tables): self.verify_offset_pair(tables, len(tables.nodes), "nodes", "metadata") self.verify_optional_column(tables, len(tables.nodes), "nodes", "population") self.verify_optional_column(tables, len(tables.nodes), "nodes", "individual") self.verify_required_columns(tables, "nodes", ["flags", "time"]) self.verify_metadata_schema(tables, "nodes") - def test_edges(self): - tables = get_example_tables() + def test_edges(self, tables): self.verify_required_columns( tables, "edges", ["left", "right", "parent", "child"] ) self.verify_offset_pair(tables, len(tables.edges), "edges", "metadata") self.verify_metadata_schema(tables, "edges") - def test_migrations(self): - tables = get_example_tables() + def test_migrations(self, tables): self.verify_required_columns( tables, "migrations", ["left", "right", "node", "source", "dest", "time"] ) @@ -549,16 +502,14 @@ def test_migrations(self): self.verify_optional_column(tables, len(tables.nodes), "nodes", "individual") self.verify_metadata_schema(tables, "migrations") - def test_sites(self): - tables = get_example_tables() + def test_sites(self, tables): self.verify_required_columns( tables, "sites", ["position", "ancestral_state", "ancestral_state_offset"] ) self.verify_offset_pair(tables, len(tables.sites), "sites", "metadata") self.verify_metadata_schema(tables, "sites") - def test_mutations(self): - tables = get_example_tables() + def test_mutations(self, tables): self.verify_required_columns( tables, "mutations", @@ -574,24 +525,21 @@ def test_mutations(self): out = lwt.asdict() assert all(util.is_unknown_time(val) for val in out["mutations"]["time"]) - def test_populations(self): - tables = get_example_tables() + def test_populations(self, tables): self.verify_required_columns( tables, "populations", ["metadata", "metadata_offset"] ) self.verify_metadata_schema(tables, "populations") self.verify_offset_pair(tables, len(tables.nodes), "nodes", "metadata", True) - def test_provenances(self): - tables = get_example_tables() + def test_provenances(self, tables): self.verify_required_columns( tables, "provenances", ["record", "record_offset", "timestamp", "timestamp_offset"], ) - def test_index(self): - tables = get_example_tables() + def test_index(self, tables): d = tables.asdict() lwt = lwt_module.LightweightTableCollection() lwt.fromdict(d) @@ -625,8 +573,7 @@ def test_index(self): ): lwt.fromdict(d) - def test_top_level_metadata(self): - tables = get_example_tables() + def test_top_level_metadata(self, tables): d = tables.asdict() # None should give default value d["metadata"] = None @@ -635,13 +582,11 @@ def test_top_level_metadata(self): out = lwt.asdict() assert "metadata" not in out tables = tskit.TableCollection.fromdict(out) - # We only removed the metadata, not the schema. So empty bytefield - # still gets interpreted - assert tables.metadata == {"top-level": []} + with pytest.raises(json.decoder.JSONDecodeError): + tables.metadata # Missing is tested in TestMissingData above - def test_top_level_metadata_schema(self): - tables = get_example_tables() + def test_top_level_metadata_schema(self, tables): d = tables.asdict() # None should give default value d["metadata_schema"] = None @@ -652,3 +597,31 @@ def test_top_level_metadata_schema(self): tables = tskit.TableCollection.fromdict(out) assert str(tables.metadata_schema) == "" # Missing is tested in TestMissingData above + + +class TestLifecycle: + def test_unassigned_empty(self): + lwt_dict = lwt_module.LightweightTableCollection().asdict() + assert tskit.TableCollection.fromdict(lwt_dict) == tskit.TableCollection(-1) + + def test_del_empty(self): + lwt = lwt_module.LightweightTableCollection() + lwt_dict = lwt.asdict() + del lwt + assert tskit.TableCollection.fromdict(lwt_dict) == tskit.TableCollection(-1) + + def test_del_full(self, tables): + lwt = lwt_module.LightweightTableCollection() + lwt.fromdict(tables.asdict()) + lwt_dict = lwt.asdict() + del lwt + assert tskit.TableCollection.fromdict(lwt_dict) == tables + + def test_del_lwt_and_tables(self, tables): + lwt = lwt_module.LightweightTableCollection() + lwt.fromdict(tables.asdict()) + lwt_dict = lwt.asdict() + del lwt + tables2 = tables.copy() + del tables + assert tskit.TableCollection.fromdict(lwt_dict) == tables2 diff --git a/python/lwt_interface/tskit_lwt_interface.h b/python/lwt_interface/tskit_lwt_interface.h index 52bdae9be6..c278674e3a 100644 --- a/python/lwt_interface/tskit_lwt_interface.h +++ b/python/lwt_interface/tskit_lwt_interface.h @@ -1600,10 +1600,12 @@ write_table_arrays(tsk_table_collection_t *tables, PyObject *dict) } col = table_descs[j].cols; while (col->name != NULL) { - array = PyArray_SimpleNewFromData(1, &col->num_rows, col->type, col->data); + array = (PyObject *) PyArray_EMPTY(1, &col->num_rows, col->type, 0); if (array == NULL) { goto out; } + memcpy(PyArray_DATA((PyArrayObject *) array), col->data, + col->num_rows * PyArray_ITEMSIZE((PyArrayObject *) array)); if (PyDict_SetItemString(table_dict, col->name, array) != 0) { goto out; } diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index d463d50b38..92d2027f83 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -2309,6 +2309,12 @@ def test_asdict(self, ts_fixture): assert t1.has_index() assert t2.has_index() + def test_asdict_lifecycle(self, ts_fixture): + tables = ts_fixture.dump_tables() + tables_dict = tables.asdict() + del tables + assert tskit.TableCollection.fromdict(tables_dict) == ts_fixture.dump_tables() + def test_from_dict(self, ts_fixture): t1 = ts_fixture.tables d = {