### Libraries Required

In [1]:
import pandas as pd
from pyspark import SparkContext
from pyspark.sql.session import SparkSession
from pyspark.sql.types import IntegerType, StringType
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS

### Initiate Spark Session

In [2]:
sc = SparkContext.getOrCreate()

if (sc is None):
    sc = SparkContext(master="local[2]", appName="Meal Recipe Collaborative Filtering")
spark = SparkSession(sparkContext=sc)

22/10/06 18:39:12 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/10/06 18:39:13 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [3]:
test_user_data = pd.read_csv('user_rating_test_data.csv', ',')
test_user_data['Recipe_Index'] = test_user_data.index
# test_user_data = test_user_data.dropna()
del test_user_data['Title']

  exec(code_obj, self.user_global_ns, self.user_ns)


### Input Format

input should be in the form of the table below User_ID, Recipe_Index, Individual's Recipe Rating

In [4]:
test_user_data = test_user_data.melt(id_vars = 'Recipe_Index', var_name='User_ID', value_name='Rating')
test_user_data = test_user_data[['User_ID', 'Recipe_Index', 'Rating']]
test_user_data = test_user_data.dropna()

test_user_data.head(5)

Unnamed: 0,User_ID,Recipe_Index,Rating
4,0,4,4.0
6,0,6,5.0
8,0,8,4.0
9,0,9,5.0
22,0,22,1.0


In [5]:
# turn pandas df into spark df for training
test_user_df = spark.createDataFrame(test_user_data)

# convert user_id to string
test_user_df = test_user_df \
    .withColumn('User_ID', test_user_df['User_ID'].cast(IntegerType()))

### Model Training

In [6]:
# train / test split
train, test = test_user_df.randomSplit([0.8, 0.2])

# define ALS model hyperparameters
als = ALS(maxIter=4, regParam=0.1, userCol="User_ID", itemCol="Recipe_Index", ratingCol="Rating",
          coldStartStrategy="drop")
model = als.fit(train)

22/10/06 18:40:02 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
22/10/06 18:40:02 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
22/10/06 18:40:02 WARN LAPACK: Failed to load implementation from: com.github.fommil.netlib.NativeSystemLAPACK
22/10/06 18:40:02 WARN LAPACK: Failed to load implementation from: com.github.fommil.netlib.NativeRefLAPACK


#### Model Performance

In [7]:
# apply model to test
# predictions = model.transform(test)

# eval = RegressionEvaluator(metricName="rmse", labelCol="Rating", predictionCol="prediction")
# rmse = eval.evaluate(predictions)
# print("Root-mean-square error = " + str(rmse))



Root-mean-square error = 3.9565706811683357


                                                                                

MSE is quite poor, suggesting that more data is required to provide more reliable predictions<br>
Ideally we should have below 1.0 RMSE

### Extracting Features (Pandas)

In [7]:
userRecs = model.recommendForAllUsers(50)

In [8]:
pd.set_option('display.max_colwidth', None)
user_predictions = userRecs.toPandas()

user_predictions['User_ID'] = user_predictions['User_ID'].astype(str)

                                                                                

In [9]:
user_predictions.head(10)

