# Multiplex Model Training with Hyper Parameter Search

Using `optuna`.

In [None]:
import os

import numpy as np
import tensorflow as tf

import deepcell

In [None]:
# Optuna config
TRIALS = 25  # number of HPS trials to run.
TIMEOUT = 3600 * 24  # time limit for HPS in seconds.

# Training paramters
SEED = 0
EPOCHS = 20
BATCH_SIZE = 8

# File paths for data and models
NPZ_NAME = '20200819_multiplex_normalized_512x512'
EXP_NAME = '20200819_hyper_parameter'
MODEL_NAME = '{}_deep_watershed'.format(NPZ_NAME)

ROOT_DIR = '/data'
LOG_DIR = os.path.join(ROOT_DIR, 'logs')
MODEL_DIR = os.path.join(ROOT_DIR, 'models', EXP_NAME)

DATA_DIR = os.path.join(ROOT_DIR, 'users/willgraf/mibi-hps')

TRAIN_DATA_FILE = os.path.join(DATA_DIR, '{}_train.npz'.format(NPZ_NAME))
TEST_DATA_FILE = os.path.join(DATA_DIR, '{}_test.npz'.format(NPZ_NAME))
VAL_DATA_FILE = os.path.join(DATA_DIR, '{}_val.npz'.format(NPZ_NAME))

if not os.path.isdir(MODEL_DIR):
    os.makedirs(MODEL_DIR)

In [None]:
def load_data():
    train_data = np.load(os.path.join(DATA_DIR, TRAIN_DATA_FILE))
    X_train = train_data['X']
    y_train = train_data['y']
    
    val_data = np.load(os.path.join(DATA_DIR, VAL_DATA_FILE))
    X_val = val_data['X']
    y_val = val_data['y']

#     test_data = np.load(os.path.join(DATA_DIR, TEST_DATA_FILE))
#     X_test = test_data['X']
#     y_test = test_data['y']

    return (X_train, y_train), (X_val, y_val)

In [None]:
from deepcell import losses
from deepcell.model_zoo.panopticnet import PanopticNet


def semantic_loss(n_classes):
    def _semantic_loss(y_pred, y_true):
        if n_classes > 1:
            return 0.01 * losses.weighted_categorical_crossentropy(
                y_pred, y_true, n_classes=n_classes)
        return tf.keras.losses.MSE(y_pred, y_true)
    return _semantic_loss


def create_model(trial):
    
    model = PanopticNet(
        backbone='resnet50',
        input_shape=(256, 256, 2),
        norm_method=None,
        num_semantic_heads=4,
        num_semantic_classes=[1, 1, 2, 3], # inner distance, outer distance, fgbg, pixelwise
        location=True,  # should always be true
        include_top=True)
    
    lr = trial.suggest_float('lr', 1e-4, 1e-1, log=True)
    clipnorm = .001
    
    optimizer = tf.keras.optimizers.Adam(lr=lr, clipnorm=clipnorm)
    
    loss = {}

    # Give losses for all of the semantic heads
    for layer in model.layers:
        if layer.name.startswith('semantic_'):
            n_classes = layer.output_shape[-1]
            loss[layer.name] = semantic_loss(n_classes)
    
    model.compile(loss=loss, optimizer=optimizer)

    return model

In [None]:
from skimage.segmentation import relabel_sequential

from deepcell.image_generators import CroppingDataGenerator


