Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional_nodes DTWF/FIXED_PEDIGREE #2176

Merged
merged 1 commit into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
flag as well as record edges for coalescing nodes along non-coalescing segments. ({issue}`2128`, {issue}`2132`, {pr}`2162`,
{user}`GertjanBisschop`)

- Enable `additional_nodes` and `coalescing_segments_only` flags for `DTWF` and `FIXED_PEDIDGREE` models ({issue}`2129`, {issue}`2133`, {issue}`2167`, {pr}`2176`,
{user}`GertjanBisschop`)

## [1.2.0] - 2022-05-18

**New features**
Expand Down
127 changes: 92 additions & 35 deletions algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,11 @@ def flush_edges(self):
self.tables.edges.add_row(left, right, parent, child)
self.edge_buffer = []

def update_node_flag(self, node_id, flag):
node_obj = self.tables.nodes[node_id]
node_obj = node_obj.replace(flags=node_obj.flags | flag)
self.tables.nodes[node_id] = node_obj

def store_edge(self, left, right, parent, child):
"""
Stores the specified edge to the output tree sequence.
Expand Down Expand Up @@ -1109,6 +1114,19 @@ def dtwf_generation(self):
"""
Evolves one generation of a Wright Fisher population
"""
# Migration events happen at the rates in the matrix.
for j in range(len(self.P)):
source_size = self.P[j].get_num_ancestors()
for k in range(len(self.P)):
if j == k:
continue
mig_rate = source_size * self.migration_matrix[j][k]
num_migs = min(source_size, np.random.poisson(mig_rate))
for _ in range(num_migs):
mig_source = j
mig_dest = k
self.migration_event(mig_source, mig_dest)

for pop_idx, pop in enumerate(self.P):
# Cluster haploid inds by parent
parent_inds = pop.get_ind_range(self.t)
Expand All @@ -1123,9 +1141,10 @@ def dtwf_generation(self):
# Draw recombinations in children and sort segments by
# inheritance direction
for children in offspring.values():
parent_nodes = [-1, -1]
H = [[], []]
for child in children:
segs_pair = self.dtwf_recombine(child)
segs_pair = self.dtwf_recombine(child, parent_nodes)
for seg in segs_pair:
if seg is not None and seg.index != child.index:
pop.add(seg)
Expand All @@ -1139,32 +1158,33 @@ def dtwf_generation(self):
heapq.heappush(H[i], (seg.left, seg))

# Merge segments
for h in H:
for ploid, h in enumerate(H):
segments_to_merge = len(h)
if segments_to_merge == 1:
if (
self.additional_nodes.value & msprime.NODE_IS_PASS_THROUGH
> 0
):
parent_nodes[ploid] = self.store_additional_nodes_edges(
msprime.NODE_IS_PASS_THROUGH,
parent_nodes[ploid],
h[0][1],
)
h = []
elif segments_to_merge >= 2:
for _, individual in h:
pop.remove_individual(individual)
# parent_nodes[ploid] does not need to be updated here
if segments_to_merge == 2:
self.merge_two_ancestors(pop_idx, 0, h[0][1], h[1][1])
self.merge_two_ancestors(
pop_idx, 0, h[0][1], h[1][1], parent_nodes[ploid]
)
else:
self.merge_ancestors(h, pop_idx, 0) # label 0 only
self.merge_ancestors(
h, pop_idx, 0, parent_nodes[ploid]
) # label 0 only
self.verify()

# Migration events happen at the rates in the matrix.
for j in range(len(self.P)):
source_size = self.P[j].get_num_ancestors()
for k in range(len(self.P)):
if j == k:
continue
mig_rate = source_size * self.migration_matrix[j][k]
num_migs = min(source_size, np.random.poisson(mig_rate))
for _ in range(num_migs):
mig_source = j
mig_dest = k
self.migration_event(mig_source, mig_dest)

def process_pedigree_common_ancestors(self, ind, ploid):
"""
Merge the ancestral material that has been inherited on this "ploid"
Expand Down Expand Up @@ -1212,7 +1232,7 @@ def process_pedigree_common_ancestors(self, ind, ploid):
# created by recombining between the parent's monoploid genomes
# to create two independent lines of ancestry.
parent = self.pedigree.individuals[ind.parents[ploid]]
parent_ancestry = self.dtwf_recombine(genome)
parent_ancestry = self.dtwf_recombine(genome, parent.nodes)
for parent_ploid in range(ind.ploidy):
seg = parent_ancestry[parent_ploid]
if seg is not None:
Expand Down Expand Up @@ -1249,8 +1269,8 @@ def dtwf_climb_pedigree(self):
for ploid in range(ind.ploidy):
self.process_pedigree_common_ancestors(ind, ploid)

def store_arg_edges(self, segment, u=None):
if u is None:
def store_arg_edges(self, segment, u=-1):
if u == -1:
u = len(self.tables.nodes) - 1
# Store edges pointing to current node to the left
x = segment
Expand Down Expand Up @@ -1619,7 +1639,7 @@ def dtwf_generate_breakpoint(self, start):
bp = math.floor(bp)
return bp

