In [1]:
from pyspark.sql import SparkSession, Row
from pyspark.sql.utils import AnalysisException
from pyspark.sql.functions import split, col, sum
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType, StringType
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS
from pyspark.ml.feature import Binarizer, Normalizer, VectorAssembler
from pyspark.ml.linalg import DenseVector

spark = SparkSession.builder.appName('ALSExample').getOrCreate()

ratingSchema = StructType(( StructField('user', IntegerType(), False),
                            StructField('item', IntegerType(), False),
                            StructField('rating', FloatType(), False) ))

userSchema = StructType(( StructField('user', IntegerType(), False),
                          StructField('gender', IntegerType(), False),
                          StructField('age', IntegerType(), False) ))

itemSchema = StructType(( StructField('item', IntegerType(), False),
                          StructField('title', StringType(), False) ))

In [63]:
# Read data
ratings = spark.read.csv('ml-1m/ratings_timesorted.dat', ratingSchema, '::')
users = spark.read.csv('ml-1m/users_binarygender.dat', userSchema, '::')
items = spark.read.csv('ml-1m/movies.dat', itemSchema, '::')

#training, test = ratings.randomSplit((0.95, 0.05))

# Build the recommendation model using ALS on the training data
# Note we set cold start strategy to 'drop' to ensure we don't get NaN evaluation metrics
als = ALS(rank=8, maxIter=20, regParam=0.005, nonnegative=False, numUserBlocks=4, numItemBlocks=4)

In [64]:
#model = als.fit(training)
model = als.fit(ratings)

In [65]:
# Evaluate the model by computing the RMSE on the test data
#predictions = model.transform(test)
predictions = model.transform(ratings)
evaluator = RegressionEvaluator(labelCol='rating')
rmse = evaluator.evaluate(predictions)
print("Root-mean-square error =", rmse)

Root-mean-square error = 0.7762990224592068


In [49]:
predictions.show()

+----+----+------+-------------+
|user|item|rating|   prediction|
+----+----+------+-------------+
|  53| 148|   2.0|     1.189728|
| 673| 148|   2.0|     1.184251|
|4169| 148|   0.0| -0.009268612|
|4227| 148|  -1.0|    -2.607967|
|5333| 148|   0.0|  -0.17176326|
|3184| 148|   1.0|  0.033986837|
|4387| 148|  -2.0|  -0.64815944|
|4784| 148|   0.0|    0.6079085|
|2383| 148|  -1.0|  -0.17862234|
|1242| 148|   0.0| -0.022532284|
|3539| 148|   0.0|   0.41909608|
|1069| 148|  -1.0|   0.11546504|
|1605| 148|  -1.0|   -0.3939552|
| 840| 148|  -2.0|   -0.8237774|
| 216| 148|  -1.0|  -0.36422893|
| 482| 148|  -1.0|-0.0028511211|
| 752| 148|   1.0|    0.3757552|
|1150| 148|  -1.0|  -0.23748167|
|3829| 148|  -1.0|   -0.5319092|
| 424| 148|   1.0|   0.04847695|
+----+----+------+-------------+
only showing top 20 rows



In [50]:
print(model.userFactors.count())
assert model.userFactors.count() == users.count() # true
fusers = model.userFactors.selectExpr('*', # id: int, features: array<float>
    'transform(features, x->float(x/sqrt(aggregate(transform(features, x->x*x), float(0), (s,x)->s+x)))) as norm_features',
    'float(sqrt(aggregate(transform(features, x->x*x), float(0), (s,x)->s+x))) as norm',
    'transform(features, x->int(x>0)) as bin_arr',
    'aggregate(transform(features, x->int(x>0)), 0, (a,x)->shiftleft(a,1)+x) as partition')

fusers.groupby('partition').count().sort('count').show()
fusers.groupby('partition').count().sort('count', ascending=False).show()

6040
+---------+-----+
|partition|count|
+---------+-----+
|       84|    1|
|      114|    1|
|      238|    1|
|      176|    1|
|      173|    1|
|      165|    1|
|       41|    1|
|       37|    1|
|      231|    1|
|      168|    1|
|      220|    1|
|      227|    1|
|       34|    1|
|      182|    1|
|       96|    1|
|      117|    1|
|      132|    1|
|       76|    1|
|      246|    1|
|      122|    1|
+---------+-----+
only showing top 20 rows

+---------+-----+
|partition|count|
+---------+-----+
|      157|  620|
|      159|  411|
|      151|  317|
|      215|  262|
|      155|  238|
|       29|  203|
|      149|  202|
|      147|  192|
|      153|  183|
|       31|  173|
|      211|  162|
|      223|  137|
|       27|  119|
|       23|  111|
|      219|  107|
|       83|   84|
|       19|   84|
|       25|   82|
|       91|   81|
|       95|   80|
+---------+-----+
only showing top 20 rows



In [51]:
model.userFactors.show()
model.userFactors.head(3)

