In [1]:
spark

In [2]:
spark.conf.set(
  "fs.azure.account.key.storagestudent.blob.core.windows.net", 
  "8ykAjWw0X+FVY8PnmVhnY5XyDdbpBLceLsFTbuTttdn2W76+sdAsSwYNQ9E+H8Oxlsay8/uCm8ayyWq73kB82Q=="
)

In [3]:
datasets = {
  dataset: spark.read.load( 
    "wasbs://default@storagestudent.blob.core.windows.net/datasets/S8-3/Exo/restaurant-data-with-consumer-ratings/{0}.csv".format(dataset), 
    format="csv",
    header="true"
  )
  for dataset in [
    "chefmozaccepts", 
    "chefmozcuisine", 
    "chefmozhours4", 
    "chefmozparking", 
    "geoplaces2", 
    "rating_final", 
    "usercuisine", 
    "userpayment", 
    "userprofile"
  ]
} 

In [4]:
import pyspark.sql.functions as F
from pyspark.sql.types import StringType, BooleanType, FloatType, IntegerType

In [5]:
chefmozaccepts = datasets["chefmozaccepts"]

In [6]:
chefmozaccepts.count()

In [7]:
display(chefmozaccepts.describe())

summary,placeID,Rpayment
count,1314.0,1314
mean,133218.7397260274,
stddev,1058.2707377096558,
min,132002.0,American_Express
max,135110.0,gift_certificates


In [8]:
display(chefmozaccepts.head(5))

placeID,Rpayment
135110,cash
135110,VISA
135110,MasterCard-Eurocard
135110,American_Express
135110,bank_debit_cards


In [9]:
display(chefmozaccepts.groupBy("Rpayment").count().alias("paymentModeCount").orderBy(F.desc("count")))

Rpayment,count
cash,500
MasterCard-Eurocard,194
VISA,172
American_Express,153
bank_debit_cards,130
Visa,83
Diners_Club,42
Discover,11
checks,10
Carte_Blanche,7


In [10]:
chefmozcuisine = datasets["chefmozcuisine"]

In [11]:
chefmozcuisine.count()

In [12]:
display(chefmozcuisine.describe())

summary,placeID,Rcuisine
count,916.0,916
mean,132897.17467248908,
stddev,923.6017108449752,
min,132001.0,Afghan
max,135110.0,Vietnamese


In [13]:
display(
  chefmozcuisine.groupBy("Rcuisine").count().orderBy(F.desc("count")).limit(10)
)

Rcuisine,count
Mexican,239
International,62
American,59
Dutch-Belgian,55
Italian,42
Greek,33
Bar,32
French,31
Cafe-Coffee_Shop,27
Pizzeria,25


In [14]:
display(
  chefmozcuisine
  .groupBy("placeID").agg(F.countDistinct("Rcuisine").alias("Speciality_count"))
  .orderBy(F.desc("Speciality_count"))
  .filter(F.col("Speciality_count") > 1)
)

placeID,Speciality_count
132774,9
135099,6
135097,6
135103,4
135098,4
132177,3
132253,3
135007,3
132237,3
135053,3


In [15]:
chefmozhours4 = datasets["chefmozhours4"]
chefmozhours4.count()

In [16]:
display(chefmozhours4.head(5))

placeID,hours,days
135111,00:00-23:30;,Mon;Tue;Wed;Thu;Fri;
135111,00:00-23:30;,Sat;
135111,00:00-23:30;,Sun;
135110,08:00-19:00;,Mon;Tue;Wed;Thu;Fri;
135110,00:00-00:00;,Sat;


In [17]:
display(chefmozhours4.describe())

summary,placeID,hours,days
count,2339.0,2339,2339
mean,133082.31167165455,,
stddev,935.9970534596042,,
min,132012.0,00:00-00:00;,Mon;Tue;Wed;Thu;Fri;
max,135111.0,21:30-23:00;,Sun;


In [18]:
display(
  chefmozhours4.groupBy("hours").count().orderBy(F.desc("count")).limit(10)
)

hours,count
00:00-23:30;,681
00:00-00:00;,100
17:00-22:00;,56
14:00-23:30;,32
09:00-23:30;,31
11:00-21:00;,31
08:00-23:30;,30
11:00-22:00;,30
12:00-22:00;,29
12:00-23:30;,24


