In [0]:
# This notebook is meant to train a classification model from the Iris dataset and save it to the UC

In [0]:
%pip install mlflow --upgrade --pre
dbutils.library.restartPython()

In [0]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
import mlflow
import mlflow.sklearn
from mlflow.models.signature import infer_signature
from mlflow.tracking.client import MlflowClient
import requests

In [0]:
dbutils.widgets.text("catalog_name", "pedroz_e2edata_dev")
catalog_name = dbutils.widgets.get("catalog_name")

In [0]:
model_name = 'iris_model'

In [0]:
feature_table_name = f'{catalog_name}.default.iris_data'

In [0]:
mlflow.autolog()

with mlflow.start_run() as run:
    # Load data from Unity Catalog table
    df_iris = spark.table(feature_table_name).toPandas()
    features = ['sepal_length_cm', 'sepal_width_cm', 'petal_length_cm', 'petal_width_cm']
    target = 'species'

    X = df_iris[features]
    y = df_iris[target]

    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7)

    # Train the model
    model = DecisionTreeClassifier()
    model.fit(X_train, y_train)

    # Make predictions
    y_pred = model.predict(X_test)

    # Calculate and log metrics
    accuracy = accuracy_score(y_test, y_pred)
    recall = recall_score(y_test, y_pred, average='macro')
    precision = precision_score(y_test, y_pred, average='macro')
    f1 = f1_score(y_test, y_pred, average='macro')

    mlflow.log_metric("test_accuracy", accuracy)
    mlflow.log_metric("test_recall", recall)
    mlflow.log_metric("test_precision", precision)
    mlflow.log_metric("test_f1", f1)

    # Infer model signature
    signature = infer_signature(X_train, y_train)

    # Log the model
    mlflow.sklearn.log_model(
        sk_model=model,
        artifact_path="model",
        signature=signature,
        input_example=X_train.head()
    )

    # Log input dataset for lineage
    data_source = mlflow.data.load_delta(table_name=feature_table_name)
    mlflow.log_input(data_source, context="training")

    # Register the model in Unity Catalog
    model_uri = f"runs:/{run.info.run_id}/model"
    registered_model = mlflow.register_model(model_uri, f"{catalog_name}.default.{model_name}")

    client = mlflow.tracking.MlflowClient()
    client.set_registered_model_alias(name=registered_model.name, alias="challenger", version=registered_model.version)