Skip to content

Commit

Permalink
unary nodes DTWF/fixed_pedigree
Browse files Browse the repository at this point in the history
  • Loading branch information
GertjanBisschop authored and mergify[bot] committed May 23, 2023
1 parent 5583e5b commit 461dfd2
Show file tree
Hide file tree
Showing 12 changed files with 904 additions and 168 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
flag as well as record edges for coalescing nodes along non-coalescing segments. ({issue}`2128`, {issue}`2132`, {pr}`2162`,
{user}`GertjanBisschop`)

- Enable `additional_nodes` and `coalescing_segments_only` flags for `DTWF` and `FIXED_PEDIDGREE` models ({issue}`2129`, {issue}`2133`, {issue}`2167`, {pr}`2176`,
{user}`GertjanBisschop`)

## [1.2.0] - 2022-05-18

**New features**
Expand Down
127 changes: 92 additions & 35 deletions algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,11 @@ def flush_edges(self):
self.tables.edges.add_row(left, right, parent, child)
self.edge_buffer = []

def update_node_flag(self, node_id, flag):
node_obj = self.tables.nodes[node_id]
node_obj = node_obj.replace(flags=node_obj.flags | flag)
self.tables.nodes[node_id] = node_obj

def store_edge(self, left, right, parent, child):
"""
Stores the specified edge to the output tree sequence.
Expand Down Expand Up @@ -1109,6 +1114,19 @@ def dtwf_generation(self):
"""
Evolves one generation of a Wright Fisher population
"""
# Migration events happen at the rates in the matrix.
for j in range(len(self.P)):
source_size = self.P[j].get_num_ancestors()
for k in range(len(self.P)):
if j == k:
continue
mig_rate = source_size * self.migration_matrix[j][k]
num_migs = min(source_size, np.random.poisson(mig_rate))
for _ in range(num_migs):
mig_source = j
mig_dest = k
self.migration_event(mig_source, mig_dest)

for pop_idx, pop in enumerate(self.P):
# Cluster haploid inds by parent
parent_inds = pop.get_ind_range(self.t)
Expand All @@ -1123,9 +1141,10 @@ def dtwf_generation(self):
# Draw recombinations in children and sort segments by
# inheritance direction
for children in offspring.values():
parent_nodes = [-1, -1]
H = [[], []]
for child in children:
segs_pair = self.dtwf_recombine(child)
segs_pair = self.dtwf_recombine(child, parent_nodes)
for seg in segs_pair:
if seg is not None and seg.index != child.index:
pop.add(seg)
Expand All @@ -1139,32 +1158,33 @@ def dtwf_generation(self):
heapq.heappush(H[i], (seg.left, seg))

# Merge segments
for h in H:
for ploid, h in enumerate(H):
segments_to_merge = len(h)
if segments_to_merge == 1:
if (
self.additional_nodes.value & msprime.NODE_IS_PASS_THROUGH
> 0
):
parent_nodes[ploid] = self.store_additional_nodes_edges(
msprime.NODE_IS_PASS_THROUGH,
parent_nodes[ploid],
h[0][1],
)
h = []
elif segments_to_merge >= 2:
for _, individual in h:
pop.remove_individual(individual)
# parent_nodes[ploid] does not need to be updated here
if segments_to_merge == 2:
self.merge_two_ancestors(pop_idx, 0, h[0][1], h[1][1])
self.merge_two_ancestors(
pop_idx, 0, h[0][1], h[1][1], parent_nodes[ploid]
)
else:
self.merge_ancestors(h, pop_idx, 0) # label 0 only
self.merge_ancestors(
h, pop_idx, 0, parent_nodes[ploid]
) # label 0 only
self.verify()

# Migration events happen at the rates in the matrix.
for j in range(len(self.P)):
source_size = self.P[j].get_num_ancestors()
for k in range(len(self.P)):
if j == k:
continue
mig_rate = source_size * self.migration_matrix[j][k]
num_migs = min(source_size, np.random.poisson(mig_rate))
for _ in range(num_migs):
mig_source = j
mig_dest = k
self.migration_event(mig_source, mig_dest)

def process_pedigree_common_ancestors(self, ind, ploid):
"""
Merge the ancestral material that has been inherited on this "ploid"
Expand Down Expand Up @@ -1212,7 +1232,7 @@ def process_pedigree_common_ancestors(self, ind, ploid):
# created by recombining between the parent's monoploid genomes
# to create two independent lines of ancestry.
parent = self.pedigree.individuals[ind.parents[ploid]]
parent_ancestry = self.dtwf_recombine(genome)
parent_ancestry = self.dtwf_recombine(genome, parent.nodes)
for parent_ploid in range(ind.ploidy):
seg = parent_ancestry[parent_ploid]
if seg is not None:
Expand Down Expand Up @@ -1249,8 +1269,8 @@ def dtwf_climb_pedigree(self):
for ploid in range(ind.ploidy):
self.process_pedigree_common_ancestors(ind, ploid)

def store_arg_edges(self, segment, u=None):
if u is None:
def store_arg_edges(self, segment, u=-1):
if u == -1:
u = len(self.tables.nodes) - 1
# Store edges pointing to current node to the left
x = segment
Expand Down Expand Up @@ -1619,7 +1639,7 @@ def dtwf_generate_breakpoint(self, start):
bp = math.floor(bp)
return bp

