# Spark ML Production

In [4]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

import findspark
findspark.init()

In [5]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as f

spark = SparkSession\
    .builder\
    .master("local[*]")\
    .appName("Spark ML Production")\
    .config("spark.sql.repl.eagerEval.enabled", True)\
    .getOrCreate()

SLF4J: Class path contains multiple SLF4J bindings.
SLF4J: Found binding in [jar:file:/usr/lib/spark/jars/log4j-slf4j-impl-2.17.2.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/usr/lib/hadoop/lib/slf4j-log4j12-1.7.30.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: See http://www.slf4j.org/codes.html#multiple_bindings for an explanation.
SLF4J: Actual binding is of type [org.apache.logging.slf4j.Log4jLoggerFactory]
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


24/12/04 14:58:25 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


## Загружаем данные

In [7]:
data = spark\
    .read\
    .option("header", "true")\
    .option("inferSchema", "true")\
    .csv("data/BankChurners.csv")

                                                                                

## Загружаем модель

In [8]:
from pyspark.ml import PipelineModel

model = PipelineModel.load("data/pipelineModel")

## Вычисляем

In [9]:
predicted = model.transform(data)

In [11]:
predicted.limit(5)

CLIENTNUM,Attrition_Flag,Customer_Age,Gender,Dependent_count,Education_Level,Marital_Status,Income_Category,Card_Category,Months_on_book,Total_Relationship_Count,Months_Inactive_12_mon,Contacts_Count_12_mon,Credit_Limit,Total_Revolving_Bal,Avg_Open_To_Buy,Total_Amt_Chng_Q4_Q1,Total_Trans_Amt,Total_Trans_Ct,Total_Ct_Chng_Q4_Q1,Avg_Utilization_Ratio,Naive_Bayes_Classifier_Attrition_Flag_Card_Category_Contacts_Count_12_mon_Dependent_count_Education_Level_Months_Inactive_12_mon_1,Naive_Bayes_Classifier_Attrition_Flag_Card_Category_Contacts_Count_12_mon_Dependent_count_Education_Level_Months_Inactive_12_mon_2,Gender_Indexed,Education_Level_Indexed,Marital_Status_Indexed,Income_Category_Indexed,Card_Category_Indexed,Gender_Indexed_Coded,Education_Level_Indexed_Coded,Marital_Status_Indexed_Coded,Income_Category_Indexed_Coded,Card_Category_Indexed_Coded,features,scaledFeatures,selectedFeatures,rawPrediction,probability,prediction
768805383,Existing Customer,45,M,3,High School,Married,$60K - $80K,Blue,39,5,1,3,12691.0,777,11914.0,1.335,1144,42,1.625,0.061,9.3448e-05,0.99991,1.0,1.0,0.0,3.0,0.0,"(1,[],[])","(6,[1],[1.0])","(3,[0],[1.0])","(5,[3],[1.0])","(3,[0],[1.0])","(28,[0,1,2,3,4,5,...","(28,[0,1,2,3,4,5,...","(21,[0,1,2,3,4,5,...",[3.51896371925687...,[0.97122255488720...,0.0
818770008,Existing Customer,49,F,5,Graduate,Single,Less than $40K,Blue,44,6,1,2,8256.0,864,7392.0,1.541,1291,33,3.714,0.105,5.6861e-05,0.99994,0.0,0.0,1.0,0.0,0.0,"(1,[0],[1.0])","(6,[0],[1.0])","(3,[1],[1.0])","(5,[0],[1.0])","(3,[0],[1.0])","(28,[0,1,2,3,4,5,...","(28,[0,1,2,3,4,5,...","(21,[0,1,2,3,4,5,...",[10.8446596535426...,[0.99998049186815...,0.0
713982108,Existing Customer,51,M,3,Graduate,Married,$80K - $120K,Blue,36,4,1,0,3418.0,0,3418.0,2.594,1887,20,2.333,0.0,2.1081e-05,0.99998,1.0,0.0,0.0,2.0,0.0,"(1,[],[])","(6,[0],[1.0])","(3,[0],[1.0])","(5,[2],[1.0])","(3,[0],[1.0])","(28,[0,1,2,3,5,6,...","(28,[0,1,2,3,5,6,...","(21,[0,1,2,3,5,6,...",[6.43360263097913...,[0.99839592497014...,0.0
769911858,Existing Customer,40,F,4,High School,Unknown,Less than $40K,Blue,34,3,4,1,3313.0,2517,796.0,1.405,1171,20,2.333,0.76,0.00013366,0.99987,0.0,1.0,2.0,0.0,0.0,"(1,[0],[1.0])","(6,[1],[1.0])","(3,[2],[1.0])","(5,[0],[1.0])","(3,[0],[1.0])","(28,[0,1,2,3,4,5,...","(28,[0,1,2,3,4,5,...","(21,[0,1,2,3,4,5,...",[5.78319104183641...,[0.99693057547687...,0.0
709106358,Existing Customer,40,M,3,Uneducated,Married,$60K - $80K,Blue,21,5,1,0,4716.0,0,4716.0,2.175,816,28,2.5,0.0,2.1676e-05,0.99998,1.0,3.0,0.0,3.0,0.0,"(1,[],[])","(6,[3],[1.0])","(3,[0],[1.0])","(5,[3],[1.0])","(3,[0],[1.0])","(28,[0,1,2,3,5,6,...","(28,[0,1,2,3,5,6,...","(21,[0,1,2,3,5,6,...",[7.17715157194106...,[0.99923674251591...,0.0


## Проверяем результат

In [12]:
tp = predicted.filter((f.col("Attrition_Flag") == "Attrited Customer") & (f.col("prediction") == 1)).count()
tn = predicted.filter((f.col("Attrition_Flag") == "Existing Customer") & (f.col("prediction") == 0)).count()
fp = predicted.filter((f.col("Attrition_Flag") == "Existing Customer") & (f.col("prediction") == 1)).count()
fn = predicted.filter((f.col("Attrition_Flag") == "Attrited Customer") & (f.col("prediction") == 0)).count()

print(f"Confusion Matrix:\n{tp:>4}\t{fp:>4}\n{fn:>4}\t{tn:>4}")

Confusion Matrix:
1262	2097
 365	6403


In [13]:
accuracy = (tp + tn) / (tp + tn + fp + fn)
precision = tp / (tp + fp)
recall = tp / (tp + fn)

print(f"Accuracy = {accuracy}")
print(f"Precision = {precision}")
print(f"Recall = {recall}")

Accuracy = 0.756887528389454
Precision = 0.3757070556713308
Recall = 0.775660725261217
