In [0]:

file_path = 'dbfs:/FileStore/shared_uploads/podurikarthikeya69@gmail.com/sf_lgbm.csv'  
df = spark.read.csv(file_path, header=True, inferSchema=True)
df.show()


+--------------+------+------+------+------------+------------+--------------+----------+----------------+----------------+------------------------+----+---------+-----+-----+-----------+---------------------+-------------------------+
|    HostItemID|  Cost|  List|OnHand|MonthsNoSale|LastSaleDate|PriorYearSales|OnOrderQty|BestReorderPoint|SafetyStockLevel|MonthsWithAtLeastOneSale|Year|    Month|Sales|Lag_1|HolidayFlag|InventoryDiff_Reorder|InventoryDiff_SafetyStock|
+--------------+------+------+------+------------+------------+--------------+----------+----------------+----------------+------------------------+----+---------+-----+-----+-----------+---------------------+-------------------------+
|PMP*14680*2541|381.96|873.72|   184|          17|  2021-04-17|           217|        62|              74|              39|                       8|2021|  January|   39|  NaN|          0|                  110|                       70|
|PMP*14680*2541| 385.0|873.72|   184|          17|  2021

In [0]:
from pyspark.sql.functions import col, isnan, when
df = df.withColumn('LastSaleDate', col('LastSaleDate').cast('date'))
df = df.na.fill({'Sales': 0})
df = df.fillna(0)
df = df.withColumn('Month_Num', when(col('Month') == 'January', 1)
                          .when(col('Month') == 'February', 2)
                          .when(col('Month') == 'March', 3)
                          .when(col('Month') == 'April', 4)
                          .when(col('Month') == 'May', 5)
                          .when(col('Month') == 'June', 6)
                          .when(col('Month') == 'July', 7)
                          .when(col('Month') == 'August', 8)
                          .when(col('Month') == 'September', 9)
                          .when(col('Month') == 'October', 10)
                          .when(col('Month') == 'November', 11)
                          .when(col('Month') == 'December', 12))

df.show()

+--------------+------+------+------+------------+------------+--------------+----------+----------------+----------------+------------------------+----+---------+-----+-----+-----------+---------------------+-------------------------+---------+
|    HostItemID|  Cost|  List|OnHand|MonthsNoSale|LastSaleDate|PriorYearSales|OnOrderQty|BestReorderPoint|SafetyStockLevel|MonthsWithAtLeastOneSale|Year|    Month|Sales|Lag_1|HolidayFlag|InventoryDiff_Reorder|InventoryDiff_SafetyStock|Month_Num|
+--------------+------+------+------+------------+------------+--------------+----------+----------------+----------------+------------------------+----+---------+-----+-----+-----------+---------------------+-------------------------+---------+
|PMP*14680*2541|381.96|873.72|   184|          17|  2021-04-17|           217|        62|              74|              39|                       8|2021|  January|   39|  0.0|          0|                  110|                       70|        1|
|PMP*14680*2541|

In [0]:
from pyspark.sql import functions as F
df = df.drop('HostItemID', 'LastSaleDate', 'Month')
feature_columns = [col for col in df.columns if col != 'Sales']
target_column = 'Sales'

df.select(feature_columns).show()

+------+------+------+------------+--------------+----------+----------------+----------------+------------------------+----+-----+-----------+---------------------+-------------------------+---------+
|  Cost|  List|OnHand|MonthsNoSale|PriorYearSales|OnOrderQty|BestReorderPoint|SafetyStockLevel|MonthsWithAtLeastOneSale|Year|Lag_1|HolidayFlag|InventoryDiff_Reorder|InventoryDiff_SafetyStock|Month_Num|
+------+------+------+------------+--------------+----------+----------------+----------------+------------------------+----+-----+-----------+---------------------+-------------------------+---------+
|381.96|873.72|   184|          17|           217|        62|              74|              39|                       8|2021|  0.0|          0|                  110|                       70|        1|
| 385.0|873.72|   184|          17|           217|        62|              74|              39|                       8|2021| 39.0|          0|                   90|                       50| 

In [0]:
train_df, test_df = df.randomSplit([0.8, 0.2], seed=123)

print("Training data count: ", train_df.count())
print("Test data count: ", test_df.count())

Training data count:  22
Test data count:  6


