In [8]:
import mlflow

# Run the login function to authenticate with Databricks CE
mlflow.login()


2024/08/29 02:00:25 INFO mlflow.utils.credentials: Successfully connected to MLflow hosted tracking server! Host: https://community.cloud.databricks.com.


In [9]:

# dummy experiment from tutorial
mlflow.set_tracking_uri("databricks")

mlflow.set_experiment("/check-databricks-connection")

with mlflow.start_run():
    mlflow.log_metric("foo", 1)
    mlflow.log_metric("bar", 2)



2024/08/29 02:00:34 INFO mlflow.tracking._tracking_service.client: 🏃 View run awesome-seal-612 at: https://community.cloud.databricks.com/ml/experiments/2597702965538188/runs/a4d786667a4f4e408d7a75d0062d959b.
2024/08/29 02:00:34 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: https://community.cloud.databricks.com/ml/experiments/2597702965538188.


## Simple Model and MLflow Experiment with Iris Dataset
Uses classic Iris dataset and RandomForest algo to train an ML model for classifying species of irises.

### Key Steps
1. Load dataset
2. Split data set into training and testing subsets (0.7, 0.3)
3. Train a Random Forest Classifier model
4. Make predictions on the test data
5. Calculate and log model accuracy to MLflow 

In [20]:
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split, cross_val_score
import mlflow
import mlflow.sklearn

# kill old processes
if mlflow.active_run():
    mlflow.end_run()

# set url and experiment in databricks
mlflow.set_tracking_uri('databricks')
mlflow.set_experiment('/iris-random-forest')

# variables to store the best results
best_n = None
best_score = 0
best_accuracy = 0

# start an mlflow experiment
with mlflow.start_run():
    # load iris dataset
    data = load_iris()
    # split data into train and test subsets
    X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.3, random_state=23)

    # train randomforest model (cross-validating 50, 100, 200, 500 trees)
    for n in [50, 100, 200, 500]:
        with mlflow.start_run(nested=True):  # start a nested run for each value of n_estimators
            model = RandomForestClassifier(n_estimators=n, random_state=42)
            
            # perform cross-validation and calculate mean score
            scores = cross_val_score(model, X_train, y_train, cv=5)
            mean_score = scores.mean()
            
            # train model on the full training set
            model.fit(X_train, y_train)
            
            # predict on the test set
            predictions = model.predict(X_test)
            
            # calculate accuracy on the test set
            accuracy = accuracy_score(y_test, predictions)
            
            # check if this is the best score
            if accuracy > best_accuracy:
                best_n = n
                best_score = mean_score
                best_accuracy = accuracy

            # log mean score and accuracy to mlflow for this model
            mlflow.log_param('n_estimators', n)
            mlflow.log_metric('mean_score', mean_score)
            mlflow.log_metric('accuracy', accuracy)

    # log the best model's results to mlflow
    mlflow.log_param('best_n_estimators', best_n)
    mlflow.log_metric('best_mean_score', best_score)
    mlflow.log_metric('best_accuracy', best_accuracy)
    mlflow.sklearn.log_model(model, 'best_model')

    print('best number of trees (n_estimators):', best_n)
    print('best mean cross-validation score:', best_score)
    print('best accuracy on test set:', best_accuracy)

    print('run id:', mlflow.active_run().info.run_id)

    print('Training matrix: ', X_train, '\nTraining vector: ', y_train, '\nTesting matrix: ', X_test, '\nTesting vector: ', y_test)


2024/08/29 19:00:08 INFO mlflow.tracking._tracking_service.client: 🏃 View run amusing-ant-818 at: https://community.cloud.databricks.com/ml/experiments/3146434618114174/runs/fe2d6eaa51ba4465b9abfe6e9ffa29a0.
2024/08/29 19:00:08 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: https://community.cloud.databricks.com/ml/experiments/3146434618114174.
2024/08/29 19:00:12 INFO mlflow.tracking._tracking_service.client: 🏃 View run adorable-rook-277 at: https://community.cloud.databricks.com/ml/experiments/3146434618114174/runs/ffd4dee5d9d74bb687ffffd954187927.
2024/08/29 19:00:12 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: https://community.cloud.databricks.com/ml/experiments/3146434618114174.
2024/08/29 19:00:16 INFO mlflow.tracking._tracking_service.client: 🏃 View run wise-fowl-193 at: https://community.cloud.databricks.com/ml/experiments/3146434618114174/runs/7a79adea41ad4fc88a53522ced0a8696.
2024/08/29 19:00:16 INFO mlflow.tracking._tracking_ser

best number of trees (n_estimators): 50
best mean cross-validation score: 0.9333333333333333
best accuracy on test set: 0.9777777777777777
run id: 7203ed72fb38411ebc57d09546505351
Training matrix:  [[6.3 2.5 5.  1.9]
 [4.9 2.5 4.5 1.7]
 [7.9 3.8 6.4 2. ]
 [4.9 3.1 1.5 0.2]
 [5.7 2.8 4.1 1.3]
 [7.4 2.8 6.1 1.9]
 [6.2 3.4 5.4 2.3]
 [5.5 4.2 1.4 0.2]
 [6.1 2.6 5.6 1.4]
 [5.4 3.  4.5 1.5]
 [7.  3.2 4.7 1.4]
 [7.7 3.8 6.7 2.2]
 [5.9 3.  4.2 1.5]
 [5.  3.6 1.4 0.2]
 [6.5 3.2 5.1 2. ]
 [6.7 3.  5.  1.7]
 [4.7 3.2 1.6 0.2]
 [5.8 2.6 4.  1.2]
 [5.1 3.3 1.7 0.5]
 [6.4 2.8 5.6 2.1]
 [5.5 2.3 4.  1.3]
 [6.8 2.8 4.8 1.4]
 [6.4 2.8 5.6 2.2]
 [5.8 2.8 5.1 2.4]
 [6.3 2.5 4.9 1.5]
 [4.9 3.6 1.4 0.1]
 [5.  3.4 1.6 0.4]
 [6.3 3.4 5.6 2.4]
 [4.9 3.  1.4 0.2]
 [4.8 3.4 1.6 0.2]
 [6.8 3.  5.5 2.1]
 [5.8 2.7 5.1 1.9]
 [5.5 2.5 4.  1.3]
 [4.4 2.9 1.4 0.2]
 [5.6 3.  4.1 1.3]
 [6.4 2.7 5.3 1.9]
 [6.2 2.2 4.5 1.5]
 [6.5 3.  5.8 2.2]
 [7.6 3.  6.6 2.1]
 [6.2 2.9 4.3 1.3]
 [4.4 3.2 1.3 0.2]
 [6.1 2.8 4.  1.3]
 [5.

2024/08/29 19:00:27 INFO mlflow.tracking._tracking_service.client: 🏃 View run welcoming-sheep-17 at: https://community.cloud.databricks.com/ml/experiments/3146434618114174/runs/7203ed72fb38411ebc57d09546505351.
2024/08/29 19:00:27 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: https://community.cloud.databricks.com/ml/experiments/3146434618114174.
