In [None]:
# Must be included at the beginning of each new notebook. Remember to change the app name.
import findspark
findspark.init('/home/ubuntu/spark-2.1.1-bin-hadoop2.7')
import pyspark
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('722-model').getOrCreate()

In [None]:
data1= spark.read.csv("./data_china.csv", header=True, inferSchema=True)
data1.printSchema()

In [3]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler
assembler = VectorAssembler(
  inputCols=['mp_price_RMB'],
              outputCol="features")
output = assembler.transform(data1)
from pyspark.ml.feature import StringIndexer
indexer = StringIndexer(inputCol="cm_name", outputCol="PrivateIndex")
output_fixed = indexer.fit(output).transform(output)
final_data = output_fixed.select("features",'PrivateIndex')

In [4]:
from pyspark.ml.classification import DecisionTreeClassifier,GBTClassifier,RandomForestClassifier
from pyspark.ml import Pipeline
dtc = DecisionTreeClassifier(labelCol='PrivateIndex',featuresCol='features')
train_data,test_data = final_data.randomSplit([0.7,0.3])
dtc_model = dtc.fit(train_data)
dtc_predictions = dtc_model.transform(test_data)

In [5]:
# Let's start off with binary classification.
from pyspark.ml.evaluation import BinaryClassificationEvaluator

# Note that the label column isn't named label, it's named PrivateIndex in this case.
my_binary_eval = BinaryClassificationEvaluator(labelCol = 'PrivateIndex')

In [6]:
print("DTC")
print(my_binary_eval.evaluate(dtc_predictions))

DTC
0.7545731707317074


In [7]:
dtc_predictions.show()
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
acc_evaluator = MulticlassClassificationEvaluator(labelCol="PrivateIndex", predictionCol="prediction", metricName="accuracy")
dtc_acc = acc_evaluator.evaluate(dtc_predictions)
print('-'*40)
print('A single decision tree has an accuracy of: {0:2.2f}%'.format(dtc_acc*100))


+--------+------------+--------------------+--------------------+----------+
|features|PrivateIndex|       rawPrediction|         probability|prediction|
+--------+------------+--------------------+--------------------+----------+
| [1.315]|         1.0|[0.0,103.0,0.0,0....|[0.0,1.0,0.0,0.0,...|       1.0|
|[1.3625]|         1.0|[0.0,103.0,0.0,0....|[0.0,1.0,0.0,0.0,...|       1.0|
| [1.395]|         1.0|[0.0,103.0,0.0,0....|[0.0,1.0,0.0,0.0,...|       1.0|
|[1.5025]|         1.0|[0.0,103.0,0.0,0....|[0.0,1.0,0.0,0.0,...|       1.0|
|  [1.54]|         1.0|[0.0,103.0,0.0,0....|[0.0,1.0,0.0,0.0,...|       1.0|
| [1.545]|         1.0|[0.0,103.0,0.0,0....|[0.0,1.0,0.0,0.0,...|       1.0|
|[1.5567]|         1.0|[0.0,103.0,0.0,0....|[0.0,1.0,0.0,0.0,...|       1.0|
|  [1.57]|         1.0|[0.0,103.0,0.0,0....|[0.0,1.0,0.0,0.0,...|       1.0|
|  [1.59]|         1.0|[0.0,103.0,0.0,0....|[0.0,1.0,0.0,0.0,...|       1.0|
|[1.5975]|         1.0|[0.0,103.0,0.0,0....|[0.0,1.0,0.0,0.0,...|       1.0|

In [8]:
from pyspark.sql import SparkSession
from pyspark.sql import Row
from pyspark.sql.types import *
from pyspark.sql.functions import *
final_data.show()
data1.select('cm_name').distinct().show()
final_data.select('PrivateIndex').distinct().show()

+--------+------------+
|features|PrivateIndex|
+--------+------------+
|[2.6567]|         3.0|
|  [2.61]|         3.0|
| [2.616]|         3.0|
|  [2.57]|         3.0|
|  [2.54]|         3.0|
| [2.516]|         3.0|
|  [2.54]|         3.0|
|[2.5775]|         3.0|
|  [2.62]|         3.0|
|[2.6175]|         3.0|
| [2.594]|         3.0|
| [2.595]|         3.0|
|   [2.6]|         3.0|
|  [2.58]|         3.0|
| [2.582]|         3.0|
| [2.595]|         3.0|
| [2.552]|         3.0|
|  [2.47]|         3.0|
|  [2.45]|         3.0|
| [2.444]|         3.0|
+--------+------------+
only showing top 20 rows

+--------------------+
|             cm_name|
+--------------------+
|Wheat flour (firs...|
|Rice (Indica) - W...|
|Rice (Japonica) -...|
|   Wheat - Wholesale|
|   Maize - Wholesale|
+--------------------+

+------------+
|PrivateIndex|
+------------+
|         0.0|
|         1.0|
|         4.0|
|         3.0|
|         2.0|
+------------+



In [9]:
print(dtc_model.toDebugString)

DecisionTreeClassificationModel (uid=DecisionTreeClassifier_42adac62f17be4bf4032) of depth 5 with 29 nodes
  If (feature 0 <= 3.75)
   If (feature 0 <= 3.08)
    If (feature 0 <= 2.2525)
     If (feature 0 <= 1.923)
      Predict: 1.0
     Else (feature 0 > 1.923)
      Predict: 1.0
    Else (feature 0 > 2.2525)
     If (feature 0 <= 2.325)
      Predict: 3.0
     Else (feature 0 > 2.325)
      If (feature 0 <= 2.57)
       Predict: 3.0
      Else (feature 0 > 2.57)
       Predict: 3.0
   Else (feature 0 > 3.08)
    If (feature 0 <= 3.465)
     Predict: 2.0
    Else (feature 0 > 3.465)
     Predict: 2.0
  Else (feature 0 > 3.75)
   If (feature 0 <= 4.23)
    If (feature 0 <= 4.1525)
     If (feature 0 <= 3.96)
      If (feature 0 <= 3.92)
       Predict: 4.0
      Else (feature 0 > 3.92)
       Predict: 4.0
     Else (feature 0 > 3.96)
      If (feature 0 <= 4.0375)
       Predict: 4.0
      Else (feature 0 > 4.0375)
       Predict: 4.0
    Else (feature 0 > 4.1525)
     Predict: 4.0
 