#### Churn Prediction Model Inference

#### 01. Inference with the Champion model

In [0]:
# Register Parameters
catalog = "workspace"
db = "customer_churn"

In [0]:
from mlflow.store.artifact.models_artifact_repo import ModelsArtifactRepository


requirements_path = ModelsArtifactRepository(f"models:/{catalog}.{db}.advanced_mlops_churn@Challenger").download_artifacts(artifact_path="requirements.txt") # download model from remote registry

In [0]:
%pip install --quiet uv
!uv pip install -r $requirements_path
dbutils.library.restartPython()

In [0]:
spark.sql("""
CREATE TABLE IF NOT EXISTS workspace.customer_churn.advanced_churn_cust_ids
AS SELECT * FROM workspace.dbdemos_mlops.advanced_churn_cust_ids
""")

In [0]:

spark.sql("""
MERGE INTO workspace.customer_churn.advanced_churn_feature_table AS target
USING workspace.dbdemos_mlops.advanced_churn_feature_table AS source
ON target.customer_id = source.customer_id
WHEN MATCHED THEN UPDATE SET *
WHEN NOT MATCHED THEN INSERT *
""")

In [0]:
from databricks.feature_engineering import FeatureEngineeringClient
import pyspark.sql.functions as F

# Load customer_id and transaction_ts columns to be scored
inference_df = spark.read.table(f"{catalog}.{db}.advanced_churn_cust_ids")
fe = FeatureEngineeringClient()

# Fully qualified model name
model_name = f"{catalog}.{db}.advanced_mlops_churn"

# Model URI
model_uri = f"models:/{model_name}@Champion"

# Batch score
preds_df = fe.score_batch(df=inference_df, model_uri=model_uri, result_type="string", )
display(preds_df)

#### 02. Save the predictions for monitoring

In [0]:
from mlflow import MlflowClient
from datetime import datetime


client = MlflowClient()

model = client.get_registered_model(name=model_name)
model_version = int(client.get_model_version_by_alias(name=model_name, alias="Champion").version)

In [0]:
import pyspark.sql.functions as F
from datetime import datetime, timedelta


offline_inference_df = preds_df.withColumn("model_name", F.lit(model_name)) \
                              .withColumn("model_version", F.lit(model_version)) \
                              .withColumn("model_alias", F.lit("Champion")) \
                              .withColumn("inference_timestamp", F.lit(datetime.now()- timedelta(days=2)))

offline_inference_df.write.mode("overwrite") \
                    .saveAsTable(f"{catalog}.{db}.advanced_churn_offline_inference")

display(offline_inference_df)