# Collaborative Filtering ALS for a retail data set

Import findspark and initiate. Then import pyspark

In [1]:
import findspark
findspark.init('/usr/local/spark')
import pyspark

Start Spark Session

In [2]:
from pyspark.sql import SparkSession

In [3]:
spark = SparkSession.builder.appName("PySpark Collaborative Filtering ALS example").getOrCreate()

Import libraries for ALS

In [4]:
from pyspark.ml.recommendation import ALS
from pyspark.sql.functions import *

Create Spark Context object

In [5]:
retailData = spark.read.load("retail_2013.csv", format="csv", inferSchema="true", header="false")
retailData.count()

2554634

In [6]:
retailData.printSchema()

root
 |-- _c0: integer (nullable = true)
 |-- _c1: timestamp (nullable = true)
 |-- _c2: integer (nullable = true)
 |-- _c3: string (nullable = true)
 |-- _c4: string (nullable = true)
 |-- _c5: integer (nullable = true)
 |-- _c6: integer (nullable = true)
 |-- _c7: integer (nullable = true)



Collaborative Filtering of spark.ml uses ALS (Alternatiing Least Squares) algorithm.
The input dataframe is expected to have the three columns userCol, itemCol and ratingCol.
In our retail data set, we will use the customer id column as userCol.
We will generate a product index for each product name (column _c4) and use them as itemCol.
We will get the number of times a user bought a product and use it for ratingCol.
We will set the input parameter implicitPrefs as true because we do not have explicit feedback such as rating of a product.

In [7]:
from pyspark.ml.feature import StringIndexer

In [8]:
indexer = StringIndexer(inputCol="_c4", outputCol="ProductIndex")
retailData1 = indexer.fit(retailData).transform(retailData)

In [9]:
retailData1.show(truncate=False)

+---+-------------------+------+-------------------------------------+---------------------+----+-----+---+------------+
|_c0|_c1                |_c2   |_c3                                  |_c4                  |_c5 |_c6  |_c7|ProductIndex|
+---+-------------------+------+-------------------------------------+---------------------+----+-----+---+------------+
|1  |2013-01-01 01:41:46|443111|Household appliance stores           |Refrigerators        |875 |14640|1  |68.0        |
|1  |2013-01-01 02:18:34|45322 |Gift  novelty  and souvenir stores   |Sun glasses          |58  |12503|1  |97.0        |
|1  |2013-01-01 12:09:03|446   |Health and personal care stores      |Health monitors      |184 |16266|1  |32.0        |
|1  |2013-01-01 00:10:20|4421  |Furniture stores                     |Study room furniture |209 |16625|1  |17.0        |
|1  |2013-01-01 13:39:08|44413 | Hardware stores                     |Saws                 |57  |16194|1  |57.0        |
|1  |2013-01-01 22:13:46|44312 |

In [10]:
retailData2 = retailData1.select('ProductIndex',col("_c6").alias("CustomerId"))
retailData2.printSchema()
retailData2.show(5)

root
 |-- ProductIndex: double (nullable = true)
 |-- CustomerId: integer (nullable = true)

+------------+----------+
|ProductIndex|CustomerId|
+------------+----------+
|        68.0|     14640|
|        97.0|     12503|
|        32.0|     16266|
|        17.0|     16625|
|        57.0|     16194|
+------------+----------+
only showing top 5 rows



In [11]:
retailData3=retailData2.groupBy("ProductIndex", "CustomerId").count()

In [12]:
retailData3.printSchema()
retailData3.show()

root
 |-- ProductIndex: double (nullable = true)
 |-- CustomerId: integer (nullable = true)
 |-- count: long (nullable = false)