In [19]:
chefmozparking = datasets["chefmozparking"]
chefmozparking.count()

In [20]:
display(chefmozparking.head(5))

placeID,parking_lot
135111,public
135110,none
135109,none
135108,none
135107,none


In [21]:
display(chefmozparking.describe())

summary,placeID,parking_lot
count,702.0,702
mean,133180.94586894586,
stddev,942.044238115414,
min,132012.0,fee
max,135111.0,yes


In [22]:
display(
  chefmozparking
  .groupBy("parking_lot")
  .count()
  .orderBy(F.desc("count"))
)

parking_lot,count
none,348
yes,174
public,102
street,32
fee,22
valet parking,21
validated parking,3


In [23]:
rating_final = datasets["rating_final"]
user = datasets["userprofile"]
place = datasets["geoplaces2"]

In [24]:
display(rating_final.head(5))

userID,placeID,rating,food_rating,service_rating
U1077,135085,2,2,2
U1077,135038,2,2,1
U1077,132825,2,2,2
U1077,135060,1,2,2
U1068,135104,1,1,2


In [25]:
rating_cleaned = rating_final.select(
  F.col("userID"), 
  F.col("placeID"), 
  F.when(F.col("rating")=="2", 1).otherwise(0).alias("label")
)

display(rating_cleaned.head(5))

userID,placeID,label
U1077,135085,1
U1077,135038,1
U1077,132825,1
U1077,135060,0
U1068,135104,0


In [26]:
[f.name for f in user.schema.fields if f.dataType==StringType()]

In [27]:
print(*['F.col("{}"),\n'.format(c) for c in user.columns])

In [28]:
user_cleaned = user.select(
  F.col("userID"),
#  F.col("latitude"),
#  F.col("longitude"),
  F.col("smoker").cast(BooleanType()),
  F.col("drink_level"),
  F.col("dress_preference"),
  F.col("ambience"),
  F.col("transport"),
  F.col("marital_status"),
  F.col("hijos"),
  F.col("birth_year"),
  F.col("interest"),
  F.col("personality"),
  F.col("religion"),
  F.col("activity"),
  F.col("color"),
  F.col("weight").cast(FloatType()),
  F.col("budget"),
  F.col("height").cast(FloatType())
)

In [29]:
display(user_cleaned.head(5))

userID,smoker,drink_level,dress_preference,ambience,transport,marital_status,hijos,birth_year,interest,personality,religion,activity,color,weight,budget,height
U1001,False,abstemious,informal,family,on foot,single,independent,1989,variety,thrifty-protector,none,student,black,69.0,medium,1.7699999809265137
U1002,False,abstemious,informal,family,public,single,independent,1990,technology,hunter-ostentatious,Catholic,student,red,40.0,low,1.870000004768372
U1003,False,social drinker,formal,family,public,single,independent,1989,none,hard-worker,Catholic,student,blue,60.0,low,1.690000057220459
U1004,False,abstemious,informal,family,public,single,independent,1940,variety,hard-worker,none,professional,green,44.0,medium,1.5299999713897705
U1005,False,abstemious,no preference,family,public,single,independent,1992,none,thrifty-protector,Catholic,student,black,65.0,medium,1.690000057220459


In [30]:
print(*['F.col("{}"),\n'.format(c) for c in place.columns])

In [31]:
display(place.head(5))

placeID,latitude,longitude,the_geom_meter,name,address,city,state,country,fax,zip,alcohol,smoking_area,dress_code,accessibility,price,url,Rambience,franchise,area,other_services
134999,18.915421,-99.184871,0101000020957F000088568DE356715AC138C0A525FC464A41,Kiku Cuernavaca,Revolucion,Cuernavaca,Morelos,Mexico,?,?,No_Alcohol_Served,none,informal,no_accessibility,medium,kikucuernavaca.com.mx,familiar,f,closed,none
132825,22.1473922,-100.983092,0101000020957F00001AD016568C4858C1243261274BA54B41,puesto de tacos,esquina santos degollado y leon guzman,s.l.p.,s.l.p.,mexico,?,78280,No_Alcohol_Served,none,informal,completely,low,?,familiar,f,open,none
135106,22.1497088,-100.9760928,0101000020957F0000649D6F21634858C119AE9BF528A34B41,El Rinc�n de San Francisco,Universidad 169,San Luis Potosi,San Luis Potosi,Mexico,?,78000,Wine-Beer,only at bar,informal,partially,medium,?,familiar,f,open,none
132667,23.7526973,-99.1633594,0101000020957F00005D67BCDDED8157C1222A2DC8D84D4941,little pizza Emilio Portes Gil,calle emilio portes gil,victoria,tamaulipas,?,?,?,No_Alcohol_Served,none,informal,completely,low,?,familiar,t,closed,none
132613,23.7529035,-99.165076,0101000020957F00008EBA2D06DC8157C194E03B7B504E4941,carnitas_mata,lic. Emilio portes gil,victoria,Tamaulipas,Mexico,?,?,No_Alcohol_Served,permitted,informal,completely,medium,?,familiar,t,closed,none


