diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index a6af0bb4d1..991a63523d 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -1076,9 +1076,9 @@ test_simplest_unary_with_individuals(void) "1 2 6 4\n"; const char *individuals = "0 0.5 -1,-1\n" "0 1.5,3.1 -1,-1\n" - "0 2.1 -1,-1\n" - "0 3.2 -1,-1\n" - "0 4.2 -1,-1\n"; + "0 2.1 0,1\n" + "0 3.2 1,2\n" + "0 4.2 2,3\n"; const char *nodes_expect = "1 0 0 -1\n" "1 0 0 0\n" "0 1 0 1\n" @@ -1091,8 +1091,8 @@ test_simplest_unary_with_individuals(void) "1 2 5 4\n"; const char *individuals_expect = "0 0.5 -1,-1\n" "0 1.5,3.1 -1,-1\n" - "0 2.1 -1,-1\n" - "0 3.2 -1,-1\n"; + "0 2.1 0,1\n" + "0 3.2 1,2\n"; tsk_treeseq_t ts, simplified, expected; tsk_id_t sample_ids[] = { 0, 1 }; diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 4c448fd9d9..165ed85e9b 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -7352,6 +7352,14 @@ simplifier_finalise_references(simplifier_t *self) } } + /* Remap parent IDs */ + for (j = 0; j < self->tables->individuals.parents_length; j++) { + self->tables->individuals.parents[j] + = self->tables->individuals.parents[j] == TSK_NULL + ? TSK_NULL + : individual_id_map[self->tables->individuals.parents[j]]; + } + /* Remap node IDs referencing the above */ for (j = 0; j < num_nodes; j++) { pop_id = node_population[j]; diff --git a/python/tests/simplify.py b/python/tests/simplify.py index 061a949f68..3ca97c101f 100644 --- a/python/tests/simplify.py +++ b/python/tests/simplify.py @@ -453,7 +453,10 @@ def finalise_references(self): if count > 0: row = input_individuals[input_id] output_id = self.tables.individuals.add_row( - flags=row.flags, location=row.location, metadata=row.metadata + flags=row.flags, + location=row.location, + parents=row.parents, + metadata=row.metadata, ) individual_id_map[input_id] = output_id @@ -468,6 +471,23 @@ def finalise_references(self): population=population_id_map[nodes.population], ) + # Remap the parent ids of individuals + individuals_copy = self.tables.individuals.copy() + self.tables.individuals.clear() + for row in individuals_copy: + mapped_parents = [] + for p in row.parents: + if p == -1: + mapped_parents.append(-1) + else: + mapped_parents.append(individual_id_map[p]) + self.tables.individuals.add_row( + flags=row.flags, + location=row.location, + parents=mapped_parents, + metadata=row.metadata, + ) + # We don't support migrations for now. We'll need to remap these as well. assert self.ts.num_migrations == 0 diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index dbbc8c1c0a..006018edb4 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -31,6 +31,7 @@ import pickle import platform import random +import struct import time import unittest import warnings @@ -2361,6 +2362,109 @@ def test_samples_interface(self): with pytest.raises(OverflowError): tables.simplify(samples=np.array([0, bad_node])) + @pytest.fixture(scope="session") + def wf_sim_with_individual_metadata(self): + tables = wf.wf_sim( + 9, + 10, + seed=1, + deep_history=True, + initial_generation_samples=False, + num_loci=5, + record_individuals=True, + ) + assert tables.individuals.num_rows > 50 + individuals_copy = tables.copy().individuals + tables.individuals.clear() + tables.individuals.metadata_schema = tskit.MetadataSchema({"codec": "json"}) + for i, individual in enumerate(individuals_copy): + tables.individuals.add_row( + flags=individual.flags, + location=individual.location, + parents=individual.parents, + metadata={ + "original_id": i, + "original_parents": [int(p) for p in individual.parents], + }, + ) + tables.sort() + return tables + + def test_individual_parent_mapping(self, wf_sim_with_individual_metadata): + tables = wf_sim_with_individual_metadata.copy() + tables.simplify() + ts = tables.tree_sequence() + for individual in tables.individuals: + for parent, original_parent in zip( + individual.parents, individual.metadata["original_parents"] + ): + if parent != tskit.NULL: + assert ( + ts.individual(parent).metadata["original_id"] == original_parent + ) + assert set(tables.individuals.parents) != {tskit.NULL} + + def test_shuffled_individual_parent_mapping(self, wf_sim_with_individual_metadata): + tables = wf_sim_with_individual_metadata.copy() + tsutil.shuffle_tables( + tables, + 42, + shuffle_edges=False, + shuffle_populations=False, + shuffle_individuals=True, + shuffle_sites=False, + shuffle_mutations=False, + shuffle_migrations=False, + ) + # Check we have a mixed up order + with pytest.raises( + tskit.LibraryError, + match="Individuals must be provided in an order where" + " children are after their parent individuals", + ): + tables.tree_sequence() + + tables.simplify() + metadata = [ + tables.individuals.metadata_schema.decode_row(m) + for m in tskit.unpack_bytes( + tables.individuals.metadata, tables.individuals.metadata_offset + ) + ] + for individual in tables.individuals: + for parent, original_parent in zip( + individual.parents, individual.metadata["original_parents"] + ): + if parent != tskit.NULL: + assert metadata[parent]["original_id"] == original_parent + assert set(tables.individuals.parents) != {tskit.NULL} + + def test_individual_mapping(self): + tables = wf.wf_sim( + 9, + 10, + seed=1, + deep_history=True, + initial_generation_samples=False, + num_loci=5, + record_individuals=True, + ) + assert tables.individuals.num_rows > 50 + node_md = [] + individual_md = [b""] * tables.individuals.num_rows + for i, node in enumerate(tables.nodes): + node_md.append(struct.pack("i", i)) + individual_md[node.individual] = struct.pack("i", i) + tables.nodes.packset_metadata(node_md) + tables.individuals.packset_metadata(individual_md) + tables.sort() + tables.simplify() + ts = tables.tree_sequence() + for node in tables.nodes: + if node.individual != tskit.NULL: + assert ts.individual(node.individual).metadata == node.metadata + assert set(tables.individuals.parents) != {tskit.NULL} + def test_bad_individuals(self, simple_ts_fixture): tables = simple_ts_fixture.dump_tables() tables.individuals.clear() diff --git a/python/tests/test_wright_fisher.py b/python/tests/test_wright_fisher.py index 12efef453b..de82c7057a 100644 --- a/python/tests/test_wright_fisher.py +++ b/python/tests/test_wright_fisher.py @@ -62,6 +62,7 @@ def __init__( num_pops=1, mig_rate=0.0, record_migrations=False, + record_individuals=False, ): self.N = N self.num_pops = num_pops @@ -69,6 +70,7 @@ def __init__( self.survival = survival self.mig_rate = mig_rate self.record_migrations = record_migrations + self.record_individuals = record_individuals self.deep_history = deep_history self.debug = debug self.initial_generation_samples = initial_generation_samples @@ -116,7 +118,12 @@ def run(self, ngens): flags = tskit.NODE_IS_SAMPLE for p in range(self.num_pops): for _ in range(self.N): - tables.nodes.add_row(flags=flags, time=ngens, population=p) + individual = -1 + if self.record_individuals: + individual = tables.individuals.add_row(parents=[-1, -1]) + tables.nodes.add_row( + flags=flags, time=ngens, population=p, individual=individual + ) pops = [ list(range(p * self.N, (p * self.N) + self.N)) for p in range(self.num_pops) @@ -150,8 +157,18 @@ def run(self, ngens): k = 0 for j in range(self.N): if dead[p][j]: - # this is: offspring ID, lparent, rparent, breakpoint - offspring = tables.nodes.add_row(time=t, population=p) + lparent, rparent = new_parents[p][k] + individual = -1 + if self.record_individuals: + individual = tables.individuals.add_row( + parents=[ + tables.nodes[lparent].individual, + tables.nodes[rparent].individual, + ] + ) + offspring = tables.nodes.add_row( + time=t, population=p, individual=individual + ) if parent_pop[p][k] != p and self.record_migrations: tables.migrations.add_row( left=0.0, @@ -161,7 +178,6 @@ def run(self, ngens): dest=p, time=t, ) - lparent, rparent = new_parents[p][k] k += 1 bp = self.random_breakpoint() if self.debug: @@ -182,9 +198,7 @@ def run(self, ngens): flags = tables.nodes.flags flattened = [n for pop in pops for n in pop] flags[flattened] = tskit.NODE_IS_SAMPLE - tables.nodes.set_columns( - flags=flags, time=tables.nodes.time, population=tables.nodes.population - ) + tables.nodes.flags = flags return tables @@ -200,6 +214,7 @@ def wf_sim( num_pops=1, mig_rate=0.0, record_migrations=False, + record_individuals=False, ): sim = WrightFisherSimulator( N, @@ -212,6 +227,7 @@ def wf_sim( num_pops=num_pops, mig_rate=mig_rate, record_migrations=record_migrations, + record_individuals=record_individuals, ) return sim.run(ngens) @@ -232,10 +248,12 @@ def test_one_gen_multipop_mig_no_deep(self): deep_history=False, seed=self.random_seed, record_migrations=True, + record_individuals=True, ) assert tables.nodes.num_rows == 5 * 4 * (1 + 1) assert tables.edges.num_rows > 0 assert tables.migrations.num_rows == 5 * 4 + assert tables.individuals.num_rows == tables.nodes.num_rows def test_multipop_mig_deep(self): N = 10 @@ -248,6 +266,7 @@ def test_multipop_mig_deep(self): mig_rate=1.0, seed=self.random_seed, record_migrations=True, + record_individuals=True, ) assert tables.nodes.num_rows > (num_pops * N * ngens) + N assert tables.edges.num_rows > 0 @@ -255,6 +274,8 @@ def test_multipop_mig_deep(self): assert tables.mutations.num_rows == 0 assert tables.migrations.num_rows >= N * num_pops * ngens assert tables.populations.num_rows == num_pops + assert tables.individuals.num_rows >= num_pops * N * ngens + # sort does not support mig tables.migrations.clear() # making sure trees are valid @@ -276,6 +297,7 @@ def test_multipop_mig_no_deep(self): deep_history=False, seed=self.random_seed, record_migrations=True, + record_individuals=True, ) assert tables.nodes.num_rows == num_pops * N * (ngens + 1) assert tables.edges.num_rows > 0 @@ -283,6 +305,8 @@ def test_multipop_mig_no_deep(self): assert tables.mutations.num_rows == 0 assert tables.migrations.num_rows == N * num_pops * ngens assert tables.populations.num_rows == num_pops + assert tables.individuals.num_rows == tables.nodes.num_rows + # FIXME this is no longer needed. # sort does not support mig tables.migrations.clear() # making sure trees are valid @@ -299,6 +323,7 @@ def test_non_overlapping_generations(self): assert tables.sites.num_rows == 0 assert tables.mutations.num_rows == 0 assert tables.migrations.num_rows == 0 + assert tables.individuals.num_rows == 0 tables.sort() tables.simplify() ts = tables.tree_sequence() @@ -319,6 +344,7 @@ def test_overlapping_generations(self): assert tables.sites.num_rows == 0 assert tables.mutations.num_rows == 0 assert tables.migrations.num_rows == 0 + assert tables.individuals.num_rows == 0 tables.sort() tables.simplify() ts = tables.tree_sequence() @@ -333,6 +359,7 @@ def test_one_generation_no_deep_history(self): assert tables.sites.num_rows == 0 assert tables.mutations.num_rows == 0 assert tables.migrations.num_rows == 0 + assert tables.individuals.num_rows == 0 tables.sort() tables.simplify() ts = tables.tree_sequence() @@ -356,6 +383,7 @@ def test_many_generations_no_deep_history(self): assert tables.sites.num_rows == 0 assert tables.mutations.num_rows == 0 assert tables.migrations.num_rows == 0 + assert tables.individuals.num_rows == 0 tables.sort() tables.simplify() ts = tables.tree_sequence() @@ -418,12 +446,54 @@ def test_with_recurrent_mutations(self): for hap in ts.haplotypes(): assert len(hap) == ts.num_sites + def test_record_individuals_initial_state(self): + N = 10 + tables = wf_sim( + N=N, ngens=0, seed=12345, record_individuals=True, deep_history=False + ) + tables.sort() + assert len(tables.individuals) == N + assert len(tables.nodes) == N + for individual in list(tables.individuals)[:N]: + assert list(individual.parents) == [-1, -1] + for j, node in enumerate(tables.nodes): + assert node.individual == j + + def test_record_individuals(self): + N = 10 + tables = wf_sim( + N=N, ngens=10, seed=12345, record_individuals=True, deep_history=False + ) + tables.sort() + assert len(tables.individuals) == len(tables.nodes) + ts = tables.tree_sequence() + for node in ts.nodes(): + assert node.individual == node.id -class TestIncrementalBuild: - """ - Tests for incrementally building a tree sequence from forward time - simulations. - """ + for tree in ts.trees(): + for u in tree.nodes(): + parent_node = tree.parent(u) + # We already know the individual has the same ID as the node + individual = ts.individual(u) + assert parent_node in individual.parents + + +def get_wf_sims(seed): + wf_sims = [] + for N in [5, 10, 20]: + for surv in [0.0, 0.5, 0.9]: + for mut in [0.01, 1.0]: + for nloci in [1, 2, 3]: + tables = wf_sim(N=N, ngens=N, survival=surv, seed=seed) + tables.sort() + ts = tables.tree_sequence() + ts = tsutil.jukes_cantor(ts, num_sites=nloci, mu=mut, seed=seed) + wf_sims.append(ts) + return wf_sims + + +# List of simulations used to parametrize tests. +wf_sims = get_wf_sims(1234) class TestSimplify: @@ -431,60 +501,6 @@ class TestSimplify: Tests for simplify on cases generated by the Wright-Fisher simulator. """ - def assertArrayEqual(self, x, y): - nt.assert_equal(x, y) - - def assertTreeSequencesEqual(self, ts1, ts2): - assert list(ts1.samples()) == list(ts2.samples()) - assert ts1.sequence_length == ts2.sequence_length - ts1_tables = ts1.dump_tables() - ts2_tables = ts2.dump_tables() - # print("compare") - # print(ts1_tables.nodes) - # print(ts2_tables.nodes) - assert ts1_tables.nodes == ts2_tables.nodes - assert ts1_tables.edges == ts2_tables.edges - assert ts1_tables.sites == ts2_tables.sites - assert ts1_tables.mutations == ts2_tables.mutations - - def get_wf_sims(self, seed): - """ - Returns an iterator of example tree sequences produced by the WF simulator. - """ - for N in [5, 10, 20]: - for surv in [0.0, 0.5, 0.9]: - for mut in [0.01, 1.0]: - for nloci in [1, 2, 3]: - tables = wf_sim(N=N, ngens=N, survival=surv, seed=seed) - tables.sort() - ts = tables.tree_sequence() - ts = tsutil.jukes_cantor(ts, num_sites=nloci, mu=mut, seed=seed) - self.verify_simulation(ts, ngens=N) - yield ts - - def verify_simulation(self, ts, ngens): - """ - Verify that in the full set of returned tables there is parentage - information for every individual, except those initially present. - """ - tables = ts.dump_tables() - for u in range(tables.nodes.num_rows): - if tables.nodes.time[u] <= ngens: - lefts = [] - rights = [] - k = 0 - for edge in ts.edges(): - if u == edge.child: - lefts.append(edge.left) - rights.append(edge.right) - k += 1 - lefts.sort() - rights.sort() - assert lefts[0] == 0.0 - assert rights[-1] == 1.0 - for k in range(len(lefts) - 1): - assert lefts[k + 1] == rights[k] - def verify_simplify(self, ts, new_ts, samples, node_map): """ Check that trees in `ts` match `new_ts` using the specified node_map. @@ -522,7 +538,7 @@ def verify_simplify(self, ts, new_ts, samples, node_map): assert mrca2 != tskit.NULL assert node_map[mrca1] == mrca2 mut_parent = tsutil.compute_mutation_parent(ts=ts) - self.assertArrayEqual(mut_parent, ts.tables.mutations.parent) + nt.assert_equal(mut_parent, ts.tables.mutations.parent) def verify_haplotypes(self, ts, samples): """ @@ -541,42 +557,38 @@ def verify_haplotypes(self, ts, samples): mapped_ids.append(mapped_node_id) assert sorted(mapped_ids) == sorted(sub_ts.samples()) - @pytest.mark.slow - def test_simplify(self): - # check that simplify(big set) -> simplify(subset) equals simplify(subset) - seed = 23 - random.seed(seed) - for ts in self.get_wf_sims(seed=seed): - s = tests.Simplifier(ts, ts.samples()) - py_full_ts, py_full_map = s.simplify() - full_ts, full_map = ts.simplify(ts.samples(), map_nodes=True) - assert all(py_full_map == full_map) - self.assertTreeSequencesEqual(full_ts, py_full_ts) - - for nsamples in [2, 5, 10]: - sub_samples = random.sample( - list(ts.samples()), min(nsamples, ts.sample_size) - ) - s = tests.Simplifier(ts, sub_samples) - py_small_ts, py_small_map = s.simplify() - small_ts, small_map = ts.simplify(samples=sub_samples, map_nodes=True) - self.assertTreeSequencesEqual(small_ts, py_small_ts) - self.verify_simplify(ts, small_ts, sub_samples, small_map) - self.verify_haplotypes(ts, samples=sub_samples) - - @pytest.mark.slow - def test_simplify_tables(self): - seed = 71 - for ts in self.get_wf_sims(seed=seed): - for nsamples in [2, 5, 10]: - tables = ts.dump_tables() - sub_samples = random.sample( - list(ts.samples()), min(nsamples, ts.num_samples) - ) - node_map = tables.simplify(samples=sub_samples) - small_ts = tables.tree_sequence() - other_tables = small_ts.dump_tables() - tables.provenances.clear() - other_tables.provenances.clear() - assert tables == other_tables - self.verify_simplify(ts, small_ts, sub_samples, node_map) + @pytest.mark.parametrize("ts", wf_sims) + def test_python_simplify_all_samples(self, ts): + s = tests.Simplifier(ts, ts.samples()) + py_full_ts, py_full_map = s.simplify() + full_ts, full_map = ts.simplify(ts.samples(), map_nodes=True) + assert all(py_full_map == full_map) + tsutil.assert_table_collections_equal( + full_ts.tables, py_full_ts.tables, ignore_provenance=True + ) + + @pytest.mark.parametrize("ts", wf_sims) + @pytest.mark.parametrize("nsamples", [2, 5, 10]) + def test_python_simplify_sample_subset(self, ts, nsamples): + sub_samples = random.sample(list(ts.samples()), min(nsamples, ts.sample_size)) + s = tests.Simplifier(ts, sub_samples) + py_small_ts, py_small_map = s.simplify() + small_ts, small_map = ts.simplify(samples=sub_samples, map_nodes=True) + tsutil.assert_table_collections_equal( + small_ts.tables, py_small_ts.tables, ignore_provenance=True + ) + self.verify_simplify(ts, small_ts, sub_samples, small_map) + self.verify_haplotypes(ts, samples=sub_samples) + + @pytest.mark.parametrize("ts", wf_sims) + @pytest.mark.parametrize("nsamples", [2, 5, 10]) + def test_simplify_tables(self, ts, nsamples): + tables = ts.dump_tables() + sub_samples = random.sample(list(ts.samples()), min(nsamples, ts.num_samples)) + node_map = tables.simplify(samples=sub_samples) + small_ts = tables.tree_sequence() + other_tables = small_ts.dump_tables() + tables.provenances.clear() + other_tables.provenances.clear() + assert tables == other_tables + self.verify_simplify(ts, small_ts, sub_samples, node_map) diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 7f7186e741..ae1aa2272c 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -873,6 +873,10 @@ def shuffle_tables( ind_id_map[j] = tables.individuals.add_row( flags=i.flags, location=i.location, parents=i.parents, metadata=i.metadata ) + tables.individuals.parents = [ + tskit.NULL if i == tskit.NULL else ind_id_map[i] + for i in tables.individuals.parents + ] # nodes (same order, but remapped populations and individuals) for n in orig.nodes: tables.nodes.add_row(