In [83]:
import numpy as np
import pandas as pd
from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.ml.classification import DecisionTreeClassificationModel,DecisionTreeClassifier
import pyspark.ml.linalg as linalg
from pyspark.ml.feature import StringIndexer,VectorAssembler


In [2]:
spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

In [107]:
# 生成测试数据集
data = {
         'Age': {0: 22, 1: 38, 2: 26, 3: 35, 4: 35, 5: 40},
         'Fare': {0: 7.3, 1: 71.3, 2: 7.9, 3: 53.1, 4: 8.0, 5: 53.1},
         'Pclass': {0: 3, 1: 1, 2: 3, 3: 1, 4: 3, 5: 1},
         'Survived': {0: 0., 1: 1., 2: 1., 3: 1., 4: 0.,5: 1.}}

In [94]:
#转换pandas df到spark df
df_pd = pd.DataFrame(data)
df_spark = spark.createDataFrame(df_pd)

In [95]:
df_spark.show()

+---+----+------+--------+
|Age|Fare|Pclass|Survived|
+---+----+------+--------+
| 22| 7.3|     3|     0.0|
| 38|71.3|     1|     1.0|
| 26| 7.9|     3|     1.0|
| 35|53.1|     1|     1.0|
| 35| 8.0|     3|     0.0|
| 40|53.1|     1|     1.0|
+---+----+------+--------+



### 数据预处理（最简单的特征工程）

In [103]:
indexer = StringIndexer(inputCol="Survived", outputCol="indexed_label").fit(df_spark)
indexed_df = indexer.transform(df_spark)
indexed_df.show()

+---+----+------+--------+-------------+
|Age|Fare|Pclass|Survived|indexed_label|
+---+----+------+--------+-------------+
| 22| 7.3|     3|     0.0|          1.0|
| 38|71.3|     1|     1.0|          0.0|
| 26| 7.9|     3|     1.0|          0.0|
| 35|53.1|     1|     1.0|          0.0|
| 35| 8.0|     3|     0.0|          1.0|
| 40|53.1|     1|     1.0|          0.0|
+---+----+------+--------+-------------+



In [97]:
assembler = VectorAssembler(inputCols=['Age','Fare','Pclass'],outputCol='features')

In [112]:
final_df = assembler.transform(indexed_df)
final_df.show()

+---+----+------+--------+-------------+---------------+
|Age|Fare|Pclass|Survived|indexed_label|       features|
+---+----+------+--------+-------------+---------------+
| 22| 7.3|     3|     0.0|          1.0| [22.0,7.3,3.0]|
| 38|71.3|     1|     1.0|          0.0|[38.0,71.3,1.0]|
| 26| 7.9|     3|     1.0|          0.0| [26.0,7.9,3.0]|
| 35|53.1|     1|     1.0|          0.0|[35.0,53.1,1.0]|
| 35| 8.0|     3|     0.0|          1.0| [35.0,8.0,3.0]|
| 40|53.1|     1|     1.0|          0.0|[40.0,53.1,1.0]|
+---+----+------+--------+-------------+---------------+



In [120]:
#拆分数据集
train, test = final_df.randomSplit([0.8,0.2])

### 模型训练

In [122]:
dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed_label")

In [123]:
model = dt.fit(final_df)

In [127]:
#预测
result = model.transform(test)

In [134]:
#对比预测与实际
result.show()
test.show()

+---+----+------+--------+-------------+--------------+-------------+-----------+----------+
|Age|Fare|Pclass|Survived|indexed_label|      features|rawPrediction|probability|prediction|
+---+----+------+--------+-------------+--------------+-------------+-----------+----------+
| 26| 7.9|     3|     1.0|          0.0|[26.0,7.9,3.0]|    [1.0,1.0]|  [0.5,0.5]|       0.0|
+---+----+------+--------+-------------+--------------+-------------+-----------+----------+

+---+----+------+--------+-------------+--------------+
|Age|Fare|Pclass|Survived|indexed_label|      features|
+---+----+------+--------+-------------+--------------+
| 26| 7.9|     3|     1.0|          0.0|[26.0,7.9,3.0]|
+---+----+------+--------+-------------+--------------+



In [125]:
model.save()

DecisionTreeClassificationModel (uid=DecisionTreeClassifier_442e36057159) of depth 2 with 5 nodes

### 文档示例

In [98]:
df = spark.createDataFrame([
  (1.0, Vectors.dense(1.0,1)),
     (0.0, Vectors.dense(1,0))], ["label", "features"])

In [99]:
df.show()

+-----+---------+
|label| features|
+-----+---------+
|  1.0|[1.0,1.0]|
|  0.0|[1.0,0.0]|
+-----+---------+



In [100]:
stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")

In [101]:
si_model = stringIndexer.fit(df)
td = si_model.transform(df)

In [102]:
td.show()

+-----+---------+-------+
|label| features|indexed|
+-----+---------+-------+
|  1.0|[1.0,1.0]|    1.0|
|  0.0|[1.0,0.0]|    0.0|
+-----+---------+-------+

