# Get Started with MLflow + Tensorflow

In this guide, we will show how to train your model with Tensorflow and log your training using MLflow.

We will use [Databricks Community Edition](https://community.cloud.databricks.com/) as our tracking server, which has built-in support for MLflow. Databricks CE is the free version of Databricks platform, if you haven't, please register an account via [link](https://www.databricks.com/try-databricks).

You can run code in this guide from cloud-based notebooks like Databricks notebook or Google Colab, or run it on your local machine.

## Install dependencies

Let's install the `mlflow` package.

```
%pip install -q mlflow
```

Then let's import the packages.

In [8]:
pip install -q mlflow

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m26.3/26.3 MB[0m [31m73.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m95.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m233.0/233.0 kB[0m [31m20.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m147.8/147.8 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m128.2/128.2 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.4/84.4 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.8/52.8 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m505.6/505.6 kB[0m [31m34.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [10]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras

## Load the dataset

We will do a simple image classification on handwritten digits with [mnist dataset](https://www.tensorflow.org/datasets/catalog/fashion_mnist).

Let's load the dataset using `tensorflow_datasets` (`tfds`), which returns datasets in the format of `tf.data.Dataset`.

In [11]:
# Load the mnist dataset.
train_ds, test_ds = tfds.load(
    "fashion_mnist",
    split=["train", "test"],
    shuffle_files=True,
)

Let's preprocess our data with the following steps:
- Scale each pixel's value to `[0, 1)`.
- Batch the dataset.
- Use `prefetch` to speed up the training.

In [12]:
def preprocess_fn(data):
    image = tf.cast(data["image"], tf.float32) / 255
    label = data["label"]
    return (image, label)


train_ds = train_ds.map(preprocess_fn).batch(128).prefetch(tf.data.AUTOTUNE)
test_ds = test_ds.map(preprocess_fn).batch(128).prefetch(tf.data.AUTOTUNE)

## Define the Model

Let's define a convolutional neural network as our classifier. We can use `keras.Sequential` to stack up the layers.

In [13]:
input_shape = (28, 28, 1)
num_classes = 10

model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(num_classes, activation="softmax"),
    ]
)

Set training-related configs, optimizers, loss function, metrics.

In [14]:
model.compile(
    loss = "sparse_categorical_crossentropy",
    optimizer = "adam",
    metrics = ["sparse_categorical_accuracy"],
)

## Set up tracking/visualization tool

In this tutorial, we will use Databricks CE as MLflow tracking server. For other options such as using your local MLflow server, please read the [Tracking Server Overview](https://mlflow.org/docs/latest/getting-started/tracking-server-overview/index.html).

If you have not, please register an account of [Databricks community edition](https://www.databricks.com/try-databricks#account). It should take no longer than 1min to register. Databricks CE (community edition) is a free platform for users to try out Databricks features. For this guide, we need the ML experiment dashboard for us to track our training progress.




After successfully registering an account on Databricks CE, let's connnect MLflow to Databricks CE. You will need to enter following information:
- **Databricks Host**: https://community.cloud.databricks.com/
- **Username**: your signed up email
- **Password**: your password

In [15]:
import mlflow

mlflow.login()

2024/08/04 13:09:52 INFO mlflow.utils.credentials: No valid Databricks credentials found, please enter your credentials...


Databricks Host (should begin with https://): https://community.cloud.databricks.com/?o=632534433870215
Username: spravin20032001@gmail.com
Password: ··········


2024/08/04 13:10:35 INFO mlflow.utils.credentials: Successfully connected to MLflow hosted tracking server! Host: https://community.cloud.databricks.com.


Now this colab is connected to the hosted tracking server. Let's configure MLflow metadata. Two things to set up:
- `mlflow.set_tracking_uri`: always use "databricks".
- `mlflow.set_experiment`: pick up a name you like, start with `/`.

## Logging with MLflow

There are two ways you can log to MLflow from your Tensorflow pipeline:
- MLflow auto logging.
- Use a callback.

Auto logging is simple to configure, but gives you less control. Using a callback is more flexible. Let's see how each way is done.

### MLflow Auto Logging

All you need to do is to call `mlflow.tensorflow.autolog()` before kicking off the training, then the backend will automatically log the metrics into the server you configured earlier. In our case, Databricks CE.

In [16]:
# Choose any name that you like.
mlflow.set_experiment("/mlflow-tf-keras-mnist")

mlflow.tensorflow.autolog()

model.fit(x=train_ds, epochs=3)

2024/08/04 13:11:12 INFO mlflow.tracking.fluent: Experiment with name '/mlflow-tf-keras-mnist' does not exist. Creating a new experiment.
2024/08/04 13:11:13 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID '3d35d417c7994df48a10726933ea80f5', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current tensorflow workflow


Epoch 1/3
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 0.9785 - sparse_categorical_accuracy: 0.6563



[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 8ms/step - loss: 0.9778 - sparse_categorical_accuracy: 0.6565
Epoch 2/3
[1m467/469[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - loss: 0.4511 - sparse_categorical_accuracy: 0.8379



[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step - loss: 0.4510 - sparse_categorical_accuracy: 0.8380
Epoch 3/3
[1m468/469[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 7ms/step - loss: 0.3966 - sparse_categorical_accuracy: 0.8596



[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 7ms/step - loss: 0.3966 - sparse_categorical_accuracy: 0.8596




Uploading artifacts:   0%|          | 0/7 [00:00<?, ?it/s]



Uploading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

2024/08/04 13:11:46 INFO mlflow.tracking._tracking_service.client: 🏃 View run bittersweet-auk-82 at: https://community.cloud.databricks.com/?o=632534433870215/ml/experiments/780218162467344/runs/3d35d417c7994df48a10726933ea80f5.
2024/08/04 13:11:46 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: https://community.cloud.databricks.com/?o=632534433870215/ml/experiments/780218162467344.


<keras.src.callbacks.history.History at 0x7eb1affd1c30>

While your training is ongoing, you can find this training in your dashboard. Log in to your [Databricks CE](https://community.cloud.databricks.com/) account, and click on machine learning in the drop down list.


You can click on metrics to see the chart.

Let's evaluate the training result.

In [17]:
score = model.evaluate(test_ds)

print(f"Test loss: {score[0]:.4f}")
print(f"Test accuracy: {score[1]: .2f}")

[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 17ms/step - loss: 0.3458 - sparse_categorical_accuracy: 0.8773
Test loss: 0.3533
Test accuracy:  0.87


### Log with MLflow Callback

Auto logging is powerful and convenient, but if you are looking for a more native way as Tensorflow pipelines, you can use `mlflow.tensorflow.MllflowCallback` inside `model.fit()`, it will log:
- Your model configuration, layers, hyperparameters and so on.
- The training stats, including losses and metrics configured with `model.compile()`.

In [18]:
from mlflow.tensorflow import MlflowCallback

# Turn off autologging.
mlflow.tensorflow.autolog(disable=True)

with mlflow.start_run() as run:
    model.fit(
        x=train_ds,
        epochs=2,
        callbacks=[MlflowCallback(run)],
    )

Epoch 1/2
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - loss: 0.3658 - sparse_categorical_accuracy: 0.8690
Epoch 2/2
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 0.3484 - sparse_categorical_accuracy: 0.8763


2024/08/04 13:15:46 INFO mlflow.tracking._tracking_service.client: 🏃 View run overjoyed-shrike-156 at: https://community.cloud.databricks.com/?o=632534433870215/ml/experiments/780218162467344/runs/958e9ac4a8e74d78a9bbb8104a9f91ec.
2024/08/04 13:15:46 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: https://community.cloud.databricks.com/?o=632534433870215/ml/experiments/780218162467344.


Going to the Databricks CE experiment view, you will see a similar dashboard as before.

### Customize the MLflow Callback

If you want to add extra logging logic, you can customize the MLflow callback. You can either subclass from `keras.callbacks.Callback` and write everything from scratch or subclass from `mlflow.tensorflow.MllflowCallback` to add you custom logging logic.

Let's look at an example that we want to replace the loss with its log value to log to MLflow.

In [None]:
import math


# Create our own callback by subclassing `MlflowCallback`.
class MlflowCustomCallback(MlflowCallback):
    def on_epoch_end(self, epoch, logs=None):
        if not self.log_every_epoch:
            return
        loss = logs["loss"]
        logs["log_loss"] = math.log(loss)
        del logs["loss"]
        self.metrics_logger.record_metrics(logs, epoch)

Train the model with the new callback.

In [None]:
with mlflow.start_run() as run:
    run_id = run.info.run_id
    model.fit(
        x=train_ds,
        epochs=2,
        callbacks=[MlflowCustomCallback(run)],
    )

Epoch 1/2
Epoch 2/2




Going to your Databricks CE page, you should find the `log_loss` is replacing the `loss` metric, similar to what is shown in the screenshot below.

![log loss screenshot](https://i.imgur.com/dncAwaP.png)