Unnamed: 0,User_ID,recommendations
0,12,"[(8, 5.376016139984131), (358, 5.023803234100342), (43, 5.023803234100342), (71, 5.023803234100342), (374, 5.023803234100342), (52, 5.023803234100342), (279, 4.97892951965332), (318, 4.954041957855225), (177, 4.8562774658203125), (219, 4.785390377044678), (384, 4.771224021911621), (229, 4.678145885467529), (134, 4.36651086807251), (133, 4.01904296875), (183, 4.01904296875), (368, 4.01904296875), (2, 4.01904296875), (264, 4.01904296875), (13, 4.01904296875), (315, 4.01904296875), (95, 4.01904296875), (29, 4.01904296875), (228, 3.9831433296203613), (435, 3.849364995956421), (319, 3.8318212032318115), (359, 3.8318212032318115), (514, 3.72397518157959), (537, 3.6805195808410645), (545, 3.6805195808410645), (351, 3.6805195808410645), (61, 3.529479503631592), (9, 3.5080208778381348), (189, 3.4907331466674805), (10, 3.4907331466674805), (207, 3.4907331466674805), (256, 3.4907331466674805), (76, 3.4907331466674805), (273, 3.4907331466674805), (320, 3.4907331466674805), (136, 3.4907331466674805), (80, 3.4907331466674805), (36, 3.4907331466674805), (21, 3.4807207584381104), (99, 3.1906204223632812), (424, 3.119018077850342), (597, 3.0619282722473145), (55, 2.9766366481781006), (532, 2.94441556930542), (379, 2.94441556930542), (244, 2.9198625087738037)]"
1,1,"[(234, 4.946934700012207), (54, 4.937650680541992), (123, 4.896336555480957), (170, 4.896336555480957), (348, 4.896336555480957), (248, 4.0796685218811035), (167, 4.07916259765625), (152, 3.9875729084014893), (450, 3.9501209259033203), (555, 3.707733154296875), (349, 3.645171642303467), (520, 3.5876309871673584), (378, 3.422382116317749), (6, 3.422382116317749), (517, 3.422382116317749), (258, 3.422382116317749), (423, 3.422382116317749), (484, 3.422382116317749), (480, 3.422382116317749), (561, 3.422382116317749), (116, 3.422382116317749), (448, 3.422382116317749), (276, 3.422382116317749), (206, 3.422382116317749), (291, 3.422382116317749), (403, 3.422382116317749), (559, 3.422382116317749), (215, 3.422382116317749), (298, 3.1325600147247314), (134, 3.0961716175079346), (153, 2.999960422515869), (601, 2.820251941680908), (350, 2.7947275638580322), (296, 2.7947275638580322), (124, 2.7947275638580322), (419, 2.7947275638580322), (539, 2.737905979156494), (562, 2.737905979156494), (375, 2.737905979156494), (109, 2.737905979156494), (201, 2.737905979156494), (401, 2.737905979156494), (364, 2.5513076782226562), (55, 2.5405144691467285), (456, 2.51198673248291), (9, 2.435243844985962), (137, 2.404531955718994), (369, 2.394360065460205), (446, 2.2357821464538574), (278, 2.2357821464538574)]"
2,13,"[(8, 4.96788215637207), (537, 4.940181255340576), (545, 4.940181255340576), (351, 4.940181255340576), (514, 4.719198703765869), (384, 4.406377792358398), (379, 3.9521446228027344), (532, 3.9521446228027344), (43, 3.8388102054595947), (52, 3.8388102054595947), (358, 3.8388102054595947), (71, 3.8388102054595947), (374, 3.8388102054595947), (279, 3.740809917449951), (229, 3.7381157875061035), (318, 3.5370681285858154), (219, 3.511699914932251), (177, 3.373915910720825), (315, 3.0710482597351074), (264, 3.0710482597351074), (183, 3.0710482597351074), (133, 3.0710482597351074), (13, 3.0710482597351074), (368, 3.0710482597351074), (2, 3.0710482597351074), (95, 3.0710482597351074), (29, 3.0710482597351074), (228, 2.992648124694824), (99, 2.869570732116699), (319, 2.7320306301116943), (359, 2.7320306301116943), (577, 2.59279465675354), (136, 2.565070152282715), (36, 2.565070152282715), (207, 2.565070152282715), (80, 2.565070152282715), (10, 2.565070152282715), (320, 2.565070152282715), (256, 2.565070152282715), (189, 2.565070152282715), (76, 2.565070152282715), (273, 2.565070152282715), (380, 2.5426597595214844), (20, 2.5426597595214844), (14, 2.366821527481079), (435, 2.345383405685425), (172, 2.052056312561035), (237, 2.052056312561035), (78, 2.052056312561035), (250, 2.052056312561035)]"
3,6,"[(280, 5.049895286560059), (74, 5.049895286560059), (156, 5.049895286560059), (128, 5.049895286560059), (32, 4.801599979400635), (318, 4.014624118804932), (424, 3.9623234272003174), (564, 3.7898190021514893), (455, 3.7151968479156494), (120, 3.320910692214966), (99, 3.163013458251953), (384, 3.049278736114502), (410, 2.8927135467529297), (31, 2.799010753631592), (577, 2.6932201385498047), (55, 2.652770519256592), (9, 2.566408157348633), (20, 2.524834632873535), (380, 2.524834632873535), (177, 2.4735302925109863), (14, 2.2978034019470215), (165, 2.248216390609741), (520, 2.2348837852478027), (555, 2.1792964935302734), (71, 2.17818546295166), (43, 2.17818546295166), (358, 2.17818546295166), (52, 2.17818546295166), (374, 2.17818546295166), (27, 2.081308364868164), (433, 2.053126096725464), (219, 2.024628162384033), (513, 2.0198676586151123), (240, 2.0136380195617676), (56, 1.9743144512176514), (452, 1.9743144512176514), (115, 1.9743144512176514), (143, 1.9743144512176514), (400, 1.9743144512176514), (166, 1.9743144512176514), (404, 1.9743144512176514), (70, 1.9743144512176514), (486, 1.965259075164795), (293, 1.886983871459961), (361, 1.886983871459961), (381, 1.886983871459961), (67, 1.886983871459961), (28, 1.886983871459961), (147, 1.886983871459961), (508, 1.886983871459961)]"
4,3,"[(490, 4.994433403015137), (397, 4.994433403015137), (321, 4.994433403015137), (515, 4.994433403015137), (324, 4.994433403015137), (456, 4.841506004333496), (177, 4.70665979385376), (134, 4.44935941696167), (153, 4.269294261932373), (555, 4.225594997406006), (275, 3.99554705619812), (220, 3.99554705619812), (369, 3.981689929962158), (99, 3.8507981300354004), (286, 3.832630157470703), (483, 3.8036949634552), (597, 3.6867713928222656), (601, 3.559250593185425), (192, 3.07696533203125), (365, 3.07696533203125), (53, 3.07696533203125), (385, 3.07696533203125), (429, 3.07696533203125), (40, 3.07696533203125), (470, 3.07696533203125), (571, 3.07696533203125), (253, 3.07696533203125), (149, 3.07696533203125), (419, 3.0428969860076904), (296, 3.0428969860076904), (124, 3.0428969860076904), (350, 3.0428969860076904), (72, 3.0378382205963135), (138, 3.0378382205963135), (55, 3.023186683654785), (152, 2.986955165863037), (9, 2.977414608001709), (298, 2.928626298904419), (120, 2.925765037536621), (248, 2.8345534801483154), (54, 2.8068289756774902), (56, 2.76029109954834), (166, 2.76029109954834), (115, 2.76029109954834), (143, 2.76029109954834), (70, 2.76029109954834), (400, 2.76029109954834), (452, 2.76029109954834), (404, 2.76029109954834), (5, 2.701261043548584)]"
5,5,"[(54, 5.027791500091553), (555, 4.9833478927612305), (601, 4.963939666748047), (419, 4.950215816497803), (350, 4.950215816497803), (296, 4.950215816497803), (124, 4.950215816497803), (456, 4.907982349395752), (248, 4.900073528289795), (152, 4.83144474029541), (153, 4.096888542175293), (234, 4.048573017120361), (450, 4.022233486175537), (349, 4.0196075439453125), (446, 3.960172653198242), (278, 3.960172653198242), (298, 3.91361927986145), (365, 3.377232074737549), (253, 3.377232074737549), (53, 3.377232074737549), (571, 3.377232074737549), (429, 3.377232074737549), (192, 3.377232074737549), (40, 3.377232074737549), (470, 3.377232074737549), (385, 3.377232074737549), (149, 3.377232074737549), (520, 3.371901035308838), (134, 3.1124191284179688), (490, 3.026341438293457), (515, 3.026341438293457), (321, 3.026341438293457), (324, 3.026341438293457), (397, 3.026341438293457), (348, 3.0161991119384766), (123, 3.0161991119384766), (170, 3.0161991119384766), (167, 2.9713733196258545), (137, 2.8041398525238037), (166, 2.7875208854675293), (452, 2.7875208854675293), (56, 2.7875208854675293), (404, 2.7875208854675293), (70, 2.7875208854675293), (400, 2.7875208854675293), (115, 2.7875208854675293), (143, 2.7875208854675293), (517, 2.7859950065612793), (480, 2.7859950065612793), (561, 2.7859950065612793)]"
6,15,"[(520, 4.997949123382568), (116, 4.969391345977783), (276, 4.969391345977783), (480, 4.969391345977783), (403, 4.969391345977783), (258, 4.969391345977783), (484, 4.969391345977783), (561, 4.969391345977783), (378, 4.969391345977783), (206, 4.969391345977783), (291, 4.969391345977783), (517, 4.969391345977783), (423, 4.969391345977783), (6, 4.969391345977783), (215, 4.969391345977783), (448, 4.969391345977783), (559, 4.969391345977783), (364, 4.950148582458496), (46, 4.893669605255127), (349, 4.8473992347717285), (167, 4.838993549346924), (27, 4.808556079864502), (54, 4.290230751037598), (302, 4.210588455200195), (555, 4.071502685546875), (342, 4.049943923950195), (298, 4.010075092315674), (234, 3.9933009147644043), (539, 3.975513458251953), (201, 3.975513458251953), (562, 3.975513458251953), (109, 3.975513458251953), (375, 3.975513458251953), (401, 3.975513458251953), (410, 3.9633982181549072), (134, 3.886234998703003), (348, 3.7890570163726807), (123, 3.7890570163726807), (170, 3.7890570163726807), (248, 3.7209630012512207), (152, 3.6640334129333496), (165, 3.4482901096343994), (450, 3.432185173034668), (153, 3.248350143432617), (361, 3.136890411376953), (147, 3.136890411376953), (381, 3.136890411376953), (544, 3.136890411376953), (67, 3.136890411376953), (508, 3.136890411376953)]"
7,9,"[(279, 4.977869987487793), (76, 4.961957931518555), (10, 4.961957931518555), (320, 4.961957931518555), (136, 4.961957931518555), (273, 4.961957931518555), (207, 4.961957931518555), (256, 4.961957931518555), (189, 4.961957931518555), (80, 4.961957931518555), (36, 4.961957931518555), (61, 4.936647415161133), (21, 4.913447856903076), (8, 4.815358638763428), (244, 4.729251384735107), (134, 4.10179328918457), (72, 4.070106506347656), (138, 4.070106506347656), (228, 3.9822962284088135), (322, 3.969567060470581), (237, 3.969567060470581), (287, 3.969567060470581), (250, 3.969567060470581), (172, 3.969567060470581), (323, 3.969567060470581), (78, 3.969567060470581), (46, 3.92425274848938), (52, 3.3804032802581787), (43, 3.3804032802581787), (358, 3.3804032802581787), (374, 3.3804032802581787), (71, 3.3804032802581787), (229, 3.2641894817352295), (219, 3.2520089149475098), (177, 3.1787161827087402), (435, 2.7218611240386963), (318, 2.71132755279541), (368, 2.7043228149414062), (133, 2.7043228149414062), (315, 2.7043228149414062), (29, 2.7043228149414062), (183, 2.7043228149414062), (2, 2.7043228149414062), (95, 2.7043228149414062), (264, 2.7043228149414062), (13, 2.7043228149414062), (597, 2.703433036804199), (384, 2.664287567138672), (302, 2.6615896224975586), (319, 2.6216495037078857)]"
8,4,"[(168, 4.945924758911133), (163, 4.945924758911133), (0, 4.945924758911133), (97, 4.945924758911133), (83, 4.945924758911133), (77, 4.945924758911133), (132, 4.945924758911133), (162, 4.945924758911133), (174, 4.945924758911133), (113, 4.945924758911133), (158, 4.945924758911133), (145, 4.945924758911133), (161, 4.945924758911133), (55, 4.918060779571533), (34, 4.753960609436035), (18, 4.753960609436035), (119, 3.956739664077759), (212, 3.956739664077759), (118, 3.956739664077759), (282, 3.956739664077759), (160, 3.956739664077759), (173, 3.956739664077759), (85, 3.956739664077759), (153, 3.843278169631958), (9, 2.332942485809326), (555, 2.1516101360321045), (520, 2.101238489151001), (27, 2.0243966579437256), (364, 1.9963338375091553), (102, 1.9783698320388794), (479, 1.9783698320388794), (180, 1.9783698320388794), (200, 1.9783698320388794), (409, 1.9783698320388794), (459, 1.9783698320388794), (451, 1.9783698320388794), (57, 1.967224359512329), (286, 1.9361381530761719), (99, 1.9318532943725586), (345, 1.9095765352249146), (120, 1.8677523136138916), (433, 1.8663538694381714), (234, 1.8268479108810425), (167, 1.7958945035934448), (424, 1.739809274673462), (54, 1.7346208095550537), (170, 1.7268784046173096), (348, 1.7268784046173096), (123, 1.7268784046173096), (31, 1.6185508966445923)]"
9,8,"[(166, 5.011740684509277), (115, 5.011740684509277), (70, 5.011740684509277), (143, 5.011740684509277), (400, 5.011740684509277), (56, 5.011740684509277), (452, 5.011740684509277), (404, 5.011740684509277), (410, 4.9638671875), (120, 4.938035011291504), (312, 4.746097087860107), (202, 4.746097087860107), (5, 4.733315944671631), (482, 4.009392738342285), (214, 4.009392738342285), (117, 4.009392738342285), (340, 4.009392738342285), (151, 4.009392738342285), (302, 3.3569483757019043), (456, 3.2868106365203857), (555, 3.2247111797332764), (153, 3.1972920894622803), (369, 3.0814669132232666), (601, 3.0618951320648193), (520, 2.9822123050689697), (296, 2.979816198348999), (350, 2.979816198348999), (124, 2.979816198348999), (419, 2.979816198348999), (167, 2.932933807373047), (137, 2.64085054397583), (31, 2.5772249698638916), (349, 2.574199914932251), (298, 2.566957950592041), (134, 2.444838285446167), (515, 2.443878650665283), (324, 2.443878650665283), (321, 2.443878650665283), (490, 2.443878650665283), (397, 2.443878650665283), (278, 2.383852958679199), (446, 2.383852958679199), (99, 2.3299005031585693), (234, 2.242629051208496), (152, 2.2189924716949463), (248, 2.1989738941192627), (571, 2.1784253120422363), (385, 2.1784253120422363), (365, 2.1784253120422363), (53, 2.1784253120422363)]"