In [32]:
place_cleaned = place.select(
  F.col("placeID"),
#  F.col("latitude"),
#  F.col("longitude"),
#  F.col("the_geom_meter"),
#  F.col("name"),
#  F.col("address"),
#  F.col("city"),
#  F.col("state"),
#  F.col("country"),
#  F.col("fax"),
#  F.col("zip"),
  F.col("alcohol"),
  F.col("smoking_area"),
  F.col("dress_code"),
  F.col("accessibility"),
  F.col("price"),
#  F.col("url"),
  F.col("Rambience"),
  F.col("franchise"),
  F.col("area"),
  F.col("other_services")
)
display(place_cleaned)

placeID,alcohol,smoking_area,dress_code,accessibility,price,Rambience,franchise,area,other_services
134999,No_Alcohol_Served,none,informal,no_accessibility,medium,familiar,f,closed,none
132825,No_Alcohol_Served,none,informal,completely,low,familiar,f,open,none
135106,Wine-Beer,only at bar,informal,partially,medium,familiar,f,open,none
132667,No_Alcohol_Served,none,informal,completely,low,familiar,t,closed,none
132613,No_Alcohol_Served,permitted,informal,completely,medium,familiar,t,closed,none
135040,Wine-Beer,none,informal,no_accessibility,high,familiar,f,closed,none
132732,No_Alcohol_Served,none,casual,completely,low,familiar,f,open,none
132875,Wine-Beer,section,informal,no_accessibility,high,familiar,t,open,Internet
132609,No_Alcohol_Served,not permitted,informal,completely,low,quiet,t,closed,none
135082,No_Alcohol_Served,none,informal,no_accessibility,medium,familiar,f,closed,none


In [33]:
full_df = rating_cleaned.join(user_cleaned, "userID", "left").join(place_cleaned, "placeID", "left").dropna()

display(full_df.head(5))

placeID,userID,label,smoker,drink_level,dress_preference,ambience,transport,marital_status,hijos,birth_year,interest,personality,religion,activity,color,weight,budget,height,alcohol,smoking_area,dress_code,accessibility,price,Rambience,franchise,area,other_services
135085,U1077,1,False,social drinker,elegant,family,public,married,kids,1987,technology,thrifty-protector,Catholic,student,blue,65.0,medium,1.7100000381469729,No_Alcohol_Served,not permitted,informal,no_accessibility,medium,familiar,f,closed,none
135038,U1077,1,False,social drinker,elegant,family,public,married,kids,1987,technology,thrifty-protector,Catholic,student,blue,65.0,medium,1.7100000381469729,No_Alcohol_Served,section,informal,no_accessibility,medium,familiar,f,closed,none
132825,U1077,1,False,social drinker,elegant,family,public,married,kids,1987,technology,thrifty-protector,Catholic,student,blue,65.0,medium,1.7100000381469729,No_Alcohol_Served,none,informal,completely,low,familiar,f,open,none
135060,U1077,0,False,social drinker,elegant,family,public,married,kids,1987,technology,thrifty-protector,Catholic,student,blue,65.0,medium,1.7100000381469729,No_Alcohol_Served,none,informal,no_accessibility,medium,familiar,f,closed,none
135104,U1068,0,False,casual drinker,informal,friends,public,single,independent,1988,technology,thrifty-protector,Catholic,student,blue,72.0,low,1.5700000524520874,Full_Bar,not permitted,informal,completely,medium,familiar,t,closed,variety


In [34]:
from pyspark.mllib.stat import Statistics
import pandas as pd

In [35]:
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StandardScaler
from pyspark.ml import Pipeline

