## Compare runs, choose the best one, and deploy it

#### Run hyperparameter optimization

#### Compare runs

#### Deploy the best run and register the model

#### Deploy the model to the REST API

#### Build an image for the deployment

In [1]:
import keras
import numpy as np
import pandas as pd
from hyperopt import STATUS_OK, Trials, fmin, hp, tpe
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split

import mlflow
from mlflow.models import infer_signature

In [2]:
# Load dataset
data = pd.read_csv(
    "https://raw.githubusercontent.com/mlflow/mlflow/master/tests/datasets/winequality-white.csv",
    sep=";",
)


In [3]:
# Split the data into training, validation, and test sets
train, test = train_test_split(data, test_size=0.25, random_state=42)
train_x = train.drop(["quality"], axis=1).values
train_y = train[["quality"]].values.ravel()
test_x = test.drop(["quality"], axis=1).values
test_y = test[["quality"]].values.ravel()
train_x, valid_x, train_y, valid_y = train_test_split(
    train_x, train_y, test_size=0.2, random_state=42
)

In [4]:
signature = infer_signature(train_x, train_y)

In [5]:
def train_model(params, epochs, train_x, train_y, valid_x, valid_y, test_x, test_y):
    # Define model architecture
    mean = np.mean(train_x, axis=0)
    var = np.var(train_x, axis=0)
    model = keras.Sequential(
        [
            keras.Input([train_x.shape[1]]),
            keras.layers.Normalization(mean=mean, variance=var),
            keras.layers.Dense(64, activation="relu"),
            keras.layers.Dense(1),
        ]
    )

    # Compile model
    model.compile(
        optimizer=keras.optimizers.SGD(
            learning_rate=params["lr"], momentum=params["momentum"]
        ),
        loss="mean_squared_error",
        metrics=[keras.metrics.RootMeanSquaredError()],
    )

    # Train model with MLflow tracking
    with mlflow.start_run(nested=True):
        model.fit(
            train_x,
            train_y,
            validation_data=(valid_x, valid_y),
            epochs=epochs,
            batch_size=64,
        )
        # Evaluate the model
        eval_result = model.evaluate(valid_x, valid_y, batch_size=64)
        eval_rmse = eval_result[1]

        # Log parameters and results
        mlflow.log_params(params)
        mlflow.log_metric("eval_rmse", eval_rmse)

        # Log model
        mlflow.tensorflow.log_model(model, "model", signature=signature)

        return {"loss": eval_rmse, "status": STATUS_OK, "model": model}


In [6]:
def objective(params):
    # MLflow will track the parameters and results for each run
    result = train_model(
        params,
        epochs=3,
        train_x=train_x,
        train_y=train_y,
        valid_x=valid_x,
        valid_y=valid_y,
        test_x=test_x,
        test_y=test_y,
    )
    return result


In [7]:
space = {
    "lr": hp.loguniform("lr", np.log(1e-5), np.log(1e-1)),
    "momentum": hp.uniform("momentum", 0.0, 1.0),
}


In [9]:
mlflow.set_tracking_uri(" http://127.0.0.1:5000")

In [10]:
mlflow.set_experiment("/wine-quality")
with mlflow.start_run():
    # Conduct the hyperparameter search using Hyperopt
    trials = Trials()
    best = fmin(
        fn=objective,
        space=space,
        algo=tpe.suggest,
        max_evals=8,
        trials=trials,
    )

    # Fetch the details of the best run
    best_run = sorted(trials.results, key=lambda x: x["loss"])[0]

    # Log the best parameters, loss, and model
    mlflow.log_params(best)
    mlflow.log_metric("eval_rmse", best_run["loss"])
    mlflow.tensorflow.log_model(best_run["model"], "model", signature=signature)

    # Print out the best parameters and corresponding loss
    print(f"Best parameters: {best}")
    print(f"Best eval rmse: {best_run['loss']}")

2024/10/23 00:39:43 INFO mlflow.tracking.fluent: Experiment with name '/wine-quality' does not exist. Creating a new experiment.


Epoch 1/3                                            

