# Model Serving & AI Functions for Batch Inference

In [0]:
import mlflow.pyfunc
from pyspark.sql.functions import struct, col
from pyspark.sql.types import StringType
import pyspark.sql.functions as F

## Create UDF for Batch Scoring

In [0]:
# Load best model from registry
model_name = "main.ttw_workshop_demo.customer_anomaly_detector"
champion_uri = f"models:/{model_name}@champion"

# This will automatically enforce the model’s input schema
predict_udf = mlflow.pyfunc.spark_udf(
    spark,
    champion_uri,
)

# Register for SQL usage
spark.udf.register("predict_customer_anomaly", predict_udf)


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/11 [00:00<?, ?it/s]



Downloading artifacts:   0%|          | 0/12 [00:00<?, ?it/s]

2025/08/04 17:41:58 INFO mlflow.models.flavor_backend_registry: Selected backend for flavor 'python_function'


<pyspark.sql.udf.UserDefinedFunction at 0x7f1f08749590>

## Batch Inference

In [0]:
# Apply batch inference
inference_df = spark.table("main.ttw_workshop_demo.customer_features")

champion_model = mlflow.pyfunc.load_model(champion_uri)

preds_df = inference_df.withColumn(
    "anomaly_prediction",
    predict_udf(*champion_model.metadata.get_input_schema().input_names())
).withColumn(
    "anomaly_status",
    F.when(F.col("anomaly_prediction") == -1, "ANOMALY")
     .otherwise("NORMAL")
)

# Persist results
preds_df.write.mode("overwrite").saveAsTable("main.ttw_workshop_demo.batch_predictions")

display(preds_df.select("Customer_ID", "engagement_score", "anomaly_status").limit(10))

Downloading artifacts:   0%|          | 0/11 [00:00<?, ?it/s]

Customer_ID,engagement_score,anomaly_status
bdd640fb06674ad19c80317fa3b1799d,3.3690000534057614,NORMAL
1a3d1fa7bc8940a9a3b8c1e9392456de,5.53899998664856,NORMAL
972a846916414f828b9d2434e465e150,3.6450000166893,NORMAL
3b8faa1837f8488b97fc695a07a0ca6e,16.26400032043457,NORMAL
b74d0fb132e746298fadc1a606cb0fb3,8.966000080108643,NORMAL
72ff5d2a386e4be0ab65a6a48b8148f6,5.314999866485596,NORMAL
c241330b01a9471f9e8a774bcf36d58b,5.966000080108643,NORMAL
47229389571a4876ac307511b2b9437a,8.835000133514404,NORMAL
1a2a73ed562b4f79837459eef50bea63,7.002999973297118,NORMAL
580d7b71d8f544139be6128e18c26797,10.27799997329712,NORMAL