string_columns = [f.name for f in full_df.schema.fields if f.dataType == StringType()][2:]
string_columns_Index = [f + "Index" for f in string_columns]
string_columns_Vec = [f + "Vec" for f in string_columns]

numeric_columns = [f.name for f in full_df.schema.fields if f.dataType != StringType()]

# String columns : One hot encoding 
string_indexer = StringIndexer(
  inputCols = string_columns, 
  outputCols = string_columns_Index
)  # inputCols et outputCols in spark 3.0

one_hot_encoder = OneHotEncoder( # Only one OneHotEncoder for all columns
  inputCols = string_columns_Index,
  outputCols = string_columns_Vec
)

# Assemble features in one unique vector
vector_assembler = VectorAssembler(
    inputCols = string_columns_Vec + numeric_columns, 
    outputCol = "features_not_scaled" # Pas le choix : toujours "features"
)

# Scale all features (can't be done before the VectorAssembler step because Scalers work with vectors and not with columns...)
standard_scaler = StandardScaler(inputCol="features_not_scaled", outputCol="features")

preprocessing_pipeline = Pipeline(
  stages= [string_indexer, one_hot_encoder, vector_assembler, standard_scaler]
)


In [36]:
preprocessing_pipeline_fitted = preprocessing_pipeline.fit(full_df)

dataset_preprocessed = preprocessing_pipeline_fitted.transform(full_df)

In [37]:
display(dataset_preprocessed.head(5))

