In [1]:
# 数据来自于https://archive.ics.uci.edu/dataset/117/census+income+kdd
# 下载后把文件解压到data/raw/

In [2]:
from ucimlrepo import fetch_ucirepo 
  
# fetch dataset 
census_income_kdd = fetch_ucirepo(id=117) 

variable_description=census_income_kdd.variables
variable_description

Unnamed: 0,name,role,type,demographic,description,units,missing_values
0,AAGE,Feature,Integer,Age,age,,no
1,ACLSWKR,Feature,Categorical,,class of worker,,no
2,ADTINK,Feature,Integer,,industry code,,no
3,ADTOCC,Feature,Integer,Occupation,occupation code,,no
4,AHGA,Feature,Integer,Education Level,education,,no
5,AHSCOL,Feature,Categorical,Education Level,enrolled in edu last week,,no
6,AMARITL,Feature,Categorical,Marital Status,marital status,,no
7,AMJIND,Feature,Categorical,,major industry code,,no
8,AMJOCC,Feature,Categorical,Occupation,major occupation code,,no
9,ARACE,Feature,Categorical,Race,race,,no


In [3]:
# 获得所有存在missing_values的列
temp=variable_description[variable_description["missing_values"]=="yes"]
missing_columns=temp["name"].tolist()
missing_columns

['MIGMTR1', 'MIGMTR3', 'MIGMTR4', 'MIGSUN', 'PEFNTVTY', 'PEMNTVTY']

In [4]:
# 分别读取训练集和测试集
import pandas as pd

columns=['AAGE', 'ACLSWKR', 'ADTINK', 'ADTOCC', 'AHGA', 'AHRSPAY', 'AHSCOL',
       'AMARITL', 'AMJIND', 'AMJOCC', 'ARACE', 'AREORGN', 'ASEX', 'AUNMEM',
       'AUNTYPE', 'AWKSTAT', 'CAPGAIN', 'GAPLOSS', 'DIVVAL', 'FILESTAT',
       'GRINREG', 'GRINST', 'HHDFMX', 'HHDREL', 'MARSUPWRT', 'MIGMTR1',
       'MIGMTR3', 'MIGMTR4', 'MIGSAME', 'MIGSUN', 'NOEMP', 'PARENT',
       'PEFNTVTY', 'PEMNTVTY', 'PENATVTY', 'PRCITSHP', 'SEOTR', 'VETQVA',
       'VETYN', 'WKSWORK', 'year', 'income']

df_train=pd.read_csv("data/raw/census-income.data",names=columns)
df_test=pd.read_csv("data/raw/census-income.test",names=columns)

df_train

Unnamed: 0,AAGE,ACLSWKR,ADTINK,ADTOCC,AHGA,AHRSPAY,AHSCOL,AMARITL,AMJIND,AMJOCC,...,PEFNTVTY,PEMNTVTY,PENATVTY,PRCITSHP,SEOTR,VETQVA,VETYN,WKSWORK,year,income
0,73,Not in universe,0,0,High school graduate,0,Not in universe,Widowed,Not in universe or children,Not in universe,...,United-States,United-States,United-States,Native- Born in the United States,0,Not in universe,2,0,95,- 50000.
1,58,Self-employed-not incorporated,4,34,Some college but no degree,0,Not in universe,Divorced,Construction,Precision production craft & repair,...,United-States,United-States,United-States,Native- Born in the United States,0,Not in universe,2,52,94,- 50000.
2,18,Not in universe,0,0,10th grade,0,High school,Never married,Not in universe or children,Not in universe,...,Vietnam,Vietnam,Vietnam,Foreign born- Not a citizen of U S,0,Not in universe,2,0,95,- 50000.
3,9,Not in universe,0,0,Children,0,Not in universe,Never married,Not in universe or children,Not in universe,...,United-States,United-States,United-States,Native- Born in the United States,0,Not in universe,0,0,94,- 50000.
4,10,Not in universe,0,0,Children,0,Not in universe,Never married,Not in universe or children,Not in universe,...,United-States,United-States,United-States,Native- Born in the United States,0,Not in universe,0,0,94,- 50000.
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
199518,87,Not in universe,0,0,7th and 8th grade,0,Not in universe,Married-civilian spouse present,Not in universe or children,Not in universe,...,Canada,United-States,United-States,Native- Born in the United States,0,Not in universe,2,0,95,- 50000.
199519,65,Self-employed-incorporated,37,2,11th grade,0,Not in universe,Married-civilian spouse present,Business and repair services,Executive admin and managerial,...,United-States,United-States,United-States,Native- Born in the United States,0,Not in universe,2,52,94,- 50000.
199520,47,Not in universe,0,0,Some college but no degree,0,Not in universe,Married-civilian spouse present,Not in universe or children,Not in universe,...,Poland,Poland,Germany,Foreign born- U S citizen by naturalization,0,Not in universe,2,52,95,- 50000.
199521,16,Not in universe,0,0,10th grade,0,High school,Never married,Not in universe or children,Not in universe,...,United-States,United-States,United-States,Native- Born in the United States,0,Not in universe,2,0,95,- 50000.


