Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions c/tests/test_tables.c
Original file line number Diff line number Diff line change
Expand Up @@ -5629,6 +5629,7 @@ test_table_collection_union(void)
tsk_table_collection_t tables_empty;
tsk_table_collection_t tables_copy;
tsk_id_t node_mapping[3];
tsk_id_t parents[2] = { -1, -1 };
char example_metadata[100] = "An example of metadata with unicode 🎄🌳🌴🌲🎋";
tsk_size_t example_metadata_length = (tsk_size_t) strlen(example_metadata);

Expand Down Expand Up @@ -5662,13 +5663,15 @@ test_table_collection_union(void)
ret = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 0.5, 1, 2, NULL, 0);
CU_ASSERT_FATAL(ret >= 0);
ret = tsk_individual_table_add_row(
&tables.individuals, 0, NULL, 0, NULL, 0, NULL, 0);
&tables.individuals, 0, NULL, 0, parents, 2, NULL, 0);
CU_ASSERT_FATAL(ret >= 0);
parents[0] = 0;
ret = tsk_individual_table_add_row(
&tables.individuals, 0, NULL, 0, NULL, 0, NULL, 0);
&tables.individuals, 0, NULL, 0, parents, 2, NULL, 0);
CU_ASSERT_FATAL(ret >= 0);
parents[1] = 1;
ret = tsk_individual_table_add_row(
&tables.individuals, 0, NULL, 0, NULL, 0, NULL, 0);
&tables.individuals, 0, NULL, 0, parents, 2, NULL, 0);
CU_ASSERT_FATAL(ret >= 0);
ret = tsk_population_table_add_row(&tables.populations, NULL, 0);
CU_ASSERT_FATAL(ret >= 0);
Expand Down
22 changes: 22 additions & 0 deletions c/tskit/tables.c
Original file line number Diff line number Diff line change
Expand Up @@ -9975,6 +9975,7 @@ tsk_table_collection_union(tsk_table_collection_t *self,
int ret = 0;
tsk_id_t k, i, new_parent, new_child;
tsk_size_t num_shared_nodes = 0;
tsk_size_t num_individuals_self = self->individuals.num_rows;
tsk_edge_t edge;
tsk_mutation_t mut;
tsk_site_t site;
Expand Down Expand Up @@ -10027,6 +10028,18 @@ tsk_table_collection_union(tsk_table_collection_t *self,
memset(population_map, 0xff, other->populations.num_rows * sizeof(*population_map));
memset(site_map, 0xff, other->sites.num_rows * sizeof(*site_map));

/* We have to map the individuals who are linked to nodes in the intersection first
as otherwise an individual linked to one node in the intersection and one in
`other` would be duplicated. We assume that the individual in `self` takes
priority.
*/
for (k = 0; k < (tsk_id_t) other->nodes.num_rows; k++) {
if (other_node_mapping[k] != TSK_NULL
&& other->nodes.individual[k] != TSK_NULL) {
individual_map[other->nodes.individual[k]]
= self->nodes.individual[other_node_mapping[k]];
}
}
// nodes, individuals, populations
for (k = 0; k < (tsk_id_t) other->nodes.num_rows; k++) {
if (other_node_mapping[k] != TSK_NULL) {
Expand All @@ -10040,6 +10053,15 @@ tsk_table_collection_union(tsk_table_collection_t *self,
}
}

/* Now we know the full individual map we can remap the parents of the new
* individuals*/
for (k = (tsk_id_t) self->individuals.parents_offset[num_individuals_self];
k < (tsk_id_t) self->individuals.parents_length; k++) {
if (self->individuals.parents[k] != TSK_NULL) {
self->individuals.parents[k] = individual_map[self->individuals.parents[k]];
}
}

// edges
for (k = 0; k < (tsk_id_t) other->edges.num_rows; k++) {
tsk_edge_table_get_row_unsafe(&other->edges, k, &edge);
Expand Down
144 changes: 72 additions & 72 deletions python/tests/data/svg/ts_multiroot.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 0 additions & 4 deletions python/tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,6 @@ def get_wf_example(self, seed):
seed=seed,
num_loci=3,
record_migrations=True,
record_individuals=True,
)
tables.sort()
ts = tables.tree_sequence()
Expand All @@ -1547,7 +1546,6 @@ def test_wf_example(self):
deep_history=False,
seed=42,
record_migrations=True,
record_individuals=True,
)
self.verify_sort(tables, 42)

Expand Down Expand Up @@ -2405,7 +2403,6 @@ def wf_sim_with_individual_metadata(self):
deep_history=False,
initial_generation_samples=False,
num_loci=5,
record_individuals=True,
)
assert tables.individuals.num_rows > 50
assert np.all(tables.nodes.individual >= 0)
Expand Down Expand Up @@ -2502,7 +2499,6 @@ def test_individual_mapping(self):
deep_history=True,
initial_generation_samples=False,
num_loci=5,
record_individuals=True,
)
assert tables.individuals.num_rows > 50
node_md = []
Expand Down
9 changes: 8 additions & 1 deletion python/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -5790,6 +5790,7 @@ def verify(self, ts):
self.verify_keep_input_roots(ts, samples)

