In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import monotonically_increasing_id, col, count, row_number, lit, when, rand, floor
from pyspark.sql.window import Window
import os

### Initialize Spark Session and Load Data

In [2]:
spark = SparkSession.builder.appName("Combine_All_Features_ALS_MLP") \
    .config("spark.driver.memory", "128g") \
    .config("spark.executor.memory", "32g") \
    .config("spark.executor.memoryOverhead", "4g") \
    .config("spark.sql.shuffle.partitions", "400") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("spark.kryoserializer.buffer.max", "1024m") \
    .getOrCreate()

# Define paths to your CSV files
file_path_train = "data/als/training_data.csv"
file_path_prod = "data/als/production_data.csv"  
file_path_val = "data/als/validation_data.csv"    

csv_files_to_load = [file_path_train, file_path_prod, file_path_val]
all_dataframes = []

master_schema = spark.read.csv(csv_files_to_load[0], header=True, inferSchema=True).schema
combined_parquet_path = "data/combined_features_deduplicated.parquet"

print("--- Reading CSVs with inferSchema=True ---")
for file_path in csv_files_to_load:
    print(f"Reading CSV file: {file_path}")
    try:
        df_temp = spark.read.csv(file_path, header=True, schema=master_schema)
        print(f"Successfully read {df_temp.count()} records from {file_path}.")

        critical_als_cols_to_select = [
            col("user_id").cast("string").alias("user_id"), # Ensure string for joining with mapping
            col("business_id").cast("string").alias("business_id"), # Ensure string for joining with mapping
            col("stars").cast("double").alias("stars") # Ensure numeric for rating
        ]
        
        # Get list of other columns to carry over with their types
        other_cols = [c for c in df_temp.columns if c.lower() not in ["user_id", "business_id", "stars"]]
        select_exprs = critical_als_cols_to_select + [col(c) for c in other_cols]
        
        df_temp_standardized_core = df_temp.select(select_exprs)
        
        all_dataframes.append(df_temp_standardized_core)
        print(f"Schema after standardizing core ALS columns for {file_path}:")

    except Exception as e:
        print(f"Error processing {file_path}: {e}")
        # Decide how to handle errors: skip the file, stop execution, etc.

if not all_dataframes:
    raise ValueError("No DataFrames were successfully processed. Please check file paths and CSV formats.")

# Combine all DataFrames using unionByName
# This should now be more robust for the critical columns.
combined_df = None
if len(all_dataframes) > 0:
    # Check if all dataframes to be unioned have the same column names (order doesn't matter for unionByName)
    # This is a basic check; type compatibility for non-critical columns is still managed by Spark's unionByName.
    base_columns = set(all_dataframes[0].columns)
    for i, df_to_union in enumerate(all_dataframes[1:], 1):
        if set(df_to_union.columns) != base_columns:
            print(f"Warning: DataFrame from file {csv_files_to_load[i]} has different columns than the first DataFrame.")
            print(f"Base columns: {sorted(list(base_columns))}")
            print(f"Columns in DF {i}: {sorted(list(set(df_to_union.columns)))}")
            print("Ensure 'unionByName' can resolve this or select a common subset of columns for all DataFrames if this is an issue.")
            # For safety, you might want to intersect columns here or handle error.
            # common_cols = list(base_columns.intersection(set(df_to_union.columns)))
            # all_dataframes[0] = all_dataframes[0].select(common_cols)
            # all_dataframes[i] = df_to_union.select(common_cols)
            # base_columns = set(common_cols)


    combined_df = all_dataframes[0]
    if len(all_dataframes) > 1:
        for i in range(1, len(all_dataframes)):
            # Using allowMissingColumns=True can be helpful if some non-critical columns
            # only exist in some files, they will be filled with nulls where missing.
            # However, ensure this is the desired behavior.
            combined_df = combined_df.unionByName(all_dataframes[i], allowMissingColumns=True)
    print(f"\n--- Combined DataFrame ---")
    print(f"Total records after union: {combined_df.count()}")
else:
    print("No data was loaded to combine.")
    spark.stop()
    exit()

unique_key_columns = ["user_id", "business_id", "stars"] # Add other columns that define a unique record
# Check if these columns exist before trying to use them
missing_keys = [key for key in unique_key_columns if key not in combined_df.columns]
if missing_keys:
    raise ValueError(f"Missing one or more key columns for deduplication: {missing_keys}. Available columns: {combined_df.columns}")

