In [0]:
from pyspark.sql import SparkSession

Load the data

In [0]:
file_location = "/FileStore/TelcoChurn.csv"
file_type = "csv"
 
# CSV options
infer_schema = "true"
first_row_is_header = "true"
delimiter = ","
 
df = spark.read.format(file_type) \
  .option("inferSchema", infer_schema) \
  .option("header", first_row_is_header) \
  .option("sep", delimiter) \
  .load(file_location)\

display(df)

customerID,gender,SeniorCitizen,Partner,Dependents,tenure,PhoneService,MultipleLines,InternetService,OnlineSecurity,OnlineBackup,DeviceProtection,TechSupport,StreamingTV,StreamingMovies,Contract,PaperlessBilling,PaymentMethod,MonthlyCharges,TotalCharges,Churn
7590-VHVEG,Female,0,Yes,No,1,No,No phone service,DSL,No,Yes,No,No,No,No,Month-to-month,Yes,Electronic check,29.85,29.85,No
5575-GNVDE,Male,0,No,No,34,Yes,No,DSL,Yes,No,Yes,No,No,No,One year,No,Mailed check,56.95,1889.5,No
3668-QPYBK,Male,0,No,No,2,Yes,No,DSL,Yes,Yes,No,No,No,No,Month-to-month,Yes,Mailed check,53.85,108.15,Yes
7795-CFOCW,Male,0,No,No,45,No,No phone service,DSL,Yes,No,Yes,Yes,No,No,One year,No,Bank transfer (automatic),42.3,1840.75,No
9237-HQITU,Female,0,No,No,2,Yes,No,Fiber optic,No,No,No,No,No,No,Month-to-month,Yes,Electronic check,70.7,151.65,Yes
9305-CDSKC,Female,0,No,No,8,Yes,Yes,Fiber optic,No,No,Yes,No,Yes,Yes,Month-to-month,Yes,Electronic check,99.65,820.5,Yes
1452-KIOVK,Male,0,No,Yes,22,Yes,Yes,Fiber optic,No,Yes,No,No,Yes,No,Month-to-month,Yes,Credit card (automatic),89.1,1949.4,No
6713-OKOMC,Female,0,No,No,10,No,No phone service,DSL,Yes,No,No,No,No,No,Month-to-month,No,Mailed check,29.75,301.9,No
7892-POOKP,Female,0,Yes,No,28,Yes,Yes,Fiber optic,No,No,Yes,Yes,Yes,Yes,Month-to-month,Yes,Electronic check,104.8,3046.05,Yes
6388-TABGU,Male,0,No,Yes,62,Yes,No,DSL,Yes,Yes,No,No,No,No,One year,No,Bank transfer (automatic),56.15,3487.95,No


Data types

In [0]:
df.createOrReplaceTempView("telco")
spark.sql("describe telco").show()

+----------------+---------+-------+
|        col_name|data_type|comment|
+----------------+---------+-------+
|      customerID|   string|   null|
|          gender|   string|   null|
|   SeniorCitizen|      int|   null|
|         Partner|   string|   null|
|      Dependents|   string|   null|
|          tenure|      int|   null|
|    PhoneService|   string|   null|
|   MultipleLines|   string|   null|
| InternetService|   string|   null|
|  OnlineSecurity|   string|   null|
|    OnlineBackup|   string|   null|
|DeviceProtection|   string|   null|
|     TechSupport|   string|   null|
|     StreamingTV|   string|   null|
| StreamingMovies|   string|   null|
|        Contract|   string|   null|
|PaperlessBilling|   string|   null|
|   PaymentMethod|   string|   null|
|  MonthlyCharges|   double|   null|
|    TotalCharges|   string|   null|
+----------------+---------+-------+
only showing top 20 rows



Data Cleaning

In [0]:
# Importing necessary libraries
from pyspark.sql.functions import isnan, when, count, col

# Using describe() method to get non-null count for each column
df.describe().select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in df.columns]).show()

+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+--------------+------------+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|Contract|PaperlessBilling|PaymentMethod|MonthlyCharges|TotalCharges|Churn|
+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+--------------+------------+-----+
|         2|     2|            0|      2|         2|     0|           2|            2|              2|             2|           2|               2|          2|          2|              2|       2|               2| 

In [0]:
#Dropping the NA's
df = df.na.drop()

