diff --git a/tsinfer/inference.py b/tsinfer/inference.py index 94d13685..ce8dbc86 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -475,35 +475,50 @@ def insert_sites(self, tables): metadata = self.sample_data.sites_metadata[:] position = self.sample_data.sites_position[:] _, node, derived_state, parent = self.tree_sequence_builder.dump_mutations() + num_non_inference_sites = np.sum(inference == 0) logger.info( "Starting mutation positioning for {} non inference sites".format( - np.sum(inference == 0))) + num_non_inference_sites)) + progress_monitor = self.progress_monitor.get("ms_sites", num_sites) - inferred_site = 0 - trees = ts.trees() - tree = next(trees) - for site_id, genotypes in self.sample_data.genotypes(): - x = position[site_id] - while tree.interval[1] <= x: - tree = next(trees) - assert tree.interval[0] <= x < tree.interval[1] - tables.sites.add_row( - position=x, - ancestral_state=alleles[site_id][0], - metadata=self.encode_metadata(metadata[site_id])) - if inference[site_id] == 1: + if num_non_inference_sites > 0: + inferred_site = 0 + trees = ts.trees() + tree = next(trees) + for site_id, genotypes in self.sample_data.genotypes(): + x = position[site_id] + while tree.interval[1] <= x: + tree = next(trees) + assert tree.interval[0] <= x < tree.interval[1] + tables.sites.add_row( + position=x, + ancestral_state=alleles[site_id][0], + metadata=self.encode_metadata(metadata[site_id])) + if inference[site_id] == 1: + tables.mutations.add_row( + site=site_id, node=node[inferred_site], + derived_state=alleles[site_id][derived_state[inferred_site]]) + inferred_site += 1 + elif ts.num_edges > 0: + self.locate_mutations_on_tree( + tree, site_id, genotypes, alleles[site_id], tables.mutations) + else: + # If we have no tree topology this is all we can do. + self.locate_mutations_over_samples( + site_id, genotypes, alleles[site_id], tables.mutations) + progress_monitor.update() + else: + # Simple case where all sites are inference sites. + for site_id in range(self.sample_data.num_sites): + x = position[site_id] + tables.sites.add_row( + position=x, + ancestral_state=alleles[site_id][0], + metadata=self.encode_metadata(metadata[site_id])) tables.mutations.add_row( - site=site_id, node=node[inferred_site], - derived_state=alleles[site_id][derived_state[inferred_site]]) - inferred_site += 1 - elif ts.num_edges > 0: - self.locate_mutations_on_tree( - tree, site_id, genotypes, alleles[site_id], tables.mutations) - else: - # If we have no tree topology this is all we can do. - self.locate_mutations_over_samples( - site_id, genotypes, alleles[site_id], tables.mutations) - progress_monitor.update() + site=site_id, node=node[site_id], + derived_state=alleles[site_id][derived_state[site_id]]) + progress_monitor.update() progress_monitor.close() def get_samples_tree_sequence(self):