Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
- Added ``TableCollection.indexes`` for access to the edge insertion/removal order indexes.
(:user:`benjeffery`, :issue:`4`, :pr:`916`)

- The dictionary representation of a TableCollection now contains its index.
(:user:`benjeffery`, :issue:`870`, :pr:`921`)

- Added ``TreeSequence._repr_html_`` for use in jupyter notebooks.
(:user:`benjeffery`, :issue:`872`, :pr:`923`)

Expand Down
35 changes: 20 additions & 15 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -5226,26 +5226,31 @@ TableCollection_get_indexes(TableCollection *self, void *closure)
goto out;
}

insertion = table_get_column_array(self->tables->indexes.num_edges,
self->tables->indexes.edge_insertion_order, NPY_INT32, sizeof(tsk_id_t));
if (insertion == NULL) {
goto out;
}
removal = table_get_column_array(self->tables->indexes.num_edges,
self->tables->indexes.edge_removal_order, NPY_INT32, sizeof(tsk_id_t));
if (removal == NULL) {
goto out;
}
indexes_dict = PyDict_New();
if (indexes_dict == NULL) {
goto out;
}
if (PyDict_SetItemString(indexes_dict, "edge_insertion_order", insertion) != 0) {
goto out;
}
if (PyDict_SetItemString(indexes_dict, "edge_removal_order", removal) != 0) {
goto out;

if (tsk_table_collection_has_index(self->tables, 0)) {
insertion = table_get_column_array(self->tables->indexes.num_edges,
self->tables->indexes.edge_insertion_order, NPY_INT32, sizeof(tsk_id_t));
if (insertion == NULL) {
goto out;
}
removal = table_get_column_array(self->tables->indexes.num_edges,
self->tables->indexes.edge_removal_order, NPY_INT32, sizeof(tsk_id_t));
if (removal == NULL) {
goto out;
}

if (PyDict_SetItemString(indexes_dict, "edge_insertion_order", insertion) != 0) {
goto out;
}
if (PyDict_SetItemString(indexes_dict, "edge_removal_order", removal) != 0) {
goto out;
}
}

ret = indexes_dict;
indexes_dict = NULL;
out:
Expand Down
6 changes: 6 additions & 0 deletions python/lwt_interface/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
--------------------
[0.1.2] - 2020-10-22
--------------------

- Added optional top-level key ``indexes`` which has contains ``edge_insertion_order`` and
``edge_removal_order``
66 changes: 64 additions & 2 deletions python/lwt_interface/dict_encoding_testlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def get_example_tables():
class TestEncodingVersion:
def test_version(self):
lwt = lwt_module.LightweightTableCollection()
assert lwt.asdict()["encoding_version"] == (1, 1)
assert lwt.asdict()["encoding_version"] == (1, 2)


class TestRoundTrip:
Expand Down Expand Up @@ -270,6 +270,7 @@ def test_missing_tables(self):
"metadata",
"metadata_schema",
"encoding_version",
"indexes",
}
for table_name in table_names:
d = tables.asdict()
Expand All @@ -292,6 +293,7 @@ def verify_columns(self, value):
"metadata",
"metadata_schema",
"encoding_version",
"indexes",
}
for table_name in table_names:
table_dict = d[table_name]
Expand All @@ -313,7 +315,7 @@ def test_str(self):
def test_bad_top_level_types(self):
tables = get_example_tables()
d = tables.asdict()
for key in set(d.keys()) - {"encoding_version"}:
for key in set(d.keys()) - {"encoding_version", "indexes"}:
bad_type_dict = tables.asdict()
# A list should be a ValueError for both the tables and sequence_length
bad_type_dict[key] = ["12345"]
Expand All @@ -336,6 +338,7 @@ def verify(self, num_rows):
"metadata",
"metadata_schema",
"encoding_version",
"indexes",
}
for table_name in sorted(table_names):
table_dict = d[table_name]
Expand All @@ -354,6 +357,30 @@ def test_two_rows(self):
def test_zero_rows(self):
self.verify(0)

def test_bad_index_length(self):
tables = get_example_tables()
for col in ("insertion", "removal"):
d = tables.asdict()
d["indexes"][f"edge_{col}_order"] = d["indexes"][f"edge_{col}_order"][:-1]
lwt = lwt_module.LightweightTableCollection()
with pytest.raises(
ValueError,
match="^edge_insertion_order and"
" edge_removal_order must be the same"
" length$",
):
lwt.fromdict(d)
d = tables.asdict()
for col in ("insertion", "removal"):
d["indexes"][f"edge_{col}_order"] = d["indexes"][f"edge_{col}_order"][:-1]
lwt = lwt_module.LightweightTableCollection()
with pytest.raises(
ValueError,
match="^edge_insertion_order and edge_removal_order must be"
" the same length as the number of edges$",
):
lwt.fromdict(d)


