In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.recommendation import ALS
from pyspark.sql.functions import lit

In [2]:

spark = SparkSession.builder \
	.appName("RSys") \
	.config("spark.executor.memory", "4g") \
    .config("spark.executor.cores", "4") \
	.getOrCreate()

24/05/24 20:35:46 WARN Utils: Your hostname, trnmah-IdeaPad-Gaming-3-15ACH6 resolves to a loopback address: 127.0.1.1; using 192.168.0.7 instead (on interface wlo1)
24/05/24 20:35:46 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/24 20:35:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
df = spark.read.format("csv").option("delimiter", "\t").load("../moviedata/ml-100k/ua.base")

In [4]:
df = df.drop("_c3")

In [5]:
df.head(5)

[Row(_c0='1', _c1='1', _c2='5'),
 Row(_c0='1', _c1='2', _c2='3'),
 Row(_c0='1', _c1='3', _c2='4'),
 Row(_c0='1', _c1='4', _c2='3'),
 Row(_c0='1', _c1='5', _c2='3')]

In [6]:
col_names = ["user_id", "item_id", "rating"]

In [7]:
for i, col_name in enumerate(col_names):
	df = df.withColumnRenamed("_c" + str(col_names.index(col_name)), col_name) #rename columns
	df = df.withColumn(col_name, df[col_name].cast("int")) #cast columns to integer

In [8]:
df.schema

StructType([StructField('user_id', IntegerType(), True), StructField('item_id', IntegerType(), True), StructField('rating', IntegerType(), True)])

In [9]:
df.head(5)

[Row(user_id=1, item_id=1, rating=5),
 Row(user_id=1, item_id=2, rating=3),
 Row(user_id=1, item_id=3, rating=4),
 Row(user_id=1, item_id=4, rating=3),
 Row(user_id=1, item_id=5, rating=3)]

In [10]:
als = ALS(maxIter=5, regParam=0.01, userCol="user_id", itemCol="item_id", ratingCol="rating", coldStartStrategy="drop")

In [11]:
model = als.fit(df)

In [12]:

# Get all unique user_ids and item_ids
unique_user_ids = df.select("user_id").distinct()
unique_item_ids = df.select("item_id").distinct()


In [13]:
unique_item_ids.head(5)

[Row(item_id=148),
 Row(item_id=471),
 Row(item_id=496),
 Row(item_id=463),
 Row(item_id=833)]

In [14]:
unique_user_ids.head(5)

[Row(user_id=148),
 Row(user_id=463),
 Row(user_id=471),
 Row(user_id=496),
 Row(user_id=833)]

In [15]:
# Create a dataframe of all user_id and item_id pairs
all_pairs = unique_user_ids.crossJoin(unique_item_ids)
# Remove pairs that are present in the original dataframe
missing_pairs = all_pairs.join(df, on=["user_id", "item_id"], how="left_anti")





In [16]:
all_pairs.count()

1584240

In [17]:
missing_pairs.head(5)

[Row(user_id=148, item_id=148),
 Row(user_id=463, item_id=148),
 Row(user_id=471, item_id=148),
 Row(user_id=496, item_id=148),
 Row(user_id=833, item_id=148)]

In [39]:
# Add a dummy column for rating
# missing_pairs = missing_pairs.withColumn("rating", lit(0))

In [40]:
missing_pairs.head(5)

[Row(user_id=148, item_id=148, rating=0),
 Row(user_id=463, item_id=148, rating=0),
 Row(user_id=471, item_id=148, rating=0),
 Row(user_id=496, item_id=148, rating=0),
 Row(user_id=833, item_id=148, rating=0)]

In [62]:
# predict the ratings for the missing pairs
predictions = model.transform(missing_pairs)
predictions.head(5)

In [47]:
predictions.head(5)

[Row(user_id=148, item_id=148, prediction=2.535693645477295),
 Row(user_id=463, item_id=148, prediction=2.374476671218872),
 Row(user_id=471, item_id=148, prediction=5.163250923156738),
 Row(user_id=496, item_id=148, prediction=2.6346185207366943),
 Row(user_id=833, item_id=148, prediction=1.1117370128631592)]

In [53]:
missing_pairs.count()

1493670

In [61]:
predictions.select('*').filter('user_id = 1').show()

+-------+-------+----------+
|user_id|item_id|prediction|
+-------+-------+----------+
|      1|    471| 3.4238014|
|      1|    496|   3.70701|
|      1|    463| 3.7824821|
|      1|    833|  4.629636|
|      1|   1088| 2.4879954|
|      1|   1238| 3.0842245|
|      1|   1342| 1.0175483|
|      1|   1580| 0.7497052|
|      1|   1591| 1.2021633|
|      1|   1645| 5.1979885|
|      1|    392| 3.1565173|
|      1|    540| 2.7254908|
|      1|    623| 1.8366802|
|      1|    737| 3.6124377|
|      1|    858| 0.9244086|
|      1|    897| 1.0481304|
|      1|   1084| 5.5868597|
|      1|   1025| 2.6102464|
|      1|   1127|  3.204022|
|      1|   1395|  4.680956|
+-------+-------+----------+
only showing top 20 rows



In [51]:
df.count()

90570