In [1]:
import findspark
findspark.init()

In [2]:
import pyspark

In [3]:
from pyspark import SparkContext
from pyspark.conf import SparkConf
from pyspark.sql import SparkSession

In [4]:
sc = SparkContext()

In [5]:
spark = SparkSession.builder.appName('Association Rules').getOrCreate()

In [6]:
from pyspark.ml.fpm import FPGrowth

In [7]:
from pyspark.sql.types import *
df_schema = StructType([StructField("tran_id", IntegerType(), True)\
                       ,StructField("quantity", IntegerType(), True)\
                       ,StructField("product", IntegerType(), True)])

In [8]:
file = "Du lieu cung cap/75000/75000i.csv"
file_goods = "Du lieu cung cap/75000/goods.csv"

df = spark.read.csv(file, schema=df_schema)
df_goods= spark.read.csv(file_goods, header=True, inferSchema=True)

In [9]:
from pyspark.sql.functions import col, regexp_replace

# List of column names
columns = df_goods.columns

for column in columns:
    df_goods = df_goods.withColumn(column, regexp_replace(col(column), "'", ""))

In [10]:
df.printSchema()

root
 |-- tran_id: integer (nullable = true)
 |-- quantity: integer (nullable = true)
 |-- product: integer (nullable = true)



In [11]:
df.show(5)

+-------+--------+-------+
|tran_id|quantity|product|
+-------+--------+-------+
|      1|       1|     21|
|      1|       5|     11|
|      2|       1|      7|
|      2|       3|     11|
|      2|       4|     37|
+-------+--------+-------+
only showing top 5 rows



In [12]:
df_goods.printSchema()

root
 |-- Id: string (nullable = true)
 |-- Flavor: string (nullable = true)
 |-- Food: string (nullable = true)
 |-- Price: string (nullable = true)
 |-- Type: string (nullable = true)



In [13]:
df_goods.show(5)

+---+----------+----+-----+----+
| Id|    Flavor|Food|Price|Type|
+---+----------+----+-----+----+
|  0| Chocolate|Cake| 8.95|Food|
|  1|     Lemon|Cake| 8.95|Food|
|  2|    Casino|Cake|15.95|Food|
|  3|     Opera|Cake|15.95|Food|
|  4|Strawberry|Cake|11.95|Food|
+---+----------+----+-----+----+
only showing top 5 rows



In [14]:
df_goods.count()

50

In [15]:
from pyspark.sql.functions import collect_list, col, count, collect_set

In [16]:
df.createOrReplaceTempView("order_products_train")
df_goods.createOrReplaceTempView("goods_list")

In [17]:
products = spark.sql("select distinct * from order_products_train")
products.count()

266209

In [18]:
goods = spark.sql("select distinct * from goods_list")
goods.count()

50

In [19]:
# Tạo tập tin giỏ hàng
rawData = spark.sql("select * from order_products_train")
baskets = rawData.groupBy('tran_id').agg(collect_set('product').alias('items'))
baskets.createOrReplaceTempView('baskets')

In [20]:
rawData.show(5)

+-------+--------+-------+
|tran_id|quantity|product|
+-------+--------+-------+
|      1|       1|     21|
|      1|       5|     11|
|      2|       1|      7|
|      2|       3|     11|
|      2|       4|     37|
+-------+--------+-------+
only showing top 5 rows



In [21]:
baskets.sort('tran_id').show(10, truncate=False)

+-------+--------------------+
|tran_id|items               |
+-------+--------------------+
|1      |[21, 11]            |
|2      |[45, 37, 7, 11]     |
|3      |[33, 42, 3]         |
|4      |[12, 5, 17, 47]     |
|5      |[42, 6, 18]         |
|6      |[34, 2, 4]          |
|7      |[15, 16, 40, 23]    |
|8      |[34, 2, 3, 29]      |
|9      |[35, 36, 18, 26, 23]|
|10     |[45, 44]            |
+-------+--------------------+
only showing top 10 rows



In [22]:
type(baskets)

pyspark.sql.dataframe.DataFrame

In [23]:
from pyspark.ml.fpm import FPGrowth
fpGrowth = FPGrowth(itemsCol="items", minSupport=0.003, minConfidence=0.3)
model = fpGrowth.fit(baskets)

In [24]:
model.freqItemsets.show()

+--------+----+
|   items|freq|
+--------+----+
|    [19]|5685|
|[19, 27]| 359|
|[19, 33]| 334|
| [19, 1]|2764|
|[19, 28]| 408|
|[19, 37]| 274|
|[19, 35]| 312|
|[19, 16]| 286|
| [19, 4]| 388|
|[19, 46]| 324|
|[19, 15]| 298|
| [19, 5]| 323|
|[19, 22]| 368|
|[19, 32]| 297|
|[19, 45]| 344|
|[19, 47]| 331|
| [19, 3]| 294|
|[19, 14]| 350|
|[19, 11]| 296|
| [19, 0]| 305|
+--------+----+
only showing top 20 rows



In [25]:
# transform examines the input items against all the association rules and summar
# consequents as prediction
mostPopularItemInABasket = model.transform(baskets)

In [26]:
mostPopularItemInABasket.show(3, truncate=False, vertical=True)

-RECORD 0---------------------
 tran_id    | 1               
 items      | [21, 11]        
 prediction | [37, 45, 7]     
-RECORD 1---------------------
 tran_id    | 2               
 items      | [45, 37, 7, 11] 
 prediction | [15, 16, 32]    
-RECORD 2---------------------
 tran_id    | 3               
 items      | [33, 42, 3]     
 prediction | [35, 18]        
only showing top 3 rows



## Use product_name instead of product_id

