In [27]:
from pyspark.sql import functions as F
from pyspark.sql import SparkSession
import findspark
findspark.init()
spark = SparkSession.builder.appName("A").config("B", "C").getOrCreate()
sc = spark.sparkContext


In [28]:
df = spark.read.csv(path='./data/WineData.csv', header=True, inferSchema=True)


In [29]:
df = df.withColumn('quality', F.when(df.quality <= 4, 0).when(
    (df.quality <= 6) & (df.quality > 4), 1).otherwise(2))


In [33]:
from pyspark.mllib.regression import LabeledPoint
data = df.rdd.map(lambda row: LabeledPoint(row[-1], row[:-1]))


In [35]:
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils

# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])
model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
                                     impurity='gini', maxDepth=5, maxBins=32)
predictions = model.predict(testData.map(lambda x: x.features))  # 解释变量
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testErr = labelsAndPredictions.filter(
    lambda lp: lp[0] != lp[1]).count() / float(testData.count())
# lp[0]是label，lp[1]是predictions，用filter保留不一样的元素，计数，再除以总数
print('Test Error = ' + str(testErr))
print('Learned classification tree model:')
print(model.toDebugString())  # 可以返回拟合出来的决策树的形状


Test Error = 0.0
Learned classification tree model:
DecisionTreeModel classifier of depth 2 with 5 nodes
  If (feature 434 <= 79.5)
   If (feature 100 <= 193.5)
    Predict: 0.0
   Else (feature 100 > 193.5)
    Predict: 1.0
  Else (feature 434 > 79.5)
   Predict: 1.0

