In [27]:
# Setup sys.path to import from src/
import sys
import os
import pandas as pd

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [28]:
# 1) Force Spark to bind and advertise on localhost
#    (or "host.docker.internal" if you're in Docker on Windows)
os.environ["SPARK_LOCAL_IP"] = "127.0.0.1"

from pyspark.sql import SparkSession
import mlflow

# 2) Build your SparkSession with matching driver configs
spark = (
    SparkSession
      .builder
      .appName("LoadAndPredict")
      .master("local[1]")
      .config("spark.driver.bindAddress", "0.0.0.0")
      .config("spark.driver.host", "127.0.0.1")
      # If you’re inside Windows Docker you may need:
      # .config("spark.driver.host", "host.docker.internal")
      .getOrCreate()
)

In [29]:
# Spark session
from pyspark.sql import SparkSession, functions as F

# Import custom functions
from src.preprocessing import load_data, clean_data, feature_engineer, assemble_features
from src.save_outputs import save_predictions, save_feature_importances, save_model_metadata
from src.model import pick_best_model_from_grid, evaluate_model, manual_grid_search_rf, manual_grid_search_dt, manual_grid_search_gbt, manual_grid_search_lr, manual_grid_search_nb, train_and_log_mlflow
from src.mongo_export import create_mongo_database, get_collection

In [3]:
spark = SparkSession.builder \
    .appName("Ecommerce Behavior Exploration") \
    .master("local[*]") \
    .config("spark.driver.memory", "8g") \
    .getOrCreate()


In [30]:
# Load raw dataset
file_path = "../data/testData-2019-Nov.csv"
raw_df = load_data(file_path, spark)

# Quick look at raw data
raw_df.show(5)
raw_df.printSchema()


+--------------------+----------+----------+-------------------+--------------------+--------+------+---------+--------------------+
|          event_time|event_type|product_id|        category_id|       category_code|   brand| price|  user_id|        user_session|
+--------------------+----------+----------+-------------------+--------------------+--------+------+---------+--------------------+
|2019-11-19 08:35:...|      view|  30200005|2053013554449088861|                NULL|   elari|  77.2|512412397|f62be3c5-18af-4ab...|
|2019-11-26 14:16:...|      view|   1005115|2053013555631882655|electronics.smart...|   apple|916.37|568675496|c857db53-cd0a-480...|
|2019-11-10 17:50:...|      view|  15700275|2053013559733912211|                NULL|imperial|206.16|513262731|c637d18a-6fc5-4c1...|
|2019-11-04 14:23:...|      view|   1004589|2053013555631882655|electronics.smart...|    inoi| 61.36|562973725|e41d3c3f-830e-48d...|
|2019-11-29 17:11:...|  purchase|   5300157|2053013563173241677|     

In [31]:
# Clean the dataset
clean_df = clean_data(raw_df)

# Quick check after cleaning
clean_df.show(5)


+-------------------+----------+----------+-------------------+--------------------+--------+------+---------+--------------------+
|         event_time|event_type|product_id|        category_id|       category_code|   brand| price|  user_id|        user_session|
+-------------------+----------+----------+-------------------+--------------------+--------+------+---------+--------------------+
|2019-11-19 03:35:46|      view|  30200005|2053013554449088861|                NULL|   elari|  77.2|512412397|f62be3c5-18af-4ab...|
|2019-11-26 09:16:08|      view|   1005115|2053013555631882655|electronics.smart...|   apple|916.37|568675496|c857db53-cd0a-480...|
|2019-11-10 12:50:50|      view|  15700275|2053013559733912211|                NULL|imperial|206.16|513262731|c637d18a-6fc5-4c1...|
|2019-11-04 09:23:52|      view|   1004589|2053013555631882655|electronics.smart...|    inoi| 61.36|562973725|e41d3c3f-830e-48d...|
|2019-11-29 12:11:17|  purchase|   5300157|2053013563173241677|             

In [32]:
# Create new features if needed
feature_df = feature_engineer(clean_df)

# Quick preview
feature_df.show(5)


+--------------------+---------+-------------+-------------+-------------------+-------------------+---------+----------------+
|        user_session|num_views|num_cart_adds|num_purchases|      session_start|        session_end|avg_price|session_duration|
+--------------------+---------+-------------+-------------+-------------------+-------------------+---------+----------------+
|08327d25-5fe5-4ec...|        1|            0|            0|2019-11-10 23:43:37|2019-11-10 23:43:37|     39.9|               0|
|cbd7e3a0-2c5e-493...|        1|            0|            0|2019-11-17 15:48:48|2019-11-17 15:48:48|     4.61|               0|
|a08f39b4-71f4-431...|        1|            0|            0|2019-11-24 10:21:36|2019-11-24 10:21:36|    23.81|               0|
|544365e8-d7e6-4ea...|        1|            1|            0|2019-11-09 11:44:13|2019-11-09 11:46:48|    47.47|             155|
|71697745-4b20-419...|        1|            0|            0|2019-11-17 12:22:24|2019-11-17 12:22:24|   1

