In [1]:
import findspark
findspark.init()
import os
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType
from dotenv import load_dotenv

load_dotenv()

# Create a Spark session
spark = SparkSession.builder.appName("ProductRatings").getOrCreate()

def get_rating_data():
    schema = StructType([
        StructField("customer_id", IntegerType(), True),
        StructField("product_id", IntegerType(), True),
        StructField("rating", FloatType(), True)
                ])
    data = spark.read.csv(os.getenv('BASE_PROJECT_PATH') + 'data.data', sep=',', schema=schema, header=False)
    return data

In [2]:
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS
from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder

# Load the data
data = get_rating_data()

# Split the data into training and testing sets
(training, test) = data.randomSplit([0.8, 0.2], seed=42)

# Build the ALS model
als = ALS(userCol="customer_id", itemCol="product_id", ratingCol="rating", coldStartStrategy="drop")

# Define a parameter grid for hyperparameter tuning
param_grid = ParamGridBuilder() \
    .addGrid(als.rank, [20, 30, 40]) \
    .addGrid(als.maxIter, [15, 18, 20]) \
    .addGrid(als.regParam, [0.2, 0.3, 0.5]) \
    .build()

# Define an evaluators 
evaluator_rmse = RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction")
evaluator_r2 = RegressionEvaluator(metricName="r2", labelCol="rating", predictionCol="prediction")
evaluator_mse = RegressionEvaluator(metricName="mse", labelCol="rating", predictionCol="prediction")

# Use TrainValidationSplit to choose the best combination of parameters
tvs = TrainValidationSplit(estimator=als,
                           estimatorParamMaps=param_grid,
                           evaluator=evaluator_rmse,
                           trainRatio=0.8)


print("Number of records in the training set:", training.count())
print("Training set schema:")
training.printSchema()


# Train the model
model = tvs.fit(training)

# Make predictions on the test set
predictions = model.transform(test)

# Evaluate the model
rmse = evaluator_rmse.evaluate(predictions)
r2 = evaluator_r2.evaluate(predictions)
mse = evaluator_mse.evaluate(predictions)

print("Root Mean Squared Error (RMSE) on test data =", rmse)
print("R Squared (R2) on test data =", r2)
print("Mean Squared Error (MSE) on test data =", mse)

Number of records in the training set: 4050
Training set schema:
root
 |-- customer_id: integer (nullable = true)
 |-- product_id: integer (nullable = true)
 |-- rating: float (nullable = true)

Root Mean Squared Error (RMSE) on test data = 1.617305881987267
R Squared (R2) on test data = -1.1498870175975955
Mean Squared Error (MSE) on test data = 2.615678315910612


In [3]:
# Get the best model from the tuning
best_model = model.bestModel

# Get the best combination of hyperparameters
best_rank = best_model.rank
best_max_iter = best_model._java_obj.parent().getMaxIter()
best_reg_param = best_model._java_obj.parent().getRegParam()

# Print the best hyperparameters
print("Best Rank:", best_rank)
print("Best Max Iter:", best_max_iter)
print("Best Reg Param:", best_reg_param)

Best Rank: 30
Best Max Iter: 18
Best Reg Param: 0.3


In [4]:
# Save the model to a new path
model.bestModel.write().save(os.getenv('als_model') + 'best_model')

In [11]:
# load the model
from pyspark.ml.recommendation import ALSModel

model1 = ALSModel.load(os.getenv('als_model') + 'best_model_als')

# Make recommendations for users
userRecs = model1.recommendForAllUsers(10)

print(userRecs.show(10, False))

+-----------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|customer_id|recommendations                                                                                                                                                           |
+-----------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|1          |[{15, 0.9255558}, {4, 0.92176926}, {13, 0.7850331}, {21, 0.78123623}, {17, 0.76684755}, {18, 0.7440187}, {14, 0.7237903}, {2, 0.7121044}, {10, 0.711445}, {19, 0.7088011}]|
|2          |[{21, 2.81583}, {18, 1.9832426}, {3, 1.9223797}, {15, 1.9175744}, {13, 1.9149837}, {4, 1.8535999}, {10, 1.8357399}, {7, 1.8019316}, {17, 1.7964588}, {5, 1.7505347}]      |
|3          |[{1, 1.7382505}, {9, 1.2178743}, {12, 1.2083862}, {10, 1.14712