In [27]:
rawData_1 = spark.sql('''select p.flavor || ' ' || p.food as flavor, o.tran_id from goods_list p
                          inner join order_products_train o
                          where o.product = p.id''')

In [28]:
rawData_1.show(5)

+--------------+-------+
|        flavor|tran_id|
+--------------+-------+
|Ganache Cookie|      1|
|     Apple Pie|      1|
| Coffee Eclair|      2|
|     Apple Pie|      2|
|  Almond Twist|      2|
+--------------+-------+
only showing top 5 rows



In [29]:
baskets_1 = rawData_1.groupBy('tran_id').agg(collect_set('flavor').alias('items'))
baskets_1.createOrReplaceTempView('baskets')

In [30]:
baskets_1.head(3)

[Row(tran_id=1, items=['Ganache Cookie', 'Apple Pie']),
 Row(tran_id=6, items=['Strawberry Cake', 'Chocolate Croissant', 'Casino Cake']),
 Row(tran_id=12, items=['Almond Twist', 'Ganache Cookie', 'Opera Cake', 'Single Espresso', 'Casino Cake', 'Raspberry Lemonade', 'Apple Pie'])]

In [31]:
fpGrowth_1 = FPGrowth(itemsCol="items",minSupport=0.003,minConfidence=0.2)
model_1 = fpGrowth.fit(baskets_1)

In [32]:
# Display frequent itemsets.
model_1.freqItemsets.show(truncate=False)

+---------------------------------------+----+
|items                                  |freq|
+---------------------------------------+----+
|[Vanilla Meringue]                     |3179|
|[Vanilla Meringue, Lemon Tart]         |252 |
|[Vanilla Meringue, Marzipan Cookie]    |277 |
|[Vanilla Meringue, Cheese Croissant]   |260 |
|[Vanilla Meringue, Chocolate Tart]     |233 |
|[Vanilla Meringue, Lemon Cake]         |293 |
|[Vanilla Meringue, Tuile Cookie]       |312 |
|[Vanilla Meringue, Apricot Danish]     |249 |
|[Vanilla Meringue, Blueberry Tart]     |229 |
|[Vanilla Meringue, Chocolate Coffee]   |245 |
|[Vanilla Meringue, Strawberry Cake]    |299 |
|[Vanilla Meringue, Blackberry Tart]    |229 |
|[Vanilla Meringue, Gongolais Cookie]   |293 |
|[Vanilla Meringue, Truffle Cake]       |255 |
|[Vanilla Meringue, Apricot Croissant]  |229 |
|[Vanilla Meringue, Hot Coffee]         |227 |
|[Vanilla Meringue, Vanilla Frappuccino]|260 |
|[Vanilla Meringue, Berry Tart]         |281 |
|[Vanilla Mer

In [33]:
mostPopularItemInABasket_1 = model_1.transform(baskets_1)

In [34]:
mostPopularItemInABasket_1.head(3)

[Row(tran_id=1, items=['Ganache Cookie', 'Apple Pie'], prediction=['Almond Twist', 'Hot Coffee', 'Coffee Eclair']),
 Row(tran_id=6, items=['Strawberry Cake', 'Chocolate Croissant', 'Casino Cake'], prediction=['Napoleon Cake', 'Chocolate Coffee', 'Chocolate Cake']),
 Row(tran_id=12, items=['Almond Twist', 'Ganache Cookie', 'Opera Cake', 'Single Espresso', 'Casino Cake', 'Raspberry Lemonade', 'Apple Pie'], prediction=['Hot Coffee', 'Coffee Eclair', 'Green Tea', 'Raspberry Cookie', 'Lemon Cookie', 'Lemon Lemonade', 'Chocolate Coffee', 'Chocolate Cake', 'Apricot Danish', 'Cherry Tart', 'Blackberry Tart'])]

In [35]:
type(mostPopularItemInABasket_1)

pyspark.sql.dataframe.DataFrame

In [36]:
# chuyển list array thành string
from pyspark.sql.types import StringType

In [37]:
mostPopularItemInABasket_1.printSchema()

root
 |-- tran_id: integer (nullable = true)
 |-- items: array (nullable = false)
 |    |-- element: string (containsNull = false)
 |-- prediction: array (nullable = true)
 |    |-- element: string (containsNull = false)



In [38]:
mostPopularItemInABasket_1.createOrReplaceTempView("popular_items")

In [39]:
DF_cast = mostPopularItemInABasket_1.select('tran_id',
                                            mostPopularItemInABasket_1.items.cast(StringType()),
                                            mostPopularItemInABasket_1.prediction.cast(StringType()))
DF_cast.printSchema()

root
 |-- tran_id: integer (nullable = true)
 |-- items: string (nullable = false)
 |-- prediction: string (nullable = true)



In [40]:
# Thông tin sản phẩm và list sản phẩm gợi ý
DF_cast.head(3)

[Row(tran_id=1, items='[Ganache Cookie, Apple Pie]', prediction='[Almond Twist, Hot Coffee, Coffee Eclair]'),
 Row(tran_id=6, items='[Strawberry Cake, Chocolate Croissant, Casino Cake]', prediction='[Napoleon Cake, Chocolate Coffee, Chocolate Cake]'),
 Row(tran_id=12, items='[Almond Twist, Ganache Cookie, Opera Cake, Single Espresso, Casino Cake, Raspberry Lemonade, Apple Pie]', prediction='[Hot Coffee, Coffee Eclair, Green Tea, Raspberry Cookie, Lemon Cookie, Lemon Lemonade, Chocolate Coffee, Chocolate Cake, Apricot Danish, Cherry Tart, Blackberry Tart]')]

In [41]:
DF_cast.write.csv('mostPopularItemInABasket.csv')