In [None]:
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import get_file 
import os
import matplotlib.pyplot as plt
import numpy as np
from skimage.transform import resize,rescale
from skimage.util import random_noise
from skimage.io import imread,imsave
from skimage.filters import gaussian
from skimage.feature import shape_index
import tensorflow as tf
import tensorflow.keras.backend as K

# necessary functions

In [None]:
def jaccard_coef(y_true, y_pred):
    smooth = K.epsilon()
    #y_pred = K.cast(K.greater(y_pred, .8), dtype='float32') # .5 is the threshold
    #y_true = K.cast(K.greater(y_true, .9), dtype='float32') # .5 is the threshold
    intersection = K.sum(y_true * y_pred, axis=[0, -1, -2])
    sum_ = K.sum(y_true + y_pred, axis=[0, -1, -2])
    jac = (intersection + smooth) / (sum_ - intersection + smooth)
    return K.mean(jac)    

def bce_and_jac(y_true,y_pred):
    return tf.keras.losses.binary_crossentropy(y_true,y_pred)-K.log(jaccard_coef(y_true,y_pred)) 

def shapeindex_preprocess(im):
        ''' apply shape index map at three scales'''
        sh = np.zeros((im.shape[0],im.shape[1],3))
        if np.max(im) == 0:
            return sh
        
        # pad to minimize edge artifacts                    
        sh[:,:,0] = shape_index(im,1, mode='reflect')
        sh[:,:,1] = shape_index(im,1.5, mode='reflect')
        sh[:,:,2] = shape_index(im,2, mode='reflect')
        #sh = 0.5*(sh+1.0)
        
        # (Kevin) shape index returns nans
        sh[np.isnan(sh)] = 0
        return sh

# get model and compile 

In [None]:

model_path = get_file('misic_model','https://github.com/pswapnesh/Models/raw/master/MiSiDC04082020.h5')
model = load_model(model_path,compile=False)
model.compile(optimizer='adam',loss=bce_and_jac,metrics=['accuracy',jaccard_coef])
model.summary()

# optional: train only decoder side

In [None]:
keep_frozen = 12 
for ii in range(keep_frozen):
    model.layers[ii].trainable = False

# training
## Data preparation:
### Make sure the images are of size 256,256
given a gray scale image IM use the preprocessing function provided above <em>x = shapeindex_preprocess(IM)</em>, to have an output of shape (256,256,3) 
The ground truth, y should have shape (256,256,2) where the first image is cell body and second is cell boundary.

If you donot have the cell boundary information one can use <em>skimage.segmentation.find_boundaries(label_img)</em> to generate the boundaries.

Finally your full training data should look like

X -> [N,256,256,3]

y -> [N,256,256,2]



In [None]:
from pathlib import Path
from cellpose import io
import skimage.io
basedir = '/home/kcutler/DataDrive/final_train'
mask_filter = '_masks'
img_names = io.get_image_files(basedir,mask_filter)
mask_names,_ = io.get_label_files(img_names, mask_filter)
X = [skimage.io.imread(f) for f in img_names] #consider changing to imread from tifffile 
Y = [skimage.io.imread(f) for f in mask_names]

In [None]:
from cellpose import utils
from skimage.segmentation import find_boundaries
# y = utils.get_masks_unet(Y[0])
def convert_mask(m):
    boundary = find_boundaries(m)
    interior = np.logical_and(m>0, ~boundary)
    return np.stack((interior,boundary),axis=-1)
y = convert_mask(Y[0])
y.shape
# fig = plt.figure(figsize=(16,16))
# plt.imshow(np.hstack((y[0],y[1])))

In [None]:
# This is a test cell.  not necessary for retraining, this is a demonstration of data preparation

y = [convert_mask(yy) for yy in Y]
X = [shapeindex_preprocess(xx) for xx in X]

In [None]:
y_crop = np.stack([tf.image.resize_with_crop_or_pad(yy.astype(np.float64),256,256) for yy in y])
X_crop = np.stack([tf.image.resize_with_crop_or_pad(xx.astype(np.float64),256,256) for xx in X])
X_crop.shape,y_crop.shape

In [None]:
np.any(np.isnan(X_crop))

In [None]:
# number of epochs and steps will depend on your data size
num_epochs = 100
history = model.fit(X_crop,y_crop,epochs = num_epochs,batch_size=8)

In [None]:
# save in any format
model.save('/home/kcutler/DataDrive/misic_etc/new_misic_model_kevin_3.h5')

In [None]:
# predict
# yy = np.random.rand(2,256,256,3)
xx = X_crop[0][np.newaxis]
yp = model.predict(xx)
yp.shape


In [None]:
# plt.imshow(np.hstack((X_crop[0][:,:,0],y_crop[0][:,:,0])))
plt.imshow(X_crop[0])

In [None]:
np.max(X_crop[0])

In [None]:
model_path = get_file('misic_model','https://github.com/pswapnesh/Models/raw/master/MiSiDC04082020.h5')
model_orig = load_model(model_path,compile=False)
model_orig.compile(optimizer='adam',loss=bce_and_jac,metrics=['accuracy',jaccard_coef])

model_path ='/home/kcutler/DataDrive/misic_etc/new_misic_model_kevin_3.h5'
model_new = load_model(model_path,compile=False)
model_new.compile(optimizer='adam',loss=bce_and_jac,metrics=['accuracy',jaccard_coef])

model_new==model_orig

In [None]:
model_path = get_file('misic_model','https://github.com/pswapnesh/Models/raw/master/MiSiDC04082020.h5')
model_orig2 = load_model(model_path,compile=False)
model_orig2.compile(optimizer='adam',loss=bce_and_jac,metrics=['accuracy',jaccard_coef])

In [None]:
model_orig2==model_orig

In [None]:
history.model.save('/home/kcutler/DataDrive/misic_etc/new_misic_model_kevin_hist.h5')

In [None]:
print(model.trainable_variables[0][0]) 

In [None]:
model_new.trainable_variables[0][0][0]==model_orig.trainable_variables[0][0][0]

In [None]:
for x in X:
    plt.imshow(x)
    plt.show()
    plt.axis('off')