In [1]:
from pyspark.sql.types import *

columns = ["age", "workclass", "fnlwgt", "education", "marital_status",
           "occupation", "relationship", "race", "sex", "capital_gain", "capital_loss",
           "hours_per_week", "native_country", "income"]

adultschema = StructType([
    StructField("age",DoubleType(),True),
    StructField("capital_gain",DoubleType(),True),
    StructField("capital_loss",DoubleType(),True),
    StructField("education",StringType(),True),
    StructField("fnlwgt",DoubleType(),True),
    StructField("hours_per_week",DoubleType(),True),
    StructField("income",StringType(),True),
    StructField("marital_status",StringType(),True),
    StructField("native_country",StringType(),True),
    StructField("occupation",StringType(),True),
    StructField("race",StringType(),True),
    StructField("relationship",StringType(),True),
    StructField("sex",StringType(),True),
    StructField("workclass",StringType(),True)
])

In [2]:
def toDouble(v):
    try:
        return float(v)
    except:
        return v

In [3]:
census_raw = sc.textFile("data-files/adult.raw").map(lambda x: [ x2.strip() for x2 in x.split(',') ])

In [5]:
census_raw.take(1)

[['39',
  'State-gov',
  '77516',
  'Bachelors',
  'Never-married',
  'Adm-clerical',
  'Not-in-family',
  'White',
  'Male',
  '2174',
  '0',
  '40',
  'United-States',
  '<=50K']]

In [6]:
census_raw = census_raw.map(lambda row: [ toDouble(v) for v in row ])

In [7]:
census_raw.take(1)

[[39.0,
  'State-gov',
  77516.0,
  'Bachelors',
  'Never-married',
  'Adm-clerical',
  'Not-in-family',
  'White',
  'Male',
  2174.0,
  0.0,
  40.0,
  'United-States',
  '<=50K']]

In [11]:
from pyspark.sql import Row

dfraw = spark.createDataFrame(census_raw.map(lambda row: Row( **{ x[0]:x[1] for x in zip(columns, row) } )), adultschema) # DataFrame -> DataSet[Row] ... Row(a=1, b=2, c=3)

In [12]:
dfraw.show(5)

+----+------------+------------+---------+--------+--------------+------+------------------+--------------+-----------------+-----+-------------+------+----------------+
| age|capital_gain|capital_loss|education|  fnlwgt|hours_per_week|income|    marital_status|native_country|       occupation| race| relationship|   sex|       workclass|
+----+------------+------------+---------+--------+--------------+------+------------------+--------------+-----------------+-----+-------------+------+----------------+
|39.0|      2174.0|         0.0|Bachelors| 77516.0|          40.0| <=50K|     Never-married| United-States|     Adm-clerical|White|Not-in-family|  Male|       State-gov|
|50.0|         0.0|         0.0|Bachelors| 83311.0|          13.0| <=50K|Married-civ-spouse| United-States|  Exec-managerial|White|      Husband|  Male|Self-emp-not-inc|
|38.0|         0.0|         0.0|  HS-grad|215646.0|          40.0| <=50K|          Divorced| United-States|Handlers-cleaners|White|Not-in-family|  Mal

In [16]:
from pyspark.sql.functions import *

dfraw.groupBy(col('workclass')).count().show()

+----------------+-----+
|       workclass|count|
+----------------+-----+
|Self-emp-not-inc| 3862|
|       Local-gov| 3136|
|       State-gov| 1981|
|         Private|33906|
|     Without-pay|   21|
|     Federal-gov| 1432|
|    Never-worked|   10|
|               ?| 2799|
|    Self-emp-inc| 1695|
+----------------+-----+



In [17]:
dfrawnona = dfraw.na.replace(['?'], ['Private'], ['workclass'])
dfrawnona = dfrawnona.na.replace(['?'], ['Prof-specialty'], ['occupation'])
dfrawnona = dfrawnona.na.replace(['?'], ['United-States'], ['native_country'])

In [18]:
dfrawnona.groupBy(col('workclass')).count().show()

+----------------+-----+
|       workclass|count|
+----------------+-----+
|Self-emp-not-inc| 3862|
|       Local-gov| 3136|
|       State-gov| 1981|
|         Private|36705|
|     Without-pay|   21|
|     Federal-gov| 1432|
|    Never-worked|   10|
|    Self-emp-inc| 1695|
+----------------+-----+



In [22]:
def indexStringColumns(df, cols):
    from pyspark.ml.feature import StringIndexer
    #variable newdf will be updated several times
    newdf = df
    for c in cols:
        si = StringIndexer(inputCol=c, outputCol=c+"-num") # c컬럼을 수치로 변경해서 c-num 컬럼 만들기
        sm = si.fit(newdf) # 훈련
        # newdf = sm.transform(newdf)
        newdf = sm.transform(newdf).drop(c) # 적용하고 기존의 c컬럼 제거
        newdf = newdf.withColumnRenamed(c+"-num", c) # 새로 만들 컬럼이름을 c컬럼으로 변경
    return newdf

In [23]:
dfnumeric = indexStringColumns(dfrawnona, ["workclass", "education", "marital_status", "occupation", "relationship", "race", "sex", "native_country", "income"])
dfnumeric.show(5)

+----+------------+------------+--------+--------------+---------+---------+--------------+----------+------------+----+---+--------------+------+
| age|capital_gain|capital_loss|  fnlwgt|hours_per_week|workclass|education|marital_status|occupation|relationship|race|sex|native_country|income|
+----+------------+------------+--------+--------------+---------+---------+--------------+----------+------------+----+---+--------------+------+
|39.0|      2174.0|         0.0| 77516.0|          40.0|      3.0|      2.0|           1.0|       3.0|         1.0| 0.0|0.0|           0.0|   0.0|
|50.0|         0.0|         0.0| 83311.0|          13.0|      1.0|      2.0|           0.0|       2.0|         0.0| 0.0|0.0|           0.0|   0.0|
|38.0|         0.0|         0.0|215646.0|          40.0|      0.0|      0.0|           2.0|       8.0|         1.0| 0.0|0.0|           0.0|   0.0|
|53.0|         0.0|         0.0|234721.0|          40.0|      0.0|      5.0|           0.0|       8.0|         0.0| 1.