## 处理缺失值

In [5]:
def get_unique_len(column_name):
    temp=df_train[column_name].unique()
    print(f"Unique len of {column_name}:",len(temp))
    return temp

def get_unique(column_name):
    temp=df_train[column_name].unique()
    print(f"Unique values of {column_name}:", temp)
    return temp

In [6]:
# Uncomment the following lines if you are using Windows!
import findspark
findspark.init()
findspark.find()

import pyspark

from pyspark.sql import SparkSession
from pyspark import SparkContext, SQLContext
from pyspark.sql.functions import when, col

appName = "mmoe"
master = "local"

# Create Configuration object for Spark.
conf = pyspark.SparkConf()\
    .set('spark.driver.host','127.0.0.1')\
    .setAppName(appName)\
    .setMaster(master)

# Create Spark Context with the new configurations rather than relying on the default one
sc = SparkContext.getOrCreate(conf=conf)

# You need to create SQL Context to conduct some database operations like what we will see later.
sqlContext = SQLContext(sc)

# If you have SQL context, you create the session from the Spark Context
spark = sqlContext.sparkSession.builder.getOrCreate()



In [7]:
data = df_train
df = spark.createDataFrame(data)
df.show(1, vertical = True)

-RECORD 0-------------------------
 AAGE      | 73                   
 ACLSWKR   |  Not in universe     
 ADTINK    | 0                    
 ADTOCC    | 0                    
 AHGA      |  High school grad... 
 AHRSPAY   | 0                    
 AHSCOL    |  Not in universe     
 AMARITL   |  Widowed             
 AMJIND    |  Not in universe ... 
 AMJOCC    |  Not in universe     
 ARACE     |  White               
 AREORGN   |  All other           
 ASEX      |  Female              
 AUNMEM    |  Not in universe     
 AUNTYPE   |  Not in universe     
 AWKSTAT   |  Not in labor force  
 CAPGAIN   | 0                    
 GAPLOSS   | 0                    
 DIVVAL    | 0                    
 FILESTAT  |  Nonfiler            
 GRINREG   |  Not in universe     
 GRINST    |  Not in universe     
 HHDFMX    |  Other Rel 18+ ev... 
 HHDREL    |  Other relative o... 
 MARSUPWRT | 1700.09              
 MIGMTR1   |  ?                   
 MIGMTR3   |  ?                   
 MIGMTR4   |  ?     

In [8]:
def check_distinct_values(col_name):
    column_values = df.select(col_name).distinct().collect()
    values = [row[col_name] for row in column_values]
    print(f"Distinct values of {col_name}:", values)

def check_distinct_len(col_name):
    column_values = df.select(col_name).distinct().collect()
    values = [row[col_name] for row in column_values]
    print(f"Distinct len of {col_name}:", len(values))

In [9]:
# # 分析可得，缺失值的表现形式为' ?'
for column_name in missing_columns:
    check_distinct_values(column_name)

