In [1]:
from recommender import Recommender

In [2]:
# Load restaurant reviews
reviews_df = spark.read.parquet('../data/ratings_ugt10_igt10')

# Randomly split data into train and test datasets
train_df, test_df = reviews_df.randomSplit(weights=[0.75, 0.25])

print(train_df.printSchema())

root
 |-- user: integer (nullable = true)
 |-- item: integer (nullable = true)
 |-- rating: byte (nullable = true)

None


In [10]:
estimator = Recommender(
    useALS=True,
    useBias=True,
    lambda_1=7,
    lambda_2=12,
    userCol='user',
    itemCol='item',
    ratingCol='rating',
    rank=100,
    regParam=0.7,
    maxIter=15,
    nonnegative=True
)
model = estimator.fit(train_df)

train_predictions_df = model.transform(train_df)
test_predictions_df = model.transform(test_df)

print(test_predictions_df.printSchema())

root
 |-- user: integer (nullable = true)
 |-- item: integer (nullable = true)
 |-- rating: byte (nullable = true)
 |-- prediction: double (nullable = true)

None


In [12]:
test_predictions_df.registerTempTable('test_predictions_df')
df1 = spark.sql(
'''
select
    user,
    item,
    rating,
    prediction,
    row_number() over (
        partition by user
        order by prediction desc
    ) as pred_row_num,
    row_number() over (
        partition by user
        order by rating desc
    ) as actual_row_num
from test_predictions_df
where user = 100
order by pred_row_num
'''
)

df1.show(100)

+----+----+------+------------------+------------+--------------+
|user|item|rating|        prediction|pred_row_num|actual_row_num|
+----+----+------+------------------+------------+--------------+
| 100|2397|     5|3.8271621666996065|           1|             1|
| 100|1256|     5| 3.775430929157153|           2|             2|
| 100|2433|     4|3.7721955593739835|           3|             4|
| 100|3560|     5|3.7430339837811477|           4|             3|
| 100|4110|     4|3.7029026405044814|           5|             5|
| 100| 796|     4| 3.640831027679008|           6|             6|
| 100|2505|     4|3.5801973842770565|           7|             7|
| 100|3871|     4|3.5722447786664784|           8|             8|
| 100|1314|     4| 3.551952102842634|           9|             9|
| 100| 603|     4|3.4440679077903127|          10|            10|
| 100|3899|     4| 3.386013989135117|          11|            11|
| 100|1806|     4| 3.383053083690733|          12|            12|
| 100| 770