print(f"\n--- Performing deduplication based on {unique_key_columns} ---")
distinct_combined_df = combined_df.dropDuplicates(unique_key_columns)
distinct_combined_df.cache() # Cache the final clean dataset
print(f"Total records after deduplication: {distinct_combined_df.count()}")
# distinct_combined_df.show(3, truncate=False)
combined_df.unpersist() # Unpersist the pre-deduplication version


df_for_mapping = distinct_combined_df

print("\n--- DataFrame ready for ID mapping ---")

# --- Continue with your ID mapping and ALS input preparation ---
user_id_mapping_loaded = combined_df.select("user_id").distinct().withColumn("userCol", monotonically_increasing_id())
business_id_mapping_loaded = combined_df.select("business_id").distinct().withColumn("itemCol", monotonically_increasing_id())

print("Joining combined data with mapping tables and saving the mappings...")
# The join keys "user_id" and "business_id" in df_for_mapping are already strings here.
als_input_df_loaded = df_for_mapping.join(user_id_mapping_loaded, "user_id", "inner")
als_input_df_loaded = als_input_df_loaded.join(business_id_mapping_loaded, "business_id", "inner")

# Now select the final columns needed for ALS, ensuring correct types
# userCol and itemCol come from your mapping tables (should be int)
# stars (rating) was already cast to double
# als_input_df_mapped = als_input_df_loaded.select(
#     col("userCol").cast("int"),
#     col("itemCol").cast("int"),
#     col("stars").alias("rating") # stars is already double, aliasing to rating
#     # If you need any of the other 300+ features for your ALS model (uncommon for standard ALS),
#     # you would select and process them here. Standard ALS only uses user, item, rating.
# )

print("\n--- Count of the new comprehensive als_input_df_loaded ---")
print(f"Total records in new als_input_df_loaded: {als_input_df_loaded.count()}")

--- Reading CSVs with inferSchema=True ---
Reading CSV file: data/als/training_data.csv
Successfully read 928878 records from data/als/training_data.csv.
Schema after standardizing core ALS columns for data/als/training_data.csv:
Reading CSV file: data/als/production_data.csv
Successfully read 4164772 records from data/als/production_data.csv.
Schema after standardizing core ALS columns for data/als/production_data.csv:
Reading CSV file: data/als/validation_data.csv
Successfully read 32637 records from data/als/validation_data.csv.
Schema after standardizing core ALS columns for data/als/validation_data.csv:

--- Combined DataFrame ---
Total records after union: 5126287

--- Performing deduplication based on ['user_id', 'business_id', 'stars'] ---
Total records after deduplication: 5027330

--- DataFrame ready for ID mapping ---
Joining combined data with mapping tables and saving the mappings...

--- Count of the new comprehensive als_input_df_loaded ---
Total records in new als_input

In [3]:
als_input_df_loaded.schema