Feature Engineering

In [0]:
from pyspark.sql.functions import col, when
# Convert gender column to a factor and set "Female" as the reference level
df = df.withColumn("gender", when(col("gender") == "Female", 0).otherwise(1))

#Convert Total Charges to double
df = df.withColumn("TotalCharges", df["TotalCharges"].cast("double"))

# Convert MultipleLines, OnlineSecurity, OnlineBackup, DeviceProtection, TechSupport, StreamingTV, and StreamingMovies columns to binary variables
cols = ["Partner","Dependents", "PhoneService","MultipleLines", "OnlineSecurity", "OnlineBackup", "DeviceProtection", "TechSupport", "StreamingTV", "StreamingMovies","PaperlessBilling", "Churn"]
for c in cols:
    df = df.withColumn(c, when(col(c) == "Yes", 1).otherwise(0))

Descriptive Analysis

In [0]:
from pyspark.sql.functions import countDistinct, mean, stddev, min, max

# Count the number of distinct values for each categorical variable
df.select([countDistinct(c).alias(c) for c in ['gender', 'SeniorCitizen', 'Partner', 'Dependents',
                                                'PhoneService', 'MultipleLines', 'InternetService',
                                                'OnlineSecurity', 'OnlineBackup', 'DeviceProtection',
                                                'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract',
                                                'PaperlessBilling', 'PaymentMethod', 'Churn']]).show()

+------+-------------+-------+----------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+-----+
|gender|SeniorCitizen|Partner|Dependents|PhoneService|MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|Contract|PaperlessBilling|PaymentMethod|Churn|
+------+-------------+-------+----------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+-----+
|     2|            2|      2|         2|           2|            2|              3|             2|           2|               2|          2|          2|              2|       3|               2|            4|    2|
+------+-------------+-------+----------+------------+-------------+---------------+--------------+------------+----------------+-------

In [0]:
# Calculate the mean, standard deviation, minimum, and maximum for the numerical variables
df.select(mean('tenure'), stddev('tenure'), min('tenure'), max('tenure')).show()
df.select(mean('MonthlyCharges'), stddev('MonthlyCharges'), min('MonthlyCharges'), max('MonthlyCharges')).show()
df.select(mean('TotalCharges'), stddev('TotalCharges'), min('TotalCharges'), max('TotalCharges')).show()

+-----------------+-------------------+-----------+-----------+
|      avg(tenure)|stddev_samp(tenure)|min(tenure)|max(tenure)|
+-----------------+-------------------+-----------+-----------+
|32.37114865824223| 24.559481023094442|          0|         72|
+-----------------+-------------------+-----------+-----------+

+-------------------+---------------------------+-------------------+-------------------+
|avg(MonthlyCharges)|stddev_samp(MonthlyCharges)|min(MonthlyCharges)|max(MonthlyCharges)|
+-------------------+---------------------------+-------------------+-------------------+
|  64.76169246059922|         30.090047097678482|              18.25|             118.75|
+-------------------+---------------------------+-------------------+-------------------+

+------------------+-------------------------+-----------------+-----------------+
| avg(TotalCharges)|stddev_samp(TotalCharges)|min(TotalCharges)|max(TotalCharges)|
+------------------+-------------------------+----------------

In [0]:
# Show the top 10 most common values for each categorical variable
cat_cols = ['gender', 'SeniorCitizen', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines',
            'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport',
            'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod', "Churn"]
for c in cat_cols:
    df.groupBy(c).count().orderBy('count', ascending=False).show(10)

+------+-----+
|gender|count|
+------+-----+
|     1| 3555|
|     0| 3488|
+------+-----+

+-------------+-----+
|SeniorCitizen|count|
+-------------+-----+
|            0| 5901|
|            1| 1142|
+-------------+-----+

+-------+-----+
|Partner|count|
+-------+-----+
|      0| 3641|
|      1| 3402|
+-------+-----+

+----------+-----+
|Dependents|count|
+----------+-----+
|         0| 4933|
|         1| 2110|
+----------+-----+

+------------+-----+
|PhoneService|count|
+------------+-----+
|           1| 6361|
|           0|  682|
+------------+-----+

+-------------+-----+
|MultipleLines|count|
+-------------+-----+
|            0| 4072|
|            1| 2971|
+-------------+-----+

