# Segmentation of cell nucleii

## Setup

### Libraries

In [1]:
import cv2 
import pandas as pd
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from itertools import chain

In [2]:
# Optional ( I like QT Graphs so i can zoom and i think %matplotlib widget sucks!)
%matplotlib qt

### Load data

Data comes from the git repo through [git lfs](https://git-lfs.github.com/)

In [3]:
images = []

for imagePath in Path("./data").glob("*.png"):
    image = cv2.imread(str(imagePath))
    if image.size > 0:
        images.append(image)
    else:
        print(f"Failed reading image {imagePath}")

images = np.stack(images)
N = len(images)
print(images.shape)

(4, 1040, 1408, 3)


In [4]:
def colPlot(images, **kwargs):
    fig, axs = plt.subplots(1,len(images))

    for im, ax in zip(images, axs):
        if im.ndim > 2: # if color
            ax.imshow(im[...,::-1]) # opencv is BGR
        else:
            ax.imshow(im, **kwargs)
    
    fig.tight_layout()

def multiRowPlot(images, titles, nrows, ncols, **kwargs):
    fig, axs = plt.subplots(nrows, ncols, sharey="col", sharex="col")

    for i, (im, title, ax) in enumerate(zip(images, titles, axs.flatten())):
        ax.set_title(title)
        ax.imshow(im, **kwargs)
        ax.axis("off")

    fig.tight_layout()

def imageTitles(pattern):
    return [pattern.format(i=i) for i in range(1, N+1)]

## Part 1

First we analyse the channels of the image and pick the best way to "grayscale" it.

The red channel is a highliting of cell nuceii, and the G and B channels (equivalent) are the grayscale, monochromatic image from the microscope.

### Plot all channels (RGB)

In [5]:
colPlot(images[:,:,::-1])

### Plot red channel and R-B for comparrison

* $I_R-I_B$ clearly shows nucleii with high contrast and no unwanted features.
* $I_R$ shows more detail for other parts of the cell. But that detail introduces unwanted features that don't have a very high contrast with nucleii

**Use $I_f=I_R-I_B$ for segmentation**

In [6]:
# the R channel
R = images[...,2] 

# the R-B difference image
diffRB = images[...,2] - images[...,0]

# For plotting
imPlots = chain(R, diffRB)
imTitles = chain(imageTitles("$I_{{ {i}R }}$"),
        imageTitles("$I_{{ {i}R }} - I_{{ {i}B }}$"))

multiRowPlot(imPlots, imTitles, nrows=2, ncols=N, cmap="gray")

In [7]:
for a,b in zip(images,diffRB):
    fig, axs = plt.subplots(1,2,sharex=True,sharey=True)
    axs[0].imshow(a[...,0], cmap="gray")
    axs[1].imshow(b)

### Histograms

We can also see in a log-histogram that $I_R-I_B$ has a more more distinct peak in its histogram, meaning higher contrast between backgrond and features.


In [8]:

def histColPlot(images:np.array, hist_args:dict):
    fig, axs = plt.subplots(2,len(images))
    for i, (im, ax) in enumerate(zip(images, axs[0])):
        ax.set_title(f"Image {i+1}")
        ax.imshow(im, cmap="gray")
        ax.axis("off")

    for i, (im, ax) in enumerate(zip(images, axs[1])):
        ax.set_title(f"Image {i+1} hist")
        ax.hist(im.flatten(), **hist_args)

    fig.tight_layout()

    return fig, axs

In [9]:
histColPlot(R, 
            hist_args=dict(bins=255,log=True))

(<Figure size 640x480 with 8 Axes>,
 array([[<AxesSubplot:title={'center':'Image 1'}>,
         <AxesSubplot:title={'center':'Image 2'}>,
         <AxesSubplot:title={'center':'Image 3'}>,
         <AxesSubplot:title={'center':'Image 4'}>],
        [<AxesSubplot:title={'center':'Image 1 hist'}>,
         <AxesSubplot:title={'center':'Image 2 hist'}>,
         <AxesSubplot:title={'center':'Image 3 hist'}>,
         <AxesSubplot:title={'center':'Image 4 hist'}>]], dtype=object))

In [10]:
histColPlot(diffRB, 
            hist_args=dict(bins=255,log=True))

(<Figure size 640x480 with 8 Axes>,
 array([[<AxesSubplot:title={'center':'Image 1'}>,
         <AxesSubplot:title={'center':'Image 2'}>,
         <AxesSubplot:title={'center':'Image 3'}>,
         <AxesSubplot:title={'center':'Image 4'}>],
        [<AxesSubplot:title={'center':'Image 1 hist'}>,
         <AxesSubplot:title={'center':'Image 2 hist'}>,
         <AxesSubplot:title={'center':'Image 3 hist'}>,
         <AxesSubplot:title={'center':'Image 4 hist'}>]], dtype=object))

### Determine $I_f$

$I_f$ is the grayscale image we'll use for segmentation. We do a linear stretching of the $I_R-I_B$ image so that $max(I_R-I_B)=255$ and $min(I_R-I_B)=0$

In [11]:
useChannel = diffRB

# np.max is calculated over ALL images
# This means e.g. we don't strech image 1 more than image 2
minval = np.min(useChannel)
If = (useChannel - minval) * 255.0 / (np.max(useChannel)-minval)

If = If.astype(np.uint8)

histColPlot(If, hist_args=dict(bins=255, log=True))
plt.suptitle("$I_f$ and Histograms")

Text(0.5, 0.98, '$I_f$ and Histograms')

## Part 2 - Thresholding

We apply two thresholding techniques to segment the nucleii of cells using $I_f$.

1. "handmade" threshold: from the histograms, we choose $T$ such that is isolates the background
2. Otsu's technique: we use the opencv implementation of Otsu's thresholding to determine $T$

We then extract the nucleii count and areas using `cv2.connectedComponents`, for result analysis.


In [12]:
histColPlot(If, hist_args=dict(bins=255, log=True))

(<Figure size 640x480 with 8 Axes>,
 array([[<AxesSubplot:title={'center':'Image 1'}>,
         <AxesSubplot:title={'center':'Image 2'}>,
         <AxesSubplot:title={'center':'Image 3'}>,
         <AxesSubplot:title={'center':'Image 4'}>],
        [<AxesSubplot:title={'center':'Image 1 hist'}>,
         <AxesSubplot:title={'center':'Image 2 hist'}>,
         <AxesSubplot:title={'center':'Image 3 hist'}>,
         <AxesSubplot:title={'center':'Image 4 hist'}>]], dtype=object))

### Threshold: Manual

In [13]:
# Feels like this is a good value:
T = 40
fig, axs = histColPlot(If, hist_args=dict(bins=255, log=False))

for hist in axs[1]:
    hist.axvline(T, color="r", label=f"{T=}")
    hist.legend()
plt.suptitle("$I_f$ histograms and chosen Threshold")

Text(0.5, 0.98, '$I_f$ histograms and chosen Threshold')

In [14]:
threshManual = 255* ( If > T )
threshManual=threshManual.astype(np.uint8)

imPlots = chain(If, threshManual)
imTitles = chain(imageTitles("$I_f{{ {i} }}$"), imageTitles("$T_{i}$"))

multiRowPlot(imPlots, imTitles, nrows=2, ncols=N, cmap="gray")

### Threshold: Otsu`s

In [15]:
threshOtsu = []
threshOtsuVals = []
for If_i in If:
    ret, otsuThresh_i =  cv2.threshold(If_i, 127, 255, cv2.THRESH_OTSU | cv2.THRESH_BINARY)
    threshOtsu.append(otsuThresh_i)
    threshOtsuVals.append(ret)

threshOtsu=np.stack(threshOtsu, axis=0)

imPlots = chain(images, threshManual, threshOtsu)
imTitles = chain(imageTitles("$I_f{{ {i} }}$"), 
    (
      title + f",{T=}" 
      for title in imageTitles("$T_{i} Manual$")
    ),
    ( title + f", {T=}"
      for T, title 
      in zip(threshOtsuVals,imageTitles("$T_{i}$ Otsu"))
    ))

print(imTitles)
multiRowPlot(imPlots, imTitles, nrows=3, ncols=N, cmap="gray")

<itertools.chain object at 0x7fba21af75e0>


### Evaluation

`ThreshManual` looks more consistant overall, with threshOtsu choosing a value that is slightly too high and leaves out part of some cells.

This makes sense because `Otsu` assumes a strong **bimodal distribution**, but the image has a very small foreground area. 

We can see this by calculating the percentage of pixels labeled as foreground by manual thresholding:

In [16]:
print("Number of foreground pixels (by image):")
print(np.count_nonzero(threshManual, axis=(1,2)))
print("Percent of foreground pixels (by image):")
percents = np.round(np.count_nonzero(threshManual, axis=(1,2))/threshManual[0].size * 100, 2)
print(", ".join([f"{p}%" for p in percents]))

Number of foreground pixels (by image):
[ 5060  7986 12246  7126]
Percent of foreground pixels (by image):
0.35%, 0.55%, 0.84%, 0.49%


In [17]:
# We picked this threshold:
thresh = threshManual

### Get connected areas

In [18]:
conn = 8
areas = []
for image in thresh:
    n_labels, labels =  cv2.connectedComponents(image, None, conn)

    # count number of pixels in each connected component
    # print(np.array([labels==i for i in range(n_labels)]).shape)
    area = np.count_nonzero([labels==i for i in range(n_labels)], axis=(1,2))

    # First item is background label
    area = area[1:]

    areas.append(area)

* There are a lot of areas with 1 or 2 pixels. 
    These are likely noise.

    We'll calculate stats including them first, but later we'll try filtering out areas that equal 1 or 2

In [19]:
for i, area in enumerate(areas):
    print(f"Image {i+1} areas:")
    print(area)

Image 1 areas:
[497   1   1 647 613   1   4  52   1 427  26   1  17   1   4   1   1   1
   4   2   2   1   1   5   1   1 679   1   1   1   1   1 296 265 964   1
  19 463   1   1  38   9   1   1   4]
Image 2 areas:
[ 587    1    1   44    9    1    7    1    7    6    1  583    1    1
  373    3  493    1  623    1    1  694  663  802  610    1  487    1
 1222  760    1]
Image 3 areas:
[ 937 1768 1175  718   15    1    1    1   10    1    1   22    1    1
   21    1    1    3    5    1    2    7   36    1    7    1    1    1
    1  732    1  689  659    1  514  573    1  708    1  531  655  575
  735    1  498    1  630]
Image 4 areas:
[  1   1   1 481 598   1 945 669 583   1 860 610  11   3  59   1 440 776
 379   1 705]


### Stats

**Original** 
* We see there are a lot of 1-pixel areas (over 25% for some images, over 50% for image 1) .
* Maybe because of this, the mean area for the first image is much smaller
* The biggest nucleus is in image 3, which also has the highest mean size. But the median size of cells is higher on image4

**With minimum area**
* Now the mean size of images is more similar, with images 3 and 4 having the bigger cells
* Standard deviation is very high, so cells are varying a lot in size.
* We see this on the quartiles, with 20% of cells being smaller than 50 pixels on all images but image 4

**Comparison with manual count**
* Manual count is more similar to the minimum area values. 
* But either way these counts are different from the manual count. Mainly because of "close-together" nucleii being counted as only one

In [20]:
def print_stats(areas):
    dfs = [pd.DataFrame(data={f"image{i+1}":area}) for i, area in enumerate(areas)]
    print(pd.concat([df.describe().transpose().round(2) for df in dfs]))

def get_areas_and_print_stats(images):
    areas = []
    for i, image in enumerate(images):
        n_labels, labels =  cv2.connectedComponents(image, None, conn)
        # count number of pixels in each connected component
        # print(np.array([labels==i for i in range(n_labels)]).shape)
        area = np.count_nonzero([labels==i for i in range(n_labels)], axis=(1,2))
        # First item is background label
        area = area[1:]
        areas.append(area)
    
    print_stats(areas)


In [21]:
print("====== ALL")
print_stats(areas)

min_area=5
print(f"====== Areas > {min_area=}")
print_stats([a[a>min_area] for a in areas])

        count    mean     std  min  25%    50%    75%     max
image1   45.0  112.44  234.66  1.0  1.0    1.0   26.0   964.0
image2   31.0  257.61  348.53  1.0  1.0    7.0  585.0  1222.0
image3   47.0  260.55  403.17  1.0  1.0    5.0  574.0  1768.0
image4   21.0  339.33  346.63  1.0  1.0  379.0  610.0   945.0
        count    mean     std   min    25%    50%    75%     max
image1   15.0  334.13  306.73   9.0   32.0  296.0  555.0   964.0
image2   17.0  468.82  350.44   6.0   44.0  583.0  663.0  1222.0
image3   23.0  531.09  435.93   7.0   29.0  575.0  713.0  1768.0
image4   13.0  547.38  277.79  11.0  440.0  598.0  705.0   945.0


## Watershet

### Gradient
(need to study this)

In [23]:
gradients=[]
for If_i in If:
    ddepth = cv2.CV_32F
    dx = cv2.Sobel(If_i, ddepth, 1, 0)
    dy = cv2.Sobel(If_i, ddepth, 0, 1)

    gradients.append(np.sqrt(dx**2+dy**2))
gradients=np.stack(gradients,axis=0)

multiRowPlot(
    chain(If, gradients),
    chain(imageTitles("$ I_{{ f{i} }} $"), imageTitles("$ | \\Delta I_{{ f{i} }} | $")),
    2,
    N
)

### Markers

In [27]:
kernel = np.ones((3,3))
markers = [cv2.morphologyEx(Thresh_i, cv2.MORPH_OPEN, kernel) for Thresh_i in thresh]
markers = [cv2.erode(m,  kernel) for m in markers]
markers = np.stack(markers, axis=0)

plt_images = chain(thresh, markers)
plt_titles = chain(imageTitles("thresh"), imageTitles("marker"))
multiRowPlot(plt_images , plt_titles , 2, N)

In [43]:
gradients = gradients*255/gradients.max()

In [56]:
watersheds=[]
for m, grad in zip(markers, gradients):
    conn = 8
    n_labels, labels =  cv2.connectedComponents(m, None, conn)
    im = np.stack([grad,grad,grad], axis=-1).astype(np.uint8)
    out = cv2.watershed(im, m.astype(np.int32))
    watersheds.append(out)


multiRowPlot(
    chain(If, watersheds),
    chain(imageTitles("$ I_{{ f{i} }} $"), imageTitles("watershed {{ f{i} }}")),
    2,
    N
)