# **Labs 1 and 2 PySpark:**

In these labs we will be using the "[[NeurIPS 2020] Data Science for COVID-19 (DS4C)](https://www.kaggle.com/datasets/kimjihoo/coronavirusdataset?select=PatientInfo.csv)" dataset, retrieved from [Kaggle](https://www.kaggle.com/) on 1/6/2022, for educational non commercial purpose, License
[CC BY-NC-SA 4.0
](https://creativecommons.org/licenses/by-nc-sa/4.0/)


The csv file that we will be using in this lab is **PatientInfo**.

## PatientInfo.csv

**patient_id**
the ID of the patient

**sex**
the sex of the patient

**age**
the age of the patient

**country**
the country of the patient

**province**
the province of the patient

**city**
the city of the patient

**infection_case**
the case of infection

**infected_by**
the ID of who infected the patient


**contact_number**
the number of contacts with people

**symptom_onset_date**
the date of symptom onset

**confirmed_date**
the date of being confirmed

**released_date**
the date of being released

**deceased_date**
the date of being deceased

**state**
isolated / released / deceased

### Import the pyspark and check it's version

In [551]:
from google.colab import drive 
drive.mount('/content/drive') 

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [552]:
!pip install pyspark

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


### Import and create SparkSession

In [553]:
from pyspark.sql import SparkSession
spark=SparkSession.builder.getOrCreate()

### Load the PatientInfo.csv file and show the first 5 rows

In [554]:
from IPython.display import display, HTML
display(HTML("<style>pre { white-space: pre !important; }</style>"))

In [555]:
df = spark.read.csv('/content/drive/MyDrive/PatientInfo.csv',header=True)#inferSchema=True

In [556]:
df.show(5)

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|patient_id|   sex|age|country|province|       city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul| Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|         null|released|
|1000000002|  male|30s|  Korea|   Seoul|Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|         null|released|
|1000000003|  male|50s|  Korea|   Seoul|  Jongno-gu|contact with patient| 2002000001|            17|              null|    2020-01-30|   202

### Display the schema of the dataset

In [557]:
df.printSchema()

root
 |-- patient_id: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- age: string (nullable = true)
 |-- country: string (nullable = true)
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |-- infection_case: string (nullable = true)
 |-- infected_by: string (nullable = true)
 |-- contact_number: string (nullable = true)
 |-- symptom_onset_date: string (nullable = true)
 |-- confirmed_date: string (nullable = true)
 |-- released_date: string (nullable = true)
 |-- deceased_date: string (nullable = true)
 |-- state: string (nullable = true)



### Display the statistical summary

In [558]:
df.describe().show()

+-------+--------------------+------+----+----------+--------+--------------+--------------------+--------------------+--------------------+------------------+--------------+-------------+-------------+--------+
|summary|          patient_id|   sex| age|   country|province|          city|      infection_case|         infected_by|      contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+-------+--------------------+------+----+----------+--------+--------------+--------------------+--------------------+--------------------+------------------+--------------+-------------+-------------+--------+
|  count|                5165|  4043|3785|      5165|    5165|          5071|                4246|                1346|                 791|               690|          5162|         1587|           66|    5165|
|   mean|2.8636345618679576E9|  null|null|      null|    null|          null|                null|2.2845944015643125E9|1.6772572523506988E7|            

### Using the state column.
### How many people survived (released), and how many didn't survive (isolated/deceased)?

In [559]:
df.columns

['patient_id',
 'sex',
 'age',
 'country',
 'province',
 'city',
 'infection_case',
 'infected_by',
 'contact_number',
 'symptom_onset_date',
 'confirmed_date',
 'released_date',
 'deceased_date',
 'state']

In [560]:
df.select('state').show()

+--------+
|   state|
+--------+
|released|
|released|
|released|
|released|
|released|
|released|
|released|
|released|
|released|
|released|
|released|
|released|
|deceased|
|released|
|released|
|released|
|released|
|released|
|released|
|released|
+--------+
only showing top 20 rows



In [561]:
df.groupBy('state').count().show()

+--------+-----+
|   state|count|
+--------+-----+
|isolated| 2158|
|released| 2929|
|deceased|   78|
+--------+-----+



In [562]:
df.filter(df.state == "released").count()

2929

In [563]:
df.filter(df.state == "isolated").count() + df.filter(df.state == "deceased").count()

2236

### Display the number of null values in each column

In [564]:
df.schema

StructType([StructField('patient_id', StringType(), True), StructField('sex', StringType(), True), StructField('age', StringType(), True), StructField('country', StringType(), True), StructField('province', StringType(), True), StructField('city', StringType(), True), StructField('infection_case', StringType(), True), StructField('infected_by', StringType(), True), StructField('contact_number', StringType(), True), StructField('symptom_onset_date', StringType(), True), StructField('confirmed_date', StringType(), True), StructField('released_date', StringType(), True), StructField('deceased_date', StringType(), True), StructField('state', StringType(), True)])

In [565]:
df.printSchema()

root
 |-- patient_id: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- age: string (nullable = true)
 |-- country: string (nullable = true)
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |-- infection_case: string (nullable = true)
 |-- infected_by: string (nullable = true)
 |-- contact_number: string (nullable = true)
 |-- symptom_onset_date: string (nullable = true)
 |-- confirmed_date: string (nullable = true)
 |-- released_date: string (nullable = true)
 |-- deceased_date: string (nullable = true)
 |-- state: string (nullable = true)



In [566]:
# from pyspark.sql.types import IntegerType,BooleanType,DateType
# # Convert String to Integer Type
# df.withColumn("age",df.age.cast(IntegerType()))
# df.withColumn("age",df.age.cast('int'))
# df.withColumn("age",df.age.cast('integer'))

# # Using select
# df.select(col("age").cast('int').alias("age"))

# #Using selectExpr()
# df.selectExpr("cast(age as int) age")

# #Using with spark.sql()
# spark.sql("SELECT INT(age),BOOLEAN(isGraduated),DATE(jobStartDate) from CastExample")

In [567]:
from pyspark.sql.functions import col,isnan, when, count
df.select([count(when(col(c).isNull(), c)).alias(c) for c in df.columns]
   ).show()

+----------+----+----+-------+--------+----+--------------+-----------+--------------+------------------+--------------+-------------+-------------+-----+
|patient_id| sex| age|country|province|city|infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|state|
+----------+----+----+-------+--------+----+--------------+-----------+--------------+------------------+--------------+-------------+-------------+-----+
|         0|1122|1380|      0|       0|  94|           919|       3819|          4374|              4475|             3|         3578|         5099|    0|
+----------+----+----+-------+--------+----+--------------+-----------+--------------+------------------+--------------+-------------+-------------+-----+



## Data preprocessing

### Fill the nulls in the deceased_date with the released_date. 
- You can use <b>coalesce</b> function

In [568]:
from pyspark.sql.functions import coalesce

df=df.withColumn('deceased_date',coalesce(col('deceased_date'),col('released_date')))

In [569]:
df.show(5)

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|patient_id|   sex|age|country|province|       city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul| Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|
|1000000002|  male|30s|  Korea|   Seoul|Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|   2020-03-02|released|
|1000000003|  male|50s|  Korea|   Seoul|  Jongno-gu|contact with patient| 2002000001|            17|              null|    2020-01-30|   202

In [570]:
#df.select(col("deceased_date")).fillna(df.select(col("released_date")))

### Add a column named no_days which is difference between the deceased_date and the confirmed_date then show the top 5 rows. Print the schema.
- <b> Hint: You need to typecast these columns as date first <b>

In [571]:
from pyspark.sql.functions import *
df=df.withColumn('no_days',datediff(df["deceased_date"],df["confirmed_date"]))

In [572]:
df.show(5)

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+
|patient_id|   sex|age|country|province|       city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|no_days|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+
|1000000001|  male|50s|  Korea|   Seoul| Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|     13|
|1000000002|  male|30s|  Korea|   Seoul|Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|   2020-03-02|released|     32|
|1000000003|  male|50s|  Korea|   Seoul|  Jongno-gu|contact with patient| 2002000001|            17|

### Add a is_male column if male then it should yield true, else then False

In [573]:
df=df.withColumn('is_male',df['sex']=='male')

In [574]:
df.show(5)

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+
|patient_id|   sex|age|country|province|       city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|no_days|is_male|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+
|1000000001|  male|50s|  Korea|   Seoul| Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|     13|   true|
|1000000002|  male|30s|  Korea|   Seoul|Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|   2020-03-02|released|     32|   true|
|1000000003|  male|50s|  Korea|   Seoul|  Jongno-gu|contact 

### Add a is_dead column if patient state is not released then it should yield true, else then False

- Use <b>UDF</b> to perform this task. 
- However, UDF is not recommended there is no built in function can do the required operation.
- UDF is slower than built in functions.

In [575]:
def isDead(col):
    #df=df.withColumn('is_dead',df[col]=='not released')
    is_dead_col=col!='released'
    return is_dead_col

In [576]:
from pyspark.sql.functions import udf,col
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType, StringType,BooleanType

is_dead_col = udf(isDead,BooleanType())

df=df.withColumn('is_dead',is_dead_col(col('state')))

In [577]:
df.show(5)

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|patient_id|   sex|age|country|province|       city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|no_days|is_male|is_dead|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|1000000001|  male|50s|  Korea|   Seoul| Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|     13|   true|  false|
|1000000002|  male|30s|  Korea|   Seoul|Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|   2020-03-02|released|     32|   true|  false|
|1000000003|  male|5

### Change the ages to bins from 10s, 0s, 10s, 20s,.etc to 0,10, 20

In [578]:
# df.select("age", translate(col("age"), "s", "").alias("age")).show(5)
df=df.withColumn("age", regexp_replace(col("age"), "[s]", ""))

In [579]:
#type(x)

In [580]:
df.show(5)

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|patient_id|   sex|age|country|province|       city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|no_days|is_male|is_dead|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|1000000001|  male| 50|  Korea|   Seoul| Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|     13|   true|  false|
|1000000002|  male| 30|  Korea|   Seoul|Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|   2020-03-02|released|     32|   true|  false|
|1000000003|  male| 

### Change age, and no_days  to be typecasted as Double

In [581]:
df=df.withColumn("age",df.age.cast('double'))
df=df.withColumn("no_days",df.no_days.cast('double'))

In [582]:
df.printSchema()

root
 |-- patient_id: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- age: double (nullable = true)
 |-- country: string (nullable = true)
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |-- infection_case: string (nullable = true)
 |-- infected_by: string (nullable = true)
 |-- contact_number: string (nullable = true)
 |-- symptom_onset_date: string (nullable = true)
 |-- confirmed_date: string (nullable = true)
 |-- released_date: string (nullable = true)
 |-- deceased_date: string (nullable = true)
 |-- state: string (nullable = true)
 |-- no_days: double (nullable = true)
 |-- is_male: boolean (nullable = true)
 |-- is_dead: boolean (nullable = true)



### Drop the columns
["patient_id","sex","infected_by","contact_number","released_date","state",
"symptom_onset_date","confirmed_date","deceased_date","country","no_days",
"city","infection_case"]

In [583]:
df=df.drop("patient_id", "sex", "infected_by","contact_number","released_date","state", "symptom_onset_date","confirmed_date","deceased_date","country","no_days", "city","infection_case")

In [584]:
ml_df=df

In [585]:
df.columns

['age', 'province', 'is_male', 'is_dead']

In [586]:
ml_df.columns

['age', 'province', 'is_male', 'is_dead']

### Recount the number of nulls now

In [587]:
df.select([count(when(col(c).isNull(), c)).alias(c) for c in df.columns]
   ).show()

+----+--------+-------+-------+
| age|province|is_male|is_dead|
+----+--------+-------+-------+
|1380|       0|   1122|      0|
+----+--------+-------+-------+



## Now do the same but using SQL select statement

### From the original Patient DataFrame, Create a temporary view (table).

In [588]:
df = spark.read.csv('/content/drive/MyDrive/PatientInfo.csv',header=True)#inferSchema=True
df.createOrReplaceTempView('DF_view')

### Use SELECT statement to select all columns from the dataframe and show the output.

In [589]:
spark.sql(""" SELECT * FROM DF_view """).show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|         null|released|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|         null|released|
|1000000003|  male|50s|  Korea|   Seoul|   Jongno-gu|contact with patient| 2002000001|            17|              null|    2020-01-30|

### *Using SQL commands*, limit the output to only 5 rows 

In [590]:
spark.sql(""" SELECT * FROM DF_view limit(5) """).show()

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|patient_id|   sex|age|country|province|       city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul| Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|         null|released|
|1000000002|  male|30s|  Korea|   Seoul|Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|         null|released|
|1000000003|  male|50s|  Korea|   Seoul|  Jongno-gu|contact with patient| 2002000001|            17|              null|    2020-01-30|   202

### Select the count of males and females in the dataset

In [591]:
spark.sql(""" SELECT sex , count(sex) FROM DF_view group by (sex) """).show()

+------+----------+
|   sex|count(sex)|
+------+----------+
|  null|         0|
|female|      2218|
|  male|      1825|
+------+----------+



### How many people did survive, and how many didn't?

In [592]:
spark.sql(""" SELECT state , count(state) FROM DF_view group by (state) """).show()

+--------+------------+
|   state|count(state)|
+--------+------------+
|isolated|        2158|
|released|        2929|
|deceased|          78|
+--------+------------+



### Now, let's perform some preprocessing using SQL:
1. Convert *age* column to double after removing the 's' at the end -- *hint: check SUBSTRING method*
2. Select only the following columns: `['sex', 'age', 'province', 'state']`
3. Store the result of the query in a new dataframe

In [593]:
spark.sql("""select SUBSTRING(age, 0, 2) AS age from DF_view """)

DataFrame[age: string]

In [594]:
spark.sql("""SELECT CAST(age AS double) as age from DF_view""")

DataFrame[age: double]

In [595]:
df2=spark.sql(""" select sex ,CAST(SUBSTRING(age, 0, 2) AS DOUBLE) AS age ,province,state from DF_view  """)

In [596]:
df2.show(5)

+------+----+--------+--------+
|   sex| age|province|   state|
+------+----+--------+--------+
|  male|50.0|   Seoul|released|
|  male|30.0|   Seoul|released|
|  male|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|female|20.0|   Seoul|released|
+------+----+--------+--------+
only showing top 5 rows



## Machine Learning 
### Create a pipeline model to predict is_dead and evaluate the performance.
- Use <b>StringIndexer</b> to transform <b>string</b> data type to indices.
- Use <b>OneHotEncoder</b> to deal with categorical values.
- Use <b>Imputer</b> to fill missing data with mean.

In [597]:
ml_df.show(5)

+----+--------+-------+-------+
| age|province|is_male|is_dead|
+----+--------+-------+-------+
|50.0|   Seoul|   true|  false|
|30.0|   Seoul|   true|  false|
|50.0|   Seoul|   true|  false|
|20.0|   Seoul|   true|  false|
|20.0|   Seoul|  false|  false|
+----+--------+-------+-------+
only showing top 5 rows



In [598]:
ml_df.show()

+----+--------+-------+-------+
| age|province|is_male|is_dead|
+----+--------+-------+-------+
|50.0|   Seoul|   true|  false|
|30.0|   Seoul|   true|  false|
|50.0|   Seoul|   true|  false|
|20.0|   Seoul|   true|  false|
|20.0|   Seoul|  false|  false|
|50.0|   Seoul|  false|  false|
|20.0|   Seoul|   true|  false|
|20.0|   Seoul|   true|  false|
|30.0|   Seoul|   true|  false|
|60.0|   Seoul|  false|  false|
|50.0|   Seoul|  false|  false|
|20.0|   Seoul|   true|  false|
|80.0|   Seoul|   true|   true|
|60.0|   Seoul|  false|  false|
|70.0|   Seoul|   true|  false|
|70.0|   Seoul|   true|  false|
|70.0|   Seoul|   true|  false|
|20.0|   Seoul|   true|  false|
|70.0|   Seoul|  false|  false|
|70.0|   Seoul|  false|  false|
+----+--------+-------+-------+
only showing top 20 rows



In [599]:
ml_df=ml_df.withColumn("is_dead",ml_df.is_dead.cast('int'))
ml_df=ml_df.withColumn("is_male",ml_df.is_male.cast('int'))

In [600]:
ml_df.show(5)

+----+--------+-------+-------+
| age|province|is_male|is_dead|
+----+--------+-------+-------+
|50.0|   Seoul|      1|      0|
|30.0|   Seoul|      1|      0|
|50.0|   Seoul|      1|      0|
|20.0|   Seoul|      1|      0|
|20.0|   Seoul|      0|      0|
+----+--------+-------+-------+
only showing top 5 rows



In [601]:
from pyspark.ml.feature import StringIndexer,OneHotEncoder,VectorAssembler
categoricalCols = [f for (f,d) in ml_df.dtypes if d=='string' and f!='is_dead']

In [602]:
categoricalCols

['province']

In [603]:
indexOutputCols = [s + "_Index" for s in categoricalCols]

In [604]:
indexOutputCols

['province_Index']

In [605]:
oheOutputCols = [s + "_OHE" for s in categoricalCols]

In [606]:
oheOutputCols

['province_OHE']

In [607]:
stringIndexer = StringIndexer(inputCols=categoricalCols,outputCols=indexOutputCols,handleInvalid='keep')
oheEncoder = OneHotEncoder(inputCols=indexOutputCols,outputCols=oheOutputCols)

In [608]:
numericCols = [f for (f,d) in ml_df.dtypes if ((d == 'double')|(d=='int') &(f!='is_dead'))]

In [609]:
numericCols

['age', 'is_male']

In [610]:
assemblerInputs = oheOutputCols + numericCols

In [611]:
assemblerInputs

['province_OHE', 'age', 'is_male']

In [612]:
ml_df.columns[:-1]

['age', 'province', 'is_male']

In [613]:
numericCols

['age', 'is_male']

In [614]:
numericCols[:-1]

['age']

In [615]:
ml_df.columns[:-1]

['age', 'province', 'is_male']

In [616]:
ml_df.columns[0]

'age'

In [617]:
from pyspark.ml.feature import Imputer

imputer_age = Imputer(
    inputCols=['age'], 
    outputCols=['age']
    ).setStrategy("mean")

imputer_is_male = Imputer(
    inputCols=['is_male'], 
    outputCols=['is_male']
    ).setStrategy("mode")

# Add imputation cols to df
#df = imputer.fit(df).transform(df)

In [618]:
vecAssembler = VectorAssembler(inputCols=assemblerInputs,outputCol='features')

In [619]:
trainDF, testDF = ml_df.randomSplit([.8,.2],seed=42)
print(f"There are {trainDF.count()} rows in the training set, and {testDF.count()} in the test set")

There are 4166 rows in the training set, and 999 in the test set


In [620]:
from pyspark.ml.classification import LogisticRegression
lr=LogisticRegression(
    featuresCol='features',
    labelCol='is_dead',
    predictionCol='prediction'
)

In [621]:
myStages = [stringIndexer,oheEncoder,imputer_age,imputer_is_male,vecAssembler,lr]

In [622]:
from pyspark.ml import Pipeline
pipeline = Pipeline(stages=myStages)

In [623]:
pipelineModel = pipeline.fit(trainDF)

In [624]:
predDF = pipelineModel.transform(testDF)

In [625]:
predDF.select('is_dead','prediction').show()

+-------+----------+
|is_dead|prediction|
+-------+----------+
|      0|       0.0|
|      1|       1.0|
|      1|       1.0|
|      1|       1.0|
|      1|       1.0|
|      1|       1.0|
|      1|       1.0|
|      1|       1.0|
|      1|       1.0|
|      1|       1.0|
|      1|       1.0|
|      1|       1.0|
|      1|       1.0|
|      1|       1.0|
|      1|       1.0|
|      1|       1.0|
|      1|       1.0|
|      1|       1.0|
|      1|       1.0|
|      1|       1.0|
+-------+----------+
only showing top 20 rows



In [626]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

In [627]:
evaluator = BinaryClassificationEvaluator(rawPredictionCol='prediction', labelCol='is_dead')

In [628]:
evaluator.evaluate(predDF)

0.812747868433268