In [1]:
import findspark
findspark.init()

In [2]:
from pyspark.sql import SparkSession
from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS
from pyspark.sql.functions import col, explode

In [3]:
sc = SparkContext()

In [4]:
spark = SparkSession.builder.appName('Recommendation_project_2').getOrCreate()

In [5]:
data = spark.read.csv('review.csv',inferSchema=True,header=True)

In [6]:
data = data.select(['customer_id', 'rating', 'product_id'])

In [7]:
data.show(5, False)

+-----------+------+----------+
|customer_id|rating|product_id|
+-----------+------+----------+
|709310     |3     |10001012  |
|10701688   |5     |10001012  |
|11763074   |5     |10001012  |
|9909549    |5     |10001012  |
|1827148    |5     |10001012  |
+-----------+------+----------+
only showing top 5 rows



In [8]:
data.printSchema()

root
 |-- customer_id: string (nullable = true)
 |-- rating: string (nullable = true)
 |-- product_id: string (nullable = true)



In [9]:
data = data.withColumn("customer_id", col('customer_id').cast('int'))\
        .withColumn("rating", col('rating').cast('int'))\
        .withColumn("product_id", col('product_id').cast('int'))

In [10]:
from pyspark.sql.types import DoubleType
from pyspark.sql.functions import isnan, when, count, col, udf

In [11]:
data.select([count(when(col(c).isNull(), c)).alias(c) for c in 
           data.columns]).toPandas().T

Unnamed: 0,0
customer_id,1722
rating,1752
product_id,1722


In [12]:
data = data.dropna()

In [13]:
data.select([count(when(col(c).isNull(), c)).alias(c) for c in 
           data.columns]).toPandas().T

Unnamed: 0,0
customer_id,0
rating,0
product_id,0


In [14]:
# Distinct customer_id and product_id
users = data.select('customer_id').distinct().count()
products = data.select('product_id').distinct().count()
numerator = data.count()

In [15]:
display(numerator, users, products)

364069

251467

4218

In [16]:
# Split training and test data
(training, test) = data.randomSplit([0.8, 0.2])

### Recommendation model

In [17]:
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS

In [18]:
als = ALS(maxIter = 5, regParam = 0.4,
         userCol = 'customer_id',
         itemCol = 'product_id',
         ratingCol = 'rating',
         coldStartStrategy = 'drop',
         nonnegative = True)
model = als.fit(training)

In [19]:
# Evaluate the model by computing the RMSE on the test data
predictions = model.transform(test)

In [20]:
predictions.show(5)

+-----------+------+----------+----------+
|customer_id|rating|product_id|prediction|
+-----------+------+----------+----------+
|    6104746|     4|   2774881| 3.6907284|
|    6722335|     1|   2774881| 3.4951613|
|    7572140|     5|   3222489| 3.2909632|
|   11535292|     5|   5983423|  3.813942|
|   11681400|     5|   2774881| 4.3285723|
+-----------+------+----------+----------+
only showing top 5 rows



In [21]:
evaluator = RegressionEvaluator(metricName = 'rmse',
                               labelCol = 'rating',
                               predictionCol = 'prediction')
rmse = evaluator.evaluate(predictions)
print('Root-mean-square error = ' + str(rmse))

Root-mean-square error = 1.2602289484884135


- On average, this model is ~ 1.26 from perfect recommendations

### Hiệu chỉnh tham số

In [22]:
als_t = ALS(maxIter = 10, regParam = 0.4,
         userCol = 'customer_id',
         itemCol = 'product_id',
         ratingCol = 'rating',
         coldStartStrategy = 'drop',
         nonnegative = True)
model_t = als_t.fit(training)

In [23]:
# Evaluate the model by computing theRMSE on the test data
predictions_t = model_t.transform(test)

In [24]:
rmse_t = evaluator.evaluate(predictions_t)
print('Root-mean-square error = ' + str(rmse_t))

Root-mean-square error = 1.1146470113238718


- Chọn model_t vì cho rmse thấp hơn

### Đưa ra đề xuất cho tất cả các user

In [28]:
# Get 20 recommendations which have highest rating
user_recs = model_t.recommendForAllUsers(20)

In [29]:
user_recs.printSchema()

root
 |-- customer_id: integer (nullable = false)
 |-- recommendations: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- product_id: integer (nullable = true)
 |    |    |-- rating: float (nullable = true)



In [30]:
for user in user_recs.head(2):
    print(user)
    print('\n')

