In [None]:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import datetime
import os
import numpy as np
from timeit import default_timer
import ssl 
ssl._create_default_https_context = ssl._create_unverified_context
import syotil
from skimage import io
import glob
from matplotlib import pyplot as plt
%matplotlib inline

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.losses import MSE
from tensorflow.python.data import Dataset
from tensorflow.keras.optimizers import SGD, Adam

from deepcell import losses
from deepcell import image_generators
from deepcell.utils import train_utils
from deepcell.utils import tracking_utils
from deepcell.utils.data_utils import get_data
from deepcell.utils.train_utils import rate_scheduler
from deepcell.utils.train_utils import get_callbacks
from deepcell.utils.train_utils import count_gpus
from deepcell.applications import CytoplasmSegmentation

In [None]:
# read image files
imgfiles = glob.glob('images/training/*_img.png')
print(imgfiles)

# show one file
img=io.imread(imgfiles[0])
print(img.shape)
im=img[:,:,0]
io.imshow(im)
plt.show()

imgs = [io.imread(imgfile)[:,:,0] for imgfile in imgfiles]

X_train = tf.stack(imgs)
X_train = np.expand_dims(X_train, axis=-1)
print(X_train.shape)

# read mask files
maskfiles = glob.glob('images/training/*_masks.png')
img=io.imread(maskfiles[0])
print(img.shape)
im=img
io.imshow(im)
plt.show()

masks = [io.imread(imgfile) for imgfile in maskfiles]

y_train = tf.stack(masks)
y_train = np.expand_dims(y_train, axis=-1)
print(y_train.shape)

np.savez("K_training_data", X=X_train, y=y_train) # objects to save need to be key value pairs

test_size=.2
seed=0
train_dict, test_dict = get_data("K_training_data.npz", test_size=test_size, seed=seed)

In [None]:
# Training model
# One-channel 
MODEL_PATH = ('https://deepcell-data.s3-us-west-1.amazonaws.com/saved-models/CytoplasmSegmentation-3.tar.gz')
MODEL_HASH = '6a244f561b4d37169cb1a58b6029910f'
archive_path = tf.keras.utils.get_file(
                'CytoplasmSegmentation.tgz', MODEL_PATH,
                file_hash=MODEL_HASH,
                extract=True, cache_subdir='models')
model_path = os.path.splitext(archive_path)[0]
pretrained_model = tf.keras.models.load_model(model_path) # this copy will not be trained
new_model = tf.keras.models.load_model(model_path)
#new_model.save_weights('/home/shan/cyto_pretrained_weights.h5')

In [None]:
model_name = '20221009'
n_epoch = 100
optimizer = Adam(learning_rate=1e-4, clipnorm=0.001)
lr_sched = rate_scheduler(lr=1e-4, decay=0.99)
batch_size = 1 # 8
min_objects = 0  # throw out images with fewer than this many objects
seed=0
model_name

datagen = image_generators.CroppingDataGenerator(
    rotation_range=180,
    shear_range=0,
    zoom_range=(0.7, 1/0.7),
    horizontal_flip=True,
    vertical_flip=True,
    #crop_size=(256, 256)) # generate error
    crop_size=(512, 512)) 

datagen_val = image_generators.SemanticDataGenerator(
    rotation_range=0,
    shear_range=0,
    zoom_range=0,
    horizontal_flip=0,
    vertical_flip=0)
    
train_data = datagen.flow(
    train_dict,
    seed=seed,
    transforms=['inner-distance','pixelwise'],
    transforms_kwargs={'pixelwise':{'dilation_radius': 1}, 
                      'inner-distance': {'erosion_width': 1, 'alpha': 'auto'}},
    min_objects=min_objects,
    batch_size=batch_size)

val_data = datagen_val.flow(
    test_dict,
    seed=seed,
    transforms=['inner-distance', 'pixelwise'],
    transforms_kwargs={'pixelwise':{'dilation_radius': 1},
                      'inner-distance': {'erosion_width': 1, 'alpha': 'auto'}},
    min_objects=min_objects,
    batch_size=batch_size)

# Define loss (create a dictionary of losses for each semantic head)

def semantic_loss(n_classes):
    def _semantic_loss(y_pred, y_true):
        if n_classes > 1:
            return 0.01 * losses.weighted_categorical_crossentropy(
                y_pred, y_true, n_classes=n_classes)
        return MSE(y_pred, y_true)
    return _semantic_loss

loss = {}

# Give losses for all of the semantic heads
for layer in new_model.layers:
    if layer.name.startswith('semantic_'):
        n_classes = layer.output_shape[-1]
        loss[layer.name] = semantic_loss(n_classes)

new_model.compile(loss=loss, optimizer=optimizer)

In [None]:
MODEL_DIR = os.path.join('models')
LOG_DIR = os.path.join('logs')
if not os.path.exists(MODEL_DIR): os.mkdir(MODEL_DIR)
if not os.path.exists(LOG_DIR): os.mkdir(LOG_DIR)    
model_path = os.path.join(MODEL_DIR, '{}.h5'.format(model_name))
loss_path = os.path.join(MODEL_DIR, '{}.npz'.format(model_name))

In [None]:
train_callbacks = get_callbacks(
    model_path,
    lr_sched=lr_sched,
    tensorboard_log_dir=LOG_DIR,
    save_weights_only=True,
    #save_weights_only=num_gpus >= 2,
    #monitor='val_loss',
    monitor='loss', # training loss
    verbose=1)

start = default_timer()
loss_history = new_model.fit_generator(
    train_data,
    steps_per_epoch=train_data.y.shape[0] // batch_size,
    epochs=n_epoch,
    #validation_data=val_data,
    #validation_steps=val_data.y.shape[0] // batch_size,
    callbacks=train_callbacks)
training_time = default_timer() - start
print('Training time: ', training_time, 'seconds.')


In [None]:
# read test files
im0 = io.imread('images/test/M872956_JML_Position8_CD3_test_img.png')
print(im0.shape)
im=im0
io.imshow(im)
plt.show()
im = np.expand_dims(im, axis=-1)
im = np.expand_dims(im, axis=0)

mask_true=io.imread("images/test/M872956_JML_Position8_CD3_test_masks.png")
print(mask_true.shape)
io.imshow(mask_true)
plt.show()

In [None]:
# predict using newly trained model
app = CytoplasmSegmentation(new_model)
x=im
y, tile_info = app._tile_input(x)
#print(x.shape)
#print(y.shape)
#print(tile_info)
pred = app.predict(y, image_mpp=1) 
prd = app._untile_output(pred, tile_info)
io.imshow(prd[0,:,:,0])
#plt.show()
print(syotil.csi(mask_true, prd[0,:,:,0])) # 0.21 without passing image_mpp. Setting image_mpp to 1 improves to 0.37

In [None]:
# predict using CytoplasmSegmentation constructed with pretrained model
# same as app = CytoplasmSegmentation() because the model is the default one
app = CytoplasmSegmentation(pretrained_model) 
x=im
y, tile_info = app._tile_input(x)
#print(x.shape)
#print(y.shape)
#print(tile_info)
pred = app.predict(y, image_mpp=1) 
prd = app._untile_output(pred, tile_info)
io.imshow(prd[0,:,:,0])
#plt.show()
print(syotil.csi(mask_true, prd[0,:,:,0])) # 0.21 without passing image_mpp. Setting image_mpp to 1 improves to 0.37