#### Challenger model validation

#### 01. Fetch Model information

In [0]:
# We are interested in validating the Challenger model
from mlflow.tracking import MlflowClient

catalog = "workspace"
db = "booking"
model_alias = "Challenger"
model_name = f"{catalog}.{db}.mlops_booking"

client = MlflowClient()
model_details = client.get_model_version_by_alias(model_name, model_alias)
model_version = int(model_details.version)

print(f"Validating {model_alias} model for {model_name} on model version {model_version}")

#### 02. Model checks

#### Description check

In [0]:
# If there's no description or an insufficient number of characters, tag accordingly
if not model_details.description:
  has_description = False
  print("Please add model description")
elif not len(model_details.description) > 20:
  has_description = False
  print("Please add detailed model description (40 char min).")
else:
  has_description = True

print(f'Model {model_name} version {model_details.version} has description: {has_description}')
client.set_model_version_tag(name=model_name, version=str(model_details.version), key="has_description", value=has_description)

#### Model performance metric

In [0]:
import mlflow

model_run_id = model_details.run_id
f1_score = mlflow.get_run(model_run_id).data.metrics['test_f1_score']

try:
    #Compare the challenger f1 score to the existing champion if it exists
    champion_model = client.get_model_version_by_alias(model_name, "Champion")
    champion_f1 = mlflow.get_run(champion_model.run_id).data.metrics['test_f1_score']
    print(f'Champion f1 score: {champion_f1}. Challenger f1 score: {f1_score}.')
    metric_f1_passed = f1_score >= champion_f1
except:
    print(f"No Champion found. Accept the model as it's the first one.")
    metric_f1_passed = True

print(f'Model {model_name} version {model_details.version} metric_f1_passed: {metric_f1_passed}')
# Tag that F1 metric check has passed
client.set_model_version_tag(name=model_name, version=model_details.version, key="metric_f1_passed", value=metric_f1_passed)

#### Benchmark or business metrics on the eval dataset

In [0]:
import pyspark.sql.functions as F
#get our validation dataset:
validation_df = spark.table(f"{catalog}.{db}.mlops_booking_training").filter("split='validate'")

#Call the model with the given alias and return the prediction
def predict_churn(validation_df, model_alias):
    model = mlflow.pyfunc.spark_udf(spark, model_uri=f"models:/{catalog}.{db}.mlops_booking@{model_alias}") #Use env_manager="virtualenv" to recreate a venv with the same python version if needed
    return validation_df.withColumn('predictions', model(*model.metadata.get_input_schema().input_names()))