Skip to content

[BUG]mmlspark==0.18.1 BarrierJobUnsupportedRDDChainException #2393

Open
@wyfSunflower

Description

@wyfSunflower

SynapseML version

0.18.1

System information

  • Language version (e.g. python 3.7.9., scala 2.12):
  • Spark Version (e.g. 2.4.5):
  • Spark Platform (e.g. dataworks maxcompute):

Describe the problem

An error occurred while calling o2446.fit.
: org.apache.spark.scheduler.BarrierJobUnsupportedRDDChainException: [SPARK-24820][SPARK-24821]: Barrier execution mode does not allow the following pattern of RDD chain within a barrier stage:
Ancestor RDDs that have different number of partitions from the resulting RDD (eg. union()/coalesce()/first()/take()/PartitionPruningRDD). A workaround for first()/take() can be barrierRdd.collect().head (scala) or barrierRdd.collect()[0] (python).
An RDD that depends on multiple barrier RDDs (eg. barrierRdd1.zip(barrierRdd2)).
at org.apache.spark.scheduler.DAGScheduler.checkBarrierStageWithRDDChainPattern(DAGScheduler.scala:372)
at org.apache.spark.scheduler.DAGScheduler.createResultStage(DAGScheduler.scala:448)
at org.apache.spark.scheduler.DAGScheduler.handleJobSubmitted(DAGScheduler.scala:963)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2073)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2065)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2054)
at org.apache.spark.util.EventLoop
anon

Code to reproduce issue

# -*- coding: utf-8 -*-
import sys,os,base64,pickle,time,gc
print(f'CDC:{os.getcwd()}')
sys.path.append(os.getcwd() + '/mmlspark')
from mmlspark.lightgbm import LightGBMClassifier
from mmlspark.train import ComputeModelStatistics
from pyspark.ml.linalg import Vectors
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import col, when, count, isnan, udf, broadcast, lit, array, avg, expr
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType, StringType, IntegerType, FloatType, DecimalType, LongType, BooleanType
import numpy as np
import pandas as pd
import lightgbm as lgb
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from datetime import datetime
from pyspark.ml.feature import VectorAssembler
# 1. 创建优化的Spark会话
spark = SparkSession.builder \
    .appName("Distributed_LightGBM_Pipeline") \
    .config("spark.driver.memory", "128g") \
    .config("spark.driver.cores", "32") \
    .config("spark.driver.maxResultSize", "32g") \
    .config("spark.shuffle.service.enabled", "false") \
    .config("spark.dynamicAllocation.enabled", "false") \
    .config("spark.dynamicAllocation.minExecutors", "10") \
    .config("spark.dynamicAllocation.maxExecutors", "200") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("spark.kryoserializer.buffer.max", "512m") \
    .config("spark.sql.shuffle.partitions", "800") \
    .config("spark.default.parallelism", "800") \
    .config("spark.executor.memoryOverhead", "4096") \
    .config("spark.memory.fraction", "0.7") \
    .config("spark.network.timeout", "800s") \
    .config("spark.rpc.askTimeout", "600s") \
    .config("spark.sql.autoBroadcastJoinThreshold", "-1") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.inMemoryColumnarStorage.compressed", "true") \
    .config("spark.sql.files.maxPartitionBytes", "256") \
    .enableHiveSupport() \
    .getOrCreate()
