From 55e2d3fc94dd81379ee50beecb1eac2eadcae627 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 4 Apr 2024 10:29:01 +0100 Subject: [PATCH 1/2] Fix incorrect flags check in sample computation --- tests/test_data_model.py | 19 +++++++++++++++++++ tsqc/model.py | 10 +++++++--- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/tests/test_data_model.py b/tests/test_data_model.py index eea4b55..8b7ac69 100644 --- a/tests/test_data_model.py +++ b/tests/test_data_model.py @@ -233,6 +233,25 @@ def test_simulated_mutations(self, seed): self.check_ts(ts) +class TestNodeIsSample: + def test_simple_example(self): + ts = single_tree_example_ts() + is_sample = model.node_is_sample(ts) + for node in ts.nodes(): + assert node.is_sample() == is_sample[node.id] + + @pytest.mark.parametrize("bit", [1, 2, 17, 31]) + def test_sample_and_other_flags(self, bit): + tables = single_tree_example_ts().dump_tables() + flags = tables.nodes.flags + tables.nodes.flags = flags | (1 << bit) + ts = tables.tree_sequence() + is_sample = model.node_is_sample(ts) + for node in ts.nodes(): + assert node.is_sample() == is_sample[node.id] + assert (node.flags & (1 << bit)) != 0 + + class TestTreesDataTable: def test_single_tree_example(self): ts = single_tree_example_ts() diff --git a/tsqc/model.py b/tsqc/model.py index 2e106ca..43de339 100644 --- a/tsqc/model.py +++ b/tsqc/model.py @@ -309,6 +309,11 @@ def _compute_population_mutation_counts( return pop_mutation_count +def node_is_sample(ts): + sample_flag = np.full_like(ts.nodes_flags, tskit.NODE_IS_SAMPLE) + return np.bitwise_and(ts.nodes_flags, sample_flag) != 0 + + def compute_population_mutation_counts(ts): """ Return a dataframe that gives the frequency of each mutation @@ -316,10 +321,9 @@ def compute_population_mutation_counts(ts): """ mutations_position = ts.sites_position[ts.mutations_site].astype(int) num_pop_samples = np.zeros((ts.num_nodes, ts.num_populations), dtype=np.int32) + is_sample = node_is_sample(ts) for pop in range(ts.num_populations): - samples = np.logical_and( - ts.nodes_population == pop, ts.nodes_flags == 1 # Not quite right! - ) + samples = np.logical_and(ts.nodes_population == pop, is_sample) num_pop_samples[samples, pop] = 1 return _compute_population_mutation_counts( From 6277d76c771c43693b67858654150c23775f5d66 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 4 Apr 2024 10:51:36 +0100 Subject: [PATCH 2/2] Refactor to use explicit loops --- tests/test_data_model.py | 14 ++++++++++++++ tsqc/model.py | 34 +++++++++++++++++++++------------- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/tests/test_data_model.py b/tests/test_data_model.py index 8b7ac69..6ab9b9a 100644 --- a/tests/test_data_model.py +++ b/tests/test_data_model.py @@ -232,6 +232,20 @@ def test_simulated_mutations(self, seed): assert ts.num_mutations > 0 self.check_ts(ts) + def test_no_metadata_schema(self): + ts = msprime.sim_mutations(self.example_ts(), rate=1e-6, random_seed=43) + assert ts.num_mutations > 0 + tables = ts.dump_tables() + tables.populations.metadata_schema = tskit.MetadataSchema(None) + self.check_ts(tables.tree_sequence()) + + def test_no_populations(self): + tables = single_tree_example_ts().dump_tables() + tables.populations.add_row(b"{}") + tsm = model.TSModel(tables.tree_sequence()) + with pytest.raises(ValueError, match="must be assigned to populations"): + tsm.mutations_df + class TestNodeIsSample: def test_simple_example(self): diff --git a/tsqc/model.py b/tsqc/model.py index 43de339..079b265 100644 --- a/tsqc/model.py +++ b/tsqc/model.py @@ -241,6 +241,7 @@ class MutationCounts: def compute_mutation_counts(ts): + logger.info("Computing mutation inheritance counts") tree_pos = alloc_tree_position(ts) mutations_position = ts.sites_position[ts.mutations_site].astype(int) num_descendants, num_inheritors = _compute_mutation_inheritance_counts( @@ -266,17 +267,22 @@ def _compute_population_mutation_counts( num_populations, edges_parent, edges_child, - num_pop_samples, + nodes_is_sample, nodes_population, mutations_position, mutations_node, mutations_parent, ): + num_pop_samples = np.zeros((num_nodes, num_populations), dtype=np.int32) + pop_mutation_count = np.zeros((num_populations, num_mutations), dtype=np.int32) parent = np.zeros(num_nodes, dtype=np.int32) - 1 - mut_id = 0 + for u in range(num_nodes): + if nodes_is_sample[u]: + num_pop_samples[u, nodes_population[u]] = 1 + mut_id = 0 while tree_pos.next(): for j in range(tree_pos.out_range[0], tree_pos.out_range[1]): e = tree_pos.edge_removal_order[j] @@ -285,7 +291,8 @@ def _compute_population_mutation_counts( parent[c] = -1 u = p while u != -1: - num_pop_samples[u] -= num_pop_samples[c] + for k in range(num_populations): + num_pop_samples[u, k] -= num_pop_samples[c, k] u = parent[u] for j in range(tree_pos.in_range[0], tree_pos.in_range[1]): @@ -295,7 +302,8 @@ def _compute_population_mutation_counts( parent[c] = p u = p while u != -1: - num_pop_samples[u] += num_pop_samples[c] + for k in range(num_populations): + num_pop_samples[u, k] += num_pop_samples[c, k] u = parent[u] left, right = tree_pos.interval @@ -316,15 +324,16 @@ def node_is_sample(ts): def compute_population_mutation_counts(ts): """ - Return a dataframe that gives the frequency of each mutation - in each of the populations in the specified tree sequence. + Return a (num_populations, num_mutations) array that gives the frequency + of each mutation in each of the populations in the specified tree sequence. """ + logger.info( + f"Computing mutation frequencies within {ts.num_populations} populations" + ) mutations_position = ts.sites_position[ts.mutations_site].astype(int) - num_pop_samples = np.zeros((ts.num_nodes, ts.num_populations), dtype=np.int32) - is_sample = node_is_sample(ts) - for pop in range(ts.num_populations): - samples = np.logical_and(ts.nodes_population == pop, is_sample) - num_pop_samples[samples, pop] = 1 + + if np.any(ts.nodes_population[ts.samples()] == -1): + raise ValueError("Sample nodes must be assigned to populations") return _compute_population_mutation_counts( alloc_tree_position(ts), @@ -333,7 +342,7 @@ def compute_population_mutation_counts(ts): ts.num_populations, ts.edges_parent, ts.edges_child, - num_pop_samples, + node_is_sample(ts), ts.nodes_population, mutations_position, ts.mutations_node, @@ -413,7 +422,6 @@ def mutations_df(self): unknown = tskit.is_unknown_time(mutations_time) mutations_time[unknown] = self.ts.nodes_time[mutations_node[unknown]] - # node_flag = ts.nodes_flags[mutations_node] position = ts.sites_position[ts.mutations_site] tables = self.ts.tables