# Image segmentation using persistent homology

With this document you can play around with image segmentation using a persistent homology algorithm.

In [8]:
# Standard data libraries
import numpy as np 
from ipywidgets import widgets, interact, interactive, fixed, interact_manual, Layout
import ipywidgets as widgets
from IPython.display import display

# My favourite plotting libraries
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

#To manipulate images
import cv2
import scipy
from scipy import ndimage
import io
from PIL import Image

from persim import plot_diagrams
from ripser import ripser, lower_star_img

#Switch on for deployment on Binder beacuse of memory limitations
resize = True
max_size = 1280

In [10]:
def manipulate_image(upl,channel='grey',blur=1,inv=False,example="Cells"):
    global img
    global smoothed
    
    if example == 'Cells':
        img = plt.imread("Cells.jpg")
    elif example == 'Gannets':
        img = plt.imread('Gannets.jpg')
    elif example == 'Blossoms':
        img = plt.imread('Blossoms.jpg')
    elif example == 'Cats':
        img = plt.imread('Cats.jpg')
    elif example == 'Trees':
        img = plt.imread('Trees.jpg')
    elif example == 'Uploaded File':
        if upl != {}:
            for name, file_info in upl.items():
                pil_img = Image.open(io.BytesIO(file_info['content']))
                img = np.array(pil_img)
    if resize == True:
        if max(img.shape) > max_size:
            scale = max_size/max(img.shape)
            new_width = int(img.shape[1] * scale)
            new_height = int(img.shape[0] * scale)
            dim = (new_width, new_height)
            img = cv2.resize(img, dim, interpolation = cv2.INTER_AREA)       
    
    if blur > 0:
        image = cv2.blur(img,(blur,blur))   
    
    if channel == 'grey':   
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        image = cv2.equalizeHist(image)
        cmap = 'Greys'
    elif channel == 'red':
        image = image[:,:,2]
        image = cv2.equalizeHist(image)
        cmap = 'Reds'
    elif channel == 'green':
        image = image[:,:,1]
        image = cv2.equalizeHist(image)
        cmap = 'Greens'
    elif channel == 'blue':
        image = image[:,:,0]
        image = cv2.equalizeHist(image)
        cmap = 'Blues'
        

    smoothed = ndimage.uniform_filter(image.astype(np.float64), size=10)
    smoothed += 0.01 * np.random.randn(*smoothed.shape)
    
    if inv == True:
        smoothed = -smoothed
    
    plt.figure(figsize=(8, 4))
    plt.subplot(1,2,1)
    plt.title("Original Image")
    plt.imshow(img,aspect = 'auto')
#    plt.show()    
    plt.subplot(1,2,2)    
    plt.title("Image used for thresholding")
    plt.imshow(-smoothed,aspect='auto',cmap = cmap)
    plt.show()


In [11]:
def make_diagram(b):
    global dgm
    with output:
        dgm = lower_star_img(-smoothed)
        print("Diagram generated! Please continue below.")
        
    

In [12]:
def segment(threshold=50):
    
    gs = gridspec.GridSpec(1, 2,width_ratios=[1,3])
    
    plt.figure(figsize=(12, 6))
    plt.subplot(gs[0])
    plot_diagrams(dgm, lifetime=True)
    plt.hlines(threshold,-255,0,colors='red')
    plt.fill_between([-255,0], threshold,255,color='r',alpha=.25)
    #plt.show()
    
    
    
    idxs = np.arange(dgm.shape[0])
    idxs = idxs[np.abs(dgm[:, 1] - dgm[:, 0]) > threshold]

    
    #plt.figure(figsize=(10, 4))
    plt.subplot(gs[1])
    plt.imshow(img)

    X, Y = np.meshgrid(np.arange(smoothed.shape[1]), np.arange(smoothed.shape[0]))
    X = X.flatten()
    Y = Y.flatten()

    
    for idx in idxs:
        bidx = np.argmin(np.abs(smoothed + dgm[idx, 0]))
        plt.scatter(X[bidx], Y[bidx], 20, 'red')
    plt.axis('off')
#    plt.subplot(1,2,1)
    plt.show()
#plt.savefig('trees_dots.jpg')

Here, you can upload your own image file to replace the default image. **With the current version, the image must be a color image!** The alogrithm needs a greyscale input, so choose either a greyscale conversion or one of the color channels. Blurring the image might improve the results, as noise is reduced. By default, the algorith searches for bright spots, if the features you want to find are dark, invert the picture.

In [13]:
w0 = widgets.FileUpload(multiple=False)
w1 = widgets.RadioButtons(options=['grey','red', 'green', 'blue'],description='Color:',disabled=False)
w2 = widgets.IntSlider(value=1, min=1, max=100, step=1, description='Blur:',continuous_update=False)
w3 = widgets.Checkbox(value=False, description='Invert image', disabled=False, indent=False)
w4 = widgets.ToggleButtons(options=['Cells', 'Gannets', 'Blossoms', 'Cats', 'Trees','Uploaded File'],
    description='Images:',
    disabled=False)

w40 = widgets.VBox([w4,w0])
w23 = widgets.VBox([w3,w2])
ui = widgets.HBox([w40, w1, w23])
out = widgets.interactive_output(manipulate_image, {'upl': w0, 'channel': w1, 'blur': w2, 'inv': w3,'example': w4})

display(ui, out)

HBox(children=(VBox(children=(ToggleButtons(description='Images:', options=('Cells', 'Gannets', 'Blossoms', 'C…

Output()

# Persistence diagram

The persistence diagram shows how long features survive if a threshold is run through the image. Points at the bottom are very short-lived and essentially noise. If the diagram for your image shows a distinct cloud somewhere above the x-axis, these are the features you might want to find. Use this diagram to select a threshold (Lifetime value) for the next step, so that only features with a lifetime above this threshold are marked below.

Every time you change the image above, press the button to generate a new diagram. This is not done automatically, as it might take some time. The next step below uses the diagram, so be sure to generate it before moving the threshold.

In [6]:
button = widgets.Button(description="Generate new persistence diagram",layout=Layout(width='50%', height='80px'))
output = widgets.Output()

make_diagram(0)


display(button, output)

button.on_click(make_diagram)

Button(description='Generate new persistence diagram', layout=Layout(height='80px', width='50%'), style=Button…

Output()

# Thresholding

Play with the threshold value here to see for which value the algorithm is best. If you changed the image, the new image should appear here if you move the slider. All features above the threshold in the persistence diagram are shown. Set the threshold too low and lots of noise gets selected. The single point on the top always exists and represents the global maximum of the picture (i.e. the brightest pixel).

In [7]:
interact(segment,threshold = widgets.IntSlider(value = 50, min = 1, max = 255, continuous_update=False));

interactive(children=(IntSlider(value=50, continuous_update=False, description='threshold', max=255, min=1), O…