def verify_keep_input_roots(self, ts, samples):
ts = tsutil.insert_unique_metadata(ts, "individuals")
ts_with_roots, node_map = self.do_simplify(
ts, samples, keep_input_roots=True, filter_sites=False, compare_lib=True
)
Expand All @@ -5807,7 +5808,13 @@ def verify_keep_input_roots(self, ts, samples):
new_node = ts_with_roots.node(root)
assert new_node.time == input_node.time
assert new_node.population == input_node.population
assert new_node.individual == input_node.individual
if new_node.individual == tskit.NULL:
assert new_node.individual == input_node.individual
else:
assert (
ts_with_roots.individual(new_node.individual).metadata
== ts.individual(input_node.individual).metadata
)
assert new_node.metadata == input_node.metadata
# This should only be marked as a sample if it's an
# element of the samples list.
Expand Down
23 changes: 8 additions & 15 deletions python/tests/test_wright_fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
num_pops=1,
mig_rate=0.0,
record_migrations=False,
record_individuals=False,
record_individuals=True,
):
self.N = N
self.num_pops = num_pops
Expand Down Expand Up @@ -214,7 +214,7 @@ def wf_sim(
num_pops=1,
mig_rate=0.0,
record_migrations=False,
record_individuals=False,
record_individuals=True,
):
sim = WrightFisherSimulator(
N,
Expand Down Expand Up @@ -248,7 +248,6 @@ def test_one_gen_multipop_mig_no_deep(self):
deep_history=False,
seed=self.random_seed,
record_migrations=True,
record_individuals=True,
)
assert tables.nodes.num_rows == 5 * 4 * (1 + 1)
assert tables.edges.num_rows > 0
Expand All @@ -266,7 +265,6 @@ def test_multipop_mig_deep(self):
mig_rate=1.0,
seed=self.random_seed,
record_migrations=True,
record_individuals=True,
)
assert tables.nodes.num_rows > (num_pops * N * ngens) + N
assert tables.edges.num_rows > 0
Expand Down Expand Up @@ -297,7 +295,6 @@ def test_multipop_mig_no_deep(self):
deep_history=False,
seed=self.random_seed,
record_migrations=True,
record_individuals=True,
)
assert tables.nodes.num_rows == num_pops * N * (ngens + 1)
assert tables.edges.num_rows > 0
Expand All @@ -323,7 +320,7 @@ def test_non_overlapping_generations(self):
assert tables.sites.num_rows == 0
assert tables.mutations.num_rows == 0
assert tables.migrations.num_rows == 0
assert tables.individuals.num_rows == 0
assert tables.individuals.num_rows > 0
tables.sort()
tables.simplify()
ts = tables.tree_sequence()
Expand All @@ -344,7 +341,7 @@ def test_overlapping_generations(self):
assert tables.sites.num_rows == 0
assert tables.mutations.num_rows == 0
assert tables.migrations.num_rows == 0
assert tables.individuals.num_rows == 0
assert tables.individuals.num_rows > 0
tables.sort()
tables.simplify()
ts = tables.tree_sequence()
Expand All @@ -359,7 +356,7 @@ def test_one_generation_no_deep_history(self):
assert tables.sites.num_rows == 0
assert tables.mutations.num_rows == 0
assert tables.migrations.num_rows == 0
assert tables.individuals.num_rows == 0
assert tables.individuals.num_rows > 0
tables.sort()
tables.simplify()
ts = tables.tree_sequence()
Expand All @@ -383,7 +380,7 @@ def test_many_generations_no_deep_history(self):
assert tables.sites.num_rows == 0
assert tables.mutations.num_rows == 0
assert tables.migrations.num_rows == 0
assert tables.individuals.num_rows == 0
assert tables.individuals.num_rows > 0
tables.sort()
tables.simplify()
ts = tables.tree_sequence()
Expand Down Expand Up @@ -448,9 +445,7 @@ def test_with_recurrent_mutations(self):

def test_record_individuals_initial_state(self):
N = 10
tables = wf_sim(
N=N, ngens=0, seed=12345, record_individuals=True, deep_history=False
)
tables = wf_sim(N=N, ngens=0, seed=12345, deep_history=False)
tables.sort()
assert len(tables.individuals) == N
assert len(tables.nodes) == N
Expand All @@ -461,9 +456,7 @@ def test_record_individuals_initial_state(self):

def test_record_individuals(self):
N = 10
tables = wf_sim(
N=N, ngens=10, seed=12345, record_individuals=True, deep_history=False
)
tables = wf_sim(N=N, ngens=10, seed=12345, deep_history=False)
assert len(tables.individuals) == len(tables.nodes)
for node_id, individual in enumerate(tables.nodes.individual):
assert node_id == individual
Expand Down
20 changes: 20 additions & 0 deletions python/tests/tsutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import json
import random
import string
import struct

import numpy as np

Expand Down Expand Up @@ -727,6 +728,12 @@ def py_union(tables, other, nodes, record_provenance=True, add_populations=True)
node_map = [tskit.NULL for _ in range(other.nodes.num_rows + 1)]
site_map = [tskit.NULL for _ in range(other.sites.num_rows + 1)]
mut_map = [tskit.NULL for _ in range(other.mutations.num_rows + 1)]
original_num_individuals = tables.individuals.num_rows

for other_id, node in enumerate(other.nodes):
if nodes[other_id] != tskit.NULL and node.individual != tskit.NULL:
ind_map[node.individual] = tables.nodes[nodes[other_id]].individual

for other_id, node in enumerate(other.nodes):
if nodes[other_id] != tskit.NULL:
node_map[other_id] = nodes[other_id]
Expand Down Expand Up @@ -755,6 +762,11 @@ def py_union(tables, other, nodes, record_provenance=True, add_populations=True)
flags=node.flags,
)
node_map[other_id] = node_id
individuals = tables.individuals
for i in range(
individuals.parents_offset[original_num_individuals], len(individuals.parents)
):
individuals.parents[i] = ind_map[individuals.parents[i]]
for edge in other.edges:
if (nodes[edge.parent] == tskit.NULL) or (nodes[edge.child] == tskit.NULL):
tables.edges.add_row(
Expand Down Expand Up @@ -1745,3 +1757,11 @@ def sort_individual_table(tables):
tables.nodes.individual = [ind_id_map[i] for i in tables.nodes.individual]

return tables


def insert_unique_metadata(ts, table):
tables = ts.dump_tables()
getattr(tables, table).packset_metadata(
[struct.pack("I", i) for i in range(getattr(tables, table).num_rows)]
)
return tables.tree_sequence()