In [2]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer
from pyspark.ml.recommendation import ALS
from pyspark.sql.functions import explode, col, lit, array, struct, udf
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql.types import StringType

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [3]:
spark = SparkSession.builder.appName("recsys").getOrCreate()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [4]:
bucket = "recsys-aws"
key_user_ad_matrix_prefix = "silver_data/user_ad_matrix/"
user_ad_interactions_df = spark.read.csv(f"s3://{bucket}/{key_user_ad_matrix_prefix}",
                                         header=True,
                                         inferSchema=True)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [5]:
cols = [c for c in user_ad_interactions_df.columns if c != "userId"]
long_format = (user_ad_interactions_df
               .withColumn("adId_rating", explode(array([struct(col(c).alias("rating"), lit(c).alias("adId")) for c in cols])))
               .select("userId", "adId_rating.adId", "adId_rating.rating"))

user_indexer = StringIndexer(inputCol="userId", outputCol="userId_indexed")
ad_indexer = StringIndexer(inputCol="adId", outputCol="adId_indexed")
user_model = user_indexer.fit(long_format)
adId_model = ad_indexer.fit(long_format)
long_format = user_model.transform(long_format)
long_format = adId_model.transform(long_format)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [6]:
(training, test) = long_format.randomSplit([0.8, 0.2])

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [7]:
als = ALS(
    maxIter=10, 
    regParam=0.01, 
    userCol="userId_indexed", 
    itemCol="adId_indexed", 
    ratingCol="rating", 
    coldStartStrategy="drop"
)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [8]:
model = als.fit(training)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [9]:
user_id_to_recommend = "4f3aecdc-f7d8-4718-925c-96d81c3765f3"
n_recommendations = 10
user_indexed = user_model.transform(spark.createDataFrame([(user_id_to_recommend,)], ["userId"]))
recs = model.recommendForUserSubset(user_indexed, n_recommendations)
recs = recs.withColumn("adId_indexed", explode(col("recommendations.adId_indexed")))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [10]:
def index_to_id(index):
    return ad_id_labels[int(index)]

ad_id_labels = adId_model.labels
index_to_id_udf = udf(index_to_id, StringType())
recs_with_original_ids = recs.withColumn("original_adId", index_to_id_udf("adId_indexed"))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [11]:
exploded_recommendations = recs_with_original_ids.select("userId_indexed", explode("recommendations").alias("recommendation"))
exploded_adId_indexed = exploded_recommendations.select("userId_indexed", "recommendation.adId_indexed")

labels_df = spark.createDataFrame([(i, label) for i, label in enumerate(ad_id_labels)], ["index", "label"])
final_recommended_adIds = exploded_adId_indexed.join(labels_df, exploded_adId_indexed.adId_indexed == labels_df.index)
final_result = final_recommended_adIds.select("userId_indexed", "label").withColumnRenamed("label", "recommended_adId")
final_result = final_result.dropDuplicates(['userId_indexed', 'recommended_adId'])
final_result.show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------+--------------------+
|userId_indexed|    recommended_adId|
+--------------+--------------------+
|            27|d83f913c-87e3-48c...|
|            27|552bfdf5-f621-4ff...|
|            27|d5f4d7ef-ac83-48f...|
|            27|1a47c3ae-0788-442...|
|            27|e8ea96db-964a-4ca...|
|            27|bc4d3e39-7e06-4c3...|
|            27|a4596d16-e59d-40c...|
|            27|13a8f121-c79f-476...|
|            27|55b3b717-b1d3-41c...|
|            27|023d7aba-d5df-459...|
+--------------+--------------------+