In [None]:
# install Java8 (Spark không tương thích tốt với các phiên bản Java khác)
!apt-get install openjdk-8-jdk-headless -qq > /dev/null

# download Spark (ví dụ với spark-3.5.1)
!wget -q https://archive.apache.org/dist/spark/spark-3.5.1/spark-3.5.1-bin-hadoop3.tgz

!tar xf spark-3.5.1-bin-hadoop3.tgz

# install findspark
!pip install -q findspark


# Set Environment Variables
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.5.1-bin-hadoop3"

# Quick Installation Test
import findspark
findspark.init()
from pyspark.sql import SparkSession
# Check the pyspark version
import pyspark
print(pyspark.__version__)


3.5.1


In [None]:
import math
from pyspark.sql import SparkSession, Window
from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, DoubleType
from pyspark.mllib.linalg.distributed import RowMatrix
from pyspark.mllib.linalg import Vectors

# ======================
# Khởi tạo Spark
# ======================
spark = SparkSession.builder.appName("SVD_Impute_Final").getOrCreate()
sc = spark.sparkContext



In [None]:
# ======================
# Đọc dữ liệu
# ======================
user_index = spark.read.parquet("/content/drive/MyDrive/XuLyDuLieuLon/mappings_svd_All_beautyv2/user_index.parquet")
item_index = spark.read.parquet("/content/drive/MyDrive/XuLyDuLieuLon/mappings_svd_All_beautyv2/item_index.parquet")
item_means = spark.read.parquet("/content/drive/MyDrive/XuLyDuLieuLon/preprocessed_svd_All_beautyv2/item_means.parquet")

train_df_norm = spark.read.parquet("/content/drive/MyDrive/XuLyDuLieuLon/preprocessed_svd_All_beautyv2/train_norm.parquet")
valid_df_norm = spark.read.parquet("/content/drive/MyDrive/XuLyDuLieuLon/preprocessed_svd_All_beautyv2/valid_norm.parquet")
test_df_norm  = spark.read.parquet("/content/drive/MyDrive/XuLyDuLieuLon/preprocessed_svd_All_beautyv2/test_norm.parquet")

In [None]:
from pyspark.sql import functions as F

# ==========================================
# 1️⃣ Giữ lại các cột cần thiết cho train, valid, test
# ==========================================
train_df = train_df_norm.select(
    "itemIndex",
    "userIndex",
    train_df_norm["rating"].cast("double"),
    train_df_norm["rating_norm"].cast("double"),
    train_df_norm["item_mean"].cast("double")
)

valid_df = valid_df_norm.select(
    "itemIndex",
    "userIndex",
    valid_df_norm["rating"].cast("double"),
    valid_df_norm["rating_norm"].cast("double"),
    valid_df_norm["item_mean"].cast("double")
)

test_df = test_df_norm.select(
    "itemIndex",
    "userIndex",
    test_df_norm["rating"].cast("double"),
    test_df_norm["rating_norm"].cast("double"),
    test_df_norm["item_mean"].cast("double")
)

# ==========================================
# 2️⃣ Gộp valid vào train → tạo train_full
# ==========================================
train_full = train_df.unionByName(valid_df, allowMissingColumns=True)

# ==========================================
# 3️⃣ Hợp nhất train_full + test
# ==========================================
full = train_full.unionByName(test_df, allowMissingColumns=True)

# ==========================================
# 4️⃣ Tạo cờ is_test cho test rows
# ==========================================
train_flag = train_full.select("userIndex", "itemIndex").withColumn("is_test", F.lit(False))
test_flag = test_df.select("userIndex", "itemIndex").withColumn("is_test", F.lit(True))
flags = train_flag.union(test_flag)

full = full.join(flags, on=["userIndex", "itemIndex"], how="left").fillna({"is_test": False})

# ==========================================
# 5️⃣ Kiểm tra kết quả
# ==========================================
print("Train+Valid count:", train_full.count())
print("Test count:", test_df.count())
full.printSchema()
full.show(5)


Train+Valid count: 2027
Test count: 508
root
 |-- userIndex: long (nullable = true)
 |-- itemIndex: long (nullable = true)
 |-- rating: double (nullable = true)
 |-- rating_norm: double (nullable = true)
 |-- item_mean: double (nullable = true)
 |-- is_test: boolean (nullable = false)

