Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion c/tskit/tables.c
Original file line number Diff line number Diff line change
Expand Up @@ -6982,7 +6982,7 @@ simplifier_merge_ancestors(simplifier_t *self, tsk_id_t input_id)
keep_unary = true;
}
if ((self->options & TSK_KEEP_UNARY_IN_INDIVIDUALS)
&& (self->tables->nodes.individual[input_id] != TSK_NULL)) {
&& (self->input_tables.nodes.individual[input_id] != TSK_NULL)) {
keep_unary = true;
}

Expand Down
15 changes: 10 additions & 5 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -4960,18 +4960,20 @@ TableCollection_simplify(TableCollection *self, PyObject *args, PyObject *kwds)
int filter_individuals = false;
int filter_populations = false;
int keep_unary = false;
int keep_unary_in_individuals = false;
int keep_input_roots = false;
int reduce_to_site_topology = false;
static char *kwlist[]
= { "samples", "filter_sites", "filter_populations", "filter_individuals",
"reduce_to_site_topology", "keep_unary", "keep_input_roots", NULL };
static char *kwlist[] = { "samples", "filter_sites", "filter_populations",
"filter_individuals", "reduce_to_site_topology", "keep_unary",
"keep_unary_in_individuals", "keep_input_roots", NULL };

if (TableCollection_check_state(self) != 0) {
goto out;
}
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|iiiiii", kwlist, &samples,
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|iiiiiii", kwlist, &samples,
&filter_sites, &filter_populations, &filter_individuals,
&reduce_to_site_topology, &keep_unary, &keep_input_roots)) {
&reduce_to_site_topology, &keep_unary, &keep_unary_in_individuals,
&keep_input_roots)) {
goto out;
}
samples_array = (PyArrayObject *) PyArray_FROMANY(
Expand All @@ -4996,6 +4998,9 @@ TableCollection_simplify(TableCollection *self, PyObject *args, PyObject *kwds)
if (keep_unary) {
options |= TSK_KEEP_UNARY;
}
if (keep_unary_in_individuals) {
options |= TSK_KEEP_UNARY_IN_INDIVIDUALS;
}
if (keep_input_roots) {
options |= TSK_KEEP_INPUT_ROOTS;
}
Expand Down
26 changes: 21 additions & 5 deletions python/tests/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
filter_populations=True,
filter_individuals=True,
keep_unary=False,
keep_unary_in_individuals=False,
keep_input_roots=False,
):
self.ts = ts
Expand All @@ -119,6 +120,7 @@ def __init__(
self.filter_populations = filter_populations
self.filter_individuals = filter_individuals
self.keep_unary = keep_unary
self.keep_unary_in_individuals = keep_unary_in_individuals
self.keep_input_roots = keep_input_roots
self.num_mutations = ts.num_mutations
self.input_sites = list(ts.sites())
Expand Down Expand Up @@ -295,7 +297,10 @@ def merge_labeled_ancestors(self, S, input_id):
if is_sample:
self.record_edge(left, right, output_id, ancestry_node)
ancestry_node = output_id
elif self.keep_unary:
elif self.keep_unary or (
self.keep_unary_in_individuals
and self.ts.node(input_id).individual >= 0
):
if output_id == -1:
output_id = self.record_node(input_id)
self.record_edge(left, right, output_id, ancestry_node)
Expand All @@ -308,7 +313,10 @@ def merge_labeled_ancestors(self, S, input_id):
if is_sample and left != prev_right:
# Fill in any gaps in the ancestry for the sample
self.add_ancestry(input_id, prev_right, left, output_id)
if self.keep_unary:
if self.keep_unary or (
self.keep_unary_in_individuals
and self.ts.node(input_id).individual >= 0
):
ancestry_node = output_id
self.add_ancestry(input_id, left, right, ancestry_node)
prev_right = right
Expand Down Expand Up @@ -757,7 +765,6 @@ def print_state(self):

samples = list(map(int, sys.argv[3:]))

# When keep_unary = True
print("When keep_unary = True:")
s = Simplifier(ts, samples, keep_unary=True)
# s.print_state()
Expand All @@ -768,8 +775,7 @@ def print_state(self):
print(tables.sites)
print(tables.mutations)

# When keep_unary = False
print("\nWhen keep_unary = False:")
print("\nWhen keep_unary = False")
s = Simplifier(ts, samples, keep_unary=False)
# s.print_state()
tss, _ = s.simplify()
Expand All @@ -779,6 +785,16 @@ def print_state(self):
print(tables.sites)
print(tables.mutations)

print("\nWhen keep_unary_in_individuals = True")
s = Simplifier(ts, samples, keep_unary_in_individuals=True)
# s.print_state()
tss, _ = s.simplify()
tables = tss.dump_tables()
print(tables.nodes)
print(tables.edges)
print(tables.sites)
print(tables.mutations)

elif class_to_implement == "AncestorMap":

samples = sys.argv[3]
Expand Down
2 changes: 2 additions & 0 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ def test_simplify_bad_args(self):
tc.simplify("asdf")
with pytest.raises(TypeError):
tc.simplify([0, 1], keep_unary="sdf")
with pytest.raises(TypeError):
tc.simplify([0, 1], keep_unary_in_individuals="abc")
with pytest.raises(TypeError):
tc.simplify([0, 1], keep_input_roots="sdf")
with pytest.raises(TypeError):
Expand Down
23 changes: 22 additions & 1 deletion python/tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2368,12 +2368,13 @@ def wf_sim_with_individual_metadata(self):
9,
10,
seed=1,
deep_history=True,
deep_history=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't thought throught he implications of changing this for test_shuffled_individual_parent_mapping, have you?

Copy link
Member Author

@hyanwong hyanwong Feb 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Erm, not sure why should this make a difference? We still shuffle the individuals and check that they are (a) shuffled and (b) the original individual ids correspond.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But perhaps I'm not understanding something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change the default here? Seems gratuitous to me.

Copy link
Member Author

@hyanwong hyanwong Feb 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below. If we want to check that all nodes have an individual, and all the individuals point to their correct parents, then we can't include the deep history nodes, which have no individuals (and also the nodes at the top of the WF generated simulation will not have the correct parents).

I.e. if we have a deep history, we (almost by definition) cannot fully map the trees onto the individuals pedigree. It seemed worth testing that we were capturing all of the genetic pedigree in the individuals.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(yan has thought this over)

initial_generation_samples=False,
num_loci=5,
record_individuals=True,
)
assert tables.individuals.num_rows > 50
assert np.all(tables.nodes.individual >= 0)
individuals_copy = tables.copy().individuals
tables.individuals.clear()
tables.individuals.metadata_schema = tskit.MetadataSchema({"codec": "json"})
Expand Down Expand Up @@ -2404,6 +2405,26 @@ def test_individual_parent_mapping(self, wf_sim_with_individual_metadata):
)
assert set(tables.individuals.parents) != {tskit.NULL}

