Skip to content

Commit

Permalink
Merge 55fccff into 55252ef
Browse files Browse the repository at this point in the history
  • Loading branch information
willgraf committed Jul 31, 2019
2 parents 55252ef + 55fccff commit 5063db3
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 1 deletion.
1 change: 1 addition & 0 deletions data_processing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from data_processing.processing import deepcell
from data_processing.processing import mibi
from data_processing.processing import watershed
from data_processing.processing import retinanet_to_label_image

del absolute_import
del division
Expand Down
160 changes: 160 additions & 0 deletions data_processing/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@

import numpy as np
from scipy import ndimage
from scipy.ndimage.morphology import distance_transform_edt
from skimage import morphology
from skimage.feature import peak_local_max
from skimage.measure import label
from skimage.transform import resize
from skimage.segmentation import random_walker, relabel_sequential


def noramlize(image):
Expand Down Expand Up @@ -174,3 +177,160 @@ def deepcell(prediction, threshold=.8):
labeled = morphology.remove_small_objects(
labeled, min_size=50, connectivity=1)
return labeled


def retinanet_to_label_image(retinanet_outputs,
score_threshold=0.5,
multi_iou_threshold=0.25,
binarize_threshold=0.5,
watershed_threshold=0.5,
small_objects_threshold=100):

def compute_iou(a, b):
"""Computes the IoU overlap of boxes in a and b.
Args:
a: (N, H, W) ndarray of float
b: (K, H, W) ndarray of float
Returns
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
"""
intersection = np.zeros((a.shape[0], b.shape[0]))
union = np.zeros((a.shape[0], b.shape[0]))
for index, mask in enumerate(a):
intersection[index, :] = np.sum(np.count_nonzero(
np.logical_and(b, mask), axis=1), axis=1)
union[index, :] = np.sum(np.count_nonzero(b + mask, axis=1), axis=1)

return intersection / union

boxes_batch = retinanet_outputs[-5]
scores_batch = retinanet_outputs[-4]
labels_batch = retinanet_outputs[-3]
masks_batch = retinanet_outputs[-2]
semantic_batch = retinanet_outputs[-1]

# Create empty label matrix
label_images = np.zeros(
(masks_batch.shape[0], semantic_batch.shape[1], semantic_batch.shape[2]))

# Iterate over batches
for i in range(boxes_batch.shape[0]):
boxes = boxes_batch[i]
scores = scores_batch[i]
labels = labels_batch[i]
masks = masks_batch[i]
semantic = semantic_batch[i]

# Get good detections
selection = np.nonzero(scores > score_threshold)[0]
boxes = boxes[selection]
scores = scores[selection]
labels = labels[selection]
masks = masks[selection, ..., -1]

# Compute overlap of masks with each other
mask_image = np.zeros((masks.shape[0], semantic.shape[0],
semantic.shape[1]), dtype='float32')

for j in range(masks.shape[0]):
mask = masks[j]
box = boxes[j].astype(int)
mask = resize(mask, (box[3] - box[1], box[2] - box[0]))
mask = (mask > binarize_threshold).astype('float32')
mask_image[j, box[1]:box[3], box[0]:box[2]] = mask

ious = compute_iou(mask_image, mask_image)

# Identify all the masks with no overlaps and
# add to the label matrix
summed_ious = np.sum(ious, axis=-1)
no_overlaps = np.where(summed_ious == 1)

masks_no_overlaps = mask_image[no_overlaps]
range_no_overlaps = np.arange(1, masks_no_overlaps.shape[0] + 1)
masks_no_overlaps *= np.expand_dims(
np.expand_dims(range_no_overlaps, axis=-1), axis=-1)

masks_concat = masks_no_overlaps

# If a mask has a big iou with two other masks, remove it
overlaps = np.where(summed_ious > 1)
bad_mask = np.sum(ious > multi_iou_threshold, axis=0)
good_overlaps = np.logical_and(summed_ious > 1, bad_mask < 3)
good_overlaps = np.where(good_overlaps == 1)

# Identify all the ambiguous pixels and resolve
# by performing marker based watershed using unambiguous
# pixels as the markers
masks_overlaps = mask_image[good_overlaps]
range_overlaps = np.arange(1, masks_overlaps.shape[0] + 1)
masks_overlaps_label = masks_overlaps * np.expand_dims(
np.expand_dims(range_overlaps, axis=-1), axis=-1)

masks_overlaps_sum = np.sum(masks_overlaps, axis=0)
ambiguous_pixels = np.where(masks_overlaps_sum > 1)
markers = np.sum(masks_overlaps_label, axis=0)

