# The Decision Tree on the Churn Dataset in Spark

In [22]:
from pyspark.sql import DataFrameReader
from pyspark.sql import SparkSession
from pyspark.ml.feature import IndexToString, Normalizer, StringIndexer, VectorAssembler, VectorIndexer
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.classification import DecisionTreeClassifier

# for pretty printing
def printDf(sprkDF): 
    newdf = sprkDF.toPandas()
    from IPython.display import display, HTML
    return HTML(newdf.to_html())

## Select the churn file 

In [23]:
inputFile = "../data/churn.csv"

## Create the Spark Session 

In [24]:
#create a SparkSession
spark = (SparkSession
       .builder
       .appName("ChurnDecisionTree")
       .getOrCreate())
# create a DataFrame using an ifered Schema 
df = spark.read.option("header", "true") \
       .option("inferSchema", "true") \
       .option("delimiter", ";") \
       .csv(inputFile)   

## Data Preparation
### Transform labels into index

In [25]:
labelIndexer = StringIndexer().setInputCol("LEAVE").setOutputCol("label").fit(df)
collegeIndexer = StringIndexer().setInputCol("COLLEGE").setOutputCol("COLLEGE_NUM").fit(df)
satIndexer = StringIndexer().setInputCol("REPORTED_SATISFACTION").setOutputCol("REPORTED_SATISFACTION_NUM").fit(df)
usageIndexer = StringIndexer().setInputCol("REPORTED_USAGE_LEVEL").setOutputCol("REPORTED_USAGE_LEVEL_NUM").fit(df)
changeIndexer = StringIndexer().setInputCol("CONSIDERING_CHANGE_OF_PLAN").setOutputCol("CONSIDERING_CHANGE_OF_PLAN_NUM").fit(df)

 ### Build the feature vector

In [26]:
featureCols = df.columns.copy()
featureCols.remove("LEAVE")
featureCols.remove("COLLEGE")
featureCols.remove("REPORTED_SATISFACTION")
featureCols.remove("REPORTED_USAGE_LEVEL")
featureCols.remove("CONSIDERING_CHANGE_OF_PLAN")
featureCols = featureCols +["COLLEGE_NUM","REPORTED_SATISFACTION_NUM","REPORTED_USAGE_LEVEL_NUM","CONSIDERING_CHANGE_OF_PLAN_NUM"]

### Build the feature Vector Assembler

In [27]:
assembler =  VectorAssembler(outputCol="features", inputCols=list(featureCols))

## Do the Data Preparation

In [28]:
labeledData = labelIndexer.transform(df)
print(labeledData.printSchema())
indexedLabedData = collegeIndexer.transform(satIndexer.transform(usageIndexer.transform(changeIndexer.transform(labeledData))))
labeledPointData = assembler.transform(indexedLabedData)
labeledPointData.show()

root
 |-- COLLEGE: string (nullable = true)
 |-- INCOME: integer (nullable = true)
 |-- OVERAGE: integer (nullable = true)
 |-- LEFTOVER: integer (nullable = true)
 |-- HOUSE: integer (nullable = true)
 |-- HANDSET_PRICE: integer (nullable = true)
 |-- OVER_15MINS_CALLS_PER_MONTH: integer (nullable = true)
 |-- AVERAGE_CALL_DURATION: integer (nullable = true)
 |-- REPORTED_SATISFACTION: string (nullable = true)
 |-- REPORTED_USAGE_LEVEL: string (nullable = true)
 |-- CONSIDERING_CHANGE_OF_PLAN: string (nullable = true)
 |-- LEAVE: string (nullable = true)
 |-- label: double (nullable = false)

None
+-------+------+-------+--------+------+-------------+---------------------------+---------------------+---------------------+--------------------+--------------------------+-----+-----+------------------------------+------------------------+-------------------------+-----------+--------------------+
|COLLEGE|INCOME|OVERAGE|LEFTOVER| HOUSE|HANDSET_PRICE|OVER_15MINS_CALLS_PER_MONTH|AVERAGE_CA

### As formated output

In [29]:
printDf(labeledPointData.limit(10))

Unnamed: 0,COLLEGE,INCOME,OVERAGE,LEFTOVER,HOUSE,HANDSET_PRICE,OVER_15MINS_CALLS_PER_MONTH,AVERAGE_CALL_DURATION,REPORTED_SATISFACTION,REPORTED_USAGE_LEVEL,CONSIDERING_CHANGE_OF_PLAN,LEAVE,label,CONSIDERING_CHANGE_OF_PLAN_NUM,REPORTED_USAGE_LEVEL_NUM,REPORTED_SATISFACTION_NUM,COLLEGE_NUM,features
0,zero,31953,0,6,313378,161,0,4,unsat,little,no,STAY,0.0,2.0,0.0,2.0,1.0,"[31953.0, 0.0, 6.0, 313378.0, 161.0, 0.0, 4.0, 1.0, 2.0, 0.0, 2.0]"
1,one,36147,0,13,800586,244,0,6,unsat,little,considering,STAY,0.0,0.0,0.0,2.0,0.0,"(36147.0, 0.0, 13.0, 800586.0, 244.0, 0.0, 6.0, 0.0, 2.0, 0.0, 0.0)"
2,one,27273,230,0,305049,201,16,15,unsat,very_little,perhaps,STAY,0.0,4.0,2.0,2.0,0.0,"[27273.0, 230.0, 0.0, 305049.0, 201.0, 16.0, 15.0, 0.0, 2.0, 2.0, 4.0]"
3,zero,120070,38,33,788235,780,3,2,unsat,very_high,considering,LEAVE,1.0,0.0,1.0,2.0,1.0,"[120070.0, 38.0, 33.0, 788235.0, 780.0, 3.0, 2.0, 1.0, 2.0, 1.0, 0.0]"
4,one,29215,208,85,224784,241,21,1,very_unsat,little,never_thought,STAY,0.0,3.0,0.0,0.0,0.0,"[29215.0, 208.0, 85.0, 224784.0, 241.0, 21.0, 1.0, 0.0, 0.0, 0.0, 3.0]"
5,zero,133728,64,48,632969,626,3,2,unsat,high,no,STAY,0.0,2.0,3.0,2.0,1.0,"[133728.0, 64.0, 48.0, 632969.0, 626.0, 3.0, 2.0, 1.0, 2.0, 3.0, 2.0]"
6,zero,42052,224,0,697949,191,10,5,very_unsat,little,actively_looking_into_it,STAY,0.0,1.0,0.0,0.0,1.0,"[42052.0, 224.0, 0.0, 697949.0, 191.0, 10.0, 5.0, 1.0, 0.0, 0.0, 1.0]"
7,one,84744,0,20,688098,357,0,5,very_unsat,little,considering,STAY,0.0,0.0,0.0,0.0,0.0,"(84744.0, 0.0, 20.0, 688098.0, 357.0, 0.0, 5.0, 0.0, 0.0, 0.0, 0.0)"
8,zero,38171,0,7,274218,190,0,5,very_sat,little,actively_looking_into_it,STAY,0.0,1.0,0.0,1.0,1.0,"[38171.0, 0.0, 7.0, 274218.0, 190.0, 0.0, 5.0, 1.0, 1.0, 0.0, 1.0]"
9,zero,105824,174,18,153560,687,25,4,very_sat,little,never_thought,LEAVE,1.0,3.0,0.0,1.0,1.0,"[105824.0, 174.0, 18.0, 153560.0, 687.0, 25.0, 4.0, 1.0, 1.0, 0.0, 3.0]"


In [30]:
spark.stop()