In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession.builder.master("local[*]").getOrCreate()

In [2]:
# !pip install numpy

## Reading and analyzing the data

### The 'Kaggle Telco Customer Churn' data is used: [link](https://www.kaggle.com/blastchar/telco-customer-churn)

In [3]:
from pyspark.sql.functions import col
from pyspark.sql.types import *

churn_schema = StructType([
    StructField("customerID", StringType(), True),
    StructField("gender", StringType(), True),
    StructField("SeniorCitizen", IntegerType(), True),
    StructField("Partner", StringType(), True),
    StructField("Dependents", StringType(), True),
    StructField("tenure", IntegerType(), True),
    StructField("PhoneService", StringType(), True),
    StructField("MultipleLines", StringType(), True),
    StructField("InternetService", StringType(), True),
    StructField("OnlineSecurity", StringType(), True),
    StructField("OnlineBackup", StringType(), True),
    StructField("DeviceProtection", StringType(), True),
    StructField("TechSupport", StringType(), True),
    StructField("StreamingTV", StringType(), True),
    StructField("StreamingMovies", StringType(), True),
    StructField("Contract", StringType(), True),
    StructField("PaperlessBilling", StringType(), True),
    StructField("PaymentMethod", StringType(), True),
    StructField("MonthlyCharges", FloatType(), True),
    StructField("TotalCharges", FloatType(), True),
    StructField("Churn", StringType(), True),

])

In [4]:
churn_df = spark.read.csv("WA_Fn-UseC_-Telco-Customer-Churn.xls", header=True, schema=churn_schema)