def dtwf_recombine(self, x):
def dtwf_recombine(self, x, ind_nodes):
"""
Chooses breakpoints and returns segments sorted by inheritance
direction, by iterating through segment chain starting with x
Expand Down Expand Up @@ -1695,6 +1715,16 @@ def dtwf_recombine(self, x):
v = s.next
self.free_segment(s)

if self.additional_nodes.value & msprime.NODE_IS_RE_EVENT > 0:
re_event = all(segment is not None for segment in [u, v])
if re_event:
for ploid, segment in enumerate([u, v]):
ind_nodes[ploid] = self.store_additional_nodes_edges(
msprime.NODE_IS_RE_EVENT,
ind_nodes[ploid],
segment,
)

return u, v

def census_event(self, time):
Expand Down Expand Up @@ -1722,10 +1752,21 @@ def bottleneck_event(self, pop_id, label, intensity):
heapq.heappush(H, (x.left, x))
self.merge_ancestors(H, pop_id, label)

def store_additional_nodes_edges(self, flag, new_node_id, z):
if self.additional_nodes.value & flag > 0:
if new_node_id == -1:
new_node_id = self.store_node(z.population, flags=flag)
else:
self.flush_edges()
self.update_node_flag(new_node_id, flag)
self.store_arg_edges(z, new_node_id)
return new_node_id

def merge_ancestors(self, H, pop_id, label, new_node_id=-1):
pop = self.P[pop_id]
defrag_required = False
coalescence = False
pass_through = len(H) == 1
alpha = None
z = None
merged_head = None
Expand Down Expand Up @@ -1754,9 +1795,11 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1):
alpha = x
alpha.next = None
else:
coalescence = True
if new_node_id == -1:
coalescence = True
new_node_id = self.store_node(pop_id)
else:
self.flush_edges()
# We must also break if the next left value is less than
# any of the right values in the current overlap set.
if left not in self.S:
Expand All @@ -1777,7 +1820,8 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1):
alpha = self.alloc_segment(left, right, new_node_id, pop_id)
# Update the heaps and make the record.
for x in X:
self.store_edge(left, right, new_node_id, x.node)
if x.node != new_node_id: # required for dtwf and fixed_pedigree
self.store_edge(left, right, new_node_id, x.node)
if x.right == right:
self.free_segment(x)
if x.next is not None:
Expand Down Expand Up @@ -1807,11 +1851,20 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1):
z = alpha
if coalescence:
if not self.coalescing_segments_only:
self.store_arg_edges(z)
self.store_arg_edges(z, new_node_id)
else:
if self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0:
self.store_node(pop_id, flags=msprime.NODE_IS_CA_EVENT)
self.store_arg_edges(z)
if not pass_through:
if self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0:
new_node_id = self.store_additional_nodes_edges(
msprime.NODE_IS_CA_EVENT, new_node_id, z
)
else:
if self.additional_nodes.value & msprime.NODE_IS_PASS_THROUGH > 0:
assert new_node_id != -1
assert self.model == "fixed_pedigree"
new_node_id = self.store_additional_nodes_edges(
msprime.NODE_IS_PASS_THROUGH, new_node_id, z
)

if defrag_required:
self.defrag_segment_chain(z)
Expand Down Expand Up @@ -1855,7 +1908,7 @@ def common_ancestor_event(self, population_index, label):
y = pop.remove(j, label)
self.merge_two_ancestors(population_index, label, x, y)

def merge_two_ancestors(self, population_index, label, x, y):
def merge_two_ancestors(self, population_index, label, x, y, u=-1):
pop = self.P[population_index]
self.num_ca_events += 1
z = None
Expand Down Expand Up @@ -1888,8 +1941,11 @@ def merge_two_ancestors(self, population_index, label, x, y):
else:
if not coalescence:
coalescence = True
self.store_node(population_index)
u = len(self.tables.nodes) - 1
if u == -1:
self.store_node(population_index)
u = len(self.tables.nodes) - 1
else:
self.flush_edges()
# Put in breakpoints for the outer edges of the coalesced
# segment
left = x.left
Expand All @@ -1916,8 +1972,10 @@ def merge_two_ancestors(self, population_index, label, x, y):
population=population_index,
label=label,
)
self.store_edge(left, right, u, x.node)
self.store_edge(left, right, u, y.node)
if x.node != u: # required for dtwf and fixed_pedigree
self.store_edge(left, right, u, x.node)
if y.node != u: # required for dtwf and fixed_pedigree
self.store_edge(left, right, u, y.node)
# Now trim the ends of x and y to the right sizes.
if x.right == right:
self.free_segment(x)
Expand Down Expand Up @@ -1953,8 +2011,7 @@ def merge_two_ancestors(self, population_index, label, x, y):
self.store_arg_edges(z, u)
else:
if self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0:
u = self.store_node(population_index, flags=msprime.NODE_IS_CA_EVENT)
self.store_arg_edges(z, u)
self.store_additional_nodes_edges(msprime.NODE_IS_CA_EVENT, u, z)

if defrag_required:
self.defrag_segment_chain(z)
Expand Down Expand Up @@ -2293,7 +2350,7 @@ def add_simulator_arguments(parser):
help="Only record edges along coalescing segments.",
)
parser.add_argument(
"--additional_nodes",
"--additional-nodes",
type=int,
default=0,
help="Record edges along all segments for coalescing nodes.",
Expand Down

0 comments on commit 461dfd2

Please sign in to comment.