# Multiplex Watershed: Hyper Parameter Search

Using [`optuna`](https://github.com/optuna/optuna) to find the best watershed hyper parameters.

In [None]:
import optuna
import numpy as np

from deepcell_toolbox.metrics import to_precision
from deepcell_toolbox.metrics import Metrics

from deepcell.applications import MultiplexSegmentation
from deepcell.datasets import multiplex_tissue

Create an objective function to run with `optuna`.

In [None]:
def objective(trial):
    # Create application
    app = MultiplexSegmentation()

    # Load the dataset
    (X_train, y_train),(X_test, y_test) = multiplex_tissue.load_data()

    # Combine into single dataset
    X = np.concatenate([X_train, X_test], axis=0)
    y = np.concatenate([y_train, y_test], axis=0)

    # Define parameters
    postprocess_kwargs = {
        'radius': trial.suggest_int('radius', 1, 15),
        'maxima_threshold': trial.suggest_int('maxima_threshold', 0., 1.),
        'interior_threshold': trial.suggest_int('interior_threshold', 0., 1.),
        'small_objects_threshold': trial.suggest_int('small_objects_threshold', 1, 15),
        'fill_holes_threshold': trial.suggest_int('fill_holes_threshold', 1, 15),
        'interior_model_smooth': trial.suggest_int('interior_model_smooth', 0, 4)
    }
    
    # Run the inference
    pred = app.predict(
        X, compartment='whole-cell',  # TODO: other options depend on data
        postprocess_kwargs_whole_cell=postprocess_kwargs,
    )
    
    # Run the metrics
    m = Metrics('Hyper Parameter Search', seg=False)
    m.calc_object_stats(y, pred)

    # Return the value to optimize
    monitor = m.stats['jaccard'].mean()
    
    recall = m.stats['correct_detections'].sum() / m.stats['n_true'].sum()
    
    precision = m.stats['correct_detections'].sum() / m.stats['n_pred'].sum()

    f1 = 2 * precision * recall / (precision + recall)
    return f1 if not np.isnan(f1) else 0

In [None]:
def show_result(study):
    pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
    complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]

    print('Study statistics: ')
    print('  Number of finished trials: ', len(study.trials))
    print('  Number of pruned trials: ', len(pruned_trials))
    print('  Number of complete trials: ', len(complete_trials))

    print('Best trial:')
    trial = study.best_trial

    print('  Value: ', trial.value)

    print('  Params: ')
    for key, value in trial.params.items():
        print('    {}: {}'.format(key, value))


In [None]:
study = optuna.create_study(
    direction='maximize',
    pruner=optuna.pruners.MedianPruner(n_startup_trials=2)
)

study.optimize(objective, n_trials=100, timeout=7200)

show_result(study)