In [0]:
from pyspark.ml.regression import GBTRegressor
from pyspark.ml.feature import VectorAssembler
from pyspark.ml import Pipeline
assembler = VectorAssembler(
    inputCols=['Cost', 'List', 'OnHand', 'MonthsNoSale', 'PriorYearSales', 'OnOrderQty', 
               'BestReorderPoint', 'SafetyStockLevel', 'MonthsWithAtLeastOneSale', 
               'Lag_1', 'HolidayFlag', 'InventoryDiff_Reorder', 'InventoryDiff_SafetyStock'], 
    outputCol='features'
)
gbt = GBTRegressor(
    featuresCol='features',  
    labelCol='Sales',        
    maxIter=100,             
    maxDepth=5
)
pipeline = Pipeline(stages=[assembler, gbt])
gbt_model = pipeline.fit(train_df)
predictions = gbt_model.transform(test_df)
predictions.select('Sales', 'prediction').show(5)

+-----+------------------+
|Sales|        prediction|
+-----+------------------+
|   22| 35.09858808838739|
|   21| 27.23841913263138|
|   48|33.999852146205875|
|   40| 33.99982288878686|
|   22|28.103160314073488|
+-----+------------------+
only showing top 5 rows



In [0]:
from pyspark.ml.evaluation import RegressionEvaluator
evaluator = RegressionEvaluator(
    labelCol='Sales', 
    predictionCol='prediction', 
    metricName='rmse'  
rmse = evaluator.evaluate(predictions)
print(f"Root Mean Squared Error (RMSE): {rmse}")
evaluator.setMetricName("mae")
mae = evaluator.evaluate(predictions)
print(f"Mean Absolute Error (MAE): {mae}")
evaluator.setMetricName("r2")
r2 = evaluator.evaluate(predictions)
print(f"R-squared: {r2}")

Root Mean Squared Error (RMSE): 9.830015742254334
Mean Absolute Error (MAE): 9.240328254010583
R-squared: 0.10436366072429226


In [0]:
%pip install mlflow
import mlflow
import mlflow.spark
with mlflow.start_run() as run:
    mlflow.log_param("maxIter", 100)
    mlflow.log_param("maxDepth", 5)

    mlflow.log_metric("rmse", rmse)  
    mlflow.log_metric("mae", mae)
    mlflow.log_metric("r2", r2)

    mlflow.spark.log_model(gbt_model, "GBTRegressorModel")
print(f"Run ID: {run.info.run_id}")

Python interpreter will be restarted.
Collecting mlflow
  Downloading mlflow-2.16.2-py3-none-any.whl (26.7 MB)
Collecting sqlalchemy<3,>=1.4.0
  Downloading SQLAlchemy-2.0.35-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
Collecting graphene<4
  Downloading graphene-3.3-py2.py3-none-any.whl (128 kB)
Collecting docker<8,>=4.0.0
  Downloading docker-7.1.0-py3-none-any.whl (147 kB)
Collecting Flask<4
  Downloading flask-3.0.3-py3-none-any.whl (101 kB)
Collecting gunicorn<24
  Downloading gunicorn-23.0.0-py3-none-any.whl (85 kB)
Collecting markdown<4,>=3.3
  Downloading Markdown-3.7-py3-none-any.whl (106 kB)
Collecting alembic!=1.10.0,<2
  Downloading alembic-1.13.3-py3-none-any.whl (233 kB)
Collecting mlflow-skinny==2.16.2
  Downloading mlflow_skinny-2.16.2-py3-none-any.whl (5.6 MB)
Collecting sqlparse<1,>=0.4.0
  Downloading sqlparse-0.5.1-py3-none-any.whl (44 kB)
Collecting opentelemetry-api<3,>=1.9.0
  Downloading opentelemetry_api-1.27.0-py3-none-any.whl (63 kB)
Col

2024/10/07 15:14:22 INFO mlflow.spark: Inferring pip requirements by reloading the logged model from the databricks artifact repository, which can be time-consuming. To speed up, explicitly specify the conda_env or pip_requirements when calling log_model().
2024/10/07 15:14:53 INFO mlflow.tracking._tracking_service.client: 🏃 View run indecisive-owl-62 at: https://community.cloud.databricks.com/ml/experiments/2817418682773692/runs/76bb3138f61742e28b35c112f3b3de8a.
2024/10/07 15:14:53 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: https://community.cloud.databricks.com/ml/experiments/2817418682773692.


Run ID: 76bb3138f61742e28b35c112f3b3de8a
