Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dependency levels rather than epochs #486

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2546,16 +2546,20 @@ def finalise(self):
# Read mode
####################################

def ancestors(self):
def ancestors(self, indexes=None):
"""
Returns an iterator over all the ancestors.
Returns an iterator over all the ancestors. If indexes is provided, it should
be a sorted list of indexes giving a subset of ancestors to return.
For efficiency, the indexes should be a numpy integer array.
"""
# TODO document properly.
start = self.ancestors_start[:]
end = self.ancestors_end[:]
time = self.ancestors_time[:]
focal_sites = self.ancestors_focal_sites[:]
for j, h in enumerate(chunk_iterator(self.ancestors_haplotype)):
haplotypes = chunk_iterator(self.ancestors_haplotype, indexes)
if indexes is None:
indexes = range(len(time))
for j, h in zip(indexes, haplotypes):
yield Ancestor(
id=j,
start=start[j],
Expand Down
129 changes: 68 additions & 61 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import collections
import copy
import heapq
import itertools
import json
import logging
import queue
Expand Down Expand Up @@ -853,11 +854,10 @@ def _run_synchronous(self, progress):
start, end = self.ancestor_builder.make_ancestor(focal_sites, a)
duration = time.perf_counter() - before
logger.debug(
"Made ancestor in {:.2f}s at timepoint {} (epoch {}) "
"Made ancestor in {:.2f}s at timepoint {} "
"from {} to {} (len={}) with {} focal sites ({})".format(
duration,
t,
self.timepoint_to_epoch[t],
start,
end,
end - start,
Expand Down Expand Up @@ -938,16 +938,11 @@ def build_worker(thread_index):
def run(self):
self.descriptors = self.ancestor_builder.ancestor_descriptors()
self.num_ancestors = len(self.descriptors)
# Maps epoch numbers to their corresponding ancestor times.
self.timepoint_to_epoch = {}
for t, _ in reversed(self.descriptors):
if t not in self.timepoint_to_epoch:
self.timepoint_to_epoch[t] = len(self.timepoint_to_epoch) + 1
if self.num_ancestors > 0:
logger.info(f"Starting build for {self.num_ancestors} ancestors")
progress = self.progress_monitor.get("ga_generate", self.num_ancestors)
a = np.zeros(self.num_sites, dtype=np.int8)
root_time = max(self.timepoint_to_epoch.keys()) + 1
root_time = self.descriptors[0][0] + 1 # first descriptor is the oldest
ultimate_ancestor_time = root_time + 1
# Add the ultimate ancestor. This is an awkward hack really; we don't
# ever insert this ancestor. The only reason to add it here is that
Expand Down Expand Up @@ -1241,31 +1236,48 @@ def __init__(self, sample_data, ancestor_data, **kwargs):
super().__init__(sample_data, ancestor_data.sites_position[:], **kwargs)
self.ancestor_data = ancestor_data
self.num_ancestors = self.ancestor_data.num_ancestors
self.epoch = self.ancestor_data.ancestors_time[:]
self.ancestors_dependency_level = {}
# Ancestors are chunked into groups of different dependency levels
# Level 0 is the ultimate ancestor. Level 1 are all the ancestors which can only
# copy from ancestors at Level 0. Level 2 are all the ancestors which can only
# copy from ancestors at Levels 0 and 1, etc. All ancestors within the same
# level can be processed in parallel

# Find the dependencies
anc_start = ancestor_data.ancestors_start[:]
anc_end = ancestor_data.ancestors_end[:]
anc_time = ancestor_data.ancestors_time[:]
dep_level = np.zeros(self.num_ancestors, dtype=int)
anc_iter = enumerate(zip(anc_start, anc_end, anc_time))
for epoch_time, epoch_grp in itertools.groupby(anc_iter, key=lambda x: x[1][2]):
curr_epoch_start = None
for anc_id, (lft, rgt, t) in epoch_grp:
if curr_epoch_start is None:
curr_epoch_start = anc_id
assert epoch_time == t
# NB the line below is currently quite slow, and should be optimised
prev_ancestors = slice(0, curr_epoch_start)
dependencies = np.where(
np.logical_and(
anc_start[prev_ancestors] < rgt, anc_end[prev_ancestors] > lft
)
)[0]
if len(dependencies) > 0:
dep_level[anc_id] = np.max(dep_level[dependencies]) + 1
# One issue is that overlapping ancestors within the same epoch can be at
# different dep levels => force all ancestors in this epoch to the max level
# Perhaps better to do this in a single pass rather than within this loop?
dep_level[curr_epoch_start : (anc_id + 1)] = np.max(
dep_level[curr_epoch_start : (anc_id + 1)]
)
for level in np.unique(dep_level):
if level > 0: # Only run matching for ancestors that have dependencies
self.ancestors_dependency_level[level] = np.where(dep_level == level)[0]

# Add nodes for all the ancestors so that the ancestor IDs are equal
# to the node IDs.
for ancestor_id in range(self.num_ancestors):
self.tree_sequence_builder.add_node(self.epoch[ancestor_id])
self.ancestors = self.ancestor_data.ancestors()
# Consume the first ancestor.
a = next(self.ancestors, None)
self.num_epochs = 0
if a is not None:
# assert np.array_equal(a.haplotype, np.zeros(self.num_sites, dtype=np.int8))
# Create a list of all ID ranges in each epoch.
breaks = np.where(self.epoch[1:] != self.epoch[:-1])[0]
start = np.hstack([[0], breaks + 1])
end = np.hstack([breaks + 1, [self.num_ancestors]])
self.epoch_slices = np.vstack([start, end]).T
self.num_epochs = self.epoch_slices.shape[0]
self.start_epoch = 1

def __epoch_info_dict(self, epoch_index):
start, end = self.epoch_slices[epoch_index]
return collections.OrderedDict(
[("epoch", str(self.epoch[start])), ("nanc", str(end - start))]
)
for ancestor in self.ancestor_data.ancestors():
self.tree_sequence_builder.add_node(ancestor.time)

def __ancestor_find_path(self, ancestor, thread_index=0):
# NOTE we're no longer using the ancestor's focal sites as a way
Expand All @@ -1279,21 +1291,18 @@ def __ancestor_find_path(self, ancestor, thread_index=0):
haplotype[start:end] = ancestor.haplotype
self._find_path(ancestor.id, haplotype, start, end, thread_index)

def __start_epoch(self, epoch_index):
start, end = self.epoch_slices[epoch_index]
def __start_level(self, level, ancestor_ids):
info = collections.OrderedDict(
[("epoch", str(self.epoch[start])), ("nanc", str(end - start))]
[("level", str(level)), ("nanc", str(len(ancestor_ids)))]
)
self.progress_monitor.set_detail(info)
self.tree_sequence_builder.freeze_indexes()

def __complete_epoch(self, epoch_index):
start, end = map(int, self.epoch_slices[epoch_index])
num_ancestors_in_epoch = end - start
current_time = self.epoch[start]
def __complete_level(self, level, ancestor_ids):
num_ancestors_in_level = len(ancestor_ids)
nodes_before = self.tree_sequence_builder.num_nodes

for child_id in range(start, end):
for child_id in ancestor_ids:
child_id = int(child_id)
left, right, parent = self.results.get_path(child_id)
self.tree_sequence_builder.add_path(
child_id,
Expand All @@ -1309,10 +1318,10 @@ def __complete_epoch(self, epoch_index):
extra_nodes = self.tree_sequence_builder.num_nodes - nodes_before
mean_memory = np.mean([matcher.total_memory for matcher in self.matcher])
logger.debug(
"Finished epoch {} with {} ancestors; {} extra nodes inserted; "
"Finished level {} with {} ancestors; {} extra nodes inserted; "
"mean_tb_size={:.2f} edges={}; mean_matcher_mem={}".format(
current_time,
num_ancestors_in_epoch,
level,
num_ancestors_in_level,
extra_nodes,
np.sum(self.mean_traceback_size) / np.sum(self.num_matches),
self.tree_sequence_builder.num_edges,
Expand All @@ -1324,16 +1333,13 @@ def __complete_epoch(self, epoch_index):
self.results.clear()

def __match_ancestors_single_threaded(self):
for j in range(self.start_epoch, self.num_epochs):
self.__start_epoch(j)
start, end = map(int, self.epoch_slices[j])
for ancestor_id in range(start, end):
a = next(self.ancestors)
assert ancestor_id == a.id
self.__ancestor_find_path(a)
self.__complete_epoch(j)

def __match_ancestors_multi_threaded(self, start_epoch=1):
for level, ancestor_ids in self.ancestors_dependency_level.items():
self.__start_level(level, ancestor_ids)
for ancestor in self.ancestor_data.ancestors(indexes=ancestor_ids):
self.__ancestor_find_path(ancestor)
self.__complete_level(level, ancestor_ids)

def __match_ancestors_multi_threaded(self):
# See note on match samples multithreaded below. Should combine these
# into a single function. Possibly when trying to make the thread
# error handling more robust.
Expand All @@ -1357,16 +1363,13 @@ def match_worker(thread_index):
]
logger.debug(f"Started {self.num_threads} match worker threads")

for j in range(self.start_epoch, self.num_epochs):
self.__start_epoch(j)
start, end = map(int, self.epoch_slices[j])
for ancestor_id in range(start, end):
a = next(self.ancestors)
assert a.id == ancestor_id
match_queue.put(a)
for level, ancestor_ids in self.ancestors_dependency_level.items():
self.__start_level(level, ancestor_ids)
for ancestor in self.ancestor_data.ancestors(indexes=ancestor_ids):
match_queue.put(ancestor)
# Block until all matches have completed.
match_queue.join()
self.__complete_epoch(j)
self.__complete_level(level, ancestor_ids)

# Stop the the worker threads.
for _ in range(self.num_threads):
Expand All @@ -1375,7 +1378,11 @@ def match_worker(thread_index):
match_threads[j].join()

def match_ancestors(self):
logger.info(f"Starting ancestor matching for {self.num_epochs} epochs")
logger.info(
"Starting ancestor matching for {} dependency levels".format(
len(self.ancestors_dependency_level)
)
)
self.match_progress = self.progress_monitor.get("ma_match", self.num_ancestors)
if self.num_threads <= 0:
self.__match_ancestors_single_threaded()
Expand Down