Skip to content

Commit

Permalink
Merge pull request #6 from seung-lab/wms_2d_ops
Browse files Browse the repository at this point in the history
feat: 2D operations for multilabel dilate and erode + fix
  • Loading branch information
william-silversmith committed Jun 22, 2024
2 parents d2340fd + 4dbb50e commit 54cd4e1
Show file tree
Hide file tree
Showing 4 changed files with 716 additions and 24 deletions.
105 changes: 103 additions & 2 deletions automated_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_spherical_close():
assert res[5,5,5] == True


def test_multilabel_dilate():
def test_multilabel_dilate_3d():
labels = np.zeros((3,3,3), dtype=bool)

out = fastmorph.dilate(labels)
Expand Down Expand Up @@ -116,8 +116,46 @@ def test_multilabel_dilate():
ans[1:,:,:] = 2
assert np.all(ans == out)

def test_multilabel_dilate_2d():
labels = np.zeros((3,3), dtype=bool)

def test_multilabel_erode():
out = fastmorph.dilate(labels)
assert not np.any(out)

labels[1,1] = True

out = fastmorph.dilate(labels)
assert np.all(out)

labels = np.zeros((3,3), dtype=bool)
labels[0,0] = True
out = fastmorph.dilate(labels)

ans = np.zeros((3,3), dtype=bool)
ans[:2,:2] = True

assert np.all(out == ans)

labels = np.zeros((3,3), dtype=int)
labels[0,1] = 1
labels[2,1] = 2

out = fastmorph.dilate(labels)
ans = np.ones((3,3), dtype=int)
ans[2,:] = 2
assert np.all(ans == out)

labels = np.zeros((3,3), dtype=int, order="F")
labels[0,1] = 1
labels[1,1] = 2
labels[2,1] = 2

out = fastmorph.dilate(labels)
ans = np.ones((3,3), dtype=int, order="F")
ans[1:,:] = 2
assert np.all(ans == out)

def test_multilabel_erode_3d():
labels = np.ones((3,3,3), dtype=bool)
out = fastmorph.erode(labels)
assert np.sum(out) == 1 and out[1,1,1] == True
Expand Down Expand Up @@ -147,11 +185,54 @@ def test_multilabel_erode():
out = fastmorph.erode(labels)
assert np.sum(out) == 27

def test_multilabel_erode_2d():
labels = np.ones((3,3), dtype=bool)
out = fastmorph.erode(labels)
assert np.sum(out) == 1 and out[1,1] == True

out = fastmorph.erode(out)
assert not np.any(out)

out = fastmorph.erode(out)
assert not np.any(out)

labels = np.ones((3,3), dtype=int, order="F")
labels[0,:] = 1
labels[1,:] = 2
labels[2,:] = 3

out = fastmorph.erode(labels)
ans = np.zeros((3,3), dtype=int, order="F")

assert np.all(ans == out)

labels = np.zeros((5,5), dtype=bool)
labels[1:4,1:4] = True
out = fastmorph.erode(labels)
assert np.sum(out) == 1 and out[2,2] == True

labels = np.ones((5,5), dtype=bool)
out = fastmorph.erode(labels)
assert np.sum(out) == 9

@pytest.mark.parametrize('dtype', [
np.uint8,np.uint16,np.uint32,np.uint64,
np.int8,np.int16,np.int32,np.int64,
])
def test_grey_erode(dtype):
labels = np.arange(9, dtype=dtype).reshape((3,3), order="F")
out = fastmorph.erode(labels, mode=fastmorph.Mode.grey)

ans = np.array([
[0, 0, 1],
[0, 0, 1],
[3, 3, 4],
], dtype=dtype).T
assert np.all(out == ans)

out = fastmorph.erode(out, mode=fastmorph.Mode.grey)
assert np.all(out == 0)

labels = np.arange(27, dtype=dtype).reshape((3,3,3), order="F")
out = fastmorph.erode(labels, mode=fastmorph.Mode.grey)

Expand Down Expand Up @@ -186,6 +267,20 @@ def test_grey_dilate(dtype):
L = 5
H = 10

labels = np.zeros((3,3), dtype=dtype)
labels[0,0] = L
labels[2,2] = H

out = fastmorph.dilate(labels, mode=fastmorph.Mode.grey)

ans = np.array([
[L, L, 0],
[L, H, H],
[0, H, H],
], dtype=dtype).T

assert np.all(out == ans)

labels = np.zeros((3,3,3), dtype=dtype)
labels[0,0,0] = L
labels[2,2,2] = H
Expand Down Expand Up @@ -222,4 +317,10 @@ def test_grey_dilate_bool():
out = fastmorph.dilate(labels, mode=fastmorph.Mode.grey)
assert np.all(out == True)

labels = np.zeros((3,3), dtype=bool)
labels[1,1] = True

out = fastmorph.dilate(labels, mode=fastmorph.Mode.grey)
assert np.all(out == True)


7 changes: 4 additions & 3 deletions fastmorph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ def dilate(
parallel = min(parallel, mp.cpu_count())

labels = np.asfortranarray(labels)
while labels.ndim < 3:
while labels.ndim < 2:
labels = labels[..., np.newaxis]

if mode == Mode.multilabel:
output = fastmorphops.multilabel_dilate(labels, background_only, parallel)
else:
Expand All @@ -66,7 +67,7 @@ def erode(
parallel = min(parallel, mp.cpu_count())

labels = np.asfortranarray(labels)
while labels.ndim < 3:
while labels.ndim < 2:
labels = labels[..., np.newaxis]

if mode == Mode.multilabel:
Expand Down Expand Up @@ -264,7 +265,7 @@ def fill_holes(

if return_fill_count:
for label in removed_set:
del fill_counts[label]
fill_counts.pop(label, None)
ret.append(fill_counts)

if return_removed:
Expand Down
Loading

0 comments on commit 54cd4e1

Please sign in to comment.