diff --git a/README.md b/README.md
index 6168a0f..1449722 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,139 @@
-# segmentify
+# Segmentify
-Python image segmentation plugin
+Segmentify is an interactive and general purpose cell segmentation plugin for the multi-dimentional image viewer [Napari](https://github.com/napari/napari).
+
+

+
+In the example above, the user is using segmentify to segment out the nucleus, cytoplasm and the background for all the cells in the image. The user uses the paint brush tool to label some training examples on the **trian layer**. The entire image is featurized using a pretrained featurizer, and the selected examples are used to train a Random Forest Classifier. Lastly, the trained classifier is used to predict the label for all the remaining unlabeled pixels, and the segmentation output is displayed at the **output layer**.
+
+## Installation
+
+Segmentify can be installed using **pip**
+
+```
+pip install segmentify
+```
+
+## Launching Segmentify Viewer
+
+Segmentify's viewer can be launched either using the command line interface or a python script.
+
+
+### Segmentify Command Line Interface
+
+The Segmentify Viewer can be launched from the command line by entering:
+
+```
+segmentify
+```
+
+### Segmentify Script
+
+To open Segmentify's Viewer from a python script, enter the following scripts:
+
+```python
+from segmentify import Viewer, gui_qt, util
+
+img = util.parse_img()
+
+with gui_qt():
+ viewer = Viewer(img)
+```
+
+An example can be found [here](./examples/viewer_example.py)
+
+## Featurization
+
+Segmentify works by featurizing input images using either pretrained UNet models or some classical image filters. Users can then use the paint brush to label some of the pixels to train a Random Forest Classifier. The trained classifier is used to predict the label for all remaining unlabeled pixels. The training labels should be provided in the **train layer** and the segmentation will be displayed on the **output layer**. Segmentify includes the following featurization strategies:
+- HPA\_4: Trained by decomposing images from the [Human Protein Atlas](https://www.kaggle.com/c/human-protein-atlas-image-classification) into Nucleus, ER, Protein and Cytoplasm
+- HPA\_3: Trained by decomposing images from the [Human Protein Atlas](https://www.kaggle.com/c/human-protein-atlas-image-classification) into Nucleus, ER, Cytoplasm
+- HPA: Trained by segmenting out the nucleus in images from the [Human Protein Atlas](https://www.kaggle.com/c/human-protein-atlas-image-classification)
+
+- Nuclei: Trained by segmenting out the nucleus in images from the [Kaggle Data Science Bowl 2018](https://www.kaggle.com/c/data-science-bowl-2018/overview)
+- Filter: Combining several classical image filter (Gaussian, Sobel, Laplace, Gabor, Canny)
+
+For more information about the training strategies for each featurizer, please refer to the notebooks [here](https://github.com/marshuang80/CellSegmentation/tree/master/notebooks)
+
+### Training your own featurizer
+
+Users can also train their own image featurizers using [Cell Segmentation](https://github.com/marshuang80/CellSegmentation). The trained model should be placed in the *./segmentify/model/saved_model* directory.
+
+## Key Bindings
+
+The following is a list of Segmentify supported functions along with their key bindings:
+- Segmentation (Shift-S)
+- Uncertainty Heatmap (Shift-H)
+- Next Featurizer (Shif-N)
+- Dilation (Shift-D)
+- Erosion (Shift-E)
+- Close (Shift-C)
+- Open (Shift-O)
+- Fill holes (Shift-F)
+- Next featurizer (Shift-N)
+
+> For all image morphology steps (dilation, erosion, close, open, fill holes), **the operations will only be applied to the selected label for the selected layer**. For example, if you want the blue labels in the segmented images to dilate, you will have to click on the *output layer* and increment to the desired label.
+
+| Morphology on labels | Morphology on segmentation |
+| --- | --- |
+|  |  |
+
+
+### Segmentation (Shift-S)
+
+Some labeled training examples must be provided before segmentify can segment the input images. The training labels can be provided by using the paint brush on the left-hand side of the Viewer to paint on the image. Please make that you are incrementing the labels in the **train** layer.
+
+
+
+The segmented output can be found on the **output** layer.
+
+### Uncertainty Heatmap (Shift-H)
+
+The Segmentify Viewer can also output a heatmap of areas where the model is having low confidence in it's segmentation. Users can relabel these areas to help improve the model's prediction. The uncertainty heatmap is generated by calculating the normalized entropy for the prediction probabilities. Similar prediction probability for all classes indicates that the model does not have confidence in one particular class, which results in low entropy; whereas high probability for one particular class results in high entropy. Using normalized entropy allow us to generate the uncertainty heatmap even when there are many segmentation classes. The lower the entropy for a pixel (high uncertainty), the brighter it will appear in red on the heatmap.
+
+To show the uncertainty heatmap, set the heatmap parameter to true when defining the Viewer:
+
+```python
+Viewer(imgs, heatmap=True)
+```
+
+If you are using the Command Line Interface, set the heatmap flag to True:
+
+```
+segmentify --heatmap True
+```
+
+
+
+
+### Next Featurizer (Shift-N)
+
+This feature cycles through all the featurizers in *segmentify/model/saved_model* as well as the filter featurizer. The name of the selected featurizer is shonw in the bottom left corner of the viewer. Note that after switching to the next featurizer, the user still needs to press Shift-S to re-segment the image.
+
+
+
+### Dilation (Shift-D)
+
+The [dilation operation](https://homepages.inf.ed.ac.uk/rbf/HIPR2/dilate.htm) expands all connected components with the selected label.
+
+
+
+### Erosion (Shift-E)
+
+The [erosion operation](https://homepages.inf.ed.ac.uk/rbf/HIPR2/erode.htm) shrinks all connected components with the selected label.
+
+
+
+### Fill Holes (Shift-F)
+
+This operation fills holes within a connected component for the selected label.
+
+
+
+### Close (Shift-C)
+
+The [close operation](https://homepages.inf.ed.ac.uk/rbf/HIPR2/close.htm)is done by applying dilation on the connected components with the selected label, following by an erosion.
+
+### Open (Shift-O)
+
+The [open operation](https://homepages.inf.ed.ac.uk/rbf/HIPR2/open.htm) is done by applying erosion on the connected components with the selected label, following by an dilation.
-#### **Development Status :: Pre-Alpha**
diff --git a/examples/image_stack.tif b/examples/image_stack.tif
deleted file mode 100644
index bccaabe..0000000
Binary files a/examples/image_stack.tif and /dev/null differ
diff --git a/examples/instance.py b/examples/instance.py
deleted file mode 100644
index 79dd8f1..0000000
--- a/examples/instance.py
+++ /dev/null
@@ -1,39 +0,0 @@
-"""
-Perform interactive semantic segmentation
-"""
-
-import numpy as np
-from napari import Viewer, gui_qt
-from skimage import data
-from segmentify import semantic
-from segmentify import instances
-
-
-coins = data.coins()
-labels = np.zeros(coins.shape, dtype=int)
-
-
-with gui_qt():
-
- # create an empty viewer
- viewer = Viewer()
-
- viewer.add_image(coins, name='input', colormap='gray')
-
- # add empty labels
- viewer.add_labels(labels, name='classes')
- viewer.add_labels(labels, name='instances')
- viewer.add_labels(labels, name='train')
- viewer.layers['train'].opacity = 0.9
-
- @viewer.bind_key('s')
- def segment(viewer):
- image = viewer.layers['input'].data
- labels = viewer.layers['train'].data
- clf = semantic.fit(image, labels)
- segmentation = semantic.predict(clf, image)
- print('classes', len(np.unique(segmentation)))
- instances = instances.predict(image, segmentation)
- print('instances', instances.max())
- viewer.layers['classes'].data = segmentation
- viewer.layers['instances'].data = instances
diff --git a/examples/nuclei.png b/examples/nuclei.png
deleted file mode 100644
index a116c9f..0000000
Binary files a/examples/nuclei.png and /dev/null differ
diff --git a/examples/semantic.py b/examples/semantic.py
deleted file mode 100644
index 1d500d2..0000000
--- a/examples/semantic.py
+++ /dev/null
@@ -1,42 +0,0 @@
-"""
-Perform interactive semantic segmentation
-"""
-import numpy as np
-from napari import Viewer, gui_qt
-from skimage import data
-from skimage.color import rgb2gray
-from segmentify.semantic import fit, predict
-import napari
-import imageio
-import os
-print(napari.__version__)
-
-
-
-with gui_qt():
-
- # create an empty viewer
- viewer = Viewer()
-
- # read in sample data
- example_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), "hpa.png")
- nuclei = imageio.imread(example_file)
- nuclei = rgb2gray(nuclei)
- labels = np.zeros(nuclei.shape, dtype=int)
-
- viewer.add_image(nuclei, name='input')
-
- # add empty labels
- viewer.add_labels(labels, name='output')
- viewer.add_labels(labels, name='train')
- viewer.layers['train'].opacity = 0.9
-
- @viewer.bind_key('s')
- def segment(viewer):
- image = viewer.layers['input'].data
- labels = viewer.layers['train'].data
-
- clf, features = fit(image, labels)
- segmentation = predict(clf, features)
- segmentation = np.squeeze(segmentation)
- viewer.layers['output'].data = segmentation
diff --git a/examples/semantic_3d.py b/examples/semantic_3d.py
deleted file mode 100644
index 6edcc16..0000000
--- a/examples/semantic_3d.py
+++ /dev/null
@@ -1,43 +0,0 @@
-"""
-Perform interactive semantic segmentation on 3D volume
-"""
-import numpy as np
-import napari
-import imageio
-import os
-from napari import Viewer, gui_qt
-from skimage import data, io
-from skimage.color import rgb2gray
-from segmentify.semantic import fit, predict
-
-print(napari.__version__)
-
-
-with gui_qt():
-
- # create an empty viewer
- viewer = Viewer()
-
- # read in sample data
- example_file = os.path.join(os.path.abspath(os.path.dirname(__file__)),"image_stack.tif")
- nuclei = io.imread(example_file)
- #nuclei = np.transpose(nuclei, (2,0,1))
- labels = np.zeros(nuclei.shape, dtype=int)
-
- viewer.add_image(nuclei, name='input')
-
- # add empty labels
- viewer.add_labels(labels, name='output')
- viewer.add_labels(labels, name='train')
- viewer.layers['train'].opacity = 0.9
-
- @viewer.bind_key('s')
- def segment(viewer):
- image = viewer.layers['input'].data
- labels = viewer.layers['train'].data
-
- clf, features = fit(image, labels)
-
- segmentation = predict(clf, features)
- segmentation = np.squeeze(segmentation)
- viewer.layers['output'].data = segmentation
diff --git a/examples/viewer_example.py b/examples/viewer_example.py
new file mode 100644
index 0000000..60f51a8
--- /dev/null
+++ b/examples/viewer_example.py
@@ -0,0 +1,10 @@
+import os
+from segmentify import Viewer, gui_qt, util
+
+# parse input file
+example_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), "hpa.png")
+img = util.parse_img(example_file)
+
+with gui_qt():
+ viewer = Viewer(img)
+
diff --git a/examples/watershed.py b/examples/watershed.py
deleted file mode 100644
index e126e61..0000000
--- a/examples/watershed.py
+++ /dev/null
@@ -1,65 +0,0 @@
-"""
-Perform interactive semantic segmentation
-"""
-
-"""
-Display a labels layer above of an image layer using the add_labels and
-add_image APIs
-"""
-
-from napari import Viewer, gui_qt
-import numpy as np
-from scipy import ndimage as ndi
-from skimage import data
-from skimage.morphology import watershed
-from skimage.feature import peak_local_max
-
-# Generate an initial image with blobs
-blobs = data.binary_blobs(length=256, blob_size_fraction=0.1, n_dim=2,
- volume_fraction=.2, seed=999)
-
-
-# Now we want to separate the two objects in image
-# Generate the markers as local maxima of the distance to the background
-distance = ndi.distance_transform_edt(blobs)
-
-local_maxi = peak_local_max(distance, indices=True, footprint=np.ones((5, 5)),
- labels=blobs)
-
-local_maxi_image = np.zeros(blobs.shape, dtype='bool')
-for cord in local_maxi:
- local_maxi_image[tuple(cord)] = True
-markers = ndi.label(local_maxi_image)[0]
-
-labels = watershed(-distance, markers, mask=blobs)
-
-
-with gui_qt():
-
- # create an empty viewer
- viewer = Viewer()
-
- # add the input image
- viewer.add_image(blobs.astype('float'), name='input', colormap='gray')
-
- # add the distance image
- viewer.add_image(distance, name='distance', colormap='gray')
-
- # add the resulting labels image
- viewer.add_labels(labels, name='output')
-
- # add the points
- viewer.add_points(local_maxi, face_color='blue', size=3, name='markers')
-
- @viewer.bind_key('r')
- def rerun(viewer):
- blobs = viewer.layers['input'].data
- distance = viewer.layers['distance'].data
- local_maxi = viewer.layers['markers'].data
- print('Number of markers: ', len(local_maxi))
- local_maxi_image = np.zeros(blobs.shape, dtype='bool')
- for cord in local_maxi:
- local_maxi_image[tuple(np.round(cord).astype(int))] = True
- markers = ndi.label(local_maxi_image)[0]
- labels = watershed(-distance, markers, mask=blobs)
- viewer.layers['output'].data = labels
diff --git a/figs/dilation.gif b/figs/dilation.gif
new file mode 100644
index 0000000..a1eabcd
Binary files /dev/null and b/figs/dilation.gif differ
diff --git a/figs/erosion.gif b/figs/erosion.gif
new file mode 100644
index 0000000..f9a9533
Binary files /dev/null and b/figs/erosion.gif differ
diff --git a/figs/fill_holes.gif b/figs/fill_holes.gif
new file mode 100644
index 0000000..f0182cc
Binary files /dev/null and b/figs/fill_holes.gif differ
diff --git a/figs/heatmap.gif b/figs/heatmap.gif
new file mode 100644
index 0000000..36afd70
Binary files /dev/null and b/figs/heatmap.gif differ
diff --git a/figs/intro.gif b/figs/intro.gif
new file mode 100644
index 0000000..86487c2
Binary files /dev/null and b/figs/intro.gif differ
diff --git a/figs/label_morph.gif b/figs/label_morph.gif
new file mode 100644
index 0000000..10276b0
Binary files /dev/null and b/figs/label_morph.gif differ
diff --git a/figs/next.gif b/figs/next.gif
new file mode 100644
index 0000000..5c9b05d
Binary files /dev/null and b/figs/next.gif differ
diff --git a/figs/seg_example.gif b/figs/seg_example.gif
new file mode 100644
index 0000000..208684a
Binary files /dev/null and b/figs/seg_example.gif differ
diff --git a/figs/seg_morph.gif b/figs/seg_morph.gif
new file mode 100644
index 0000000..64b3252
Binary files /dev/null and b/figs/seg_morph.gif differ
diff --git a/requirements.txt b/requirements.txt
index 8f32774..3f80f08 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,11 @@
-napari >= 0.0.8
+imageio>=2.5.0
+napari>=0.1.4
+numba>=0.45.1
numpy>=1.10.0
scipy>=1.2.0
scikit-image>=0.15.0
scikit-learn>=0.21.2
+vispy>=0.6.1
+torch>=1.2.0
+torchvision>=1.2.0
+
diff --git a/segmentify/__init__.py b/segmentify/__init__.py
index c57bfd5..af61e0f 100644
--- a/segmentify/__init__.py
+++ b/segmentify/__init__.py
@@ -1 +1,5 @@
__version__ = '0.0.0'
+
+from .viewer import Viewer
+from . import util
+from napari import gui_qt
diff --git a/segmentify/main.py b/segmentify/main.py
new file mode 100644
index 0000000..d1d875c
--- /dev/null
+++ b/segmentify/main.py
@@ -0,0 +1,32 @@
+"""segmentify command line viewer.
+"""
+
+import argparse
+import numpy as np
+import imageio
+from napari import gui_qt
+from segmentify import Viewer, util
+
+
+def main(args):
+ """The main Command Line Interface for Segmentify"""
+
+ # parse in images
+ imgs = [util.parse_img(img) for img in args.images]
+
+ if len(imgs) > 1:
+ imgs = np.stack(imgs, axis=0)
+ else:
+ imgs = np.array(imgs)
+
+ with gui_qt():
+ viewer = Viewer(imgs, heatmap=args.heatmap)
+
+
+if __name__ == "__main__":
+ # parser
+ parser = argparse.ArgumentParser()
+ parser.add_argument("images", nargs="*", type=str, help="Image to view and segment.")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/segmentify/model/saved_model/UNet_hpa_max.pth b/segmentify/model/saved_model/HPA.pth
similarity index 100%
rename from segmentify/model/saved_model/UNet_hpa_max.pth
rename to segmentify/model/saved_model/HPA.pth
diff --git a/segmentify/model/saved_model/HPA_3.pth b/segmentify/model/saved_model/HPA_3.pth
new file mode 100644
index 0000000..5aedb6b
Binary files /dev/null and b/segmentify/model/saved_model/HPA_3.pth differ
diff --git a/segmentify/model/saved_model/HPA_4.pth b/segmentify/model/saved_model/HPA_4.pth
new file mode 100644
index 0000000..6c64c70
Binary files /dev/null and b/segmentify/model/saved_model/HPA_4.pth differ
diff --git a/segmentify/model/saved_model/UNet_nuclei.pth b/segmentify/model/saved_model/Nuclei.pth
similarity index 98%
rename from segmentify/model/saved_model/UNet_nuclei.pth
rename to segmentify/model/saved_model/Nuclei.pth
index 2a867a5..cdf6dbd 100644
Binary files a/segmentify/model/saved_model/UNet_nuclei.pth and b/segmentify/model/saved_model/Nuclei.pth differ
diff --git a/segmentify/semantic/main.py b/segmentify/semantic/main.py
index 7f391a2..0a4541d 100644
--- a/segmentify/semantic/main.py
+++ b/segmentify/semantic/main.py
@@ -7,26 +7,21 @@
from ..model import UNet, layers
from skimage import feature
-def _load_model(pretrained_model):
+def _load_model(featurizer_path):
"""Load the featurization model
Parameters
----------
- model_path: str
+ featurizer_path: str
Path to the saved model file
+
+ Returns
+ -------
+ The loaded PyTorch model
"""
- ## TODO better way to store and define file path
- file_name = os.path.abspath(os.path.dirname(__file__))
- if pretrained_model == "HPA":
- model_path = os.path.join(file_name,"..","model","saved_model","UNet_hpa_max.pth")
- elif pretrained_model == "nuclei":
- model_path = os.path.join(file_name,"..","model","saved_model","UNet_nuclei.pth")
- else:
- raise ValueError("pretrained model not defined")
# load in saved model
- # TODO allow gpu -> issue
- pth = torch.load(model_path)
+ pth = torch.load(featurizer_path)
model_args = pth['model_args']
model_state = pth['model_state']
model = UNet(**model_args)
@@ -40,14 +35,14 @@ def _load_model(pretrained_model):
return model
-def unet_featurize(image, pretrained_model="HPA"):
+def unet_featurize(image, featurizer_path):
"""Featurize pixels in an image using pretrained UNet
Parameters
----------
image: numpy.ndarray
Image data to be featurized
- pretrained_model: str (HPA)
+ featurizer_path: str (HPA)
name of the pretraine model to use for featurization
Returns
@@ -56,7 +51,7 @@ def unet_featurize(image, pretrained_model="HPA"):
One feature vector per pixel in the image
"""
- model = _load_model(pretrained_model)
+ model = _load_model(featurizer_path)
image = torch.Tensor(image).float()
@@ -95,7 +90,7 @@ def filter_featurize(image):
-def fit(image, labels, multichannel=False, featurizer="HPA"):
+def fit(image, labels, featurizer="../model/saved_model/UNet_hpa_4c_mean_8.pth"):
"""Train a pixel classifier.
Parameters
@@ -113,6 +108,7 @@ def fit(image, labels, multichannel=False, featurizer="HPA"):
classifier: sklearn.ensemble.RandomForestClassifier
Object that can perform classifications
"""
+ print(featurizer)
# pad input image
w,h = image.shape[-2:]
w_padding = int((16-w%16)/2) if w%16 >0 else 0
@@ -122,32 +118,31 @@ def fit(image, labels, multichannel=False, featurizer="HPA"):
elif len(image.shape) == 2:
image = np.pad(image, ((w_padding, w_padding),(h_padding, h_padding)), 'constant')
- clf = RandomForestClassifier(n_estimators=10)
-
- # TODO should this be elsewhere?
+ # make sure image has four dimentions (b,c,w,h)
while len(image.shape) < 4:
image = np.expand_dims(image, 0)
image = np.transpose(image, (1,0,2,3))
- # TODO better way to choose featurizer
+ # choose filter or unet featurizer
if featurizer == "filter":
- print("filter")
features = filter_featurize(image)
else:
features = unet_featurize(image, featurizer)
-
# crop out paddings
if w_padding > 0:
features = features[w_padding:-w_padding]
if h_padding > 0:
features = features[:,h_padding:-h_padding]
+ # reshape and extract data
X = features.reshape([-1, features.shape[-1]])
y = labels.reshape(-1)
X = X[y != 0]
y = y[y != 0]
+ # define and fit classifier
+ clf = RandomForestClassifier(n_estimators=10)
if len(X) > 0:
clf = clf.fit(X, y)
@@ -175,9 +170,15 @@ def predict(classifier, features):
X = features.reshape([-1, features.shape[-1]])
try:
+ # get prediction and probability
y = classifier.predict(X)
+ prob = classifier.predict_proba(X)
labels = y.reshape(features.shape[:-1])
+ prob_shape = features.shape[:-1] + (prob.shape[-1], )
+ prob = prob.reshape(prob_shape)
+
except:
# If classifer has not yet been fit return zeros
labels = np.zeros(features.shape[:-1], dtype=int)
- return labels
+ prob = np.zeros(features.shape[:-1], dtype=int)
+ return labels, prob
diff --git a/segmentify/util.py b/segmentify/util.py
new file mode 100644
index 0000000..c6b1e70
--- /dev/null
+++ b/segmentify/util.py
@@ -0,0 +1,130 @@
+import imageio
+import numpy as np
+import numba
+import math
+from scipy import stats
+
+def parse_img(img):
+ """parse in a input image into the standardized format for segmentify
+
+ Parameters
+ ----------
+ img : np.array
+ The image to be parsed
+
+ Returns
+ -------
+ The parsed image
+ """
+
+ # read in image
+ img = imageio.imread(img)
+
+ # check dimention order
+ if img.shape[-1] <=5:
+ img = np.transpose(img, (2,0,1))
+
+ # normalize to 0 - 1
+ if img.max() > 1:
+ img = img / 255.
+
+ # use only one dimention
+ if img.shape[0] > 1:
+ img = np.mean(img, axis=0)
+
+ return img
+
+@numba.jit(nopython=True, fastmath=True, cache=True)
+def _get_mode(target_region):
+ """get the mode of list
+
+ Parameters:
+ -----------
+ target_region: list(int)
+ A list of pixels
+
+ Returns
+ -------
+ The mode of a given list
+ """
+ counter = 0
+ target = target_region[0]
+ for t in target_region:
+ curr_freq = target_region.count(t)
+ if curr_freq > counter:
+ counter = curr_freq
+ target = t
+ return target
+
+@numba.jit(fastmath=True, cache=True)
+def _erode_img(img, target_label):
+ """multi-class image erosion
+
+ This function performs a multi class by finding the mode in a sliding kernel
+
+ Parameters
+ ----------
+ img: np.array
+ The image to be eroded
+ target_label: int
+ The label of the pixels to be eroded
+
+ Returns
+ -------
+ The eroded image
+ """
+
+ rows, cols = img.shape
+
+ # padd image with sides of original image
+ padded_img = np.zeros((rows+2, cols+2))
+ padded_img[1:-1,1:-1] = img
+ padded_img[:,0] = padded_img[:,1]
+ padded_img[:,-1] = padded_img[:,-2]
+ padded_img[0,:] = padded_img[1,:]
+ padded_img[-1,:] = padded_img[-2,:]
+
+ output_img = np.zeros_like(img)
+
+ for r in range(rows):
+ for c in range(cols):
+ # use a 3x3 square sliding window
+ region = np.copy(padded_img[r:r+3, c:c+3])
+ region_flattern = np.copy(np.reshape(region,-1))
+
+ # use the target label if no other labels exisit in the sliding window
+ if np.all(region_flattern == target_label):
+ output_img[r,c] = target_label
+ else:
+ target_region = [v for v in region_flattern \
+ if v != target_label]
+
+ if len(target_region) == 0:
+ continue
+
+ target = _get_mode(target_region)
+ output_img[r,c] = target
+ return output_img
+
+
+def _norm_entropy(probs):
+ """get the normalized entropy based on a list of proabilities
+
+ Parameters
+ ----------
+ probs: list
+ list of probabilities
+
+ Returns
+ -------
+ normalized entropy of the probabilities
+ """
+
+ entropy = 0
+ for prob in probs:
+ if prob > 0:
+ entropy += prob * math.log(prob, math.e)
+ else:
+ entropy += 0
+ return - entropy / len(probs)
+
diff --git a/segmentify/viewer.py b/segmentify/viewer.py
new file mode 100644
index 0000000..c149fe5
--- /dev/null
+++ b/segmentify/viewer.py
@@ -0,0 +1,295 @@
+"""Segmentify Viewer Class with key bindings
+"""
+
+import numpy as np
+import math
+import os
+from . import util
+from napari import Viewer as NapariViewer
+from segmentify.semantic import fit, predict
+from vispy.color import Colormap
+from skimage import morphology, measure
+from itertools import cycle
+from scipy import stats
+
+class Viewer(NapariViewer):
+ """viewer for segmentify
+
+ A NapariViewer based viewer with predefined keybindings for segmenting and post-processing
+
+ Parameters
+ ----------
+ img : np.array
+ Input image in numpy array format. If no image is passed in, matrix of zeros is used
+ heatmap: bool
+ If a probability heatmap layer should be created
+ """
+
+ def __init__(self, img=None):
+
+ super(Viewer, self).__init__()
+
+ # use empty image if none provided
+ if img is None:
+ self.img = np.zeros((256,256))
+ else:
+ self.img = img
+
+ # class variables
+ self.min_object_size = 25
+ self.background_label = 1
+ self.labels = np.zeros(self.img.shape, dtype=int)
+ self.add_image(self.img, name='input')
+ self.selem = morphology.selem.square(3)
+ self.prob = None
+ self.segmentation = None
+
+ if len(img.shape) > 2:
+ self.selem = np.array([self.selem])
+
+ # create empty heatmap label
+ self.probability_heatmap = self.add_image(self.labels.astype(float), \
+ name="prediction probability")
+ self.probability_heatmap.opacity = 0.0
+ self.colormap = Colormap([[0.0,0.0,0.0,0.0],[1.0,0.0,0.0,1.0],[0.0,0.0,0.0,0.0]])
+
+ # add label layers
+ self.add_labels(self.labels, name='output')
+ self.add_labels(self.labels, name='train')
+ self.layers['train'].opacity = 0.9
+
+ # define featurizers
+ featurizer_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)),"model","saved_model")
+ featurizer_paths = os.listdir(featurizer_dir)
+ featurizer_paths = sorted([os.path.join(featurizer_dir,path) for path in featurizer_paths])
+ featurizer_paths.append("filter")
+ self.featurizers = cycle(featurizer_paths)
+ self.cur_featurizer = next(self.featurizers)
+ self.status = self.cur_featurizer.split("/")[-1]
+
+ # key-bindings
+ self.bind_key('Shift-S', self.segment)
+ self.bind_key('Shift-H', self.show_heatmap)
+ self.bind_key('Shift-N', self.next_featurizer)
+ self.bind_key('Shift-D', self.dilation)
+ self.bind_key('Shift-E', self.erosion)
+ self.bind_key('Shift-C', self.closing)
+ self.bind_key('Shift-O', self.opening)
+ self.bind_key('Shift-F', self.fill_holes)
+
+ def _extract_label(func):
+ """decorator that extract pixels with selected label
+
+ This decorator takes in a function with a Napari Viewer parameter,
+ and only run the function on pixels with selected labels.
+ The output of the function will replace the original selected pixels.
+
+
+ Parameters
+ ----------
+ func : python function
+ Function to be decorated
+
+ Returns
+ -------
+ new_func: python function
+ The decorated function
+ """
+
+ def new_func(self, viewer):
+ """Decorator function that extracts selected label
+
+ Parameters
+ ----------
+ viewer : Segmentify Viewer
+ """
+
+ curr_label = viewer.active_layer.selected_label
+
+ # extract only pixels with selected label
+ original = viewer.active_layer.data
+ labeled_img = np.zeros_like(original)
+ labeled_img[original == curr_label] = 1
+ viewer.active_layer.data = labeled_img
+
+ # run morphology
+ viewer.active_layer.data = func(self, viewer)
+
+ # merge processed image with original
+ all_labels = np.unique(original)
+ if len(all_labels) == 2:
+ background_label = all_labels[all_labels != curr_label][0]
+ original[(viewer.active_layer.data == 0) & (original == curr_label)] = background_label
+ else:
+ original[(viewer.active_layer.data == 0) & (original == curr_label)] = self.background_label
+ original[viewer.active_layer.data == 1] = curr_label
+
+ viewer.active_layer.data = original
+
+ viewer.status = f"Finished {func.__name__} on label {viewer.active_layer.selected_label}"
+
+ return new_func
+
+
+ def segment(self, viewer):
+ """function that fit and segment input image (key-binding: SHIFT-S)
+
+ This function takes in a Segmentify Viewer and featurize the input image.
+ A Random Forest Classifier is trained based on the selected training labels.
+ The entire input image is segmented based on the Random Forest Classifier's predictions.
+
+ Parameters
+ ----------
+ viewer : Segmentify Viewer
+ """
+
+ viewer.status = "Segmenting..."
+
+ # get data
+ image = viewer.layers['input'].data
+ labels = viewer.layers['train'].data
+
+ # fit and predict
+ clf, features = fit(image, labels, featurizer=self.cur_featurizer)
+ segmentation, self.prob = predict(clf, features)
+
+ # show prediction
+ self.segmentation = np.squeeze(segmentation)
+ viewer.layers['output'].data = self.segmentation
+
+ viewer.status = "Segmentation Completed"
+
+ def show_heatmap(self, viewer):
+ """This function generates the confidence heatmap of model's prediction.
+
+ The heatmap is generated based on the normalized entropy of the prediction probabilities.
+ """
+
+ if self.prob is not None:
+ # calcualte entropy for probability
+ prob = np.apply_along_axis(util._norm_entropy,-1, self.prob)
+ prob = (prob - np.min(prob)) / np.ptp(prob)
+ prob = np.squeeze(prob)
+
+ self.probability_heatmap.colormap = "uncertainty", self.colormap
+ self.probability_heatmap.data = prob
+ self.probability_heatmap.opacity = 0.7
+ viewer.status = "Probability heatmap generated"
+ else:
+ viewer.status = "Segmentation required before heatmap generation"
+
+ def next_featurizer(self, viewer):
+ """get next featurizer from self.featurizer (key-binding: SHIFT-N)
+
+ This function cycles through the availible featurizers in segmentify/model/saved_model,
+ as well as the image filter featurizer.
+
+ Parameters
+ ----------
+ viewer : Segmentify Viewer
+ """
+ self.cur_featurizer = next(self.featurizers)
+ viewer.status = self.cur_featurizer.split("/")[-1]
+
+
+ def closing(self, viewer):
+ """apply the closing operation (key-binding: SHIFT-C)
+
+ This function applies the closing operation by dilating the selected label pixels,
+ following by erosion
+
+ Parameters
+ ----------
+ viewer : Segmentify Viewer
+
+ Returns
+ -------
+ The procssed image
+ """
+ viewer.status = "Closing"
+ _ = self.dilation(viewer)
+ processed_img = self.erosion(viewer)
+ viewer.status = f"Finished closing on label {viewer.active_layer.selected_label}"
+ return processed_img
+
+ def opening(self, viewer):
+ """apply the opening operation (key-binding: SHIFT-O)
+
+ This function applies the opening operation by eroding the selected label pixels,
+ following by dilation
+
+ Parameters
+ ----------
+ viewer : Segmentify Viewer
+
+ Returns
+ -------
+ The procssed image
+ """
+ viewer.status = "Closing"
+ _ = self.erosion(viewer)
+ processed_img = self.dilation(viewer)
+ viewer.status = f"Finished opening on label {viewer.active_layer.selected_label}"
+ return processed_img
+
+ def erosion(self, viewer):
+ """apply the erosion operation (key-binding: SHIFT-E)
+
+ This function applies the erosion operation on selected label pixels
+
+ Parameters
+ ----------
+ viewer : Segmentify Viewer
+
+ Returns
+ -------
+ The procssed image
+ """
+ viewer.status = "Eroding"
+ processed_img = util._erode_img(viewer.active_layer.data, \
+ target_label=viewer.active_layer.selected_label)
+ viewer.active_layer.data = processed_img
+ viewer.status = f"Finished erosion on label {viewer.active_layer.selected_label}"
+ return processed_img
+
+ @_extract_label
+ def dilation(self, viewer):
+ """apply the dilation operation (key-binding: SHIFT-D)
+
+ This function applies the dilation operation on selected label pixels
+
+ Parameters
+ ----------
+ viewer : Segmentify Viewer
+
+ Returns
+ -------
+ The procssed image
+ """
+ processed_img = morphology.dilation(viewer.active_layer.data, self.selem)
+ viewer.status = f"Finished Dilation on label {viewer.active_layer.selected_label}"
+ return processed_img
+
+ @_extract_label
+ def fill_holes(self, viewer):
+ """apply the fill holes operation (key-binding: SHIFT-D)
+
+ This function applies the fill holes operation on the selected label pixels
+
+ Parameters
+ ----------
+ viewer : Segmentify Viewer
+
+ Returns
+ -------
+ The procssed image
+ """
+ if len(viewer.active_layer.data.shape) > 2:
+ processed_imgs = []
+ for i in range(viewer.active_layer.data.shape[0]):
+ processed_img = morphology.remove_small_holes(viewer.active_layer.data[i].astype(bool)).astype(int)
+ processed_imgs.append(processed_img)
+ return np.stack(processed_imgs, 0)
+ else:
+ return morphology.remove_small_holes(viewer.active_layer.data.astype(bool)).astype(int)
+