In [3]:
from recommender import Recommender

In [4]:
# Load restaurant reviews
reviews_df = spark.read.parquet('../data/ratings_ugt10_igt10')

# Randomly split data into train and test datasets
train_df, test_df = reviews_df.randomSplit(weights=[0.75, 0.25])

print(train_df.printSchema())

root
 |-- user: integer (nullable = true)
 |-- item: integer (nullable = true)
 |-- rating: byte (nullable = true)

None


In [5]:
estimator = Recommender(
    useALS=True,
    useBias=True,
    lambda_1=10,
    lambda_2=15,
    userCol='user',
    itemCol='item',
    ratingCol='rating',
    rank=76,
    regParam=0.7,
    maxIter=10,
    nonnegative=True
)
model = estimator.fit(train_df)

train_predictions_df = model.transform(train_df)
test_predictions_df = model.transform(test_df)

print(test_predictions_df.printSchema())

root
 |-- user: integer (nullable = true)
 |-- item: integer (nullable = true)
 |-- rating: byte (nullable = true)
 |-- prediction: double (nullable = true)

None


In [13]:
test_predictions_df.registerTempTable('test_predictions_df')
df1 = spark.sql(
'''
select
    user,
    item,
    rating,
    prediction,
    row_number() over (
        partition by user
        order by prediction desc
    ) as pred_row_num,
    row_number() over (
        partition by user
        order by rating desc
    ) as actual_row_num
from test_predictions_df
where user = 22
order by pred_row_num
'''
)

df1.show(100)

+----+----+------+------------------+------------+--------------+
|user|item|rating|        prediction|pred_row_num|actual_row_num|
+----+----+------+------------------+------------+--------------+
|  22| 137|     5|3.9933987595712743|           1|             2|
|  22|1025|     5| 3.935163781519597|           2|             1|
|  22|1440|     4| 3.807981914033812|           3|            66|
|  22|  67|     5|3.7998659444128773|           4|            17|
|  22|1233|     5| 3.796777460368327|           5|            12|
|  22|  11|     5|3.7485844194795224|           6|            16|
|  22|  43|     5|3.7399988775074604|           7|             6|
|  22| 383|     5|3.7023914423366104|           8|             9|
|  22| 746|     4| 3.701308989315276|           9|            47|
|  22|2257|     4|3.6941495453100384|          10|            44|
|  22|1093|     4| 3.634463903569177|          11|            25|
|  22| 256|     5|3.6214131907469316|          12|            14|
|  22| 729