# Diatom analysis

See https://www.nature.com/articles/s41524-019-0202-3:

**Deep data analytics for genetic engineering of diatoms linking genotype to phenotype via machine learning**, Artem A. Trofimov, Alison A. Pawlicki, Nikolay Borodinov, Shovon Mandal, Teresa J. Mathews, Mark Hildebrand, Maxim A. Ziatdinov, Katherine A. Hausladen, Paulina K. Urbanowicz, Chad A. Steed, Anton V. Ievlev, Alex Belianinov, Joshua K. Michener, Rama Vasudevan, and Olga S. Ovchinnikova.

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Set up matplotlib defaults: larger images, gray color map
import matplotlib
matplotlib.rcParams.update({
    'figure.figsize': (10, 10),
    'image.cmap': 'gray'
})

In [None]:
from skimage import io
image = io.imread('../data/diatom-wild-032.jpg')

plt.imshow(image);

In [None]:
pores = image[:690, :]

plt.imshow(pores);

In [None]:
from skimage import filters
from skimage import util

denoised = filters.median(util.img_as_float(pores), behavior='ndimage')

In [None]:
pores_sqrt = np.sqrt(denoised)
plt.imshow(pores_sqrt);

In [None]:
pores_inv = 1 - pores_sqrt
plt.imshow(pores_inv);

In [None]:
from skimage import filters

In [None]:
T = filters.threshold_li(pores_sqrt)
thresholded = (pores_sqrt <= T)

plt.imshow(thresholded);

In [None]:
from scipy import ndimage as ndi
from skimage import segmentation, morphology, color

In [None]:
distance = ndi.distance_transform_edt(thresholded)
local_maxima = morphology.local_maxima(distance)

In [None]:
f, ax = plt.subplots(1, 1, figsize=(20, 20))

maxi_coords = np.nonzero(local_maxima)

ax.imshow(pores);
plt.scatter(maxi_coords[1], maxi_coords[0]);

In [None]:
def shuffle_labels(labels):
    """Shuffle the labels so that they are no longer in order.
    This helps with visualization.
    """
    indices = np.unique(labels)
    indices = np.append(
        [0],
        np.random.permutation(indices)
    )
    return indices[labels]

In [None]:
markers = ndi.label(local_maxima)[0]
labels = segmentation.watershed(-distance, markers)

In [None]:
f, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(20, 5))
ax0.imshow(thresholded)
ax1.imshow(np.log(1 + distance))
ax2.imshow(shuffle_labels(labels), cmap='magma');

In [None]:
labels_masked = segmentation.watershed(-distance, markers, mask=thresholded, connectivity=2)

In [None]:
f, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(20, 5))
ax0.imshow(thresholded)
ax1.imshow(np.log(1 + distance))
ax2.imshow(shuffle_labels(labels_masked), cmap='magma');

In [None]:
from skimage import measure

contours = measure.find_contours(labels_masked, level=0.5)
plt.imshow(pores)
for c in contours:
    plt.plot(c[:, 1], c[:, 0])

In [None]:
regions = measure.regionprops(labels_masked)

In [None]:
f, ax = plt.subplots(figsize=(10, 3))
ax.hist([r.area for r in regions], bins=100, range=(0, 200));

In [None]:
from keras import models, layers
from keras.layers import Conv2D, MaxPooling2D, UpSampling2D

M = 76
N = int(23 / 76 * M) * 2

model = models.Sequential()
model.add(
    Conv2D(
        32,
        kernel_size=(2, 2),
        activation='relu',
        input_shape=(N, N, 1),
        padding='same'
    )
)
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D(size=(2, 2)))
model.add(
    Conv2D(
        1,
        kernel_size=(2, 2),
        activation='sigmoid',
        padding='same'
    )
)
model.compile(loss='mse', optimizer='Adam', metrics=['accuracy'])

# Load pre-trained weights from disk
model.load_weights('../data/keras_model-diatoms-pores.h5')

In [None]:
shape = np.array(pores.shape)
padded_shape = (np.ceil(shape / 46) * 46).astype(int)
delta_shape = padded_shape - shape

padded_pores = np.pad(
    pores,
    pad_width=[(0, delta_shape[0]), (0, delta_shape[1])],
    mode='symmetric'
)

blocks = util.view_as_blocks(padded_pores, (46, 46))

In [None]:
B_rows, B_cols, _, _ = blocks.shape

In [None]:
tile_masks = []

for i in range(B_rows):
    for j in range(B_cols):
        tile = blocks[i, j]
        
        # nn wants (1, 46, 46, 1) tile shape
        tile = tile[np.newaxis, :, :, np.newaxis]
        predicted_mask = model.predict(tile)
        
        tile_masks.append(predicted_mask[0, :, :, 0])

In [None]:
nn_mask = util.montage(tile_masks, grid_shape=(B_rows, B_cols))
nn_mask = nn_mask[:shape[0], :shape[1]]

In [None]:
plt.imshow(nn_mask);

In [None]:
contours = measure.find_contours(nn_mask, level=0.5)
plt.imshow(pores)
for c in contours:
    plt.plot(c[:, 1], c[:, 0])

In [None]:
nn_regions = measure.regionprops(
    measure.label(nn_mask)
)

In [None]:
f, ax = plt.subplots(figsize=(10, 3))
ax.hist([r.area for r in regions], bins=100, range=(0, 200), alpha=0.5, label='Classic')
ax.hist([r.area for r in nn_regions], bins=100, range=(0, 200), alpha=0.5, label='NN')
ax.legend();