class TestRequiredAndOptionalColumns:
"""
Expand Down Expand Up @@ -563,6 +590,41 @@ def test_provenances(self):
["record", "record_offset", "timestamp", "timestamp_offset"],
)

def test_index(self):
tables = get_example_tables()
d = tables.asdict()
lwt = lwt_module.LightweightTableCollection()
lwt.fromdict(d)
other = lwt.asdict()
assert np.array_equal(
d["indexes"]["edge_insertion_order"],
other["indexes"]["edge_insertion_order"],
)
assert np.array_equal(
d["indexes"]["edge_removal_order"], other["indexes"]["edge_removal_order"]
)

# index is optional
d = tables.asdict()
del d["indexes"]
lwt = lwt_module.LightweightTableCollection()
lwt.fromdict(d)
# and a tc without indexes has empty dict
assert lwt.asdict()["indexes"] == {}

# Both columns must be provided, if one is
for col in ("insertion", "removal"):
d = tables.asdict()
del d["indexes"][f"edge_{col}_order"]
lwt = lwt_module.LightweightTableCollection()
with pytest.raises(
TypeError,
match="^edge_insertion_order and "
"edge_removal_order must be specified "
"together$",
):
lwt.fromdict(d)

def test_top_level_metadata(self):
tables = get_example_tables()
d = tables.asdict()
Expand Down
44 changes: 35 additions & 9 deletions python/lwt_interface/tskit_lwt_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -1172,14 +1172,7 @@ parse_provenance_table_dict(
return ret;
}

// TODO REMOVE THIS ONCE INDEXES IN LWT
#ifdef __GNUC__
#define VARIABLE_IS_NOT_USED __attribute__((unused))
#else
#define VARIABLE_IS_NOT_USED
#endif

VARIABLE_IS_NOT_USED static int
static int
parse_indexes_dict(tsk_table_collection_t *tables, PyObject *dict)
{
int err;
Expand Down Expand Up @@ -1409,6 +1402,21 @@ parse_table_collection_dict(tsk_table_collection_t *tables, PyObject *tables_dic
goto out;
}

/* indexes */
value = get_table_dict_value(tables_dict, "indexes", false);
if (value == NULL) {
goto out;
}
if (value != Py_None) {
if (!PyDict_Check(value)) {
PyErr_SetString(PyExc_TypeError, "not a dictionary");
goto out;
}
if (parse_indexes_dict(tables, value) != 0) {
goto out;
}
}

ret = 0;
out:
return ret;
Expand Down Expand Up @@ -1550,6 +1558,18 @@ write_table_arrays(tsk_table_collection_t *tables, PyObject *dict)
{ NULL },
};

struct table_col indexes_cols[] = {
{ "edge_insertion_order", (void *) tables->indexes.edge_insertion_order,
tables->indexes.num_edges, NPY_INT32 },
{ "edge_removal_order", (void *) tables->indexes.edge_removal_order,
tables->indexes.num_edges, NPY_INT32 },
{ NULL },
};

struct table_col no_indexes_cols[] = {
{ NULL },
};

struct table_desc table_descs[] = {
{ "individuals", individual_cols, tables->individuals.metadata_schema,
tables->individuals.metadata_schema_length },
Expand All @@ -1566,6 +1586,11 @@ write_table_arrays(tsk_table_collection_t *tables, PyObject *dict)
{ "populations", population_cols, tables->populations.metadata_schema,
tables->populations.metadata_schema_length },
{ "provenances", provenance_cols, NULL, 0 },
/* We don't want to insert empty indexes, return an empty dict if there are none
*/
{ "indexes",
tsk_table_collection_has_index(tables, 0) ? indexes_cols : no_indexes_cols,
NULL, 0 },
};

for (j = 0; j < sizeof(table_descs) / sizeof(*table_descs); j++) {
Expand Down Expand Up @@ -1605,6 +1630,7 @@ write_table_arrays(tsk_table_collection_t *tables, PyObject *dict)
Py_DECREF(table_dict);
table_dict = NULL;
}

ret = 0;
out:
Py_XDECREF(array);
Expand All @@ -1627,7 +1653,7 @@ dump_tables_dict(tsk_table_collection_t *tables)
}

/* Dict representation version */
val = Py_BuildValue("ll", 1, 1);
val = Py_BuildValue("ll", 1, 2);
if (val == NULL) {
goto out;
}
Expand Down
12 changes: 6 additions & 6 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,12 +417,7 @@ def test_index(self):
tc.indexes["edge_removal_order"], np.arange(18, dtype=np.int32)[::-1]
)
tc.drop_index()
assert np.array_equal(
tc.indexes["edge_insertion_order"], np.arange(0, dtype=np.int32)
)
assert np.array_equal(
tc.indexes["edge_removal_order"], np.arange(0, dtype=np.int32)[::-1]
)
assert tc.indexes == {}
tc.build_index()
assert np.array_equal(
tc.indexes["edge_insertion_order"], np.arange(18, dtype=np.int32)
Expand All @@ -444,6 +439,11 @@ def test_index(self):
tc.indexes["edge_removal_order"], np.arange(4242, 4242 + 18, dtype=np.int32)
)

def test_no_indexes(self):
tc = msprime.simulate(10, random_seed=42).tables._ll_tables
tc.drop_index()
assert tc.indexes == {}

def test_bad_indexes(self):
tc = msprime.simulate(10, random_seed=42).tables._ll_tables
for col in ("insertion", "removal"):
Expand Down
Loading