placeID,userID,label,smoker,drink_level,dress_preference,ambience,transport,marital_status,hijos,birth_year,interest,personality,religion,activity,color,weight,budget,height,alcohol,smoking_area,dress_code,accessibility,price,Rambience,franchise,area,other_services,RambienceIndex,transportIndex,ambienceIndex,birth_yearIndex,interestIndex,budgetIndex,priceIndex,drink_levelIndex,religionIndex,smoking_areaIndex,dress_preferenceIndex,other_servicesIndex,hijosIndex,areaIndex,personalityIndex,marital_statusIndex,accessibilityIndex,dress_codeIndex,alcoholIndex,colorIndex,activityIndex,franchiseIndex,alcoholVec,hijosVec,activityVec,religionVec,ambienceVec,franchiseVec,budgetVec,drink_levelVec,dress_codeVec,birth_yearVec,marital_statusVec,other_servicesVec,transportVec,dress_preferenceVec,interestVec,accessibilityVec,personalityVec,areaVec,smoking_areaVec,colorVec,priceVec,RambienceVec,features_not_scaled,features
135085,U1077,1,False,social drinker,elegant,family,public,married,kids,1987,technology,thrifty-protector,Catholic,student,blue,65.0,medium,1.7100000381469729,No_Alcohol_Served,not permitted,informal,no_accessibility,medium,familiar,f,closed,none,0.0,0.0,0.0,4.0,1.0,0.0,0.0,2.0,0.0,2.0,3.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,"List(0, 2, List(0), List(1.0))","List(0, 3, List(1), List(1.0))","List(0, 4, List(0), List(1.0))","List(0, 4, List(0), List(1.0))","List(0, 3, List(0), List(1.0))","List(0, 1, List(0), List(1.0))","List(0, 3, List(0), List(1.0))","List(0, 2, List(), List())","List(0, 2, List(0), List(1.0))","List(0, 20, List(4), List(1.0))","List(0, 3, List(1), List(1.0))","List(0, 2, List(0), List(1.0))","List(0, 3, List(0), List(1.0))","List(0, 4, List(3), List(1.0))","List(0, 4, List(1), List(1.0))","List(0, 2, List(0), List(1.0))","List(0, 3, List(1), List(1.0))","List(0, 1, List(0), List(1.0))","List(0, 4, List(2), List(1.0))","List(0, 7, List(0), List(1.0))","List(0, 2, List(0), List(1.0))","List(0, 1, List(0), List(1.0))","List(0, 84, List(5, 6, 9, 13, 16, 22, 39, 43, 45, 49, 53, 60, 63, 67, 69, 71, 73, 75, 76, 77, 78, 80, 82, 83), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 65.0, 1.7100000381469727))","List(0, 84, List(5, 6, 9, 13, 16, 22, 39, 43, 45, 49, 53, 60, 63, 67, 69, 71, 73, 75, 76, 77, 78, 80, 82, 83), List(6.430080052290919, 2.0017534024588266, 2.0279876035569173, 4.232355691294307, 3.8529965947639644, 4.146524540888921, 2.280812862741132, 2.039589452300133, 2.3591480815390318, 2.8746039339998752, 2.095031560615419, 2.095031560615419, 2.1025098027232763, 2.5243766601831674, 3.1647096091325246, 2.0867309823636893, 1.999193132980752, 4.389395576184009, 2.913405417419759, 3.1867511802111435, 4.292938931075585, 2.026759636381925, 4.069458723569071, 15.120833680304443))"
135038,U1077,1,False,social drinker,elegant,family,public,married,kids,1987,technology,thrifty-protector,Catholic,student,blue,65.0,medium,1.7100000381469729,No_Alcohol_Served,section,informal,no_accessibility,medium,familiar,f,closed,none,0.0,0.0,0.0,4.0,1.0,0.0,0.0,2.0,0.0,1.0,3.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,"List(0, 2, List(0), List(1.0))","List(0, 3, List(1), List(1.0))","List(0, 4, List(0), List(1.0))","List(0, 4, List(0), List(1.0))","List(0, 3, List(0), List(1.0))","List(0, 1, List(0), List(1.0))","List(0, 3, List(0), List(1.0))","List(0, 2, List(), List())","List(0, 2, List(0), List(1.0))","List(0, 20, List(4), List(1.0))","List(0, 3, List(1), List(1.0))","List(0, 2, List(0), List(1.0))","List(0, 3, List(0), List(1.0))","List(0, 4, List(3), List(1.0))","List(0, 4, List(1), List(1.0))","List(0, 2, List(0), List(1.0))","List(0, 3, List(1), List(1.0))","List(0, 1, List(0), List(1.0))","List(0, 4, List(1), List(1.0))","List(0, 7, List(0), List(1.0))","List(0, 2, List(0), List(1.0))","List(0, 1, List(0), List(1.0))","List(0, 84, List(5, 6, 9, 13, 16, 22, 39, 43, 45, 49, 53, 60, 63, 66, 69, 71, 73, 75, 76, 77, 78, 80, 82, 83), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 65.0, 1.7100000381469727))","List(0, 84, List(5, 6, 9, 13, 16, 22, 39, 43, 45, 49, 53, 60, 63, 66, 69, 71, 73, 75, 76, 77, 78, 80, 82, 83), List(6.430080052290919, 2.0017534024588266, 2.0279876035569173, 4.232355691294307, 3.8529965947639644, 4.146524540888921, 2.280812862741132, 2.039589452300133, 2.3591480815390318, 2.8746039339998752, 2.095031560615419, 2.095031560615419, 2.1025098027232763, 2.511482276237347, 3.1647096091325246, 2.0867309823636893, 1.999193132980752, 4.389395576184009, 2.913405417419759, 3.1867511802111435, 4.292938931075585, 2.026759636381925, 4.069458723569071, 15.120833680304443))"
132825,U1077,1,False,social drinker,elegant,family,public,married,kids,1987,technology,thrifty-protector,Catholic,student,blue,65.0,medium,1.7100000381469729,No_Alcohol_Served,none,informal,completely,low,familiar,f,open,none,0.0,0.0,0.0,4.0,1.0,0.0,1.0,2.0,0.0,0.0,3.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,"List(0, 2, List(0), List(1.0))","List(0, 3, List(1), List(1.0))","List(0, 4, List(0), List(1.0))","List(0, 4, List(0), List(1.0))","List(0, 3, List(0), List(1.0))","List(0, 1, List(0), List(1.0))","List(0, 3, List(0), List(1.0))","List(0, 2, List(), List())","List(0, 2, List(0), List(1.0))","List(0, 20, List(4), List(1.0))","List(0, 3, List(1), List(1.0))","List(0, 2, List(0), List(1.0))","List(0, 3, List(0), List(1.0))","List(0, 4, List(3), List(1.0))","List(0, 4, List(1), List(1.0))","List(0, 2, List(1), List(1.0))","List(0, 3, List(1), List(1.0))","List(0, 1, List(), List())","List(0, 4, List(0), List(1.0))","List(0, 7, List(0), List(1.0))","List(0, 2, List(1), List(1.0))","List(0, 1, List(0), List(1.0))","List(0, 84, List(5, 6, 9, 13, 16, 22, 39, 43, 45, 49, 53, 60, 63, 65, 69, 72, 74, 75, 76, 78, 80, 82, 83), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 65.0, 1.7100000381469727))","List(0, 84, List(5, 6, 9, 13, 16, 22, 39, 43, 45, 49, 53, 60, 63, 65, 69, 72, 74, 75, 76, 78, 80, 82, 83), List(6.430080052290919, 2.0017534024588266, 2.0279876035569173, 4.232355691294307, 3.8529965947639644, 4.146524540888921, 2.280812862741132, 2.039589452300133, 2.3591480815390318, 2.8746039339998752, 2.095031560615419, 2.095031560615419, 2.1025098027232763, 2.008654582582969, 3.1647096091325246, 2.1905664175050243, 2.1739610888729937, 4.389395576184009, 2.913405417419759, 4.292938931075585, 2.026759636381925, 4.069458723569071, 15.120833680304443))"
135060,U1077,0,False,social drinker,elegant,family,public,married,kids,1987,technology,thrifty-protector,Catholic,student,blue,65.0,medium,1.7100000381469729,No_Alcohol_Served,none,informal,no_accessibility,medium,familiar,f,closed,none,0.0,0.0,0.0,4.0,1.0,0.0,0.0,2.0,0.0,0.0,3.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,"List(0, 2, List(0), List(1.0))","List(0, 3, List(1), List(1.0))","List(0, 4, List(0), List(1.0))","List(0, 4, List(0), List(1.0))","List(0, 3, List(0), List(1.0))","List(0, 1, List(0), List(1.0))","List(0, 3, List(0), List(1.0))","List(0, 2, List(), List())","List(0, 2, List(0), List(1.0))","List(0, 20, List(4), List(1.0))","List(0, 3, List(1), List(1.0))","List(0, 2, List(0), List(1.0))","List(0, 3, List(0), List(1.0))","List(0, 4, List(3), List(1.0))","List(0, 4, List(1), List(1.0))","List(0, 2, List(0), List(1.0))","List(0, 3, List(1), List(1.0))","List(0, 1, List(0), List(1.0))","List(0, 4, List(0), List(1.0))","List(0, 7, List(0), List(1.0))","List(0, 2, List(0), List(1.0))","List(0, 1, List(0), List(1.0))","List(0, 84, List(5, 6, 9, 13, 16, 22, 39, 43, 45, 49, 53, 60, 63, 65, 69, 71, 73, 75, 76, 77, 78, 82, 83), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 65.0, 1.7100000381469727))","List(0, 84, List(5, 6, 9, 13, 16, 22, 39, 43, 45, 49, 53, 60, 63, 65, 69, 71, 73, 75, 76, 77, 78, 82, 83), List(6.430080052290919, 2.0017534024588266, 2.0279876035569173, 4.232355691294307, 3.8529965947639644, 4.146524540888921, 2.280812862741132, 2.039589452300133, 2.3591480815390318, 2.8746039339998752, 2.095031560615419, 2.095031560615419, 2.1025098027232763, 2.008654582582969, 3.1647096091325246, 2.0867309823636893, 1.999193132980752, 4.389395576184009, 2.913405417419759, 3.1867511802111435, 4.292938931075585, 4.069458723569071, 15.120833680304443))"
135104,U1068,0,False,casual drinker,informal,friends,public,single,independent,1988,technology,thrifty-protector,Catholic,student,blue,72.0,low,1.5700000524520874,Full_Bar,not permitted,informal,completely,medium,familiar,t,closed,variety,0.0,0.0,1.0,3.0,1.0,1.0,0.0,0.0,0.0,2.0,2.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,2.0,0.0,0.0,1.0,"List(0, 2, List(), List())","List(0, 3, List(0), List(1.0))","List(0, 4, List(0), List(1.0))","List(0, 4, List(0), List(1.0))","List(0, 3, List(1), List(1.0))","List(0, 1, List(), List())","List(0, 3, List(1), List(1.0))","List(0, 2, List(0), List(1.0))","List(0, 2, List(0), List(1.0))","List(0, 20, List(3), List(1.0))","List(0, 3, List(0), List(1.0))","List(0, 2, List(1), List(1.0))","List(0, 3, List(0), List(1.0))","List(0, 4, List(2), List(1.0))","List(0, 4, List(1), List(1.0))","List(0, 2, List(1), List(1.0))","List(0, 3, List(1), List(1.0))","List(0, 1, List(0), List(1.0))","List(0, 4, List(2), List(1.0))","List(0, 7, List(0), List(1.0))","List(0, 2, List(0), List(1.0))","List(0, 1, List(0), List(1.0))","List(0, 84, List(0, 4, 7, 9, 12, 15, 21, 39, 43, 45, 49, 53, 61, 67, 69, 72, 73, 75, 77, 79, 82, 83), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 72.0, 1.5700000524520874))","List(0, 84, List(0, 4, 7, 9, 12, 15, 21, 39, 43, 45, 49, 53, 61, 67, 69, 72, 73, 75, 77, 79, 82, 83), List(2.0735242720382665, 2.2082415234516, 2.105078895612003, 2.0279876035569173, 3.568208018318354, 2.809596905044346, 4.066186264524531, 2.280812862741132, 2.039589452300133, 2.3591480815390318, 2.8746039339998752, 2.095031560615419, 2.227052468421988, 2.5243766601831674, 3.1647096091325246, 2.1905664175050243, 1.999193132980752, 4.389395576184009, 3.1867511802111435, 5.851153035256578, 4.5077081245688175, 13.88287084304548))"


