In [1]:
import findspark
findspark.init()

In [2]:
import pyspark

In [3]:
from pyspark import SparkContext
from pyspark.conf import SparkConf
from pyspark.sql import SparkSession

In [4]:
sc = SparkContext()

In [5]:
spark = SparkSession.builder.appName('ex_demo').getOrCreate()

In [17]:
from pyspark.ml.fpm import FPGrowth

In [7]:
# Loads data.
data = spark.read.csv('instacart_2017_05_01/order_products__train.csv', header=True, inferSchema=True)

In [8]:
data.count()

1384617

In [9]:
data.show()

+--------+----------+-----------------+---------+
|order_id|product_id|add_to_cart_order|reordered|
+--------+----------+-----------------+---------+
|       1|     49302|                1|        1|
|       1|     11109|                2|        1|
|       1|     10246|                3|        0|
|       1|     49683|                4|        0|
|       1|     43633|                5|        1|
|       1|     13176|                6|        0|
|       1|     47209|                7|        0|
|       1|     22035|                8|        1|
|      36|     39612|                1|        0|
|      36|     19660|                2|        1|
|      36|     49235|                3|        0|
|      36|     43086|                4|        1|
|      36|     46620|                5|        1|
|      36|     34497|                6|        1|
|      36|     48679|                7|        1|
|      36|     46979|                8|        1|
|      38|     11913|                1|        0|


In [10]:
# Pre-processing data
from pyspark.sql.functions import collect_list, col, count, collect_set

In [11]:
data.createOrReplaceTempView("order_products_train")

In [12]:
products = spark.sql("select distinct product_id from order_products_train")
products.count()

39123

In [13]:
rawData = spark.sql("select * from order_products_train")
baskets = rawData.groupBy('order_id').agg(collect_set('product_id').alias('items'))
baskets.createOrReplaceTempView('baskets')

In [14]:
baskets.show(5, truncate=False)

