In [1]:
#!pip install --pre --upgrade stav

In [2]:
# Training code from
# https://mlflow.org/docs/latest/getting-started/intro-quickstart/index.html

import mlflow
from mlflow.models import infer_signature

import pandas as pd
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import stav

mlflow.set_tracking_uri(uri="http://localhost:8080")

# Load the Iris dataset
X, y = datasets.load_iris(return_X_y=True)

# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# Define the model hyperparameters
params = {
    "solver": "lbfgs",
    "max_iter": 1000,
    "multi_class": "auto",
    "random_state": 8888,
}

# Train the model
lr = LogisticRegression(**params)
lr.fit(X_train, y_train)

# Predict on the test set
y_pred = lr.predict(X_test)

# Calculate metrics
accuracy = accuracy_score(y_test, y_pred)

with mlflow.start_run():
    mlflow.set_tag(stav.INFO_TRAINING, "Basic LR model for iris data")
    mlflow.set_tag(stav.AI_PROVIDER, "Acme Corporation")
    mlflow.set_tag(stav.AI_DEPLOYER, "Sirius Cybernetics")
    mlflow.set_tag(stav.AUTONOMY_TYPE, "No")
    mlflow.set_tag(stav.USE_SENSITIVE_PERSONAL_INFO, "No")
    mlflow.set_tag(stav.HYPERPARAMETER, params)
    mlflow.log_params(params)

    mlflow.log_metric(stav.METRICS_ACCURACY, accuracy)

    # Infer the model signature
    signature = infer_signature(X_train, lr.predict(X_train))

    # Log the model
    model_info = mlflow.sklearn.log_model(
        sk_model=lr,
        artifact_path="iris_model",
        signature=signature,
        input_example=X_train,
        registered_model_name="tracking-quickstart",
    )

Registered model 'tracking-quickstart' already exists. Creating a new version of this model...
2024/04/16 12:48:38 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: tracking-quickstart, version 3
Created version '3' of model 'tracking-quickstart'.
