In [0]:
%run ./include/setup

## Load model

Retrieve model form 'Models'

From UI to copy the model name from : dbacademy.labuser9128531_1738705451.rf2

In [0]:
import mlflow
#                                      Stage/version
#                       Model name          |              output
#                           |               |                 |
mlflow.set_registry_uri('databricks-uc') #This sets the URI for the MLflow Model Registry to be the Databricks Unity Catalog (databricks-uc).
modelURL = "models:/" + 'dbacademy.labuser9128531_1739377194.rf2' + "@production"   #show stage or version is production 
print("Retrieving model " + modelURL) #references a model in the MLflow Model Registry
#"int" indicates that the model is expected to return an integer output (e.g., class label prediction, such as churn or no churn)
predict_churn_udf = mlflow.pyfunc.spark_udf(spark, modelURL, "int") #Convert Model to Spark UDF (User-Defined Function)
#This registers the UDF with Spark SQL, so you can now use the predict_churn function in Spark SQL queries.
spark.udf.register("predict_churn", predict_churn_udf)

## test if featuresfrom new data matches with feature store data schema

In [0]:
# Check the schema of the DataFrame to ensure column names match
df = spark.table('churn_features')
df.printSchema()

# Extract the list of input features (columns) that the model expects as input
model_features = predict_churn_udf.metadata.get_input_schema().input_names()

# Ensure the column names in the DataFrame match the expected input features
for feature in model_features:
    if feature not in df.columns:
        raise ValueError(f"Column {feature} not found in DataFrame")

In [0]:
#check the features that he model required and check the datatypes in the new data table
for feature in model_features:
    print(f"Column: {feature}, Type: {df.schema[feature].dataType}")

In [0]:
# This line extracts the list of input features (columns) that the model expects as input. 
model_features = predict_churn_udf.metadata.get_input_schema().input_names()
#start with a delta table
predictions = spark.table('churn_features').withColumn('churn_prediction', predict_churn_udf(*model_features))
predictions.createOrReplaceTempView("v_churn_prediction")

In [0]:
# Save the predictions to a Delta table
predictions.write.format("delta").mode("overwrite").saveAsTable("churn_prediction")

In [0]:
%sql
-- create or replace table churn_prediction as select * from v_churn_prediction

In [0]:
%sql
select * from churn_prediction