def verify_complete_genetic_pedigree(self, tables):
ts = tables.tree_sequence()
for edge in ts.edges():
child = ts.individual(ts.node(edge.child).individual)
parent = ts.individual(ts.node(edge.parent).individual)
assert parent.id in child.parents
assert parent.metadata["original_id"] in child.metadata["original_parents"]

def test_no_complete_genetic_pedigree(self, wf_sim_with_individual_metadata):
tables = wf_sim_with_individual_metadata.copy()
tables.simplify() # Will remove intermediate individuals
with pytest.raises(AssertionError):
self.verify_complete_genetic_pedigree(tables)

def test_complete_genetic_pedigree(self, wf_sim_with_individual_metadata):
for params in [{"keep_unary": True}, {"keep_unary_in_individuals": True}]:
tables = wf_sim_with_individual_metadata.copy()
tables.simplify(**params) # Keep intermediate individuals
self.verify_complete_genetic_pedigree(tables)

def test_shuffled_individual_parent_mapping(self, wf_sim_with_individual_metadata):
tables = wf_sim_with_individual_metadata.copy()
tsutil.shuffle_tables(
Expand Down
18 changes: 12 additions & 6 deletions python/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -2363,19 +2363,20 @@ def test_ladder_tree(self):
def verify_unary_tree_sequence(self, ts):
"""
Take the specified tree sequence and produce an equivalent in which
unary records have been interspersed.
unary records have been interspersed, every other with an associated individual
"""
assert ts.num_trees > 2
assert ts.num_mutations > 2
tables = ts.dump_tables()
next_node = ts.num_nodes
node_times = {j: node.time for j, node in enumerate(ts.nodes())}
edges = []
for e in ts.edges():
for i, e in enumerate(ts.edges()):
node = ts.node(e.parent)
t = node.time - 1e-14 # Arbitrary small value.
next_node = len(tables.nodes)
tables.nodes.add_row(time=t, population=node.population)
indiv = tables.individuals.add_row() if i % 2 == 0 else tskit.NULL
tables.nodes.add_row(time=t, population=node.population, individual=indiv)
edges.append(
tskit.Edge(left=e.left, right=e.right, parent=next_node, child=e.child)
)
Expand All @@ -2398,11 +2399,16 @@ def verify_unary_tree_sequence(self, ts):
self.assert_haplotypes_equal(ts, ts_simplified)
self.assert_variants_equal(ts, ts_simplified)
assert len(list(ts.edge_diffs())) == ts.num_trees
assert 0 < ts_new.num_individuals < ts_new.num_nodes

for keep_unary in [True, False]:
s = tests.Simplifier(ts, ts.samples(), keep_unary=keep_unary)
for params in [
{"keep_unary": False, "keep_unary_in_individuals": False},
{"keep_unary": True, "keep_unary_in_individuals": False},
{"keep_unary": False, "keep_unary_in_individuals": True},
]:
s = tests.Simplifier(ts_new, ts_new.samples(), **params)
py_ts, py_node_map = s.simplify()
lib_ts, lib_node_map = ts.simplify(keep_unary=keep_unary, map_nodes=True)
lib_ts, lib_node_map = ts_new.simplify(map_nodes=True, **params)
py_tables = py_ts.dump_tables()
py_tables.provenances.clear()
lib_tables = lib_ts.dump_tables()
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_ploidy_2_reversed(self):
ts = msprime.simulate(10, random_seed=1)
assert ts.num_individuals == 0
samples = ts.samples()[::-1]
ts = tsutil.insert_individuals(ts, samples=samples, ploidy=2)
ts = tsutil.insert_individuals(ts, nodes=samples, ploidy=2)
assert ts.num_individuals == 5
for j, ind in enumerate(ts.individuals()):
assert list(ind.nodes) == [samples[2 * j + 1], samples[2 * j]]
4 changes: 2 additions & 2 deletions python/tests/test_vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,14 @@ def test_simple_infinite_sites_ploidy_2(self):
def test_simple_infinite_sites_ploidy_2_reversed_samples(self):
ts = msprime.simulate(10, mutation_rate=1, random_seed=2)
samples = ts.samples()[::-1]
ts = tsutil.insert_individuals(ts, samples=samples, ploidy=2)
ts = tsutil.insert_individuals(ts, nodes=samples, ploidy=2)
assert ts.num_sites > 2
self.verify(ts)

def test_simple_infinite_sites_ploidy_2_even_samples(self):
ts = msprime.simulate(20, mutation_rate=1, random_seed=2)
samples = ts.samples()[0::2]
ts = tsutil.insert_individuals(ts, samples=samples, ploidy=2)
ts = tsutil.insert_individuals(ts, nodes=samples, ploidy=2)
assert ts.num_sites > 2
self.verify(ts)

Expand Down
55 changes: 55 additions & 0 deletions python/tests/test_wright_fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,3 +592,58 @@ def test_simplify_tables(self, ts, nsamples):
other_tables.provenances.clear()
assert tables == other_tables
self.verify_simplify(ts, small_ts, sub_samples, node_map)

@pytest.mark.parametrize("ts", wf_sims)
@pytest.mark.parametrize("nsamples", [2, 5])
def test_simplify_keep_unary(self, ts, nsamples):
np.random.seed(123)
ts = tsutil.mark_metadata(ts, "nodes")
sub_samples = random.sample(list(ts.samples()), min(nsamples, ts.num_samples))
random_nodes = np.random.choice(ts.num_nodes, ts.num_nodes // 2)
ts = tsutil.insert_individuals(ts, random_nodes)
ts = tsutil.mark_metadata(ts, "individuals")

for params in [{}, {"keep_unary": True}, {"keep_unary_in_individuals": True}]:
sts = ts.simplify(sub_samples, **params)
# check samples match
assert sts.num_samples == len(sub_samples)
for n, sn in zip(sub_samples, sts.samples()):
assert ts.node(n).metadata == sts.node(sn).metadata

# check that nodes are correctly retained: only nodes ancestral to
# retained samples, and: by default, only coalescent events; if
# keep_unary_in_individuals then also nodes in individuals; if
# keep_unary then all such nodes.
for t in ts.trees(tracked_samples=sub_samples):
st = sts.at(t.interval[0])
visited = [False for _ in sts.nodes()]
for n, sn in zip(sub_samples, sts.samples()):
last_n = t.num_tracked_samples(n)
while n != tskit.NULL:
ind = ts.node(n).individual
keep = False
if t.num_tracked_samples(n) > last_n:
# a coalescent node
keep = True
if "keep_unary_in_individuals" in params and ind != tskit.NULL:
keep = True
if "keep_unary" in params:
keep = True
if (n in sub_samples) or keep:
visited[sn] = True
assert sn != tskit.NULL
assert ts.node(n).metadata == sts.node(sn).metadata
assert t.num_tracked_samples(n) == st.num_samples(sn)
if ind != tskit.NULL:
sind = sts.node(sn).individual
assert sind != tskit.NULL
assert (
ts.individual(ind).metadata
== sts.individual(sind).metadata
)
sn = st.parent(sn)
last_n = t.num_tracked_samples(n)
n = t.parent(n)
st_nodes = list(st.nodes())
for k, v in enumerate(visited):
assert v == (k in st_nodes)
36 changes: 25 additions & 11 deletions python/tests/tsutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,30 +242,44 @@ def insert_random_ploidy_individuals(ts, max_ploidy=5, max_dimension=3, seed=1):
return tables.tree_sequence()


def insert_individuals(ts, samples=None, ploidy=1):
def insert_individuals(ts, nodes=None, ploidy=1):
"""
Inserts individuals into the tree sequence using the specified list
of samples (or all samples if None) with the specified ploidy by combining
ploidy-sized chunks of the list.
of node (or use all sample nodes if None) with the specified ploidy by combining
ploidy-sized chunks of the list. Add metadata to the individuals so we can
track them
"""
if samples is None:
samples = ts.samples()
if len(samples) % ploidy != 0:
raise ValueError("number of samples must be divisible by ploidy")
if nodes is None:
nodes = ts.samples()
assert len(nodes) % ploidy == 0 # To allow mixed ploidies we could comment this out
tables = ts.dump_tables()
tables.individuals.clear()
individual = tables.nodes.individual[:]
individual[:] = tskit.NULL
j = 0
while j < len(samples):
nodes = samples[j : j + ploidy]
ind_id = tables.individuals.add_row()
individual[nodes] = ind_id
while j < len(nodes):
nodes_in_individual = nodes[j : min(len(nodes), j + ploidy)]
# should we warn here if nodes[j : j + ploidy] are at different times?
# probably not, as although this is unusual, it is actually allowed
ind_id = tables.individuals.add_row(
metadata=f"orig_id {tables.individuals.num_rows}".encode()
)
individual[nodes_in_individual] = ind_id
j += ploidy
tables.nodes.individual = individual
return tables.tree_sequence()


def mark_metadata(ts, table_name, prefix="orig_id:"):
"""
Add metadata to all rows of the form prefix + row_number
"""
tables = ts.dump_tables()
table = getattr(tables, table_name)
table.packset_metadata([(prefix + str(i)).encode() for i in range(table.num_rows)])
return tables.tree_sequence()


def permute_nodes(ts, node_map):
"""
Returns a copy of the specified tree sequence such that the nodes are
Expand Down
16 changes: 13 additions & 3 deletions python/tskit/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2493,6 +2493,7 @@ def simplify(
filter_individuals=True,
filter_sites=True,
keep_unary=False,
keep_unary_in_individuals=None,
keep_input_roots=False,
record_provenance=True,
filter_zero_mutation_sites=None, # Deprecated alias for filter_sites
Expand Down Expand Up @@ -2538,9 +2539,14 @@ def simplify(
not referenced by mutations after simplification; new site IDs are
allocated sequentially from zero. If False, the site table will not
be altered in any way. (Default: True)
:param bool keep_unary: If True, any unary nodes (i.e. nodes with exactly
one child) that exist on the path from samples to root will be preserved
in the output. (Default: False)
:param bool keep_unary: If True, preserve unary nodes (i.e. nodes with
exactly one child) that exist on the path from samples to root.
(Default: False)
:param bool keep_unary_in_individuals: If True, preserve unary nodes
that exist on the path from samples to root, but only if they are
associated with an individual in the individuals table. Cannot be
specified at the same time as ``keep_unary``. (Default: ``None``,
equivalent to False)
:param bool keep_input_roots: Whether to retain history ancestral to the
MRCA of the samples. If ``False``, no topology older than the MRCAs of the
samples will be included. If ``True`` the roots of all trees in the returned
Expand Down Expand Up @@ -2568,13 +2574,17 @@ def simplify(
].astype(np.int32)
else:
samples = util.safe_np_int_cast(samples, np.int32)
if keep_unary_in_individuals is None:
keep_unary_in_individuals = False

node_map = self._ll_tables.simplify(
samples,
filter_sites=filter_sites,
filter_individuals=filter_individuals,
filter_populations=filter_populations,
reduce_to_site_topology=reduce_to_site_topology,
keep_unary=keep_unary,
keep_unary_in_individuals=keep_unary_in_individuals,
keep_input_roots=keep_input_roots,
)
if record_provenance:
Expand Down
Loading