+---+--------------------+
| id|            features|
+---+--------------------+
| 10|[0.6651825, -0.93...|
| 20|[0.104887925, -0....|
| 30|[-0.585375, -0.46...|
| 40|[0.047282364, 0.0...|
| 50|[0.7905751, -0.76...|
| 60|[0.53856766, -1.4...|
| 70|[0.4566475, -0.49...|
| 80|[-0.4555481, -1.5...|
| 90|[-0.30143672, 0.0...|
|100|[0.95588344, 0.13...|
|110|[0.41917062, -0.5...|
|120|[-1.1208363, 0.85...|
|130|[-0.24234508, 0.1...|
|140|[1.7980086, 0.123...|
|150|[0.3380763, -0.95...|
|160|[0.86701286, -1.0...|
|170|[0.38731375, -0.4...|
|180|[-0.52567625, -0....|
|190|[0.31519017, -0.3...|
|200|[-0.5407522, -0.6...|
+---+--------------------+
only showing top 20 rows



[Row(id=10, features=[0.6651824712753296, -0.9377222657203674, -0.37672102451324463, 0.588722825050354, 1.2778269052505493, 0.5031376481056213, -0.7216995358467102, 1.384491205215454]),
 Row(id=20, features=[0.10488792508840561, -0.204710453748703, -1.5924307107925415, 0.5988529920578003, 0.7721832990646362, 0.3685716390609741, -0.5720065832138062, 0.9694386720657349]),
 Row(id=30, features=[-0.5853750109672546, -0.46003457903862, -0.34607887268066406, 0.04201498255133629, 0.6973044276237488, -0.04515616223216057, 0.2587907314300537, 0.508955180644989])]

In [52]:
# Generate top 10 movie recommendations for each user
userRecs = model.recommendForAllUsers(10)
# Generate top 10 user recommendations for each movie
movieRecs = model.recommendForAllItems(10)

# Generate top 10 movie recommendations for a specified set of users
users = ratings.select(als.getUserCol()).distinct().limit(3)
userSubsetRecs = model.recommendForUserSubset(users, 10)
# Generate top 10 user recommendations for a specified set of movies
movies = ratings.select(als.getItemCol()).distinct().limit(3)
movieSubSetRecs = model.recommendForItemSubset(movies, 10)

In [53]:
userRecs.show()

+----+--------------------+
|user|     recommendations|
+----+--------------------+
|1580|[[2705, 4.6829414...|
|4900|[[1567, 13.155405...|
|5300|[[3012, 4.175614]...|
| 471|[[106, 4.116742],...|
|1591|[[2964, 5.1865582...|
|4101|[[1539, 8.265969]...|
|1342|[[614, 3.5433302]...|
|2122|[[2765, 4.835989]...|
|2142|[[1567, 6.1453977...|
| 463|[[632, 4.4248056]...|
| 833|[[2998, 5.3107677...|
|5803|[[3636, 11.271729...|
|3794|[[811, 5.978745],...|
|1645|[[2933, 7.141486]...|
|3175|[[2192, 7.75389],...|
|4935|[[2705, 8.246386]...|
| 496|[[1420, 6.2241726...|
|2366|[[1930, 4.7507305...|
|2866|[[3184, 4.9383516...|
|5156|[[2192, 8.229196]...|
+----+--------------------+
only showing top 20 rows



In [10]:
movieRecs.show()

+----+--------------------+
|item|     recommendations|
+----+--------------------+
|1580|[[283, 2.0554464]...|
| 471|[[5529, 2.0453987...|
|1591|[[283, 1.964103],...|
|1342|[[1445, 2.046189]...|
|2122|[[5670, 2.0586808...|
|2142|[[101, 1.7309802]...|
| 463|[[283, 2.1129339]...|
| 833|[[1070, 1.3529685...|
|3794|[[4751, 2.418365]...|
|1645|[[283, 2.10782], ...|
|3175|[[2037, 2.310983]...|
| 496|[[1445, 2.523903]...|
|2366|[[1535, 2.4748394...|
|2866|[[2339, 1.8196274...|
| 148|[[540, 3.1303573]...|
|1088|[[2364, 2.4093258...|
|1238|[[446, 2.1394424]...|
|3918|[[46, 2.2402115],...|
|1829|[[283, 1.8440673]...|
|1959|[[2560, 1.983925]...|
+----+--------------------+
only showing top 20 rows



In [11]:
userSubsetRecs.show()

+----+--------------------+
|user|     recommendations|
+----+--------------------+
|5300|[[53, 2.377953], ...|
|5803|[[887, 4.9874797]...|
|5518|[[687, 1.9456208]...|
+----+--------------------+



In [12]:
movieSubSetRecs.show()

+----+--------------------+
|item|     recommendations|
+----+--------------------+
| 471|[[5529, 2.0453987...|
|3175|[[2037, 2.310983]...|
|1088|[[2364, 2.4093258...|
+----+--------------------+



In [13]:
spark.stop()