+---------+---------+------+-------------------+-----------------+-------+
|userIndex|itemIndex|rating|        rating_norm|        item_mean|is_test|
+---------+---------+------+-------------------+-----------------+-------+
|        0|        0|   5.0|0.20000000000000018|              4.8|  false|
|        1|        1|   5.0|0.33333333333333304|4.666666666666667|  false|
|        2|        2|   4.0|               -0.5|              4.5|  false|
|        3|        3|   4.0|-0.5454545454545459|4.545454545454546|  false|
|        4|        4|   5.0|0.16666666666666696|4.833333333333333|  false|
+---------+---------+------+-------------------+-----------------+-------+
only showing top 5 rows



In [None]:
# ==========================================
# 4️⃣ Tạo cột rating_masked
# ==========================================
# (rating_masked = 0 cho test, giữ nguyên cho train)
full = full.withColumn(
    "rating_masked",
    F.when(F.col("is_test") == True, F.lit(-99))          # test → -99
     .when(F.col("rating_norm").isNotNull(), F.col("rating_norm"))  # train → rating_norm
     .otherwise(F.lit(-10))                               # không thuộc train/test → -10
)
# Đếm số lượng user và item
n_users = int(full.select(F.max("userIndex")).first()[0]) + 1
n_items = int(full.select(F.max("itemIndex")).first()[0]) + 1

In [None]:
# ==========================================
# ✅ Kết quả: full
# ==========================================
# full gồm các cột:
# userIndex, itemIndex, rating, rating_norm, item_mean, is_test
full.printSchema()
full.show(50)

root
 |-- userIndex: long (nullable = true)
 |-- itemIndex: long (nullable = true)
 |-- rating: double (nullable = true)
 |-- rating_norm: double (nullable = true)
 |-- item_mean: double (nullable = true)
 |-- is_test: boolean (nullable = false)
 |-- rating_masked: double (nullable = true)

+---------+---------+------+--------------------+------------------+-------+--------------------+
|userIndex|itemIndex|rating|         rating_norm|         item_mean|is_test|       rating_masked|
+---------+---------+------+--------------------+------------------+-------+--------------------+
|        0|        0|   5.0| 0.20000000000000018|               4.8|  false| 0.20000000000000018|
|        1|        1|   5.0| 0.33333333333333304| 4.666666666666667|  false| 0.33333333333333304|
|        2|        2|   4.0|                -0.5|               4.5|  false|                -0.5|
|        3|        3|   4.0| -0.5454545454545459| 4.545454545454546|  false| -0.5454545454545459|
|        4|        4| 

In [None]:
# -------------------
# Group ratings by user to build sparse vectors (rows of rating matrix)
# -------------------
from pyspark.sql import functions as F

grouped = full.groupBy("userIndex").agg(
    F.collect_list(F.struct("itemIndex", "rating_masked")).alias("ratings")
)

# Broadcast tổng số item
n_items_b = sc.broadcast(full.select(F.max("itemIndex")).collect()[0][0] + 1)

def map_partition_to_sparse(iterator):
    from pyspark.mllib.linalg import Vectors
    n = n_items_b.value
    for row in iterator:
        try:
            if not row.ratings:
                # Người dùng chưa có rating nào
                yield Vectors.sparse(n, [], [])
                continue

            idx_vals = []
            for x in row.ratings:
                try:
                    idx = int(x.itemIndex)
                    val = float(x.rating_masked)
                    # ✅ Bỏ qua test (rating_masked = -99)
                    if val != -99.0:
                        idx_vals.append((idx, val))
                except Exception:
                    continue

            if len(idx_vals) == 0:
                yield Vectors.sparse(n, [], [])
            else:
                indices, values = zip(*idx_vals)
                yield Vectors.sparse(n, list(indices), list(values))

        except Exception:
            yield Vectors.sparse(n, [], [])

# RDD ma trận thưa (mỗi hàng là 1 user)
rows_rdd = grouped.rdd.mapPartitions(map_partition_to_sparse)


In [None]:
# -------------------
# Compute SVD on the masked rating matrix
# -------------------
mat = RowMatrix(rows_rdd)

k_init = 10
k_init = min(k_init, n_items, n_users)
print("computeSVD with k_init =", k_init)
svd = mat.computeSVD(k_init, computeU=True)
U_mat = svd.U
k_used = 10

# -------------------
# Build user features DataFrame (reduced to k_used)
# -------------------
def zip_row_with_index(pair):
    row, idx = pair
    return (int(idx), row.toArray().tolist())

user_feat_rdd = U_mat.rows.zipWithIndex().map(zip_row_with_index)
user_features_df = spark.createDataFrame(user_feat_rdd, ["userIndex", "features"])

