In [3]:
from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors
from pyspark.ml.stat import Correlation

spark = SparkSession.builder.appName("SParkCrossVal").getOrCreate()

df = spark.read.format("csv")\
    .load("/run/media/vas1a/str1/Downloads1/santander/train.csv", header=True, inferSchema=True)


In [4]:
#Features Vector generated

from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

assembler = VectorAssembler(inputCols=df.columns[2:], outputCol="features")

output = assembler.transform(df)


In [5]:
inputDF = output.selectExpr('target as label', 'features')


In [6]:
inputDF.show(5)

+-----+--------------------+
|label|            features|
+-----+--------------------+
|    0|[8.9255,-6.7863,1...|
|    0|[11.5006,-4.1473,...|
|    0|[8.6093,-2.7457,1...|
|    0|[11.0604,-2.1518,...|
|    0|[9.8369,-1.4834,1...|
+-----+--------------------+
only showing top 5 rows



In [7]:
from pyspark.ml.classification import DecisionTreeClassifier
decisionTree = DecisionTreeClassifier(labelCol = "label")


In [8]:
from pyspark.ml import Pipeline
pipeline = Pipeline(stages = [decisionTree])


In [9]:
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator


In [10]:
paramGrid = ParamGridBuilder()\
    .addGrid(decisionTree.maxDepth, [1,2,4,5,6,7,8])\
    .build()


In [120]:
evaluator = MulticlassClassificationEvaluator(labelCol = "label", predictionCol = "prediction", metricName = "accuracy")

crossVal = CrossValidator(estimator = pipeline,
                         estimatorParamMaps = paramGrid,
                         evaluator = evaluator,
                         numFolds = 10)

In [121]:
cvModel = crossVal.fit(inputDF)

In [122]:
cvModel.avgMetrics

[0.8995119939847793,
 0.8995119939847793,
 0.8993668993384226,
 0.8990409122081747,
 0.8988212877104155,
 0.8982867721122823,
 0.8972866285232752]

In [123]:
print (cvModel.bestModel.stages[0])

DecisionTreeClassificationModel (uid=DecisionTreeClassifier_4a53b9e076edb4e3e186) of depth 1 with 3 nodes
