diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 3c9f92c64a..83295aef02 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -25,6 +25,9 @@ - Added ``TableCollection.indexes`` for access to the edge insertion/removal order indexes. (:user:`benjeffery`, :issue:`4`, :pr:`916`) +- The dictionary representation of a TableCollection now contains its index. + (:user:`benjeffery`, :issue:`870`, :pr:`921`) + - Added ``TreeSequence._repr_html_`` for use in jupyter notebooks. (:user:`benjeffery`, :issue:`872`, :pr:`923`) diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index a76d60042a..55d69aacae 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -5226,26 +5226,31 @@ TableCollection_get_indexes(TableCollection *self, void *closure) goto out; } - insertion = table_get_column_array(self->tables->indexes.num_edges, - self->tables->indexes.edge_insertion_order, NPY_INT32, sizeof(tsk_id_t)); - if (insertion == NULL) { - goto out; - } - removal = table_get_column_array(self->tables->indexes.num_edges, - self->tables->indexes.edge_removal_order, NPY_INT32, sizeof(tsk_id_t)); - if (removal == NULL) { - goto out; - } indexes_dict = PyDict_New(); if (indexes_dict == NULL) { goto out; } - if (PyDict_SetItemString(indexes_dict, "edge_insertion_order", insertion) != 0) { - goto out; - } - if (PyDict_SetItemString(indexes_dict, "edge_removal_order", removal) != 0) { - goto out; + + if (tsk_table_collection_has_index(self->tables, 0)) { + insertion = table_get_column_array(self->tables->indexes.num_edges, + self->tables->indexes.edge_insertion_order, NPY_INT32, sizeof(tsk_id_t)); + if (insertion == NULL) { + goto out; + } + removal = table_get_column_array(self->tables->indexes.num_edges, + self->tables->indexes.edge_removal_order, NPY_INT32, sizeof(tsk_id_t)); + if (removal == NULL) { + goto out; + } + + if (PyDict_SetItemString(indexes_dict, "edge_insertion_order", insertion) != 0) { + goto out; + } + if (PyDict_SetItemString(indexes_dict, "edge_removal_order", removal) != 0) { + goto out; + } } + ret = indexes_dict; indexes_dict = NULL; out: diff --git a/python/lwt_interface/CHANGELOG.rst b/python/lwt_interface/CHANGELOG.rst new file mode 100644 index 0000000000..2fbf607ac3 --- /dev/null +++ b/python/lwt_interface/CHANGELOG.rst @@ -0,0 +1,6 @@ +-------------------- +[0.1.2] - 2020-10-22 +-------------------- + + - Added optional top-level key ``indexes`` which has contains ``edge_insertion_order`` and + ``edge_removal_order`` \ No newline at end of file diff --git a/python/lwt_interface/dict_encoding_testlib.py b/python/lwt_interface/dict_encoding_testlib.py index a9e0a0db78..eb381ddb5b 100644 --- a/python/lwt_interface/dict_encoding_testlib.py +++ b/python/lwt_interface/dict_encoding_testlib.py @@ -148,7 +148,7 @@ def get_example_tables(): class TestEncodingVersion: def test_version(self): lwt = lwt_module.LightweightTableCollection() - assert lwt.asdict()["encoding_version"] == (1, 1) + assert lwt.asdict()["encoding_version"] == (1, 2) class TestRoundTrip: @@ -270,6 +270,7 @@ def test_missing_tables(self): "metadata", "metadata_schema", "encoding_version", + "indexes", } for table_name in table_names: d = tables.asdict() @@ -292,6 +293,7 @@ def verify_columns(self, value): "metadata", "metadata_schema", "encoding_version", + "indexes", } for table_name in table_names: table_dict = d[table_name] @@ -313,7 +315,7 @@ def test_str(self): def test_bad_top_level_types(self): tables = get_example_tables() d = tables.asdict() - for key in set(d.keys()) - {"encoding_version"}: + for key in set(d.keys()) - {"encoding_version", "indexes"}: bad_type_dict = tables.asdict() # A list should be a ValueError for both the tables and sequence_length bad_type_dict[key] = ["12345"] @@ -336,6 +338,7 @@ def verify(self, num_rows): "metadata", "metadata_schema", "encoding_version", + "indexes", } for table_name in sorted(table_names): table_dict = d[table_name] @@ -354,6 +357,30 @@ def test_two_rows(self): def test_zero_rows(self): self.verify(0) + def test_bad_index_length(self): + tables = get_example_tables() + for col in ("insertion", "removal"): + d = tables.asdict() + d["indexes"][f"edge_{col}_order"] = d["indexes"][f"edge_{col}_order"][:-1] + lwt = lwt_module.LightweightTableCollection() + with pytest.raises( + ValueError, + match="^edge_insertion_order and" + " edge_removal_order must be the same" + " length$", + ): + lwt.fromdict(d) + d = tables.asdict() + for col in ("insertion", "removal"): + d["indexes"][f"edge_{col}_order"] = d["indexes"][f"edge_{col}_order"][:-1] + lwt = lwt_module.LightweightTableCollection() + with pytest.raises( + ValueError, + match="^edge_insertion_order and edge_removal_order must be" + " the same length as the number of edges$", + ): + lwt.fromdict(d) + class TestRequiredAndOptionalColumns: """ @@ -563,6 +590,41 @@ def test_provenances(self): ["record", "record_offset", "timestamp", "timestamp_offset"], ) + def test_index(self): + tables = get_example_tables() + d = tables.asdict() + lwt = lwt_module.LightweightTableCollection() + lwt.fromdict(d) + other = lwt.asdict() + assert np.array_equal( + d["indexes"]["edge_insertion_order"], + other["indexes"]["edge_insertion_order"], + ) + assert np.array_equal( + d["indexes"]["edge_removal_order"], other["indexes"]["edge_removal_order"] + ) + + # index is optional + d = tables.asdict() + del d["indexes"] + lwt = lwt_module.LightweightTableCollection() + lwt.fromdict(d) + # and a tc without indexes has empty dict + assert lwt.asdict()["indexes"] == {} + + # Both columns must be provided, if one is + for col in ("insertion", "removal"): + d = tables.asdict() + del d["indexes"][f"edge_{col}_order"] + lwt = lwt_module.LightweightTableCollection() + with pytest.raises( + TypeError, + match="^edge_insertion_order and " + "edge_removal_order must be specified " + "together$", + ): + lwt.fromdict(d) + def test_top_level_metadata(self): tables = get_example_tables() d = tables.asdict() diff --git a/python/lwt_interface/tskit_lwt_interface.h b/python/lwt_interface/tskit_lwt_interface.h index b6e9332977..b658f319e8 100644 --- a/python/lwt_interface/tskit_lwt_interface.h +++ b/python/lwt_interface/tskit_lwt_interface.h @@ -1172,14 +1172,7 @@ parse_provenance_table_dict( return ret; } -// TODO REMOVE THIS ONCE INDEXES IN LWT -#ifdef __GNUC__ -#define VARIABLE_IS_NOT_USED __attribute__((unused)) -#else -#define VARIABLE_IS_NOT_USED -#endif - -VARIABLE_IS_NOT_USED static int +static int parse_indexes_dict(tsk_table_collection_t *tables, PyObject *dict) { int err; @@ -1409,6 +1402,21 @@ parse_table_collection_dict(tsk_table_collection_t *tables, PyObject *tables_dic goto out; } + /* indexes */ + value = get_table_dict_value(tables_dict, "indexes", false); + if (value == NULL) { + goto out; + } + if (value != Py_None) { + if (!PyDict_Check(value)) { + PyErr_SetString(PyExc_TypeError, "not a dictionary"); + goto out; + } + if (parse_indexes_dict(tables, value) != 0) { + goto out; + } + } + ret = 0; out: return ret; @@ -1550,6 +1558,18 @@ write_table_arrays(tsk_table_collection_t *tables, PyObject *dict) { NULL }, }; + struct table_col indexes_cols[] = { + { "edge_insertion_order", (void *) tables->indexes.edge_insertion_order, + tables->indexes.num_edges, NPY_INT32 }, + { "edge_removal_order", (void *) tables->indexes.edge_removal_order, + tables->indexes.num_edges, NPY_INT32 }, + { NULL }, + }; + + struct table_col no_indexes_cols[] = { + { NULL }, + }; + struct table_desc table_descs[] = { { "individuals", individual_cols, tables->individuals.metadata_schema, tables->individuals.metadata_schema_length }, @@ -1566,6 +1586,11 @@ write_table_arrays(tsk_table_collection_t *tables, PyObject *dict) { "populations", population_cols, tables->populations.metadata_schema, tables->populations.metadata_schema_length }, { "provenances", provenance_cols, NULL, 0 }, + /* We don't want to insert empty indexes, return an empty dict if there are none + */ + { "indexes", + tsk_table_collection_has_index(tables, 0) ? indexes_cols : no_indexes_cols, + NULL, 0 }, }; for (j = 0; j < sizeof(table_descs) / sizeof(*table_descs); j++) { @@ -1605,6 +1630,7 @@ write_table_arrays(tsk_table_collection_t *tables, PyObject *dict) Py_DECREF(table_dict); table_dict = NULL; } + ret = 0; out: Py_XDECREF(array); @@ -1627,7 +1653,7 @@ dump_tables_dict(tsk_table_collection_t *tables) } /* Dict representation version */ - val = Py_BuildValue("ll", 1, 1); + val = Py_BuildValue("ll", 1, 2); if (val == NULL) { goto out; } diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index ac6f6d6fa4..0fb3180552 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -417,12 +417,7 @@ def test_index(self): tc.indexes["edge_removal_order"], np.arange(18, dtype=np.int32)[::-1] ) tc.drop_index() - assert np.array_equal( - tc.indexes["edge_insertion_order"], np.arange(0, dtype=np.int32) - ) - assert np.array_equal( - tc.indexes["edge_removal_order"], np.arange(0, dtype=np.int32)[::-1] - ) + assert tc.indexes == {} tc.build_index() assert np.array_equal( tc.indexes["edge_insertion_order"], np.arange(18, dtype=np.int32) @@ -444,6 +439,11 @@ def test_index(self): tc.indexes["edge_removal_order"], np.arange(4242, 4242 + 18, dtype=np.int32) ) + def test_no_indexes(self): + tc = msprime.simulate(10, random_seed=42).tables._ll_tables + tc.drop_index() + assert tc.indexes == {} + def test_bad_indexes(self): tc = msprime.simulate(10, random_seed=42).tables._ll_tables for col in ("insertion", "removal"): diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index b8040eb59f..e790953b67 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -1310,16 +1310,23 @@ def test_add_row_bad_data(self): t.add_row(metadata=[0]) -class TestTableCollectionIndex: +class TestTableCollectionIndexes: def test_index(self): i = np.arange(20) r = np.arange(20)[::-1] index = tskit.TableCollectionIndexes( edge_insertion_order=i, edge_removal_order=r ) - assert index.edge_insertion_order is i - assert index.edge_removal_order is r - assert index.asdict() == {"edge_insertion_order": i, "edge_removal_order": r} + assert np.array_equal(index.edge_insertion_order, i) + assert np.array_equal(index.edge_removal_order, r) + d = index.asdict() + assert np.array_equal(d["edge_insertion_order"], i) + assert np.array_equal(d["edge_removal_order"], r) + + index = tskit.TableCollectionIndexes() + assert index.edge_insertion_order is None + assert index.edge_removal_order is None + assert index.asdict() == {} class TestSortTables: @@ -2255,7 +2262,7 @@ def test_asdict(self): t = ts.tables self.add_metadata(t) d1 = { - "encoding_version": (1, 1), + "encoding_version": (1, 2), "sequence_length": t.sequence_length, "metadata_schema": str(t.metadata_schema), "metadata": t.metadata_schema.encode_row(t.metadata), @@ -2267,12 +2274,15 @@ def test_asdict(self): "mutations": t.mutations.asdict(), "migrations": t.migrations.asdict(), "provenances": t.provenances.asdict(), + "indexes": t.indexes.asdict(), } d2 = t.asdict() assert set(d1.keys()) == set(d2.keys()) t1 = tskit.TableCollection.fromdict(d1) t2 = tskit.TableCollection.fromdict(d2) assert t1 == t2 + assert t1.has_index() + assert t2.has_index() def test_from_dict(self): ts = msprime.simulate(10, mutation_rate=1, random_seed=1) @@ -2291,6 +2301,7 @@ def test_from_dict(self): "mutations": t1.mutations.asdict(), "migrations": t1.migrations.asdict(), "provenances": t1.provenances.asdict(), + "indexes": t1.indexes.asdict(), } t2 = tskit.TableCollection.fromdict(d) assert t1 == t2 @@ -2551,16 +2562,9 @@ def test_sequence_length_longer_than_edges(self): tree = next(trees) assert len(tree.parent_dict) == 0 - def test_index_read(self, simple_ts_fixture): + def test_indexes(self, simple_ts_fixture): tc = tskit.TableCollection(sequence_length=1) - assert tc.indexes.edge_insertion_order.dtype == np.int32 - assert tc.indexes.edge_removal_order.dtype == np.int32 - assert np.array_equal( - tc.indexes.edge_insertion_order, np.arange(0, dtype=np.int32) - ) - assert np.array_equal( - tc.indexes.edge_removal_order, np.arange(0, dtype=np.int32)[::-1] - ) + assert tc.indexes == tskit.TableCollectionIndexes() tc = simple_ts_fixture.tables assert np.array_equal( tc.indexes.edge_insertion_order, np.arange(18, dtype=np.int32) @@ -2569,20 +2573,72 @@ def test_index_read(self, simple_ts_fixture): tc.indexes.edge_removal_order, np.arange(18, dtype=np.int32)[::-1] ) tc.drop_index() + assert tc.indexes == tskit.TableCollectionIndexes() + tc.build_index() assert np.array_equal( - tc.indexes.edge_insertion_order, np.arange(0, dtype=np.int32) + tc.indexes.edge_insertion_order, np.arange(18, dtype=np.int32) ) assert np.array_equal( - tc.indexes.edge_removal_order, np.arange(0, dtype=np.int32)[::-1] + tc.indexes.edge_removal_order, np.arange(18, dtype=np.int32)[::-1] ) - tc.build_index() + + modify_indexes = tskit.TableCollectionIndexes( + edge_insertion_order=np.arange(42, 42 + 18, dtype=np.int32), + edge_removal_order=np.arange(4242, 4242 + 18, dtype=np.int32), + ) + tc.indexes = modify_indexes assert np.array_equal( - tc.indexes.edge_insertion_order, np.arange(18, dtype=np.int32) + tc.indexes.edge_insertion_order, np.arange(42, 42 + 18, dtype=np.int32) ) assert np.array_equal( - tc.indexes.edge_removal_order, np.arange(18, dtype=np.int32)[::-1] + tc.indexes.edge_removal_order, np.arange(4242, 4242 + 18, dtype=np.int32) ) + def test_indexes_roundtrip(self, simple_ts_fixture): + # Indexes shouldn't be made by roundtripping + tables = tskit.TableCollection(sequence_length=1) + assert not tables.has_index() + assert not tskit.TableCollection.fromdict(tables.asdict()).has_index() + + tables = simple_ts_fixture.dump_tables() + tables.drop_index() + assert not tskit.TableCollection.fromdict(tables.asdict()).has_index() + + def test_asdict_lwt_concordence(self, ts_fixture): + def check_concordence(d1, d2): + assert set(d1.keys()) == set(d2.keys()) + for k1, v1 in d1.items(): + v2 = d2[k1] + assert type(v1) == type(v2) + if type(v1) == dict: + assert set(v1.keys()) == set(v2.keys()) + for sk1, sv1 in v1.items(): + sv2 = v2[sk1] + assert type(sv1) == type(sv2) + if type(sv1) == np.ndarray: + assert np.array_equal(sv1, sv2) or ( + np.all(tskit.is_unknown_time(sv1)) + and np.all(tskit.is_unknown_time(sv2)) + ) + elif type(sv1) in [bytes, str]: + assert sv1 == sv2 + else: + raise AssertionError() + + else: + assert v1 == v2 + + tables = ts_fixture.dump_tables() + assert tables.has_index() + lwt = _tskit.LightweightTableCollection() + lwt.fromdict(tables.asdict()) + check_concordence(lwt.asdict(), tables.asdict()) + + tables.drop_index() + lwt = _tskit.LightweightTableCollection() + lwt.fromdict(tables.asdict()) + check_concordence(lwt.asdict(), tables.asdict()) + class TestEqualityOptions: def test_equals_provenance(self): diff --git a/python/tskit/tables.py b/python/tskit/tables.py index f689a71397..cb4f4b3bda 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -142,11 +142,11 @@ class ProvenanceTableRow: @attr.s(**attr_options) class TableCollectionIndexes: - edge_insertion_order: np.ndarray - edge_removal_order: np.ndarray + edge_insertion_order: np.ndarray = attr.ib(default=None) + edge_removal_order: np.ndarray = attr.ib(default=None) def asdict(self): - return attr.asdict(self) + return attr.asdict(self, filter=lambda k, v: v is not None) def keep_with_offset(keep, data, offset): @@ -2080,7 +2080,12 @@ def provenances(self): @property def indexes(self): - return TableCollectionIndexes(**self._ll_tables.indexes) + indexes = self._ll_tables.indexes + return TableCollectionIndexes(**indexes) + + @indexes.setter + def indexes(self, indexes): + self._ll_tables.indexes = indexes.asdict() @property def sequence_length(self): @@ -2134,8 +2139,8 @@ def asdict(self): Note: the semantics of this method changed at tskit 0.1.0. Previously a map of table names to the tables themselves was returned. """ - return { - "encoding_version": (1, 1), + ret = { + "encoding_version": (1, 2), "sequence_length": self.sequence_length, "metadata_schema": str(self.metadata_schema), "metadata": self.metadata_schema.encode_row(self.metadata), @@ -2147,7 +2152,9 @@ def asdict(self): "mutations": self.mutations.asdict(), "populations": self.populations.asdict(), "provenances": self.provenances.asdict(), + "indexes": self.indexes.asdict(), } + return ret @property def name_map(self): @@ -2278,7 +2285,6 @@ def fromdict(self, tables_dict): tables.metadata = tables.metadata_schema.decode_row(tables_dict["metadata"]) except KeyError: pass - tables.individuals.set_columns(**tables_dict["individuals"]) tables.nodes.set_columns(**tables_dict["nodes"]) tables.edges.set_columns(**tables_dict["edges"]) @@ -2287,6 +2293,12 @@ def fromdict(self, tables_dict): tables.mutations.set_columns(**tables_dict["mutations"]) tables.populations.set_columns(**tables_dict["populations"]) tables.provenances.set_columns(**tables_dict["provenances"]) + + # Indexes must be last as other wise the check for their consistency will fail + try: + tables.indexes = TableCollectionIndexes(**tables_dict["indexes"]) + except KeyError: + pass return tables def copy(self):