# Segmentation of cell nucleii

## Setup

### Libraries

In [None]:
import cv2 
import pandas as pd
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from itertools import chain

In [None]:
# Optional ( I like QT Graphs so i can zoom and i think %matplotlib widget sucks!)
%matplotlib qt

### Load data

Data comes from the git repo through [git lfs](https://git-lfs.github.com/)

In [None]:
images = []

for imagePath in Path("./data").glob("*.png"):
    image = cv2.imread(str(imagePath))
    if image.size > 0:
        images.append(image)
    else:
        print(f"Failed reading image {imagePath}")

images = np.stack(images)
N = len(images)
print(images.shape)

In [None]:
def colPlot(images, **kwargs):
    fig, axs = plt.subplots(1,len(images))

    for im, ax in zip(images, axs):
        if im.ndim > 2: # if color
            ax.imshow(im[...,::-1]) # opencv is BGR
        else:
            ax.imshow(im, **kwargs)
    
    fig.tight_layout()

def multiRowPlot(images, titles, nrows, ncols, **kwargs):
    fig, axs = plt.subplots(nrows, ncols, sharey="col", sharex="col")

    for i, (im, title, ax) in enumerate(zip(images, titles, axs.flatten())):
        ax.set_title(title)
        ax.imshow(im, **kwargs)
        ax.axis("off")

    fig.tight_layout()

def imageTitles(pattern):
    return [pattern.format(i=i) for i in range(1, N+1)]

## Part 1

First we analyse the channels of the image and pick the best way to "grayscale" it.

The red channel is a highliting of cell nuceii, and the G and B channels (equivalent) are the grayscale, monochromatic image from the microscope.

### Plot all channels (RGB)

In [None]:
colPlot(images[:,:,::-1])

### Plot red channel and R-B for comparrison

* $I_R-I_B$ clearly shows nucleii with high contrast and no unwanted features.
* $I_R$ shows more detail for other parts of the cell. But that detail introduces unwanted features that don't have a very high contrast with nucleii

**Use $I_f=I_R-I_B$ for segmentation**

In [None]:
# the R channel
R = images[...,2] 

# the R-B difference image
diffRB = images[...,2] - images[...,0]

# For plotting
imPlots = chain(R, diffRB)
imTitles = chain(imageTitles("$I_{{ {i}R }}$"),
        imageTitles("$I_{{ {i}R }} - I_{{ {i}B }}$"))

multiRowPlot(imPlots, imTitles, nrows=2, ncols=N, cmap="gray")

### Histograms

We can also see in a log-histogram that $I_R-I_B$ has a more more distinct peak in its histogram, meaning higher contrast between backgrond and features.


In [None]:

def histColPlot(images:np.array, hist_args:dict):
    fig, axs = plt.subplots(2,len(images))
    for i, (im, ax) in enumerate(zip(images, axs[0])):
        ax.set_title(f"Image {i+1}")
        ax.imshow(im, cmap="gray")
        ax.axis("off")

    for i, (im, ax) in enumerate(zip(images, axs[1])):
        ax.set_title(f"Image {i+1} hist")
        ax.hist(im.flatten(), **hist_args)

    fig.tight_layout()

    return fig, axs

In [None]:
histColPlot(R, 
            hist_args=dict(bins=255,log=True))

In [None]:
histColPlot(diffRB, 
            hist_args=dict(bins=255,log=True))

### Determine $I_f$

$I_f$ is the grayscale image we'll use for segmentation. We do a linear stretching of the $I_R-I_B$ image so that $max(I_R-I_B)=255$ and $min(I_R-I_B)=0$

In [None]:
useChannel = diffRB

# np.max is calculated over ALL images
# This means e.g. we don't strech image 1 more than image 2
minval = np.min(useChannel)
If = (useChannel - minval) * 255.0 / (np.max(useChannel)-minval)

If = If.astype(np.uint8)

histColPlot(If, hist_args=dict(bins=255, log=True))
plt.suptitle("$I_f$ and Histograms")

## Part 2 - Thresholding

We apply two thresholding techniques to segment the nucleii of cells using $I_f$.

1. "handmade" threshold: from the histograms, we choose $T$ such that is isolates the background
2. Otsu's technique: we use the opencv implementation of Otsu's thresholding to determine $T$

We then extract the nucleii count and areas using `cv2.connectedComponents`, for result analysis.


In [None]:
histColPlot(If, hist_args=dict(bins=255, log=True))

### Threshold: Manual

In [None]:
# Feels like this is a good value:
T = 40
fig, axs = histColPlot(If, hist_args=dict(bins=255, log=False))

for hist in axs[1]:
    hist.axvline(T, color="r", label=f"{T=}")
    hist.legend()
plt.suptitle("$I_f$ histograms and chosen Threshold")

In [None]:
threshManual = 255* ( If > T )
threshManual=threshManual.astype(np.uint8)

imPlots = chain(If, threshManual)
imTitles = chain(imageTitles("$I_f{{ {i} }}$"), imageTitles("$T_{i}$"))

multiRowPlot(imPlots, imTitles, nrows=2, ncols=N, cmap="gray")

### Threshold: Otsu`s

In [None]:
threshOtsu = []
threshOtsuVals = []
for If_i in If:
    ret, otsuThresh_i =  cv2.threshold(If_i, 127, 255, cv2.THRESH_OTSU | cv2.THRESH_BINARY)
    threshOtsu.append(otsuThresh_i)
    threshOtsuVals.append(ret)

threshOtsu=np.stack(threshOtsu, axis=0)

imPlots = chain(images, threshManual, threshOtsu)
imTitles = chain(imageTitles("$I_f{{ {i} }}$"), 
    (
      title + f",{T=}" 
      for title in imageTitles("$T_{i} Manual$")
    ),
    ( title + f", {T=}"
      for T, title 
      in zip(threshOtsuVals,imageTitles("$T_{i}$ Otsu"))
    ))

print(imTitles)
multiRowPlot(imPlots, imTitles, nrows=3, ncols=N, cmap="gray")

### Evaluation

`ThreshManual` looks more consistant overall, with threshOtsu choosing a value that is slightly too high and leaves out part of some cells.

This makes sense because `Otsu` assumes a strong **bimodal distribution**, but the image has a very small foreground area. 

We can see this by calculating the percentage of pixels labeled as foreground by manual thresholding:

In [None]:
print("Number of foreground pixels (by image):")
print(np.count_nonzero(threshManual, axis=(1,2)))
print("Percent of foreground pixels (by image):")
percents = np.round(np.count_nonzero(threshManual, axis=(1,2))/threshManual[0].size * 100, 2)
print(", ".join([f"{p}%" for p in percents]))

In [None]:
# We picked this threshold:
thresh = threshManual

### Get connected areas

In [None]:
conn = 8
areas = []
for image in thresh:
    n_labels, labels =  cv2.connectedComponents(image, None, conn)

    # count number of pixels in each connected component
    # print(np.array([labels==i for i in range(n_labels)]).shape)
    area = np.count_nonzero([labels==i for i in range(n_labels)], axis=(1,2))

    # First item is background label
    area = area[1:]

    areas.append(area)

* There are a lot of areas with 1 or 2 pixels. 
    These are likely noise.

    We'll calculate stats including them first, but later we'll try filtering out areas that equal 1 or 2

In [None]:
for i, area in enumerate(areas):
    print(f"Image {i+1} areas:")
    print(area)

### Stats

**Without filtering** 
* We see there are a lot of 1-pixel areas (over 25% for some images, over 50% for image 1) .
* Maybe because of this, the mean area for the first image is much smaller
* The biggest nucleus is in image 3, which also has the highest mean size. But the median size of cells is higher on image4

**With filetring**
* Now the mean size of images is more similar, with images 3 and 4 having the bigger cells
* Standard deviation is very high, so cells are varying a lot in size.
* We see this on the quartiles, with 20% of cells being smaller than 50 pixels on all images but image 4

In [None]:
def print_stats(areas):
    dfs = [pd.DataFrame(data={f"image{i+1}":area}) for i, area in enumerate(areas)]
    print(pd.concat([df.describe().transpose().round(2) for df in dfs]))

print("====== ALL")
print_stats(areas)

min_area=5
print(f"====== Areas > {min_area=}")
print_stats([a[a>min_area] for a in areas])