# 关键配置:降低日志级别显示print输出
spark.sparkContext.setLogLevel("INFO")  
print("Spark配置信息:")
print(f"Executor数量: {spark.conf.get('spark.executor.instances')}")
print(f"每个Executor的内存: {spark.conf.get('spark.executor.memory')}")
print(f"每个Executor的核心数: {spark.conf.get('spark.executor.cores')}")
task_list_cn=['复购概率预测','降频预警预测','沉默流失召回可能性预测','沉默预警预测'];task_list=['repurchase','decrease_alert','silent_recall','silent_alert']
database_name='chagee_cdp_dev';execute = True ;interval_list = [7, 14]
label_name,time_interval = task_list[1],interval_list[0];c0=str(time_interval);label_cn=task_list_cn[task_list.index(label_name)]
table_name=f'{database_name}.ads_algo_{label_name}_{c0}_define_df';task_col='probability';
# 2. 配置参数
class Config:
    # 数据源配置
    TRAIN_TABLE = table_name # "chagee_cdp_dev.ads_algo_decrease_alert_14_define_df"
    TRAIN_PARTITION = "20250529"
    PREDICT_PARTITION = "20250612"
    OUTPUT_TABLE = "chagee_cdp_prod.ads_algo_dtc_df"
    OUTPUT_PARTITION = {"pt": "20250612", "label_name": label_name, "time_interval": c0}
    # 列名配置
    ID_COLUMNS = ["member_id", "oneid"]
    LABEL_COLUMN = "probability"
    # 处理配置
    TEST_SIZE = 0.2
    RANDOM_SEED = 42
    NUM_PARTITIONS = 2000  # 用于重分区的分区数
    ARTIFACTS_TABLE = "chagee_cdp_dev.ads_algo_df" # 新增:用于存储模型和编码器的表
    # 定义存储在表中的 item_name
    ENCODER_ITEM_NAME = f"target_encoder_{label_name}_{c0}d"
    MODEL_ITEM_NAME = f"lgbm_model_{label_name}_{c0}d"
    # 存储路径
    # MODEL_DIR = "/tmp/models/decrease_alert_fixed" # 使用 /tmp 目录
    # SPARK_MODEL_PATH = os.path.join(MODEL_DIR, "spark_lgbm_model")
    # ENCODERS_FILE = "target_encoders.pkl"
    # ENCODER_PATH = os.path.join(MODEL_DIR, ENCODERS_FILE) # 统一路径变量
    # LightGBM参数 - 针对大数据量优化
    LGBM_PARAMS = {
        'objective': 'binary',
        'boostingType': 'gbdt',
        'numLeaves': 63,
        'learningRate': 0.05,
        'featureFraction': 0.8,
        'baggingFraction': 0.8,
        'baggingFreq': 5,
        'maxDepth': 8,
        # 'minDataInLeaf': 100,  # 参数不存在,必须移除
        'minSumHessianInLeaf': 20.0, # 使用存在的 minSumHessianInLeaf 代替,值需要是float
        'lambdaL1': 0.1,
        'lambdaL2': 0.2,
        'maxBin': 255,
        'isUnbalance': True,
    }
    LGBM_NUM_ROUNDS = 200
# 修改后的 save_artifact_to_table 函数
# def save_artifact_to_table(spark_session, obj, item_name, item_type, item_version, description, table_name):
#     """
#     将Python对象序列化并以Base64编码存入MaxCompute表中。
#     """
#     print(f"开始保存产物 '{item_name}' (版本: {item_version}) 到表 '{table_name}'...")
#     pickled_obj = pickle.dumps(obj)
#     encoded_str = base64.b64encode(pickled_obj).decode('utf-8')
#     update_ts = datetime.now() # 使用datetime对象,spark会自动映射为DATETIME类型
#     df_to_write = spark_session.createDataFrame(
#         [(item_name, encoded_str, item_type, item_version, update_ts, description)],
#         ["item_name", "item_value", "item_type", "item_version", "update_time", "description"]
#     )
#     # 采用追加模式
#     # 注意:为了替换旧版本,你需要在写入前手动执行DELETE,或者在读取时过滤最新版本
#     # 这里我们依赖读取时过滤最新版本
#     df_to_write.write.mode("append").saveAsTable(table_name)
#     print(f"产物 '{item_name}' 已成功保存。")
# def load_artifact_from_table(spark_session, item_name, table_name, item_version=None):
#     """
#     从表中加载产物。
#     - 如果提供了 item_version,则精确加载该版本。
#     - 如果未提供 item_version,则加载最新版本。
#     """
#     if item_version:
#         print(f"开始从表 '{table_name}' 加载指定版本 '{item_version}' 的产物 '{item_name}'...")
#         query = f"""
#         SELECT item_value, item_version, update_time, description
#         FROM {table_name}
#         WHERE item_name = '{item_name}' AND item_version = '{item_version}'
#         LIMIT 1
#         """
#     else:
#         print(f"开始从表 '{table_name}' 加载最新版本的产物 '{item_name}'...")
#         query = f"""
#         SELECT item_value, item_version, update_time, description
#         FROM (
#             SELECT 
#                 item_value,
#                 item_version,
#                 update_time,
#                 description,
#                 ROW_NUMBER() OVER (PARTITION BY item_name ORDER BY update_time DESC) as rn
#             FROM {table_name}
#             WHERE item_name = '{item_name}'
#         ) t
#         WHERE rn = 1
#         """