+---------------+-----+
|InternetService|count|
+---------------+-----+
|    Fiber optic| 3096|
|            DSL| 2421|
|             No| 1526|
+---------------+-----+

+--------------+-----+
|OnlineSecurity|count|
+--------------+-----+
|             0| 5024|
|             1| 2019|
+--------------+----

In [0]:
from pyspark.sql.functions import col
# Calculate the churn rate (proportion of customers who churned)
churn_rate = df.filter(col('Churn') == 1).count() / df.count()
print("Churn rate:", churn_rate)

Churn rate: 0.2653698707936959


In [0]:
# Show the distribution of the dependent variable
df.groupBy('Churn').count().show()
# Calculate the churn rate by different categories
df.groupBy('SeniorCitizen', 'Churn').count().show()
df.groupBy('Partner', 'Churn').count().show()
df.groupBy('Dependents', 'Churn').count().show()
df.groupBy('MultipleLines', 'Churn').count().show()
df.groupBy('InternetService', 'Churn').count().show()
df.groupBy('OnlineSecurity', 'Churn').count().show()
df.groupBy('OnlineBackup', 'Churn').count().show()
df.groupBy('DeviceProtection', 'Churn').count().show()
df.groupBy('TechSupport', 'Churn').count().show()
df.groupBy('StreamingTV', 'Churn').count().show()
df.groupBy('StreamingMovies', 'Churn').count().show()
df.groupBy('Contract', 'Churn').count().show()
df.groupBy('PaperlessBilling', 'Churn').count().show()
df.groupBy('PaymentMethod', 'Churn').count().show()

+-----+-----+
|Churn|count|
+-----+-----+
|    1| 1869|
|    0| 5174|
+-----+-----+

+-------------+-----+-----+
|SeniorCitizen|Churn|count|
+-------------+-----+-----+
|            1|    0|  666|
|            1|    1|  476|
|            0|    0| 4508|
|            0|    1| 1393|
+-------------+-----+-----+

+-------+-----+-----+
|Partner|Churn|count|
+-------+-----+-----+
|      1|    0| 2733|
|      1|    1|  669|
|      0|    0| 2441|
|      0|    1| 1200|
+-------+-----+-----+

+----------+-----+-----+
|Dependents|Churn|count|
+----------+-----+-----+
|         1|    0| 1784|
|         1|    1|  326|
|         0|    0| 3390|
|         0|    1| 1543|
+----------+-----+-----+

+-------------+-----+-----+
|MultipleLines|Churn|count|
+-------------+-----+-----+
|            1|    0| 2121|
|            1|    1|  850|
|            0|    0| 3053|
|            0|    1| 1019|
+-------------+-----+-----+

+---------------+-----+-----+
|InternetService|Churn|count|
+---------------+-----+----

In [0]:
# Calculate the mean, standard deviation, minimum, and maximum for numerical variables, by churn status
df.groupBy('Churn').agg(mean('tenure'), stddev('tenure'), min('tenure'), max('tenure')).show()
df.groupBy('Churn').agg(mean('MonthlyCharges'), stddev('MonthlyCharges'), min('MonthlyCharges'), max('MonthlyCharges')).show()
df.groupBy('Churn').agg(mean('TotalCharges'), stddev('TotalCharges'), min('TotalCharges'), max('TotalCharges')).show()

+-----+------------------+-------------------+-----------+-----------+
|Churn|       avg(tenure)|stddev_samp(tenure)|min(tenure)|max(tenure)|
+-----+------------------+-------------------+-----------+-----------+
|    1|17.979133226324237|  19.53112305451955|          1|         72|
|    0| 37.56996521066873|  24.11377669070408|          0|         72|
+-----+------------------+-------------------+-----------+-----------+

+-----+-------------------+---------------------------+-------------------+-------------------+
|Churn|avg(MonthlyCharges)|stddev_samp(MonthlyCharges)|min(MonthlyCharges)|max(MonthlyCharges)|
+-----+-------------------+---------------------------+-------------------+-------------------+
|    1|   74.4413322632423|         24.666053259397422|              18.85|             118.35|
|    0|   61.2651236953999|         31.092648119345316|              18.25|             118.75|
+-----+-------------------+---------------------------+-------------------+------------------

Data Cleaning

