In [1]:
import sklearn
import mlflow
import time
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.ensemble import RandomForestRegressor

In [2]:
import pandas as pd

In [3]:
mlflow_uri = "http://mlflow-server:8888"
mlflow.set_tracking_uri(mlflow_uri)
time.sleep(5)

In [4]:
diamonds_df = sns.load_dataset('diamonds').drop(['cut', 'color', 'clarity'], axis=1)
#
X_train, X_test, y_train, y_test = train_test_split(diamonds_df.drop(["price"], axis=1), diamonds_df["price"], random_state=42)
print(type(X_train))
X_train.head()

<class 'pandas.core.frame.DataFrame'>


Unnamed: 0,carat,depth,table,x,y,z
35965,0.25,64.9,58.0,3.95,3.97,2.57
52281,0.84,61.8,56.0,6.04,6.07,3.74
6957,1.05,61.1,58.0,6.56,6.51,3.99
9163,1.02,60.7,56.0,6.53,6.5,3.95
50598,0.61,61.8,57.0,5.43,5.47,3.37


In [5]:
from pyspark.sql import SparkSession
spark = SparkSession.builder \
    .appName("mlflow_predict").master("spark://spark-master:7077") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/06/06 11:04:52 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [6]:
temp_df = X_test.copy()
temp_df.head()

Unnamed: 0,carat,depth,table,x,y,z
1388,0.24,62.1,56.0,3.97,4.0,2.47
50052,0.58,60.0,57.0,5.44,5.42,3.26
41645,0.4,62.1,55.0,4.76,4.74,2.95
42377,0.43,60.8,57.0,4.92,4.89,2.98
17244,1.55,62.3,55.0,7.44,7.37,4.61


In [7]:
# logged_model = 'runs:/38b7f28742a445409b736240a2a8042f/model'
logged_model= 'runs:/264000fa54314d598db2cf6f634dd78b/model'


# Load the model as a Spark UDF
loaded_model = mlflow.pyfunc.spark_udf(spark, model_uri=logged_model)


  from .autonotebook import tqdm as notebook_tqdm
Downloading artifacts: 100%|██████████| 5/5 [00:00<00:00, 73.40it/s]  
Downloading artifacts: 100%|██████████| 5/5 [00:00<00:00, 1040.72it/s]
2025/06/06 11:04:53 INFO mlflow.models.flavor_backend_registry: Selected backend for flavor 'python_function'


In [8]:
from pyspark.sql.functions import struct, col
df = spark.createDataFrame(temp_df)
print(type(df))
df.show()

<class 'pyspark.sql.dataframe.DataFrame'>


                                                                                

+-----+-----+-----+----+----+----+
|carat|depth|table|   x|   y|   z|
+-----+-----+-----+----+----+----+
| 0.24| 62.1| 56.0|3.97| 4.0|2.47|
| 0.58| 60.0| 57.0|5.44|5.42|3.26|
|  0.4| 62.1| 55.0|4.76|4.74|2.95|
| 0.43| 60.8| 57.0|4.92|4.89|2.98|
| 1.55| 62.3| 55.0|7.44|7.37|4.61|
|  1.0| 55.4| 62.0|6.63|6.59|3.66|
| 0.51| 60.2| 56.0|5.22|5.24|3.15|
| 0.52| 62.0| 56.0|5.17|5.19|3.21|
| 0.62| 60.0| 59.0|5.58|5.56|3.34|
| 1.14| 60.3| 57.0|6.79|6.85|4.11|
|  0.4| 62.8| 56.0|4.73| 4.7|2.96|
| 1.83| 62.8| 56.0|7.76|7.82|4.89|
|  0.6| 55.3| 63.0|5.67|5.61|3.12|
|  0.7| 63.5| 56.0|5.58|5.66|3.57|
| 0.53| 61.2| 65.0|5.16|5.27|3.19|
| 1.55| 61.2| 55.0|7.49|7.47|4.58|
| 0.98| 61.6| 66.0|6.46|6.24|3.92|
|  0.4| 62.6| 56.0|4.73| 4.7|2.95|
|  1.5| 62.8| 56.0|7.26|7.33|4.58|
| 0.35| 60.7| 62.0|4.53|4.59|2.77|
+-----+-----+-----+----+----+----+
only showing top 20 rows



In [9]:
# Predict on DataFrame using the loaded UDF
df_with_preds = df.withColumn('predictions', loaded_model(struct(*map(col, df.columns))))

# Show results
df_with_preds.show()

[Stage 1:>                                                          (0 + 1) / 1]

+-----+-----+-----+----+----+----+--------------------+
|carat|depth|table|   x|   y|   z|         predictions|
+-----+-----+-----+----+----+----+--------------------+
| 0.24| 62.1| 56.0|3.97| 4.0|2.47|[499.46524777281684]|
| 0.58| 60.0| 57.0|5.44|5.42|3.26|[1770.8910964268757]|
|  0.4| 62.1| 55.0|4.76|4.74|2.95|[1016.5741870113009]|
| 0.43| 60.8| 57.0|4.92|4.89|2.98|[1062.2121713852703]|
| 1.55| 62.3| 55.0|7.44|7.37|4.61|[11199.816620678614]|
|  1.0| 55.4| 62.0|6.63|6.59|3.66|[4467.1252610300235]|
| 0.51| 60.2| 56.0|5.22|5.24|3.15|[1707.3808620955772]|
| 0.52| 62.0| 56.0|5.17|5.19|3.21|[1727.0030647988162]|
| 0.62| 60.0| 59.0|5.58|5.56|3.34|[2159.1935709093577]|
| 1.14| 60.3| 57.0|6.79|6.85|4.11| [7678.507135442683]|
|  0.4| 62.8| 56.0|4.73| 4.7|2.96|  [981.936803967939]|
| 1.83| 62.8| 56.0|7.76|7.82|4.89|[12673.807236673114]|
|  0.6| 55.3| 63.0|5.67|5.61|3.12|[1733.5521919965825]|
|  0.7| 63.5| 56.0|5.58|5.66|3.57| [2440.946265601723]|
| 0.53| 61.2| 65.0|5.16|5.27|3.19|[1639.18167580

                                                                                

In [10]:
spark.stop()