In [38]:
rdd_table = dataset_preprocessed.select(cols).rdd.map(lambda row: row[0:])
corr_mat = pd.DataFrame(Statistics.corr(rdd_table, method="pearson")[-1], columns=["label"], index=cols)
corr_mat.sort_values(by="label", ascending=False)

Unnamed: 0,label
label,1.0
interestIndex,0.147376
drink_levelIndex,0.14111
birth_yearIndex,0.102381
transportIndex,0.086598
dress_codeIndex,0.074044
alcoholIndex,0.056952
other_servicesIndex,0.053715
hijosIndex,0.025359
marital_statusIndex,0.021331


## La cible est le plus fortement corrélée avec l'interest, le drink_level et inversement à color

In [40]:
display(full_df)

placeID,userID,label,smoker,drink_level,dress_preference,ambience,transport,marital_status,hijos,birth_year,interest,personality,religion,activity,color,weight,budget,height,alcohol,smoking_area,dress_code,accessibility,price,Rambience,franchise,area,other_services
135085,U1077,1,False,social drinker,elegant,family,public,married,kids,1987,technology,thrifty-protector,Catholic,student,blue,65.0,medium,1.71,No_Alcohol_Served,not permitted,informal,no_accessibility,medium,familiar,f,closed,none
135038,U1077,1,False,social drinker,elegant,family,public,married,kids,1987,technology,thrifty-protector,Catholic,student,blue,65.0,medium,1.71,No_Alcohol_Served,section,informal,no_accessibility,medium,familiar,f,closed,none
132825,U1077,1,False,social drinker,elegant,family,public,married,kids,1987,technology,thrifty-protector,Catholic,student,blue,65.0,medium,1.71,No_Alcohol_Served,none,informal,completely,low,familiar,f,open,none
135060,U1077,0,False,social drinker,elegant,family,public,married,kids,1987,technology,thrifty-protector,Catholic,student,blue,65.0,medium,1.71,No_Alcohol_Served,none,informal,no_accessibility,medium,familiar,f,closed,none
135104,U1068,0,False,casual drinker,informal,friends,public,single,independent,1988,technology,thrifty-protector,Catholic,student,blue,72.0,low,1.57,Full_Bar,not permitted,informal,completely,medium,familiar,t,closed,variety
132740,U1068,0,False,casual drinker,informal,friends,public,single,independent,1988,technology,thrifty-protector,Catholic,student,blue,72.0,low,1.57,No_Alcohol_Served,permitted,informal,completely,low,familiar,f,open,none
132663,U1068,0,False,casual drinker,informal,friends,public,single,independent,1988,technology,thrifty-protector,Catholic,student,blue,72.0,low,1.57,No_Alcohol_Served,none,informal,completely,low,familiar,f,closed,none
132732,U1068,0,False,casual drinker,informal,friends,public,single,independent,1988,technology,thrifty-protector,Catholic,student,blue,72.0,low,1.57,No_Alcohol_Served,none,casual,completely,low,familiar,f,open,none
132630,U1068,0,False,casual drinker,informal,friends,public,single,independent,1988,technology,thrifty-protector,Catholic,student,blue,72.0,low,1.57,No_Alcohol_Served,none,informal,completely,low,familiar,f,closed,none
132584,U1067,1,False,abstemious,no preference,family,public,single,independent,1987,technology,thrifty-protector,Christian,student,green,92.0,medium,1.73,No_Alcohol_Served,not permitted,informal,completely,medium,familiar,t,closed,none


