In [None]:
!pip install pyspark

In [None]:
!pip install pyarrow

In [8]:
# Import libraries

from pyspark.sql import SparkSession

from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.feature import VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.tuning import CrossValidator
from pyspark.ml.tuning import ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
import pyspark.sql.functions as fn
from pyspark.ml.feature import OneHotEncoder, VectorAssembler, StringIndexer
from pyspark.sql.functions import udf, StringType

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [9]:
spark = SparkSession.builder.master('local').getOrCreate()

In [10]:
path = '/content/drive/MyDrive/Marketing-Analytics/Data/Churn_Modelling.csv'

In [11]:
df1=spark.read.csv(path,inferSchema=True, header=True)

In [12]:
df1.show(8)

+----------+--------+---------+------+---+------+----------+-----------+---------+------+---------------+------+
|CustomerId|CredRate|Geography|Gender|Age|Tenure|   Balance|Prod Number|HasCrCard|ActMem|EstimatedSalary|Exited|
+----------+--------+---------+------+---+------+----------+-----------+---------+------+---------------+------+
|  15634602|     619|   France|Female| 42|     2|       0.0|          1|        1|     1|      91213.992|     1|
|  15647311|     608|    Spain|Female| 41|     1| 75427.074|          1|        0|     1|     101288.322|     0|
|  15619304|     502|   France|Female| 42|     8| 143694.72|          3|        1|     0|     102538.413|     1|
|  15701354|     699|   France|Female| 39|     1|       0.0|          2|        0|     0|      84443.967|     0|
|  15737888|     850|    Spain|Female| 43|     2|112959.738|          1|        1|     1|       71175.69|     0|
|  15574012|     645|    Spain|  Male| 44|     8|102380.202|          2|        1|     0|     13

In [13]:
# check what percentage of missing observations are there in each column
df1.agg(*[
    (1 - (fn.count(c) / fn.count('*'))).alias(c + '_missing')
    for c in df1.columns
]).show()

+------------------+----------------+-----------------+--------------------+--------------------+--------------+---------------+-------------------+-----------------+--------------+-----------------------+--------------+
|CustomerId_missing|CredRate_missing|Geography_missing|      Gender_missing|         Age_missing|Tenure_missing|Balance_missing|Prod Number_missing|HasCrCard_missing|ActMem_missing|EstimatedSalary_missing|Exited_missing|
+------------------+----------------+-----------------+--------------------+--------------------+--------------+---------------+-------------------+-----------------+--------------+-----------------------+--------------+
|               0.0|             0.0|              0.0|3.999999999999559...|6.000000000000449E-4|           0.0|            0.0|                0.0|              0.0|           0.0|                    0.0|           0.0|
+------------------+----------------+-----------------+--------------------+--------------------+--------------+----

In [14]:
df1.describe().show()

+-------+-----------------+-----------------+---------+------+------------------+------------------+-----------------+------------------+-------------------+-------------------+------------------+-------------------+
|summary|       CustomerId|         CredRate|Geography|Gender|               Age|            Tenure|          Balance|       Prod Number|          HasCrCard|             ActMem|   EstimatedSalary|             Exited|
+-------+-----------------+-----------------+---------+------+------------------+------------------+-----------------+------------------+-------------------+-------------------+------------------+-------------------+
|  count|            10000|            10000|    10000|  9996|              9994|             10000|            10000|             10000|              10000|              10000|             10000|              10000|
|   mean|  1.56909405694E7|         650.5288|     null|  null|38.925255153091854|            5.0128| 68837.3003592003|            1.

In [16]:
df1.printSchema()

root
 |-- CustomerId: integer (nullable = true)
 |-- CredRate: integer (nullable = true)
 |-- Geography: string (nullable = true)
 |-- Gender: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tenure: integer (nullable = true)
 |-- Balance: double (nullable = true)
 |-- Prod Number: integer (nullable = true)
 |-- HasCrCard: integer (nullable = true)
 |-- ActMem: integer (nullable = true)
 |-- EstimatedSalary: double (nullable = true)
 |-- Exited: integer (nullable = true)



