Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tests/test_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,39 @@ def test_simulated_mutations(self, seed):
assert ts.num_mutations > 0
self.check_ts(ts)

def test_no_metadata_schema(self):
ts = msprime.sim_mutations(self.example_ts(), rate=1e-6, random_seed=43)
assert ts.num_mutations > 0
tables = ts.dump_tables()
tables.populations.metadata_schema = tskit.MetadataSchema(None)
self.check_ts(tables.tree_sequence())

def test_no_populations(self):
tables = single_tree_example_ts().dump_tables()
tables.populations.add_row(b"{}")
tsm = model.TSModel(tables.tree_sequence())
with pytest.raises(ValueError, match="must be assigned to populations"):
tsm.mutations_df


class TestNodeIsSample:
def test_simple_example(self):
ts = single_tree_example_ts()
is_sample = model.node_is_sample(ts)
for node in ts.nodes():
assert node.is_sample() == is_sample[node.id]

@pytest.mark.parametrize("bit", [1, 2, 17, 31])
def test_sample_and_other_flags(self, bit):
tables = single_tree_example_ts().dump_tables()
flags = tables.nodes.flags
tables.nodes.flags = flags | (1 << bit)
ts = tables.tree_sequence()
is_sample = model.node_is_sample(ts)
for node in ts.nodes():
assert node.is_sample() == is_sample[node.id]
assert (node.flags & (1 << bit)) != 0


class TestTreesDataTable:
def test_single_tree_example(self):
Expand Down
40 changes: 26 additions & 14 deletions tsqc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ class MutationCounts:


def compute_mutation_counts(ts):
logger.info("Computing mutation inheritance counts")
tree_pos = alloc_tree_position(ts)
mutations_position = ts.sites_position[ts.mutations_site].astype(int)
num_descendants, num_inheritors = _compute_mutation_inheritance_counts(
Expand All @@ -266,17 +267,22 @@ def _compute_population_mutation_counts(
num_populations,
edges_parent,
edges_child,
num_pop_samples,
nodes_is_sample,
nodes_population,
mutations_position,
mutations_node,
mutations_parent,
):
num_pop_samples = np.zeros((num_nodes, num_populations), dtype=np.int32)

pop_mutation_count = np.zeros((num_populations, num_mutations), dtype=np.int32)
parent = np.zeros(num_nodes, dtype=np.int32) - 1

mut_id = 0
for u in range(num_nodes):
if nodes_is_sample[u]:
num_pop_samples[u, nodes_population[u]] = 1

mut_id = 0
while tree_pos.next():
for j in range(tree_pos.out_range[0], tree_pos.out_range[1]):
e = tree_pos.edge_removal_order[j]
Expand All @@ -285,7 +291,8 @@ def _compute_population_mutation_counts(
parent[c] = -1
u = p
while u != -1:
num_pop_samples[u] -= num_pop_samples[c]
for k in range(num_populations):
num_pop_samples[u, k] -= num_pop_samples[c, k]
u = parent[u]

for j in range(tree_pos.in_range[0], tree_pos.in_range[1]):
Expand All @@ -295,7 +302,8 @@ def _compute_population_mutation_counts(
parent[c] = p
u = p
while u != -1:
num_pop_samples[u] += num_pop_samples[c]
for k in range(num_populations):
num_pop_samples[u, k] += num_pop_samples[c, k]
u = parent[u]

left, right = tree_pos.interval
Expand All @@ -309,18 +317,23 @@ def _compute_population_mutation_counts(
return pop_mutation_count


def node_is_sample(ts):
sample_flag = np.full_like(ts.nodes_flags, tskit.NODE_IS_SAMPLE)
return np.bitwise_and(ts.nodes_flags, sample_flag) != 0


def compute_population_mutation_counts(ts):
"""
Return a dataframe that gives the frequency of each mutation
in each of the populations in the specified tree sequence.
Return a (num_populations, num_mutations) array that gives the frequency
of each mutation in each of the populations in the specified tree sequence.
"""
logger.info(
f"Computing mutation frequencies within {ts.num_populations} populations"
)
mutations_position = ts.sites_position[ts.mutations_site].astype(int)
num_pop_samples = np.zeros((ts.num_nodes, ts.num_populations), dtype=np.int32)
for pop in range(ts.num_populations):
samples = np.logical_and(
ts.nodes_population == pop, ts.nodes_flags == 1 # Not quite right!
)
num_pop_samples[samples, pop] = 1

if np.any(ts.nodes_population[ts.samples()] == -1):
raise ValueError("Sample nodes must be assigned to populations")

return _compute_population_mutation_counts(
alloc_tree_position(ts),
Expand All @@ -329,7 +342,7 @@ def compute_population_mutation_counts(ts):
ts.num_populations,
ts.edges_parent,
ts.edges_child,
num_pop_samples,
node_is_sample(ts),
ts.nodes_population,
mutations_position,
ts.mutations_node,
Expand Down Expand Up @@ -409,7 +422,6 @@ def mutations_df(self):
unknown = tskit.is_unknown_time(mutations_time)
mutations_time[unknown] = self.ts.nodes_time[mutations_node[unknown]]

# node_flag = ts.nodes_flags[mutations_node]
position = ts.sites_position[ts.mutations_site]

tables = self.ts.tables
Expand Down