In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, lag, avg, dayofmonth, month, year, date_format,dayofweek,
    monotonically_increasing_id
)
from pyspark.sql.window import Window
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

In [2]:
from pyspark.sql import SparkSession
# Tạo Spark session + bật Hive
spark = SparkSession.builder \
    .appName("CryptoTrainAllCoins_Optimized") \
    .config("spark.sql.catalogImplementation", "hive") \
    .config("spark.hadoop.fs.defaultFS", "hdfs://namenode:9000") \
    .config("spark.sql.warehouse.dir", "hdfs://namenode:9000/user/hive/warehouse") \
    .config("hive.metastore.uris", "thrift://hive-metastore:9083") \
    .config("spark.sql.shuffle.partitions", "8")\
    .enableHiveSupport() \
    .getOrCreate()

In [3]:
# 2. Đọc dữ liệu thô từ Hive
df_raw = spark.sql("SELECT Date, Open, High, Low, Close, Volume, coin FROM crypto_db.crypto_prices")

In [4]:
# 3. Tạo dimension tables (dim_coin, dim_date)
dim_coin = df_raw.select("coin").distinct() \
    .withColumn("coin_id", monotonically_increasing_id())
dim_coin.write.mode("overwrite").saveAsTable("crypto_db.dim_coin")

dim_date = df_raw.select("Date").distinct() \
    .withColumn("date_id", date_format("Date", "yyyyMMdd").cast("int")) \
    .withColumn("day", dayofmonth("Date")) \
    .withColumn("month", month("Date")) \
    .withColumn("year", year("Date")) \
    .withColumn("weekday", date_format("Date", "E")) \
    .withColumn("is_weekend", ((dayofweek("Date") == 1) | (dayofweek("Date") == 7)).cast("boolean"))

In [5]:
# Ghi vào Hive table
dim_date.write \
    .format("parquet") \
    .mode("overwrite") \
    .saveAsTable("crypto_db.dim_date")

In [6]:
# 4. Tính feature & label cho toàn bộ coin
window = Window.partitionBy("coin").orderBy("Date")
df_feat = df_raw.orderBy("coin", "Date")
df_feat = df_feat.withColumn("prev_close", lag("Close", 1).over(window)) \
                 .withColumn("label", (col("Close") > col("prev_close")).cast("int")) \
                 .withColumn("pct_change", ((col("Close") - col("prev_close")) / col("prev_close")) * 100) \
                 .withColumn("ma7", avg("Close").over(window.rowsBetween(-6, 0))) \
                 .withColumn("ma30", avg("Close").over(window.rowsBetween(-29, 0))) \
                 .dropna()

In [7]:
# 5. Join dim_coin để lấy coin_id, tạo date_id
df_joined = df_feat.join(dim_coin, on="coin", how="left") \
                   .join(dim_date.select("Date", "date_id"), on="Date", how="left")



In [8]:
# 6. Tạo fact table & lưu
fact_df = df_joined.select(
    "date_id", "coin_id", "Open", "High", "Low", "Close", "Volume",
    "pct_change", "ma7", "ma30", "label"
)
fact_df.write.partitionBy("coin_id").mode("overwrite").saveAsTable("crypto_db.fact_crypto_price")
fact_df.show(20)  # Hiển thị 20 dòng đầu tiên


+--------+-------+-------------------+-------------------+-------------------+-------------------+--------+-------------------+-------------------+-------------------+-----+
| date_id|coin_id|               Open|               High|                Low|              Close|  Volume|         pct_change|                ma7|               ma30|label|
+--------+-------+-------------------+-------------------+-------------------+-------------------+--------+-------------------+-------------------+-------------------+-----+
|20211121|     65| 0.2990280091762543| 0.3323259949684143|0.27083298563957214|  0.293969988822937|666013.0|-1.6789933455869162| 0.2964800000190735| 0.2964800000190735|    0|
|20211122|     65|0.29424700140953064| 0.3332839906215668|0.25044599175453186| 0.2847540080547333|620981.0|-3.1350073540175805| 0.2925713360309601| 0.2925713360309601|    0|
|20211123|     65|0.28477799892425537|0.28861498832702637|0.20972900092601776| 0.2461860030889511|642836.0|-13.544323828575896|0.2

In [9]:
# 7. ML Pipeline
assembler = VectorAssembler(
    inputCols=["Open", "High", "Low", "Volume", "pct_change", "ma7", "ma30"],
    outputCol="features"
)
rf = RandomForestClassifier(featuresCol="features", labelCol="label", seed=42)
pipeline = Pipeline(stages=[assembler, rf])

In [10]:
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

coin_ids = [2, 1, 21, 9, 11, 12, 22, 6, 17, 23] # top 10 coins
results = []

