In [171]:
import warnings
warnings.filterwarnings('ignore')
import pyspark
import json
import requests
import pymysql
import numpy as np
import pandas as pd
import pyspark.pandas as ps
import pyspark.sql.functions as F
from scipy import stats
from pca import pca

from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV

from scipy import stats
from functools import reduce
from pyspark.sql import DataFrame
from typing import Optional
# from backend_spark.doris_common.doris_client import DorisClient
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, TimestampType, IntegerType, FloatType
from pyspark.sql.functions import pandas_udf, PandasUDFType, monotonically_increasing_id, lit, col, when, countDistinct

In [2]:
import os
import warnings
warnings.filterwarnings('ignore')
from pyspark.sql import SparkSession

os.environ['PYSPARK_PYTHON'] = '/usr/local/python-3.9.13/bin/python3'

spark = SparkSession.builder \
    .appName("pandas_udf") \
    .config('spark.sql.session.timeZone', 'Asia/Shanghai') \
    .config("spark.scheduler.mode", "FAIR") \
    .config('spark.driver.memory', '1024m') \
    .config('spark.driver.cores', '3') \
    .config('spark.executor.memory', '1024m') \
    .config('spark.executor.cores', '1') \
    .config('spark.cores.max', '2') \
    .config('spark.driver.host','192.168.22.28') \
    .master("spark://192.168.12.47:7077,192.168.12.48:7077") \
    .getOrCreate()

### 数据预处理

In [251]:
class DataPreprocessorForInline:
    def __init__(self,
                 df: pyspark.sql.dataframe,
                 columns_list: list[str],
                 certain_column: str,
                 key_words: list[str],
                 convert_to_numeric_list: list[str]):
        self.df = df
        self.columns_list = columns_list
        self.certain_column = certain_column
        self.key_words = key_words
        self.convert_to_numeric_list = convert_to_numeric_list

    def select_columns(self, df):
        return df.select(self.columns_list)

    def exclude_some_data(self, df):
        key_words_str = '|'.join(self.key_words)
        df_filtered = df.filter(~col(self.certain_column).rlike(key_words_str))
        return df_filtered

    def pre_process(self, df):
        for column in self.convert_to_numeric_list:
            df = df.withColumn(column, col(column).cast('double'))
        if 'SITE_COUNT' in self.convert_to_numeric_list:
            self.convert_to_numeric_list.remove('SITE_COUNT')
        df = df.dropna(subset=self.convert_to_numeric_list, how='all')
        return df

    def run(self):
        df_select = self.select_columns(df=self.df)
        df_esd = self.exclude_some_data(df=df_select)
        df_pp = self.pre_process(df=df_esd)
        return df_pp

In [61]:
# 1. 读取数据
df_pandas = pd.read_csv("D:/Jupyterfiles/晶合MVAFDC_general开发/MVAanlysisDevelop/inline_algorithm/inline_case5_label.csv")
df_spark = ps.from_pandas(df_pandas).to_spark()
num_rows = df_spark.count()
num_columns = len(df_spark.columns)
print(f"df_spark shape: ({num_rows}, {num_columns})")

# 2. 数据预处理
dp = DataPreprocessorForInline(df=df_spark,
                               columns_list=['WAFER_ID', 'OPE_NO', 'INLINE_PARAMETER_ID', 'SITE_COUNT', 'AVERAGE', 'SITE1_VAL', 
                                             'SITE2_VAL', 'SITE3_VAL', 'SITE4_VAL', 'SITE5_VAL', 'SITE6_VAL', 'SITE7_VAL', 'SITE8_VAL', 
                                             'SITE9_VAL', 'SITE10_VAL', 'SITE11_VAL', 'SITE12_VAL', 'SITE13_VAL', 'SITE14_VAL', 
                                             'SITE15_VAL', 'SITE16_VAL', 'SITE17_VAL'],
                               certain_column='INLINE_PARAMETER_ID',
                               key_words=['CXS', 'CYS', 'FDS'],
                               convert_to_numeric_list=['SITE_COUNT', 'AVERAGE', 'SITE1_VAL', 
                                             'SITE2_VAL', 'SITE3_VAL', 'SITE4_VAL', 'SITE5_VAL', 'SITE6_VAL', 'SITE7_VAL', 'SITE8_VAL', 
                                             'SITE9_VAL', 'SITE10_VAL', 'SITE11_VAL', 'SITE12_VAL', 'SITE13_VAL', 'SITE14_VAL', 
                                             'SITE15_VAL', 'SITE16_VAL', 'SITE17_VAL'])
