Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions c/tests/test_tables.c
Original file line number Diff line number Diff line change
Expand Up @@ -11240,6 +11240,109 @@ test_table_collection_union(void)
tsk_table_collection_free(&tables);
}

static void
test_table_collection_disjoint_union(void)
{
int ret;
tsk_id_t ret_id;
tsk_table_collection_t tables;
tsk_table_collection_t tables1;
tsk_table_collection_t tables2;
tsk_table_collection_t tables12;
tsk_id_t node_mapping[4];

tsk_memset(node_mapping, 0xff, sizeof(node_mapping));

ret = tsk_table_collection_init(&tables1, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
tables1.sequence_length = 2;

// set up nodes, which will be shared
// flags, time, pop, ind, metadata, metadata_length
ret_id = tsk_node_table_add_row(
&tables1.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_node_table_add_row(
&tables1.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_node_table_add_row(&tables1.nodes, 0, 0.5, TSK_NULL, TSK_NULL, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_node_table_add_row(&tables1.nodes, 0, 1.5, TSK_NULL, TSK_NULL, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret = tsk_table_collection_copy(&tables1, &tables2, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

// for tables1:
// on [0, 1] we have 0, 1 inherit from 2
// left, right, parent, child, metadata, metadata_length
ret_id = tsk_edge_table_add_row(&tables1.edges, 0.0, 1.0, 2, 0, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_edge_table_add_row(&tables1.edges, 0.0, 1.0, 2, 1, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_site_table_add_row(&tables1.sites, 0.4, "T", 1, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_mutation_table_add_row(
&tables1.mutations, ret_id, 0, TSK_NULL, TSK_UNKNOWN_TIME, NULL, 0, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret = tsk_table_collection_build_index(&tables1, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_table_collection_sort(&tables1, NULL, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

// all this goes in tables12 so far
ret = tsk_table_collection_copy(&tables1, &tables12, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

// for tables2; and need to add to tables12 also:
// on [1, 2] we have 0, 1 inherit from 3
// left, right, parent, child, metadata, metadata_length
ret_id = tsk_edge_table_add_row(&tables2.edges, 1.0, 2.0, 3, 0, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_edge_table_add_row(&tables2.edges, 1.0, 2.0, 3, 1, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_site_table_add_row(&tables2.sites, 1.4, "A", 1, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_mutation_table_add_row(
&tables2.mutations, ret_id, 1, TSK_NULL, TSK_UNKNOWN_TIME, "T", 1, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret = tsk_table_collection_build_index(&tables2, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_table_collection_sort(&tables2, NULL, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
// also tables12
ret_id = tsk_edge_table_add_row(&tables12.edges, 1.0, 2.0, 3, 0, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_edge_table_add_row(&tables12.edges, 1.0, 2.0, 3, 1, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_site_table_add_row(&tables12.sites, 1.4, "A", 1, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_mutation_table_add_row(
&tables12.mutations, ret_id, 1, TSK_NULL, TSK_UNKNOWN_TIME, "T", 1, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret = tsk_table_collection_build_index(&tables12, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_table_collection_sort(&tables12, NULL, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

// now disjoint union-ing tables1 and tables2 should get tables12
ret = tsk_table_collection_copy(&tables1, &tables, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
node_mapping[0] = 0;
node_mapping[1] = 1;
node_mapping[2] = 2;
node_mapping[3] = 3;
ret = tsk_table_collection_union(&tables, &tables2, node_mapping,
TSK_UNION_NO_CHECK_SHARED | TSK_UNION_ALL_EDGES | TSK_UNION_ALL_MUTATIONS);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_FATAL(
tsk_table_collection_equals(&tables, &tables12, TSK_CMP_IGNORE_PROVENANCE));

tsk_table_collection_free(&tables12);
tsk_table_collection_free(&tables2);
tsk_table_collection_free(&tables1);
tsk_table_collection_free(&tables);
}

static void
test_table_collection_union_middle_merge(void)
{
Expand Down Expand Up @@ -11836,6 +11939,7 @@ main(int argc, char **argv)
test_table_collection_subset_unsorted },
{ "test_table_collection_subset_errors", test_table_collection_subset_errors },
{ "test_table_collection_union", test_table_collection_union },
{ "test_table_collection_disjoint_union", test_table_collection_disjoint_union },
{ "test_table_collection_union_middle_merge",
test_table_collection_union_middle_merge },
{ "test_table_collection_union_errors", test_table_collection_union_errors },
Expand Down
25 changes: 22 additions & 3 deletions c/tskit/tables.c
Original file line number Diff line number Diff line change
Expand Up @@ -13202,6 +13202,8 @@ tsk_table_collection_union(tsk_table_collection_t *self,
tsk_id_t *site_map = NULL;
bool add_populations = !(options & TSK_UNION_NO_ADD_POP);
bool check_shared_portion = !(options & TSK_UNION_NO_CHECK_SHARED);
bool all_edges = !!(options & TSK_UNION_ALL_EDGES);
bool all_mutations = !!(options & TSK_UNION_ALL_MUTATIONS);

/* Not calling TSK_CHECK_TREES so casting to int is safe */
ret = (int) tsk_table_collection_check_integrity(self, 0);
Expand Down Expand Up @@ -13285,7 +13287,7 @@ tsk_table_collection_union(tsk_table_collection_t *self,
// edges
for (k = 0; k < (tsk_id_t) other->edges.num_rows; k++) {
tsk_edge_table_get_row_unsafe(&other->edges, k, &edge);
if ((other_node_mapping[edge.parent] == TSK_NULL)
if (all_edges || (other_node_mapping[edge.parent] == TSK_NULL)
|| (other_node_mapping[edge.child] == TSK_NULL)) {
new_parent = node_map[edge.parent];
new_child = node_map[edge.child];
Expand All @@ -13298,14 +13300,31 @@ tsk_table_collection_union(tsk_table_collection_t *self,
}
}

// mutations and sites
// sites
// first do the "disjoint" (all_mutations) case, where we just add all sites;
// otherwise we want to just add sites for new mutations
if (all_mutations) {
for (k = 0; k < (tsk_id_t) other->sites.num_rows; k++) {
tsk_site_table_get_row_unsafe(&other->sites, k, &site);
ret_id = tsk_site_table_add_row(&self->sites, site.position,
site.ancestral_state, site.ancestral_state_length, site.metadata,
site.metadata_length);
if (ret_id < 0) {
ret = (int) ret_id;
goto out;
}
site_map[site.id] = ret_id;
}
}

// mutations (and maybe sites)
i = 0;
for (k = 0; k < (tsk_id_t) other->sites.num_rows; k++) {
tsk_site_table_get_row_unsafe(&other->sites, k, &site);
while ((i < (tsk_id_t) other->mutations.num_rows)
&& (other->mutations.site[i] == site.id)) {
tsk_mutation_table_get_row_unsafe(&other->mutations, i, &mut);
if (other_node_mapping[mut.node] == TSK_NULL) {
if (all_mutations || (other_node_mapping[mut.node] == TSK_NULL)) {
if (site_map[site.id] == TSK_NULL) {
ret_id = tsk_site_table_add_row(&self->sites, site.position,
site.ancestral_state, site.ancestral_state_length, site.metadata,
Expand Down
16 changes: 15 additions & 1 deletion c/tskit/tables.h
Original file line number Diff line number Diff line change
Expand Up @@ -858,11 +858,21 @@ equality of the subsets.
*/
#define TSK_UNION_NO_CHECK_SHARED (1 << 0)
/**
By default, all nodes new to ``self`` are assigned new populations. If this
By default, all nodes new to ``self`` are assigned new populations. If this
option is specified, nodes that are added to ``self`` will retain the
population IDs they have in ``other``.
*/
#define TSK_UNION_NO_ADD_POP (1 << 1)
/**
By default, union only adds edges adjacent to a newly added node;
this option adds all edges.
*/
#define TSK_UNION_ALL_EDGES (1 << 2)
/**
By default, union only adds only mutations on newly added edges, and
sites for those mutations; this option adds all mutations and all sites.
*/
#define TSK_UNION_ALL_MUTATIONS (1 << 3)
/** @} */

/**
Expand Down Expand Up @@ -4414,6 +4424,10 @@ that are exclusive ``other`` are added to ``self``, along with:
By default, populations of newly added nodes are assumed to be new populations,
and added to the population table as well.

The behavior can be changed by the flags ``TSK_UNION_ALL_EDGES`` and
``TSK_UNION_ALL_MUTATIONS``, which will (respectively) add *all* edges
or *all* sites and mutations instead.

This operation will also sort the resulting tables, so the tables may change
even if nothing new is added, if the original tables were not sorted.

Expand Down
7 changes: 7 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@
- Add ``Mutation.inherited_state`` property which returns the inherited state
for a single mutation. (:user:`benjeffery`, :pr:`3277`, :issue:`2631`)

- Add ``all_mutations`` and ``all_edges`` options to ``TreeSequence.union``,
allowing greater flexibility in "disjoint union" situations.
(:user:`hyanwong`, :user:`petrelharp`, :issue:`3181`)

**Bugfixes**

- In some tables with mutations out-of-order ``TableCollection.sort`` did not re-order
Expand Down Expand Up @@ -84,6 +88,9 @@
- Prevent iterating over a ``TopologyCounter``
(:user:`benjeffery` , :pr:`3202`, :issue:`1462`)

- Fix ``TreeSequence.concatenate()`` to work with internal samples by using the
``all_mutations`` and ``all_edges`` parameters in ``union()``
(:user:`hyanwong`, :pr:`3283`, :issue:`3181`)

**Breaking changes**

Expand Down
15 changes: 12 additions & 3 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -4347,15 +4347,18 @@ TableCollection_union(TableCollection *self, PyObject *args, PyObject *kwds)
npy_intp *shape;
tsk_flags_t options = 0;
int check_shared = true;
int all_edges = false;
int all_mutations = false;
int add_populations = true;
static char *kwlist[] = { "other", "other_node_mapping", "check_shared_equality",
"add_populations", NULL };
"add_populations", "all_edges", "all_mutations", NULL };

if (TableCollection_check_state(self) != 0) {
goto out;
}
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!O|ii", kwlist, &TableCollectionType,
&other, &other_node_mapping, &check_shared, &add_populations)) {
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!O|iiii", kwlist,
&TableCollectionType, &other, &other_node_mapping, &check_shared,
&add_populations, &all_edges, &all_mutations)) {
goto out;
}
nmap_array = (PyArrayObject *) PyArray_FROMANY(
Expand All @@ -4370,6 +4373,12 @@ TableCollection_union(TableCollection *self, PyObject *args, PyObject *kwds)
" number of nodes in the other tree sequence.");
goto out;
}
if (all_edges) {
options |= TSK_UNION_ALL_EDGES;
}
if (all_mutations) {
options |= TSK_UNION_ALL_MUTATIONS;
}
if (!check_shared) {
options |= TSK_UNION_NO_CHECK_SHARED;
}
Expand Down
22 changes: 22 additions & 0 deletions python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2536,6 +2536,28 @@ def test_mutation_parent_errors(self, mutations, error):
else:
tables.tree_sequence()

def test_union(self, ts_fixture):
# most of the union tests are in test_tables.py, here we just sanity check
tables = ts_fixture.dump_tables()
tables.migrations.clear() # migrations not supported in union()
ts = tables.tree_sequence()
tables = tskit.TableCollection(ts.sequence_length)
tables.time_units = ts.time_units
empty = tables.tree_sequence()
union_ts = empty.union(
ts,
node_mapping=np.full(ts.num_nodes, tskit.NULL, dtype=int),
all_edges=True,
all_mutations=True,
check_shared_equality=False,
)
union_ts.tables.assert_equals(
ts.tables,
ignore_metadata=True,
ignore_reference_sequence=True,
ignore_provenance=True,
)


class TestSimplify:
# This class was factored out of the old TestHighlevel class 2022-12-13,
Expand Down
19 changes: 19 additions & 0 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,25 @@ def test_union_bad_args(self):
with pytest.raises(ValueError):
tc.union(tc2, np.array([[1], [2]], dtype="int32"))

@pytest.mark.parametrize("value", [True, False])
@pytest.mark.parametrize(
"flag",
[
"all_edges",
"all_mutations",
"check_shared_equality",
"add_populations",
],
)
def test_union_options(self, flag, value):
ts = msprime.simulate(10, random_seed=1)
tc = ts.dump_tables()._ll_tables
empty_tables = ts.dump_tables()
for table in empty_tables.table_name_map.keys():
getattr(empty_tables, table).clear()
tc2 = empty_tables._ll_tables
tc.union(tc2, np.arange(0, dtype="int32"), **{flag: value})

def test_equals_bad_args(self):
ts = msprime.simulate(10, random_seed=1242)
tc = ts.dump_tables()._ll_tables
Expand Down
76 changes: 76 additions & 0 deletions python/tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -5271,6 +5271,82 @@ def test_examples(self):
ts = tables.tree_sequence()
self.verify_union(*self.split_example(ts, T))

def test_split_and_rejoin(self):
ts = self.get_msprime_example(5, T=2, seed=928)
cutpoints = np.array([0, 0.25, 0.5, 0.75, 1]) * ts.sequence_length
tables1 = ts.dump_tables()
tables1.delete_intervals([cutpoints[0:2], cutpoints[2:4]], simplify=False)
tables2 = ts.dump_tables()
tables2.delete_intervals([cutpoints[1:3], cutpoints[3:]], simplify=False)
tables1.union(
tables2,
all_edges=True,
all_mutations=True,
node_mapping=np.arange(ts.num_nodes),
check_shared_equality=False,
)
tables1.edges.squash()
tables1.sort()
tables1.assert_equals(ts.tables, ignore_provenance=True)

def test_both_empty(self):
tables = tskit.TableCollection(sequence_length=1)
t1 = tables.copy()
t2 = tables.copy()
t1.union(t2, node_mapping=np.arange(0), all_edges=True, all_mutations=True)
t1.assert_equals(tables, ignore_provenance=True)

def test_one_empty(self):
ts = self.get_msprime_example(5, T=2, seed=928)
ts = ts.simplify() # the example has a load of unreferenced individuals
tables = ts.dump_tables()
empty = tskit.TableCollection(sequence_length=tables.sequence_length)
empty.time_units = tables.time_units

# union with empty should be no-op
tables.union(
empty, node_mapping=np.arange(0), all_edges=True, all_mutations=True
)
tables.assert_equals(ts.dump_tables(), ignore_provenance=True)

# empty union with tables should be tables
empty.union(
tables,
node_mapping=np.full(tables.nodes.num_rows, tskit.NULL),
all_edges=True,
all_mutations=True,
check_shared_equality=False,
)
empty.assert_equals(tables, ignore_provenance=True)

def test_reciprocal_empty(self):
# reciprocally add mutations from one table and edges from the other
edges_table = tskit.Tree.generate_comb(6, span=6).tree_sequence.dump_tables()
muts_table = tskit.TableCollection(sequence_length=6)
muts_table.nodes.replace_with(edges_table.nodes) # same nodes, no edges
for j in range(0, 6):
site_id = muts_table.sites.add_row(position=j, ancestral_state="0")
if j % 2 == 0:
# Some sites empty
muts_table.mutations.add_row(site=site_id, node=j, derived_state="1")
identity_map = np.arange(len(muts_table.nodes), dtype="int32")
params = {"node_mapping": identity_map, "check_shared_equality": False}

test_table = edges_table.copy()
test_table.union(muts_table, **params, all_edges=True) # null op
assert len(test_table.sites) == 0
assert len(test_table.mutations) == 0
test_table.union(muts_table, **params, all_mutations=True)
assert test_table.sites == muts_table.sites
assert test_table.mutations == muts_table.mutations

muts_table.union(edges_table, **params, all_mutations=True) # null op
assert len(muts_table.edges) == 0
muts_table.union(edges_table, **params, all_edges=True)
assert muts_table.edges == edges_table.edges

muts_table.assert_equals(test_table, ignore_provenance=True)


class TestTableSetitemMetadata:
@pytest.mark.parametrize("table_name", tskit.TABLE_NAMES)
Expand Down
Loading