Row(customer_id=28, recommendations=[Row(product_id=69507754, rating=5.22247314453125), Row(product_id=73238633, rating=4.940127849578857), Row(product_id=70567940, rating=4.84923791885376), Row(product_id=2080951, rating=4.847939491271973), Row(product_id=73830099, rating=4.846951961517334), Row(product_id=45327625, rating=4.82234001159668), Row(product_id=68174409, rating=4.790010452270508), Row(product_id=38606217, rating=4.757803916931152), Row(product_id=73844240, rating=4.749032497406006), Row(product_id=72520984, rating=4.744541168212891), Row(product_id=53751834, rating=4.737171649932861), Row(product_id=49729049, rating=4.736989498138428), Row(product_id=3525255, rating=4.729927062988281), Row(product_id=77737982, rating=4.725241661071777), Row(product_id=50592901, rating=4.72523832321167), Row(product_id=21057555, rating=4.72262716293335), Row(product_id=8321616, rating=4.722252368927002), Row(product_id=76283023, rating=4.718403339385986), Row(product_id=20015885, rating=4.7

### Đưa ra đề xuất cho 1 customer cụ thể có customer_id là 588

In [40]:
customer_id = 588
result = user_recs.filter(user_recs['customer_id']==customer_id)
result.show(truncate = False)

+-----------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|customer_id|recommendations                                                                                                                                                                                                                                                                                                                                                                                                                                                |
+-----------+---------------------------------------------------------------

### Chuẩn hóa dữ liệu cho user

In [41]:
result = result.select(result.customer_id, explode(result.recommendations))
result = result.withColumn('product_id', result.col.getField('product_id')).withColumn('rating', result.col.getField('rating'))
result.show()

+-----------+--------------------+----------+---------+
|customer_id|                 col|product_id|   rating|
+-----------+--------------------+----------+---------+
|        588|{73238633, 4.352064}|  73238633| 4.352064|
|        588|{15623237, 4.138151}|  15623237| 4.138151|
|        588|{52785519, 4.064938}|  52785519| 4.064938|
|        588|{69507754, 4.0479...|  69507754|4.0479884|
|        588| {28075354, 4.01815}|  28075354|  4.01815|
|        588|{19395453, 4.017176}|  19395453| 4.017176|
|        588| {4597127, 4.012729}|   4597127| 4.012729|
|        588|{8877900, 3.9880514}|   8877900|3.9880514|
|        588|{13583766, 3.9676...|  13583766|3.9676995|
|        588|{10001353, 3.9597...|  10001353|3.9597583|
|        588|{59081231, 3.9572...|  59081231|3.9572663|
|        588|{57625269, 3.945741}|  57625269| 3.945741|
|        588|{74489817, 3.9331...|  74489817|3.9331427|
|        588|{51466982, 3.9301...|  51466982|3.9301553|
|        588|{73830099, 3.9290...|  73830099|3.9

### Lọc đề xuất dựa trên ngưỡng

In [42]:
# Filter all products having rating >= 3.0
result.filter(result.rating >= 3.0).show()

+-----------+--------------------+----------+---------+
|customer_id|                 col|product_id|   rating|
+-----------+--------------------+----------+---------+
|        588|{73238633, 4.352064}|  73238633| 4.352064|
|        588|{15623237, 4.138151}|  15623237| 4.138151|
|        588|{52785519, 4.064938}|  52785519| 4.064938|
|        588|{69507754, 4.0479...|  69507754|4.0479884|
|        588| {28075354, 4.01815}|  28075354|  4.01815|
|        588|{19395453, 4.017176}|  19395453| 4.017176|
|        588| {4597127, 4.012729}|   4597127| 4.012729|
|        588|{8877900, 3.9880514}|   8877900|3.9880514|
|        588|{13583766, 3.9676...|  13583766|3.9676995|
|        588|{10001353, 3.9597...|  10001353|3.9597583|
|        588|{59081231, 3.9572...|  59081231|3.9572663|
|        588|{57625269, 3.945741}|  57625269| 3.945741|
|        588|{74489817, 3.9331...|  74489817|3.9331427|
|        588|{51466982, 3.9301...|  51466982|3.9301553|
|        588|{73830099, 3.9290...|  73830099|3.9

In [55]:
product = spark.read.csv('product.csv',inferSchema=True,header=True).withColumnRenamed('item_id', 'product_id')
product = product.select(['product_id', 'name'])
product.show(5, False)

+----------+-----------------------------------------------------------------------------------------------+
|product_id|name                                                                                           |
+----------+-----------------------------------------------------------------------------------------------+
|48102821  |Tai nghe Bluetooth Inpods 12 - Cảm biến vân tay, chống nước,màu sắc đa dạng- 5 màu sắc lựa chọn|
|52333193  |Tai nghe bluetooth không dây F9 True wireless Dock Sạc có Led Báo Pin Kép                      |
|299461    |Chuột Không Dây Logitech M331 Silent - Hàng Chính Hãng                                         |
|57440329  |Loa Bluetooth 5.0 Kiêm Đồng Hồ Báo Thức - [[ 2 Trong 1 ]] - Robot - Hàng Chính Hãng            |
|38458616  |Tai Nghe Bluetooth Apple AirPods Pro True Wireless - MWP22 - Hàng Chính Hãng VN/A              |
+----------+-----------------------------------------------------------------------------------------------+
only showing top 5 

In [57]:
result.join(product, on='product_id').filter('customer_id = 588').show()

+----------+-----------+--------------------+---------+--------------------+
|product_id|customer_id|                 col|   rating|                name|
+----------+-----------+--------------------+---------+--------------------+
|  73238633|        588|{73238633, 4.352064}| 4.352064|Ổ cứng di động Ex...|
|  15623237|        588|{15623237, 4.138151}| 4.138151|Tủ đông ALASKA 2 ...|
|  52785519|        588|{52785519, 4.064938}| 4.064938|Máy giặt Toshiba ...|
|  69507754|        588|{69507754, 4.0479...|4.0479884|Bo mạch chủ Gigab...|
|  28075354|        588| {28075354, 4.01815}|  4.01815|Card Màn Hình VGA...|
|  19395453|        588|{19395453, 4.017176}| 4.017176|MÁY RỬA CHÉN BOSC...|
|   4597127|        588| {4597127, 4.012729}| 4.012729|Giá Treo Tivi Sát...|
|   8877900|        588|{8877900, 3.9880514}|3.9880514|Miếng Dán Bảo Vệ ...|
|  13583766|        588|{13583766, 3.9676...|3.9676995|Thẻ Nhớ SDXC SanD...|
|  10001353|        588|{10001353, 3.9597...|3.9597583|RAM Laptop Samsun...|