<a href="https://colab.research.google.com/github/santhoshsrivi/study/blob/main/Pyspark_customer_churn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pyspark

Collecting pyspark
  Downloading pyspark-3.4.1.tar.gz (310.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m310.8/310.8 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.4.1-py2.py3-none-any.whl size=311285397 sha256=0114ce346f999115420a9db7e9b19f3dfdc50486bd7344193ef997d205dc5339
  Stored in directory: /root/.cache/pip/wheels/0d/77/a3/ff2f74cc9ab41f8f594dabf0579c2a7c6de920d584206e0834
Successfully built pyspark
Installing collected packages: pyspark
Successfully installed pyspark-3.4.1


In [2]:
import pyspark

In [3]:
from pyspark.sql import SparkSession

In [4]:
spark = SparkSession.builder.appName('customer_churn_prediction').getOrCreate()

In [5]:
spark

In [6]:
data = spark.read.csv('/content/Churn_Modelling.csv',inferSchema=True,header=True)

In [7]:
data.show(5)

+---------+----------+--------+-----------+---------+------+---+------+---------+-------------+---------+--------------+---------------+------+
|RowNumber|CustomerId| Surname|CreditScore|Geography|Gender|Age|Tenure|  Balance|NumOfProducts|HasCrCard|IsActiveMember|EstimatedSalary|Exited|
+---------+----------+--------+-----------+---------+------+---+------+---------+-------------+---------+--------------+---------------+------+
|        1|  15634602|Hargrave|        619|   France|Female| 42|     2|      0.0|            1|        1|             1|      101348.88|     1|
|        2|  15647311|    Hill|        608|    Spain|Female| 41|     1| 83807.86|            1|        0|             1|      112542.58|     0|
|        3|  15619304|    Onio|        502|   France|Female| 42|     8| 159660.8|            3|        1|             0|      113931.57|     1|
|        4|  15701354|    Boni|        699|   France|Female| 39|     1|      0.0|            2|        0|             0|       93826.63|

In [8]:
data.printSchema()

root
 |-- RowNumber: integer (nullable = true)
 |-- CustomerId: integer (nullable = true)
 |-- Surname: string (nullable = true)
 |-- CreditScore: integer (nullable = true)
 |-- Geography: string (nullable = true)
 |-- Gender: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Tenure: integer (nullable = true)
 |-- Balance: double (nullable = true)
 |-- NumOfProducts: integer (nullable = true)
 |-- HasCrCard: integer (nullable = true)
 |-- IsActiveMember: integer (nullable = true)
 |-- EstimatedSalary: double (nullable = true)
 |-- Exited: integer (nullable = true)



In [9]:
data.describe().show()

+-------+------------------+-----------------+-------+-----------------+---------+------+------------------+------------------+-----------------+------------------+-------------------+-------------------+-----------------+-------------------+
|summary|         RowNumber|       CustomerId|Surname|      CreditScore|Geography|Gender|               Age|            Tenure|          Balance|     NumOfProducts|          HasCrCard|     IsActiveMember|  EstimatedSalary|             Exited|
+-------+------------------+-----------------+-------+-----------------+---------+------+------------------+------------------+-----------------+------------------+-------------------+-------------------+-----------------+-------------------+
|  count|             10000|            10000|  10000|            10000|    10000| 10000|             10000|             10000|            10000|             10000|              10000|              10000|            10000|              10000|
|   mean|            5000.5|

In [10]:
data.columns

['RowNumber',
 'CustomerId',
 'Surname',
 'CreditScore',
 'Geography',
 'Gender',
 'Age',
 'Tenure',
 'Balance',
 'NumOfProducts',
 'HasCrCard',
 'IsActiveMember',
 'EstimatedSalary',
 'Exited']

In [11]:
from pyspark.ml.feature import VectorAssembler

In [12]:
assembler = VectorAssembler(inputCols=['RowNumber','CustomerId','CreditScore','Age','Tenure','Balance','NumOfProducts','HasCrCard','IsActiveMember'],outputCol='features')

In [13]:
out = assembler.transform(data)

In [16]:
final_data = out.select('features','Exited')

In [17]:
final_data.show()

+--------------------+------+
|            features|Exited|
+--------------------+------+
|[1.0,1.5634602E7,...|     1|
|[2.0,1.5647311E7,...|     0|
|[3.0,1.5619304E7,...|     1|
|[4.0,1.5701354E7,...|     0|
|[5.0,1.5737888E7,...|     0|
|[6.0,1.5574012E7,...|     1|
|[7.0,1.5592531E7,...|     0|
|[8.0,1.5656148E7,...|     1|
|[9.0,1.5792365E7,...|     0|
|[10.0,1.5592389E7...|     0|
|[11.0,1.5767821E7...|     0|
|[12.0,1.5737173E7...|     0|
|[13.0,1.5632264E7...|     0|
|[14.0,1.5691483E7...|     0|
|[15.0,1.5600882E7...|     0|
|[16.0,1.5643966E7...|     0|
|[17.0,1.5737452E7...|     1|
|[18.0,1.5788218E7...|     0|
|[19.0,1.5661507E7...|     0|
|[20.0,1.5568982E7...|     0|
+--------------------+------+
only showing top 20 rows



In [18]:
train_churn, test_churn = final_data.randomSplit([0.8,0.2],seed=42)

In [19]:
from pyspark.ml.classification import LogisticRegression
lr_churn = LogisticRegression(labelCol='Exited')
fitted_churn_model = lr_churn.fit(train_churn)
training_sum = fitted_churn_model.summary

In [20]:
training_sum.predictions.describe().show()

+-------+-------------------+-------------------+
|summary|             Exited|         prediction|
+-------+-------------------+-------------------+
|  count|               8079|               8079|
|   mean|0.20336675331105336|0.05483351899987622|
| stddev| 0.4025279772969834|0.22766910196703233|
|    min|                0.0|                0.0|
|    max|                1.0|                1.0|
+-------+-------------------+-------------------+

