In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.ml.recommendation import ALS

In [2]:
spark = (
    SparkSession.builder 
    .appName("ALSTraining") 
    .master("local[*]")  # Adjust as needed for your environment
    .config("spark.driver.memory", "4g")  # Adjust as needed
    .config("spark.executor.memory", "4g") # Adjust as needed
    .config("spark.sql.shuffle.partitions", "200") # Adjust as needed
    .getOrCreate()
)

print("Spark Session created for training")

Spark Session created for training


In [3]:
user_id_mapping_path = "/home/jovyan/work/id_mappings/user_id_mapping"
user_id_mapping_loaded = spark.read.parquet(user_id_mapping_path)
user_id_mapping_loaded.show(5)

business_id_mapping_path = "/home/jovyan/work/id_mappings/business_id_mapping"
business_id_mapping_loaded = spark.read.parquet(business_id_mapping_path)
business_id_mapping_loaded.show(5)

print("ID mappings loaded successfully")

+--------------------+------------+
|             user_id|     userCol|
+--------------------+------------+
|--2bpE5vyR-2hAP7s...|274877906944|
|--T_QxqWcEu76n1da...|274877906945|
|--q3Qv-yYG9jFqXlM...|274877906946|
|-0BfVK9AA00ynhvW6...|274877906947|
|-0bmx13qzWqXCafeA...|274877906948|
+--------------------+------------+
only showing top 5 rows

+--------------------+------------+
|         business_id|     itemCol|
+--------------------+------------+
|NAMen7YzwlYDs_5EC...|326417514496|
|BjBDHqHhMXSxgyVip...|326417514497|
|wm5mQ4cSpvko9WlCq...|326417514498|
|KD9-X5AykKmiZszrM...|326417514499|
|eLOqWp2OLfr5dfJHw...|326417514500|
+--------------------+------------+
only showing top 5 rows

ID mappings loaded successfully


In [5]:
file_path = "training_data.csv"  # Replace with the actual path to your data
df = spark.read.csv(file_path, header=True, inferSchema=True)

als_input_df_loaded = df.join(user_id_mapping_loaded, "user_id", "inner").drop("user_id")
als_input_df_loaded = als_input_df_loaded.join(business_id_mapping_loaded, "business_id", "inner").drop("business_id")

als_input_df_loaded = als_input_df_loaded.select(
    col("userCol").cast("int"),
    col("itemCol").cast("int"),
    col("stars").cast("float").alias("rating") # ALS expects 'rating' column
)

als_input_df_loaded.printSchema()
als_input_df_loaded.show(5)

print("ALS input data loaded and prepared")

root
 |-- userCol: integer (nullable = true)
 |-- itemCol: integer (nullable = true)
 |-- rating: float (nullable = true)

+-------+-------+------+
|userCol|itemCol|rating|
+-------+-------+------+
|      0|    203|   1.0|
|      1|    215|   5.0|
|     89|    398|   4.0|
|      2|      6|   5.0|
|      3|    120|   1.0|
+-------+-------+------+
only showing top 5 rows

ALS input data loaded and prepared


In [6]:
als = ALS(userCol="userCol",
          itemCol="itemCol",
          ratingCol="rating",
          rank=10,          # Number of latent factors
          maxIter=10,       # Maximum number of iterations
          regParam=0.01,    # Regularization parameter
          coldStartStrategy="drop") # Handle new users/items

# Train the model
model = als.fit(als_input_df_loaded)

print("ALS model trained successfully")

ALS model trained successfully


In [7]:
# Get user latent factors with original user IDs
user_latent_vectors = model.userFactors.withColumnRenamed("id", "userCol").join(user_id_mapping_loaded, "userCol", "inner")
user_latent_vectors.show(5)

# Get item latent factors with original business IDs
item_latent_vectors = model.itemFactors.withColumnRenamed("id", "itemCol").join(business_id_mapping_loaded, "itemCol", "inner")
item_latent_vectors.show(5)

+-------+--------------------+--------------------+
|userCol|            features|             user_id|
+-------+--------------------+--------------------+
|     26|[0.44952798, -3.1...|-EmuvqfmhKylSVG1H...|
|     29|[0.70582867, -2.7...|-FybaJ4pQZsfoRfW5...|
|    474|[0.68184626, -3.1...|-GsRfCDYv0myI_YCv...|
|    964|[-0.12040462, -3....|0LhRYfa7YErm-Y0Xh...|
|   1677|[1.041442, -2.852...|22ml-CTcoabnc-uu4...|
+-------+--------------------+--------------------+
only showing top 5 rows

+-------+--------------------+--------------------+
|itemCol|            features|         business_id|
+-------+--------------------+--------------------+
|      0|[0.32506737, -0.3...|vXb4OWsjPoiBtmarf...|
|     10|[0.009884086, -0....|Q9G2gtnVDZgsNerHD...|
|     20|[0.0032310067, -0...|LnJSsNVZkStgtj86f...|
|     30|[0.5399539, -0.17...|bQ_R0bLvTWS4jHsF6...|
|     40|[-0.21802746, -0....|jzsspHqP9kATWji8x...|
+-------+--------------------+--------------------+
only showing top 5 rows



In [10]:
from pyspark.sql.functions import concat_ws, col

# Convert the 'features' column to a string before writing to CSV for users
user_latent_vectors_string = user_latent_vectors.withColumn("features", concat_ws(",", col("features")))

# Export user_latent_vectors to a CSV file
user_latent_vectors_string.coalesce(1).write.csv("user_latent_vectors.csv", header=True, mode="overwrite")

# Convert the 'features' column to a string before writing to CSV for businesses
item_latent_vectors_string = item_latent_vectors.withColumn("features", concat_ws(",", col("features")))

# Export item_latent_vectors to a CSV file
item_latent_vectors_string.coalesce(1).write.csv("item_latent_vectors.csv", header=True, mode="overwrite")
