In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import math

import findspark
findspark.init()

import pyspark
findspark.find()

from pyspark.sql import SparkSession
from pyspark.sql.functions import count

from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator

spark = (SparkSession
         .builder
         .appName("ALS recommendation spark session")
         .getOrCreate())

Đọc dữ liệu

In [2]:
ratingData = (spark.read
          .option("HEADER", True)
          .option("inferSchema", True)
          .csv("ratings_Beauty.csv")
         )

ratingData.show(5)

+--------------+----------+------+----------+
|        UserId| ProductId|Rating| Timestamp|
+--------------+----------+------+----------+
|A39HTATAQ9V7YF|0205616461|   5.0|1369699200|
|A3JM6GV9MNOF9X|0558925278|   3.0|1355443200|
|A1Z513UWSAAO0F|0558925278|   5.0|1404691200|
|A1WMRR494NWEWV|0733001998|   4.0|1382572800|
|A3IAAVS479H7M7|0737104473|   1.0|1274227200|
+--------------+----------+------+----------+
only showing top 5 rows



In [3]:
nd=ratingData.select(ratingData['UserId'],ratingData['Rating'],ratingData['ProductId'])
nd.show(5)

+--------------+------+----------+
|        UserId|Rating| ProductId|
+--------------+------+----------+
|A39HTATAQ9V7YF|   5.0|0205616461|
|A3JM6GV9MNOF9X|   3.0|0558925278|
|A1Z513UWSAAO0F|   5.0|0558925278|
|A1WMRR494NWEWV|   4.0|0733001998|
|A3IAAVS479H7M7|   1.0|0737104473|
+--------------+------+----------+
only showing top 5 rows



Chuyển cột UserId (kiểu string) thành label1 (kiểu double)

In [4]:
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

indexer = [StringIndexer(inputCol=column, outputCol=column+"_index") 
for column in list(set(nd.columns)-set(['Rating'])) ]
pipeline = Pipeline(stages=indexer)
transformed = pipeline.fit(nd).transform(nd)
transformed.show(10)

+--------------+------+----------+------------+---------------+
|        UserId|Rating| ProductId|UserId_index|ProductId_index|
+--------------+------+----------+------------+---------------+
|A39HTATAQ9V7YF|   5.0|0205616461|     49380.0|        43788.0|
|A3JM6GV9MNOF9X|   3.0|0558925278|    263677.0|        28932.0|
|A1Z513UWSAAO0F|   5.0|0558925278|    142369.0|        28932.0|
|A1WMRR494NWEWV|   4.0|0733001998|    137071.0|        43789.0|
|A3IAAVS479H7M7|   1.0|0737104473|    260907.0|        43790.0|
| AKJHHD5VEH7VG|   5.0|0762451459|       186.0|        43791.0|
|A1BG8QW55XHN6U|   5.0|1304139212|       619.0|        43792.0|
|A22VW0P4VZHDE3|   5.0|1304139220|      2034.0|        43793.0|
|A3V3RE4132GKRO|   5.0|130414089X|    288317.0|        43794.0|
|A327B0I7CYTEJC|   4.0|130414643X|    226073.0|        21588.0|
+--------------+------+----------+------------+---------------+
only showing top 10 rows



Chuyển cột productId (kiểu string) thành label2 (kiểu double)

Khởi tạo sparkDataFrame từ pandas dataFrame

In [5]:
(train, test) = transformed.randomSplit([.8, .2], seed = 1)

In [6]:
als=ALS(maxIter=1,regParam=0.09,rank=25,userCol="UserId_index",itemCol="ProductId_index",
ratingCol="Rating",coldStartStrategy="drop",nonnegative=True)
model=als.fit(train)

 Xây dựng và đánh giá mô hình ALS trên pyspark.

Tính RMSE

In [7]:
# Xây dựng mô hình recommendation sử dụng thuật toán ALS trên tập dữ liệu huấn luyện
evaluator=RegressionEvaluator(metricName="rmse",labelCol="Rating",predictionCol="prediction")
predictions=model.transform(test)
rmse=evaluator.evaluate(predictions)
print("RMSE="+str(rmse))
predictions.show(10)

RMSE=3.0772670795474943
+--------------+------+----------+------------+---------------+----------+
|        UserId|Rating| ProductId|UserId_index|ProductId_index|prediction|
+--------------+------+----------+------------+---------------+----------+
| A96FRN3LGYUX9|   5.0|B005XIDZHO|       451.0|           26.0|0.73248327|
|A25G3M46N5BEW2|   5.0|B005XIDZHO|      8782.0|           26.0| 0.3127757|
|A1H691MQUAA31K|   5.0|B005XIDZHO|     13960.0|           26.0| 0.5683287|
|A38JDRJ4SK0KN2|   5.0|B005XIDZHO|     49061.0|           26.0| 0.1787904|
| AICDTM94WTGZH|   5.0|B005XIDZHO|     60369.0|           26.0|  0.870149|
| ANKUJGLW1TFXH|   5.0|B005XIDZHO|     61933.0|           26.0|0.34665826|
|A256NW4OC7466Z|   5.0|B005XIDZHO|      2768.0|           26.0|0.79248327|
|A3A7WDJ5FF4BD1|   1.0|B005XIDZHO|      3045.0|           26.0| 0.5653503|
|A2WHUNQ6R2NBHH|   3.0|B005XIDZHO|      9783.0|           26.0| 1.8064873|
|A349QJY8WW4TB5|   5.0|B005XIDZHO|     10070.0|           26.0| 0.1317767|
+