In [17]:
print('Count of rows: {0}'.format(df1.count()))
print('Count of distinct rows: {0}'.format(df1.distinct().count()))

Count of rows: 10000
Count of distinct rows: 10000


In [19]:
df1.groupBy("Geography").count().orderBy("count", ascending=False).show()

+---------+-----+
|Geography|count|
+---------+-----+
|   France| 5014|
|  Germany| 2509|
|    Spain| 2477|
+---------+-----+



In [20]:
df1.groupBy("Gender").count().orderBy("count", ascending=False).show()

+------+-----+
|Gender|count|
+------+-----+
|  Male| 5453|
|Female| 4543|
|  null|    4|
+------+-----+



In [27]:
df1.groupBy("Age").count().orderBy("count", ascending=False).first()[0]

37

In [21]:
df1.groupBy("HasCrCard").count().orderBy("count", ascending=False).show()

+---------+-----+
|HasCrCard|count|
+---------+-----+
|        1| 7055|
|        0| 2945|
+---------+-----+



In [22]:
df1.groupBy("Prod Number").count().orderBy("count", ascending=False).show()

+-----------+-----+
|Prod Number|count|
+-----------+-----+
|          1| 5084|
|          2| 4590|
|          3|  266|
|          4|   60|
+-----------+-----+



In [23]:
df1.groupBy("ActMem").count().orderBy("count", ascending=False).show()

+------+-----+
|ActMem|count|
+------+-----+
|     1| 5151|
|     0| 4849|
+------+-----+



In [24]:
df1.groupBy("Exited").count().orderBy("count", ascending=False).show()

+------+-----+
|Exited|count|
+------+-----+
|     0| 7963|
|     1| 2037|
+------+-----+



In [29]:
# impute with mode 
df2 = df1.fillna({'Gender':'Male', 'Age': 37})

In [31]:
# check what percentage of missing observations are there in each column
df2.agg(*[
    (1 - (fn.count(c) / fn.count('*'))).alias(c + '_missing')
    for c in df2.columns
]).show()

+------------------+----------------+-----------------+--------------+-----------+--------------+---------------+-------------------+-----------------+--------------+-----------------------+--------------+
|CustomerId_missing|CredRate_missing|Geography_missing|Gender_missing|Age_missing|Tenure_missing|Balance_missing|Prod Number_missing|HasCrCard_missing|ActMem_missing|EstimatedSalary_missing|Exited_missing|
+------------------+----------------+-----------------+--------------+-----------+--------------+---------------+-------------------+-----------------+--------------+-----------------------+--------------+
|               0.0|             0.0|              0.0|           0.0|        0.0|           0.0|            0.0|                0.0|              0.0|           0.0|                    0.0|           0.0|
+------------------+----------------+-----------------+--------------+-----------+--------------+---------------+-------------------+-----------------+--------------+----------

In [32]:
df3 = df2.withColumnRenamed('CredRate', 'CreditScore') #rename columns: original name, new name
df3 = df3.withColumnRenamed('ActMem', 'IsActiveMember')
df3 = df3.withColumnRenamed('Prod Number', 'NumOfProducts')
df3 = df3.withColumnRenamed('Exited', 'Churn')

In [33]:
df3.printSchema()

root
 |-- CustomerId: integer (nullable = true)
 |-- CreditScore: integer (nullable = true)
 |-- Geography: string (nullable = true)
 |-- Gender: string (nullable = false)
 |-- Age: integer (nullable = false)
 |-- Tenure: integer (nullable = true)
 |-- Balance: double (nullable = true)
 |-- NumOfProducts: integer (nullable = true)
 |-- HasCrCard: integer (nullable = true)
 |-- IsActiveMember: integer (nullable = true)
 |-- EstimatedSalary: double (nullable = true)
 |-- Churn: integer (nullable = true)



In [35]:
df3.describe().show()