#     result_df = spark_session.sql(query)
    
#     if result_df.rdd.isEmpty():
#         version_str = f"版本 '{item_version}' 的" if item_version else "任何版本的"
#         raise ValueError(f"在表 '{table_name}' 中未找到{version_str}产物 '{item_name}'")
        
#     loaded_artifact = result_df.collect()[0]
#     encoded_str = loaded_artifact['item_value']
    
#     print(f"  > 正在加载版本: {loaded_artifact['item_version']}")
#     print(f"  > 更新时间: {loaded_artifact['update_time']}")
#     print(f"  > 描述: {loaded_artifact['description']}")
    
#     pickled_obj = base64.b64decode(encoded_str)
#     obj = pickle.loads(pickled_obj)
    
#     print(f"产物 '{item_name}' 加载成功。")
#     return obj


# 3. 批处理目标编码器 - (基本保持原样,原实现已考虑优化)
class BatchTargetEncoder:
    def __init__(self, smoothing=10, min_samples=10):
        self.encodings = {}
        self.global_means = {}
        self.smoothing = smoothing
        self.min_samples = min_samples

    def fit(self, df, categorical_cols, label_col, batch_size=None):
        print(f"开始对{len(categorical_cols)}个分类特征进行目标编码...")
        global_mean_df = df.selectExpr(f"AVG(CAST({label_col} AS DOUBLE)) as global_mean")
        global_mean = global_mean_df.collect()[0]['global_mean']
        print(f"标签全局均值: {global_mean}")

        # 原有的批处理逻辑已经比较优化,这里保持
        batch_size = batch_size or min(50, len(categorical_cols))
        for i in range(0, len(categorical_cols), batch_size):
            batch_cols = categorical_cols[i:i+batch_size]
            print(f"处理特征批次 {i//batch_size + 1}/{(len(categorical_cols)+batch_size-1)//batch_size}, 特征数: {len(batch_cols)}")
            
            for col_name in batch_cols:
                start_time = time.time()
                print(f"  处理特征: {col_name}")
                # 您的SQL聚合方式是高效的,避免了在Python端循环
                encoding_df = spark.sql(f"""
                    SELECT 
                        `{col_name}`,
                        COUNT(*) as count,
                        AVG(CAST(`{label_col}` AS DOUBLE)) as category_mean
                    FROM {Config.TRAIN_TABLE}
                    WHERE pt = '{Config.TRAIN_PARTITION}'
                    AND `{col_name}` IS NOT NULL  
                    GROUP BY `{col_name}`
                    HAVING COUNT(*) >= {self.min_samples}
                """)
                
                encoding_df = encoding_df.withColumn(
                    "encoded_value",
                    (col("count") * col("category_mean") + self.smoothing * lit(global_mean)) / 
                    (col("count") + self.smoothing)
                )

                mappings = {
                    row[col_name]: float(row["encoded_value"]) 
                    for row in encoding_df.select(col_name, "encoded_value").collect()
                }

                self.encodings[col_name] = mappings
                self.global_means[col_name] = float(global_mean)
                print(f"  特征 {col_name} 处理完成, 耗时: {time.time() - start_time:.2f}秒, 编码值数量: {len(mappings)}")
                
                # unpersist可能不是必须的,因为encoding_df是临时变量,但显式回收有好处
                encoding_df.unpersist() if encoding_df.is_cached else None
                gc.collect()
            
        print("所有特征目标编码完成!")
        return self

    def transform(self, df, categorical_cols):
        print(f"开始应用目标编码到{len(categorical_cols)}个分类特征...")
        
        df_transformed = df
        for col_name in categorical_cols:
            if col_name not in self.encodings:
                print(f"警告: 特征 {col_name} 没有编码映射,将使用全局均值")
                default_value = self.global_means.get(col_name, 0.0) # 提供一个备用默认值
                encoded_col_name = f"{col_name}_encoded"
                df_transformed = df_transformed.withColumn(encoded_col_name, lit(default_value).cast(DoubleType()))
                continue
                
            start_time = time.time()
            
            encoder_map_bc = spark.sparkContext.broadcast(self.encodings[col_name])
            default_value = self.global_means[col_name]
            
            @udf(DoubleType())
            def encode_udf(x):
                # UDF中处理None值
                if x is None:
                    return default_value
                return encoder_map_bc.value.get(x, default_value)
            
            encoded_col_name = f"{col_name}_encoded"
            df_transformed = df_transformed.withColumn(encoded_col_name, encode_udf(col(col_name)))
            
            print(f"特征 {col_name} 编码完成, 耗时: {time.time() - start_time:.2f}秒")
            
        return df_transformed
    
    # save和load方法保持原样,设计得很好

    def save(self, filepath):
        """保存编码器 - 避免大对象序列化问题"""
        os.makedirs(os.path.dirname(filepath), exist_ok=True)
        # ... 原有代码 ...
        encodings_path = filepath + ".encodings"
        os.makedirs(encodings_path, exist_ok=True)
        feature_batches = list(self.encodings.keys())
        for i, feature in enumerate(feature_batches):
            feature_path = os.path.join(encodings_path, f"feature_{i}.pkl")
            with open(feature_path, 'wb') as f:
                pickle.dump({'name': feature,'mapping': self.encodings[feature]}, f)
        meta_data = {
            'global_means': self.global_means, 'smoothing': self.smoothing,
            'min_samples': self.min_samples, 'num_features': len(feature_batches)
        }
        with open(filepath + ".meta", 'wb') as f:
            pickle.dump(meta_data, f)
        print(f"编码器已保存至: {filepath}")

    @classmethod
    def load(cls, filepath):
        """加载编码器 - 支持大对象分批加载"""
        # ... 原有代码 ...
        with open(filepath + ".meta", 'rb') as f:
            meta_data = pickle.load(f)
        encoder = cls(smoothing=meta_data['smoothing'],min_samples=meta_data['min_samples'])
        encoder.global_means = meta_data['global_means']
        encoder.encodings = {}
        encodings_path = filepath + ".encodings"
        for i in range(meta_data['num_features']):
            feature_path = os.path.join(encodings_path, f"feature_{i}.pkl")
            with open(feature_path, 'rb') as f:
                feature_data = pickle.load(f)
                encoder.encodings[feature_data['name']] = feature_data['mapping']
        print(f"编码器已从 {filepath} 加载, 包含 {len(encoder.encodings)} 个特征的编码")
        return encoder