if np.sum(markers.flatten()) > 0:
markers[markers == 0] = -1
markers[ambiguous_pixels] = 0

foreground = masks_overlaps_sum > 0
segments = random_walker(foreground, markers)

masks_overlaps = np.zeros((np.amax(segments).astype(int),
masks_overlaps.shape[1],
masks_overlaps.shape[2]))

for j in range(1, masks_overlaps.shape[0] + 1):
masks_overlaps[j - 1] = segments == j

range_overlaps = np.arange(
masks_no_overlaps.shape[0] + 1,
masks_no_overlaps.shape[0] + masks_overlaps.shape[0] + 1)

masks_overlaps *= np.expand_dims(
np.expand_dims(range_overlaps, axis=-1), axis=-1)
masks_concat = np.concatenate([masks_concat, masks_overlaps], axis=0)

# Find peaks in watershed that are not within any
# box and perform watershed
semantic_argmax = np.argmax(semantic, axis=-1)
semantic_argmax *= np.sum(masks_concat, axis=0) < 1
foreground = semantic_argmax > 0

inner_most = semantic[..., -1] * (np.sum(masks_concat, axis=0) < 1)
local_maxi = inner_most > watershed_threshold

if np.sum(local_maxi.flatten()) > 0:
markers_semantic = label(local_maxi)
distance = semantic_argmax
segments_semantic = morphology.watershed(
-distance, markers_semantic, mask=foreground)
masks_semantic = np.zeros((np.amax(segments_semantic).astype(int),
semantic.shape[0], semantic.shape[1]))
for j in range(1, masks_semantic.shape[0] + 1):
masks_semantic[j - 1] = segments_semantic == j

range_semantic = np.arange(
masks_no_overlaps.shape[0] + masks_overlaps.shape[0] + 1,
masks_no_overlaps.shape[0] + masks_overlaps.shape[0] +
masks_semantic.shape[0] + 1)
masks_semantic *= np.expand_dims(
np.expand_dims(range_semantic, axis=-1), axis=-1)

masks_concat = np.concatenate([masks_concat, masks_semantic], axis=0)

label_image = np.sum(masks_concat, axis=0).astype(int)

# Remove small objects
label_image = morphology.remove_small_objects(
label_image, min_size=small_objects_threshold)

# Relabel the label image
label_image, _, _ = relabel_sequential(label_image)

# Store in batched array
label_images[i] = label_image

return label_images
53 changes: 53 additions & 0 deletions data_processing/processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from __future__ import print_function

import numpy as np
from skimage.measure import regionprops

from data_processing import processing

Expand All @@ -40,6 +41,51 @@ def _get_image(img_h=300, img_w=300):
return img


def _sample1(w, h, imw, imh):
"""Basic single cell synthetic sample"""
x = np.random.randint(0, imw - w * 2)
y = np.random.randint(0, imh - h * 2)

im = np.zeros((imw, imh))
im[x:x + w, y:y + h] = 1

# Randomly rotate to pick horizontal or vertical
if np.random.random() > 0.5:
im = np.rot90(im)

return im


def _retinanet_data(im):
n_batch = 1
n_det = 1
mask_size = 14 # Is this correct?
n_labels = 1

# boxes
rp = regionprops(im.astype(int))[0].bbox
boxes = np.zeros((n_batch, n_det, 4))
boxes[0, 0, :] = rp

# scores
scores = np.zeros((n_batch, n_det, 1))
scores[0, 0, 0] = np.random.rand()

# labels
labels = np.zeros((n_batch, n_det, n_labels))
labels[0, 0, 0] = 1

# masks
masks = np.ones((n_batch, n_det, mask_size, mask_size))

# semantic
semantic = np.zeros((n_batch, im.shape[0], im.shape[1], 4))
semantic[:, :, :] = processing.watershed(np.reshape(
im, (1, im.shape[0], im.shape[1], 1)))

return [boxes, scores, labels, masks, semantic]


def test_normalize():
height, width = 300, 300
img = _get_image(height, width)
Expand Down Expand Up @@ -67,3 +113,10 @@ def test_watershed():
img = np.random.rand(300, 300, channels)
watershed_img = processing.watershed(img)
np.testing.assert_equal(watershed_img.shape, (300, 300, 1))


def test_retinanet():
im = _sample1(10, 10, 40, 40)
out = _retinanet_data(im)

label = processing.retinanet_to_label_image(out)
3 changes: 2 additions & 1 deletion data_processing/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
'post': {
'deepcell': processing.deepcell,
'mibi': processing.mibi,
'watershed': processing.watershed
'watershed': processing.watershed,
'retinanet': processing.retinanet_to_label_image
},
}

0 comments on commit 5063db3

Please sign in to comment.