diff --git a/tests/test_data_model.py b/tests/test_data_model.py index eea4b55..6ab9b9a 100644 --- a/tests/test_data_model.py +++ b/tests/test_data_model.py @@ -232,6 +232,39 @@ def test_simulated_mutations(self, seed): assert ts.num_mutations > 0 self.check_ts(ts) + def test_no_metadata_schema(self): + ts = msprime.sim_mutations(self.example_ts(), rate=1e-6, random_seed=43) + assert ts.num_mutations > 0 + tables = ts.dump_tables() + tables.populations.metadata_schema = tskit.MetadataSchema(None) + self.check_ts(tables.tree_sequence()) + + def test_no_populations(self): + tables = single_tree_example_ts().dump_tables() + tables.populations.add_row(b"{}") + tsm = model.TSModel(tables.tree_sequence()) + with pytest.raises(ValueError, match="must be assigned to populations"): + tsm.mutations_df + + +class TestNodeIsSample: + def test_simple_example(self): + ts = single_tree_example_ts() + is_sample = model.node_is_sample(ts) + for node in ts.nodes(): + assert node.is_sample() == is_sample[node.id] + + @pytest.mark.parametrize("bit", [1, 2, 17, 31]) + def test_sample_and_other_flags(self, bit): + tables = single_tree_example_ts().dump_tables() + flags = tables.nodes.flags + tables.nodes.flags = flags | (1 << bit) + ts = tables.tree_sequence() + is_sample = model.node_is_sample(ts) + for node in ts.nodes(): + assert node.is_sample() == is_sample[node.id] + assert (node.flags & (1 << bit)) != 0 + class TestTreesDataTable: def test_single_tree_example(self): diff --git a/tsqc/model.py b/tsqc/model.py index 2e106ca..079b265 100644 --- a/tsqc/model.py +++ b/tsqc/model.py @@ -241,6 +241,7 @@ class MutationCounts: def compute_mutation_counts(ts): + logger.info("Computing mutation inheritance counts") tree_pos = alloc_tree_position(ts) mutations_position = ts.sites_position[ts.mutations_site].astype(int) num_descendants, num_inheritors = _compute_mutation_inheritance_counts( @@ -266,17 +267,22 @@ def _compute_population_mutation_counts( num_populations, edges_parent, edges_child, - num_pop_samples, + nodes_is_sample, nodes_population, mutations_position, mutations_node, mutations_parent, ): + num_pop_samples = np.zeros((num_nodes, num_populations), dtype=np.int32) + pop_mutation_count = np.zeros((num_populations, num_mutations), dtype=np.int32) parent = np.zeros(num_nodes, dtype=np.int32) - 1 - mut_id = 0 + for u in range(num_nodes): + if nodes_is_sample[u]: + num_pop_samples[u, nodes_population[u]] = 1 + mut_id = 0 while tree_pos.next(): for j in range(tree_pos.out_range[0], tree_pos.out_range[1]): e = tree_pos.edge_removal_order[j] @@ -285,7 +291,8 @@ def _compute_population_mutation_counts( parent[c] = -1 u = p while u != -1: - num_pop_samples[u] -= num_pop_samples[c] + for k in range(num_populations): + num_pop_samples[u, k] -= num_pop_samples[c, k] u = parent[u] for j in range(tree_pos.in_range[0], tree_pos.in_range[1]): @@ -295,7 +302,8 @@ def _compute_population_mutation_counts( parent[c] = p u = p while u != -1: - num_pop_samples[u] += num_pop_samples[c] + for k in range(num_populations): + num_pop_samples[u, k] += num_pop_samples[c, k] u = parent[u] left, right = tree_pos.interval @@ -309,18 +317,23 @@ def _compute_population_mutation_counts( return pop_mutation_count +def node_is_sample(ts): + sample_flag = np.full_like(ts.nodes_flags, tskit.NODE_IS_SAMPLE) + return np.bitwise_and(ts.nodes_flags, sample_flag) != 0 + + def compute_population_mutation_counts(ts): """ - Return a dataframe that gives the frequency of each mutation - in each of the populations in the specified tree sequence. + Return a (num_populations, num_mutations) array that gives the frequency + of each mutation in each of the populations in the specified tree sequence. """ + logger.info( + f"Computing mutation frequencies within {ts.num_populations} populations" + ) mutations_position = ts.sites_position[ts.mutations_site].astype(int) - num_pop_samples = np.zeros((ts.num_nodes, ts.num_populations), dtype=np.int32) - for pop in range(ts.num_populations): - samples = np.logical_and( - ts.nodes_population == pop, ts.nodes_flags == 1 # Not quite right! - ) - num_pop_samples[samples, pop] = 1 + + if np.any(ts.nodes_population[ts.samples()] == -1): + raise ValueError("Sample nodes must be assigned to populations") return _compute_population_mutation_counts( alloc_tree_position(ts), @@ -329,7 +342,7 @@ def compute_population_mutation_counts(ts): ts.num_populations, ts.edges_parent, ts.edges_child, - num_pop_samples, + node_is_sample(ts), ts.nodes_population, mutations_position, ts.mutations_node, @@ -409,7 +422,6 @@ def mutations_df(self): unknown = tskit.is_unknown_time(mutations_time) mutations_time[unknown] = self.ts.nodes_time[mutations_node[unknown]] - # node_flag = ts.nodes_flags[mutations_node] position = ts.sites_position[ts.mutations_site] tables = self.ts.tables