+------------+----------+-----+
|ProductIndex|CustomerId|count|
+------------+----------+-----+
|        17.0|     10024|    1|
|        58.0|     13278|    1|
|         3.0|     15331|    1|
|        34.0|     13037|    2|
|        96.0|     26986|    1|
|       130.0|     27242|    1|
|       153.0|     40674|    1|
|        44.0|     40769|    1|
|        97.0|     57281|    1|
|       155.0|     50594|    1|
|        14.0|     54659|    1|
|        56.0|     57917|    1|
|        66.0|     67818|    1|
|        80.0|     63618|    1|
|        67.0|     72022|    1|
|       120.0|     76786|    1|
|        87.0|     70551|    1|
|       149.0|     79213|    2|
|       147.0|     71824|    1|
|        36.0|     83221|    2|
+------------+----------+-----+
only showing top 20 rows



In [13]:
# Changing column name of count to Count and type to double

In [14]:
retailData4=retailData3.withColumn("Count", expr("CAST(count AS DOUBLE)"))

In [15]:
retailData4.printSchema()
retailData4.show()

root
 |-- ProductIndex: double (nullable = true)
 |-- CustomerId: integer (nullable = true)
 |-- Count: double (nullable = false)

+------------+----------+-----+
|ProductIndex|CustomerId|Count|
+------------+----------+-----+
|        17.0|     10024|  1.0|
|        58.0|     13278|  1.0|
|         3.0|     15331|  1.0|
|        34.0|     13037|  2.0|
|        96.0|     26986|  1.0|
|       130.0|     27242|  1.0|
|       153.0|     40674|  1.0|
|        44.0|     40769|  1.0|
|        97.0|     57281|  1.0|
|       155.0|     50594|  1.0|
|        14.0|     54659|  1.0|
|        56.0|     57917|  1.0|
|        66.0|     67818|  1.0|
|        80.0|     63618|  1.0|
|        67.0|     72022|  1.0|
|       120.0|     76786|  1.0|
|        87.0|     70551|  1.0|
|       149.0|     79213|  2.0|
|       147.0|     71824|  1.0|
|        36.0|     83221|  2.0|
+------------+----------+-----+
only showing top 20 rows



Define ALS object and traing the model using fit method on the dataframe

In [16]:
als = ALS(maxIter=5, regParam=0.01, userCol="CustomerId", itemCol="ProductIndex", ratingCol="Count", implicitPrefs=True, alpha=1.0)

In [17]:
model = als.fit(retailData4)

In [18]:
# Create a dataframe, productDF that contains distinct product names and product indexes

In [19]:
productDF = retailData1.select(col("_c4").alias("ProductName"),'ProductIndex').distinct()
productDF.printSchema()

root
 |-- ProductName: string (nullable = true)
 |-- ProductIndex: double (nullable = true)



In [20]:
productDF.count()

160

In [21]:
productDF.orderBy('ProductIndex').show()

+--------------------+------------+
|         ProductName|ProductIndex|
+--------------------+------------+
|            Perfumes|         0.0|
|      Airated drinks|         1.0|
|  Mystery & thriller|         2.0|
|         Dishwashers|         3.0|
|       Other jewelry|         4.0|
|     Packaged fruits|         5.0|
|      Wheels & Tires|         6.0|
|       Running shoes|         7.0|
| Parts & Accessories|         8.0|
|      Personal items|         9.0|
|     Steoreo Systems|        10.0|
|             Juicers|        11.0|
|        Writing aids|        12.0|
|       Classic books|        13.0|
|       Arts & Crafts|        14.0|
| Dolls & accessories|        15.0|
|Engine & Drive train|        16.0|
|Study room furniture|        17.0|
|     Loose gemstones|        18.0|
|Boating & water s...|        19.0|
+--------------------+------------+
only showing top 20 rows



To get recommendations for a customer, for example customer id 43124, create a data frame, testDF from productDF that contains all the product indexes (and product names). To this add a column CustomerId whose value is the given customer id which in this case 43124 and column Count as 0.0.

In [22]:
testDF=productDF.withColumn('CustomerId', lit(43124)).withColumn('Count', lit(0.0))
testDF.printSchema()
testDF.show()

root
 |-- ProductName: string (nullable = true)
 |-- ProductIndex: double (nullable = true)
 |-- CustomerId: integer (nullable = false)
 |-- Count: double (nullable = false)