### Extract Recommendations

Run this function to get the N-number of recommendations that should be shown to the user

extractRecommendations(ratings_df, predictions_df, user_id, num_of_recommendations)<br>

Where:<br><br>
<b>ratings_df</b><br> refers to the user data table from the database, including user_id, recipe_id, recipe_ratings<br>

<b>predictions_df</b><br> is the recomemndations for each user based on the trained model<br>

<b>user_id</b><br> being the unique user identifying number<br>

<b>num_of_recommendations</b><br> being the number of recommendations you wish to output<br>

In [13]:
def extractRecommendations(ratings_df, predictions_df, user_id, num_of_recommendations):
    
    predicted_recipes = []
    user_ratings = ratings_df[ratings_df['User_ID'] == user_id]['Recipe_Index'].tolist()
    
    for item in predictions_df[predictions_df['User_ID'] == user_id]['recommendations'].tolist()[0]:
        predicted_recipes.append(item[0])
        
    return [x for x in predicted_recipes if x not in user_ratings][:num_of_recommendations]
    
    
print(extractRecommendations(test_user_data, user_predictions, '11', 5))

[153, 134, 298, 72, 138]


In [25]:
def extractRecommendations(ratings_df, predictions_df, user_id, num_of_recommendations):
    
    predicted_recipes = []
    user_ratings = ratings_df[ratings_df['User_ID'] == user_id]['Recipe_Index'].tolist()
    
    for item in predictions_df[predictions_df['User_ID'] == user_id]['recommendations'].tolist()[0]:
        predicted_recipes.append(item[0])
        
    return [x for x in predicted_recipes if x not in user_ratings][:num_of_recommendations]
    
    
    
print(extractRecommendations(test_user_data, user_predictions, '11', 5))

[153, 134, 298, 72, 138]


In [33]:
user_list = []
for index, row in user_predictions.iterrows():
    user_list.append(row[0])

recommended_recipes = []
for user in user_list:
    recommended_recipes.append(extractRecommendations(test_user_data, user_predictions, user, 5))
    
output_df = pd.DataFrame({'User': user_list, 'Recommendations': recommended_recipes})

output_df.to_csv('user_recommendations.csv', sep=',', encoding='utf-8', index=False)