In [11]:
from pyspark.sql import SparkSession
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.stat import Correlation
from pyspark.ml import Pipeline

In [12]:
spark = SparkSession.builder.appName("SParkCrossVal").getOrCreate()

df = spark.read.format("csv")\
    .load("../train.csv", header=True, inferSchema=True)

In [13]:
#Features Vector generated

assembler = VectorAssembler(inputCols=df.columns[2:], outputCol="features")

output = assembler.transform(df)


In [14]:
trainDF = output.selectExpr('target as label', 'features')
trainDF.show(5)

+-----+--------------------+
|label|            features|
+-----+--------------------+
|    0|[8.9255,-6.7863,1...|
|    0|[11.5006,-4.1473,...|
|    0|[8.6093,-2.7457,1...|
|    0|[11.0604,-2.1518,...|
|    0|[9.8369,-1.4834,1...|
+-----+--------------------+
only showing top 5 rows



In [15]:
#from pyspark.ml.classification import DecisionTreeClassifier
#decisionTree = DecisionTreeClassifier(labelCol = "label")

# Train a GBT model.
gbt = GBTClassifier(labelCol="label", featuresCol="features", maxIter=10)

pipeline = Pipeline(stages = [gbt])


In [16]:
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator


In [17]:
paramGrid = ParamGridBuilder()\
    .addGrid(decisionTree.maxDepth, [1,2,4,5,6,7,8])\
    .build()


In [18]:
evaluator = MulticlassClassificationEvaluator(labelCol = "label", predictionCol = "prediction",
                                              metricName = "accuracy")

crossVal = CrossValidator(estimator = pipeline,
                         estimatorParamMaps = paramGrid,
                         evaluator = evaluator,
                         numFolds = 10)

In [21]:
cvModel = crossVal.fit(trainDF)

In [22]:
cvModel.avgMetrics

[0.8995794816579963,
 0.8995794816579963,
 0.8995794816579963,
 0.8995794816579963,
 0.8995794816579963,
 0.8995794816579963,
 0.8995794816579963]

In [25]:
print (cvModel.bestModel.stages[0])

GBTClassificationModel (uid=GBTClassifier_4e2abde9099877b217e6) with 10 trees


In [26]:
print(gbt)

GBTClassifier_4e2abde9099877b217e6
