# Train mnist with Tensorflow Keras and log with MLflow

## Prerequisites:

You may create a new conda virtual enviroment using  JupyterHub's conda tab.

You need to install the following python packages in this virtual environment

1.  Python 3.9.2, 
2.  Tensorflow 2.4.1
3.  tensorflow_datasets 1.2.0

Make sure your kernel is properly selected to use the virtual environment you created.

In [None]:
#
# if you do not have mlflow, you may uncomment the below line, and run the cell.  Restart your kernel after installation.

#%pip install mlflow==1.15.0

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import mlflow
import mlflow.keras
import os
print(tf.__version__)

import platform
print(platform.python_version())

In [None]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

In [None]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

In [None]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

## Step 2: Create and train the model

Plug the input pipeline into Keras.

In [None]:
mlflow.set_tracking_uri("http://127.0.0.1:5000")

os.environ["AZURE_STORAGE_CONNECTION_STRING"] ="DefaultEndpointsProtocol=https;AccountName=ashstore;AccountKey=yPI37wCjwFAdPeLdtTb/1re9dCZGsNMHxkhuLf0NT7KS1DELdzqnGkbQxmgxQbA1S1scaA+Yyz2Af0deDtkTEQ==;EndpointSuffix=orlando.azurestack.corp.microsoft.com"

##"DefaultEndpointsProtocol=https;AccountName=backupsli;AccountKey=XEB0k7Fh6+keWrNbSpWlvTZnGvQaCARHE2I+U2qQb71DQs1xmGQ3we/8worVyeFpE38vLLx6QrnzOGPHrcFGgQ==;EndpointSuffix=core.windows.net"

mlflow.set_experiment("mnist_ash1")
with mlflow.start_run(run_name="mnist-run-ex") as run:
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128,activation='relu'),
        tf.keras.layers.Dense(10)
     ])
    model.compile(
        optimizer=tf.keras.optimizers.Adam(0.001),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    )

    params = {"lr": 0.001}
    mlflow.log_params(params)
    mlflow.log_param("param_1","1")
    
    model.fit(
        ds_train,
        epochs=6,
        validation_data=ds_test,
    )
    
    mlflow.keras.log_model(
        keras_model=model,
        artifact_path="mnist-model",
        conda_env = "conda.yaml",
        registered_model_name="mnist-model2"
    )
    
    
    