In [6]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.ml.recommendation import ALS

In [3]:
spark = (
    SparkSession.builder 
    .appName("ALSTraining") 
    .master("local[*]")  # Adjust as needed for your environment
    .config("spark.driver.memory", "4g")  # Adjust as needed
    .config("spark.executor.memory", "4g") # Adjust as needed
    .config("spark.sql.shuffle.partitions", "200") # Adjust as needed
    .getOrCreate()
)

print("Spark Session created for training")

Spark Session created for training


In [4]:
user_id_mapping_path = "/home/jovyan/work/id_mappings/user_id_mapping"
user_id_mapping_loaded = spark.read.parquet(user_id_mapping_path)
user_id_mapping_loaded.show(5)

business_id_mapping_path = "/home/jovyan/work/id_mappings/business_id_mapping"
business_id_mapping_loaded = spark.read.parquet(business_id_mapping_path)
business_id_mapping_loaded.show(5)

print("ID mappings loaded successfully")

+--------------------+------------+
|             user_id|     userCol|
+--------------------+------------+
|FCZKqerr3-dq9SXG6...|558345748480|
|CIFrgKgUrr6xyigd1...|558345748481|
|1e0y0vUJiEU2l2U5Z...|558345748482|
|KvcNVhNrnIYEEQtR6...|558345748483|
|Q5kovyJgHn6_rTy4X...|558345748484|
+--------------------+------------+
only showing top 5 rows

+--------------------+------------+
|         business_id|     itemCol|
+--------------------+------------+
|aH3xlewNKQk5K4mEk...|300647710720|
|-q4YqKsWJY6NtkQdn...|300647710721|
|cv-SmPhbpwQCtlI2Q...|300647710722|
|wIXYreqGaO5AEVjNQ...|300647710723|
|6d25hRt6Hz4SPc9Ih...|300647710724|
+--------------------+------------+
only showing top 5 rows

ID mappings loaded successfully


In [5]:
file_path = "yelp_train.csv"  # Replace with the actual path to your data
df = spark.read.csv(file_path, header=True, inferSchema=True)

als_input_df_loaded = df.join(user_id_mapping_loaded, "user_id", "inner").drop("user_id")
als_input_df_loaded = als_input_df_loaded.join(business_id_mapping_loaded, "business_id", "inner").drop("business_id")

als_input_df_loaded = als_input_df_loaded.select(
    col("userCol").cast("int"),
    col("itemCol").cast("int"),
    col("stars").cast("float").alias("rating") # ALS expects 'rating' column
)

als_input_df_loaded.printSchema()
als_input_df_loaded.show(5)

print("ALS input data loaded and prepared")

root
 |-- userCol: integer (nullable = true)
 |-- itemCol: integer (nullable = true)
 |-- rating: float (nullable = true)

+-------+-------+------+
|userCol|itemCol|rating|
+-------+-------+------+
|  13125|    261|   1.0|
|  13387|    282|   5.0|
|   9030|    119|   4.0|
|  15762|    393|   5.0|
|  12526|    406|   1.0|
+-------+-------+------+
only showing top 5 rows

ALS input data loaded and prepared


In [7]:
als = ALS(userCol="userCol",
          itemCol="itemCol",
          ratingCol="rating",
          rank=10,          # Number of latent factors
          maxIter=10,       # Maximum number of iterations
          regParam=0.01,    # Regularization parameter
          coldStartStrategy="drop") # Handle new users/items

# Train the model
model = als.fit(als_input_df_loaded)

print("ALS model trained successfully")

ALS model trained successfully


In [10]:
# Get user latent factors with original user IDs
user_latent_vectors = model.userFactors.withColumnRenamed("id", "userCol").join(user_id_mapping_loaded, "userCol", "inner")
user_latent_vectors.show(5)

# Get item latent factors with original business IDs
item_latent_vectors = model.itemFactors.withColumnRenamed("id", "itemCol").join(business_id_mapping_loaded, "itemCol", "inner")
item_latent_vectors.show(5)

+-------+--------------------+--------------------+
|userCol|            features|             user_id|
+-------+--------------------+--------------------+
|     26|[-1.1553166, 0.45...|HB8tVmNjWaa_18nWW...|
|     29|[-0.8118293, 0.50...|cNoBKr08_gQgYW0Kc...|
|    474|[-1.1891398, 0.38...|IZaWH5nva6mXn_MVx...|
|    964|[-1.4100026, 0.27...|DLAdQTxg2jMH-bGWZ...|
|   1677|[-1.3440387, 0.34...|YuR6Z0uhC1I_e6t7U...|
+-------+--------------------+--------------------+
only showing top 5 rows

+-------+--------------------+--------------------+
|itemCol|            features|         business_id|
+-------+--------------------+--------------------+
|      0|[-0.13696466, 0.0...|NLV0ppsHTiJk6JVdF...|
|     10|[0.053821385, 0.3...|VdB1YL718sAxae12P...|
|     20|[-0.38813618, -0....|YlAI0sW0bsVsvErTn...|
|     30|[-8.9343965E-5, -...|7hRaOnXRRS8q620F6...|
|     40|[-0.1328803, -0.0...|TF3qNGUBUgIYp6u0j...|
+-------+--------------------+--------------------+
only showing top 5 rows