[1m 1/46[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m18s[0m 404ms/step - loss: 35.2453 - root_mean_squared_error: 5.8840
[1m 2/46[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 657us/step - loss: 34.6237 - root_mean_squared_error: 5.8840
[1m 3/46[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 34.4843 - root_mean_squared_error: 5.8722
[1m 4/46[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 34.5409 - root_mean_squared_error: 5.8770
[1m 6/46[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step - loss: 34.4627 - root_mean_squared_error: 5.8704
[1m 5/46[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 34.5522 - root_mean_squared_error: 5.8704
[1m 7/46[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step - loss: 34.3448 - root_mean_squared_error: 5.8603
[1m 8/46[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 34.1987 - root_m

2024/10/23 00:39:52 INFO mlflow.tracking._tracking_service.client: 🏃 View run adventurous-dog-953 at:  http://127.0.0.1:5000/#/experiments/512988094738068919/runs/a27542434df946e8a68a2024666d73bb.

2024/10/23 00:39:52 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at:  http://127.0.0.1:5000/#/experiments/512988094738068919.



Epoch 1/3                                                                      

[1m 1/46[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m20s[0m 465ms/step - loss: 39.1877 - root_mean_squared_error: 6.2600
[1m 2/46[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 3ms/step - loss: 36.3564 - root_mean_squared_error: 6.0250
[1m 3/46[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 34.1470 - root_mean_squared_error: 5.8341
[1m 4/46[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 32.1897 - root_mean_squared_error: 5.4826
[1m 5/46[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 30.3227 - root_mean_squared_error: 5.4826
[1m 6/46[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 28.6588 - root_mean_squared_error: 5.3205
[1m 7/46[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 27.1682 - root_mean_squared_error: 5.1703   
[1m46/46[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/st

2024/10/23 00:39:59 INFO mlflow.tracking._tracking_service.client: 🏃 View run brawny-turtle-425 at:  http://127.0.0.1:5000/#/experiments/512988094738068919/runs/fa38042e146145dea04ca5e03c105d70.

2024/10/23 00:39:59 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at:  http://127.0.0.1:5000/#/experiments/512988094738068919.



Epoch 1/3                                                                      

[1m 1/46[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m15s[0m 347ms/step - loss: 43.7297 - root_mean_squared_error: 6.6128
[1m 2/46[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 709us/step - loss: 42.5022 - root_mean_squared_error: 6.5187
[1m 3/46[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step - loss: 41.7898 - root_mean_squared_error: 6.4331
[1m 5/46[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step - loss: 41.1016 - root_mean_squared_error: 6.3915
[1m 4/46[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step - loss: 41.3961 - root_mean_squared_error: 6.4331
[1m 6/46[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step - loss: 40.8622 - root_mean_squared_error: 6.3915
[1m 7/46[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step - loss: 40.7053 - root_mean_squared_error: 6.3792   
[1m46/46[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/

2024/10/23 00:40:07 INFO mlflow.tracking._tracking_service.client: 🏃 View run valuable-fox-127 at:  http://127.0.0.1:5000/#/experiments/512988094738068919/runs/e38fe4cad778446cadef412d8ca7a137.

2024/10/23 00:40:07 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at:  http://127.0.0.1:5000/#/experiments/512988094738068919.



Epoch 1/3                                                                      

[1m 1/46[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m12s[0m 287ms/step - loss: 31.1674 - root_mean_squared_error: 5.5828
[1m 2/46[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 30.8280 - root_mean_squared_error: 5.5522
[1m 3/46[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 30.6151 - root_mean_squared_error: 5.5330
[1m 4/46[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 30.6298 - root_mean_squared_error: 5.5343
[1m 5/46[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step - loss: 30.7351 - root_mean_squared_error: 5.5469
[1m 7/46[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step - loss: 30.8530 - root_mean_squared_error: 5.5544
[1m 6/46[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step - loss: 30.7692 - root_mean_squared_error: 5.5469
[1m 8/46[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step 

2024/10/23 00:40:14 INFO mlflow.tracking._tracking_service.client: 🏃 View run skittish-fox-148 at:  http://127.0.0.1:5000/#/experiments/512988094738068919/runs/574028b359fe443aba4de8c7eeb27119.

2024/10/23 00:40:14 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at:  http://127.0.0.1:5000/#/experiments/512988094738068919.



Epoch 1/3                                                                      

[1m 1/46[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m21s[0m 474ms/step - loss: 34.0865 - root_mean_squared_error: 5.8384
[1m 2/46[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 10ms/step - loss: 34.6401 - root_mean_squared_error: 5.8854
[1m 3/46[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 7ms/step - loss: 34.9426 - root_mean_squared_error: 5.9110
[1m 4/46[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 6ms/step - loss: 35.1966 - root_mean_squared_error: 5.9457
[1m 5/46[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 5ms/step - loss: 35.3552 - root_mean_squared_error: 5.9531   
[1m 6/46[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 4ms/step - loss: 35.4420 - root_mean_squared_error: 5.9580   
[1m46/46[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - loss: 35.0146 - root_mean_squared_error: 5.9171 - val_loss: 31.1895 - val_root_mean_squared_error: 5.5848

Epoch 2

2024/10/23 00:40:22 INFO mlflow.tracking._tracking_service.client: 🏃 View run sincere-chimp-723 at:  http://127.0.0.1:5000/#/experiments/512988094738068919/runs/9f2d18594b8641c1b3792ce59b963aa2.

2024/10/23 00:40:22 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at:  http://127.0.0.1:5000/#/experiments/512988094738068919.



Epoch 1/3                                                                      

[1m 1/46[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m12s[0m 277ms/step - loss: 40.0708 - root_mean_squared_error: 6.3301
[1m 2/46[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 39.9461 - root_mean_squared_error: 6.3203
[1m 4/46[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step - loss: 39.4388 - root_mean_squared_error: 6.2799
[1m 3/46[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 39.6577 - root_mean_squared_error: 6.2973
[1m 5/46[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 39.1575 - root_mean_squared_error: 6.2573   
[1m46/46[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 23.7060 - root_mean_squared_error: 4.7878 - val_loss: 2.6687 - val_root_mean_squared_error: 1.6336

Epoch 2/3                                                                      

[1m 2/46[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m

2024/10/23 00:40:28 INFO mlflow.tracking._tracking_service.client: 🏃 View run chill-zebra-650 at:  http://127.0.0.1:5000/#/experiments/512988094738068919/runs/1d20fc6568624319b453ba7ff36815e4.

2024/10/23 00:40:28 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at:  http://127.0.0.1:5000/#/experiments/512988094738068919.



Epoch 1/3                                                                      

[1m 1/46[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m12s[0m 273ms/step - loss: 34.4812 - root_mean_squared_error: 5.8721
[1m 2/46[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step - loss: 35.5381 - root_mean_squared_error: 5.9607
[1m 3/46[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 36.2148 - root_mean_squared_error: 6.0169
[1m 4/46[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 36.6115 - root_mean_squared_error: 6.0498
[1m 5/46[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step - loss: 36.9398 - root_mean_squared_error: 6.0768
[1m 6/46[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step - loss: 37.0896 - root_mean_squared_error: 6.0892   
[1m 7/46[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m0s[0m 1ms/step - loss: 37.1290 - root_mean_squared_error: 6.0926   
[1m46/46[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms

2024/10/23 00:40:33 INFO mlflow.tracking._tracking_service.client: 🏃 View run entertaining-bee-269 at:  http://127.0.0.1:5000/#/experiments/512988094738068919/runs/f54c31a4f5944ae08c6691574004ca1d.

2024/10/23 00:40:33 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at:  http://127.0.0.1:5000/#/experiments/512988094738068919.



Epoch 1/3                                                                      

[1m 1/46[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m12s[0m 271ms/step - loss: 37.6005 - root_mean_squared_error: 6.1319
[1m 2/46[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 37.6632 - root_mean_squared_error: 6.1370
[1m 3/46[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 37.9373 - root_mean_squared_error: 6.1592
[1m 4/46[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 37.9808 - root_mean_squared_error: 6.1628
[1m 5/46[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 2ms/step - loss: 37.9729 - root_mean_squared_error: 6.1622   
[1m46/46[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 37.1776 - root_mean_squared_error: 6.0970 - val_loss: 33.5609 - val_root_mean_squared_error: 5.7932

Epoch 2/3                                                                      

[1m 1/46[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m0s[0

2024/10/23 00:40:38 INFO mlflow.tracking._tracking_service.client: 🏃 View run dapper-turtle-383 at:  http://127.0.0.1:5000/#/experiments/512988094738068919/runs/5b63ab60fc4d4099a2144dff2d21e9c6.

2024/10/23 00:40:38 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at:  http://127.0.0.1:5000/#/experiments/512988094738068919.



100%|██████████| 8/8 [00:55<00:00,  6.90s/trial, best loss: 0.8738760352134705]


2024/10/23 00:40:43 INFO mlflow.tracking._tracking_service.client: 🏃 View run able-pig-921 at:  http://127.0.0.1:5000/#/experiments/512988094738068919/runs/fb88bd31aeb846a88c790af0a3eb7da0.
2024/10/23 00:40:43 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at:  http://127.0.0.1:5000/#/experiments/512988094738068919.


Best parameters: {'lr': 0.01051513745776196, 'momentum': 0.3995259729788678}
Best eval rmse: 0.8738760352134705


In [15]:
## Inferece
from mlflow.models import validate_serving_input

model_uri = 'runs:/fb88bd31aeb846a88c790af0a3eb7da0/model'

# The logged model does not contain an input_example.
# Manually generate a serving payload to verify your model prior to deployment.
from mlflow.models import convert_input_example_to_serving_input

# Define INPUT_EXAMPLE via assignment with your own input example to the model
# A valid input example is a data instance suitable for pyfunc prediction
serving_payload = convert_input_example_to_serving_input(test_x)
# Validate the serving payload works on the model
validate_serving_input(model_uri, serving_payload)

Downloading artifacts: 100%|██████████| 7/7 [00:00<00:00, 190.49it/s]  

[1m39/39[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step 





array([[5.165843 ],
       [7.076086 ],
       [6.3972936],
       ...,
       [6.486343 ],
       [6.6715965],
       [5.423612 ]], dtype=float32)