In [0]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('log_reg_hw').getOrCreate()
from pyspark.ml.classification import LogisticRegression

In [0]:
df = spark.read.table('customer_churn_csv')
df.show(2)

+----------------+----+--------------+---------------+-----+---------+-------------------+--------------------+----------+-----+
|           Names| Age|Total_Purchase|Account_Manager|Years|Num_Sites|       Onboard_date|            Location|   Company|Churn|
+----------------+----+--------------+---------------+-----+---------+-------------------+--------------------+----------+-----+
|Cameron Williams|42.0|       11066.8|              0| 7.22|      8.0|2013-08-30 07:00:40|10265 Elizabeth M...|Harvey LLC|    1|
|   Kevin Mueller|41.0|      11916.22|              0|  6.5|     11.0|2013-08-13 00:38:46|6157 Frank Garden...|Wilson PLC|    1|
+----------------+----+--------------+---------------+-----+---------+-------------------+--------------------+----------+-----+
only showing top 2 rows



In [0]:
df.printSchema()

root
 |-- Names: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- Total_Purchase: double (nullable = true)
 |-- Account_Manager: integer (nullable = true)
 |-- Years: double (nullable = true)
 |-- Num_Sites: double (nullable = true)
 |-- Onboard_date: timestamp (nullable = true)
 |-- Location: string (nullable = true)
 |-- Company: string (nullable = true)
 |-- Churn: integer (nullable = true)



In [0]:
df.columns

Out[6]: ['Names',
 'Age',
 'Total_Purchase',
 'Account_Manager',
 'Years',
 'Num_Sites',
 'Onboard_date',
 'Location',
 'Company',
 'Churn']

In [0]:
my_cols = df.select(['Churn','Age','Total_Purchase','Years','Num_Sites'])#,'Onboard_date'])
my_final_data = my_cols.na.drop()
my_final_data.show(5)

+-----+----+--------------+-----+---------+
|Churn| Age|Total_Purchase|Years|Num_Sites|
+-----+----+--------------+-----+---------+
|    1|42.0|       11066.8| 7.22|      8.0|
|    1|41.0|      11916.22|  6.5|     11.0|
|    1|38.0|      12884.75| 6.67|     12.0|
|    1|42.0|       8010.76| 6.71|     10.0|
|    1|37.0|       9191.58| 5.56|      9.0|
+-----+----+--------------+-----+---------+
only showing top 5 rows



In [0]:
from pyspark.ml.feature import VectorAssembler, VectorIndexer, OneHotEncoder, StringIndexer

In [0]:
# gender_indexer = StringIndexer(inputCol = 'Sex',outputCol = 'SexIndex')
# gender_encoder = OneHotEncoder(inputCol = 'SexIndex', outputCol = 'SexVec')


In [0]:
assembler = VectorAssembler(
    inputCols = ['Age','Total_Purchase','Years','Num_Sites'],
    outputCol = 'features'
)

In [0]:
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline

In [0]:
log_reg = LogisticRegression(featuresCol = 'features',labelCol = 'Churn')

In [0]:
pipeline = Pipeline(
    stages = [
        assembler,
        log_reg
    ]
)

In [0]:
train_df, test_df = my_final_data.randomSplit([0.7,0.3])

In [0]:
fit_model = pipeline.fit(train_df)
results = fit_model.transform(test_df)

In [0]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator
my_eval = BinaryClassificationEvaluator(rawPredictionCol = 'prediction',labelCol = 'Churn')

In [0]:
results.select('Churn','prediction').show(10)


+-----+----------+
|Churn|prediction|
+-----+----------+
|    0|       0.0|
|    0|       0.0|
|    0|       0.0|
|    0|       0.0|
|    0|       0.0|
|    0|       0.0|
|    0|       0.0|
|    0|       0.0|
|    0|       0.0|
|    0|       0.0|
+-----+----------+
only showing top 10 rows



In [0]:
AUC = my_eval.evaluate(results)
AUC

Out[46]: 0.7580429551733358

In [0]:
final_lr_model = pipeline.fit(my_final_data)

In [0]:
new_customers = spark.read.table('new_customers_csv')
new_customers.show(5)

+--------------+----+--------------+---------------+-----+---------+-------------------+--------------------+----------------+
|         Names| Age|Total_Purchase|Account_Manager|Years|Num_Sites|       Onboard_date|            Location|         Company|
+--------------+----+--------------+---------------+-----+---------+-------------------+--------------------+----------------+
| Andrew Mccall|37.0|       9935.53|              1| 7.71|      8.0|2011-08-29 18:37:54|38612 Johnny Stra...|        King Ltd|
|Michele Wright|23.0|       7526.94|              1| 9.28|     15.0|2013-07-22 18:19:54|21083 Nicole Junc...|   Cannon-Benson|
|  Jeremy Chang|65.0|         100.0|              1|  1.0|     15.0|2006-12-11 07:48:13|085 Austin Views ...|Barron-Robertson|
|Megan Ferguson|32.0|        6487.5|              0|  9.4|     14.0|2016-10-28 05:32:13|922 Wright Branch...|   Sexton-Golden|
|  Taylor Young|32.0|      13147.71|              1| 10.0|      8.0|2012-03-20 00:36:46|Unit 0789 Box 073...|  

In [0]:
new_customers.printSchema()

root
 |-- Names: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- Total_Purchase: double (nullable = true)
 |-- Account_Manager: integer (nullable = true)
 |-- Years: double (nullable = true)
 |-- Num_Sites: double (nullable = true)
 |-- Onboard_date: timestamp (nullable = true)
 |-- Location: string (nullable = true)
 |-- Company: string (nullable = true)



In [0]:
test_new_customers = assembler.transform(new_customers)
test_new_customers = test_new_customers.drop("features")

test_new_customers.show(5)

+--------------+----+--------------+---------------+-----+---------+-------------------+--------------------+----------------+
|         Names| Age|Total_Purchase|Account_Manager|Years|Num_Sites|       Onboard_date|            Location|         Company|
+--------------+----+--------------+---------------+-----+---------+-------------------+--------------------+----------------+
| Andrew Mccall|37.0|       9935.53|              1| 7.71|      8.0|2011-08-29 18:37:54|38612 Johnny Stra...|        King Ltd|
|Michele Wright|23.0|       7526.94|              1| 9.28|     15.0|2013-07-22 18:19:54|21083 Nicole Junc...|   Cannon-Benson|
|  Jeremy Chang|65.0|         100.0|              1|  1.0|     15.0|2006-12-11 07:48:13|085 Austin Views ...|Barron-Robertson|
|Megan Ferguson|32.0|        6487.5|              0|  9.4|     14.0|2016-10-28 05:32:13|922 Wright Branch...|   Sexton-Golden|
|  Taylor Young|32.0|      13147.71|              1| 10.0|      8.0|2012-03-20 00:36:46|Unit 0789 Box 073...|  

In [0]:
test_new_customers.printSchema()

root
 |-- Names: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- Total_Purchase: double (nullable = true)
 |-- Account_Manager: integer (nullable = true)
 |-- Years: double (nullable = true)
 |-- Num_Sites: double (nullable = true)
 |-- Onboard_date: timestamp (nullable = true)
 |-- Location: string (nullable = true)
 |-- Company: string (nullable = true)



In [0]:
final_results = final_lr_model.transform(test_new_customers)
final_results.select('prediction').show(5)

+----------+
|prediction|
+----------+
|       0.0|
|       1.0|
|       1.0|
|       1.0|
|       0.0|
+----------+
only showing top 5 rows