def dtwf_recombine(self, x):
def dtwf_recombine(self, x, ind_nodes):
"""
Chooses breakpoints and returns segments sorted by inheritance
direction, by iterating through segment chain starting with x
Expand Down Expand Up @@ -1695,6 +1715,16 @@ def dtwf_recombine(self, x):
v = s.next
self.free_segment(s)

if self.additional_nodes.value & msprime.NODE_IS_RE_EVENT > 0:
re_event = all(segment is not None for segment in [u, v])
if re_event:
for ploid, segment in enumerate([u, v]):
ind_nodes[ploid] = self.store_additional_nodes_edges(
msprime.NODE_IS_RE_EVENT,
ind_nodes[ploid],
segment,
)

return u, v

def census_event(self, time):
Expand Down Expand Up @@ -1722,10 +1752,21 @@ def bottleneck_event(self, pop_id, label, intensity):
heapq.heappush(H, (x.left, x))
self.merge_ancestors(H, pop_id, label)

def store_additional_nodes_edges(self, flag, new_node_id, z):
if self.additional_nodes.value & flag > 0:
if new_node_id == -1:
new_node_id = self.store_node(z.population, flags=flag)
else:
self.flush_edges()
self.update_node_flag(new_node_id, flag)
self.store_arg_edges(z, new_node_id)
return new_node_id

def merge_ancestors(self, H, pop_id, label, new_node_id=-1):
pop = self.P[pop_id]
defrag_required = False
coalescence = False
pass_through = len(H) == 1
alpha = None
z = None
merged_head = None
Expand Down Expand Up @@ -1754,9 +1795,11 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1):
alpha = x
alpha.next = None
else:
coalescence = True
if new_node_id == -1:
coalescence = True
new_node_id = self.store_node(pop_id)
else:
self.flush_edges()
# We must also break if the next left value is less than
# any of the right values in the current overlap set.
if left not in self.S:
Expand All @@ -1777,7 +1820,8 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1):
alpha = self.alloc_segment(left, right, new_node_id, pop_id)
# Update the heaps and make the record.
for x in X:
self.store_edge(left, right, new_node_id, x.node)
if x.node != new_node_id: # required for dtwf and fixed_pedigree
self.store_edge(left, right, new_node_id, x.node)
if x.right == right:
self.free_segment(x)
if x.next is not None:
Expand Down Expand Up @@ -1807,11 +1851,20 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1):
z = alpha
if coalescence:
if not self.coalescing_segments_only:
self.store_arg_edges(z)
self.store_arg_edges(z, new_node_id)
else:
if self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0:
self.store_node(pop_id, flags=msprime.NODE_IS_CA_EVENT)
self.store_arg_edges(z)
if not pass_through:
if self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0:
new_node_id = self.store_additional_nodes_edges(
msprime.NODE_IS_CA_EVENT, new_node_id, z
)
else:
if self.additional_nodes.value & msprime.NODE_IS_PASS_THROUGH > 0:
assert new_node_id != -1
assert self.model == "fixed_pedigree"
new_node_id = self.store_additional_nodes_edges(
msprime.NODE_IS_PASS_THROUGH, new_node_id, z
)

if defrag_required:
self.defrag_segment_chain(z)
Expand Down Expand Up @@ -1855,7 +1908,7 @@ def common_ancestor_event(self, population_index, label):
y = pop.remove(j, label)
self.merge_two_ancestors(population_index, label, x, y)

def merge_two_ancestors(self, population_index, label, x, y):
def merge_two_ancestors(self, population_index, label, x, y, u=-1):
pop = self.P[population_index]
self.num_ca_events += 1
z = None
Expand Down Expand Up @@ -1888,8 +1941,11 @@ def merge_two_ancestors(self, population_index, label, x, y):
else:
if not coalescence:
coalescence = True
self.store_node(population_index)
u = len(self.tables.nodes) - 1
if u == -1:
self.store_node(population_index)
u = len(self.tables.nodes) - 1
else:
self.flush_edges()
# Put in breakpoints for the outer edges of the coalesced
# segment
left = x.left
Expand All @@ -1916,8 +1972,10 @@ def merge_two_ancestors(self, population_index, label, x, y):
population=population_index,
label=label,
)
self.store_edge(left, right, u, x.node)
self.store_edge(left, right, u, y.node)
if x.node != u: # required for dtwf and fixed_pedigree
self.store_edge(left, right, u, x.node)
if y.node != u: # required for dtwf and fixed_pedigree
self.store_edge(left, right, u, y.node)
# Now trim the ends of x and y to the right sizes.
if x.right == right:
self.free_segment(x)
Expand Down Expand Up @@ -1953,8 +2011,7 @@ def merge_two_ancestors(self, population_index, label, x, y):
self.store_arg_edges(z, u)
else:
if self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0:
u = self.store_node(population_index, flags=msprime.NODE_IS_CA_EVENT)
self.store_arg_edges(z, u)
self.store_additional_nodes_edges(msprime.NODE_IS_CA_EVENT, u, z)

if defrag_required:
self.defrag_segment_chain(z)
Expand Down Expand Up @@ -2293,7 +2350,7 @@ def add_simulator_arguments(parser):
help="Only record edges along coalescing segments.",
)
parser.add_argument(
"--additional_nodes",
"--additional-nodes",
type=int,
default=0,
help="Record edges along all segments for coalescing nodes.",
Expand Down