# Segmentation of cell nucleii

## Setup

### Libraries

In [3]:
import cv2 
import pandas as pd
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from itertools import chain

In [4]:
# Optional ( I like QT Graphs so i can zoom and i think %matplotlib widget sucks!)
%matplotlib qt

### Load data

Data comes from the git repo through [git lfs](https://git-lfs.github.com/)

In [5]:
images = []

for imagePath in Path("./data").glob("*.png"):
    image = cv2.imread(str(imagePath))
    if image.size > 0:
        images.append(image)
    else:
        print(f"Failed reading image {imagePath}")

images = np.stack(images)
N = len(images)
print(images.shape)

(4, 1040, 1408, 3)


In [6]:
maximized=True
def colPlot(images, **kwargs):
    fig, axs = plt.subplots(1,len(images))

    for im, ax in zip(images, axs):
        if im.ndim > 2: # if color
            ax.imshow(im[...,::-1]) # opencv is BGR
        else:
            ax.imshow(im, **kwargs)
    
    if maximized:
        figManager = plt.get_current_fig_manager()
        figManager.window.showMaximized()
    fig.tight_layout()

def multiRowPlot(images, titles, nrows, ncols, **kwargs):
    fig, axs = plt.subplots(nrows, ncols, sharey="col", sharex="col")

    for i, (im, title, ax) in enumerate(zip(images, titles, axs.flatten())):
        ax.set_title(title)
        ax.imshow(im, **kwargs)
        ax.axis("off")


    if maximized:
        figManager = plt.get_current_fig_manager()
        figManager.window.showMaximized()
    fig.tight_layout()


def imageTitles(pattern):
    return [pattern.format(i=i) for i in range(1, N+1)]

## Part 1

First we analyse the channels of the image and pick the best way to "grayscale" it.

The red channel is a highliting of cell nuceii, and the G and B channels (equivalent) are the grayscale, monochromatic image from the microscope.

### Plot all channels (RGB)

In [7]:
colPlot(images[:,:,::-1])

### Plot red channel and R-B for comparrison

* $I_R-I_B$ clearly shows nucleii with high contrast and no unwanted features.
* $I_R$ shows more detail for other parts of the cell. But that detail introduces unwanted features that don't have a very high contrast with nucleii

**Use $I_f=I_R-I_B$ for segmentation**

In [8]:
# the R channel
R = images[...,2] 

# the R-B difference image
diffRB = images[...,2] - images[...,0]

# For plotting
imPlots = chain(R, diffRB)
imTitles = chain(imageTitles("$I_{{ {i}R }}$"),
        imageTitles("$I_{{ {i}R }} - I_{{ {i}B }}$"))

multiRowPlot(imPlots, imTitles, nrows=2, ncols=N, cmap="gray")

In [9]:
for a,b in zip(images,diffRB):
    fig, axs = plt.subplots(1,2,sharex=True,sharey=True)
    axs[0].imshow(a[...,0], cmap="gray")
    axs[1].imshow(b)

### Histograms

We can also see in a log-histogram that $I_R-I_B$ has a more more distinct peak in its histogram, meaning higher contrast between backgrond and features.


In [10]:

def histColPlot(images:np.array, hist_args:dict):
    fig, axs = plt.subplots(2,len(images))
    for i, (im, ax) in enumerate(zip(images, axs[0])):
        ax.set_title(f"Image {i+1}")
        ax.imshow(im, cmap="gray")
        ax.axis("off")

    for i, (im, ax) in enumerate(zip(images, axs[1])):
        ax.set_title(f"Image {i+1} hist")
        ax.hist(im.flatten(), **hist_args)

    fig.tight_layout()

    return fig, axs

In [11]:
histColPlot(R, 
            hist_args=dict(bins=255,log=True))

(<Figure size 640x480 with 8 Axes>,
 array([[<AxesSubplot:title={'center':'Image 1'}>,
         <AxesSubplot:title={'center':'Image 2'}>,
         <AxesSubplot:title={'center':'Image 3'}>,
         <AxesSubplot:title={'center':'Image 4'}>],
        [<AxesSubplot:title={'center':'Image 1 hist'}>,
         <AxesSubplot:title={'center':'Image 2 hist'}>,
         <AxesSubplot:title={'center':'Image 3 hist'}>,
         <AxesSubplot:title={'center':'Image 4 hist'}>]], dtype=object))

In [12]:
histColPlot(diffRB, 
            hist_args=dict(bins=255,log=True))

(<Figure size 640x480 with 8 Axes>,
 array([[<AxesSubplot:title={'center':'Image 1'}>,
         <AxesSubplot:title={'center':'Image 2'}>,
         <AxesSubplot:title={'center':'Image 3'}>,
         <AxesSubplot:title={'center':'Image 4'}>],
        [<AxesSubplot:title={'center':'Image 1 hist'}>,
         <AxesSubplot:title={'center':'Image 2 hist'}>,
         <AxesSubplot:title={'center':'Image 3 hist'}>,
         <AxesSubplot:title={'center':'Image 4 hist'}>]], dtype=object))

### Determine $I_f$

$I_f$ is the grayscale image we'll use for segmentation. We do a linear stretching of the $I_R-I_B$ image so that $max(I_R-I_B)=255$ and $min(I_R-I_B)=0$

In [13]:
useChannel = diffRB

# np.max is calculated over ALL images
# This means e.g. we don't strech image 1 more than image 2
minval = np.min(useChannel)
If = (useChannel - minval) * 255.0 / (np.max(useChannel)-minval)

If = If.astype(np.uint8)

histColPlot(If, hist_args=dict(bins=255, log=True))
plt.suptitle("$I_f$ and Histograms")

Text(0.5, 0.98, '$I_f$ and Histograms')

## Part 2 - Thresholding

We apply two thresholding techniques to segment the nucleii of cells using $I_f$.

1. "handmade" threshold: from the histograms, we choose $T$ such that is isolates the background
2. Otsu's technique: we use the opencv implementation of Otsu's thresholding to determine $T$

We then extract the nucleii count and areas using `cv2.connectedComponents`, for result analysis.


In [14]:
histColPlot(If, hist_args=dict(bins=255, log=True))

(<Figure size 640x480 with 8 Axes>,
 array([[<AxesSubplot:title={'center':'Image 1'}>,
         <AxesSubplot:title={'center':'Image 2'}>,
         <AxesSubplot:title={'center':'Image 3'}>,
         <AxesSubplot:title={'center':'Image 4'}>],
        [<AxesSubplot:title={'center':'Image 1 hist'}>,
         <AxesSubplot:title={'center':'Image 2 hist'}>,
         <AxesSubplot:title={'center':'Image 3 hist'}>,
         <AxesSubplot:title={'center':'Image 4 hist'}>]], dtype=object))

### Threshold: Manual

In [15]:
# Feels like this is a good value:
T = 35

In [16]:
fig, axs = histColPlot(If, hist_args=dict(bins=255, log=True))

for hist in axs[1]:
    hist.axvline(T, color="r", label=f"{T=}")
    hist.legend()
plt.suptitle("$I_f$ histograms and chosen Threshold")

Text(0.5, 0.98, '$I_f$ histograms and chosen Threshold')

In [17]:
threshManual = 255* ( If > T )
threshManual=threshManual.astype(np.uint8)

imPlots = chain(If, threshManual)
imTitles = chain(imageTitles("$I_f{{ {i} }}$"), imageTitles("$T_{i}$"))

multiRowPlot(imPlots, imTitles, nrows=2, ncols=N, cmap="gray")
plt.show()

### Threshold: Otsu`s

In [74]:
threshOtsu = []
threshOtsuVals = []
for If_i in If:
    ret, otsuThresh_i =  cv2.threshold(If_i, 127, 255, cv2.THRESH_OTSU | cv2.THRESH_BINARY)
    threshOtsu.append(otsuThresh_i)
    threshOtsuVals.append(ret)

threshOtsu=np.stack(threshOtsu, axis=0)

imPlots = chain(If, threshManual, threshOtsu)
imTitles = chain(imageTitles("$I_f{{ {i} }}$"), 
    (
      title + f",{T=}" 
      for title in imageTitles("$T_{i} Manual$")
    ),
    ( title + f", {T=}"
      for T, title 
      in zip(threshOtsuVals,imageTitles("$T_{i}$ Otsu"))
    ))

print(imTitles)
multiRowPlot(imPlots, imTitles, nrows=3, ncols=N, cmap="gray")

<itertools.chain object at 0x7f1d4f90ef70>


### Evaluation

`ThreshManual` looks more consistant overall, with `threshOtsu` choosing a value that is slightly too high and leaves out part of some cells.

This makes sense because `Otsu` assumes a strong **bimodal distribution**, but the image has a very small foreground area. 

We can see this by calculating the percentage of pixels labeled as foreground by manual thresholding:

In [19]:
print("Number of foreground pixels (by image):")
print(np.count_nonzero(threshManual, axis=(1,2)))
print("Percent of foreground pixels (by image):")
percents = np.round(np.count_nonzero(threshManual, axis=(1,2))/threshManual[0].size * 100, 2)
print(", ".join([f"{p}%" for p in percents]))

Number of foreground pixels (by image):
[ 5519  8306 13064  7621]
Percent of foreground pixels (by image):
0.38%, 0.57%, 0.89%, 0.52%


In [75]:
# We picked this threshold:
thresh = threshManual

### Get connected areas

In [21]:
conn = 8
areas = []
for image in thresh:
    n_labels, labels =  cv2.connectedComponents(image, None, conn)

    # count number of pixels in each connected component
    # print(np.array([labels==i for i in range(n_labels)]).shape)
    area = np.count_nonzero([labels==i for i in range(n_labels)], axis=(1,2))

    # First item is background label
    area = area[1:]

    areas.append(area)

* There are a lot of areas with 1 or 2 pixels. 
    These are likely noise.

    We'll calculate stats including them first, but later we'll try filtering out areas that equal 1 or 2

In [22]:
for i, area in enumerate(areas):
    print(f"Image {i+1} areas:")
    print(area)

Image 1 areas:
[568   1 674 641   1   5  58   1   2 451   1   1   1   2  27   3  26   8
   1   1  24   1   1   1   1   8   1   1 710   1   1   1   1   1 346   1
   1 321 990   1   1   1   1  29 516   1   1   1   1  49   1   1   1   1
   1  25   1   1]
Image 2 areas:
[ 612    1    1   46    9    1   10    8    1    6    1  601    1    1
  447  515    1  642    1    1  704  678  820  639    1  532 1247  772
    1    1    1    2    1    1]
Image 3 areas:
[ 963 1821 1216  761    1   20    1    1   42    2    1    2    3    8
   41    1    2    2    3    2    1    1   20    2    1   26    1    3
    1    8    2    7    3    1    1    1    1    1    1    1    1    2
   18    1    1    1   56    7    2    1    1    1    1  779    1  736
  690    1  554  591    1  737    1  552  679  634  786    1  593  658
    1]
Image 4 areas:
[  1   1   1 619 504   1 980   1 690   1 663   1 886 636  15   1   1  17
   1   2   1   1  78   1   1   4   1   1   1   1   1   1 501 799 475 732]


### Stats

**Original** 
* We see there are a lot of 1-pixel areas (over 25% for some images, over 50% for image 1) .
* Maybe because of this, the mean area for the first image is much smaller
* The biggest nucleus is in image 3, which also has the highest mean size. But the median size of cells is higher on image4

**With minimum area**
* Now the mean size of images is more similar, with images 3 and 4 having the bigger cells
* Standard deviation is very high, so cells are varying a lot in size.
* We see this on the quartiles, with 20% of cells being smaller than 50 pixels on all images but image 4

**Comparison with manual count**
* Manual count is more similar to the minimum area values. 
* But either way these counts are different from the manual count. Mainly because of "close-together" nucleii being counted as only one

In [23]:
def print_stats(areas):
    dfs = [pd.DataFrame(data={f"image{i+1}":area}) for i, area in enumerate(areas)]
    print(pd.concat([df.describe().transpose().round(2) for df in dfs]))

def get_areas_and_print_stats(images):
    areas = []
    for i, image in enumerate(images):
        n_labels, labels =  cv2.connectedComponents(image, None, conn)
        # count number of pixels in each connected component
        # print(np.array([labels==i for i in range(n_labels)]).shape)
        area = np.count_nonzero([labels==i for i in range(n_labels)], axis=(1,2))
        # First item is background label
        area = area[1:]
        areas.append(area)
    
    print_stats(areas)


In [24]:
print("====== ALL")
print_stats(areas)

min_area=5
print(f"====== Areas > {min_area=}")
print_stats([a[a>min_area] for a in areas])

        count    mean     std  min  25%  50%     75%     max
image1   58.0   95.16  223.55  1.0  1.0  1.0   24.75   990.0
image2   34.0  244.29  350.60  1.0  1.0  4.0  583.75  1247.0
image3   71.0  184.00  364.50  1.0  1.0  2.0   41.50  1821.0
image4   36.0  211.69  327.19  1.0  1.0  1.0  501.75   980.0
        count    mean     std   min     25%    50%    75%     max
image1   18.0  303.94  317.24   8.0   26.25  189.5  555.0   990.0
image2   17.0  487.53  357.49   6.0   46.00  601.0  678.0  1247.0
image3   27.0  481.59  457.49   7.0   23.00  591.0  736.5  1821.0
image4   14.0  542.50  308.83  15.0  481.50  627.5  721.5   980.0


## Watershet

### Gradient
(need to study this)

In [26]:
gradients=[]
for If_i in If:
    ddepth = cv2.CV_32F
    
    dx = cv2.Sobel(If_i, ddepth, 1, 0)
    dy = cv2.Sobel(If_i, ddepth, 0, 1)

    gradients.append(np.sqrt(dx**2+dy**2))
gradients=np.stack(gradients,axis=0)

multiRowPlot(
    chain(If, gradients),
    chain(imageTitles("$ I_{{ f{i} }} $"), imageTitles("$ | \\Delta I_{{ f{i} }} | $")),
    2,
    N
)

### Markers

In [27]:
kernel = np.ones((3,3))
markers = [cv2.morphologyEx(m, cv2.MORPH_CLOSE, kernel) for m in thresh]
markers = [cv2.erode(m,  kernel) for m in markers]
markers = np.stack(markers, axis=0)

kernel = np.ones((5,5))
bgs = [~cv2.dilate(m,  kernel, iterations=3) for m in thresh]
bgs = np.stack(bgs, axis=0)

plt_images = chain(thresh, markers, bgs)
plt_titles = chain(imageTitles("thresh"), imageTitles("marker"), imageTitles("background"))
multiRowPlot(plt_images , plt_titles , 3, N)

In [28]:
gradients = gradients*255/gradients.max()

In [29]:
from skimage.segmentation import watershed

watersheds=[]
for m, bg, grad in zip(markers, bgs, gradients):
    conn = 8
    n_labels, labels =  cv2.connectedComponents(m, None, conn)

    labels[bg>0] = labels.max() + 1
    # plt.imshow(labels * 255 / labels.max())
    out = watershed(grad, labels.astype(np.int32))
    out[out==out.max()] = 0
    watersheds.append(out)
watersheds=np.stack(watersheds,axis=0)


In [30]:
multiRowPlot(
    chain(If, markers, watersheds),
    chain(imageTitles("$ I_{{ f{i} }} $"), imageTitles("$marker_{{ B{i} }}$"), imageTitles("$watershed_{{ {i} }}$")),
    3,
    N
)

In [31]:
print("====== ALL")
print_stats(areas)

min_area=5
print(f"====== Areas > {min_area=}")
print_stats([a[a>min_area] for a in areas])
print(f"====== watershed")
get_areas_and_print_stats(255*(watersheds>0).astype(np.uint8))

        count    mean     std  min  25%  50%     75%     max
image1   58.0   95.16  223.55  1.0  1.0  1.0   24.75   990.0
image2   34.0  244.29  350.60  1.0  1.0  4.0  583.75  1247.0
image3   71.0  184.00  364.50  1.0  1.0  2.0   41.50  1821.0
image4   36.0  211.69  327.19  1.0  1.0  1.0  501.75   980.0
        count    mean     std   min     25%    50%    75%     max
image1   18.0  303.94  317.24   8.0   26.25  189.5  555.0   990.0
image2   17.0  487.53  357.49   6.0   46.00  601.0  678.0  1247.0
image3   27.0  481.59  457.49   7.0   23.00  591.0  736.5  1821.0
image4   14.0  542.50  308.83  15.0  481.50  627.5  721.5   980.0
        count    mean     std    min     25%    50%     75%     max
image1   16.0  332.38  291.47    4.0   40.75  324.0  580.25   933.0
image2   14.0  538.29  267.64    5.0  518.00  575.5  594.00  1113.0
image3   21.0  621.81  370.61    9.0  499.00  651.0  726.00  1680.0
image4   12.0  651.75  146.55  409.0  565.75  645.0  702.75   964.0


In [82]:
filtered=[]
for If_i in If:
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5))
    ddepth = cv2.CV_32F
    
    f = If_i
    # f = cv2.medianBlur(f, ksize=7)

    f = 255*(f > T)

    f = cv2.morphologyEx(f.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
    f = cv2.morphologyEx(f.astype(np.uint8), cv2.MORPH_OPEN, kernel)
    # f = cv2.dilate(f.astype(np.uint8),  kernel, iterations=2)
    filtered.append(f)
    
multiRowPlot(
    chain(If, thresh, filtered),
    chain(imageTitles("$ I_{{ f{i} }} $"), imageTitles("$ Thresh I_{{ f{i} }} $"), imageTitles("$ Filter I_{{ f{i} }} $")),
    3,
    N
)

In [88]:
kmeans=[]
for If_i in If:
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
    # Set flags (Just to avoid line break in the code)
    flags = cv2.KMEANS_RANDOM_CENTERS
    # Apply KMeans
    compactness,labels,centers = cv2.kmeans(
            If_i.ravel().astype(np.float32),
            2,
            None,
            criteria,
            10,
            flags
        )

    km = If_i.copy()
    km.flat[labels.ravel()==0] = 0
    km.flat[labels.ravel()==1] = 1

    kmeans.append(km)

In [87]:
multiRowPlot(
    chain(If, kmeans, thresh),
    chain(imageTitles("$ I_{{ f{i} }} $"), imageTitles("$ KMeans ~I_{{ f{i} }} $"), imageTitles("$ Thresh~I_{{ f{i} }} $")),
    3,
    N
    )