In [33]:
labels = clean_df.groupBy("user_session").agg(
    (F.max(F.when(F.col("event_type") == "purchase", 1).otherwise(0))).alias("label")
)

final_df = feature_df.join(labels, on="user_session", how="left").fillna(0)


In [34]:
# Feature columns you want to use
feature_cols = ['num_views', 'num_cart_adds', 'session_duration', 'avg_price']

# Assemble features
final_df = assemble_features(final_df, feature_cols)

In [35]:
train_df, test_df = final_df.randomSplit([0.8, 0.2], seed=42)

In [24]:
train_and_log_mlflow(train_df, test_df, 10)



✔ Logged run with AUC=0.9366
🏃 View run defiant-ram-266 at: http://localhost:5000/#/experiments/1/runs/7fdafa6c452c47789d0cbd81be42165e
🧪 View experiment at: http://localhost:5000/#/experiments/1


In [36]:
import mlflow

# 1) Point at your MLflow tracking server
mlflow.set_tracking_uri("http://localhost:5000")

# 2) (Optional) pick your experiment by name
experiment_name = "test_lr_experiment"
mlflow.set_experiment(experiment_name)

# 3) Find the most recent run ID in that experiment
from mlflow.tracking import MlflowClient
client = MlflowClient()
# look up the experiment
exp = client.get_experiment_by_name(experiment_name)
if exp is None:
    raise ValueError(f"Experiment '{experiment_name}' not found")
# fetch runs, ordered by start time descending
runs = client.search_runs(
    experiment_ids=[exp.experiment_id],
    order_by=["attribute.start_time DESC"],
    max_results=1,
)
if not runs:
    raise ValueError(f"No runs found in experiment '{experiment_name}'")
run_id = runs[0].info.run_id
print("Loading model from run:", run_id)

# 4) Construct the model URI and load it
model_uri = f"runs:/{run_id}/spark-model"
loaded_model = mlflow.spark.load_model(model_uri)

# 5) Now you can call .transform() on any DataFrame with the same schema
# e.g. reuse your `test_df` from before, or create a new one:
predictions = loaded_model.transform(test_df)
predictions.show()


2025/05/04 13:09:55 INFO mlflow.spark: URI 'runs:/7fdafa6c452c47789d0cbd81be42165e/spark-model/sparkml' does not point to the current DFS.
2025/05/04 13:09:55 INFO mlflow.spark: File 'runs:/7fdafa6c452c47789d0cbd81be42165e/spark-model/sparkml' not found on DFS. Will attempt to upload the file.
2025/05/04 13:09:55 INFO mlflow.spark: Copied SparkML model to /tmp/mlflow/543ee1c3-0316-4d3e-8045-3d152a9db10b


Loading model from run: 7fdafa6c452c47789d0cbd81be42165e


Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.runJob.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 6.0 failed 1 times, most recent failure: Lost task 0.0 in stage 6.0 (TID 6) (host.docker.internal executor driver): org.apache.spark.SparkException: Python worker failed to connect back.
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:203)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:174)
	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:67)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:834)
Caused by: java.net.SocketTimeoutException: Accept timed out
	at java.base/java.net.PlainSocketImpl.waitForNewConnection(Native Method)
	at java.base/java.net.PlainSocketImpl.socketAccept(PlainSocketImpl.java:163)
	at java.base/java.net.AbstractPlainSocketImpl.accept(AbstractPlainSocketImpl.java:474)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:551)
	at java.base/java.net.ServerSocket.accept(ServerSocket.java:519)
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:190)
	... 17 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:989)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2393)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2414)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2433)
	at org.apache.spark.api.python.PythonRDD$.runJob(PythonRDD.scala:181)
	at org.apache.spark.api.python.PythonRDD.runJob(PythonRDD.scala)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:834)
Caused by: org.apache.spark.SparkException: Python worker failed to connect back.
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:203)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:174)
	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:67)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	... 1 more
Caused by: java.net.SocketTimeoutException: Accept timed out
	at java.base/java.net.PlainSocketImpl.waitForNewConnection(Native Method)
	at java.base/java.net.PlainSocketImpl.socketAccept(PlainSocketImpl.java:163)
	at java.base/java.net.AbstractPlainSocketImpl.accept(AbstractPlainSocketImpl.java:474)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:551)
	at java.base/java.net.ServerSocket.accept(ServerSocket.java:519)
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:190)
	... 17 more


In [None]:
manual_grid_search_rf(train_df, test_df)

In [None]:
manual_grid_search_lr(train_df, test_df)


In [None]:
manual_grid_search_dt(train_df, test_df)

In [None]:

manual_grid_search_nb(train_df, test_df)