In [5]:
churn_df.printSchema()

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: integer (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: float (nullable = true)
 |-- TotalCharges: float (nullable = true)
 |-- Churn: string (nullable = true)



In [6]:
churn_df.show()

+----------+------+-------------+-------+----------+------+------------+----------------+---------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+--------------+----------------+--------------------+--------------+------------+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|   MultipleLines|InternetService|     OnlineSecurity|       OnlineBackup|   DeviceProtection|        TechSupport|        StreamingTV|    StreamingMovies|      Contract|PaperlessBilling|       PaymentMethod|MonthlyCharges|TotalCharges|Churn|
+----------+------+-------------+-------+----------+------+------------+----------------+---------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+--------------+----------------+--------------------+--------------+------------+-----+
|7590-VHVEG|Female|            0|    Yes|        No|     1|  

In [7]:
# Counting number of missing (nan) values in each column, used this blogpost as a reference:
#  https://www.datasciencemadesimple.com/count-of-missing-nanna-and-null-values-in-pyspark/

churn_df.select([F.count(F.when(F.isnan(col), col)).alias(col) for col in churn_df.columns]).show()

+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+--------------+------------+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|Contract|PaperlessBilling|PaymentMethod|MonthlyCharges|TotalCharges|Churn|
+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+--------------+------------+-----+
|         0|     0|            0|      0|         0|     0|           0|            0|              0|             0|           0|               0|          0|          0|              0|       0|               0| 

In [8]:
# Counting number of null values in each column, used this blogpost as a reference:
#  https://www.datasciencemadesimple.com/count-of-missing-nanna-and-null-values-in-pyspark/

churn_df.select([F.count(F.when(F.isnull(col), col)).alias(col) for col in churn_df.columns]).show()

+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+--------------+------------+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|Contract|PaperlessBilling|PaymentMethod|MonthlyCharges|TotalCharges|Churn|
+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+--------------+------------+-----+
|         0|     0|            0|      0|         0|     0|           0|            0|              0|             0|           0|               0|          0|          0|              0|       0|               0| 

In [9]:
# Checking number of duplicate rows
#  Used this stackoverflow answer: https://stackoverflow.com/a/48554666
churn_df.groupBy(churn_df.columns)\
    .count()\
    .where(F.col('count') > 1)\
    .select(F.sum('count'))\
    .show()

+----------+
|sum(count)|
+----------+
|      null|
+----------+



## Imputing null values with mean value 

In [10]:
churn_df = churn_df.na.fill(
    {
        'MonthlyCharges': churn_df.agg({"MonthlyCharges": "avg"}).collect()[0][0],
        'TotalCharges': churn_df.agg({"TotalCharges": "avg"}).collect()[0][0]
    }
)

## Data preparation

In [11]:
from pyspark.ml.feature import StringIndexer, VectorIndexer, OneHotEncoder

In [12]:
# Change the label column to label
churn_df = churn_df.withColumnRenamed('Churn','label')

In [13]:
# List of columns need to be indexed and featurized
col_list = [
    "gender",
    "SeniorCitizen", 
    "Partner",
    "Dependents",
    "PhoneService", 
    "MultipleLines",
    "InternetService", 
    "OnlineSecurity",
    "OnlineBackup",
    "DeviceProtection",
    "TechSupport", 
    "StreamingTV",
    "StreamingMovies",
    "Contract",
    "PaperlessBilling", 
    "PaymentMethod"
]

featurized_col_list = col_list + ["tenure", "MonthlyCharges", "TotalCharges"]

In [14]:
# List of features and label indexers 
indexers = [
    StringIndexer(inputCol=c, outputCol=f'{c}_indexed')
    for c in col_list
]

label_indexer = StringIndexer(inputCol="label", outputCol="labelIndex")
indexers.append(label_indexer)

In [15]:
# One hot encode the categorical columns
encoder = OneHotEncoder(
    inputCols = [f'{c}_indexed' for c in col_list], 
    outputCols=[f'{c}_vector' for c in col_list],
    dropLast=True
)

In [16]:
# Vectorizing the features
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

vector_assembler = VectorAssembler(inputCols=[f'{c}_vector' for c in col_list], outputCol="features")

### Pipeline preparation

In [17]:
# Define the pipline
from pyspark.ml import Pipeline

pipeline = Pipeline(stages=indexers + [encoder, vector_assembler])

In [18]:
# fit and transform the pipeline
pipeline_model = pipeline.fit(churn_df)
df = pipeline_model.transform(churn_df)

In [19]:
df.printSchema()

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: integer (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: float (nullable = false)
 |-- TotalCharges: float (nullable = false)
 |-- label: string (nullable = true)
 |-- gender_indexed: double (nullable = false)
 |-- SeniorCitizen_indexed: double 

In [20]:
df.select("StreamingMovies", "StreamingMovies_indexed", "StreamingMovies_vector").show()

+-------------------+-----------------------+----------------------+
|    StreamingMovies|StreamingMovies_indexed|StreamingMovies_vector|
+-------------------+-----------------------+----------------------+
|                 No|                    0.0|         (2,[0],[1.0])|
|                 No|                    0.0|         (2,[0],[1.0])|
|                 No|                    0.0|         (2,[0],[1.0])|
|                 No|                    0.0|         (2,[0],[1.0])|
|                 No|                    0.0|         (2,[0],[1.0])|
|                Yes|                    1.0|         (2,[1],[1.0])|
|                 No|                    0.0|         (2,[0],[1.0])|
|                 No|                    0.0|         (2,[0],[1.0])|
|                Yes|                    1.0|         (2,[1],[1.0])|
|                 No|                    0.0|         (2,[0],[1.0])|
|                 No|                    0.0|         (2,[0],[1.0])|
|No internet service|             

In [21]:
cols_drop = [f'{c}_indexed' for c in col_list] + [f'{c}_vector' for c in col_list] + [f'{c}' for c in featurized_col_list] + ['customerID']
df = df.drop(*cols_drop)

In [22]:
df.printSchema()

root
 |-- label: string (nullable = true)
 |-- labelIndex: double (nullable = false)
 |-- features: vector (nullable = true)



In [23]:
df.show()

+-----+----------+--------------------+
|label|labelIndex|            features|
+-----+----------+--------------------+
|   No|       0.0|(27,[1,3,8,9,12,1...|
|   No|       0.0|(27,[0,1,2,3,4,5,...|
|  Yes|       1.0|(27,[0,1,2,3,4,5,...|
|   No|       0.0|(27,[0,1,2,3,8,10...|
|  Yes|       1.0|(27,[1,2,3,4,5,7,...|
|  Yes|       1.0|(27,[1,2,3,4,6,7,...|
|   No|       0.0|(27,[0,1,2,4,6,7,...|
|   No|       0.0|(27,[1,2,3,8,10,1...|
|  Yes|       1.0|(27,[1,3,4,6,7,9,...|
|   No|       0.0|(27,[0,1,2,4,5,8,...|
|   No|       0.0|(27,[0,1,4,5,8,10...|
|   No|       0.0|(27,[0,1,2,3,4,5,...|
|   No|       0.0|(27,[0,1,3,4,6,7,...|
|  Yes|       1.0|(27,[0,1,2,3,4,6,...|
|   No|       0.0|(27,[0,1,2,3,4,5,...|
|   No|       0.0|(27,[1,4,6,7,10,1...|
|   No|       0.0|(27,[1,2,3,4,5,25...|
|   No|       0.0|(27,[0,1,2,4,6,7,...|
|  Yes|       1.0|(27,[1,4,5,8,9,11...|
|   No|       0.0|(27,[1,2,3,4,5,7,...|
+-----+----------+--------------------+
only showing top 20 rows



### Splitting data to train/test splits

In [24]:
(train_data, test_data) = df.randomSplit(weights=[0.7, 0.3], seed=420)

In [25]:
train_data.describe().show()

+-------+-----+-------------------+
|summary|label|         labelIndex|
+-------+-----+-------------------+
|  count| 4840|               4840|
|   mean| null|0.26776859504132233|
| stddev| null| 0.4428420632218157|
|    min|   No|                0.0|
|    max|  Yes|                1.0|
+-------+-----+-------------------+



In [26]:
test_data.describe().show()

+-------+-----+-------------------+
|summary|label|         labelIndex|
+-------+-----+-------------------+
|  count| 2203|               2203|
|   mean| null|0.26009986382206085|
| stddev| null|0.43878847015331973|
|    min|   No|                0.0|
|    max|  Yes|                1.0|
+-------+-----+-------------------+



### Classifier model training

In [27]:
from pyspark.ml.classification import DecisionTreeClassifier

model = DecisionTreeClassifier(
    labelCol="labelIndex",
    featuresCol="features",
    maxDepth=5,
    maxBins=32
)

In [28]:
model = model.fit(train_data)

In [29]:
predictions = model.transform(test_data)
predictions.show()

+-----+----------+--------------------+--------------+--------------------+----------+
|label|labelIndex|            features| rawPrediction|         probability|prediction|
+-----+----------+--------------------+--------------+--------------------+----------+
|   No|       0.0|(27,[0,1,2,3,4,5]...|[2031.0,146.0]|[0.93293523197060...|       0.0|
|   No|       0.0|(27,[0,1,2,3,4,5]...|[2031.0,146.0]|[0.93293523197060...|       0.0|
|   No|       0.0|(27,[0,1,2,3,4,5]...|[2031.0,146.0]|[0.93293523197060...|       0.0|
|   No|       0.0|(27,[0,1,2,3,4,5,...| [149.0,166.0]|[0.47301587301587...|       1.0|
|   No|       0.0|(27,[0,1,2,3,4,5,...| [263.0,491.0]|[0.34880636604774...|       1.0|
|   No|       0.0|(27,[0,1,2,3,4,5,...| [263.0,491.0]|[0.34880636604774...|       1.0|
|   No|       0.0|(27,[0,1,2,3,4,5,...| [149.0,166.0]|[0.47301587301587...|       1.0|
|   No|       0.0|(27,[0,1,2,3,4,5,...| [149.0,166.0]|[0.47301587301587...|       1.0|
|   No|       0.0|(27,[0,1,2,3,4,5,...| [14

### Link to MulticlassClassificationEvaluator list of metric names:
[link](https://spark.apache.org/docs/latest/api/python/pyspark.ml.html#pyspark.ml.evaluation.MulticlassClassificationEvaluator.metricName)

In [30]:
# Confusion Matrix calculation
#  taken from this stackoverflow post: https://stackoverflow.com/a/58405759/10086080
from pyspark.mllib.evaluation import MulticlassMetrics

#select only prediction and label columns
preds_and_labels = predictions.select(['prediction','labelIndex'])

metrics = MulticlassMetrics(preds_and_labels.rdd.map(tuple))

print(metrics.confusionMatrix().toArray())

[[1412.  218.]
 [ 277.  296.]]


In [31]:
# Accuracy metric calculation
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(labelCol="labelIndex", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)

print("Test Error = %g " % (1.0 - accuracy))
print("Accuracy = %g " % accuracy)

Test Error = 0.224694 
Accuracy = 0.775306 


In [32]:
# Weighted Precision metric calculation
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(labelCol="labelIndex", predictionCol="prediction", metricName="weightedPrecision")
precision = evaluator.evaluate(predictions)

print("Weighted Precision = %g " % precision)

Weighted Precision = 0.76834 


In [33]:
# Weighted Recall metric calculation
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(labelCol="labelIndex", predictionCol="prediction", metricName="weightedRecall")
recall = evaluator.evaluate(predictions)

print("Weighted Recall = %g " % recall)

Weighted Recall = 0.775306 


In [34]:
# F1-score metric calculation
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(labelCol="labelIndex", predictionCol="prediction", metricName="f1")
f1 = evaluator.evaluate(predictions)

print("F1-score = %g " % f1)

F1-score = 0.771206 


In [35]:
# Receiver Operating Characteristic Area Under Curve metric calculation
from pyspark.ml.evaluation import BinaryClassificationEvaluator

evaluator = BinaryClassificationEvaluator(labelCol="labelIndex", rawPredictionCol="prediction", metricName="areaUnderROC")
roc_auc = evaluator.evaluate(predictions)

print(roc_auc)

0.6914185376717095


In [36]:
model_parameter_string = model.toDebugString

In [37]:
print(model_parameter_string)

DecisionTreeClassificationModel: uid=DecisionTreeClassifier_3ef04cb3d4b8, depth=5, numNodes=17, numClasses=2, numFeatures=27
  If (feature 21 in {0.0})
   Predict: 0.0
  Else (feature 21 not in {0.0})
   If (feature 7 in {0.0})
    If (feature 11 in {0.0})
     Predict: 0.0
    Else (feature 11 not in {0.0})
     If (feature 24 in {0.0})
      Predict: 0.0
     Else (feature 24 not in {0.0})
      If (feature 1 in {0.0})
       Predict: 1.0
      Else (feature 1 not in {0.0})
       Predict: 0.0
   Else (feature 7 not in {0.0})
    If (feature 9 in {0.0})
     Predict: 0.0
    Else (feature 9 not in {0.0})
     If (feature 24 in {0.0})
      If (feature 11 in {0.0})
       Predict: 0.0
      Else (feature 11 not in {0.0})
       Predict: 1.0
     Else (feature 24 not in {0.0})
      Predict: 1.0



In [53]:
model.extractParamMap()

{Param(parent='DecisionTreeClassifier_3ef04cb3d4b8', name='featuresCol', doc='features column name.'): 'features',
 Param(parent='DecisionTreeClassifier_3ef04cb3d4b8', name='labelCol', doc='label column name.'): 'labelIndex',
 Param(parent='DecisionTreeClassifier_3ef04cb3d4b8', name='predictionCol', doc='prediction column name.'): 'prediction',
 Param(parent='DecisionTreeClassifier_3ef04cb3d4b8', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.'): 'probability',
 Param(parent='DecisionTreeClassifier_3ef04cb3d4b8', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.'): 'rawPrediction',
 Param(parent='DecisionTreeClassifier_3ef04cb3d4b8', name='seed', doc='random seed.'): 1898502691226342252,
 Param(parent='DecisionTreeClassifier_3ef04cb3d4b8', name='cacheNodeIds', doc='If false,

### Testing grid search training with Cross Validation

In [57]:
model = DecisionTreeClassifier(
    labelCol="labelIndex",
    featuresCol="features",
    maxDepth=5,
    maxBins=32
)

In [58]:
# Cross Validator takes a pipeline object containing the model not the model object itself
pipeline = Pipeline(stages=[model])

In [59]:
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

param_grid = ParamGridBuilder() \
    .baseOn({model.labelCol: 'labelIndex'}) \
    .baseOn([model.predictionCol, 'prediction']) \
    .addGrid(model.maxDepth, [i for i in range(2, 10)]) \
    .addGrid(model.maxBins, [i for i in range(2, 60)]) \
    .build()

In [60]:
len(param_grid)

464

In [61]:
evaluator = MulticlassClassificationEvaluator(labelCol="labelIndex", predictionCol="prediction", metricName="accuracy")

In [62]:
cross_validator = CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=param_grid,
    evaluator=evaluator,
    numFolds=10
)

In [63]:
# Run cross validations
dtcv_model = cross_validator.fit(train_data)
print(dtcv_model)

CrossValidatorModel_cebded9f894c


In [64]:
len(dtcv_model.avgMetrics)

464

In [65]:
dtcv_model.avgMetrics

[0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180010716,
 0.7663922180

In [66]:
# Use test set here so we can measure the accuracy of our model on new data
dtpredictions = dtcv_model.transform(test_data)

In [67]:
# cvModel uses the best model found from the Cross Validation
# Evaluate best model
from pyspark.mllib.evaluation import BinaryClassificationMetrics

print('Accuracy:', evaluator.evaluate(dtpredictions))

Accuracy: 0.7653200181570585


### Extacting best model from cross validator object

In [68]:
best_model = dtcv_model.bestModel

In [69]:
dir(best_model)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__metaclass__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_copyValues',
 '_copy_params',
 '_defaultParamMap',
 '_dummy',
 '_from_java',
 '_paramMap',
 '_params',
 '_randomUID',
 '_resetUid',
 '_resolveParam',
 '_set',
 '_setDefault',
 '_shouldOwn',
 '_to_java',
 '_transform',
 'clear',
 'copy',
 'explainParam',
 'explainParams',
 'extractParamMap',
 'getOrDefault',
 'getParam',
 'hasDefault',
 'hasParam',
 'isDefined',
 'isSet',
 'load',
 'params',
 'read',
 'save',
 'set',
 'stages',
 'transform',
 'uid',
 'write']

In [70]:
best_model.stages

[DecisionTreeClassificationModel: uid=DecisionTreeClassifier_09e51251a702, depth=7, numNodes=105, numClasses=2, numFeatures=27]

In [71]:
best_model = best_model.stages[0]

In [72]:
print(best_model.toDebugString)

DecisionTreeClassificationModel: uid=DecisionTreeClassifier_09e51251a702, depth=7, numNodes=105, numClasses=2, numFeatures=27
  If (feature 21 in {0.0})
   If (feature 7 in {0.0})
    If (feature 9 in {0.0})
     Predict: 0.0
    Else (feature 9 not in {0.0})
     If (feature 22 in {1.0})
      If (feature 15 in {0.0})
       Predict: 0.0
      Else (feature 15 not in {0.0})
       If (feature 24 in {0.0})
        Predict: 0.0
       Else (feature 24 not in {0.0})
        If (feature 4 in {1.0})
         Predict: 0.0
        Else (feature 4 not in {1.0})
         Predict: 1.0
     Else (feature 22 not in {1.0})
      If (feature 23 in {0.0})
       If (feature 4 in {1.0})
        Predict: 0.0
       Else (feature 4 not in {1.0})
        If (feature 2 in {1.0})
         Predict: 0.0
        Else (feature 2 not in {1.0})
         Predict: 1.0
      Else (feature 23 not in {0.0})
       Predict: 0.0
   Else (feature 7 not in {0.0})
    If (feature 22 in {1.0})
     Predict: 0.0
    Else (

In [73]:
best_model.extractParamMap()

{Param(parent='DecisionTreeClassifier_09e51251a702', name='featuresCol', doc='features column name.'): 'features',
 Param(parent='DecisionTreeClassifier_09e51251a702', name='labelCol', doc='label column name.'): 'labelIndex',
 Param(parent='DecisionTreeClassifier_09e51251a702', name='predictionCol', doc='prediction column name.'): 'prediction',
 Param(parent='DecisionTreeClassifier_09e51251a702', name='probabilityCol', doc='Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.'): 'probability',
 Param(parent='DecisionTreeClassifier_09e51251a702', name='rawPredictionCol', doc='raw prediction (a.k.a. confidence) column name.'): 'rawPrediction',
 Param(parent='DecisionTreeClassifier_09e51251a702', name='seed', doc='random seed.'): 1898502691226342252,
 Param(parent='DecisionTreeClassifier_09e51251a702', name='cacheNodeIds', doc='If false,