In [41]:
train, test = dataset_preprocessed.randomSplit([0.7, 0.3], 100)

In [42]:
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

rf = RandomForestClassifier(labelCol="label", featuresCol="features")

evaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction", metricName="areaUnderPR") # dataset déséquilibré

paramGrid = ParamGridBuilder()\
             .addGrid(rf.maxDepth, [2, 5])\
             .addGrid(rf.maxBins, [4, 8])\
             .addGrid(rf.numTrees, [2, 5])\
             .build()
 
cv_rf = CrossValidator(estimator=rf, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=5)
cv_rf_fitted = cv_rf.fit(train)

In [43]:
bestModel = cv_rf_fitted.bestModel

bestModelParameters = {
  key.name: value
  for key, value in bestModel.extractParamMap().items()
  if key.name in ["maxDepth", "maxBins", "numTrees"]
}

print(bestModelParameters)

In [44]:
predictions_train = cv_rf_fitted.transform(train)
predictions_test = cv_rf_fitted.transform(test)

In [45]:
print(
  "AUC train : {0} - AUC test : {1}"
  .format(
    evaluator.evaluate(predictions_train), 
    evaluator.evaluate(predictions_test)
  )
)

In [46]:
cf_count = predictions_test.select(F.col("prediction"), F.col("label").cast(IntegerType()))\
                           .groupBy("prediction")\
                           .pivot("label")\
                           .count() 