+--------+---------------------------------------------------------------------------------------------------------------------+
|order_id|items                                                                                                                |
+--------+---------------------------------------------------------------------------------------------------------------------+
|762     |[41220, 21137, 30391, 15872]                                                                                         |
|844     |[14992, 18599, 21405, 31766, 11182, 28289, 9387]                                                                     |
|988     |[4818, 12626, 45061, 28464]                                                                                          |
|1139    |[34969, 1376, 13431, 45757, 40396, 7559, 21137, 24852, 46993]                                                        |
|1143    |[3464, 29307, 47209, 39275, 19660, 7552, 27966, 12206, 47626, 42958, 21405, 42719, 3675

In [15]:
type(baskets)

pyspark.sql.dataframe.DataFrame

In [18]:
fpGrowth = FPGrowth(itemsCol="items", minSupport=0.003, minConfidence=0.003)
model = fpGrowth.fit(baskets)

In [19]:
# Display frequent itemsets.
model.freqItemsets.show()

+--------------+----+
|         items|freq|
+--------------+----+
|       [33120]| 834|
|       [13263]| 420|
|       [34448]| 684|
|        [8021]|1183|
|       [17316]| 468|
|       [37524]| 450|
|       [16759]|1742|
|[16759, 13176]| 465|
|       [25588]| 401|
|        [4086]| 510|
|       [20842]|1046|
|       [37646]|2809|
|[37646, 21137]| 574|
|[37646, 13176]| 810|
|[37646, 21903]| 461|
|[37646, 47209]| 401|
|[37646, 24852]| 666|
|       [45104]| 436|
|       [20995]|1361|
|       [33401]| 568|
+--------------+----+
only showing top 20 rows



In [20]:
# transform examines the input items against all the association rules and summar
# consequents as prediction
mostPopularItemInABasket = model.transform(baskets)

In [21]:
mostPopularItemInABasket.show()

+--------+--------------------+--------------------+
|order_id|               items|          prediction|
+--------+--------------------+--------------------+
|     762|[41220, 21137, 30...|[13176, 37646, 22...|
|     844|[14992, 18599, 21...|             [24852]|
|     988|[4818, 12626, 450...|                  []|
|    1139|[34969, 1376, 134...|[37646, 22935, 39...|
|    1143|[3464, 29307, 472...|[21137, 47766, 16...|
|    1280|[27845, 23955, 49...|[13176, 24852, 21...|
|    1342|[30827, 3798, 149...|[37646, 22935, 39...|
|    1350|[1017, 23776, 305...|                  []|
|    1468|[17794, 34243, 27...|[13176, 21903, 21...|
|    1591|[48246, 44116, 24...|[47626, 37646, 22...|
|    1721|[14197, 5134, 278...|[21137, 47766, 47...|
|    1890|[8277, 8424, 4463...|[22935, 24964, 31...|
|    1955|[16254, 35469, 12...|[24852, 16759, 37...|
|    2711|[27325, 3873, 171...|             [24852]|
|    2888|[35622, 5077, 382...|[37646, 22935, 39...|
|    3010|[45535, 21195, 47...|[37646, 22935, 

### Use product_name instead of product_id


In [23]:
product_data = spark.read.csv('instacart_2017_05_01/products.csv', header=True, inferSchema=True)

In [24]:
product_data.show(5, truncate=False)

+----------+-----------------------------------------------------------------+--------+-------------+
|product_id|product_name                                                     |aisle_id|department_id|
+----------+-----------------------------------------------------------------+--------+-------------+
|1         |Chocolate Sandwich Cookies                                       |61      |19           |
|2         |All-Seasons Salt                                                 |104     |13           |
|3         |Robust Golden Unsweetened Oolong Tea                             |94      |7            |
|4         |Smart Ones Classic Favorites Mini Rigatoni With Vodka Cream Sauce|38      |1            |
|5         |Green Chile Anytime Sauce                                        |5       |13           |
+----------+-----------------------------------------------------------------+--------+-------------+
only showing top 5 rows



In [25]:
product_data.createOrReplaceTempView("products")

In [26]:
rawData_1 = spark.sql('''select p.product_name, o.order_id from products p 
 inner join order_products_train o 
 where o.product_id = p.product_id''')
baskets_1 = rawData_1.groupBy('order_id').agg(collect_set('product_name')\
 .alias('items'))
baskets_1.createOrReplaceTempView('baskets')

In [27]:
baskets_1.head(3)

[Row(order_id=762, items=['Organic Cucumber', 'Organic Romaine Lettuce', 'Celery Hearts', 'Organic Strawberries']),
 Row(order_id=1139, items=['Cinnamon Rolls with Icing', 'Red Vine Tomato', 'Picnic Potato Salad', 'Flaky Biscuits', 'Organic Strawberries', 'Organic Bakery Hamburger Buns Wheat - 8 CT', 'Buttermilk Biscuits', 'Banana', 'Guacamole']),
 Row(order_id=1143, items=['Water', 'Natural Premium Coconut Water', 'Organic Red Radish, Bunch', 'Organic Capellini Whole Wheat Pasta', 'Organic Raspberries', 'Calming Lavender Body Wash', 'Organic Garlic', 'Rustic Baguette', 'Organic Brussel Sprouts', 'Organic Butterhead (Boston, Butter, Bibb) Lettuce', 'Organic Blueberries', 'Spring Water', 'Large Lemon', 'Basil Pesto', 'Baby Arugula', 'Organic Hass Avocado', 'Unscented Long Lasting Stick Deodorant'])]

In [28]:
fpGrowth_1 = FPGrowth(itemsCol="items", 
                      minSupport=0.003, 
                      minConfidence=0.003)
model_1 = fpGrowth.fit(baskets_1)

In [29]:
# Display frequent itemsets.
model_1.freqItemsets.show(truncate=False)

+-------------------------------------+----+
|items                                |freq|
+-------------------------------------+----+
|[Total 0% Raspberry Yogurt]          |441 |
|[Organic Egg Whites]                 |834 |
|[100% Raw Coconut Water]             |1298|
|[Mint]                               |510 |
|[Organic Red Potato]                 |411 |
|[Lime]                               |472 |
|[Raspberries]                        |3279|
|[Raspberries, Organic Strawberries]  |563 |
|[Raspberries, Strawberries]          |658 |
|[Raspberries, Organic Blueberries]   |574 |
|[Raspberries, Bag of Organic Bananas]|556 |
|[Raspberries, Banana]                |623 |
|[Strawberry Preserves]               |622 |
|[Mini Original Babybel Cheese]       |564 |
|[Roma Tomato]                        |1953|
|[Roma Tomato, Banana]                |572 |
|[Honeycrisp Apples]                  |670 |
|[Whole Almonds]                      |448 |
|[Organic Grade A Large Brown Eggs]   |874 |
|[Non Fat 

In [30]:
mostPopularItemInABasket_1 = model_1.transform(baskets_1)

In [31]:
mostPopularItemInABasket_1.head(3)

[Row(order_id=762, items=['Organic Cucumber', 'Organic Romaine Lettuce', 'Celery Hearts', 'Organic Strawberries'], prediction=['Bag of Organic Bananas', 'Raspberries', 'Organic Zucchini', 'Organic Small Bunch Celery', 'Organic Yellow Onion', 'Organic Garnet Sweet Potato (Yam)', 'Organic Tomato Cluster', 'Banana', 'Small Hass Avocado', 'Apple Honeycrisp Organic', 'Organic Garlic', 'Broccoli Crown', 'Seedless Red Grapes', 'Organic Baby Spinach', 'Asparagus', 'Large Lemon', 'Organic Baby Arugula', 'Organic Peeled Whole Baby Carrots', 'Organic Avocado', "Organic D'Anjou Pears", 'Organic Grape Tomatoes', 'Organic Large Extra Fancy Fuji Apple', 'Organic Kiwi', 'Organic Red Onion', 'Organic Hass Avocado', 'Original Hummus', 'Limes', 'Organic Blackberries', 'Organic Baby Carrots', 'Organic Gala Apples', 'Honeycrisp Apple', 'Organic Raspberries', 'Organic Red Bell Pepper', 'Organic Cilantro', 'Organic Whole String Cheese', 'Organic Granny Smith Apple', 'Organic Blueberries', 'Fresh Cauliflower'

In [32]:
type(mostPopularItemInABasket_1)

pyspark.sql.dataframe.DataFrame

In [33]:
# chuyển list array thành string
from pyspark.sql.types import StringType

In [34]:
mostPopularItemInABasket_1.printSchema()

root
 |-- order_id: integer (nullable = true)
 |-- items: array (nullable = false)
 |    |-- element: string (containsNull = false)
 |-- prediction: array (nullable = true)
 |    |-- element: string (containsNull = false)



In [35]:
mostPopularItemInABasket_1.createOrReplaceTempView("popular_items")

In [36]:
DF_cast = mostPopularItemInABasket_1.select('order_id', 
                                            mostPopularItemInABasket_1.items.cast(StringType()),
                                            mostPopularItemInABasket_1.prediction.cast(StringType()))

In [37]:
DF_cast.printSchema()

root
 |-- order_id: integer (nullable = true)
 |-- items: string (nullable = false)
 |-- prediction: string (nullable = true)



In [38]:
DF_cast.head(3)

[Row(order_id=762, items='[Organic Cucumber, Organic Romaine Lettuce, Celery Hearts, Organic Strawberries]', prediction="[Bag of Organic Bananas, Raspberries, Organic Zucchini, Organic Small Bunch Celery, Organic Yellow Onion, Organic Garnet Sweet Potato (Yam), Organic Tomato Cluster, Banana, Small Hass Avocado, Apple Honeycrisp Organic, Organic Garlic, Broccoli Crown, Seedless Red Grapes, Organic Baby Spinach, Asparagus, Large Lemon, Organic Baby Arugula, Organic Peeled Whole Baby Carrots, Organic Avocado, Organic D'Anjou Pears, Organic Grape Tomatoes, Organic Large Extra Fancy Fuji Apple, Organic Kiwi, Organic Red Onion, Organic Hass Avocado, Original Hummus, Limes, Organic Blackberries, Organic Baby Carrots, Organic Gala Apples, Honeycrisp Apple, Organic Raspberries, Organic Red Bell Pepper, Organic Cilantro, Organic Whole String Cheese, Organic Granny Smith Apple, Organic Blueberries, Fresh Cauliflower, Organic Banana, Organic Whole Milk, Organic Navel Orange, Organic Lemon]"),
 Ro

In [39]:
DF_cast.write.csv('mostPopularItemInABasket.csv')