Removing the NA's

In [0]:
#Dropping the NA's
df = df.na.drop()

Building the Classification Models

We will be building 3 classification models to predict the churn. We will compare and analyze and select a best model out of it and then we will optimize the neccessary metric on that model.

In [0]:
# Import required libraries
from pyspark.ml import Pipeline
from pyspark.ml.classification import DecisionTreeClassifier, RandomForestClassifier, GBTClassifier, LinearSVC
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.feature import StringIndexer, VectorAssembler, OneHotEncoder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, BinaryClassificationEvaluator
from pyspark.mllib.evaluation import MulticlassMetrics
import numpy as np
from sklearn.metrics import confusion_matrix

In [0]:
# Use StringIndexer to convert the categorical columns to hold numerical data 
InternetService_indexer = StringIndexer(inputCol='InternetService',outputCol='InternetService_index',handleInvalid='keep')
Contract_indexer = StringIndexer(inputCol='Contract',outputCol='Contract_index',handleInvalid='keep')
PaymentMethod_indexer = StringIndexer(inputCol='PaymentMethod',outputCol='PaymentMethod_index',handleInvalid='keep')

In [0]:
# OneHotEncoderEstimator converts the indexed data into a vector
data_encoder = OneHotEncoder(inputCols=['InternetService_index','Contract_index',
                                            'PaymentMethod_index'], outputCols= ['InternetService_vec','Contract_vec','PaymentMethod_vec'],
                                      handleInvalid='keep')

