From ddc1dba3f73c45a8f36be74cd50ce677ed3236d1 Mon Sep 17 00:00:00 2001 From: peter Date: Wed, 24 Sep 2025 15:12:18 -0700 Subject: [PATCH] c stuff for disjoint options to union Add some python tests And fix concatenate() all_mutations implies really all_sites --- c/tests/test_tables.c | 104 +++++++++++++++++++++++++++++++++ c/tskit/tables.c | 25 +++++++- c/tskit/tables.h | 16 ++++- python/CHANGELOG.rst | 7 +++ python/_tskitmodule.c | 15 ++++- python/tests/test_highlevel.py | 22 +++++++ python/tests/test_lowlevel.py | 19 ++++++ python/tests/test_tables.py | 76 ++++++++++++++++++++++++ python/tests/test_topology.py | 93 ++++++++++++++++++++++++----- python/tskit/tables.py | 9 +++ python/tskit/trees.py | 62 +++++++++++++++++--- 11 files changed, 418 insertions(+), 30 deletions(-) diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 9def0a29d6..0b9ff088e5 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -11240,6 +11240,109 @@ test_table_collection_union(void) tsk_table_collection_free(&tables); } +static void +test_table_collection_disjoint_union(void) +{ + int ret; + tsk_id_t ret_id; + tsk_table_collection_t tables; + tsk_table_collection_t tables1; + tsk_table_collection_t tables2; + tsk_table_collection_t tables12; + tsk_id_t node_mapping[4]; + + tsk_memset(node_mapping, 0xff, sizeof(node_mapping)); + + ret = tsk_table_collection_init(&tables1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables1.sequence_length = 2; + + // set up nodes, which will be shared + // flags, time, pop, ind, metadata, metadata_length + ret_id = tsk_node_table_add_row( + &tables1.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_node_table_add_row( + &tables1.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_node_table_add_row(&tables1.nodes, 0, 0.5, TSK_NULL, TSK_NULL, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_node_table_add_row(&tables1.nodes, 0, 1.5, TSK_NULL, TSK_NULL, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret = tsk_table_collection_copy(&tables1, &tables2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // for tables1: + // on [0, 1] we have 0, 1 inherit from 2 + // left, right, parent, child, metadata, metadata_length + ret_id = tsk_edge_table_add_row(&tables1.edges, 0.0, 1.0, 2, 0, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_edge_table_add_row(&tables1.edges, 0.0, 1.0, 2, 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_site_table_add_row(&tables1.sites, 0.4, "T", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_mutation_table_add_row( + &tables1.mutations, ret_id, 0, TSK_NULL, TSK_UNKNOWN_TIME, NULL, 0, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret = tsk_table_collection_build_index(&tables1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_sort(&tables1, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // all this goes in tables12 so far + ret = tsk_table_collection_copy(&tables1, &tables12, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // for tables2; and need to add to tables12 also: + // on [1, 2] we have 0, 1 inherit from 3 + // left, right, parent, child, metadata, metadata_length + ret_id = tsk_edge_table_add_row(&tables2.edges, 1.0, 2.0, 3, 0, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_edge_table_add_row(&tables2.edges, 1.0, 2.0, 3, 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_site_table_add_row(&tables2.sites, 1.4, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_mutation_table_add_row( + &tables2.mutations, ret_id, 1, TSK_NULL, TSK_UNKNOWN_TIME, "T", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret = tsk_table_collection_build_index(&tables2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_sort(&tables2, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + // also tables12 + ret_id = tsk_edge_table_add_row(&tables12.edges, 1.0, 2.0, 3, 0, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_edge_table_add_row(&tables12.edges, 1.0, 2.0, 3, 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_site_table_add_row(&tables12.sites, 1.4, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_mutation_table_add_row( + &tables12.mutations, ret_id, 1, TSK_NULL, TSK_UNKNOWN_TIME, "T", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret = tsk_table_collection_build_index(&tables12, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_sort(&tables12, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // now disjoint union-ing tables1 and tables2 should get tables12 + ret = tsk_table_collection_copy(&tables1, &tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + node_mapping[0] = 0; + node_mapping[1] = 1; + node_mapping[2] = 2; + node_mapping[3] = 3; + ret = tsk_table_collection_union(&tables, &tables2, node_mapping, + TSK_UNION_NO_CHECK_SHARED | TSK_UNION_ALL_EDGES | TSK_UNION_ALL_MUTATIONS); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_FATAL( + tsk_table_collection_equals(&tables, &tables12, TSK_CMP_IGNORE_PROVENANCE)); + + tsk_table_collection_free(&tables12); + tsk_table_collection_free(&tables2); + tsk_table_collection_free(&tables1); + tsk_table_collection_free(&tables); +} + static void test_table_collection_union_middle_merge(void) { @@ -11836,6 +11939,7 @@ main(int argc, char **argv) test_table_collection_subset_unsorted }, { "test_table_collection_subset_errors", test_table_collection_subset_errors }, { "test_table_collection_union", test_table_collection_union }, + { "test_table_collection_disjoint_union", test_table_collection_disjoint_union }, { "test_table_collection_union_middle_merge", test_table_collection_union_middle_merge }, { "test_table_collection_union_errors", test_table_collection_union_errors }, diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 9805d669a5..7106300f3d 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -13202,6 +13202,8 @@ tsk_table_collection_union(tsk_table_collection_t *self, tsk_id_t *site_map = NULL; bool add_populations = !(options & TSK_UNION_NO_ADD_POP); bool check_shared_portion = !(options & TSK_UNION_NO_CHECK_SHARED); + bool all_edges = !!(options & TSK_UNION_ALL_EDGES); + bool all_mutations = !!(options & TSK_UNION_ALL_MUTATIONS); /* Not calling TSK_CHECK_TREES so casting to int is safe */ ret = (int) tsk_table_collection_check_integrity(self, 0); @@ -13285,7 +13287,7 @@ tsk_table_collection_union(tsk_table_collection_t *self, // edges for (k = 0; k < (tsk_id_t) other->edges.num_rows; k++) { tsk_edge_table_get_row_unsafe(&other->edges, k, &edge); - if ((other_node_mapping[edge.parent] == TSK_NULL) + if (all_edges || (other_node_mapping[edge.parent] == TSK_NULL) || (other_node_mapping[edge.child] == TSK_NULL)) { new_parent = node_map[edge.parent]; new_child = node_map[edge.child]; @@ -13298,14 +13300,31 @@ tsk_table_collection_union(tsk_table_collection_t *self, } } - // mutations and sites + // sites + // first do the "disjoint" (all_mutations) case, where we just add all sites; + // otherwise we want to just add sites for new mutations + if (all_mutations) { + for (k = 0; k < (tsk_id_t) other->sites.num_rows; k++) { + tsk_site_table_get_row_unsafe(&other->sites, k, &site); + ret_id = tsk_site_table_add_row(&self->sites, site.position, + site.ancestral_state, site.ancestral_state_length, site.metadata, + site.metadata_length); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; + } + site_map[site.id] = ret_id; + } + } + + // mutations (and maybe sites) i = 0; for (k = 0; k < (tsk_id_t) other->sites.num_rows; k++) { tsk_site_table_get_row_unsafe(&other->sites, k, &site); while ((i < (tsk_id_t) other->mutations.num_rows) && (other->mutations.site[i] == site.id)) { tsk_mutation_table_get_row_unsafe(&other->mutations, i, &mut); - if (other_node_mapping[mut.node] == TSK_NULL) { + if (all_mutations || (other_node_mapping[mut.node] == TSK_NULL)) { if (site_map[site.id] == TSK_NULL) { ret_id = tsk_site_table_add_row(&self->sites, site.position, site.ancestral_state, site.ancestral_state_length, site.metadata, diff --git a/c/tskit/tables.h b/c/tskit/tables.h index 85ed29d58c..9523ee1274 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -858,11 +858,21 @@ equality of the subsets. */ #define TSK_UNION_NO_CHECK_SHARED (1 << 0) /** - By default, all nodes new to ``self`` are assigned new populations. If this +By default, all nodes new to ``self`` are assigned new populations. If this option is specified, nodes that are added to ``self`` will retain the population IDs they have in ``other``. */ #define TSK_UNION_NO_ADD_POP (1 << 1) +/** +By default, union only adds edges adjacent to a newly added node; +this option adds all edges. + */ +#define TSK_UNION_ALL_EDGES (1 << 2) +/** +By default, union only adds only mutations on newly added edges, and +sites for those mutations; this option adds all mutations and all sites. + */ +#define TSK_UNION_ALL_MUTATIONS (1 << 3) /** @} */ /** @@ -4414,6 +4424,10 @@ that are exclusive ``other`` are added to ``self``, along with: By default, populations of newly added nodes are assumed to be new populations, and added to the population table as well. +The behavior can be changed by the flags ``TSK_UNION_ALL_EDGES`` and +``TSK_UNION_ALL_MUTATIONS``, which will (respectively) add *all* edges +or *all* sites and mutations instead. + This operation will also sort the resulting tables, so the tables may change even if nothing new is added, if the original tables were not sorted. diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 102c111981..28cf1496b6 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -50,6 +50,10 @@ - Add ``Mutation.inherited_state`` property which returns the inherited state for a single mutation. (:user:`benjeffery`, :pr:`3277`, :issue:`2631`) +- Add ``all_mutations`` and ``all_edges`` options to ``TreeSequence.union``, + allowing greater flexibility in "disjoint union" situations. + (:user:`hyanwong`, :user:`petrelharp`, :issue:`3181`) + **Bugfixes** - In some tables with mutations out-of-order ``TableCollection.sort`` did not re-order @@ -84,6 +88,9 @@ - Prevent iterating over a ``TopologyCounter`` (:user:`benjeffery` , :pr:`3202`, :issue:`1462`) +- Fix ``TreeSequence.concatenate()`` to work with internal samples by using the + ``all_mutations`` and ``all_edges`` parameters in ``union()`` + (:user:`hyanwong`, :pr:`3283`, :issue:`3181`) **Breaking changes** diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 23ab663538..78cb9f7c8e 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -4347,15 +4347,18 @@ TableCollection_union(TableCollection *self, PyObject *args, PyObject *kwds) npy_intp *shape; tsk_flags_t options = 0; int check_shared = true; + int all_edges = false; + int all_mutations = false; int add_populations = true; static char *kwlist[] = { "other", "other_node_mapping", "check_shared_equality", - "add_populations", NULL }; + "add_populations", "all_edges", "all_mutations", NULL }; if (TableCollection_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!O|ii", kwlist, &TableCollectionType, - &other, &other_node_mapping, &check_shared, &add_populations)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!O|iiii", kwlist, + &TableCollectionType, &other, &other_node_mapping, &check_shared, + &add_populations, &all_edges, &all_mutations)) { goto out; } nmap_array = (PyArrayObject *) PyArray_FROMANY( @@ -4370,6 +4373,12 @@ TableCollection_union(TableCollection *self, PyObject *args, PyObject *kwds) " number of nodes in the other tree sequence."); goto out; } + if (all_edges) { + options |= TSK_UNION_ALL_EDGES; + } + if (all_mutations) { + options |= TSK_UNION_ALL_MUTATIONS; + } if (!check_shared) { options |= TSK_UNION_NO_CHECK_SHARED; } diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index b3fed482a6..61777f1e79 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -2536,6 +2536,28 @@ def test_mutation_parent_errors(self, mutations, error): else: tables.tree_sequence() + def test_union(self, ts_fixture): + # most of the union tests are in test_tables.py, here we just sanity check + tables = ts_fixture.dump_tables() + tables.migrations.clear() # migrations not supported in union() + ts = tables.tree_sequence() + tables = tskit.TableCollection(ts.sequence_length) + tables.time_units = ts.time_units + empty = tables.tree_sequence() + union_ts = empty.union( + ts, + node_mapping=np.full(ts.num_nodes, tskit.NULL, dtype=int), + all_edges=True, + all_mutations=True, + check_shared_equality=False, + ) + union_ts.tables.assert_equals( + ts.tables, + ignore_metadata=True, + ignore_reference_sequence=True, + ignore_provenance=True, + ) + class TestSimplify: # This class was factored out of the old TestHighlevel class 2022-12-13, diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index ab003ffd7a..53d073b6e8 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -433,6 +433,25 @@ def test_union_bad_args(self): with pytest.raises(ValueError): tc.union(tc2, np.array([[1], [2]], dtype="int32")) + @pytest.mark.parametrize("value", [True, False]) + @pytest.mark.parametrize( + "flag", + [ + "all_edges", + "all_mutations", + "check_shared_equality", + "add_populations", + ], + ) + def test_union_options(self, flag, value): + ts = msprime.simulate(10, random_seed=1) + tc = ts.dump_tables()._ll_tables + empty_tables = ts.dump_tables() + for table in empty_tables.table_name_map.keys(): + getattr(empty_tables, table).clear() + tc2 = empty_tables._ll_tables + tc.union(tc2, np.arange(0, dtype="int32"), **{flag: value}) + def test_equals_bad_args(self): ts = msprime.simulate(10, random_seed=1242) tc = ts.dump_tables()._ll_tables diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 427e9130c2..6971780b47 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -5271,6 +5271,82 @@ def test_examples(self): ts = tables.tree_sequence() self.verify_union(*self.split_example(ts, T)) + def test_split_and_rejoin(self): + ts = self.get_msprime_example(5, T=2, seed=928) + cutpoints = np.array([0, 0.25, 0.5, 0.75, 1]) * ts.sequence_length + tables1 = ts.dump_tables() + tables1.delete_intervals([cutpoints[0:2], cutpoints[2:4]], simplify=False) + tables2 = ts.dump_tables() + tables2.delete_intervals([cutpoints[1:3], cutpoints[3:]], simplify=False) + tables1.union( + tables2, + all_edges=True, + all_mutations=True, + node_mapping=np.arange(ts.num_nodes), + check_shared_equality=False, + ) + tables1.edges.squash() + tables1.sort() + tables1.assert_equals(ts.tables, ignore_provenance=True) + + def test_both_empty(self): + tables = tskit.TableCollection(sequence_length=1) + t1 = tables.copy() + t2 = tables.copy() + t1.union(t2, node_mapping=np.arange(0), all_edges=True, all_mutations=True) + t1.assert_equals(tables, ignore_provenance=True) + + def test_one_empty(self): + ts = self.get_msprime_example(5, T=2, seed=928) + ts = ts.simplify() # the example has a load of unreferenced individuals + tables = ts.dump_tables() + empty = tskit.TableCollection(sequence_length=tables.sequence_length) + empty.time_units = tables.time_units + + # union with empty should be no-op + tables.union( + empty, node_mapping=np.arange(0), all_edges=True, all_mutations=True + ) + tables.assert_equals(ts.dump_tables(), ignore_provenance=True) + + # empty union with tables should be tables + empty.union( + tables, + node_mapping=np.full(tables.nodes.num_rows, tskit.NULL), + all_edges=True, + all_mutations=True, + check_shared_equality=False, + ) + empty.assert_equals(tables, ignore_provenance=True) + + def test_reciprocal_empty(self): + # reciprocally add mutations from one table and edges from the other + edges_table = tskit.Tree.generate_comb(6, span=6).tree_sequence.dump_tables() + muts_table = tskit.TableCollection(sequence_length=6) + muts_table.nodes.replace_with(edges_table.nodes) # same nodes, no edges + for j in range(0, 6): + site_id = muts_table.sites.add_row(position=j, ancestral_state="0") + if j % 2 == 0: + # Some sites empty + muts_table.mutations.add_row(site=site_id, node=j, derived_state="1") + identity_map = np.arange(len(muts_table.nodes), dtype="int32") + params = {"node_mapping": identity_map, "check_shared_equality": False} + + test_table = edges_table.copy() + test_table.union(muts_table, **params, all_edges=True) # null op + assert len(test_table.sites) == 0 + assert len(test_table.mutations) == 0 + test_table.union(muts_table, **params, all_mutations=True) + assert test_table.sites == muts_table.sites + assert test_table.mutations == muts_table.mutations + + muts_table.union(edges_table, **params, all_mutations=True) # null op + assert len(muts_table.edges) == 0 + muts_table.union(edges_table, **params, all_edges=True) + assert muts_table.edges == edges_table.edges + + muts_table.assert_equals(test_table, ignore_provenance=True) + class TestTableSetitemMetadata: @pytest.mark.parametrize("table_name", tskit.TABLE_NAMES) diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index e19c957e9c..b693743a2a 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -7215,18 +7215,54 @@ def test_reference_sequence(self): class TestConcatenate: def test_simple(self): ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence + ts1 = msprime.sim_mutations(ts1, rate=1, random_seed=1) ts2 = tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence + ts2 = msprime.sim_mutations(ts2, rate=1, random_seed=2) assert ts1.num_samples == ts2.num_samples assert ts1.num_nodes != ts2.num_nodes joint_ts = ts1.concatenate(ts2) assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 5 assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length assert joint_ts.num_samples == ts1.num_samples + assert joint_ts.num_sites == ts1.num_sites + ts2.num_sites + assert joint_ts.num_mutations == ts1.num_mutations + ts2.num_mutations ts3 = joint_ts.delete_intervals([[2, 5]]).rtrim() # Have to simplify here, to remove the redundant nodes - assert ts3.equals(ts1.simplify(), ignore_provenance=True) + ts3.tables.assert_equals(ts1.tables, ignore_provenance=True) ts4 = joint_ts.delete_intervals([[0, 2]]).ltrim() - assert ts4.equals(ts2.simplify(), ignore_provenance=True) + ts4.tables.assert_equals(ts2.tables, ignore_provenance=True) + + def test_metadata(self, ts_fixture): + tables = ts_fixture.dump_tables() + tables.reference_sequence.clear() + tables.migrations.clear() + ts = tables.tree_sequence() + num_sites = ts.num_sites + assert num_sites > 0 + joint_ts = ts.concatenate(ts) + for s1, s2 in zip(range(num_sites), range(num_sites, num_sites * 2)): + site1 = joint_ts.site(s1) + site2 = joint_ts.site(s2) + assert site1.metadata == site2.metadata + assert site1.ancestral_state == site2.ancestral_state + assert len(site1.mutations) == len(site2.mutations) + for m1, m2 in zip(site1.mutations, site2.mutations): + assert m1.metadata == m2.metadata + assert m1.derived_state == m2.derived_state + assert m1.time == m2.time + ns_nodes = np.where(ts.tables.nodes.flags & tskit.NODE_IS_SAMPLE == 0)[0] + assert len(ns_nodes) > 0 + for u in ns_nodes: + node1 = ts.node(u) + node2 = joint_ts.node(u) + assert node1.metadata == node2.metadata + assert node1.flags == node2.flags + assert node1.time == node2.time + ind1 = joint_ts.individual(node1.individual) + ind2 = joint_ts.individual(node2.individual) + assert ind1.metadata == ind2.metadata + assert ind1.flags == ind2.flags + assert np.all(ind1.location == ind2.location) def test_multiple(self): np.random.seed(42) @@ -7278,15 +7314,47 @@ def test_internal_samples(self): assert joint_ts.sequence_length == ts.sequence_length * 2 def test_some_shared_samples(self): - ts1 = tskit.Tree.generate_comb(4, span=2).tree_sequence - ts2 = tskit.Tree.generate_balanced(8, arity=3, span=3).tree_sequence - shared = np.full(ts2.num_nodes, tskit.NULL) - shared[0] = 1 - shared[1] = 0 - joint_ts = ts1.concatenate(ts2, node_mappings=[shared]) - assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length - assert joint_ts.num_samples == ts1.num_samples + ts2.num_samples - 2 - assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 2 + tables = tskit.Tree.generate_comb(5).tree_sequence.dump_tables() + tables.nodes[5] = tables.nodes[5].replace(flags=tskit.NODE_IS_SAMPLE) + ts1 = tables.tree_sequence() + tables = tskit.Tree.generate_balanced(5).tree_sequence.dump_tables() + tables.nodes[5] = tables.nodes[5].replace(flags=tskit.NODE_IS_SAMPLE) + ts2 = tables.tree_sequence() + assert ts1.num_samples == ts2.num_samples + joint_ts = ts1.concatenate(ts2) + assert joint_ts.num_samples == ts1.num_samples + assert joint_ts.num_edges == ts1.num_edges + ts2.num_edges + for tree in joint_ts.trees(): + assert tree.num_roots == 1 + + @pytest.mark.parametrize("simplify", [True, False]) + def test_wf_sim(self, simplify): + # Test that we can split & concat a wf_sim ts, which has internal samples + tables = wf.wf_sim( + 6, + 5, + seed=3, + deep_history=True, + initial_generation_samples=True, + num_loci=10, + ) + tables.sort() + tables.simplify() + ts = msprime.mutate(tables.tree_sequence(), rate=0.05, random_seed=234) + assert ts.num_trees > 2 + assert len(np.unique(ts.nodes_time[ts.samples()])) > 1 + ts1 = ts.keep_intervals([[0, 4.5]], simplify=False).trim() + ts2 = ts.keep_intervals([[4.5, ts.sequence_length]], simplify=False).trim() + if simplify: + ts1 = ts1.simplify(filter_nodes=False) + ts2, node_map = ts2.simplify(map_nodes=True) + node_mapping = np.zeros_like(node_map, shape=ts2.num_nodes) + kept = node_map != tskit.NULL + node_mapping[node_map[kept]] = np.arange(len(node_map))[kept] + else: + node_mapping = np.arange(ts.num_nodes) + ts_new = ts1.concatenate(ts2, node_mappings=[node_mapping]).simplify() + ts_new.tables.assert_equals(ts.tables, ignore_provenance=True) def test_provenance(self): ts = tskit.Tree.generate_comb(2).tree_sequence @@ -7304,9 +7372,6 @@ def test_unequal_samples(self): with pytest.raises(ValueError, match="must have the same number of samples"): ts1.concatenate(ts2) - @pytest.mark.skip( - reason="union bug: https://github.com/tskit-dev/tskit/issues/3168" - ) def test_duplicate_ts(self): ts1 = tskit.Tree.generate_comb(3, span=4).tree_sequence ts = ts1.keep_intervals([[0, 1]]).trim() # a quarter of the original diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 58f4f718fc..cab3407c27 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -4337,6 +4337,9 @@ def union( check_shared_equality=True, add_populations=True, record_provenance=True, + *, + all_edges=False, + all_mutations=False, ): """ Modifies the table collection in place by adding the non-shared @@ -4358,6 +4361,10 @@ def union( assigned new population IDs. :param bool record_provenance: Whether to record a provenance entry in the provenance table for this operation. + :param bool all_edges: If True, then all edges in ``other`` are added + to ``self``. + :param bool all_mutations: If True, then all mutations in ``other`` are added + to ``self``. """ node_mapping = util.safe_np_int_cast(node_mapping, np.int32) self._ll_tables.union( @@ -4365,6 +4372,8 @@ def union( node_mapping, check_shared_equality=check_shared_equality, add_populations=add_populations, + all_edges=all_edges, + all_mutations=all_mutations, ) if record_provenance: other_records = [prov.record for prov in other.provenances] diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 561b5a4b24..9904b60a98 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7198,19 +7198,32 @@ def concatenate( self, *args, node_mappings=None, record_provenance=True, add_populations=None ): r""" - Concatenate a set of tree sequences to the right of this one, by repeatedly - calling :meth:`~TreeSequence.union` with an (optional) - node mapping for each of the ``others``. If any node mapping is ``None`` - only map the sample nodes between the input tree sequence and this one, - based on the numerical order of sample node IDs. + Concatenate a set of tree sequences to the right of this one, by shifting + their coordinate systems and adding all edges, sites, mutations, and + any additional nodes, individuals, or populations needed for these. + Concretely, to concatenate an ``other`` tree sequence to ``self``, the value + of ``self.sequence_length`` is added to all genomic coordinates in ``other``, + and then the concatenated tree sequence will contain all edges, sites, and + mutations in both. Which nodes in ``other`` are treated as "new", and hence + added as well, is controlled by ``node_mappings``. Any individuals to which + new nodes belong are added as well. + + The method uses :meth:`.shift` followed by :meth:`.union`, with + ``all_mutations=True``, ``all_edges=True``, and ``check_shared_equality=False``. + + By default, the samples in current and input tree sequences are assumed to + refer to the same nodes, and are matched based on the numerical order of + sample node IDs; all other nodes are assumed to be new. This can be + changed by providing explicit ``node_mappings`` for each input tree sequence + (see below). .. note:: - To add gaps between the concatenated tables, use :meth:`shift` or - to remove gaps, use :meth:`trim` before concatenating. + To add gaps between the concatenated tree sequences, use :meth:`shift` + or to remove gaps, use :meth:`trim` before concatenating. :param TreeSequence \*args: A list of other tree sequences to append to the right of this one. - :param Union[list, None] node_mappings: An list of node mappings for each + :param Union[list, None] node_mappings: A list of node mappings for each input tree sequence in ``args``. Each should either be an array of integers of the same length as the number of nodes in the equivalent input tree sequence (see :meth:`~TreeSequence.union` for details), or @@ -7252,6 +7265,8 @@ def concatenate( other_tables, node_mapping=node_mapping, check_shared_equality=False, # Else checks fail with internal samples + all_mutations=True, + all_edges=True, record_provenance=False, add_populations=add_populations, ) @@ -7480,6 +7495,9 @@ def union( check_shared_equality=True, add_populations=True, record_provenance=True, + *, + all_edges=False, + all_mutations=False, ): """ Returns an expanded tree sequence which contains the node-wise union of @@ -7513,6 +7531,26 @@ def union( nodes are in entirely new populations, then you must set up the population table first, and then union with ``add_populations=False``. + This method makes sense if the "shared" portions of the tree sequences + are equal; the option ``check_shared_equality`` performs a consistency + check that this is true. If this check is disabled, it is very easy to + produce nonsensical results via subtle inconsistencies. + + The behavior above can be changed by ``all_edges`` and ``all_mutations``. + If ``all_edges`` is True, then all edges in ``other`` are added to + ``self``, instead of only edges adjacent to added nodes. If + ``all_mutations`` is True, then similarly all mutations in ``other`` + are added (not just those on added nodes); furthermore, all sites + at positions without a site already present are added to ``self``. + The intended use case for these options is a "disjoint" union, + where for instance the two tree sequences contain information about + disjoint segments of the genome (see :meth:`.concatenate`). + For some such applications it may be necessary to set + ``check_shared_equality=False``: for instance, if ``other`` has + an identical copy of the node table but no edges, then + ``all_mutations=True, check_shared_equality=False`` can be used + to add mutations to ``self``. + If the resulting tree sequence is invalid (for instance, a node is specified to have two distinct parents on the same interval), an error will be raised. @@ -7521,9 +7559,13 @@ def union( resulting tree sequence may not be equal to ``self`` even if nothing new was added (although it would differ only in ordering of the tables). - :param TableCollection other: Another table collection. + :param TreeSequence other: Another tree sequence. :param list node_mapping: An array of node IDs that relate nodes in ``other`` to nodes in ``self``. + :param bool all_edges: If True, then all edges in ``other`` are added + to ``self``. + :param bool all_mutations: If True, then all mutations and sites in + ``other`` are added to ``self``. :param bool check_shared_equality: If True, the shared portions of the tree sequences will be checked for equality. It does so by running :meth:`TreeSequence.subset` on both ``self`` and ``other`` @@ -7542,6 +7584,8 @@ def union( check_shared_equality=check_shared_equality, add_populations=add_populations, record_provenance=record_provenance, + all_edges=all_edges, + all_mutations=all_mutations, ) return tables.tree_sequence()