display(cf_count)

prediction,0,1
0.0,196.0,
1.0,,137.0


In [47]:
rating = rating_final.select(
  F.col("userID"),
  F.col("placeID").cast(IntegerType()),
  F.col("rating").cast(IntegerType())
)

userIdIndexer = StringIndexer(inputCol="userID", outputCol="userIdIndex").fit(rating)

rating = userIdIndexer.transform(rating).select(
  F.col("userIdIndex"),
  F.col("placeID").cast(IntegerType()),
  F.col("rating").cast(IntegerType())
)

display(rating)

userIdIndex,placeID,rating
112.0,135085,2
112.0,135038,2
112.0,132825,2
112.0,135060,1
78.0,135104,1
78.0,132740,0
78.0,132663,1
78.0,132732,0
78.0,132630,1
95.0,132584,2


In [48]:
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

train_r, test_r = rating.randomSplit([0.7, 0.3], seed=42)

In [49]:
als = ALS(userCol="userIdIndex", itemCol="placeID", ratingCol="rating", coldStartStrategy="drop", nonnegative=True)

evaluator_r = RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction")

paramGrid_r = (ParamGridBuilder()
               .addGrid(als.rank, [1, 3])
               .addGrid(als.maxIter, [5, 10])
               .build()
              )

cv_als = CrossValidator(
  estimator=als,
  evaluator=evaluator_r,
  estimatorParamMaps=paramGrid_r,
  numFolds=5
)

In [50]:
model = cv_als.fit(train_r)

In [51]:
bestModel = model.bestModel

print("rank: {}, maxIter: {}".format(bestModel.rank, bestModel._java_obj.parent().getMaxIter()))

In [52]:
predictions_train_r = model.transform(train_r)
predictions_test_r = model.transform(test_r)

In [53]:
print(
  "rmse train : {0} - rmse test : {1}"
  .format(
    evaluator_r.evaluate(predictions_train_r), 
    evaluator_r.evaluate(predictions_test_r)
  )
)

In [54]:
from pyspark.ml.feature import IndexToString

predictions_test_r = IndexToString(inputCol="userIdIndex", outputCol="userID").transform(predictions_test_r).drop("userIdIndex")

display(predictions_test_r)

placeID,rating,prediction,userID
135000,2,1.7727768,U1100
135000,1,0.8122277,U1117
135027,2,0.8517891,U1101
135027,0,0.95360076,U1116
135066,2,1.5807979,U1106
135066,1,0.80565107,U1018
135108,1,0.9829387,U1007
135108,1,0.9633849,U1008
135108,2,1.2215414,U1126
135108,2,0.9867994,U1111