In [None]:
manual_grid_search_gbt(train_df, test_df)

In [None]:
# for random forest model

base_rf_dir = "../output/rf/"
base_lr_dir = "../output/lr/"
base_dt_dir = "../output/dt/"
base_nb_dir = "../output/nb/"
base_gbt_dir = "../output/gbt/"

dic = {"model":"", "params": "", "AUC": 0 }

results = []

print("RF: ")
results_csv_path_rf = os.path.join(base_rf_dir, "grid_search_results.csv")
results.append(pick_best_model_from_grid(results_csv_path_rf, base_rf_dir).split(','))

print("LR: ")
results_csv_path_lr = os.path.join(base_lr_dir, "grid_search_results.csv")
results.append(pick_best_model_from_grid(results_csv_path_lr, base_lr_dir).split(','))

print("DT: ")
results_csv_path_dt = os.path.join(base_dt_dir, "grid_search_results.csv")
results.append(pick_best_model_from_grid(results_csv_path_dt, base_dt_dir).split(','))

print("NB: ")
results_csv_path_nb = os.path.join(base_nb_dir, "grid_search_results.csv")
results.append(pick_best_model_from_grid(results_csv_path_nb, base_nb_dir).split(','))

print("GBT: ")
results_csv_path_gbt = os.path.join(base_gbt_dir, "grid_search_results.csv")
results.append(pick_best_model_from_grid(results_csv_path_gbt, base_gbt_dir).split(','))

for r in results:
    if float(r[1]) >= dic["AUC"]:
        dic["model"] = r[2]
        dic["params"] = r[0][:-2]
        dic["AUC"] = float(r[1])
print(dic)

In [None]:
# Save a list of best hyperparameters or feature names
import json

with open("../output/best_params.json", "w") as f:
    json.dump(dic, f)


In [None]:
# Load predictions CSV (adjust path if needed)
model, params = dic["model"], dic["params"]
predictions_path = f"../output/{model}/{model}_{params}/predictions.csv"
df_predictions = pd.read_csv(predictions_path)

df_predictions.head()


In [None]:
create_mongo_database(df_predictions.to_dict(orient='records'))

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

output_dir = "../output/figures"
os.makedirs(output_dir, exist_ok=True)

collection = get_collection()
# Fetch from MongoDB
df_mongo = pd.DataFrame(list(collection.find()))

# Drop MongoDB's default _id field
df_mongo = df_mongo.drop(columns=["_id"], errors='ignore')

# --- Visualization 1: Conversion Funnel ---
funnel = {
    "Viewed": df_mongo.shape[0],
    "Cart Add": df_mongo[df_mongo["num_cart_adds"] > 0].shape[0],
    "Purchased": df_mongo[df_mongo["label"] == 1].shape[0]
}
sns.barplot(x=list(funnel.keys()), y=list(funnel.values()))
plt.title("User Conversion Funnel")
plt.ylabel("Number of Users")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "conversion_funnel.png"), dpi=300)
plt.show()

# --- Visualization 2: Predicted Purchase Probabilities ---
if 'probability' in df_mongo.columns:
    sns.histplot(df_mongo['probability'], bins=30, kde=True)
    plt.title("Predicted Purchase Probability")
    plt.xlabel("Probability")
    plt.ylabel("Frequency")
    plt.ylim(0, 10000)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "purchase_probability.png"), dpi=300)
    plt.show()


In [None]:
funnel_df = pd.DataFrame({
    "Stage": ["Viewed", "Cart Add", "Purchased"],
    "Users": [funnel["Viewed"], funnel["Cart Add"], funnel["Purchased"]]
})
funnel_df["Dropoff %"] = 100 * (1 - funnel_df["Users"] / funnel_df["Users"].iloc[0])

sns.lineplot(data=funnel_df, x="Stage", y="Dropoff %", marker="o")
plt.title("Drop-Off Rate at Funnel Stages")
plt.ylabel("Drop-Off (%)")
plt.ylim(0, 100)
plt.grid(True)
plt.tight_layout()
plt.savefig("../output/figures/dropoff_rate.png")
plt.show()


In [None]:
sns.boxplot(data=df_mongo, x="label", y="session_duration")
plt.xticks([0, 1], ["No Purchase", "Purchase"])
plt.title("Session Duration vs Purchase Behavior")
plt.ylabel("Duration (s)")
plt.tight_layout()
plt.savefig("../output/figures/session_duration_vs_purchase.png")
plt.show()


In [None]:
sns.histplot(data=df_mongo, x="avg_price", hue="label", bins=40, kde=True, palette="Set2")
plt.title("Price Distribution by Purchase Label")
plt.xlabel("Average Price of Viewed Products")
plt.ylabel("Session Count")
plt.legend(["No Purchase", "Purchase"])
plt.tight_layout()
plt.savefig("../output/figures/price_vs_purchase.png")
plt.show()