In [32]:
import os
from pathlib import Path
import pandas as pd
from surprise import BaselineOnly, Dataset, Reader, accuracy, SVD, KNNBasic
from surprise.dump import dump, load
from surprise.model_selection import train_test_split

import surprise_utils as s_utils

In [45]:
# path to dataset file
file_path = os.path.expanduser("~/.surprise_data/ml-100k/ml-100k/u.data")

# file_path = "recommender.csv"
file_path = Path(r"C:\\Users\doem\projects\dsgapps\test data\recommender\collab filtering Restaurant Recommender.csv")

In [46]:
df = pd.read_csv(file_path, delimiter=';')

# A reader is still needed but only the rating_scale param is requiered.
reader = Reader(rating_scale=(1, 5))

# The columns must correspond to user id, item id and ratings (in that order).
data = Dataset.load_from_df(df[["USER_ID", "OSM_ID", "RATING"]], reader)

In [47]:
data

<surprise.dataset.DatasetAutoFolds at 0x1fdee7c2800>

In [50]:
# sample random trainset and testset
# test set is made of 25% of the ratings.
trainset, testset = train_test_split(data, test_size=0.2)

In [51]:
trainset

<surprise.trainset.Trainset at 0x1fdf0765ed0>

In [52]:
sim_options = {
    "name": "cosine",
    "user_based": False,  # compute  similarities between items
}
algo = KNNBasic(sim_options=sim_options)

# Train the algorithm on the trainset, and predict ratings for the testset
algo.fit(trainset)
predictions = algo.test(testset)
# Then compute RMSE
accuracy.rmse(predictions)

Computing the cosine similarity matrix...
Done computing similarity matrix.
RMSE: 1.1081


1.1081254449193887

In [53]:
predictions

