Skip to content

Commit

Permalink
Merge pull request #1063 from petrelharp/mask_intervals
Browse files Browse the repository at this point in the history
mask_intervals utility.
  • Loading branch information
andrewkern committed Oct 28, 2021
2 parents 9454010 + 5988f1f commit 06f5a7e
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
36 changes: 36 additions & 0 deletions stdpopsim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,39 @@ def build_intervals_array(intervals, start=0, end=np.inf):
intervals = intervals[sorter]
check_intervals_validity(intervals, start, end)
return intervals


def mask_intervals(intervals, mask):
"""
Removes the intervals of ``mask`` from those in ``intervals``.
Both ``intervals`` and ``mask`` should be sets of intervals
(i.e., sorted numpy arrays with two columns); then the result will
be the the intervals in ``intervals`` except with those in
``mask`` removed.
"""
check_intervals_validity(intervals)
check_intervals_validity(mask)
out = []
last_mask_right = -1 * np.inf
j = 0
if j < len(mask):
next_mask_left, next_mask_right = mask[j]
else:
# the mask is of zero length, return for efficiency
return intervals
for inter in intervals:
left, right = inter
while left < right:
while left >= next_mask_left:
last_mask_right = next_mask_right
j += 1
if j < len(mask):
next_mask_left, next_mask_right = mask[j]
else:
next_mask_left, next_mask_right = np.inf, np.inf
next_left = max(left, last_mask_right)
next_right = min(right, next_mask_left)
if next_left < next_right:
out.append([next_left, next_right])
left = max(next_left, next_right)
return np.array(out).reshape((len(out), 2))
52 changes: 52 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Tests for the utils module.
"""
import functools
import os
import pathlib
import tarfile
Expand Down Expand Up @@ -497,3 +498,54 @@ def test_interval_sorting(self):
casted = utils.build_intervals_array(intervals)
assert not (np.all(np.diff(intervals[:, 0]) >= 0))
assert np.all(np.diff(casted[:, 0]) >= 0)

def intervals_to_set(self, intervals):
# note: we may want to use intervaltree for more complicated operations
# in the future
utils.check_intervals_validity(intervals)
return functools.reduce(
lambda a, b: a | b, [set(range(a, b)) for a, b in intervals], set()
)

def test_mask_intervals(self):
empty_array = np.array([]).reshape((0, 2))
for a, b in [
(empty_array, [[1, 2]]),
(
[[0, 5], [5, 10]],
[[0, 10]],
),
(
[[3, 4], [5, 12]],
[[0, 10]],
),
(
[[0, 3], [3, 4], [5, 9]],
[[1, 10]],
),
(
[[k, k + 1] for k in range(12)],
[[0, 10]],
),
(
[[1, 4], [6, 9], [11, 13], [14, 17], [19, 21]],
[[0, 3], [5, 9], [11, 13], [15, 17], [20, 22]],
),
(
[[1, 2], [6, 7], [11, 12], [14, 15], [19, 20], [20, 21], [22, 23]],
[[0, 3], [5, 7], [11, 12], [14, 16], [19, 23]],
),
]:
for u, v in [(a, b), (b, a), (a, empty_array), (empty_array, a)]:
u = np.array(u)
v = np.array(v)
umv = utils.mask_intervals(u, v)
x = self.intervals_to_set(umv)
y = self.intervals_to_set(u) - self.intervals_to_set(v)
assert x == y

def test_mask_intervals_errors(self):
with pytest.raises(ValueError):
utils.mask_intervals(intervals=np.array([[50, 10]]), mask=np.array([[]]))
with pytest.raises(ValueError):
utils.mask_intervals(intervals=np.array([[]]), mask=np.array([[10, 5]]))

0 comments on commit 06f5a7e

Please sign in to comment.