+-------+-----------------+-----------------+---------+------+------------------+------------------+-----------------+------------------+-------------------+-------------------+------------------+-------------------+
|summary|       CustomerId|      CreditScore|Geography|Gender|               Age|            Tenure|          Balance|     NumOfProducts|          HasCrCard|     IsActiveMember|   EstimatedSalary|              Churn|
+-------+-----------------+-----------------+---------+------+------------------+------------------+-----------------+------------------+-------------------+-------------------+------------------+-------------------+
|  count|            10000|            10000|    10000| 10000|             10000|             10000|            10000|             10000|              10000|              10000|             10000|              10000|
|   mean|  1.56909405694E7|         650.5288|     null|  null|           38.9241|            5.0128| 68837.3003592003|            1.

In [36]:
df3=df3.withColumn("Gender", fn.when(fn.col("Gender") == 'Male' , 1).otherwise(0))

In [37]:
def udf_Geo(Geo): ## name of the UDF
      if (Geo == 'France'):
        return 0
      elif (Geo == 'Germany'):
        return 1
      else: return 2

In [38]:
Geo_udf = udf(udf_Geo)
df3=df3.withColumn("Geography", Geo_udf('Geography'))

In [39]:
from typing import cast
df3 = df3.withColumn('Geography', df3['Geography'].cast('integer'))
df3 = df3.withColumn('Gender', df3['Gender'].cast('integer'))

In [40]:
df3.printSchema()

root
 |-- CustomerId: integer (nullable = true)
 |-- CreditScore: integer (nullable = true)
 |-- Geography: integer (nullable = true)
 |-- Gender: integer (nullable = false)
 |-- Age: integer (nullable = false)
 |-- Tenure: integer (nullable = true)
 |-- Balance: double (nullable = true)
 |-- NumOfProducts: integer (nullable = true)
 |-- HasCrCard: integer (nullable = true)
 |-- IsActiveMember: integer (nullable = true)
 |-- EstimatedSalary: double (nullable = true)
 |-- Churn: integer (nullable = true)



In [41]:
###split data into training and testing
train_data,test_data=df3.randomSplit([0.7,0.3]) ##70% for training

In [42]:
Pred_corr= ['CreditScore', 'Geography', 'Gender','Age', 'Tenure', 'Balance', 'NumOfProducts', 'HasCrCard', 'IsActiveMember', 'EstimatedSalary']

In [43]:
# Initializing Vector Assembler to convert columns to vector
vector_col = "Predictors"
assembler = VectorAssembler(inputCols=Pred_corr, 
                            outputCol=vector_col) ##assemble the inputs and outputs

In [44]:
# Initializing Random Forest Model
classifier=RandomForestClassifier(featuresCol="Predictors",labelCol="Churn",predictionCol="Churn Prediction")

In [45]:
# Initializing Pipeline to execute all steps at once
pipeline = Pipeline(stages=[assembler, classifier])

In [46]:
# Creating a Parameter Grid with all the parameters 
paramGrid = ParamGridBuilder() \
                  .addGrid(classifier.maxDepth, [3, 5, 10, 15]) \
                  .addGrid(classifier.numTrees, [3, 5, 10, 15]) \
                  .build()

In [47]:
# evaluate model
evaluator = MulticlassClassificationEvaluator(labelCol="Churn", predictionCol="Churn Prediction")

In [48]:
# Initialiazing cross validator model with paramGrid to create parameterized tuned model
crossValidator = CrossValidator(estimator=pipeline,
                             estimatorParamMaps=paramGrid,
                             evaluator=evaluator,
                             numFolds=10)

In [49]:
# Training cross validator model
tuned_model = crossValidator.fit(train_data)

In [50]:
# Getting predictions
predictions = tuned_model.transform(test_data)

In [51]:
# Evaluating accuracy of the model

evaluator = MulticlassClassificationEvaluator() \
                      .setLabelCol("Churn") \
                      .setPredictionCol("Churn Prediction") \
                      .setMetricName("accuracy")

accuracy = evaluator.evaluate(predictions)

print("Test Error : {}".format(100*(1.0 - accuracy)))
print("Accuracy of the model : {}".format(100*accuracy))

Test Error : 15.110812625923442
Accuracy of the model : 84.88918737407656