+--------------------+------------+----------+-----+
|         ProductName|ProductIndex|CustomerId|Count|
+--------------------+------------+----------+-----+
|      Romance novels|       102.0|     43124|  0.0|
|       Personal care|       100.0|     43124|  0.0|
|Portable audio & ...|       106.0|     43124|  0.0|
|          Text books|       108.0|     43124|  0.0|
|  Office electronics|       120.0|     43124|  0.0|
|   Smoking cessation|        61.0|     43124|  0.0|
|            Perfumes|         0.0|     43124|  0.0|
|    Unbranded paints|        44.0|     43124|  0.0|
|       Arts & Crafts|        14.0|     43124|  0.0|
|             Tablets|        81.0|     43124|  0.0|
|  Gameroom & Leisure|        27.0|     43124|  0.0|
|      Girls clothing|        35.0|     43124|  0.0|
|             Fiction|        

In [23]:
predictions = model.transform(testDF)
predictions.printSchema()

root
 |-- ProductName: string (nullable = true)
 |-- ProductIndex: double (nullable = true)
 |-- CustomerId: integer (nullable = false)
 |-- Count: double (nullable = false)
 |-- prediction: float (nullable = true)



In [24]:
predictions.orderBy("prediction", ascending=False).limit(5).show()

+--------------------+------------+----------+-----+----------+
|         ProductName|ProductIndex|CustomerId|Count|prediction|
+--------------------+------------+----------+-----+----------+
|            Perfumes|         0.0|     43124|  0.0| 0.5862731|
|Sterling silver j...|       124.0|     43124|  0.0| 0.5718691|
|         board games|       156.0|     43124|  0.0| 0.4935474|
|   Smoking cessation|        61.0|     43124|  0.0| 0.4852053|
|      Wheels & Tires|         6.0|     43124|  0.0| 0.4648381|
+--------------------+------------+----------+-----+----------+



In [25]:
testDF2=productDF.withColumn('CustomerId', lit(99970)).withColumn('Count', lit(0.0))
testDF2.printSchema()
testDF2.show()

root
 |-- ProductName: string (nullable = true)
 |-- ProductIndex: double (nullable = true)
 |-- CustomerId: integer (nullable = false)
 |-- Count: double (nullable = false)

+--------------------+------------+----------+-----+
|         ProductName|ProductIndex|CustomerId|Count|
+--------------------+------------+----------+-----+
|      Romance novels|       102.0|     99970|  0.0|
|       Personal care|       100.0|     99970|  0.0|
|Portable audio & ...|       106.0|     99970|  0.0|
|          Text books|       108.0|     99970|  0.0|
|  Office electronics|       120.0|     99970|  0.0|
|   Smoking cessation|        61.0|     99970|  0.0|
|            Perfumes|         0.0|     99970|  0.0|
|    Unbranded paints|        44.0|     99970|  0.0|
|       Arts & Crafts|        14.0|     99970|  0.0|
|             Tablets|        81.0|     99970|  0.0|
|  Gameroom & Leisure|        27.0|     99970|  0.0|
|      Girls clothing|        35.0|     99970|  0.0|
|             Fiction|        

In [26]:
predictions2 = model.transform(testDF2)
predictions2.printSchema()

root
 |-- ProductName: string (nullable = true)
 |-- ProductIndex: double (nullable = true)
 |-- CustomerId: integer (nullable = false)
 |-- Count: double (nullable = false)
 |-- prediction: float (nullable = true)



In [27]:
predictions2.orderBy("prediction", ascending=False).limit(5).show()

+--------------------+------------+----------+-----+----------+
|         ProductName|ProductIndex|CustomerId|Count|prediction|
+--------------------+------------+----------+-----+----------+
|            Perfumes|         0.0|     99970|  0.0|  0.599639|
| Wearable Technology|        24.0|     99970|  0.0| 0.5077851|
|         Audio books|       153.0|     99970|  0.0|0.44745424|
|Soccer clothing &...|        94.0|     99970|  0.0| 0.4139315|
|      Branded paints|        70.0|     99970|  0.0|  0.412652|
+--------------------+------------+----------+-----+----------+