In [0]:
# Vector assembler is used to create a vector of input features
assembler = VectorAssembler(inputCols=['gender','SeniorCitizen', 'Partner', 'Dependents', 'tenure', 'PhoneService','MultipleLines', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV','StreamingMovies','PaperlessBilling', 'MonthlyCharges', 'TotalCharges','InternetService_vec', 'Contract_vec','PaymentMethod_vec'],
                            outputCol="features")

In [0]:
# Define classifiers
dt = DecisionTreeClassifier(labelCol="Churn", featuresCol="features")
rf = RandomForestClassifier(labelCol="Churn", featuresCol="features")
gbt = GBTClassifier(labelCol="Churn", featuresCol="features")

In [0]:
# Define evaluation metrics
bin_evaluator = BinaryClassificationEvaluator(labelCol="Churn")
multi_evaluator = MulticlassClassificationEvaluator(labelCol="Churn")

In [0]:
# Define pipeline for each model
dt_pipeline = Pipeline(stages=[InternetService_indexer,Contract_indexer,PaymentMethod_indexer, data_encoder, assembler, dt])
rf_pipeline = Pipeline(stages=[InternetService_indexer,Contract_indexer,PaymentMethod_indexer, data_encoder, assembler, rf])
gbt_pipeline = Pipeline(stages=[InternetService_indexer,Contract_indexer,PaymentMethod_indexer, data_encoder, assembler, gbt])
svm_pipeline = Pipeline(stages=[InternetService_indexer,Contract_indexer,PaymentMethod_indexer, data_encoder, assembler, svm])

In [0]:
# Split the data into training and testing sets
train_data, test_data = df.randomSplit([0.7, 0.3], seed=1234)

In [0]:
# Train and evaluate Decision Tree model
dt_model = dt_pipeline.fit(train_data)
dt_predictions = dt_model.transform(test_data)
print("Decision Tree evaluation metrics:")
print("Accuracy:", multi_evaluator.evaluate(dt_predictions))
print("AUC:", bin_evaluator.evaluate(dt_predictions))

Decision Tree evaluation metrics:
Accuracy: 0.7897297459579693
AUC: 0.5796615602339322


In [0]:
#Looking at the confusion matrix, Precision, Recall, F-1 score and misclassification rate for the DT.
y_true =dt_predictions.select("Churn")
y_true = y_true.toPandas()
 
y_pred = dt_predictions.select("prediction")
y_pred = y_pred.toPandas()
 
cnf_matrix_dt = confusion_matrix(y_true, y_pred)
print("Below is the confusion matrix \n {}".format(cnf_matrix))

Below is the confusion matrix 
 [[1278  217]
 [ 214  339]]


In [0]:
tn = cnf_matrix_dt[0][0]
fp = cnf_matrix_dt[0][1]
fn = cnf_matrix_dt[1][0]
tp = cnf_matrix_dt[1][1]
accuracy = (tp+tn)/(tp+tn+fp+fn)
precision = tp/(tp+fp)
recall = tp/(tp+fn)
f1_score = 2*(precision*recall)/(precision+recall)
print(f"Accuracy: {accuracy:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1 Score: {f1_score:.2f}")
print(f"Misclassification error: { 1 - accuracy:.2f}")

Accuracy: 0.79
Precision: 0.61
Recall: 0.61
F1 Score: 0.61
Misclassification error: 0.21


In [0]:
# Train and evaluate Random Forest model
rf_model = rf_pipeline.fit(train_data)
rf_predictions = rf_model.transform(test_data)
print("Random Forest evaluation metrics:")
print("Accuracy:", multi_evaluator.evaluate(rf_predictions))
print("AUC:", bin_evaluator.evaluate(rf_predictions))

Random Forest evaluation metrics:
Accuracy: 0.768320230788529
AUC: 0.8361070959860168


In [0]:
#Looking at the confusion matrix, Precision, Recall, F-1 score and misclassification rate for the Random Forest model.
y_true = rf_predictions.select("Churn")
y_true = y_true.toPandas()
 
y_pred = rf_predictions.select("prediction")
y_pred = y_pred.toPandas()

cnf_matrix_rf = confusion_matrix(y_true, y_pred)
print("Below is the confusion matrix \n {}".format(cnf_matrix_rf))

Below is the confusion matrix 
 [[1391  104]
 [ 330  223]]


In [0]:
tn = cnf_matrix_rf[0][0]
fp = cnf_matrix_rf[0][1]
fn = cnf_matrix_rf[1][0]
tp = cnf_matrix_rf[1][1]
accuracy = (tp+tn)/(tp+tn+fp+fn)
precision = tp/(tp+fp)
recall = tp/(tp+fn)
f1_score = 2*(precision*recall)/(precision+recall)
print(f"Accuracy: {accuracy:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1 Score: {f1_score:.2f}")
print(f"Misclassification error: { 1 - accuracy:.2f}")

Accuracy: 0.79
Precision: 0.68
Recall: 0.40
F1 Score: 0.51
Misclassification error: 0.21


In [0]:
# Train and evaluate Gradient Boosted Trees model
gbt_model = gbt_pipeline.fit(train_data)
gbt_predictions = gbt_model.transform(test_data)
print("Gradient Boosted Trees evaluation metrics:")
print("Accuracy:", multi_evaluator.evaluate(gbt_predictions))
print("AUC:", bin_evaluator.evaluate(gbt_predictions))

Gradient Boosted Trees evaluation metrics:
Accuracy: 0.7876599979424899
AUC: 0.8368098604752421


In [0]:
#Looking at the confusion matrix, Precision, Recall, F-1 score and misclassification rate for the Gradient Boost Trees model.
y_true = gbt_predictions.select("Churn")
y_true = y_true.toPandas()
 
y_pred = gbt_predictions.select("prediction")
y_pred = y_pred.toPandas()

cnf_matrix_gbt = confusion_matrix(y_true, y_pred)
print("Below is the confusion matrix \n {}".format(cnf_matrix_gbt))

Below is the confusion matrix 
 [[1341  154]
 [ 265  288]]


In [0]:
tn = cnf_matrix_gbt[0][0]
fp = cnf_matrix_gbt[0][1]
fn = cnf_matrix_gbt[1][0]
tp = cnf_matrix_gbt[1][1]
accuracy = (tp+tn)/(tp+tn+fp+fn)
precision = tp/(tp+fp)
recall = tp/(tp+fn)
f1_score = 2*(precision*recall)/(precision+recall)
print(f"Accuracy: {accuracy:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1 Score: {f1_score:.2f}")
print(f"Misclassification error: { 1 - accuracy:.2f}")

Accuracy: 0.80
Precision: 0.65
Recall: 0.52
F1 Score: 0.58
Misclassification error: 0.20


Model selection
| Metric Evaluation | Decison Tree   | Random Forest   | Gradient Boost Tree  |
|-------------------|------|------|------|
| Accuracy          | 0.79 | 0.79 | 0.80 |
| Precision         | 0.61 | 0.68 | 0.65 |
| Recall            | 0.61 | 0.40 | 0.52 |
| F1 Score          | 0.61 | 0.51 | 0.58 |
| Misclassification error | 0.21 | 0.21 | 0.20 |

The metric to optimize is the Recall or Sensitivity. Maximizing sensitivity is important when false negatives (i.e., failing to identify a churable employee) is costly and should be minimized.

Methods to address the class imbalance problem

*Collect more data: Increasing the size of the dataset, especially the number of positive (minority) samples, can help to improve recall.

*Resampling techniques: There are two common resampling techniques to balance the dataset: oversampling and undersampling. Oversampling involves randomly duplicating the minority class samples to create a more balanced dataset, while undersampling involves randomly removing samples from the majority class to balance the dataset. Both techniques have their advantages and disadvantages, and the choice depends on the specific context of the problem.

*Class weight: Adjusting the class weights in the model can help to give more emphasis to the minority class during training. This can be done by setting a higher weight for the minority class and a lower weight for the majority class.

*Threshold adjustment: The threshold used to classify samples can be adjusted to prioritize recall over precision. A lower threshold means that more samples will be classified as positive, which can increase recall at the cost of lower precision.

*Ensemble methods: Combining the predictions of multiple models using ensemble techniques, such as bagging or boosting, can help to improve recall by reducing the impact of individual model biases.

Applying the class weight method

In [0]:
# Define the class weights as a dictionary
class_weights = {0: 1.0, 1: 3.0} # set the weight for the positive class (1) higher than the negative class (0)

# Define the Decision Tree model with weightCol parameter
dt_weighted = DecisionTreeClassifier(labelCol="Churn", featuresCol="features", weightCol="classWeights")

# Define the pipeline
dt_pipeline_weighted = Pipeline(stages=[InternetService_indexer, Contract_indexer, PaymentMethod_indexer, data_encoder, assembler, dt_weighted])

In [0]:
# Define the evaluator
multi_evaluator_weighted = MulticlassClassificationEvaluator(predictionCol="prediction", labelCol="Churn", metricName="accuracy")
bin_evaluator_weighted = BinaryClassificationEvaluator(rawPredictionCol="prediction", labelCol="Churn", metricName="areaUnderROC")

In [0]:
# Train the model
train_data, test_data = df.randomSplit([0.7, 0.3], seed=1234)
train_data_weighted = train_data.withColumn("classWeights", when(train_data.Churn == 1, class_weights[1]).otherwise(class_weights[0]))
dt_model_weighted = dt_pipeline_weighted.fit(train_data_weighted)

# Make predictions on the test set and evaluate the model
test_data_weighted = test_data.withColumn("classWeights", when(test_data.Churn == 1, class_weights[1]).otherwise(class_weights[0]))
dt_predictions_weighted = dt_model_weighted.transform(test_data_weighted)

In [0]:
TP = dt_predictions_weighted.filter('prediction = 1 AND Churn = 1').count()
FP = dt_predictions_weighted.filter('prediction = 1 AND Churn = 0').count()
TN = dt_predictions_weighted.filter('prediction = 0 AND Churn = 0').count()
FN = dt_predictions_weighted.filter('prediction = 0 AND Churn = 1').count()

In [0]:
precision = TP / (TP + FP)
recall = TP / (TP + FN)
f1_score = 2 * precision * recall / (precision + recall)
misclassification_rate = (FP + FN) / (TP + TN + FP + FN)

In [0]:
print("Accuracy:", multi_evaluator_weighted.evaluate(dt_predictions_weighted))
print("AUC:", bin_evaluator_weighted.evaluate(dt_predictions_weighted))
print("Precision:", precision)
print("Recall:", recall)
print("F-1 Score:", f1_score)
print("Misclassification Rate:", misclassification_rate)

Accuracy: 0.73046875
AUC: 0.7521466975512106
Precision: 0.5005662514156285
Recall: 0.7992766726943942
F-1 Score: 0.6155988857938718
Misclassification Rate: 0.26953125


In [0]:
dt_predictions_weighted.groupBy('Churn', 'prediction').count().show()

+-----+----------+-----+
|Churn|prediction|count|
+-----+----------+-----+
|    1|       0.0|  111|
|    0|       0.0| 1054|
|    1|       1.0|  442|
|    0|       1.0|  441|
+-----+----------+-----+

