In [0]:
from databricks.feature_engineering import FeatureEngineeringClient
from mlflow.tracking import MlflowClient

# Initialize FeatureEngineeringClient
fe = FeatureEngineeringClient()

# Input customer IDs for prediction
inference_df = spark.createDataFrame(
    [("C1001",), ("C1002",), ("C1003",), ("C1004",), ("C1005",)],
    ["customer_id"]
)

# Get latest model version from registry
def get_latest_model_version(model_name):
    client = MlflowClient()
    versions = client.search_model_versions(f"name='{model_name}'")
    return max(int(mv.version) for mv in versions)

model_name = "oms_analytics.ml.churn_prediction_model"
latest_version = get_latest_model_version(model_name)

# Perform batch inference (prediction)
predictions_df = fe.score_batch(
    model_uri=f"models:/{model_name}/{latest_version}",
    df=inference_df
)

# Show results: prediction = 1.0 means customer might leave (churn), 0.0 means they likely stay
display(predictions_df.select("customer_id", "prediction"))