def slice_features(feat):
    return feat[:k_used]

slice_udf = F.udf(slice_features, ArrayType(DoubleType()))
user_features_df = user_features_df.withColumn("feat_k", slice_udf(F.col("features"))).select(F.col("userIndex"), F.col("feat_k"))

computeSVD with k_init = 10


In [None]:
user_features_df.schema
user_features_df.show()

+---------+--------------------+
|userIndex|              feat_k|
+---------+--------------------+
|        0|[0.0, 0.0, 0.0, 0...|
|        1|[0.0, 0.0, 0.0, 0...|
|        2|[0.0, 0.0, 0.0, 0...|
|        3|[0.0, 0.0, 0.0, 0...|
|        4|[0.0, 0.0, 0.0, 0...|
|        5|[0.0, 0.0, 0.0, 0...|
|        6|[0.0, 0.0, 0.0, 0...|
|        7|[0.0, 0.0, 0.0, 0...|
|        8|[0.0, 0.0, 0.0, 0...|
|        9|[0.0, 0.0, 0.0, 0...|
|       10|[0.0, 0.0, 0.0, 0...|
|       11|[0.0, 0.0, 0.0, 0...|
|       12|[0.0, 0.0, 0.0, 0...|
|       13|[0.0, 0.0, 0.0, 0...|
|       14|[0.0, 0.0, 0.0, 0...|
|       15|[0.0, 0.0, 0.0, 0...|
|       16|[0.0, 0.0, 0.0, 0...|
|       17|[0.0, 0.0, 0.0, 0...|
|       18|[0.0, 0.0, 0.0, 0...|
|       19|[0.0, 0.0, 0.0, 0...|
+---------+--------------------+
only showing top 20 rows



In [None]:
# -------------------
# Cosine similarity UDF
# -------------------
import math

def cosine_sim_list(v1, v2):
    if v1 is None or v2 is None:
        return 0.0
    dot = 0.0
    norm1 = 0.0
    norm2 = 0.0
    for a, b in zip(v1, v2):
        dot += a * b
        norm1 += a * a
        norm2 += b * b
    if norm1 == 0.0 or norm2 == 0.0:
        return 0.0
    sim = dot / (math.sqrt(norm1) * math.sqrt(norm2))
    print(dot, math.sqrt(norm1), math.sqrt(norm2), sim)
    return float(dot / (math.sqrt(norm1) * math.sqrt(norm2)))

cosine_udf = F.udf(cosine_sim_list, "double")

In [None]:
# -------------------
# Compute top-10 neighbors for each user (user-based CF)
# -------------------
uf = user_features_df.select(F.col("userIndex").alias("a_idx"), F.col("feat_k").alias("a_feat"))
vf = user_features_df.select(F.col("userIndex").alias("b_idx"), F.col("feat_k").alias("b_feat"))

# Cross join để tạo tất cả cặp user
pairs = uf.crossJoin(vf).filter(F.col("a_idx") != F.col("b_idx"))

# Tính cosine similarity dựa trên SVD features
pairs = pairs.withColumn("sim", cosine_udf(F.col("a_feat"), F.col("b_feat")))

# Lấy top-10 neighbors cho mỗi user
w = Window.partitionBy("a_idx").orderBy(F.desc("sim"))
top10 = pairs.withColumn("rn", F.row_number().over(w)).filter(F.col("rn") <= 10)

# Chọn các cột cần thiết
top10 = top10.select(
    F.col("a_idx").alias("userIndex"),
    F.col("b_idx").alias("neighborIndex"),
    F.col("sim")
)

# Persist để sử dụng nhiều lần
top10 = top10.persist()


In [None]:
top10.orderBy(F.desc("sim")).show(10)

+---------+-------------+--------------------+
|userIndex|neighborIndex|                 sim|
+---------+-------------+--------------------+
|      145|           90|6.472790933722232...|
|       90|          145|6.472790933722232...|
|      167|           90|2.240247496700285...|
|       90|          167|2.240247496700285...|
|      235|           90|2.151306929674381...|
|       90|          235|2.151306929674381...|
|      199|           90|1.726354329082323...|
|       90|          199|1.726354329082323...|
|      145|          199|9.068686549769544...|
|      199|          145|9.068686549769544...|
+---------+-------------+--------------------+
only showing top 10 rows



In [None]:
# -------------------
# Chuẩn bị user-item pairs cần dự đoán
# Chỉ dự đoán cho các cặp có trong test_df
# -------------------
# Lấy user-item pairs từ test_df
test_pairs = test_df.select("userIndex", "itemIndex", "item_mean", F.col("rating").alias("true_rating"))

# Join với ratings_masked để giữ rating gốc nếu cần
ratings_masked = full.select(
    F.col("userIndex"),
    F.col("itemIndex"),
    F.col("rating_masked").alias("rating")
)

# Chỉ lấy các cặp test
missing_df = test_pairs.join(ratings_masked, on=["userIndex", "itemIndex"], how="left")


In [None]:
missing_df.show(5)

+---------+---------+-----------------+-----------+------+
|userIndex|itemIndex|        item_mean|true_rating|rating|
+---------+---------+-----------------+-----------+------+
|       24|       92|              4.2|        2.0| -99.0|
|       32|      267|              5.0|        5.0| -99.0|
|       34|       67|4.333333333333333|        5.0| -99.0|
|       53|      107|              4.0|        5.0| -99.0|
|       56|      272|              5.0|        4.0| -99.0|
+---------+---------+-----------------+-----------+------+
only showing top 5 rows



In [None]:
# ======================================================
#  IMPUTATION + FILL cho test pairs
# ======================================================

# -------------------
# Join missing_df (test_pairs) với top-10 neighbors
# -------------------
miss_nei = missing_df.join(
    top10,
    on=["userIndex"],
    how="left"
)

In [None]:
miss_nei.orderBy(F.desc("sim")).show(100)

+---------+---------+-----------------+-----------+------+-------------+--------------------+
|userIndex|itemIndex|        item_mean|true_rating|rating|neighborIndex|                 sim|
+---------+---------+-----------------+-----------+------+-------------+--------------------+
|      145|      225|             4.25|        3.0| -99.0|           90|6.472790933722232...|
|      145|      293|              3.5|        5.0| -99.0|           90|6.472790933722232...|
|       90|       89|            3.875|        5.0| -99.0|          145|6.472790933722232...|
|      145|      111|              5.0|        3.0| -99.0|           90|6.472790933722232...|
|       90|      187|              4.5|        4.0| -99.0|          145|6.472790933722232...|
|       90|      194|              4.4|        5.0| -99.0|          145|6.472790933722232...|
|      167|      303|              5.0|        5.0| -99.0|           90|2.240247496700285...|
|       90|       89|            3.875|        5.0| -99.0|  

In [None]:
# Chuẩn bị neighbor ratings
neighbor_ratings = full.select(
    F.col("userIndex").alias("neighborIndex"),
    F.col("itemIndex").alias("neighborItemIndex"),
    F.col("rating_masked").alias("neighbor_rating")
)

In [None]:
neighbor_ratings.show(10)

+-------------+-----------------+-------------------+
|neighborIndex|neighborItemIndex|    neighbor_rating|
+-------------+-----------------+-------------------+
|            0|                0|0.20000000000000018|
|            1|                1|0.33333333333333304|
|            2|                2|               -0.5|
|            3|                3|-0.5454545454545459|
|            4|                4|0.16666666666666696|
|            5|                5| 1.1428571428571428|
|            6|                6|               -1.5|
|            7|                7|0.40000000000000036|
|            8|                8|0.40000000000000036|
|            9|                9|               0.75|
+-------------+-----------------+-------------------+
only showing top 10 rows



In [None]:
# Join neighbor ratings
miss_nei = miss_nei.join(
    neighbor_ratings,
    on=[miss_nei.neighborIndex == neighbor_ratings.neighborIndex,
        miss_nei.itemIndex == neighbor_ratings.neighborItemIndex],
    how="left"
)

In [None]:
miss_nei.orderBy(F.asc("userIndex")).show(100)

+---------+---------+-----------------+-----------+------+-------------+---+-------------+-----------------+--------------------+
|userIndex|itemIndex|        item_mean|true_rating|rating|neighborIndex|sim|neighborIndex|neighborItemIndex|     neighbor_rating|
+---------+---------+-----------------+-----------+------+-------------+---+-------------+-----------------+--------------------+
|        0|      229|              4.6|        5.0| -99.0|            4|0.0|         NULL|             NULL|                NULL|
|        0|      330|              5.0|        5.0| -99.0|            5|0.0|         NULL|             NULL|                NULL|
|        0|       55|              4.0|        5.0| -99.0|            8|0.0|         NULL|             NULL|                NULL|
|        0|       55|              4.0|        5.0| -99.0|            5|0.0|         NULL|             NULL|                NULL|
|        0|      229|              4.6|        5.0| -99.0|            8|0.0|         NULL|

In [None]:
# Lọc neighbor có rating hợp lệ (khác -99)
miss_nei = miss_nei.filter(F.col("neighbor_rating").isNotNull() & (F.col("neighbor_rating") != -99) & (F.col("neighbor_rating") != -10))


In [None]:
miss_nei_filtered = miss_nei.filter(F.col("sim") > 0.5)
# Tính trung bình rating của neighbors
imputed = miss_nei.groupBy("userIndex", "itemIndex") \
    .agg(F.avg("neighbor_rating").alias("imputed_rating"))

# Nếu không có neighbor nào → gán 0.0
imputed = imputed.withColumn(
    "final_imputed",
    F.when(F.col("imputed_rating").isNull(), F.lit(0.0))
     .otherwise(F.col("imputed_rating"))
).select("userIndex", "itemIndex", F.col("final_imputed").alias("imputed_rating"))
imputed.orderBy(F.asc("userIndex")).show(20)
# -------------------
# Build filled rating chỉ cho test pairs
# -------------------
filled_test = missing_df.join(imputed, on=["userIndex", "itemIndex"], how="left")

filled_test = filled_test.withColumn(
    "filled_rating",
    F.when(F.col("imputed_rating").isNotNull(), F.col("imputed_rating"))  # dùng giá trị imputation
     .otherwise(F.lit(0))  # fallback: giữ rating gốc nếu cần
)
filled_test.orderBy(F.asc("userIndex")).show(20)
filled_test = filled_test.withColumn(
    "pred_rating",
    F.col("filled_rating") + F.col("item_mean")
)
# Chọn cột cuối cùng
filled_test = filled_test.select("userIndex", "itemIndex", "pred_rating", "true_rating")


+---------+---------+--------------------+
|userIndex|itemIndex|      imputed_rating|
+---------+---------+--------------------+
|        1|      214|                0.75|
|        1|      106|-0.16666666666666696|
|        2|       30|                 0.0|
|        7|      142|               0.875|
|        7|      173|               0.125|
|       11|        3| -0.5454545454545459|
|       16|      196|                 0.5|
|       16|       42|               0.375|
|       16|      355|                 0.0|
|       18|       86|  0.3636363636363633|
|       20|      355|                 0.0|
|       20|       30|                 0.0|
|       24|       92|  0.7999999999999998|
|       24|      139|                 1.0|
|       24|      264|                 0.0|
|       26|      107|                 1.0|
|       27|      159|                0.25|
|       31|      168|               -0.25|
|       32|      188|                 0.0|
|       32|      173|               0.125|
+---------+

In [None]:
filled_test.show(100)
from pyspark.sql import functions as F
from pyspark.sql import types as T

# -------------------
# Tính RMSE
# -------------------
rmse_df = filled_test.withColumn(
    "squared_error",
    F.pow(F.col("pred_rating") - F.col("true_rating"), 2)
)

rmse_val = rmse_df.agg(F.sqrt(F.avg("squared_error")).alias("RMSE")).collect()[0]["RMSE"]

print("Test RMSE:", rmse_val)

# -------------------
# Tính MAE
# -------------------
mae_df = filled_test.withColumn(
    "abs_error",
    F.abs(F.col("pred_rating") - F.col("true_rating"))
)

mae_val = mae_df.agg(F.avg("abs_error").alias("MAE")).collect()[0]["MAE"]

print("Test MAE:", mae_val)

+---------+---------+-----------------+-----------+
|userIndex|itemIndex|      pred_rating|true_rating|
+---------+---------+-----------------+-----------+
|        0|       55|              4.0|        5.0|
|        0|       67|4.333333333333333|        5.0|
|        0|      330|              5.0|        5.0|
|        1|      214|              5.0|        1.0|
|        2|      146|              4.5|        4.0|
|        4|      190|4.166666666666667|        5.0|
|        5|      185|              4.5|        5.0|
|       11|        3|              4.0|        5.0|
|       16|       10|4.666666666666667|        4.0|
|       16|      204|              4.0|        5.0|
|       16|      274|4.166666666666667|        5.0|
|       16|      355|              5.0|        4.0|
|       18|       86|              5.0|        5.0|
|       18|      282|4.333333333333333|        3.0|
|       20|      289|4.666666666666667|        5.0|
|       23|      287|              4.2|        3.0|
|       24| 

In [None]:
from pyspark.sql import Window
from pyspark.sql import functions as F

# Tạo cột true_rank theo true_rating giảm dần trong từng user
w_true = Window.partitionBy("userIndex").orderBy(F.desc("true_rating"))
filled_test = filled_test.withColumn("true_rank", F.row_number().over(w_true))

# Tạo cột pred_rank theo pred_rating giảm dần trong từng user
w_pred = Window.partitionBy("userIndex").orderBy(F.desc("pred_rating"))
filled_test = filled_test.withColumn("pred_rank", F.row_number().over(w_pred))

# Xem kết quả
filled_test.orderBy(F.asc("userIndex")).show(100)


+---------+---------+------------------+-----------+---------+---------+
|userIndex|itemIndex|       pred_rating|true_rating|true_rank|pred_rank|
+---------+---------+------------------+-----------+---------+---------+
|        0|      330|               5.0|        5.0|        4|        1|
|        0|      229|               4.6|        5.0|        3|        2|
|        0|       67| 4.333333333333333|        5.0|        2|        3|
|        0|       55|               4.0|        5.0|        1|        4|
|        1|      315|               5.0|        5.0|        2|        1|
|        1|      214|               5.0|        1.0|        3|        2|
|        1|      106|               4.0|        5.0|        1|        3|
|        2|       30|               5.0|        5.0|        1|        1|
|        2|      146|               4.5|        4.0|        2|        2|
|        4|       67| 4.333333333333333|        5.0|        1|        1|
|        4|      190| 4.166666666666667|        5.0

In [None]:
from pyspark.sql import functions as F

K = 10  # top-K dự đoán

# Lấy top-K items dự đoán cho mỗi user
topk_pred = filled_test.filter(F.col("pred_rank") <= K)

# Đánh dấu item dự đoán có rating thật >= 4
topk_pred = topk_pred.withColumn("relevant", F.when(F.col("true_rating") >= 4, 1).otherwise(0))

# Tính precision@K cho từng user
precision_per_user = topk_pred.groupBy("userIndex") \
    .agg((F.sum("relevant") / F.lit(K)).alias("precision_at_{}".format(K)))

# Precision@K trung bình toàn bộ user
precision_at_k = precision_per_user.agg(F.mean("precision_at_{}".format(K)).alias("precision_at_{}".format(K))).collect()[0][0]

print(f"Precision@{K} trung bình: {precision_at_k:.4f}")


Precision@10 trung bình: 0.2064


In [None]:
from pyspark.sql import Window
from pyspark.sql import functions as F
import math

K = 10

# 1️⃣ Tạo cột DCG theo pred_rank
filled_test = filled_test.withColumn(
    "dcg_contrib",
    (2 ** F.col("true_rating") - 1) / F.log2(F.col("pred_rank") + 1)
)

# Window cho top-K dự đoán mỗi user
w = Window.partitionBy("userIndex").orderBy("pred_rank")
topk_pred = filled_test.withColumn("rn", F.row_number().over(w)).filter(F.col("rn") <= K)

# DCG@K từng user
dcg_user = topk_pred.groupBy("userIndex").agg(F.sum("dcg_contrib").alias("dcg"))

# 2️⃣ Tính IDCG@K (sắp xếp theo true_rating giảm dần)
w_true = Window.partitionBy("userIndex").orderBy(F.desc("true_rating"))
topk_true = filled_test.withColumn("rn_true", F.row_number().over(w_true)).filter(F.col("rn_true") <= K)

topk_true = topk_true.withColumn(
    "idcg_contrib",
    (2 ** F.col("true_rating") - 1) / F.log2(F.col("rn_true") + 1)
)
idcg_user = topk_true.groupBy("userIndex").agg(F.sum("idcg_contrib").alias("idcg"))

# 3️⃣ NDCG@K từng user
ndcg_user = dcg_user.join(idcg_user, on="userIndex")
ndcg_user = ndcg_user.withColumn(
    "ndcg",
    F.when(F.col("idcg") == 0, 0.0).otherwise(F.col("dcg") / F.col("idcg"))
)

# 4️⃣ NDCG@K trung bình toàn bộ users
ndcg_at_k = ndcg_user.agg(F.mean("ndcg").alias("ndcg_at_{}".format(K))).collect()[0][0]

print(f"NDCG@{K} trung bình: {ndcg_at_k:.4f}")


NDCG@10 trung bình: 0.9651