[Prediction(uid='4-Team Rick', iid=8199027204, r_ui=5.0, est=3.0, details={'actual_k': 3, 'was_impossible': False}),
 Prediction(uid='6-Team21', iid=359695274, r_ui=3.0, est=3.3275786826695875, details={'actual_k': 3, 'was_impossible': False}),
 Prediction(uid='6-Kool and the floating point', iid=4095987436, r_ui=3.0, est=3.725303201689979, details={'actual_k': 7, 'was_impossible': False}),
 Prediction(uid='3-Kool and the floating point', iid=247766478, r_ui=2.0, est=3.0000000000000004, details={'actual_k': 3, 'was_impossible': False}),
 Prediction(uid='6-Data in Destiny', iid=9813646886, r_ui=4.0, est=4.6640378824896205, details={'actual_k': 3, 'was_impossible': False}),
 Prediction(uid='U16-Data Bunny', iid=359687037, r_ui=4.0, est=3.745929212426745, details={'actual_k': 4, 'was_impossible': False}),
 Prediction(uid='115-S3J', iid=282655574, r_ui=3.0, est=3.000169801133423, details={'actual_k': 9, 'was_impossible': False}),
 Prediction(uid='8-Team Rick', iid=2795232150, r_ui=3.0, est

In [10]:
algo=SVD()
# Train the algorithm on the trainset, and predict ratings for the testset
algo.fit(trainset)
predictions = algo.test(testset)
# Then compute RMSE
accuracy.rmse(predictions)

RMSE: 0.9088


0.9087707776624142

In [11]:
# dump('bla', predictions=None, algo=algo, verbose=0)
# grmpf = load('bla')

In [12]:
predictions

[Prediction(uid='U5-Data Bunny', iid=659475608, r_ui=4.0, est=4.004185641660566, details={'was_impossible': False}),
 Prediction(uid='6-Kool and the floating point', iid=564888514, r_ui=4.0, est=3.8048099022841475, details={'was_impossible': False}),
 Prediction(uid='10-Kool and the floating point', iid=7503079102, r_ui=3.0, est=3.5231903624904954, details={'was_impossible': False}),
 Prediction(uid='U16-Data Bunny', iid=43011471, r_ui=4.0, est=3.195964086010014, details={'was_impossible': False}),
 Prediction(uid='U12-Data Bunny', iid=256866280, r_ui=5.0, est=3.8616633685575374, details={'was_impossible': False}),
 Prediction(uid='108-S3J', iid=561504426, r_ui=3.0, est=3.27854248466876, details={'was_impossible': False}),
 Prediction(uid='115-S3J', iid=6223030513, r_ui=5.0, est=3.521205219443351, details={'was_impossible': False}),
 Prediction(uid='12-Kool and the floating point', iid=359687037, r_ui=3.0, est=4.219129540768112, details={'was_impossible': False}),
 Prediction(uid='9-Te

In [13]:
# predictions = algo.test(testset)
precisions, recalls = s_utils.precision_recall_at_k(predictions, k=5, threshold=4)

In [14]:
len(precisions)

86

In [15]:
print(sum(prec for prec in precisions.values()) / len(precisions))
print(sum(rec for rec in recalls.values()) / len(recalls))

0.2558139534883721
0.22093023255813954


In [30]:
# testset = trainset.build_anti_testset()
predictions = algo.test(testset)

top_n = s_utils.get_top_n(predictions, n=5)

# Print the recommended items for each user
for uid, user_ratings in top_n.items():
    print(uid, [iid for (iid, _) in user_ratings])

U5-Data Bunny [359694784, 256866280, 659475608]
6-Kool and the floating point [564888514, 277746978, 4095400151, 359693315]
10-Kool and the floating point [7503079102, 4816412579]
U16-Data Bunny [359694784, 472774192, 43011471]
U12-Data Bunny [256866280, 246253651]
108-S3J [731833436, 359686907, 561504426]
115-S3J [6223030513]
12-Kool and the floating point [359687037, 2466736061, 7503079102, 135915056]
9-Team Rick [359686907, 1519474203]
3-Team21 [359695274, 1519474203, 4127859153, 703057903]
U15-Data Bunny [9708663946, 282655574, 43011471]
13-Team Rick [76852026291]
U4-Data Bunny [7439496294, 574685031, 353186982]
18-Data in Destiny [380094845, 2466736061, 405400186]
12-Team Rick [1479407508, 7503079102, 282655574]
107-S3J [4095987422, 854335396]
19-Team Rick [9581108183]
14-Team Rick [76852026291, 359686907, 471628000, 353186982]
U7-Data Bunny [564888543, 4095400186]
U9-Data Bunny [270908998, 380488748]
2-Team21 [282655581, 703057903]
20-Data in Destiny [380094845]
11-Team21 [282655

In [31]:
top_n

defaultdict(list,
            {'U5-Data Bunny': [(359694784, 4.213136498013448),
              (256866280, 4.140649067144969),
              (659475608, 4.004185641660566)],
             '6-Kool and the floating point': [(564888514, 3.8048099022841475),
              (277746978, 3.6889426182556653),
              (4095400151, 3.455377613648257),
              (359693315, 2.8108507714245956)],
             '10-Kool and the floating point': [(7503079102,
               3.5231903624904954),
              (4816412579, 3.5197522955900262)],
             'U16-Data Bunny': [(359694784, 3.7982552072420317),
              (472774192, 3.7770414447711866),
              (43011471, 3.195964086010014)],
             'U12-Data Bunny': [(256866280, 3.8616633685575374),
              (246253651, 3.769639572017036)],
             '108-S3J': [(731833436, 3.8611087603574545),
              (359686907, 3.6850521744322804),
              (561504426, 3.27854248466876)],
             '115-S3J': [(6223030513,

In [None]:
new user
retrain
get the top n recommendations for those neighbours

In [17]:
df

Unnamed: 0,USER_ID,OSM_ID,RATING
0,U1-Data Bunny,353186982,4
1,U1-Data Bunny,8128785390,2
2,U1-Data Bunny,7439496294,1
3,U1-Data Bunny,7439496294,4
4,U2-Data Bunny,359687975,4
...,...,...,...
876,17-Team21,4486137546,4
877,19-Team21,359695274,3
878,19-Team21,8199027204,3
879,20-Team21,359695274,4


In [20]:
df_new_user = pd.DataFrame({'USER_ID': ['new_user', 'new_user', 'new_user'],
                            'OSM_ID': [76852026291, 359686907, 471628000], 
                            'RATING': [5, 5, 5]
                           }
                          )

# df_total = df.append(df_new_user, ignore_index=True)
df_tmp = pd.concat([df, df_new_user], ignore_index = True)
# df_total.reset_index()

In [21]:
df_total

Unnamed: 0,USER_ID,OSM_ID,RATING
0,U1-Data Bunny,353186982,4
1,U1-Data Bunny,8128785390,2
2,U1-Data Bunny,7439496294,1
3,U1-Data Bunny,7439496294,4
4,U2-Data Bunny,359687975,4
...,...,...,...
879,20-Team21,359695274,4
880,20-Team21,8199027204,4
881,new_user,76852026291,5
882,new_user,359686907,5


In [22]:
data_tmp = Dataset.load_from_df(df_total[["USER_ID", "OSM_ID", "RATING"]], reader)

In [28]:
algo_tmp = SVD()
trainset_tmp = data_tmp.build_full_trainset()
algo_tmp.fit(trainset_tmp)

#predictions_new_user = algo_tmp.predict(uid, iid, r_ui=None, clip=True, verbose=False)

# Then predict ratings for all pairs (u, i) that are NOT in the training set.
anti_testset_tmp = trainset_tmp.build_anti_testset()
predictions_tmp = algo_tmp.test(anti_testset_tmp)

top_n_tmp = s_utils.get_top_n(predictions_tmp, n=5)

In [29]:
top_n_tmp['new_user']

[(6223030513, 4.736568163449599),
 (359687975, 4.546574060498998),
 (4095987419, 4.514561857823522),
 (4434577593, 4.513280053757089),
 (359694784, 4.50174913545392)]

In [59]:
bla = []
for u, item_ratings in top_n_tmp.items():
    for item, rating in item_ratings:
        bla.append([u, item, rating])

In [62]:
pd.DataFrame(bla, columns=['USER_ID', 'ITEM_ID', 'RATING'])

Unnamed: 0,USER_ID,ITEM_ID,RATING
0,U1-Data Bunny,6223030513,4.210050
1,U1-Data Bunny,4434577593,4.125262
2,U1-Data Bunny,359686638,3.971164
3,U1-Data Bunny,359687975,3.950164
4,U1-Data Bunny,4095400186,3.900128
...,...,...,...
555,new_user,6223030513,4.736568
556,new_user,359687975,4.546574
557,new_user,4095987419,4.514562
558,new_user,4434577593,4.513280


In [55]:
pd.DataFrame(top_n_tmp)

Unnamed: 0,U1-Data Bunny,U2-Data Bunny,U3-Data Bunny,U4-Data Bunny,U5-Data Bunny,U6-Data Bunny,U7-Data Bunny,U8-Data Bunny,U9-Data Bunny,U10-Data Bunny,...,11-Team21,12-Team21,13-Team21,14-Team21,15-Team21,16-Team21,17-Team21,19-Team21,20-Team21,new_user
0,"(6223030513, 4.210049705643355)","(256866280, 4.141171734976134)","(6223030513, 4.520904481669662)","(6223030513, 4.677453475894925)","(6223030513, 4.484484076464255)","(6223030513, 4.500488825624482)","(256866280, 4.806635137535432)","(6223030513, 4.3792065709660495)","(4434577593, 4.805647022615433)","(6223030513, 4.042237773744769)",...,"(256866280, 4.240259894909435)","(6223030513, 4.346625710479949)","(256866280, 4.327422635259509)","(256866280, 4.431241847868084)","(6223030513, 4.359824402043575)","(6223030513, 4.416068959414806)","(4095987419, 4.30252705014796)","(256866280, 4.213698527544506)","(6223030513, 4.224983601604321)","(6223030513, 4.736568163449599)"
1,"(4434577593, 4.125262229132057)","(282655581, 3.9876514070273363)","(4095987419, 4.315206518886112)","(256866280, 4.631672091927025)","(4434577593, 4.427189430732434)","(4434577593, 4.44017454521395)","(4434577593, 4.806123101760217)","(4095987422, 4.2841665592203135)","(359686638, 4.773930972765266)","(4434577593, 4.002780506484053)",...,"(4434577593, 4.166472631702044)","(256866280, 4.037364326612054)","(6223030513, 4.295954141051041)","(9655661535, 4.328161777159039)","(4095987419, 4.339937013904146)","(256866280, 4.212455940284786)","(256866280, 4.2733838183980035)","(9655661535, 4.072661191254685)","(4095987419, 4.207457169139077)","(359687975, 4.546574060498998)"
2,"(359686638, 3.971164141612961)","(6223030513, 3.9724082097545144)","(359695274, 4.301687472883813)","(4095987424, 4.468860155636384)","(359687975, 4.390838602362455)","(564888514, 4.364418774710649)","(4095987419, 4.773580812495828)","(498891528, 4.214019601857021)","(4095987419, 4.695684337310117)","(359695274, 3.944201416716761)",...,"(380488748, 4.151063109730835)","(9655661535, 4.010784858853815)","(282655581, 4.265025647947657)","(6223030513, 4.269403411296697)","(282655581, 4.337419981964846)","(4095987419, 4.211640278371585)","(4095400186, 4.208746003460079)","(4434577593, 4.053277027932539)","(534878149, 4.195665667081286)","(4095987419, 4.514561857823522)"
3,"(359687975, 3.9501643604386665)","(4095987419, 3.8656477212810105)","(560157037, 4.293099806268862)","(4095987419, 4.423433118245608)","(9655661535, 4.324742137966524)","(256866280, 4.324386852073746)","(534878149, 4.768654740105913)","(359694784, 4.213022912683253)","(6223030513, 4.683626881551704)","(359686638, 3.9210725720743786)",...,"(6223030513, 4.145040274833909)","(498891528, 3.98707953156145)","(4095987419, 4.157591497874724)","(564888543, 4.247160898162837)","(256866280, 4.31362134813561)","(9655661535, 4.17487238073402)","(6223030513, 4.20530507272088)","(4095987419, 4.030657242202946)","(256866280, 4.18480816509312)","(4434577593, 4.513280053757089)"
4,"(4095400186, 3.9001282083842854)","(369825427, 3.8482031461881423)","(359687975, 4.264703729733945)","(4434577593, 4.39364948924547)","(4095987419, 4.3177272809284855)","(498891528, 4.317552437513661)","(6223030513, 4.754182212968831)","(359687975, 4.204955767080571)","(256866280, 4.682496431952673)","(9813646886, 3.9066480127694336)",...,"(9655661535, 4.144024156867603)","(4095987419, 3.963019961555809)","(380488748, 4.148894422837941)","(498891528, 4.2447745264065855)","(359694784, 4.266381323042245)","(4434577593, 4.1664278669807855)","(4434577593, 4.147449610384641)","(498891528, 4.005171880025994)","(4095987422, 4.182984300902256)","(359694784, 4.50174913545392)"