df_pp_ = dp.run()
num_rows = df_pp_.count()
num_columns = len(df_pp_.columns)
print(f"df_pp_ shape: ({num_rows}, {num_columns})")

df_spark shape: (32278, 144)
df_pp_ shape: (31791, 22)


In [62]:
df_pandas_select = df_pp_.toPandas()

In [63]:
df_pandas_select.columns

Index(['WAFER_ID', 'OPE_NO', 'INLINE_PARAMETER_ID', 'SITE_COUNT', 'AVERAGE',
       'SITE1_VAL', 'SITE2_VAL', 'SITE3_VAL', 'SITE4_VAL', 'SITE5_VAL',
       'SITE6_VAL', 'SITE7_VAL', 'SITE8_VAL', 'SITE9_VAL', 'SITE10_VAL',
       'SITE11_VAL', 'SITE12_VAL', 'SITE13_VAL', 'SITE14_VAL', 'SITE15_VAL',
       'SITE16_VAL', 'SITE17_VAL'],
      dtype='object')

In [64]:
df_pandas_select['OPE_NO'].unique()

array(['1F.FQE10', '1C.CDG10', '1U.CDG10', '1U.CDG20', '1U.EQW10',
       '1U.PQW10', '1U.PQX10', '1U.ECU10', '1V.ECU10', '1V.PQA10',
       '1V.PQX10', '1V.PQX20', '2U.CDG10', '2U.CDG20', '2U.EQW10',
       '2U.PQA10', '2U.PQX10', '2V.ECU10', '2V.PQA10', '2V.PQW10',
       '2V.PQX10', '2V.PQX20', '3U.CDG10', '3U.CDG20', '3U.PQA10',
       '3U.PQW10', '3U.PQX10', '6V.CDG10', '6V.CDG20', '6V.PQA10',
       '6V.PQW10', '6V.PQX10', '7U.ECU10', '7U.EQA10', '7U.EQA20',
       '7U.PQA10', '7U.PQA20', '7U.PQW10', '7U.PQX10', 'PV.CDG10',
       'PV.EQA10', 'PV.PQA10', 'PV.PQX10', 'TM.EQA10', 'TM.PQA10',
       'TM.PQW10', 'TM.PQX10', 'TV.CDG10', 'TV.EQA20', 'TV.PQA10',
       'TV.PQW10', 'TV.PQX10', '1V.EQW10', '1V.PQW10', '2U.CQC50',
       '2U.PQW10', '2V.EQW10', '3U.CQC10', '3U.EQW10', '5V.PQA10',
       '6U.CDG10', '6U.CDG20', '6U.PQA10', '7U.CDG10', '7U.EQW10',
       'PV.PQW10', 'XX.PQW03', 'XX.PQX01', 'XX.PQX02', '1C.PQA10',
       '1V.EQW20', '6V.EQA10', 'XX.CCX01', 'XX.CCZ01', '1U.EQW

In [176]:
df_pp_

DataFrame[WAFER_ID: string, OPE_NO: string, INLINE_PARAMETER_ID: string, SITE_COUNT: double, AVERAGE: double, SITE1_VAL: double, SITE2_VAL: double, SITE3_VAL: double, SITE4_VAL: double, SITE5_VAL: double, SITE6_VAL: double, SITE7_VAL: double, SITE8_VAL: double, SITE9_VAL: double, SITE10_VAL: double, SITE11_VAL: double, SITE12_VAL: double, SITE13_VAL: double, SITE14_VAL: double, SITE15_VAL: double, SITE16_VAL: double, SITE17_VAL: double]

### 根据site来处理数据、根据OPE_NO拟合模型

In [181]:
def process_missing_values_for_site(df, good_site_columns, bad_site_columns, missing_value_threshold=0.6, process_miss_site_mode='drop'):
    assert process_miss_site_mode in ['drop', 'fill']
    site_columns = good_site_columns + bad_site_columns
    if process_miss_site_mode == 'drop':
        # 按照missing_value_threshold删除行数据
        df = df.dropna(subset=site_columns, thresh=missing_value_threshold)
    else:
        # 使用该行的AVERAGE去填充对应行site的缺失值
        df[site_columns] = df[site_columns].apply(lambda col: col.fillna(df['AVERAGE']))
    return df 


def calculate_statistics(row):
    return pd.Series({
        'MAX_VAL': row.max(),
        'MIN_VAL': row.min(),
        'MEDIAN': row.median(),
        'AVERAGE': row.mean(),
        'STD_DEV': row.std(),
        'PERCENTILE_25': row.quantile(0.25),
        'PERCENTILE_75': row.quantile(0.75)})


def calculate_site_stats(df, site_columns, good_or_bad):
    assert good_or_bad in ['good', 'bad'], "Label could only be 'good' or 'bad'"
    selected_df = df[['WAFER_ID', 'OPE_NO', 'INLINE_PARAMETER_ID'] + site_columns].reset_index(drop=True)
    # 对每一行进行统计计算
    side_features = selected_df.apply(lambda row: calculate_statistics(row[site_columns]), axis=1)
    side_features = side_features.fillna(0)
    df_with_features = pd.concat([selected_df, side_features], axis=1)
    if good_or_bad == 'good':
        df_with_features['label'] = 0
    else:
        df_with_features['label'] = 1
    return df_with_features


def extract_features_by_site(df, good_site_columns, bad_site_columns, missing_value_threshold=0.6, process_miss_site_mode='drop'):
    df_pandas_specific_oper = process_missing_values_for_site(df=df,  
                                                              good_site_columns=good_site_columns, 
                                                              bad_site_columns=bad_site_columns, 
                                                              missing_value_threshold=missing_value_threshold, 
                                                              process_miss_site_mode=process_miss_site_mode)
    if df_pandas_specific_oper.shape[0] != 0:
        side_with_features1 = calculate_site_stats(df_pandas_specific_oper, good_site_columns, good_or_bad='good')
        side_with_features2 = calculate_site_stats(df_pandas_specific_oper, bad_site_columns, good_or_bad='bad')
        side_with_features1_select = side_with_features1[['WAFER_ID', 'OPE_NO', 'INLINE_PARAMETER_ID', 'MAX_VAL', 'MIN_VAL', 'MEDIAN', 
                                                          'AVERAGE', 'STD_DEV', 'PERCENTILE_25', 'PERCENTILE_75', 'label']]
        side_with_features2_select = side_with_features2[['WAFER_ID', 'OPE_NO', 'INLINE_PARAMETER_ID', 'MAX_VAL', 'MIN_VAL', 'MEDIAN', 
                                                          'AVERAGE', 'STD_DEV', 'PERCENTILE_25', 'PERCENTILE_75', 'label']]
        side_with_features_all = pd.concat([side_with_features1_select, side_with_features2_select], axis=0)
        return side_with_features_all
    
    
def process_missing_values(df, columns_to_process, missing_value_threshold):
    for column in columns_to_process:
        missing_percentage = df[column].isnull().mean()
        if missing_percentage > missing_value_threshold:
            df = df.drop(columns=[column])
        else:
            df[column] = df[column].fillna(df[column].mean())
    return df


def get_pivot_table(df, columns_to_process, missing_value_threshold):
    df_specific = process_missing_values(df, columns_to_process, missing_value_threshold)
    index_list = ['WAFER_ID', 'label']
    values_list = df.columns.difference(['WAFER_ID', 'OPE_NO', 'INLINE_PARAMETER_ID', 'label'])
    pivot_result = df.pivot_table(index=index_list, 
                                  columns=['OPE_NO', 'INLINE_PARAMETER_ID'], 
                                  values=values_list)
    pivot_result.columns = pivot_result.columns.map('#'.join)
    pivot_result = process_missing_values(pivot_result, pivot_result.columns, missing_value_threshold)
    pivot_result = pivot_result.reset_index(drop=False)
    return pivot_result

In [246]:
def fit_pca_model(df, by, 
                  good_site_columns, 
                  bad_site_columns, 
                  process_miss_site_mode,
                  columns_to_process, 
                  missing_value_threshold):
    schema_all = StructType([StructField("features", StringType(), True),
                             StructField("importance", FloatType(), True)])

    @pandas_udf(returnType=schema_all, functionType=PandasUDFType.GROUPED_MAP)
    def get_model_result(df_run):
        side_with_features_all = extract_features_by_site(df=df_run, good_site_columns=good_site_columns, 
                                                      bad_site_columns=bad_site_columns, 
                                                      missing_value_threshold=missing_value_threshold, 
                                                      process_miss_site_mode=process_miss_site_mode)
        if side_with_features_all is None:
            return pd.DataFrame({"features": "STATS#OPE#PARAM", "importance": -100}, index=[0])
        
        pivot_result = get_pivot_table(df=side_with_features_all, columns_to_process=columns_to_process, missing_value_threshold=missing_value_threshold)
        x_train = pivot_result[pivot_result.columns.difference(['WAFER_ID', 'label']).tolist()]

        if x_train.shape[1] > 1:
            n_components = min(min(x_train.shape)-2, 20)
            model = pca(n_components=n_components, verbose=None)
            results = model.fit_transform(x_train)
            res_top = results['topfeat']
            res_top_select = res_top[res_top['type'] == 'best'][['feature', 'loading']]
            res_top_select['importance'] = abs(res_top_select['loading'])
            res_top_select = res_top_select.rename(columns={'feature': 'features'}).drop("loading", axis=1).drop_duplicates()
            return res_top_select
        else:
            res_top_select = pd.DataFrame({"features": "STATS#OPE#PARAM", "importance": -101}, index=[0])
            return res_top_select
    return df.groupby(by).apply(get_model_result)

----PCA

In [247]:
good_site_columns = ['SITE4_VAL', 'SITE8_VAL', 'SITE9_VAL', 'SITE12_VAL', 'SITE13_VAL']
bad_site_columns = ['SITE2_VAL', 'SITE6_VAL', 'SITE7_VAL', 'SITE10_VAL', 'SITE11_VAL']
process_miss_site_mode = 'drop'
columns_to_process = ['AVERAGE', 'MAX_VAL', 'MEDIAN', 'MIN_VAL', 'STD_DEV', 'PERCENTILE_25', 'PERCENTILE_75']
missing_value_threshold = 0.6

res = fit_pca_model(df=df_pp_, 
                    by=['OPE_NO'],
                    good_site_columns=good_site_columns,
                    bad_site_columns=bad_site_columns,
                    process_miss_site_mode=process_miss_site_mode,
                    columns_to_process=columns_to_process,
                    missing_value_threshold=missing_value_threshold)
res.show()

+--------------------+----------+
|            features|importance|
+--------------------+----------+
|     STATS#OPE#PARAM|    -100.0|
|     STATS#OPE#PARAM|    -100.0|
|     STATS#OPE#PARAM|    -100.0|
|     STATS#OPE#PARAM|    -100.0|
|MAX_VAL#1C.PQX10#...|0.22420332|
|MEDIAN#1C.PQX10#MX02| 0.2645363|
|MAX_VAL#1C.PQX10#...|0.38837156|
|MAX_VAL#1C.PQX10#...| 0.4017471|
|PERCENTILE_75#1C....| 0.4057569|
|MEDIAN#1C.PQX10#MX03|0.29744515|
|MEDIAN#1C.PQX10#MY01| 0.3725794|
|MAX_VAL#1C.PQX10#...| 0.4259164|
|MIN_VAL#1C.PQX10#...|0.33780888|
|MAX_VAL#1C.PQX10#...|0.36506242|
|MIN_VAL#1F.FQE10#...|0.45260897|
|PERCENTILE_75#1F....| 0.5346337|
|     STATS#OPE#PARAM|    -100.0|
|     STATS#OPE#PARAM|    -100.0|
|MIN_VAL#1U.CQC10#...|0.90289927|
|PERCENTILE_25#1U....| 0.7958148|
+--------------------+----------+
only showing top 20 rows



In [248]:
res_p = res.toPandas()

In [228]:
res_p.query("importance < 0")['importance'].value_counts() 

-100.0    61
Name: importance, dtype: int64

In [249]:
res_p = res_p.query("importance > 0")
res_p.shape

(403, 2)

In [250]:
res_p.sort_values('importance').head(9)

Unnamed: 0,features,importance
364,PERCENTILE_75#PV.PQX10#MY04,0.199323
291,MEDIAN#6V.PQX10#FY01,0.201347
411,MIN_VAL#TV.PQX10#FY01,0.201777
337,MIN_VAL#7U.PQX10#FY01,0.201789
385,MIN_VAL#TM.PQX10#FY01,0.201818
457,MEDIAN#XX.PQX02#MY01,0.209944
361,MEDIAN#PV.PQX10#FY01,0.210319
202,MAX_VAL#2V.PQX20#MY04,0.211106
123,MIN_VAL#1V.PQX30#MX09,0.214443


----RF

In [220]:
def fit_rf_model(df, by, 
                  good_site_columns, 
                  bad_site_columns, 
                  process_miss_site_mode,
                  columns_to_process, 
                  missing_value_threshold):
    schema_all = StructType([StructField("features", StringType(), True),
                             StructField("importance", FloatType(), True)])

    @pandas_udf(returnType=schema_all, functionType=PandasUDFType.GROUPED_MAP)
    def get_model_result(df_run):
        side_with_features_all = extract_features_by_site(df=df_run, good_site_columns=good_site_columns, 
                                                          bad_site_columns=bad_site_columns, 
                                                          missing_value_threshold=missing_value_threshold, 
                                                          process_miss_site_mode=process_miss_site_mode)
        if side_with_features_all is None:
            return pd.DataFrame({"features": "STATS#OPE#PARAM", "importance": -100}, index=[0])
        
        pivot_result = get_pivot_table(df=side_with_features_all, columns_to_process=columns_to_process, missing_value_threshold=missing_value_threshold)
        x_train = pivot_result[pivot_result.columns.difference(['WAFER_ID', 'label']).tolist()]
        y_train = pivot_result[['label']]
        if min(x_train.shape) > 4 and y_train['label'].nunique() > 1:
            pipe = Pipeline(steps=[
                ('imputer', SimpleImputer(strategy='constant', fill_value=-999)),
                ('scaler', StandardScaler()),
                ('model', RandomForestClassifier(random_state=2024))])
            param_grid = {'model__n_estimators': [*range(10, 60, 10)],
                          'model__max_depth': [*range(5, 50, 10)],
                          'model__min_samples_split': [2, 5],
                          'model__min_samples_leaf': [1, 3]}
            grid = GridSearchCV(estimator=pipe, scoring='roc_auc', param_grid=param_grid, cv=3, n_jobs=-1)
            grid.fit(x_train.values, y_train.values.ravel())
            roc_auc_score_ = grid.best_score_
            if roc_auc_score_ >= 0.6:
                small_importance_res = pd.DataFrame({'features': x_train.columns,
                                                     'importance': grid.best_estimator_.steps[2][1].feature_importances_})
                return small_importance_res
            else:
                small_importance_res = pd.DataFrame({"features": "STATS#OPE#PARAM", "importance": -101}, index=[0])
                return small_importance_res
        else:
            small_importance_res = pd.DataFrame({"features": "STATS#OPE#PARAM", "importance": -102}, index=[0])
            return small_importance_res
    return df.groupby(by).apply(get_model_result)

In [221]:
good_site_columns = ['SITE4_VAL', 'SITE8_VAL', 'SITE9_VAL', 'SITE12_VAL', 'SITE13_VAL']
bad_site_columns = ['SITE2_VAL', 'SITE6_VAL', 'SITE7_VAL', 'SITE10_VAL', 'SITE11_VAL']
process_miss_site_mode = 'drop'
columns_to_process = ['AVERAGE', 'MAX_VAL', 'MEDIAN', 'MIN_VAL', 'STD_DEV', 'PERCENTILE_25', 'PERCENTILE_75']
missing_value_threshold = 0.6


res1 = fit_rf_model(df=df_pp_, 
                    by=['OPE_NO'],
                    good_site_columns=good_site_columns,
                    bad_site_columns=bad_site_columns,
                    process_miss_site_mode=process_miss_site_mode,
                    columns_to_process=columns_to_process,
                    missing_value_threshold=missing_value_threshold)
res1.show()

+--------------------+----------+
|            features|importance|
+--------------------+----------+
|     STATS#OPE#PARAM|    -100.0|
|     STATS#OPE#PARAM|    -100.0|
|     STATS#OPE#PARAM|    -100.0|
|     STATS#OPE#PARAM|    -100.0|
|AVERAGE#1C.PQX10#...|       0.1|
|AVERAGE#1C.PQX10#...|       0.1|
|AVERAGE#1C.PQX10#...|       0.1|
|AVERAGE#1C.PQX10#...|       0.0|
|AVERAGE#1C.PQX10#...|       0.0|
|AVERAGE#1C.PQX10#...|       0.0|
|AVERAGE#1C.PQX10#...|       0.0|
|AVERAGE#1C.PQX10#...|       0.0|
|AVERAGE#1C.PQX10#...|       0.0|
|AVERAGE#1C.PQX10#...|       0.0|
|AVERAGE#1C.PQX10#...|       0.0|
|AVERAGE#1C.PQX10#...|       0.0|
|AVERAGE#1C.PQX10#...|       0.0|
|AVERAGE#1C.PQX10#...|       0.0|
|AVERAGE#1C.PQX10#...|       0.0|
|AVERAGE#1C.PQX10#...|       0.0|
+--------------------+----------+
only showing top 20 rows



In [222]:
res_p1= res1.toPandas()
res_p1

Unnamed: 0,features,importance
0,STATS#OPE#PARAM,-100.0
1,STATS#OPE#PARAM,-100.0
2,STATS#OPE#PARAM,-100.0
3,STATS#OPE#PARAM,-100.0
4,AVERAGE#1C.PQX10#FX01,0.1
...,...,...
3800,STD_DEV#XX.PQX02#QY02,0.0
3801,STD_DEV#XX.PQX02#QY03,0.0
3802,STD_DEV#XX.PQX02#QY04,0.0
3803,STD_DEV#XX.PQX02#QY05,0.0


In [223]:
res_p1.query("importance < 0")['importance'].value_counts() 

-100.0    61
-102.0     7
-101.0     6
Name: importance, dtype: int64

In [224]:
res_p1 = res_p1.query("importance > 0")
res_p1.sort_values('importance').reset_index(drop=True)

Unnamed: 0,features,importance
0,PERCENTILE_75#1U.ECU10#T2S1,0.000085
1,MEDIAN#1U.ECU10#TGS1,0.000548
2,MIN_VAL#XX.PQX02#MX05,0.001118
3,AVERAGE#1U.ECU10#TGS1,0.001326
4,MEDIAN#XX.PQX02#QX04,0.001575
...,...,...
467,MAX_VAL#3U.CQC10#TDS1,0.233333
468,MAX_VAL#6V.CQC40#TDS1,0.236667
469,PERCENTILE_75#1U.CQC10#TDS1,0.238679
470,STD_DEV#3U.CQC50#TDS2,0.300000


In [172]:
# res1.toPandas().sort_values('importance').reset_index(drop=True)

In [173]:
# res.toPandas().sort_values('importance').reset_index(drop=True)

### 结果整理

In [252]:
class SplitInlineModelResults:
    def __init__(self, df: pyspark.sql.dataframe, request_id: str):
        self.df = df
        self.request_id = request_id

    @staticmethod
    def split_features(df: pd.DataFrame, index: int) -> str:
        return df['features'].apply(lambda x: x.split('#')[index])

    @staticmethod
    def get_split_features(df: pd.DataFrame) -> pd.DataFrame:
        df['STATISTIC_RESULT'] = SplitInlineModelResults.split_features(df, 0)
        df['OPE_NO'] = SplitInlineModelResults.split_features(df, 1)
        df['INLINE_PARAMETER_ID'] = SplitInlineModelResults.split_features(df, 2)
        df = df.drop(['features', 'STATISTIC_RESULT'], axis=1).reset_index(drop=True)
        return df

    @staticmethod
    def split_calculate_features(df: pyspark.sql.dataframe, by: str) -> pyspark.sql.dataframe:
        schema_all = StructType([StructField("OPE_NO", StringType(), True),
                                 StructField("INLINE_PARAMETER_ID", StringType(), True),
                                 StructField("importance", FloatType(), True)])

        @pandas_udf(returnType=schema_all, functionType=PandasUDFType.GROUPED_MAP)
        def get_model_result(df_run):
            split_table = SplitInlineModelResults.get_split_features(df_run)
            split_table_grpby = split_table.groupby(['OPE_NO', 'INLINE_PARAMETER_ID'])['importance'].sum().reset_index(drop=False)
            return split_table_grpby
        return df.groupby(by).apply(get_model_result)

    @staticmethod
    def add_certain_column(df: pyspark.sql.dataframe, by: str) -> pyspark.sql.dataframe:
        schema_all = StructType([StructField("OPER_NO", StringType(), True),
                                StructField("INLINE_PARAMETER_ID", StringType(), True),
                                StructField("AVG_SPEC_CHK_RESULT_COUNT", FloatType(), True),
                                StructField("weight", FloatType(), True),
                                StructField("weight_percent", FloatType(), True),
                                StructField("index_no", IntegerType(), True)])

        @pandas_udf(returnType=schema_all, functionType=PandasUDFType.GROUPED_MAP)
        def get_result(final_res):
            # 计算weight, 归一化
            final_res['importance'] = final_res['importance'].astype(float)
            final_res = final_res.query("importance > 0")
            final_res['weight'] = final_res['importance'] / final_res['importance'].sum()
            final_res['weight_percent'] = final_res['weight'] * 100
            final_res = final_res.sort_values('weight', ascending=False)
            # 增加列
            final_res['index_no'] = [i + 1 for i in range(len(final_res))]
            final_res['AVG_SPEC_CHK_RESULT_COUNT'] = 0.0
            final_res = final_res.rename(columns={'OPE_NO': 'OPER_NO'})
            final_res = final_res.drop(['importance', 'temp'], axis=1)
            return final_res
        return df.groupby(by).apply(get_result)

    def run(self):
        df = self.df.withColumn('temp', lit(0))
        res = self.split_calculate_features(df=df, by='temp')
        res = res.withColumn('temp', lit(1))
        final_res = self.add_certain_column(df=res, by='temp')
        final_res = final_res.withColumn('request_id', lit(self.request_id))
        return final_res

---pca的结果

In [253]:
print(res.count())
res.show()

464
+--------------------+----------+
|            features|importance|
+--------------------+----------+
|     STATS#OPE#PARAM|    -100.0|
|     STATS#OPE#PARAM|    -100.0|
|     STATS#OPE#PARAM|    -100.0|
|     STATS#OPE#PARAM|    -100.0|
|MAX_VAL#1C.PQX10#...|0.22420332|
|MEDIAN#1C.PQX10#MX02| 0.2645363|
|MAX_VAL#1C.PQX10#...|0.38837156|
|MAX_VAL#1C.PQX10#...| 0.4017471|
|PERCENTILE_75#1C....| 0.4057569|
|MEDIAN#1C.PQX10#MX03|0.29744515|
|MEDIAN#1C.PQX10#MY01| 0.3725794|
|MAX_VAL#1C.PQX10#...| 0.4259164|
|MIN_VAL#1C.PQX10#...|0.33780888|
|MAX_VAL#1C.PQX10#...|0.36506242|
|MIN_VAL#1F.FQE10#...|0.45260897|
|PERCENTILE_75#1F....| 0.5346337|
|     STATS#OPE#PARAM|    -100.0|
|     STATS#OPE#PARAM|    -100.0|
|MIN_VAL#1U.CQC10#...|0.90289927|
|PERCENTILE_25#1U....| 0.7958148|
+--------------------+----------+
only showing top 20 rows



In [254]:
final_res_pca = SplitInlineModelResults(df=res, request_id='855s').run()
print(final_res_pca.count())
final_res_pca.show()

197
+--------+-------------------+-------------------------+-----------+--------------+--------+----------+
| OPER_NO|INLINE_PARAMETER_ID|AVG_SPEC_CHK_RESULT_COUNT|     weight|weight_percent|index_no|request_id|
+--------+-------------------+-------------------------+-----------+--------------+--------+----------+
|7U.ECU10|               FGS1|                      0.0|0.040170122|      4.017012|       1|      855s|
|7U.ECU10|               TGS1|                      0.0|0.034431532|     3.4431534|       2|      855s|
|7U.ECU10|               TGS2|                      0.0|0.034431532|     3.4431534|       3|      855s|
|2V.ECU10|               TGS2|                      0.0|0.028241418|     2.8241417|       4|      855s|
|1V.ECU10|               TGS2|                      0.0|0.025817852|     2.5817852|       5|      855s|
|1U.CQC10|               TDS1|                      0.0|0.022986656|     2.2986655|       6|      855s|
|1V.ECU10|               TGS1|                      0.0|0.02

---rf的结果

In [255]:
final_res_rf = SplitInlineModelResults(df=res1, request_id='855rf').run()
print(final_res_rf.count())
final_res_rf.show()

224
+--------+-------------------+-------------------------+------------+--------------+--------+----------+
| OPER_NO|INLINE_PARAMETER_ID|AVG_SPEC_CHK_RESULT_COUNT|      weight|weight_percent|index_no|request_id|
+--------+-------------------+-------------------------+------------+--------------+--------+----------+
|3U.CQC10|               TDS1|                      0.0|  0.03846154|     3.8461537|       1|     855rf|
|1U.CQC10|               TDS1|                      0.0|  0.03846154|     3.8461537|       2|     855rf|
|3U.CQC50|               TDS2|                      0.0|  0.03041514|     3.0415142|       3|     855rf|
|6V.CQC20|               TDS1|                      0.0| 0.021013888|     2.1013887|       4|     855rf|
|6V.CQC40|               TDS1|                      0.0| 0.019517703|     1.9517703|       5|     855rf|
|6V.CQC20|               TDS2|                      0.0|  0.01744765|      1.744765|       6|     855rf|
|1U.PQX10|               FY05|                     

In [256]:
final_res_rf_topandas = final_res_rf.toPandas()

In [257]:
final_res_rf_topandas.tail(20)

Unnamed: 0,OPER_NO,INLINE_PARAMETER_ID,AVG_SPEC_CHK_RESULT_COUNT,weight,weight_percent,index_no,request_id
204,2U.PQX10,QX05,0.0,0.000357,0.035714,205,855rf
205,2V.PQX10,MY03,0.0,0.000357,0.035651,206,855rf
206,2V.PQX10,QY04,0.0,0.000347,0.034722,207,855rf
207,2U.PQX10,FY02,0.0,0.000278,0.027778,208,855rf
208,2V.PQX10,MX03,0.0,0.00027,0.027006,209,855rf
209,2U.PQX10,MY04,0.0,0.000256,0.025641,210,855rf
210,XX.PQX02,MX03,0.0,0.000217,0.02173,211,855rf
211,2U.PQX10,MX03,0.0,0.000216,0.021562,212,855rf
212,2V.PQX10,MX04,0.0,0.000189,0.018939,213,855rf
213,2V.PQX10,QX03,0.0,0.000184,0.018424,214,855rf