StructType([StructField('business_id', StringType(), True), StructField('user_id', StringType(), True), StructField('stars', DoubleType(), True), StructField('date', TimestampType(), True), StructField('city', StringType(), True), StructField('category_0', DoubleType(), True), StructField('category_1', DoubleType(), True), StructField('category_2', DoubleType(), True), StructField('category_3', DoubleType(), True), StructField('category_4', DoubleType(), True), StructField('category_5', DoubleType(), True), StructField('category_6', DoubleType(), True), StructField('category_7', DoubleType(), True), StructField('category_8', DoubleType(), True), StructField('category_9', DoubleType(), True), StructField('category_10', DoubleType(), True), StructField('category_11', DoubleType(), True), StructField('category_12', DoubleType(), True), StructField('category_13', DoubleType(), True), StructField('category_14', DoubleType(), True), StructField('category_15', DoubleType(), True), StructField

In [4]:
# --- User Defined Parameters ---
MIN_INTERACTIONS_PER_USER = 3
MIN_INTERACTIONS_PER_ITEM = 3  # For businesses/restaurants
TRAIN_RATIO = 0.6
VALIDATION_RATIO = 0.2
# TEST_RATIO is implicitly 1.0 - TRAIN_RATIO - VALIDATION_RATIO (i.e., 0.2)

# Column names (ensure these match your DataFrame)
USER_ID_COL = "user_id"
ITEM_ID_COL = "business_id"
# RATING_COL = "stars" # Assuming this is a column in your df, though not directly used in split logic here.

TRAIN_CSV_PATH = "./als_split/train_data"
VALIDATION_CSV_PATH = "./als_split/validation_data"
TEST_CSV_PATH = "./als_split/test_data"

# Assume 'als_input_df_loaded' is your PySpark DataFrame
# For example:
# als_input_df_loaded = spark.read.parquet("path/to/your/als_input_df_loaded.parquet")

print(f"Original dataset record count: {als_input_df_loaded.count()}")
original_cols = distinct_combined_df.columns # Store original columns

# --- Step 1: K-Core Filtering (Iterative) ---
print("\nStarting K-core filtering...")
current_df = als_input_df_loaded

# Iteratively filter users and items until counts stabilize or max iterations are reached
for i in range(10): # Max 10 iterations for k-core filtering
    initial_count_iter = current_df.count()
    print(f"  K-core Iteration {i+1}: Starting with {initial_count_iter} records.")

    # Filter by min interactions per user
    user_interaction_counts = current_df.groupBy(USER_ID_COL).agg(count("*").alias("u_count"))
    df_after_user_filter = current_df.join(user_interaction_counts, USER_ID_COL, "inner") \
                                     .filter(col("u_count") >= MIN_INTERACTIONS_PER_USER) \
                                     .select(*original_cols) # Keep only original columns to avoid duplicate columns

    # Filter by min interactions per item (on the result of user filtering)
    item_interaction_counts = df_after_user_filter.groupBy(ITEM_ID_COL).agg(count("*").alias("i_count"))
    current_df = df_after_user_filter.join(item_interaction_counts, ITEM_ID_COL, "inner") \
                                       .filter(col("i_count") >= MIN_INTERACTIONS_PER_ITEM) \
                                       .select(*original_cols) # Keep only original columns
    
    current_iter_count_after_filter = current_df.count()
    print(f"  K-core Iteration {i+1}: Ended with {current_iter_count_after_filter} records.")
    
    if current_iter_count_after_filter == initial_count_iter:
        print("  K-core filtering stabilized.")
        break
    if i == 9: # Last iteration
        print("  Warning: K-core filtering reached max iterations.")

filtered_df = current_df.cache() # Cache the result of k-core filtering
final_filtered_count = filtered_df.count()
print(f"K-core filtering complete. Records after filtering: {final_filtered_count}")

if final_filtered_count == 0:
    print("ERROR: K-core filtering resulted in an empty DataFrame. ")
    print("Please check your MIN_INTERACTIONS thresholds or the density of your input data.")
    # Consider raising an error or exiting if this happens:
    # raise ValueError("K-core filtering resulted in an empty DataFrame.")
    # For now, we'll print an error and let it continue, though subsequent steps will likely fail.

# --- Step 2: Per-User Splitting ---
# This step splits each user's interactions proportionally into train, validation, and test sets.
print("\nStarting per-user splitting...")

# Add a column with the total number of interactions for each user (post k-core filtering)
df_with_user_total_interactions = filtered_df.withColumn(
    "total_interactions_for_user", 
    count("*").over(Window.partitionBy(USER_ID_COL))
)

# Define a window specification to order interactions randomly within each user group.
# Using a fixed seed (e.g., 42) for rand() ensures reproducibility of the split.
window_spec_user_split = Window.partitionBy(USER_ID_COL).orderBy(rand(seed=42))

# Add a row number (rank) to each interaction within its user group
df_with_rn = df_with_user_total_interactions.withColumn("rn_in_user", row_number().over(window_spec_user_split))

# Calculate the index boundaries for splitting based on defined ratios.
# Using floor ensures that for the minimum interaction count (e.g., 3 for a user),
# the items are distributed as 1 to train, 1 to validation, and 1 to test.
df_with_boundaries = df_with_rn.withColumn(
    "train_boundary_idx", 
    floor(TRAIN_RATIO * col("total_interactions_for_user"))
).withColumn(
    "validation_boundary_idx",
    floor((TRAIN_RATIO + VALIDATION_RATIO) * col("total_interactions_for_user"))
)

# Assign records to train, validation, or test sets based on their row number and the calculated boundaries
# For a user with MIN_INTERACTIONS_PER_USER (3):
#   train_boundary_idx = floor(0.6 * 3) = 1
#   validation_boundary_idx = floor(0.8 * 3) = 2
#   Train: rn_in_user <= 1 (1 item)
#   Validation: rn_in_user > 1 AND rn_in_user <= 2 (1 item)
#   Test: rn_in_user > 2 (1 item)
# This results in a 1, 1, 1 split for users with the minimum 3 interactions.

train_df_intermediate = df_with_boundaries.filter(col("rn_in_user") <= col("train_boundary_idx"))
validation_df_intermediate = df_with_boundaries.filter(
    (col("rn_in_user") > col("train_boundary_idx")) & (col("rn_in_user") <= col("validation_boundary_idx"))
)
test_df_intermediate = df_with_boundaries.filter(col("rn_in_user") > col("validation_boundary_idx"))

# Select only the original columns for the final DataFrames and cache them
train_df = train_df_intermediate.select(*original_cols).cache()
validation_df = validation_df_intermediate.select(*original_cols).cache()
test_df = test_df_intermediate.select(*original_cols).cache()

# Unpersist the DataFrame that results from k-core filtering as it's no longer directly needed
filtered_df.unpersist()

print(f"  Train set interaction count (after per-user split): {train_df.count()}")
print(f"  Validation set interaction count (after per-user split): {validation_df.count()}")
print(f"  Test set interaction count (after per-user split): {test_df.count()}")

# --- Step 3: Ensure Items in Validation/Test are in Train (Warm-Start Evaluation) ---
# This step ensures that items in the validation and test sets have also been seen in the training set.
# This is a common practice for evaluating recommendation models fairly.
print("\nEnsuring warm-start for validation and test sets (items must exist in train set)...")

train_items = train_df.select(ITEM_ID_COL).distinct().cache()
num_train_items = train_items.count()
print(f"  Number of unique items in the final training set: {num_train_items}")

# Filter validation set
original_val_count = validation_df.count()
# Perform an inner join with unique items from the training set
validation_df_final = validation_df.join(train_items, ITEM_ID_COL, "inner").select(*original_cols).cache()
new_val_count = validation_df_final.count()
validation_df.unpersist() # Unpersist the version before this "warm-start" filtering
print(f"  Validation interactions: Original = {original_val_count}, After ensuring items in train = {new_val_count}")
if original_val_count > new_val_count:
    print(f"  INFO: {original_val_count - new_val_count} interactions removed from validation set because their item was not present in the training set.")

# Filter test set
original_test_count = test_df.count()
# Perform an inner join with unique items from the training set
test_df_final = test_df.join(train_items, ITEM_ID_COL, "inner").select(*original_cols).cache()
new_test_count = test_df_final.count()
test_df.unpersist() # Unpersist the version before this "warm-start" filtering
print(f"  Test interactions: Original = {original_test_count}, After ensuring items in train = {new_test_count}")
if original_test_count > new_test_count:
    print(f"  INFO: {original_test_count - new_test_count} interactions removed from test set because their item was not present in the training set.")

train_items.unpersist() # Unpersist the distinct training items DataFrame

# --- Step 4: Saving DataFrames to CSV ---
print("\nSaving data sets to CSV (output will be directories)...")

final_train_count = train_df.count() # train_df is already the final version for training
final_validation_count = validation_df_final.count()
final_test_count = test_df_final.count()

print(f"Final counts for output: Train={final_train_count}, Validation={final_validation_count}, Test={final_test_count}")

if final_train_count > 0:
    train_df.write.mode("overwrite").option("header", "true").csv(TRAIN_CSV_PATH)
    print(f"  Training data saved to: {TRAIN_CSV_PATH}")
else:
    print(f"  Skipping saving training data as it is empty.")

if final_validation_count > 0:
    validation_df_final.write.mode("overwrite").option("header", "true").csv(VALIDATION_CSV_PATH)
    print(f"  Validation data saved to: {VALIDATION_CSV_PATH}")
else:
    print(f"  Skipping saving validation data as it is empty.")

if final_test_count > 0:
    test_df_final.write.mode("overwrite").option("header", "true").csv(TEST_CSV_PATH)
    print(f"  Test data saved to: {TEST_CSV_PATH}")
else:
    print(f"  Skipping saving test data as it is empty.")

print("\nData splitting and saving process complete.")

# --- Clean up cached DataFrames ---
train_df.unpersist()
validation_df_final.unpersist()
test_df_final.unpersist()

print("\nAll cached DataFrames unpersisted.")

Original dataset record count: 5027330

Starting K-core filtering...
  K-core Iteration 1: Starting with 5027330 records.
  K-core Iteration 1: Ended with 3665543 records.
  K-core Iteration 2: Starting with 3665543 records.
  K-core Iteration 2: Ended with 3665225 records.
  K-core Iteration 3: Starting with 3665225 records.
  K-core Iteration 3: Ended with 3665225 records.
  K-core filtering stabilized.
K-core filtering complete. Records after filtering: 3665225

Starting per-user splitting...
  Train set interaction count (after per-user split): 2008955
  Validation set interaction count (after per-user split): 785055
  Test set interaction count (after per-user split): 871215

Ensuring warm-start for validation and test sets (items must exist in train set)...
  Number of unique items in the final training set: 63487
  Validation interactions: Original = 785055, After ensuring items in train = 784397
  INFO: 658 interactions removed from validation set because their item was not pre