def create_data_generators(trial, train_dict, val_dict):
    # Tunable parameters
    # preprocessing parameter
    k = trial.suggest_int('k', 32, 256)

    # data generator parameter
    min_objects = trial.suggest_int('min_objects', 1, 50)

    # transform parameters
    inner_erosion = trial.suggest_int('inner_erosion', 0, 5)
    outer_erosion = trial.suggest_int('outer_erosion', 0, 5)
    dilation_radius = trial.suggest_int('dilation_radius', 0, 5)
    
    # data augmentation parameters
    zoom_min = trial.suggest_float('zoom_min', 0.25, 1.)
    zoom_max = trial.suggest_float('zoom_max', 1., 4.)

    fill_modes = ['constant', 'nearest', 'reflect', 'wrap']
    fill_mode = trial.suggest_categorical('fill_mode', fill_modes)

    X_train, y_train = train_dict['X'], train_dict['y']
    X_val, y_val = val_dict['X'], val_dict['y']

    # Preprocess the data
    train_dict['X'] = phase_preprocess(train_dict['X'], k)
    val_dict['X'] = phase_preprocess(val_dict['X'], k)

    # use augmentation for training but not validation
    datagen = CroppingDataGenerator(
        fill_mode=fill_mode,
        rotation_range=180,
        shear_range=0,
        zoom_range=(zoom_min, zoom_max),
        horizontal_flip=True,
        vertical_flip=True,
        crop_size=(256, 256))

    datagen_val = CroppingDataGenerator(
        fill_mode=fill_mode,
        rotation_range=0,
        shear_range=0,
        zoom_range=0,
        horizontal_flip=0,
        vertical_flip=0,
        crop_size=(256, 256))

    transforms = ['inner-distance', 'watershed-cont', 'fgbg', 'pixelwise']

    transforms_kwargs = {
        'watershed-cont': {
            'erosion_width': inner_erosion
        },
        'pixelwise': {
            'dilation_radius': dilation_radius
        },
        'inner-distance': {
            'erosion_width': outer_erosion,
            'alpha': 'auto'
        }
    }

    train_data = datagen.flow(
        train_dict,
        seed=SEED,
        transforms=transforms,
        transforms_kwargs=transforms_kwargs,
        min_objects=min_objects,
        batch_size=BATCH_SIZE)

    val_data = datagen_val.flow(
        val_dict,
        seed=SEED,
        transforms=transforms,
        transforms_kwargs=transforms_kwargs,
        min_objects=min_objects,
        batch_size=BATCH_SIZE)
    
    return train_data, val_data

In [None]:
from deepcell.utils.train_utils import count_gpus, get_callbacks, rate_scheduler
from deepcell_toolbox.processing import phase_preprocess

from optuna.integration import TFKerasPruningCallback


def objective(trial):
    # Clear clutter from previous TensorFlow graphs.
    tf.keras.backend.clear_session()

    monitor = 'val_loss'
    
    model_path = os.path.join(MODEL_DIR, '{}.h5'.format(MODEL_NAME))
    loss_path = os.path.join(MODEL_DIR, '{}.npz'.format(MODEL_NAME))

    # Create model instance.
    model = create_model(trial)
    
    # load the data
    (X_train, y_train), (X_val, y_val) = load_data()

    train_dict = {'X': X_train, 'y': y_train}
    val_dict = {'X': X_val, 'y': y_val}

    # Create dataset instance.
    train_data, val_data = create_data_generators(trial, train_dict, val_dict)

    # Create callbacks for early stopping and pruning.
    train_callbacks = get_callbacks(
        model_path,
        lr_sched=rate_scheduler(lr=1e-4, decay=0.99),
        tensorboard_log_dir=LOG_DIR,
        save_weights_only=count_gpus() >= 2,
        monitor='val_loss',
        verbose=1)
    
    train_callbacks.append(TFKerasPruningCallback(trial, monitor))
    train_callbacks.append(tf.keras.callbacks.EarlyStopping(patience=5))

    # Train model.
    history = model.fit_generator(
        train_data,
        steps_per_epoch=train_data.y.shape[0] // BATCH_SIZE,
        epochs=EPOCHS,
        validation_data=val_data,
        validation_steps=val_data.y.shape[0] // BATCH_SIZE,
        callbacks=train_callbacks)

    return history.history[monitor][-1]


In [None]:
def show_result(study):
    pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
    complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]

    print('Study statistics: ')
    print('  Number of finished trials: ', len(study.trials))
    print('  Number of pruned trials: ', len(pruned_trials))
    print('  Number of complete trials: ', len(complete_trials))

    print('Best trial:')
    trial = study.best_trial

    print('  Value: ', trial.value)

    print('  Params: ')
    for key, value in trial.params.items():
        print('    {}: {}'.format(key, value))


In [None]:
import optuna

study = optuna.create_study(
    direction='maximize',
    pruner=optuna.pruners.MedianPruner(n_startup_trials=2)
)

study.optimize(objective, n_trials=TRIALS, timeout=TIMEOUT)

show_result(study)