From e49fe6d542643cbbc171ad663a5db18f447ae287 Mon Sep 17 00:00:00 2001 From: peter Date: Fri, 29 Jan 2021 15:45:16 -0800 Subject: [PATCH] Implement keep_unary_in_individuals in Python and fix the associated simplify bug --- c/tskit/tables.c | 2 +- python/_tskitmodule.c | 15 +++++--- python/tests/simplify.py | 26 +++++++++++--- python/tests/test_lowlevel.py | 2 ++ python/tests/test_tables.py | 23 ++++++++++++- python/tests/test_topology.py | 18 ++++++---- python/tests/test_utilities.py | 2 +- python/tests/test_vcf.py | 4 +-- python/tests/test_wright_fisher.py | 55 ++++++++++++++++++++++++++++++ python/tests/tsutil.py | 36 +++++++++++++------ python/tskit/tables.py | 16 +++++++-- python/tskit/trees.py | 13 +++++-- 12 files changed, 174 insertions(+), 38 deletions(-) diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 165ed85e9b..e11cb77561 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -6982,7 +6982,7 @@ simplifier_merge_ancestors(simplifier_t *self, tsk_id_t input_id) keep_unary = true; } if ((self->options & TSK_KEEP_UNARY_IN_INDIVIDUALS) - && (self->tables->nodes.individual[input_id] != TSK_NULL)) { + && (self->input_tables.nodes.individual[input_id] != TSK_NULL)) { keep_unary = true; } diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index c3f595b11a..5e61f37e4d 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -4960,18 +4960,20 @@ TableCollection_simplify(TableCollection *self, PyObject *args, PyObject *kwds) int filter_individuals = false; int filter_populations = false; int keep_unary = false; + int keep_unary_in_individuals = false; int keep_input_roots = false; int reduce_to_site_topology = false; - static char *kwlist[] - = { "samples", "filter_sites", "filter_populations", "filter_individuals", - "reduce_to_site_topology", "keep_unary", "keep_input_roots", NULL }; + static char *kwlist[] = { "samples", "filter_sites", "filter_populations", + "filter_individuals", "reduce_to_site_topology", "keep_unary", + "keep_unary_in_individuals", "keep_input_roots", NULL }; if (TableCollection_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|iiiiii", kwlist, &samples, + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|iiiiiii", kwlist, &samples, &filter_sites, &filter_populations, &filter_individuals, - &reduce_to_site_topology, &keep_unary, &keep_input_roots)) { + &reduce_to_site_topology, &keep_unary, &keep_unary_in_individuals, + &keep_input_roots)) { goto out; } samples_array = (PyArrayObject *) PyArray_FROMANY( @@ -4996,6 +4998,9 @@ TableCollection_simplify(TableCollection *self, PyObject *args, PyObject *kwds) if (keep_unary) { options |= TSK_KEEP_UNARY; } + if (keep_unary_in_individuals) { + options |= TSK_KEEP_UNARY_IN_INDIVIDUALS; + } if (keep_input_roots) { options |= TSK_KEEP_INPUT_ROOTS; } diff --git a/python/tests/simplify.py b/python/tests/simplify.py index 3ca97c101f..e82e778a0a 100644 --- a/python/tests/simplify.py +++ b/python/tests/simplify.py @@ -109,6 +109,7 @@ def __init__( filter_populations=True, filter_individuals=True, keep_unary=False, + keep_unary_in_individuals=False, keep_input_roots=False, ): self.ts = ts @@ -119,6 +120,7 @@ def __init__( self.filter_populations = filter_populations self.filter_individuals = filter_individuals self.keep_unary = keep_unary + self.keep_unary_in_individuals = keep_unary_in_individuals self.keep_input_roots = keep_input_roots self.num_mutations = ts.num_mutations self.input_sites = list(ts.sites()) @@ -295,7 +297,10 @@ def merge_labeled_ancestors(self, S, input_id): if is_sample: self.record_edge(left, right, output_id, ancestry_node) ancestry_node = output_id - elif self.keep_unary: + elif self.keep_unary or ( + self.keep_unary_in_individuals + and self.ts.node(input_id).individual >= 0 + ): if output_id == -1: output_id = self.record_node(input_id) self.record_edge(left, right, output_id, ancestry_node) @@ -308,7 +313,10 @@ def merge_labeled_ancestors(self, S, input_id): if is_sample and left != prev_right: # Fill in any gaps in the ancestry for the sample self.add_ancestry(input_id, prev_right, left, output_id) - if self.keep_unary: + if self.keep_unary or ( + self.keep_unary_in_individuals + and self.ts.node(input_id).individual >= 0 + ): ancestry_node = output_id self.add_ancestry(input_id, left, right, ancestry_node) prev_right = right @@ -757,7 +765,6 @@ def print_state(self): samples = list(map(int, sys.argv[3:])) - # When keep_unary = True print("When keep_unary = True:") s = Simplifier(ts, samples, keep_unary=True) # s.print_state() @@ -768,8 +775,7 @@ def print_state(self): print(tables.sites) print(tables.mutations) - # When keep_unary = False - print("\nWhen keep_unary = False:") + print("\nWhen keep_unary = False") s = Simplifier(ts, samples, keep_unary=False) # s.print_state() tss, _ = s.simplify() @@ -779,6 +785,16 @@ def print_state(self): print(tables.sites) print(tables.mutations) + print("\nWhen keep_unary_in_individuals = True") + s = Simplifier(ts, samples, keep_unary_in_individuals=True) + # s.print_state() + tss, _ = s.simplify() + tables = tss.dump_tables() + print(tables.nodes) + print(tables.edges) + print(tables.sites) + print(tables.mutations) + elif class_to_implement == "AncestorMap": samples = sys.argv[3] diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index bdb81a6326..f0dbb6cda3 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -263,6 +263,8 @@ def test_simplify_bad_args(self): tc.simplify("asdf") with pytest.raises(TypeError): tc.simplify([0, 1], keep_unary="sdf") + with pytest.raises(TypeError): + tc.simplify([0, 1], keep_unary_in_individuals="abc") with pytest.raises(TypeError): tc.simplify([0, 1], keep_input_roots="sdf") with pytest.raises(TypeError): diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 006018edb4..310479f13d 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -2368,12 +2368,13 @@ def wf_sim_with_individual_metadata(self): 9, 10, seed=1, - deep_history=True, + deep_history=False, initial_generation_samples=False, num_loci=5, record_individuals=True, ) assert tables.individuals.num_rows > 50 + assert np.all(tables.nodes.individual >= 0) individuals_copy = tables.copy().individuals tables.individuals.clear() tables.individuals.metadata_schema = tskit.MetadataSchema({"codec": "json"}) @@ -2404,6 +2405,26 @@ def test_individual_parent_mapping(self, wf_sim_with_individual_metadata): ) assert set(tables.individuals.parents) != {tskit.NULL} + def verify_complete_genetic_pedigree(self, tables): + ts = tables.tree_sequence() + for edge in ts.edges(): + child = ts.individual(ts.node(edge.child).individual) + parent = ts.individual(ts.node(edge.parent).individual) + assert parent.id in child.parents + assert parent.metadata["original_id"] in child.metadata["original_parents"] + + def test_no_complete_genetic_pedigree(self, wf_sim_with_individual_metadata): + tables = wf_sim_with_individual_metadata.copy() + tables.simplify() # Will remove intermediate individuals + with pytest.raises(AssertionError): + self.verify_complete_genetic_pedigree(tables) + + def test_complete_genetic_pedigree(self, wf_sim_with_individual_metadata): + for params in [{"keep_unary": True}, {"keep_unary_in_individuals": True}]: + tables = wf_sim_with_individual_metadata.copy() + tables.simplify(**params) # Keep intermediate individuals + self.verify_complete_genetic_pedigree(tables) + def test_shuffled_individual_parent_mapping(self, wf_sim_with_individual_metadata): tables = wf_sim_with_individual_metadata.copy() tsutil.shuffle_tables( diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index 8e890d0b26..3c1e226cae 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -2363,7 +2363,7 @@ def test_ladder_tree(self): def verify_unary_tree_sequence(self, ts): """ Take the specified tree sequence and produce an equivalent in which - unary records have been interspersed. + unary records have been interspersed, every other with an associated individual """ assert ts.num_trees > 2 assert ts.num_mutations > 2 @@ -2371,11 +2371,12 @@ def verify_unary_tree_sequence(self, ts): next_node = ts.num_nodes node_times = {j: node.time for j, node in enumerate(ts.nodes())} edges = [] - for e in ts.edges(): + for i, e in enumerate(ts.edges()): node = ts.node(e.parent) t = node.time - 1e-14 # Arbitrary small value. next_node = len(tables.nodes) - tables.nodes.add_row(time=t, population=node.population) + indiv = tables.individuals.add_row() if i % 2 == 0 else tskit.NULL + tables.nodes.add_row(time=t, population=node.population, individual=indiv) edges.append( tskit.Edge(left=e.left, right=e.right, parent=next_node, child=e.child) ) @@ -2398,11 +2399,16 @@ def verify_unary_tree_sequence(self, ts): self.assert_haplotypes_equal(ts, ts_simplified) self.assert_variants_equal(ts, ts_simplified) assert len(list(ts.edge_diffs())) == ts.num_trees + assert 0 < ts_new.num_individuals < ts_new.num_nodes - for keep_unary in [True, False]: - s = tests.Simplifier(ts, ts.samples(), keep_unary=keep_unary) + for params in [ + {"keep_unary": False, "keep_unary_in_individuals": False}, + {"keep_unary": True, "keep_unary_in_individuals": False}, + {"keep_unary": False, "keep_unary_in_individuals": True}, + ]: + s = tests.Simplifier(ts_new, ts_new.samples(), **params) py_ts, py_node_map = s.simplify() - lib_ts, lib_node_map = ts.simplify(keep_unary=keep_unary, map_nodes=True) + lib_ts, lib_node_map = ts_new.simplify(map_nodes=True, **params) py_tables = py_ts.dump_tables() py_tables.provenances.clear() lib_tables = lib_ts.dump_tables() diff --git a/python/tests/test_utilities.py b/python/tests/test_utilities.py index 824ccb500f..1a12a83843 100644 --- a/python/tests/test_utilities.py +++ b/python/tests/test_utilities.py @@ -138,7 +138,7 @@ def test_ploidy_2_reversed(self): ts = msprime.simulate(10, random_seed=1) assert ts.num_individuals == 0 samples = ts.samples()[::-1] - ts = tsutil.insert_individuals(ts, samples=samples, ploidy=2) + ts = tsutil.insert_individuals(ts, nodes=samples, ploidy=2) assert ts.num_individuals == 5 for j, ind in enumerate(ts.individuals()): assert list(ind.nodes) == [samples[2 * j + 1], samples[2 * j]] diff --git a/python/tests/test_vcf.py b/python/tests/test_vcf.py index 6e7b1f27b0..5e5b1f1a2a 100644 --- a/python/tests/test_vcf.py +++ b/python/tests/test_vcf.py @@ -212,14 +212,14 @@ def test_simple_infinite_sites_ploidy_2(self): def test_simple_infinite_sites_ploidy_2_reversed_samples(self): ts = msprime.simulate(10, mutation_rate=1, random_seed=2) samples = ts.samples()[::-1] - ts = tsutil.insert_individuals(ts, samples=samples, ploidy=2) + ts = tsutil.insert_individuals(ts, nodes=samples, ploidy=2) assert ts.num_sites > 2 self.verify(ts) def test_simple_infinite_sites_ploidy_2_even_samples(self): ts = msprime.simulate(20, mutation_rate=1, random_seed=2) samples = ts.samples()[0::2] - ts = tsutil.insert_individuals(ts, samples=samples, ploidy=2) + ts = tsutil.insert_individuals(ts, nodes=samples, ploidy=2) assert ts.num_sites > 2 self.verify(ts) diff --git a/python/tests/test_wright_fisher.py b/python/tests/test_wright_fisher.py index de82c7057a..4d072956d2 100644 --- a/python/tests/test_wright_fisher.py +++ b/python/tests/test_wright_fisher.py @@ -592,3 +592,58 @@ def test_simplify_tables(self, ts, nsamples): other_tables.provenances.clear() assert tables == other_tables self.verify_simplify(ts, small_ts, sub_samples, node_map) + + @pytest.mark.parametrize("ts", wf_sims) + @pytest.mark.parametrize("nsamples", [2, 5]) + def test_simplify_keep_unary(self, ts, nsamples): + np.random.seed(123) + ts = tsutil.mark_metadata(ts, "nodes") + sub_samples = random.sample(list(ts.samples()), min(nsamples, ts.num_samples)) + random_nodes = np.random.choice(ts.num_nodes, ts.num_nodes // 2) + ts = tsutil.insert_individuals(ts, random_nodes) + ts = tsutil.mark_metadata(ts, "individuals") + + for params in [{}, {"keep_unary": True}, {"keep_unary_in_individuals": True}]: + sts = ts.simplify(sub_samples, **params) + # check samples match + assert sts.num_samples == len(sub_samples) + for n, sn in zip(sub_samples, sts.samples()): + assert ts.node(n).metadata == sts.node(sn).metadata + + # check that nodes are correctly retained: only nodes ancestral to + # retained samples, and: by default, only coalescent events; if + # keep_unary_in_individuals then also nodes in individuals; if + # keep_unary then all such nodes. + for t in ts.trees(tracked_samples=sub_samples): + st = sts.at(t.interval[0]) + visited = [False for _ in sts.nodes()] + for n, sn in zip(sub_samples, sts.samples()): + last_n = t.num_tracked_samples(n) + while n != tskit.NULL: + ind = ts.node(n).individual + keep = False + if t.num_tracked_samples(n) > last_n: + # a coalescent node + keep = True + if "keep_unary_in_individuals" in params and ind != tskit.NULL: + keep = True + if "keep_unary" in params: + keep = True + if (n in sub_samples) or keep: + visited[sn] = True + assert sn != tskit.NULL + assert ts.node(n).metadata == sts.node(sn).metadata + assert t.num_tracked_samples(n) == st.num_samples(sn) + if ind != tskit.NULL: + sind = sts.node(sn).individual + assert sind != tskit.NULL + assert ( + ts.individual(ind).metadata + == sts.individual(sind).metadata + ) + sn = st.parent(sn) + last_n = t.num_tracked_samples(n) + n = t.parent(n) + st_nodes = list(st.nodes()) + for k, v in enumerate(visited): + assert v == (k in st_nodes) diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index ae1aa2272c..24c66f9b59 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -242,30 +242,44 @@ def insert_random_ploidy_individuals(ts, max_ploidy=5, max_dimension=3, seed=1): return tables.tree_sequence() -def insert_individuals(ts, samples=None, ploidy=1): +def insert_individuals(ts, nodes=None, ploidy=1): """ Inserts individuals into the tree sequence using the specified list - of samples (or all samples if None) with the specified ploidy by combining - ploidy-sized chunks of the list. + of node (or use all sample nodes if None) with the specified ploidy by combining + ploidy-sized chunks of the list. Add metadata to the individuals so we can + track them """ - if samples is None: - samples = ts.samples() - if len(samples) % ploidy != 0: - raise ValueError("number of samples must be divisible by ploidy") + if nodes is None: + nodes = ts.samples() + assert len(nodes) % ploidy == 0 # To allow mixed ploidies we could comment this out tables = ts.dump_tables() tables.individuals.clear() individual = tables.nodes.individual[:] individual[:] = tskit.NULL j = 0 - while j < len(samples): - nodes = samples[j : j + ploidy] - ind_id = tables.individuals.add_row() - individual[nodes] = ind_id + while j < len(nodes): + nodes_in_individual = nodes[j : min(len(nodes), j + ploidy)] + # should we warn here if nodes[j : j + ploidy] are at different times? + # probably not, as although this is unusual, it is actually allowed + ind_id = tables.individuals.add_row( + metadata=f"orig_id {tables.individuals.num_rows}".encode() + ) + individual[nodes_in_individual] = ind_id j += ploidy tables.nodes.individual = individual return tables.tree_sequence() +def mark_metadata(ts, table_name, prefix="orig_id:"): + """ + Add metadata to all rows of the form prefix + row_number + """ + tables = ts.dump_tables() + table = getattr(tables, table_name) + table.packset_metadata([(prefix + str(i)).encode() for i in range(table.num_rows)]) + return tables.tree_sequence() + + def permute_nodes(ts, node_map): """ Returns a copy of the specified tree sequence such that the nodes are diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 1b000b9d8c..1daed02c20 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -2493,6 +2493,7 @@ def simplify( filter_individuals=True, filter_sites=True, keep_unary=False, + keep_unary_in_individuals=None, keep_input_roots=False, record_provenance=True, filter_zero_mutation_sites=None, # Deprecated alias for filter_sites @@ -2538,9 +2539,14 @@ def simplify( not referenced by mutations after simplification; new site IDs are allocated sequentially from zero. If False, the site table will not be altered in any way. (Default: True) - :param bool keep_unary: If True, any unary nodes (i.e. nodes with exactly - one child) that exist on the path from samples to root will be preserved - in the output. (Default: False) + :param bool keep_unary: If True, preserve unary nodes (i.e. nodes with + exactly one child) that exist on the path from samples to root. + (Default: False) + :param bool keep_unary_in_individuals: If True, preserve unary nodes + that exist on the path from samples to root, but only if they are + associated with an individual in the individuals table. Cannot be + specified at the same time as ``keep_unary``. (Default: ``None``, + equivalent to False) :param bool keep_input_roots: Whether to retain history ancestral to the MRCA of the samples. If ``False``, no topology older than the MRCAs of the samples will be included. If ``True`` the roots of all trees in the returned @@ -2568,6 +2574,9 @@ def simplify( ].astype(np.int32) else: samples = util.safe_np_int_cast(samples, np.int32) + if keep_unary_in_individuals is None: + keep_unary_in_individuals = False + node_map = self._ll_tables.simplify( samples, filter_sites=filter_sites, @@ -2575,6 +2584,7 @@ def simplify( filter_populations=filter_populations, reduce_to_site_topology=reduce_to_site_topology, keep_unary=keep_unary, + keep_unary_in_individuals=keep_unary_in_individuals, keep_input_roots=keep_input_roots, ) if record_provenance: diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 4c5ccfb1ee..1d31943474 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -4917,6 +4917,7 @@ def simplify( filter_individuals=True, filter_sites=True, keep_unary=False, + keep_unary_in_individuals=None, keep_input_roots=False, record_provenance=True, filter_zero_mutation_sites=None, # Deprecated alias for filter_sites @@ -4979,9 +4980,14 @@ def simplify( not referenced by mutations after simplification; new site IDs are allocated sequentially from zero. If False, the site table will not be altered in any way. (Default: True) - :param bool keep_unary: If True, any unary nodes (i.e. nodes with exactly - one child) that exist on the path from samples to root will be preserved - in the output. (Default: False) + :param bool keep_unary: If True, preserve unary nodes (i.e. nodes with + exactly one child) that exist on the path from samples to root. + (Default: False) + :param bool keep_unary_in_individuals: If True, preserve unary nodes + that exist on the path from samples to root, but only if they are + associated with an individual in the individuals table. Cannot be + specified at the same time as ``keep_unary``. (Default: ``None``, + equivalent to False) :param bool keep_input_roots: Whether to retain history ancestral to the MRCA of the samples. If ``False``, no topology older than the MRCAs of the samples will be included. If ``True`` the roots of all trees in the returned @@ -5008,6 +5014,7 @@ def simplify( filter_individuals=filter_individuals, filter_sites=filter_sites, keep_unary=keep_unary, + keep_unary_in_individuals=keep_unary_in_individuals, keep_input_roots=keep_input_roots, record_provenance=record_provenance, filter_zero_mutation_sites=filter_zero_mutation_sites,