## Automatically pick up the last model training from MLFlow instead of hardcoding the run id

In [1]:
import mlflow

last_parent_run = set()
exp = mlflow.get_experiment_by_name("Default")

df = mlflow.search_runs([exp.experiment_id], order_by=["Created DESC"])
last_run_id = df.loc[0,'run_id']

print(last_run_id)

715ffc45700d474caa97257fa1787911


## Load back the model from MLFlow

In [2]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

logged_model = f'runs:/{last_run_id}/model'
# logged_model = 'runs:/502530abbc2a4ae3b043462eaa4b8828/model'

# Load model as a Spark UDF. Override result_type if the model does not return double values.
loaded_model = mlflow.pyfunc.spark_udf(spark, model_uri=logged_model, result_type='string')

23/01/04 06:37:29 WARN Utils: Your hostname, unreal resolves to a loopback address: 127.0.1.1; using 192.168.20.68 instead (on interface wlp1s0)
23/01/04 06:37:29 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
23/01/04 06:37:30 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/01/04 06:37:31 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


## Create a Spark DataFrame with some rows of text

In [3]:
from pyspark.sql.functions import struct, col

df = spark.createDataFrame([
    ['ca', 'number of bananas sold in december', 1, 10], 
    ['wi', 'number of carrots in 2022', 2, 20], 
], schema=['state', 'text', 'cat', 'id'])

df.show()
display(df)

+-----+--------------------+---+---+
|state|                text|cat| id|
+-----+--------------------+---+---+
|   ca|number of bananas...|  1| 10|
|   wi|number of carrots...|  2| 20|
+-----+--------------------+---+---+



DataFrame[state: string, text: string, cat: bigint, id: bigint]

## Now apply the Model on the DataFrame

In [4]:
df2 = df.withColumn('predictions', loaded_model(struct(*map(col, df.columns))))
df2.show()

+-----+--------------------+---+---+--------------------+
|state|                text|cat| id|         predictions|
+-----+--------------------+---+---+--------------------+
|   ca|number of bananas...|  1| 10|SELECT Bananen FR...|
|   wi|number of carrots...|  2| 20|SELECT COUNT Carr...|
+-----+--------------------+---+---+--------------------+



## Now register the model as a Spark UDF

In [5]:
spark.udf.register("predict_text_to_sql", loaded_model)

spark.sql("SHOW FUNCTIONS LIKE '*predict_text_to_sql*'").show()

+-------------------+
|           function|
+-------------------+
|predict_text_to_sql|
+-------------------+



## Use the same DataFrame example data as a table

In [6]:
df.createOrReplaceTempView("veggies")

spark.sql("SELECT * FROM veggies").show()

+-----+--------------------+---+---+
|state|                text|cat| id|
+-----+--------------------+---+---+
|   ca|number of bananas...|  1| 10|
|   wi|number of carrots...|  2| 20|
+-----+--------------------+---+---+



## The crux, now we run the SQL query with the model function and get the predictions!

In [7]:
spark.sql("SELECT predict_text_to_sql(text) AS prediction FROM veggies").show()

+--------------------+
|          prediction|
+--------------------+
|SELECT Bananen FR...|
|SELECT COUNT Carr...|
+--------------------+



## Bonus: Run the model on a Panda DataFrames

In [8]:
import pandas as pd

# Load model as a PyFuncModel.
loaded_model = mlflow.pyfunc.load_model(logged_model)

# From Spark DataFrame 
data = df.toPandas()
display(data)
print(loaded_model.predict(data[1:2]))

# From Panda DataFrame
d = { 'text': ['number of bananas sold in december', 'number of carrots in 2022']}
pdf = pd.DataFrame(data=d)

# Predict on a single Pandas DataFrame row
print(loaded_model.predict(pdf[1:2]))

# Predict on a Pandas DataFrame.
pdf['sql'] = pdf.apply(loaded_model.predict, axis=1)
display(pdf['sql'].values)

Unnamed: 0,state,text,cat,id
0,ca,number of bananas sold in december,1,10
1,wi,number of carrots in 2022,2,20


                                                   0
0  SELECT COUNT Carrots FROM table WHERE Year = 2...
                                                   0
0  SELECT COUNT Carrots FROM table WHERE Year = 2...


array([                                                 0
       0  SELECT Bananen FROM table WHERE Date = December,
                                                          0
       0  SELECT COUNT Carrots FROM table WHERE Year = 2022], dtype=object)

In [15]:
%%sql
sElect * FROM veggies