In [1]:
import os
import sys
from collections import Counter
import numpy as np
import pandas as pd
from pyspark.ml.recommendation import ALS, ALSModel
import pyspark as ps

spark = (ps.sql.SparkSession.builder
        .appName("recommend")
        .getOrCreate()
        )

sc = spark.sparkContext

In [8]:
## load df
data_dir = os.path.join("ml-latest-small")
movies_df = pd.read_csv(os.path.join(data_dir,"movies.csv"))
movies_df.rename(columns={"movieId": "movie_id"},inplace=True)
# top_movies = pd.read_csv("top-movies.csv")

In [9]:
## load saved model
save_dir = "saved-recommender"
model = ALSModel.load(save_dir)

In [16]:
## compile some movie ratings these are taken from the top ten (movie_id, title, rating)
new_ratings = [(1210,"Episode VI - Return of the Jedi (1983)",5),
               (179819,"Star Wars: The Last Jedi (2017)",5),
               (187595,"Solo: A Star Wars Story (2018)",4),
               (122886,"Star Wars: Episode VII - The Force Awakens (2015)",3),
               (1196,"Star Wars: Episode V - The Empire Strikes Back (1980)",5),
               (2628,"Star Wars: Episode I - The Phantom Menace (1999)",5),
               (260, "Star Wars: Episode IV - A New Hope (1977)",5)]


In [15]:
new_ratings

[(1210, 'Episode VI - Return of the Jedi (1983)', 5),
 (179819, 'Star Wars: The Last Jedi (2017)', 5),
 (187595, 'Solo: A Star Wars Story (2018)', 4),
 (122886, 'Star Wars: Episode VII - The Force Awakens (2015)', 3),
 (1196, 'Star Wars: Episode V - The Empire Strikes Back (1980)', 5),
 (2628, 'Star Wars: Episode I - The Phantom Menace (1999)', 5),
 (260, 'Star Wars: Episode IV - A New Hope (1977)', 5)]

In [17]:
new_ratings = sorted(new_ratings, key=lambda tup: tup[2])[::-1]
new_ratings

[(260, 'Star Wars: Episode IV - A New Hope (1977)', 5),
 (2628, 'Star Wars: Episode I - The Phantom Menace (1999)', 5),
 (1196, 'Star Wars: Episode V - The Empire Strikes Back (1980)', 5),
 (179819, 'Star Wars: The Last Jedi (2017)', 5),
 (1210, 'Episode VI - Return of the Jedi (1983)', 5),
 (187595, 'Solo: A Star Wars Story (2018)', 4),
 (122886, 'Star Wars: Episode VII - The Force Awakens (2015)', 3)]

In [18]:
best_rated = [(nr[0],) for nr in new_ratings]
print('best rated', best_rated)

best rated [(260,), (2628,), (1196,), (179819,), (1210,), (187595,), (122886,)]


In [19]:
rated_movies = np.array([nr[0] for nr in new_ratings])
rated_movies

array([   260,   2628,   1196, 179819,   1210, 187595, 122886])

In [24]:
## query the model and find the closest k users
query1 = spark.createDataFrame(best_rated, ["movieId"])
query1.show()

+-------+
|movieId|
+-------+
|    260|
|   2628|
|   1196|
| 179819|
|   1210|
| 187595|
| 122886|
+-------+



In [25]:
user_recs = model.recommendForItemSubset(query1,100)
user_recs = user_recs.toPandas()

In [28]:
user_recs

Unnamed: 0,movieId,recommendations
0,1210,"[(543, 5.272472381591797), (53, 5.205809116363..."
1,2628,"[(452, 4.151139259338379), (543, 4.06827688217..."
2,179819,"[(406, 4.228682518005371), (224, 4.20913028717..."
3,122886,"[(494, 5.200282573699951), (310, 5.09558153152..."
4,1196,"[(543, 5.298564434051514), (276, 5.25375032424..."
5,260,"[(543, 5.494180202484131), (337, 5.41877126693..."
6,187595,"[(360, 4.236614227294922), (452, 4.10780477523..."


In [29]:
urecs = Counter()

for u in user_recs['recommendations']:
    for rec in u:
        rec = tuple(rec.asDict().values())
        urecs[rec[0]] += rec[1]

closest_users = [(ur[0],) for ur in urecs.most_common()]
print("closest_users\n",closest_users)

closest_users
 [(452,), (523,), (475,), (236,), (93,), (519,), (515,), (122,), (25,), (169,), (48,), (250,), (491,), (371,), (251,), (348,), (337,), (1,), (267,), (53,), (273,), (77,), (69,), (594,), (95,), (319,), (246,), (12,), (380,), (597,), (472,), (97,), (450,), (586,), (336,), (543,), (540,), (595,), (389,), (382,), (527,), (186,), (340,), (533,), (453,), (400,), (201,), (494,), (220,), (52,), (80,), (45,), (276,), (171,), (413,), (554,), (375,), (37,), (574,), (579,), (106,), (441,), (99,), (44,), (548,), (435,), (505,), (377,), (534,), (409,), (355,), (573,), (408,), (550,), (417,), (210,), (43,), (447,), (40,), (284,), (456,), (486,), (585,), (79,), (147,), (367,), (538,), (31,), (465,), (30,), (275,), (234,), (458,), (360,), (451,), (362,), (310,), (58,), (244,), (269,), (115,), (358,), (155,), (278,), (556,), (324,), (98,), (35,), (303,), (545,), (13,), (176,), (238,), (62,), (410,), (376,), (419,), (96,), (544,), (492,), (20,), (302,), (32,), (562,), (66,), (572,), (406,),

In [38]:
urecs.most_common()

[(452, 31.6476788520813),
 (523, 30.947514057159424),
 (475, 30.940001964569092),
 (236, 30.932044744491577),
 (93, 30.824015855789185),
 (519, 30.739495515823364),
 (515, 30.276190757751465),
 (122, 30.244189500808716),
 (25, 29.931838274002075),
 (169, 29.89694905281067),
 (48, 29.8144633769989),
 (250, 29.70211386680603),
 (491, 29.556551933288574),
 (371, 29.5282199382782),
 (251, 29.317179918289185),
 (348, 29.13367009162903),
 (337, 27.514169216156006),
 (1, 27.51026940345764),
 (267, 27.278703212738037),
 (53, 26.732878923416138),
 (273, 26.439852237701416),
 (77, 26.31610918045044),
 (69, 26.312480449676514),
 (594, 26.10871148109436),
 (95, 25.766714811325073),
 (319, 25.71707057952881),
 (246, 25.666338682174683),
 (12, 25.594131469726562),
 (380, 25.453202724456787),
 (597, 25.322821855545044),
 (472, 25.18815517425537),
 (97, 25.10091805458069),
 (450, 25.097493886947632),
 (586, 25.05040454864502),
 (336, 24.259069681167603),
 (543, 24.045047283172607),
 (540, 24.021739244