# 4. 数据处理工具函数
def identify_column_types(df, id_cols, label_col):
    """识别数值型和分类型特征"""
    schema = df.schema
    # 扩大数值类型范围
    numeric_types = (DoubleType, IntegerType, FloatType, DecimalType, LongType, BooleanType)
    categorical_cols = []
    numeric_cols = []

    for field in schema.fields:
        col_name = field.name
        # 排除ID列、标签列和分区列
        if col_name in id_cols or col_name == label_col or col_name == "pt":
            continue
        
        if isinstance(field.dataType, numeric_types):
            numeric_cols.append(col_name)
        # 仅将StringType视为分类,其他复杂类型可单独处理
        elif isinstance(field.dataType, StringType):
            categorical_cols.append(col_name)
    
    return numeric_cols, categorical_cols

# 核心优化点 2: 重写数据预处理函数,避免在循环中触发Job
def preprocess_data(df, numeric_cols, categorical_cols):
    """
    数据预处理:处理空值。
    - 对分类特征,用 'MISSING' 字符串填充空值。
    - 对数值特征,用该列的均值填充空值。
    此实现通过一次性计算所有均值,避免了在循环中触发Spark Job,从而解决了磁盘空间问题。
    """
    print("开始进行高效数据预处理...")
    
    # 1. 填充分类特征的空值
    print("  填充分类特征的空值...")
    df_processed = df.na.fill("MISSING", subset=categorical_cols)

    # 2. 一次性计算所有数值列的均值
    print("  一次性计算所有数值特征的均值...")
    if not numeric_cols:
        print("  没有数值特征需要填充。")
        return df_processed
        
    # 构建聚合表达式列表
    avg_exprs = [avg(col(c)).alias(c) for c in numeric_cols]
    
    # 执行一次聚合操作,获取所有均值
    avg_values_row = df_processed.agg(*avg_exprs).collect()[0]
    avg_values_dict = avg_values_row.asDict()

    # 过滤掉值为None的均值(如果某列全为NULL,其均值也为NULL)
    fill_values = {k: float(v) for k, v in avg_values_dict.items() if v is not None}
    
    print(f"  计算出的均值将用于填充 {len(fill_values)} 个数值列。")
    
    # 3. 使用计算出的均值填充所有数值列的空值
    df_processed = df_processed.na.fill(fill_values)
    
    print("数据预处理完成。")
    return df_processed

