From 7cc450323128687fad1257ae0e1c35746df65009 Mon Sep 17 00:00:00 2001 From: mufernando Date: Mon, 1 Jun 2020 17:59:22 -0300 Subject: [PATCH] method to subset table collection --- c/tests/test_tables.c | 143 +++++++++++++++++++++++++++++ c/tskit/core.c | 3 + c/tskit/core.h | 1 + c/tskit/tables.c | 158 ++++++++++++++++++++++++++++++++ c/tskit/tables.h | 37 ++++++++ python/_tskitmodule.c | 34 +++++++ python/tests/test_lowlevel.py | 12 +++ python/tests/test_tables.py | 166 ++++++++++++++++++++++++++++++++++ python/tests/tsutil.py | 70 ++++++++++++++ python/tskit/tables.py | 20 ++++ python/tskit/trees.py | 39 ++++++++ 11 files changed, 683 insertions(+) diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 17c99e6c8d..7b087d4d09 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -3212,6 +3212,147 @@ test_table_collection_check_integrity(void) tsk_table_collection_free(&tables); } +static void +test_table_collection_subset(void) +{ + int ret; + tsk_table_collection_t tables; + tsk_table_collection_t tables_copy; + int k; + tsk_id_t nodes[4]; + + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 1; + ret = tsk_table_collection_init(&tables_copy, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // does not error on empty tables + ret = tsk_table_collection_subset(&tables, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // four nodes from two diploids; the first is from pop 0 + ret = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, 0, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, 0, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_node_table_add_row( + &tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_node_table_add_row( + &tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_individual_table_add_row(&tables.individuals, 0, NULL, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_individual_table_add_row(&tables.individuals, 0, NULL, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_population_table_add_row(&tables.populations, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_edge_table_add_row(&tables.edges, 0.0, 1.0, 1, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_edge_table_add_row(&tables.edges, 0.0, 1.0, 2, 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_site_table_add_row(&tables.sites, 0.2, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_site_table_add_row(&tables.sites, 0.4, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_mutation_table_add_row( + &tables.mutations, 0, 0, TSK_NULL, NULL, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_mutation_table_add_row(&tables.mutations, 0, 0, 0, NULL, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_mutation_table_add_row( + &tables.mutations, 1, 1, TSK_NULL, NULL, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + + // empty nodes should get empty tables + ret = tsk_table_collection_copy(&tables, &tables_copy, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_subset(&tables_copy, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(tables_copy.nodes.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(tables_copy.individuals.num_rows, 0); + CU_ASSERT_EQUAL_FATAL(tables_copy.populations.num_rows, 0); + + // the identity transformation + for (k = 0; k < 4; k++) { + nodes[k] = k; + } + ret = tsk_table_collection_copy(&tables, &tables_copy, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_subset(&tables_copy, nodes, 4); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_FATAL(tsk_table_collection_equals(&tables, &tables_copy)); + + // reverse twice should get back to the start + for (k = 0; k < 4; k++) { + nodes[k] = 3 - k; + } + ret = tsk_table_collection_copy(&tables, &tables_copy, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_subset(&tables_copy, nodes, 4); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_subset(&tables_copy, nodes, 4); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_FATAL(tsk_table_collection_equals(&tables, &tables_copy)); + + tsk_table_collection_free(&tables_copy); + tsk_table_collection_free(&tables); +} + +static void +test_table_collection_subset_errors(void) +{ + int ret; + tsk_table_collection_t tables; + tsk_table_collection_t tables_copy; + tsk_id_t nodes[4] = { 0, 1, 2, 3 }; + + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_init(&tables_copy, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // four nodes from two diploids; the first is from pop 0 + ret = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, 0, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, 0, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_node_table_add_row( + &tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_node_table_add_row( + &tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_individual_table_add_row(&tables.individuals, 0, NULL, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_individual_table_add_row(&tables.individuals, 0, NULL, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_population_table_add_row(&tables.populations, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_edge_table_add_row(&tables.edges, 0.0, 1.0, 1, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + + /* Migrations are not supported */ + ret = tsk_table_collection_copy(&tables, &tables_copy, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_migration_table_add_row(&tables_copy.migrations, 0, 1, 0, 0, 0, 0, NULL, 0); + CU_ASSERT_EQUAL_FATAL(tables_copy.migrations.num_rows, 1); + ret = tsk_table_collection_subset(&tables_copy, nodes, 4); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MIGRATIONS_NOT_SUPPORTED); + + // test out of bounds nodes + nodes[0] = -1; + ret = tsk_table_collection_subset(&tables, nodes, 4); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + nodes[0] = 6; + ret = tsk_table_collection_subset(&tables, nodes, 4); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + + tsk_table_collection_free(&tables); + tsk_table_collection_free(&tables_copy); +} + int main(int argc, char **argv) { @@ -3257,6 +3398,8 @@ main(int argc, char **argv) { "test_column_overflow", test_column_overflow }, { "test_table_collection_check_integrity", test_table_collection_check_integrity }, + { "test_table_collection_subset", test_table_collection_subset }, + { "test_table_collection__subset_errors", test_table_collection_subset_errors }, { NULL, NULL }, }; diff --git a/c/tskit/core.c b/c/tskit/core.c index 317c2c8097..ae6f12b930 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -323,6 +323,9 @@ tsk_strerror_internal(int err) case TSK_ERR_SORT_MIGRATIONS_NOT_SUPPORTED: ret = "Migrations not currently supported by sort"; break; + case TSK_ERR_MIGRATIONS_NOT_SUPPORTED: + ret = "Migrations not currently supported by this operation"; + break; case TSK_ERR_SORT_OFFSET_NOT_SUPPORTED: ret = "Specifying position for mutation, sites or migrations is not " "supported"; diff --git a/c/tskit/core.h b/c/tskit/core.h index 0cc0cec33f..9925c89b91 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -225,6 +225,7 @@ not found in the file. #define TSK_ERR_SORT_MIGRATIONS_NOT_SUPPORTED -802 #define TSK_ERR_SORT_OFFSET_NOT_SUPPORTED -803 #define TSK_ERR_NONBINARY_MUTATIONS_UNSUPPORTED -804 +#define TSK_ERR_MIGRATIONS_NOT_SUPPORTED -805 /* Stats errors */ #define TSK_ERR_BAD_NUM_WINDOWS -900 diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 6f5e8ef38e..d462079e9f 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -7799,6 +7799,164 @@ tsk_table_collection_clear(tsk_table_collection_t *self) return tsk_table_collection_truncate(self, &start); } +int TSK_WARN_UNUSED +tsk_table_collection_subset( + tsk_table_collection_t *self, tsk_id_t *nodes, tsk_size_t num_nodes) +{ + int ret = 0; + tsk_id_t k, i, new_ind, new_pop, new_parent, new_child, new_node; + tsk_node_t node; + tsk_individual_t ind; + tsk_population_t pop; + tsk_edge_t edge; + tsk_mutation_t mut; + tsk_site_t site; + tsk_id_t *node_map = NULL; + tsk_id_t *individual_map = NULL; + tsk_id_t *population_map = NULL; + tsk_id_t *site_map = NULL; + tsk_id_t *mutation_map = NULL; + tsk_table_collection_t tables; + + ret = tsk_table_collection_copy(self, &tables, 0); + if (ret != 0) { + goto out; + } + ret = tsk_table_collection_clear(self); + if (ret != 0) { + goto out; + } + + node_map = malloc(tables.nodes.num_rows * sizeof(*node_map)); + individual_map = malloc(tables.individuals.num_rows * sizeof(*individual_map)); + population_map = malloc(tables.populations.num_rows * sizeof(*population_map)); + site_map = malloc(tables.sites.num_rows * sizeof(*site_map)); + mutation_map = malloc(tables.mutations.num_rows * sizeof(*mutation_map)); + if (node_map == NULL || individual_map == NULL || population_map == NULL + || site_map == NULL || mutation_map == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + memset(node_map, 0xff, tables.nodes.num_rows * sizeof(*node_map)); + memset(individual_map, 0xff, tables.individuals.num_rows * sizeof(*individual_map)); + memset(population_map, 0xff, tables.populations.num_rows * sizeof(*population_map)); + memset(site_map, 0xff, tables.sites.num_rows * sizeof(*site_map)); + memset(mutation_map, 0xff, tables.mutations.num_rows * sizeof(*mutation_map)); + + // nodes, individuals, populations + for (k = 0; k < (tsk_id_t) num_nodes; k++) { + ret = tsk_node_table_get_row(&tables.nodes, nodes[k], &node); + if (ret < 0) { + goto out; + } + new_ind = TSK_NULL; + if (node.individual != TSK_NULL) { + if (individual_map[node.individual] == TSK_NULL) { + tsk_individual_table_get_row(&tables.individuals, node.individual, &ind); + ret = tsk_individual_table_add_row(&self->individuals, ind.flags, + ind.location, ind.location_length, ind.metadata, + ind.metadata_length); + if (ret < 0) { + goto out; + } + individual_map[node.individual] = ret; + } + new_ind = individual_map[node.individual]; + } + new_pop = TSK_NULL; + if (node.population != TSK_NULL) { + if (population_map[node.population] == TSK_NULL) { + tsk_population_table_get_row(&tables.populations, node.population, &pop); + ret = tsk_population_table_add_row( + &self->populations, pop.metadata, pop.metadata_length); + if (ret < 0) { + goto out; + } + population_map[node.population] = ret; + } + new_pop = population_map[node.population]; + } + ret = tsk_node_table_add_row(&self->nodes, node.flags, node.time, new_pop, + new_ind, node.metadata, node.metadata_length); + if (ret < 0) { + goto out; + } + node_map[node.id] = ret; + } + + // edges + for (k = 0; k < (tsk_id_t) tables.edges.num_rows; k++) { + tsk_edge_table_get_row(&tables.edges, k, &edge); + new_parent = node_map[edge.parent]; + new_child = node_map[edge.child]; + if ((new_parent != TSK_NULL) && (new_child != TSK_NULL)) { + ret = tsk_edge_table_add_row(&self->edges, edge.left, edge.right, new_parent, + new_child, edge.metadata, edge.metadata_length); + if (ret < 0) { + goto out; + } + } + } + + // mutations and sites + i = 0; + for (k = 0; k < (tsk_id_t) tables.sites.num_rows; k++) { + tsk_site_table_get_row(&tables.sites, k, &site); + while ((i < (tsk_id_t) tables.mutations.num_rows) + && (tables.mutations.site[i] == site.id)) { + tsk_mutation_table_get_row(&tables.mutations, i, &mut); + new_node = node_map[mut.node]; + if (new_node != TSK_NULL) { + if (site_map[site.id] == TSK_NULL) { + ret = tsk_site_table_add_row(&self->sites, site.position, + site.ancestral_state, site.ancestral_state_length, site.metadata, + site.metadata_length); + if (ret < 0) { + goto out; + } + site_map[site.id] = ret; + } + new_parent = TSK_NULL; + if (mut.parent != TSK_NULL) { + new_parent = mutation_map[mut.parent]; + } + ret = tsk_mutation_table_add_row(&self->mutations, site_map[site.id], + new_node, new_parent, mut.derived_state, mut.derived_state_length, + mut.metadata, mut.metadata_length); + if (ret < 0) { + goto out; + } + mutation_map[mut.id] = ret; + } + i++; + } + } + + /* TODO: Subset of the Migrations Table. The way to do this properly is not + * well-defined, mostly because migrations might contain events from/to + * populations that have not been kept in after the subset. */ + if (tables.migrations.num_rows != 0) { + ret = TSK_ERR_MIGRATIONS_NOT_SUPPORTED; + goto out; + } + + // provenance (new record is added in python) + ret = tsk_provenance_table_copy( + &tables.provenances, &self->provenances, TSK_NO_INIT); + if (ret < 0) { + goto out; + } + +out: + tsk_safe_free(node_map); + tsk_safe_free(individual_map); + tsk_safe_free(population_map); + tsk_safe_free(site_map); + tsk_safe_free(mutation_map); + tsk_table_collection_free(&tables); + return ret; +} + static int cmp_edge_cl(const void *a, const void *b) { diff --git a/c/tskit/tables.h b/c/tskit/tables.h index 47d3d9e3cd..5d3bffe1a5 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -2409,6 +2409,43 @@ completes. int tsk_table_collection_simplify(tsk_table_collection_t *self, tsk_id_t *samples, tsk_size_t num_samples, tsk_flags_t options, tsk_id_t *node_map); +/** +@brief Subsets and reorders a table collection according to an array of nodes. + +@rst +Reduces the table collection to contain only the entries referring to +the provided list of nodes, with nodes reordered according to the order +they appear in the ``nodes`` argument. Specifically, this subsets and reorders +each of the tables as follows: + +1. Nodes: if in the list of nodes, and in the order provided. +2. Individuals and Populations: if referred to by a retained node, + and in the order first seen when traversing the list of retained nodes. +3. Edges: if both parent and child are retained nodes. +4. Mutations: if the mutation's node is a retained node. +5. Sites: if any mutations remain at the site after removing mutations. +6. Migrations: if the migration's node is a retained node. + +Retained edges, mutations, sites, and migrations appear in the same +order as in the original tables. + +If ``nodes`` is the entire list of nodes in the tables, then the +resulting tables will be identical to the original tables, but with +nodes (and individuals and populations) reordered. + +.. note:: Migrations are currently not supported by susbset, and an error will + be raised if we attempt call subset on a table collection with greater + than zero migrations. +@endrst + +@param self A pointer to a tsk_table_collection_t object. +@param nodes An array of num_nodes valid node IDs. +@param num_nodes The number of node IDs in the input nodes array. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_table_collection_subset( + tsk_table_collection_t *self, tsk_id_t *nodes, tsk_size_t num_nodes); + /** @brief Set the metadata @rst diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 2a1a987785..df096aa219 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -6256,6 +6256,38 @@ TableCollection_link_ancestors(TableCollection *self, PyObject *args, PyObject * return ret; } +static PyObject * +TableCollection_subset(TableCollection *self, PyObject *args) +{ + int err; + PyObject *ret = NULL; + PyObject *nodes = NULL; + PyArrayObject *nodes_array = NULL; + npy_intp *shape; + size_t num_nodes; + + if (!PyArg_ParseTuple(args, "O", &nodes)) { + goto out; + } + nodes_array = (PyArrayObject *) PyArray_FROMANY(nodes, NPY_INT32, 1, 1, + NPY_ARRAY_IN_ARRAY); + if (nodes_array == NULL) { + goto out; + } + shape = PyArray_DIMS(nodes_array); + num_nodes = shape[0]; + + err = tsk_table_collection_subset(self->tables, PyArray_DATA(nodes_array), num_nodes); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + Py_XDECREF(nodes_array); + return ret; +} + static PyObject * TableCollection_sort(TableCollection *self, PyObject *args, PyObject *kwds) { @@ -6397,6 +6429,8 @@ static PyMethodDef TableCollection_methods[] = { {"link_ancestors", (PyCFunction) TableCollection_link_ancestors, METH_VARARGS|METH_KEYWORDS, "Returns an edge table linking samples to a set of specified ancestors." }, + {"subset", (PyCFunction) TableCollection_subset, METH_VARARGS, + "Subsets the tree sequence to a set of nodes." }, {"sort", (PyCFunction) TableCollection_sort, METH_VARARGS|METH_KEYWORDS, "Sorts the tables to satisfy tree sequence requirements." }, {"equals", (PyCFunction) TableCollection_equals, METH_VARARGS, diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index ed65d7233d..7b8fd18791 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -265,6 +265,18 @@ def test_link_ancestors(self): del edges self.assertEqual(tc.edges.num_rows, 2) + def test_subset_bad_args(self): + ts = msprime.simulate(10, random_seed=1) + tc = ts.tables.ll_tables + with self.assertRaises(TypeError): + tc.subset(np.array(["a"])) + with self.assertRaises(ValueError): + tc.subset(np.array([[1], [2]], dtype="int32")) + with self.assertRaises(TypeError): + tc.subset() + with self.assertRaises(_tskit.LibraryError): + tc.subset(np.array([100, 200], dtype="int32")) + class TestTreeSequence(LowLevelTestCase, MetadataTestMixin): """ diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index a4f7e3e75f..8a26eef2af 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -36,6 +36,7 @@ import numpy as np import _tskit +import tests.test_wright_fisher as wf import tests.tsutil as tsutil import tskit import tskit.exceptions as exceptions @@ -2363,3 +2364,168 @@ def test_set_columns_not_implemented(self): t = tskit.BaseTable(None, None) with self.assertRaises(NotImplementedError): t.set_columns() + + +class TestSubsetTables(unittest.TestCase): + """ + Tests for the TableCollection.subset method. + """ + + def get_msprime_example(self, sample_size=10, seed=1234): + M = [[0.0, 0.1], [1.0, 0.0]] + population_configurations = [ + msprime.PopulationConfiguration(sample_size=sample_size), + msprime.PopulationConfiguration(sample_size=sample_size), + ] + ts = msprime.simulate( + population_configurations=population_configurations, + migration_matrix=M, + length=2e5, + recombination_rate=1e-8, + mutation_rate=1e-7, + record_migrations=False, + random_seed=seed, + ) + # adding metadata and locations + ts = tsutil.add_random_metadata(ts, seed) + ts = tsutil.insert_random_ploidy_individuals(ts, max_ploidy=1) + return ts.tables + + def get_wf_example(self, N=5, ngens=2, seed=1249): + tables = wf.wf_sim(N, N, seed=seed) + tables.sort() + ts = tables.tree_sequence() + # adding muts + ts = tsutil.jukes_cantor(ts, 1, 10, seed=seed) + ts = tsutil.add_random_metadata(ts, seed) + ts = tsutil.insert_random_ploidy_individuals(ts, max_ploidy=2) + return ts.tables + + def get_examples(self, seed): + yield self.get_msprime_example(seed=seed) + yield self.get_wf_example(seed=seed) + + def verify_subset_equality(self, tables, nodes): + sub1 = tables.copy() + sub2 = tables.copy() + tsutil.py_subset(sub1, nodes, record_provenance=False) + sub2.subset(nodes, record_provenance=False) + self.assertEqual(sub1, sub2) + + def verify_subset(self, tables, nodes): + self.verify_subset_equality(tables, nodes) + subset = tables.copy() + subset.subset(nodes, record_provenance=False) + # adding one so the last element always maps to NULL (-1 -> -1) + node_map = np.repeat(tskit.NULL, tables.nodes.num_rows + 1) + indivs = [] + pops = [] + for k, n in enumerate(nodes): + node_map[n] = k + ind = tables.nodes[n].individual + pop = tables.nodes[n].population + if ind not in indivs and ind != tskit.NULL: + indivs.append(ind) + if pop not in pops and pop != tskit.NULL: + pops.append(pop) + ind_map = np.repeat(tskit.NULL, tables.individuals.num_rows + 1) + ind_map[indivs] = np.arange(len(indivs), dtype="int32") + pop_map = np.repeat(tskit.NULL, tables.populations.num_rows + 1) + pop_map[pops] = np.arange(len(pops), dtype="int32") + self.assertEqual(subset.nodes.num_rows, len(nodes)) + for k, n in zip(nodes, subset.nodes): + nn = tables.nodes[k] + self.assertEqual(nn.time, n.time) + self.assertEqual(nn.flags, n.flags) + self.assertEqual(nn.metadata, n.metadata) + self.assertEqual(ind_map[nn.individual], n.individual) + self.assertEqual(pop_map[nn.population], n.population) + self.assertEqual(subset.individuals.num_rows, len(indivs)) + for k, i in zip(indivs, subset.individuals): + ii = tables.individuals[k] + self.assertEqual(ii, i) + self.assertEqual(subset.populations.num_rows, len(pops)) + for k, p in zip(pops, subset.populations): + pp = tables.populations[k] + self.assertEqual(pp, p) + edges = [ + i + for i, e in enumerate(tables.edges) + if e.parent in nodes and e.child in nodes + ] + self.assertEqual(subset.edges.num_rows, len(edges)) + for k, e in zip(edges, subset.edges): + ee = tables.edges[k] + self.assertEqual(ee.left, e.left) + self.assertEqual(ee.right, e.right) + self.assertEqual(node_map[ee.parent], e.parent) + self.assertEqual(node_map[ee.child], e.child) + self.assertEqual(ee.metadata, e.metadata) + muts = [] + sites = [] + for k, m in enumerate(tables.mutations): + if m.node in nodes: + muts.append(k) + if m.site not in sites: + sites.append(m.site) + site_map = np.repeat(-1, tables.sites.num_rows) + site_map[sites] = np.arange(len(sites), dtype="int32") + mutation_map = np.repeat(tskit.NULL, tables.mutations.num_rows + 1) + mutation_map[muts] = np.arange(len(muts), dtype="int32") + self.assertEqual(subset.sites.num_rows, len(sites)) + for k, s in zip(sites, subset.sites): + ss = tables.sites[k] + self.assertEqual(ss, s) + self.assertEqual(subset.mutations.num_rows, len(muts)) + for k, m in zip(muts, subset.mutations): + mm = tables.mutations[k] + self.assertEqual(mutation_map[mm.parent], m.parent) + self.assertEqual(site_map[mm.site], m.site) + self.assertEqual(node_map[mm.node], m.node) + self.assertEqual(mm.derived_state, m.derived_state) + self.assertEqual(mm.metadata, m.metadata) + self.assertEqual(tables.migrations, subset.migrations) + self.assertEqual(tables.provenances, subset.provenances) + + def test_ts_subset(self): + nodes = np.array([0, 1]) + for tables in self.get_examples(83592): + ts = tables.tree_sequence() + tables2 = ts.subset(nodes, record_provenance=False).dump_tables() + tables.subset(nodes, record_provenance=False) + self.assertEqual(tables, tables2) + + def test_subset_all(self): + # subsetting to everything shouldn't change things + # except the individual ids in the node tables if + # there are gaps + for tables in self.get_examples(123583): + tables2 = tables.copy() + tables2.subset(np.arange(tables.nodes.num_rows)) + tables.provenances.clear() + tables2.provenances.clear() + tables.individuals.clear() + tables2.individuals.clear() + tables.nodes.clear() + tables2.nodes.clear() + self.assertEqual(tables, tables2) + + def test_random_subsets(self): + rng = np.random.default_rng(1542) + for tables in self.get_examples(9412): + for n in [2, tables.nodes.num_rows - 10]: + nodes = rng.choice(np.arange(tables.nodes.num_rows), n, replace=False) + self.verify_subset(tables, nodes) + + def test_empty_nodes(self): + for tables in self.get_examples(8724): + subset = tables.copy() + subset.subset(np.array([]), record_provenance=False) + self.assertEqual(subset.nodes.num_rows, 0) + self.assertEqual(subset.edges.num_rows, 0) + self.assertEqual(subset.populations.num_rows, 0) + self.assertEqual(subset.individuals.num_rows, 0) + self.assertEqual(subset.migrations.num_rows, 0) + self.assertEqual(subset.sites.num_rows, 0) + self.assertEqual(subset.mutations.num_rows, 0) + self.assertEqual(subset.provenances, tables.provenances) diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 5d8a88f6ff..aeb5b9f743 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -539,6 +539,76 @@ def compute_mutation_parent(ts): return mutation_parent +def py_subset(tables, nodes, record_provenance=True): + """ + Naive implementation of the TableCollection.subset method using the Python API. + """ + if np.any(nodes > tables.nodes.num_rows) or np.any(nodes < 0): + raise ValueError("Nodes out of bounds.") + full = tables.copy() + # there is no table collection clear in the py API + tables.nodes.clear() + tables.individuals.clear() + tables.populations.clear() + tables.edges.clear() + tables.migrations.clear() + tables.sites.clear() + tables.mutations.clear() + # mapping from old to new ids + node_map = {} + ind_map = {tskit.NULL: tskit.NULL} + pop_map = {tskit.NULL: tskit.NULL} + for old_id in nodes: + node = full.nodes[old_id] + if node.individual not in ind_map and node.individual != tskit.NULL: + ind = full.individuals[node.individual] + new_ind_id = tables.individuals.add_row( + ind.flags, ind.location, ind.metadata + ) + ind_map[node.individual] = new_ind_id + if node.population not in pop_map and node.population != tskit.NULL: + pop = full.populations[node.population] + new_pop_id = tables.populations.add_row(pop.metadata) + pop_map[node.population] = new_pop_id + new_id = tables.nodes.add_row( + node.flags, + node.time, + pop_map[node.population], + ind_map[node.individual], + node.metadata, + ) + node_map[old_id] = new_id + for edge in full.edges: + if edge.child in nodes and edge.parent in nodes: + tables.edges.add_row( + edge.left, + edge.right, + node_map[edge.parent], + node_map[edge.child], + edge.metadata, + ) + if full.migrations.num_rows > 0: + raise ValueError("Migrations are currently not supported in this operation.") + site_map = {} + mutation_map = {tskit.NULL: tskit.NULL} + for i, mut in enumerate(full.mutations): + if mut.node in nodes: + if mut.site not in site_map: + site = full.sites[mut.site] + new_site = tables.sites.add_row( + site.position, site.ancestral_state, site.metadata + ) + site_map[mut.site] = new_site + new_mut = tables.mutations.add_row( + site_map[mut.site], + node_map[mut.node], + mut.derived_state, + mutation_map.get(mut.parent, tskit.NULL), + mut.metadata, + ) + mutation_map[i] = new_mut + + def algorithm_T(ts): """ Simple implementation of algorithm T from the PLOS paper, taking into diff --git a/python/tskit/tables.py b/python/tskit/tables.py index f14a958b74..b97e01ac33 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -2494,3 +2494,23 @@ def drop_index(self): currently indexed this method has no effect. """ self.ll_tables.drop_index() + + def subset(self, nodes, record_provenance=True): + """ + Modifies the tables in place to contain only the entries referring to + the provided list of nodes, with nodes reordered according to the order + they appear in the list. See :meth:`TreeSequence.subset` for a more + detailed description. + + :param list nodes: The list of nodes for which to retain information. This + may be a numpy array (or array-like) object (dtype=np.int32). + :param bool record_provenance: Whether to record a provenance entry + in the provenance table for this operation. + """ + nodes = util.safe_np_int_cast(nodes, np.int32) + self.ll_tables.subset(nodes) + if record_provenance: + parameters = {"command": "subset", "nodes": nodes.tolist()} + self.provenances.add_row( + record=json.dumps(provenance.get_provenance_dict(parameters)) + ) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 6bb3b62599..2382121d14 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -4470,6 +4470,45 @@ def trim(self, record_provenance=True): tables.trim(record_provenance) return tables.tree_sequence() + def subset(self, nodes, record_provenance=True): + """ + Returns a tree sequence modified to contain only the entries referring to + the provided list of nodes, with nodes reordered according to the order + they appear in the ``nodes`` argument. Specifically, this subsets and reorders + each of the tables as follows: + + 1. Nodes: if in the list of nodes, and in the order provided. + 2. Individuals and Populations: if referred to by a retained node, + and in the order first seen when traversing the list of retained nodes. + 3. Edges: if both parent and child are retained nodes. + 4. Mutations: if the mutation's node is a retained node. + 5. Sites: if any mutations remain at the site after removing mutations. + 6. Migrations: if the migration's node is a retained node. + + Retained edges, mutations, sites, and migrations appear in the same + order as in the original tables. + + If ``nodes`` is the entire list of nodes in the tables, then the + resulting tables will be identical to the original tables, but with + nodes (and individuals and populations) reordered. + + To instead subset the tables to a given portion of the *genome*, see + :meth:`.keep_intervals`. + + **Note:** This is quite different from :meth:`.simplify`: the resulting + tables contain only the nodes given, not ancestral ones as well, and + does not simplify the relationships in any way. + + :param list nodes: The list of nodes for which to retain information. This + may be a numpy array (or array-like) object (dtype=np.int32). + :param bool record_provenance: If True, add details of this operation to the + provenance information of the returned tree sequence. (Default: True). + :rtype: .TreeSequence + """ + tables = self.dump_tables() + tables.subset(nodes, record_provenance) + return tables.tree_sequence() + def draw_svg( self, path=None,