for cid in coin_ids:
    print(f"\n========== ĐANG TRAIN CHO COIN ID {cid} ==========")
    
    coin_data = fact_df.filter(col("coin_id") == cid).orderBy("date_id")
    total = coin_data.count()
    if total < 100:
        print(f"[WARN] Coin ID {cid} có {total} record, bỏ qua.")
        continue
    
    # Tách train/test theo thứ tự thời gian
    window_spec = Window.orderBy("date_id")
    coin_data = coin_data.withColumn("row_num", row_number().over(window_spec))
    
    train_cnt = int(total * 0.8)
    train_df = coin_data.filter(col("row_num") <= train_cnt)
    test_df = coin_data.filter(col("row_num") > train_cnt)
    
    if train_df.count() == 0 or test_df.count() == 0:
        print(f"[ERROR] Train/Test rỗng cho Coin ID {cid}, bỏ qua.")
        continue
    
    # Kiểm tra phân phối nhãn
    label_dist = train_df.groupBy("label").count().collect()
    if len(label_dist) < 2:
        print(f"[SKIP] Coin ID {cid} chỉ có 1 nhãn duy nhất trong train set. Bỏ qua.")
        continue

    # Chuyển đổi dữ liệu thành vector features
    assembler = VectorAssembler(
        inputCols=["Open", "High", "Low", "Volume", "pct_change", "ma7", "ma30"],
        outputCol="features"
    )
    train_df = assembler.transform(train_df)

    test_df = assembler.transform(test_df)
    
    # Huấn luyện mô hình với tập tham số cố định
    rf = RandomForestClassifier(featuresCol="features", labelCol="label", numTrees=100, maxDepth=10, seed=42)
    model = rf.fit(train_df)
    predictions = model.transform(test_df)
    
    evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="f1")
    f1_score = evaluator.evaluate(predictions)
    
    print(f"[Coin ID {cid}] Model => F1 = {f1_score:.4f}")
    results.append((cid, f1_score))
    
    predictions.select("date_id", "coin_id", "label", "prediction") \
        .write.mode("overwrite") \
        .saveAsTable(f"crypto_db.predictions_coin_{cid}")
    
    
print("\n=== TỔNG KẾT KẾT QUẢ F1 ===")
for cid, score in results:
    print(f"Coin ID {cid} => F1 Score: {score:.4f}")



[Coin ID 2] Model => F1 = 0.9722

[Coin ID 1] Model => F1 = 0.9981

[Coin ID 21] Model => F1 = 0.9981

[Coin ID 9] Model => F1 = 1.0000

[Coin ID 11] Model => F1 = 0.9888

[Coin ID 12] Model => F1 = 0.9870

[Coin ID 22] Model => F1 = 1.0000

[Coin ID 6] Model => F1 = 0.9814

[Coin ID 17] Model => F1 = 0.9758

[Coin ID 23] Model => F1 = 0.9963

=== TỔNG KẾT KẾT QUẢ F1 ===
Coin ID 2 => F1 Score: 0.9722
Coin ID 1 => F1 Score: 0.9981
Coin ID 21 => F1 Score: 0.9981
Coin ID 9 => F1 Score: 1.0000
Coin ID 11 => F1 Score: 0.9888
Coin ID 12 => F1 Score: 0.9870
Coin ID 22 => F1 Score: 1.0000
Coin ID 6 => F1 Score: 0.9814
Coin ID 17 => F1 Score: 0.9758
Coin ID 23 => F1 Score: 0.9963


In [11]:
from pyspark.sql.functions import col, count, isnan

print("===> KIỂM TRA NULL TRONG fact_df")
fact_df.select(
    count(col("date_id").isNull().cast("int")).alias("null_date_id"),
    count(col("coin_id").isNull().cast("int")).alias("null_coin_id"),
    count(col("label").isNull().cast("int")).alias("null_label")
).show()

print("===> MẪU 10 DÒNG fact_df")
fact_df.select("date_id", "coin_id", "label").show(10)

===> KIỂM TRA NULL TRONG fact_df
+------------+------------+----------+
|null_date_id|null_coin_id|null_label|
+------------+------------+----------+
|      171607|      171607|    171607|
+------------+------------+----------+

===> MẪU 10 DÒNG fact_df
+--------+-------+-----+
| date_id|coin_id|label|
+--------+-------+-----+
|20211121|     65|    0|
|20211122|     65|    0|
|20211123|     65|    0|
|20211124|     65|    0|
|20211125|     65|    0|
|20211126|     65|    1|
|20211127|     65|    0|
|20211128|     65|    1|
|20211129|     65|    0|
|20211130|     65|    0|
+--------+-------+-----+
only showing top 10 rows



In [12]:
spark.sql("SHOW TABLES IN crypto_db").show(truncate=False)

+---------+-------------------+-----------+
|namespace|tableName          |isTemporary|
+---------+-------------------+-----------+
|crypto_db|crypto_prices      |false      |
|crypto_db|dim_coin           |false      |
|crypto_db|dim_date           |false      |
|crypto_db|fact_crypto_price  |false      |
|crypto_db|predictions_coin_1 |false      |
|crypto_db|predictions_coin_11|false      |
|crypto_db|predictions_coin_12|false      |
|crypto_db|predictions_coin_17|false      |
|crypto_db|predictions_coin_2 |false      |
|crypto_db|predictions_coin_21|false      |
|crypto_db|predictions_coin_22|false      |
|crypto_db|predictions_coin_23|false      |
|crypto_db|predictions_coin_6 |false      |
|crypto_db|predictions_coin_9 |false      |
+---------+-------------------+-----------+



In [13]:
spark.sql("""
    SELECT * 
    FROM crypto_db.fact_crypto_price
""").show()

+--------+------------------+------------------+------------------+------------------+-----------+--------------------+------------------+------------------+-----+-------+
| date_id|              Open|              High|               Low|             Close|     Volume|          pct_change|               ma7|              ma30|label|coin_id|
+--------+------------------+------------------+------------------+------------------+-----------+--------------------+------------------+------------------+-----+-------+
|20140918| 456.8599853515625| 456.8599853515625|   413.10400390625|424.44000244140625|  3.44832E7|  -7.192557601231182| 440.8870086669922| 440.8870086669922|    0|      3|
|20140919| 424.1029968261719| 427.8349914550781| 384.5320129394531| 394.7959899902344|  3.79197E7|   -6.98426450868382|425.52333577473956|425.52333577473956|    0|      3|
|20140920| 394.6730041503906| 423.2959899902344|389.88299560546875|408.90399169921875|  3.68636E7|   3.573491642945347| 421.3684997558594| 4