# 5. 主处理流程
def main():
    start_time_total = datetime.now()
    print(f"开始执行时间: {start_time_total}")
    try:
        # 6.1 加载训练数据
        print("加载训练数据...")
        train_df = spark.sql(f"SELECT * FROM {Config.TRAIN_TABLE} WHERE pt = '{Config.TRAIN_PARTITION}'") \
                        .repartition(Config.NUM_PARTITIONS)
        train_df.cache()
        train_count = train_df.count()
        print(f"训练数据加载完成,总行数: {train_count}")
        all_columns = train_df.columns
        id_cols_set = set(Config.ID_COLUMNS)
        cols_to_exclude_in_prediction = {Config.LABEL_COLUMN, "pt"} | id_cols_set
        # 从训练集的列中筛选出预测集需要的列
        # 预测时需要ID列来进行关联,所以不排除它们
        predict_cols = [f"`{c}`" for c in all_columns if c not in cols_to_exclude_in_prediction]
        
        # 将列名列表拼接成SQL查询字符串
        predict_cols_str = ", ".join(predict_cols)
        print(f"\n--- 将用于加载预测数据的列 ---")
        print(predict_cols_str)
        print("---------------------------------\n")
        # 6.2 识别特征类型
        numeric_cols, categorical_cols = identify_column_types(train_df, Config.ID_COLUMNS, Config.LABEL_COLUMN)
        print(f"特征类型识别完成: {len(numeric_cols)}个数值特征,{len(categorical_cols)}个分类特征")
        
        # 6.3 数据预处理 (使用优化后的函数)
        train_df = preprocess_data(train_df, numeric_cols, categorical_cols)

        # 6.4 创建并拟合目标编码器
        target_encoder = BatchTargetEncoder(smoothing=10, min_samples=20)
        target_encoder.fit(train_df, categorical_cols, Config.LABEL_COLUMN)
        # # 保存编码器
        # encoder_path = os.path.join(Config.MODEL_DIR, Config.ENCODERS_FILE)
        # target_encoder.save(encoder_path)
        # print(f"目标编码器已保存至: {encoder_path}")
        # target_encoder.save(Config.ENCODER_PATH)
        # print(f"目标编码器已保存至: {Config.ENCODER_PATH}")
        # save_artifact_to_table(spark, target_encoder, Config.ENCODER_ITEM_NAME, 'TARGET_ENCODER', 'v1.0', '14日流失预警模型的目标编码器', Config.ARTIFACTS_TABLE)
        # 6.5 应用编码
        print("应用目标编码到训练数据...")
        train_encoded_df = target_encoder.transform(train_df, categorical_cols)
        train_df.unpersist() # 释放旧的DataFrame

        # 6.6 准备分布式训练数据
        print("准备分布式模型训练数据...")
        feature_cols = numeric_cols + [f"{col}_encoded" for col in categorical_cols]
        assembler = VectorAssembler(inputCols=feature_cols, outputCol="features", handleInvalid="skip")
        data_for_training = assembler.transform(train_encoded_df)
        final_df = data_for_training.select(
            col(Config.LABEL_COLUMN).alias("label"), 
            col("features")
        )

        # ======================== 推荐的替代方案 ========================
        # 6.7 创建验证指示列,而不是使用 randomSplit
        validation_col_name = "is_validation"
        # 添加一个布尔列,根据随机数将其标记为训练(false)或验证(true)
        df_with_validation = final_df.withColumn(
            validation_col_name,
            (F.rand(seed=Config.RANDOM_SEED) >= (1.0 - Config.TEST_SIZE))
        )
        print("将准备好的数据写入临时视图以切断RDD血统...")
        # 创建一个临时的、只在当前Spark Session内有效的视图
        temp_view_name = "data_ready_for_barrier_training"
        df_with_validation.createOrReplaceTempView(temp_view_name)

        # 强制物化这个视图到缓存中。这一步至关重要!
        spark.catalog.cacheTable(temp_view_name)

        # 从这个干净的、已缓存的视图中重新读取数据
        # 这将确保 training_data_final 的血统非常简单
        training_data_final = spark.table(temp_view_name)

        # 触发一个action来确认物化完成,并检查数据
        print("验证已缓存的最终训练数据...")
        split_summary = training_data_final.groupBy("is_validation").count().collect()
        print("按验证指示列分割的数据统计:")
        for row in split_summary:
            print(f"  is_validation={row['is_validation']}: {row['count']} 行")
        # ========================================================================

        # 释放所有上游DataFrame的缓存,因为它们的数据已经被物化到新表中了
        final_df.unpersist()
        df_with_validation.unpersist() # 确保释放

        # 6.8 训练分布式LightGBM模型
        print("开始训练分布式LightGBM模型...")
        lgbm = LightGBMClassifier(
            numIterations=Config.LGBM_NUM_ROUNDS,
            labelCol="label",
            featuresCol="features",
            validationIndicatorCol="is_validation", # 确保指定了验证列
            useBarrierExecutionMode=False, #True,
            timeout=1800.0,
            defaultListenPort=25500,
            # earlyStoppingRound=10, # 配合验证集,可以开启早停
            **Config.LGBM_PARAMS 
        )
        start_time_train = time.time()
        # 将整个DataFrame传入,LightGBM会根据指示列自动分割
        model = lgbm.fit(training_data_final) #(df_with_validation) 
        print(f"分布式模型训练完成,耗时: {time.time() - start_time_train:.2f}秒")
        spark.catalog.uncacheTable(temp_view_name)
        spark.catalog.dropTempView(temp_view_name)
        # 训练完成后,可以释放缓存
        # df_with_validation.unpersist()
        # final_df = data_for_training.select(
        #     col(Config.LABEL_COLUMN).alias("label"), 
        #     col("features")
        # ).cache()
        # print("特征向量化完成。")
        # # 6.7 分割训练集和测试集
        # train_data, test_data = final_df.randomSplit([1.0 - Config.TEST_SIZE, Config.TEST_SIZE], seed=Config.RANDOM_SEED)
        # print(f"训练集大小: {train_data.count()}, 测试集大小: {test_data.count()}")
        # # 6.8 训练分布式LightGBM模型
        # print("开始训练分布式LightGBM模型...")
        # # ==============================================================================
        # #  核心修复:插入诊断代码,探查可用的参数
        # # ==============================================================================
        # # try:
        # #     import mmlspark
        # #     print(f"MMLSpark version: {mmlspark.__version__}")
        # # except ImportError:
        # #     print("Could not import mmlspark to check version.")

        # # 创建一个临时实例以检查参数
        # temp_lgbm_for_inspection = LightGBMClassifier()
        # print("\n--- 可用的 LightGBMClassifier 参数列表 ---")
        # for param in temp_lgbm_for_inspection.params:
        #     print(f"  参数名: {param.name}")
        #     print(f"  文档说明: {param.doc}\n")
        # print("------------------------------------------\n")
        # lgbm = LightGBMClassifier(
        #     numIterations=Config.LGBM_NUM_ROUNDS,
        #     labelCol="label",
        #     featuresCol="features",
        #     useBarrierExecutionMode=True,  # <--- 添加或修改此行
        #     timeout=1800.0,
        #     defaultListenPort=25500,
        #     # 从字典动态传入其他参数
        #     # 这样做的好处是,我们不需要在代码里写死参数名
        #     # 而是直接使用字典中的键值对
        #     # **注意**:这里我们假设字典中的key已经是正确的camelCase格式
        #     # 我们将在Config类中修正它们
        #     **Config.LGBM_PARAMS 
        # )
        # start_time_train = time.time()
        # model = lgbm.fit(train_data)
        # print(f"分布式模型训练完成,耗时: {time.time() - start_time_train:.2f}秒")
        # final_df.unpersist()

        # 6.9 在分布式测试集上评估模型
        print("在分布式测试集上评估模型...")
        predictions = model.transform(training_data_final.filter(F.col("is_validation") == True))
        predictions.cache() # 缓存预测结果以加速评估

        from pyspark.ml.evaluation import BinaryClassificationEvaluator
        evaluator_auc = BinaryClassificationEvaluator(rawPredictionCol="probability", labelCol="label", metricName="areaUnderROC")
        auc = evaluator_auc.evaluate(predictions)
        
        tp = predictions.filter("prediction == 1.0 AND label == 1.0").count()
        fp = predictions.filter("prediction == 1.0 AND label == 0.0").count()
        tn = predictions.filter("prediction == 0.0 AND label == 0.0").count()
        fn = predictions.filter("prediction == 0.0 AND label == 1.0").count()

        accuracy = (tp + tn) / (tp + fp + tn + fn) if (tp + fp + tn + fn) > 0 else 0
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

        print("\n--- 模型性能评估 ---")
        print(f"AUC: {auc:.4f}")
        print(f"准确率(Accuracy): {accuracy:.4f}")
        print(f"精确率(Precision): {precision:.4f}")
        print(f"召回率(Recall): {recall:.4f}")
        print(f"F1分数: {f1:.4f}")
        training_data_final.unpersist() 
        predictions.unpersist()
        # 6.10 保存Spark ML模型
        print("保存Spark ML LightGBM模型到表中...")
        # 提取原生的Booster对象,这个对象是可序列化的
        # native_booster = model._model
        # save_artifact_to_table(spark, native_booster, Config.MODEL_ITEM_NAME, 'LGB_MODEL', 'v1.0', f"14日流失预警模型, AUC on test set: {auc:.4f}", Config.ARTIFACTS_TABLE)
        # model.save(Config.SPARK_MODEL_PATH)
        # print(f"模型已保存至: {Config.SPARK_MODEL_PATH}")
        # --- 预测流程 ---
        print("\n开始加载并处理预测数据...")
        predict_df = spark.sql(f"SELECT {predict_cols_str} FROM {Config.TRAIN_TABLE} WHERE pt = '{Config.PREDICT_PARTITION}'") \
                          .repartition(Config.NUM_PARTITIONS)
        
        predict_df_processed = preprocess_data(predict_df, numeric_cols, categorical_cols)
        # loaded_encoder = load_artifact_from_table(spark, Config.ENCODER_ITEM_NAME, Config.ARTIFACTS_TABLE)
        predict_encoded_df = target_encoder.transform(predict_df_processed, categorical_cols)
        predict_vectorized = assembler.transform(predict_encoded_df)
        print("执行分布式预测...")
        # loaded_native_booster = load_artifact_from_table(spark, Config.MODEL_ITEM_NAME, Config.ARTIFACTS_TABLE)
        print("使用加载的原生Booster重建Spark ML模型...")
        # 使用 modelString 参数来从原生Booster的模型字符串重建一个可用于预测的Spark ML模型
        # reconstructed_model = LightGBMClassifier(
        #     modelString=loaded_native_booster.model_str()
        # )     
        predict_results = model.transform(predict_vectorized)
        get_prob_udf = udf(lambda v: float(v[1]) if v else 0.0, DoubleType())
        result_df = predict_results.select(
            *Config.ID_COLUMNS,
            get_prob_udf("probability").alias("prediction_probability"),
            col("prediction").cast(IntegerType()).alias("prediction_label")
        )

        # 6.11 写入结果
        print("写入预测结果...")
        for partition_key, partition_value in Config.OUTPUT_PARTITION.items():
            result_df = result_df.withColumn(partition_key, lit(partition_value))
            
        result_df.write \
            .format("orc") \
            .mode("overwrite") \
            .partitionBy(*Config.OUTPUT_PARTITION.keys()) \
            .saveAsTable(Config.OUTPUT_TABLE)
        
        print(f"预测结果已写入表: {Config.OUTPUT_TABLE}")
        print(f"分区路径: {'/'.join([f'{k}={v}' for k, v in Config.OUTPUT_PARTITION.items()])}")

    except Exception as e:
        print(f"处理过程中发生严重错误: {str(e)}")
        import traceback
        traceback.print_exc()
    finally:
        end_time_total = datetime.now()
        total_time = (end_time_total - start_time_total).total_seconds() / 60
        print(f"\n处理结束时间: {end_time_total}")
        print(f"总耗时: {total_time:.2f} 分钟")
        print("清理资源...")
        spark.catalog.clearCache()
        spark.stop()
        print("Spark Session已停止,处理完成。")


if __name__ == "__main__":
    main()

Other info / logs

No response

What component(s) does this bug affect?

  • area/cognitive: Cognitive project
  • area/core: Core project
  • area/deep-learning: DeepLearning project
  • area/lightgbm: Lightgbm project
  • area/opencv: Opencv project
  • area/vw: VW project
  • area/website: Website
  • area/build: Project build system
  • area/notebooks: Samples under notebooks folder
  • area/docker: Docker usage
  • area/models: models related issue

What language(s) does this bug affect?

  • language/scala: Scala source code
  • language/python: Pyspark APIs
  • language/r: R APIs
  • language/csharp: .NET APIs
  • language/new: Proposals for new client languages

What integration(s) does this bug affect?

  • integrations/synapse: Azure Synapse integrations
  • integrations/azureml: Azure ML integrations
  • integrations/databricks: Databricks integrations

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions