In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator

# Initialize Spark Session
spark = SparkSession.builder.appName("classificationExample1").getOrCreate()
spark

# 타이타닉 데이터를 이용한 생존여부 예측 모델
## 로지스틱 회귀

In [2]:
# CSV 파일 로드
data = spark.read.csv("learning_spark_data/titanic.csv", 
                      header=True, 
                      inferSchema=True)
# 데이터 확인
data.show(10)

+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|PassengerId|Survived|Pclass|                Name|Gender| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| NULL|       S|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|
|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925| NULL|       S|
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1| C123|       S|
|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05| NULL|       S|
|          6|       0|     3|    Moran, Mr. James|  male|NULL|    0|    0|      

In [4]:
from pyspark.sql.functions import col, sum, when, isnan

## 결측치 처리
null_counts = data.select(
                    [
                        sum( when(col(c).isNull() | isnan(c),1).otherwise(0) ).alias(c)
                        for c in data.columns
                    ]
                )
null_counts.show()

+-----------+--------+------+----+------+---+-----+-----+------+----+-----+--------+
|PassengerId|Survived|Pclass|Name|Gender|Age|SibSp|Parch|Ticket|Fare|Cabin|Embarked|
+-----------+--------+------+----+------+---+-----+-----+------+----+-----+--------+
|          0|       0|     0|   0|     0|177|    0|    0|     0|   0|  687|       2|
+-----------+--------+------+----+------+---+-----+-----+------+----+-----+--------+



In [5]:
# feature selection
data_1 = data.select('Survived', 'Pclass', 'Gender', 'Age', 'SibSp', 'Parch','Fare' )
data_1.show(3)

+--------+------+------+----+-----+-----+-------+
|Survived|Pclass|Gender| Age|SibSp|Parch|   Fare|
+--------+------+------+----+-----+-----+-------+
|       0|     3|  male|22.0|    1|    0|   7.25|
|       1|     1|female|38.0|    1|    0|71.2833|
|       1|     3|female|26.0|    0|    0|  7.925|
+--------+------+------+----+-----+-----+-------+
only showing top 3 rows



In [6]:
#age 결측치 처리 - 평균값으로 대체
mean_age = data_1.select('Age').agg( {
                    "Age":"mean"
                } ).collect()[0][0]
mean_age

29.69911764705882

In [7]:
data_1 = data_1.fillna({'Age': mean_age})
data_1.show(10)

+--------+------+------+-----------------+-----+-----+-------+
|Survived|Pclass|Gender|              Age|SibSp|Parch|   Fare|
+--------+------+------+-----------------+-----+-----+-------+
|       0|     3|  male|             22.0|    1|    0|   7.25|
|       1|     1|female|             38.0|    1|    0|71.2833|
|       1|     3|female|             26.0|    0|    0|  7.925|
|       1|     1|female|             35.0|    1|    0|   53.1|
|       0|     3|  male|             35.0|    0|    0|   8.05|
|       0|     3|  male|29.69911764705882|    0|    0| 8.4583|
|       0|     1|  male|             54.0|    0|    0|51.8625|
|       0|     3|  male|              2.0|    3|    1| 21.075|
|       1|     3|female|             27.0|    0|    2|11.1333|
|       1|     2|female|             14.0|    1|    0|30.0708|
+--------+------+------+-----------------+-----+-----+-------+
only showing top 10 rows



In [11]:
from pyspark.sql.functions import col, when
from pyspark.ml.feature import StringIndexer, VectorAssembler

In [12]:
# 데이터 인코딩 StringIndexer
indexer = StringIndexer( inputCol='Gender', outputCol='SexIndexer')
data_1 = indexer.fit(data_1).transform(data_1)
data_1.show(5)


+--------+------+------+----+-----+-----+-------+----------+
|Survived|Pclass|Gender| Age|SibSp|Parch|   Fare|SexIndexer|
+--------+------+------+----+-----+-----+-------+----------+
|       0|     3|  male|22.0|    1|    0|   7.25|       0.0|
|       1|     1|female|38.0|    1|    0|71.2833|       1.0|
|       1|     3|female|26.0|    0|    0|  7.925|       1.0|
|       1|     1|female|35.0|    1|    0|   53.1|       1.0|
|       0|     3|  male|35.0|    0|    0|   8.05|       0.0|
+--------+------+------+----+-----+-----+-------+----------+
only showing top 5 rows



In [14]:
# FeatureVector 생성
assembler = VectorAssembler(
    inputCols=['Pclass', 'SexIndexer', 'Age', 'SibSp', 'Parch','Fare' ],
    outputCol='features'
)
data_1 = assembler.transform(data_1)
data_1.select('features','Survived').show(5)

+--------------------+--------+
|            features|Survived|
+--------------------+--------+
|[3.0,0.0,22.0,1.0...|       0|
|[1.0,1.0,38.0,1.0...|       1|
|[3.0,1.0,26.0,0.0...|       1|
|[1.0,1.0,35.0,1.0...|       1|
|[3.0,0.0,35.0,0.0...|       0|
+--------------------+--------+
only showing top 5 rows



In [15]:
# 데이터셋 분할
train_data, test_data = data_1.randomSplit([0.8,0.2], seed=42)
train_data.show(5), test_data.show(5)

+--------+------+------+----+-----+-----+------+----------+--------------------+
|Survived|Pclass|Gender| Age|SibSp|Parch|  Fare|SexIndexer|            features|
+--------+------+------+----+-----+-----+------+----------+--------------------+
|       0|     1|female| 2.0|    1|    2|151.55|       1.0|[1.0,1.0,2.0,1.0,...|
|       0|     1|female|25.0|    1|    2|151.55|       1.0|[1.0,1.0,25.0,1.0...|
|       0|     1|  male|18.0|    1|    0| 108.9|       0.0|[1.0,0.0,18.0,1.0...|
|       0|     1|  male|19.0|    1|    0|  53.1|       0.0|[1.0,0.0,19.0,1.0...|
|       0|     1|  male|19.0|    3|    2| 263.0|       0.0|[1.0,0.0,19.0,3.0...|
+--------+------+------+----+-----+-----+------+----------+--------------------+
only showing top 5 rows

+--------+------+------+-----------------+-----+-----+-------+----------+--------------------+
|Survived|Pclass|Gender|              Age|SibSp|Parch|   Fare|SexIndexer|            features|
+--------+------+------+-----------------+-----+-----+--

(None, None)