Distinct values of MIGMTR1: [' MSA to MSA', ' NonMSA to MSA', ' Abroad to MSA', ' Abroad to nonMSA', ' MSA to nonMSA', ' Not in universe', ' ?', ' Not identifiable', ' Nonmover', ' NonMSA to nonMSA']
Distinct values of MIGMTR3: [' Abroad', ' Different county same state', ' Different region', ' Same county', ' Not in universe', ' ?', ' Nonmover', ' Different division same region', ' Different state same division']
Distinct values of MIGMTR4: [' Different state in West', ' Abroad', ' Different state in South', ' Different state in Midwest', ' Different county same state', ' Same county', ' Not in universe', ' ?', ' Different state in Northeast', ' Nonmover']
Distinct values of MIGSUN: [' Yes', ' Not in universe', ' ?', ' No']
Distinct values of PEFNTVTY: [' Dominican-Republic', ' Ireland', ' Cuba', ' Guatemala', ' Iran', ' Panama', ' El-Salvador', ' Taiwan', ' Hong Kong', ' United-States', ' Japan', ' Nicaragua', ' Canada', ' Cambodia', ' Laos', ' Germany', ' South Korea', ' Trinadad&Tob

发现缺失值以"?"和"NaN"形式存在

In [10]:
df = df.replace([' ?'], [None])
# 成功把' ?'替换为None
check_distinct_values("MIGMTR1")

Distinct values of MIGMTR1: [' MSA to MSA', ' NonMSA to MSA', None, ' Abroad to MSA', ' Abroad to nonMSA', ' MSA to nonMSA', ' Not in universe', ' Not identifiable', ' Nonmover', ' NonMSA to nonMSA']


In [11]:
# 去除所有含None的行，从199523变为95130行
df = df.na.drop()
print(df.count())

95130


## 数据集的统计数据

接下来开始处理每一列的数据类型

In [12]:
df.summary().show(truncate=False, vertical=True)

-RECORD 0--------------------------------------------
 summary   | count                                   
 AAGE      | 95130                                   
 ACLSWKR   | 95130                                   
 ADTINK    | 95130                                   
 ADTOCC    | 95130                                   
 AHGA      | 95130                                   
 AHRSPAY   | 95130                                   
 AHSCOL    | 95130                                   
 AMARITL   | 95130                                   
 AMJIND    | 95130                                   
 AMJOCC    | 95130                                   
 ARACE     | 95130                                   
 AREORGN   | 95130                                   
 ASEX      | 95130                                   
 AUNMEM    | 95130                                   
 AUNTYPE   | 95130                                   
 AWKSTAT   | 95130                                   
 CAPGAIN   | 95130          

In [13]:
df.dtypes

[('AAGE', 'bigint'),
 ('ACLSWKR', 'string'),
 ('ADTINK', 'bigint'),
 ('ADTOCC', 'bigint'),
 ('AHGA', 'string'),
 ('AHRSPAY', 'bigint'),
 ('AHSCOL', 'string'),
 ('AMARITL', 'string'),
 ('AMJIND', 'string'),
 ('AMJOCC', 'string'),
 ('ARACE', 'string'),
 ('AREORGN', 'string'),
 ('ASEX', 'string'),
 ('AUNMEM', 'string'),
 ('AUNTYPE', 'string'),
 ('AWKSTAT', 'string'),
 ('CAPGAIN', 'bigint'),
 ('GAPLOSS', 'bigint'),
 ('DIVVAL', 'bigint'),
 ('FILESTAT', 'string'),
 ('GRINREG', 'string'),
 ('GRINST', 'string'),
 ('HHDFMX', 'string'),
 ('HHDREL', 'string'),
 ('MARSUPWRT', 'double'),
 ('MIGMTR1', 'string'),
 ('MIGMTR3', 'string'),
 ('MIGMTR4', 'string'),
 ('MIGSAME', 'string'),
 ('MIGSUN', 'string'),
 ('NOEMP', 'bigint'),
 ('PARENT', 'string'),
 ('PEFNTVTY', 'string'),
 ('PEMNTVTY', 'string'),
 ('PENATVTY', 'string'),
 ('PRCITSHP', 'string'),
 ('SEOTR', 'bigint'),
 ('VETQVA', 'string'),
 ('VETYN', 'bigint'),
 ('WKSWORK', 'bigint'),
 ('year', 'bigint'),
 ('income', 'string')]

In [14]:
for i in df.columns:
    check_distinct_len(i)

Distinct len of AAGE: 91
Distinct len of ACLSWKR: 9
Distinct len of ADTINK: 52
Distinct len of ADTOCC: 47
Distinct len of AHGA: 17
Distinct len of AHRSPAY: 874
Distinct len of AHSCOL: 3
Distinct len of AMARITL: 7
Distinct len of AMJIND: 24
Distinct len of AMJOCC: 15
Distinct len of ARACE: 5
Distinct len of AREORGN: 10
Distinct len of ASEX: 2
Distinct len of AUNMEM: 3
Distinct len of AUNTYPE: 6
Distinct len of AWKSTAT: 1
Distinct len of CAPGAIN: 127
Distinct len of GAPLOSS: 106
Distinct len of DIVVAL: 932
Distinct len of FILESTAT: 6
Distinct len of GRINREG: 6
Distinct len of GRINST: 50
Distinct len of HHDFMX: 37
Distinct len of HHDREL: 8
Distinct len of MARSUPWRT: 52461
Distinct len of MIGMTR1: 9
Distinct len of MIGMTR3: 8
Distinct len of MIGMTR4: 9
Distinct len of MIGSAME: 3
Distinct len of MIGSUN: 3
Distinct len of NOEMP: 7
Distinct len of PARENT: 5
Distinct len of PEFNTVTY: 40
Distinct len of PEMNTVTY: 40
Distinct len of PENATVTY: 40
Distinct len of PRCITSHP: 5
Distinct len of SEOTR:

## 区分categorical和numerical数据

In [15]:
# 我们把原始数据集中的bigint和double数据当做连续数据；把string数据当做类别数据

In [16]:
from pyspark.sql.types import DoubleType, LongType, StringType

numerical_column = [field.name for field in df.schema.fields if isinstance(field.dataType, (DoubleType, LongType))]
categorical_column = [field.name for field in df.schema.fields if isinstance(field.dataType, StringType)]

# 打印结果
print("Numerical Columns:", numerical_column)
print("Categorical Columns:", categorical_column)

Numerical Columns: ['AAGE', 'ADTINK', 'ADTOCC', 'AHRSPAY', 'CAPGAIN', 'GAPLOSS', 'DIVVAL', 'MARSUPWRT', 'NOEMP', 'SEOTR', 'VETYN', 'WKSWORK', 'year']
Categorical Columns: ['ACLSWKR', 'AHGA', 'AHSCOL', 'AMARITL', 'AMJIND', 'AMJOCC', 'ARACE', 'AREORGN', 'ASEX', 'AUNMEM', 'AUNTYPE', 'AWKSTAT', 'FILESTAT', 'GRINREG', 'GRINST', 'HHDFMX', 'HHDREL', 'MIGMTR1', 'MIGMTR3', 'MIGMTR4', 'MIGSAME', 'MIGSUN', 'PARENT', 'PEFNTVTY', 'PEMNTVTY', 'PENATVTY', 'PRCITSHP', 'VETQVA', 'income']


## 处理categorical数据

In [17]:
# 这两类数据重复了，简单起见，删除掉AHGA，然后把AHSCOL转化为0或1
check_distinct_values("AHGA")
check_distinct_values("AHSCOL")

df=df.drop("AHGA")
df = df.withColumn(
    "AHSCOL",
    when(col("AHSCOL")==' College or university', 1) \
    .otherwise(0) \
    .cast("bigint") 
)
check_distinct_values("AHSCOL")

Distinct values of AHGA: [' 10th grade', ' Some college but no degree', ' Doctorate degree(PhD EdD)', ' Less than 1st grade', ' 12th grade no diploma', ' Associates degree-academic program', ' Bachelors degree(BA AB BS)', ' High school graduate', ' Prof school degree (MD DDS DVM LLB JD)', ' 9th grade', ' Associates degree-occup /vocational', ' 11th grade', ' Children', ' Masters degree(MA MS MEng MEd MSW MBA)', ' 5th or 6th grade', ' 1st 2nd 3rd or 4th grade', ' 7th and 8th grade']
Distinct values of AHSCOL: [' College or university', ' High school', ' Not in universe']
Distinct values of AHSCOL: [0, 1]


In [18]:
# 我们的预测目标包含婚姻状况是否为从未结婚
check_distinct_values("AMARITL")
df = df.withColumn(
    "AMARITL",
    when(col("AMARITL")==' Never married', 1).otherwise(0) \
    .cast("bigint")  # 将数据类型转换为 bigint
)
check_distinct_values("AMARITL")

Distinct values of AMARITL: [' Widowed', ' Married-A F spouse present', ' Never married', ' Divorced', ' Married-spouse absent', ' Separated', ' Married-civilian spouse present']
Distinct values of AMARITL: [0, 1]


In [19]:
# 1表示收入>50K,0表示收入<=50K
check_distinct_values("income")
df = df.withColumn(
    "income",
    when(col("income")==' - 50000.', 0) \
    .when(col("income")==' 50000+.', 1) \
    .cast("bigint")  # 将数据类型转换为 bigint
)
check_distinct_values("income")

Distinct values of income: [' - 50000.', ' 50000+.']
Distinct values of income: [0, 1]


In [20]:
from copy import deepcopy
# 通过StringIndexer处理其他类别数据
catergorical_column_others=deepcopy(categorical_column)
catergorical_column_others.remove("AHGA")
catergorical_column_others.remove("AHSCOL")
catergorical_column_others.remove("AMARITL")
catergorical_column_others.remove("income")

In [21]:
from pyspark.ml.feature import StringIndexer, StringIndexerModel
from pyspark.ml import Pipeline

stages=[StringIndexer(inputCol= column_name, outputCol= column_name+'_index') for column_name in catergorical_column_others]
pipeline = Pipeline(stages=stages)

pipeline_model = pipeline.fit(df)
df = pipeline_model.transform(df)

In [22]:
df=df.drop(*catergorical_column_others)
df.show(1,vertical=True)

-RECORD 0-----------------
 AAGE           | 58      
 ADTINK         | 4       
 ADTOCC         | 34      
 AHRSPAY        | 0       
 AHSCOL         | 0       
 AMARITL        | 0       
 CAPGAIN        | 0       
 GAPLOSS        | 0       
 DIVVAL         | 0       
 MARSUPWRT      | 1053.55 
 NOEMP          | 1       
 SEOTR          | 0       
 VETYN          | 2       
 WKSWORK        | 52      
 year           | 94      
 income         | 0       
 ACLSWKR_index  | 2.0     
 AMJIND_index   | 6.0     
 AMJOCC_index   | 6.0     
 ARACE_index    | 0.0     
 AREORGN_index  | 0.0     
 ASEX_index     | 1.0     
 AUNMEM_index   | 0.0     
 AUNTYPE_index  | 0.0     
 AWKSTAT_index  | 0.0     
 FILESTAT_index | 4.0     
 GRINREG_index  | 1.0     
 GRINST_index   | 24.0    
 HHDFMX_index   | 0.0     
 HHDREL_index   | 0.0     
 MIGMTR1_index  | 1.0     
 MIGMTR3_index  | 1.0     
 MIGMTR4_index  | 1.0     
 MIGSAME_index  | 1.0     
 MIGSUN_index   | 2.0     
 PARENT_index   | 0.0     
 

In [23]:
# 把离散数据放一起(25个)，连续数据放一起(13个)，标签放最后(3个)
catergorical_column_others=[col_name+"_index" for col_name in catergorical_column_others]
columns_reordered = catergorical_column_others+numerical_column+["AHSCOL"]+["AMARITL"]+["income"]
print(len(catergorical_column_others),len(numerical_column),3)
print(columns_reordered)
df = df.select(*columns_reordered)
df.show(1,vertical=True)

25 13 3
['ACLSWKR_index', 'AMJIND_index', 'AMJOCC_index', 'ARACE_index', 'AREORGN_index', 'ASEX_index', 'AUNMEM_index', 'AUNTYPE_index', 'AWKSTAT_index', 'FILESTAT_index', 'GRINREG_index', 'GRINST_index', 'HHDFMX_index', 'HHDREL_index', 'MIGMTR1_index', 'MIGMTR3_index', 'MIGMTR4_index', 'MIGSAME_index', 'MIGSUN_index', 'PARENT_index', 'PEFNTVTY_index', 'PEMNTVTY_index', 'PENATVTY_index', 'PRCITSHP_index', 'VETQVA_index', 'AAGE', 'ADTINK', 'ADTOCC', 'AHRSPAY', 'CAPGAIN', 'GAPLOSS', 'DIVVAL', 'MARSUPWRT', 'NOEMP', 'SEOTR', 'VETYN', 'WKSWORK', 'year', 'AHSCOL', 'AMARITL', 'income']
-RECORD 0-----------------
 ACLSWKR_index  | 2.0     
 AMJIND_index   | 6.0     
 AMJOCC_index   | 6.0     
 ARACE_index    | 0.0     
 AREORGN_index  | 0.0     
 ASEX_index     | 1.0     
 AUNMEM_index   | 0.0     
 AUNTYPE_index  | 0.0     
 AWKSTAT_index  | 0.0     
 FILESTAT_index | 4.0     
 GRINREG_index  | 1.0     
 GRINST_index   | 24.0    
 HHDFMX_index   | 0.0     
 HHDREL_index   | 0.0     
 MIGMTR1_

In [24]:
# 为后续创建nn.Embedding 备用
for i in catergorical_column_others:
    check_distinct_len(i)

Distinct len of ACLSWKR_index: 9
Distinct len of AMJIND_index: 24
Distinct len of AMJOCC_index: 15
Distinct len of ARACE_index: 5
Distinct len of AREORGN_index: 10
Distinct len of ASEX_index: 2
Distinct len of AUNMEM_index: 3
Distinct len of AUNTYPE_index: 6
Distinct len of AWKSTAT_index: 1
Distinct len of FILESTAT_index: 6
Distinct len of GRINREG_index: 6
Distinct len of GRINST_index: 50
Distinct len of HHDFMX_index: 37
Distinct len of HHDREL_index: 8
Distinct len of MIGMTR1_index: 9
Distinct len of MIGMTR3_index: 8
Distinct len of MIGMTR4_index: 9
Distinct len of MIGSAME_index: 3
Distinct len of MIGSUN_index: 3
Distinct len of PARENT_index: 5
Distinct len of PEFNTVTY_index: 40
Distinct len of PEMNTVTY_index: 40
Distinct len of PENATVTY_index: 40
Distinct len of PRCITSHP_index: 5
Distinct len of VETQVA_index: 3


## 拼接特征变为一个向量

In [25]:
from copy import deepcopy

feature_list=df.columns
print("feature_list:",feature_list)

feature_list1=deepcopy(feature_list)
feature_list1.remove("income")
feature_list1.remove('AMARITL')
print("feature_list for exp1:",feature_list1)

feature_list2=deepcopy(feature_list)
feature_list2.remove("AHSCOL")
feature_list2.remove("AMARITL")
print("feature_list for exp2:",feature_list2)

feature_list: ['ACLSWKR_index', 'AMJIND_index', 'AMJOCC_index', 'ARACE_index', 'AREORGN_index', 'ASEX_index', 'AUNMEM_index', 'AUNTYPE_index', 'AWKSTAT_index', 'FILESTAT_index', 'GRINREG_index', 'GRINST_index', 'HHDFMX_index', 'HHDREL_index', 'MIGMTR1_index', 'MIGMTR3_index', 'MIGMTR4_index', 'MIGSAME_index', 'MIGSUN_index', 'PARENT_index', 'PEFNTVTY_index', 'PEMNTVTY_index', 'PENATVTY_index', 'PRCITSHP_index', 'VETQVA_index', 'AAGE', 'ADTINK', 'ADTOCC', 'AHRSPAY', 'CAPGAIN', 'GAPLOSS', 'DIVVAL', 'MARSUPWRT', 'NOEMP', 'SEOTR', 'VETYN', 'WKSWORK', 'year', 'AHSCOL', 'AMARITL', 'income']
feature_list for exp1: ['ACLSWKR_index', 'AMJIND_index', 'AMJOCC_index', 'ARACE_index', 'AREORGN_index', 'ASEX_index', 'AUNMEM_index', 'AUNTYPE_index', 'AWKSTAT_index', 'FILESTAT_index', 'GRINREG_index', 'GRINST_index', 'HHDFMX_index', 'HHDREL_index', 'MIGMTR1_index', 'MIGMTR3_index', 'MIGMTR4_index', 'MIGSAME_index', 'MIGSUN_index', 'PARENT_index', 'PEFNTVTY_index', 'PEMNTVTY_index', 'PENATVTY_index', 'P

In [26]:
from pyspark.ml.feature import VectorAssembler

vector_assembler = VectorAssembler(inputCols=feature_list1, outputCol="vectorized_features_1")
df = vector_assembler.transform(df)

vector_assembler = VectorAssembler(inputCols=feature_list2, outputCol="vectorized_features_2")
df = vector_assembler.transform(df)

In [27]:
from pyspark.sql.functions import udf
from pyspark.ml.linalg import DenseVector, SparseVector
from pyspark.sql.types import ArrayType, DoubleType

# 转换 SparseVector 和 DenseVector 为普通 Python 列表
def convert_vector_to_list(vector):
    if isinstance(vector, DenseVector):
        return vector.toArray().tolist()  # 将 DenseVector 转换为列表
    elif isinstance(vector, SparseVector):
        return vector.toArray().tolist()  # 将 SparseVector 转换为列表
    else:
        return vector  # 如果不是向量类型，保持不变

# 注册 UDF
convert_vector_udf = udf(convert_vector_to_list, ArrayType(DoubleType()))  # 输出为数组类型

# 使用 withColumn 转换列
df = df.withColumn("vectorized_features_1", convert_vector_udf(df["vectorized_features_1"]))
df = df.withColumn("vectorized_features_2", convert_vector_udf(df["vectorized_features_2"]))

In [28]:
# toPandas()能自动把bigint转化为浮点数
panda_df_1=df.select("vectorized_features_1","income","AMARITL").toPandas()
print(panda_df_1.iloc[0,0])
panda_df_1

[2.0, 6.0, 6.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 4.0, 1.0, 24.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 58.0, 4.0, 34.0, 0.0, 0.0, 0.0, 0.0, 1053.55, 1.0, 0.0, 2.0, 52.0, 94.0, 0.0]


Unnamed: 0,vectorized_features_1,income,AMARITL
0,"[2.0, 6.0, 6.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...",0,0
1,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0,1
2,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0,1
3,"[1.0, 5.0, 3.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...",0,0
4,"[1.0, 6.0, 7.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...",0,0
...,...,...,...
95125,"[3.0, 3.0, 2.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ...",1,0
95126,"[1.0, 1.0, 8.0, 4.0, 0.0, 1.0, 0.0, 4.0, 0.0, ...",0,1
95127,"[1.0, 1.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0,0
95128,"[0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...",0,1


In [29]:
panda_df_2=df.select("vectorized_features_2", "AHSCOL", "AMARITL").toPandas()
print(panda_df_2.iloc[0,0])
panda_df_2

[2.0, 6.0, 6.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 4.0, 1.0, 24.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 58.0, 4.0, 34.0, 0.0, 0.0, 0.0, 0.0, 1053.55, 1.0, 0.0, 2.0, 52.0, 94.0, 0.0]


Unnamed: 0,vectorized_features_2,AHSCOL,AMARITL
0,"[2.0, 6.0, 6.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...",0,0
1,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0,1
2,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0,1
3,"[1.0, 5.0, 3.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...",0,0
4,"[1.0, 6.0, 7.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...",0,0
...,...,...,...
95125,"[3.0, 3.0, 2.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ...",0,0
95126,"[1.0, 1.0, 8.0, 4.0, 0.0, 1.0, 0.0, 4.0, 0.0, ...",0,1
95127,"[1.0, 1.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0,0
95128,"[0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...",0,1


In [30]:
import os

# 判断目录是否存在，如果不存在则创建
if not os.path.exists("data/exp1"):
    os.makedirs("data/exp1")  
if not os.path.exists("data/exp2"):
    os.makedirs("data/exp2")  

In [31]:
panda_df_1.to_csv("data/exp1/raw_train.csv", index=False)
panda_df_2.to_csv("data/exp2/raw_train.csv", index=False)

# 对df_test 做一样的处理

In [32]:
# 共有47391个测试样本
data = df_test
df = spark.createDataFrame(data)

df = df.replace([' ?'], [None])

df = df.na.drop()
print(df.count())

47391


In [33]:
numerical_column = [field.name for field in df.schema.fields if isinstance(field.dataType, (DoubleType, LongType))]
categorical_column = [field.name for field in df.schema.fields if isinstance(field.dataType, StringType)]

# 打印结果
print("Numerical Columns:", numerical_column)
print("Categorical Columns:", categorical_column)

Numerical Columns: ['AAGE', 'ADTINK', 'ADTOCC', 'AHRSPAY', 'CAPGAIN', 'GAPLOSS', 'DIVVAL', 'MARSUPWRT', 'NOEMP', 'SEOTR', 'VETYN', 'WKSWORK', 'year']
Categorical Columns: ['ACLSWKR', 'AHGA', 'AHSCOL', 'AMARITL', 'AMJIND', 'AMJOCC', 'ARACE', 'AREORGN', 'ASEX', 'AUNMEM', 'AUNTYPE', 'AWKSTAT', 'FILESTAT', 'GRINREG', 'GRINST', 'HHDFMX', 'HHDREL', 'MIGMTR1', 'MIGMTR3', 'MIGMTR4', 'MIGSAME', 'MIGSUN', 'PARENT', 'PEFNTVTY', 'PEMNTVTY', 'PENATVTY', 'PRCITSHP', 'VETQVA', 'income']


In [34]:
df=df.drop("AHGA")
df = df.withColumn(
    "AHSCOL",
    when(col("AHSCOL")==' College or university', 1) \
    .otherwise(0) \
    .cast("bigint") 
)
check_distinct_values("AHSCOL")

Distinct values of AHSCOL: [0, 1]


In [35]:
# 我们的预测目标包含婚姻状况是否为从未结婚
check_distinct_values("AMARITL")
df = df.withColumn(
    "AMARITL",
    when(col("AMARITL")==' Never married', 1).otherwise(0) \
    .cast("bigint")  # 将数据类型转换为 bigint
)
check_distinct_values("AMARITL")

Distinct values of AMARITL: [' Widowed', ' Married-A F spouse present', ' Never married', ' Divorced', ' Married-spouse absent', ' Separated', ' Married-civilian spouse present']
Distinct values of AMARITL: [0, 1]


In [36]:
# 1表示收入>50K,0表示收入<=50K
check_distinct_values("income")
df = df.withColumn(
    "income",
    when(col("income")==' - 50000.', 0) \
    .when(col("income")==' 50000+.', 1) \
    .cast("bigint")  # 将数据类型转换为 bigint
)
check_distinct_values("income")

Distinct values of income: [' - 50000.', ' 50000+.']
Distinct values of income: [0, 1]


In [37]:
from copy import deepcopy
# 通过StringIndexer处理其他类别数据
catergorical_column_others=deepcopy(categorical_column)
catergorical_column_others.remove("AHGA")
catergorical_column_others.remove("AHSCOL")
catergorical_column_others.remove("AMARITL")
catergorical_column_others.remove("income")

In [38]:
pipeline = Pipeline(stages=stages)

pipeline_model = pipeline.fit(df)
df = pipeline_model.transform(df)

df=df.drop(*catergorical_column_others)

In [39]:
# 把离散数据放一起(25个)，连续数据放一起(13个)，标签放最后(3个)
stages=[StringIndexer(inputCol= column_name, outputCol= column_name+'_index') for column_name in catergorical_column_others]
catergorical_column_others=[col_name+"_index" for col_name in catergorical_column_others]
columns_reordered = catergorical_column_others+numerical_column+["AHSCOL"]+["AMARITL"]+["income"]
print(len(catergorical_column_others),len(numerical_column),3)
print(columns_reordered)
df = df.select(*columns_reordered)

25 13 3
['ACLSWKR_index', 'AMJIND_index', 'AMJOCC_index', 'ARACE_index', 'AREORGN_index', 'ASEX_index', 'AUNMEM_index', 'AUNTYPE_index', 'AWKSTAT_index', 'FILESTAT_index', 'GRINREG_index', 'GRINST_index', 'HHDFMX_index', 'HHDREL_index', 'MIGMTR1_index', 'MIGMTR3_index', 'MIGMTR4_index', 'MIGSAME_index', 'MIGSUN_index', 'PARENT_index', 'PEFNTVTY_index', 'PEMNTVTY_index', 'PENATVTY_index', 'PRCITSHP_index', 'VETQVA_index', 'AAGE', 'ADTINK', 'ADTOCC', 'AHRSPAY', 'CAPGAIN', 'GAPLOSS', 'DIVVAL', 'MARSUPWRT', 'NOEMP', 'SEOTR', 'VETYN', 'WKSWORK', 'year', 'AHSCOL', 'AMARITL', 'income']


In [40]:
# 为后续创建nn.Embedding 备用
# Distinct len of PENATVTY: 41 说明测试集中出现过训练集里从未出现过的值
# 在设置nn.Embedding时应该选择较大值
for i in catergorical_column_others:
    check_distinct_len(i)

Distinct len of ACLSWKR_index: 9
Distinct len of AMJIND_index: 24
Distinct len of AMJOCC_index: 15
Distinct len of ARACE_index: 5
Distinct len of AREORGN_index: 10
Distinct len of ASEX_index: 2
Distinct len of AUNMEM_index: 3
Distinct len of AUNTYPE_index: 6
Distinct len of AWKSTAT_index: 1
Distinct len of FILESTAT_index: 6
Distinct len of GRINREG_index: 6
Distinct len of GRINST_index: 50
Distinct len of HHDFMX_index: 36
Distinct len of HHDREL_index: 8
Distinct len of MIGMTR1_index: 9
Distinct len of MIGMTR3_index: 8
Distinct len of MIGMTR4_index: 9
Distinct len of MIGSAME_index: 3
Distinct len of MIGSUN_index: 3
Distinct len of PARENT_index: 5
Distinct len of PEFNTVTY_index: 40
Distinct len of PEMNTVTY_index: 40
Distinct len of PENATVTY_index: 41
Distinct len of PRCITSHP_index: 5
Distinct len of VETQVA_index: 3


In [41]:
feature_list=df.columns
print("feature_list:",feature_list)

feature_list1=deepcopy(feature_list)
feature_list1.remove("income")
feature_list1.remove('AMARITL')
print("feature_list for exp1:",feature_list1)

feature_list2=deepcopy(feature_list)
feature_list2.remove("AHSCOL")
feature_list2.remove("AMARITL")
print("feature_list for exp2:",feature_list2)

feature_list: ['ACLSWKR_index', 'AMJIND_index', 'AMJOCC_index', 'ARACE_index', 'AREORGN_index', 'ASEX_index', 'AUNMEM_index', 'AUNTYPE_index', 'AWKSTAT_index', 'FILESTAT_index', 'GRINREG_index', 'GRINST_index', 'HHDFMX_index', 'HHDREL_index', 'MIGMTR1_index', 'MIGMTR3_index', 'MIGMTR4_index', 'MIGSAME_index', 'MIGSUN_index', 'PARENT_index', 'PEFNTVTY_index', 'PEMNTVTY_index', 'PENATVTY_index', 'PRCITSHP_index', 'VETQVA_index', 'AAGE', 'ADTINK', 'ADTOCC', 'AHRSPAY', 'CAPGAIN', 'GAPLOSS', 'DIVVAL', 'MARSUPWRT', 'NOEMP', 'SEOTR', 'VETYN', 'WKSWORK', 'year', 'AHSCOL', 'AMARITL', 'income']
feature_list for exp1: ['ACLSWKR_index', 'AMJIND_index', 'AMJOCC_index', 'ARACE_index', 'AREORGN_index', 'ASEX_index', 'AUNMEM_index', 'AUNTYPE_index', 'AWKSTAT_index', 'FILESTAT_index', 'GRINREG_index', 'GRINST_index', 'HHDFMX_index', 'HHDREL_index', 'MIGMTR1_index', 'MIGMTR3_index', 'MIGMTR4_index', 'MIGSAME_index', 'MIGSUN_index', 'PARENT_index', 'PEFNTVTY_index', 'PEMNTVTY_index', 'PENATVTY_index', 'P

In [42]:
from pyspark.ml.feature import VectorAssembler

vector_assembler = VectorAssembler(inputCols=feature_list1, outputCol="vectorized_features_1")
df = vector_assembler.transform(df)

vector_assembler = VectorAssembler(inputCols=feature_list2, outputCol="vectorized_features_2")
df = vector_assembler.transform(df)

In [43]:
from pyspark.sql.functions import udf
from pyspark.ml.linalg import DenseVector, SparseVector
from pyspark.sql.types import ArrayType, DoubleType

# 转换 SparseVector 和 DenseVector 为普通 Python 列表
def convert_vector_to_list(vector):
    if isinstance(vector, DenseVector):
        return vector.toArray().tolist()  # 将 DenseVector 转换为列表
    elif isinstance(vector, SparseVector):
        return vector.toArray().tolist()  # 将 SparseVector 转换为列表
    else:
        return vector  # 如果不是向量类型，保持不变

# 注册 UDF
convert_vector_udf = udf(convert_vector_to_list, ArrayType(DoubleType()))  # 输出为数组类型

# 使用 withColumn 转换列
df = df.withColumn("vectorized_features_1", convert_vector_udf(df["vectorized_features_1"]))
df = df.withColumn("vectorized_features_2", convert_vector_udf(df["vectorized_features_2"]))

In [44]:
# toPandas()能自动把bigint转化为浮点数
panda_df_1=df.select("vectorized_features_1","income","AMARITL").toPandas()
print(panda_df_1.iloc[0,0])
panda_df_1

panda_df_2=df.select("vectorized_features_2", "AHSCOL", "AMARITL").toPandas()
print(panda_df_2.iloc[0,0])
panda_df_2

[1.0, 9.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 35.0, 29.0, 3.0, 0.0, 0.0, 0.0, 0.0, 1866.88, 5.0, 2.0, 2.0, 52.0, 94.0, 0.0]
[1.0, 9.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 35.0, 29.0, 3.0, 0.0, 0.0, 0.0, 0.0, 1866.88, 5.0, 2.0, 2.0, 52.0, 94.0, 0.0]


Unnamed: 0,vectorized_features_2,AHSCOL,AMARITL
0,"[1.0, 9.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0,0
1,"[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...",0,1
2,"[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ...",0,1
3,"[1.0, 11.0, 2.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0,...",0,0
4,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0,1
...,...,...,...
47386,"[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0,1
47387,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0,0
47388,"[1.0, 3.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1,1
47389,"[2.0, 14.0, 10.0, 0.0, 2.0, 1.0, 0.0, 0.0, 0.0...",0,0


In [45]:
panda_df_1.to_csv("data/exp1/raw_test.csv", index=False)
panda_df_2.to_csv("data/exp2/raw_test.csv", index=False)