In [None]:
import glob
import os
import errno
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from deepcell.datasets import SpotNet
from tensorflow.keras import callbacks
from tensorflow.keras.utils import get_file

from deepcell_spots import dotnet_losses
from deepcell_spots import image_generators
from deepcell_spots.utils.postprocessing_utils import *

In [None]:
spot_net = SpotNet()

# Load training dataset

In [None]:
train_X, train_y = spot_net.load_data(split='train')
train_dict = {'X': train_X, 'y': train_y}

val_X, val_y = spot_net.load_data(split='val')
val_dict = {'X': val_X, 'y': val_y}

test_X, test_y = spot_net.load_data(split='test')
test_dict = {'X': test_X, 'y': test_y}

In [None]:
print('Training set size: {}'.format(train_dict['X'].shape))
print('Validation set size: {}'.format(val_dict['X'].shape))
print('Test set size: {}'.format(test_dict['X'].shape))

In [None]:
# visualize to check that it loaded correctly
ind=0
plt.imshow(train_dict['X'][ind,...,0])
plt.scatter(train_dict['y'][ind][:,1], train_dict['y'][ind][:,0], edgecolors='r', facecolors='None', s=80)
plt.show()

# Set up model

In [None]:
# Set up required filepaths

modeldir = './models'
logdir = './logs'

# create directories if they do not exist
for d in (modeldir, logdir):
    try:
        os.makedirs(d)
    except OSError as exc:  # Guard against race condition
        if exc.errno != errno.EEXIST:
            raise
            
print('model dir: ', modeldir)
print('log dir: ', logdir)

In [None]:
from tensorflow.keras.optimizers import SGD
from deepcell.utils.train_utils import rate_scheduler

conv_model_name = 'example_conv_dots_model'

n_epoch = 10  # Number of training epochs
norm_method = None  # data normalization - options are: 'std','max', None, 'whole_image'
receptive_field = 13  # should be adjusted for the scale of the data

optimizer = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)

lr_sched = rate_scheduler(lr=0.01, decay=0.99)

# FC training settings
n_skips = 3  # number of skip-connections (only for FC training)
batch_size = 1  # FC training uses 1 image per batch

In [None]:
from deepcell_spots.dotnet import dot_net_2D

dots_model = dot_net_2D(receptive_field=receptive_field,
               input_shape=tuple(train_dict['X'].shape[1:]),
               inputs=None,
               n_skips=n_skips,
               norm_method=norm_method,
               padding_mode='reflect')

In [None]:
dots_model.summary()

In [None]:
sigma=3.0
gamma=0.5
focal = False

losses = dotnet_losses.DotNetLosses(
    sigma=sigma, gamma=gamma, focal=focal)

loss = {
    'offset_regression': losses.regression_loss,
    'classification': losses.classification_loss
}

regression_weight = 1
classification_weight = 5
total_weight = regression_weight + classification_weight
    
loss_weights = {
    "offset_regression": regression_weight / total_weight,
    "classification": classification_weight / total_weight
}
dots_model.compile(loss=loss, loss_weights=loss_weights,
              optimizer=optimizer, metrics=['accuracy'])

In [None]:
rotation_range=0
flip=True
shear=0
zoom_range=0
fill_mode='nearest'
cval=0.
seed=0

datagen = image_generators.ImageFullyConvDotDataGenerator(
    rotation_range=rotation_range,
    shear_range=shear,
    zoom_range=zoom_range,
    horizontal_flip=flip,
    vertical_flip=flip,
    fill_mode=fill_mode,
    cval=cval)

# DataGenerator object for validation data - generates data with no augmentation
datagen_val = image_generators.ImageFullyConvDotDataGenerator(
    rotation_range=0,
    shear_range=0,
    zoom_range=0,
    horizontal_flip=0,
    vertical_flip=0)

train_data = datagen.flow(
    train_dict,
    seed=seed,
    batch_size=batch_size)

val_data = datagen_val.flow(
    val_dict,
    seed=seed,
    batch_size=batch_size)

# Train model

In [None]:
num_gpus=1
loss_history = dots_model.fit(
    train_data,
    steps_per_epoch=train_data.y.shape[0] // batch_size,
    epochs=n_epoch,
    validation_data=val_data,
    validation_steps=val_data.y.shape[0] // batch_size,
    callbacks=[
        callbacks.LearningRateScheduler(lr_sched),
        callbacks.ModelCheckpoint(
            modeldir, monitor='val_loss', verbose=1,
            save_best_only=True, save_weights_only=num_gpus >= 2),
        callbacks.TensorBoard(log_dir=os.path.join(logdir, conv_model_name))
    ])

# Test model

In [None]:
y_pred_test = dots_model.predict(test_dict['X'])
print('Test image result shape:', y_pred_test[0].shape)

In [None]:
ind = 0
# plots results with restrictive decision
fig,ax=plt.subplots(1, 2, figsize=(10,5))
ax[0].imshow(test_dict['X'][ind,...,0], cmap='gray')
ax[0].set_title('Raw image')

# mark above threshold pixels (opaque plot over original)
threshold = 0.95
y_pred_test_dict = {}
y_pred_test_dict['classification'] = y_pred_test[1]
y_pred_test_dict['offset_regression'] = y_pred_test[0]
points_list = y_annotations_to_point_list_max(y_pred_test_dict, threshold, min_distance=1)
# plot ground truth centers, and predictions
ax[1].imshow(y_pred_test[1][ind,:,:,1], vmax=vmax, cmap='gray')
ax[1].scatter(points_list[ind][:,1], points_list[ind][:,0], edgecolors='r',
              facecolors='None', s=200, label='Predicted')
ax[1].plot(test_dict['y'][ind][:,1], test_dict['y'][ind][:,0], 'xb', label='GT')
ax[1].legend()
ax[1].set_title('Classification prediction')