https://medium.com/@bogdan.cojocar/how-to-install-xgboost-on-the-pyspark-jupyter-notebook-efb092064ef4
https://towardsdatascience.com/pyspark-and-xgboost-integration-tested-on-the-kaggle-titanic-dataset-4e75a568bdb
https://github.com/BogdanCojocar/medium-articles/blob/master/titanic_xgboost/titanic_xgboost.ipynb

In [1]:
try:
    spark.stop
except:
    pass

In [2]:
#import os
#s.environ['PYSPARK_SUBMIT_ARGS'] = '--master local[*] pyspark-shell --jars /home/jovyan/work/maven/xgboost4j-spark-0.72.jar,/home/jovyan/work/maven/xgboost4j-0.72.jar'

In [4]:
import pyspark
from pyspark.sql.session import SparkSession

spark = (SparkSession
        .builder
        .appName("PySpark XGBOOST Titanic")
        .config("spark.jars", "/home/jovyan/work/maven/xgboost4j-spark-0.72.jar,/home/jovyan/work/maven/xgboost4j-0.72.jar")
        .getOrCreate())

In [5]:
spark.sparkContext.addPyFile("/home/jovyan/work/maven/sparkxgb.zip")
from sparkxgb import XGBoostEstimator

In [9]:
from pyspark.sql.types import *
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.sql.functions import col

In [10]:
schema = StructType(
  [StructField("PassengerId", DoubleType()),
    StructField("Survival", DoubleType()),
    StructField("Pclass", DoubleType()),
    StructField("Name", StringType()),
    StructField("Sex", StringType()),
    StructField("Age", DoubleType()),
    StructField("SibSp", DoubleType()),
    StructField("Parch", DoubleType()),
    StructField("Ticket", StringType()),
    StructField("Fare", DoubleType()),
    StructField("Cabin", StringType()),
    StructField("Embarked", StringType())
  ])

In [11]:
df_raw = spark\
  .read\
  .option("header", "true")\
  .schema(schema)\
  .csv("train.csv")

In [14]:
df = df_raw.na.fill(0)
df.show()

+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|PassengerId|Survival|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|        1.0|     0.0|   3.0|Braund, Mr. Owen ...|  male|22.0|  1.0|  0.0|       A/5 21171|   7.25| null|       S|
|        2.0|     1.0|   1.0|Cumings, Mrs. Joh...|female|38.0|  1.0|  0.0|        PC 17599|71.2833|  C85|       C|
|        3.0|     1.0|   3.0|Heikkinen, Miss. ...|female|26.0|  0.0|  0.0|STON/O2. 3101282|  7.925| null|       S|
|        4.0|     1.0|   1.0|Futrelle, Mrs. Ja...|female|35.0|  1.0|  0.0|          113803|   53.1| C123|       S|
|        5.0|     0.0|   3.0|Allen, Mr. Willia...|  male|35.0|  0.0|  0.0|          373450|   8.05| null|       S|
|        6.0|     0.0|   3.0|    Moran, Mr. James|  male| 0.0|  0.0|  0.0|      

In [15]:
sexIndexer = StringIndexer()\
  .setInputCol("Sex")\
  .setOutputCol("SexIndex")\
  .setHandleInvalid("keep")
    
cabinIndexer = StringIndexer()\
  .setInputCol("Cabin")\
  .setOutputCol("CabinIndex")\
  .setHandleInvalid("keep")
    
embarkedIndexer = StringIndexer()\
  .setInputCol("Embarked")\
  .setOutputCol("EmbarkedIndex")\
  .setHandleInvalid("keep")

vectorAssembler = VectorAssembler()\
  .setInputCols(["Pclass", "SexIndex", "Age", "SibSp", "Parch", "Fare", "CabinIndex", "EmbarkedIndex"])\
  .setOutputCol("features")

In [16]:
xgboost = XGBoostEstimator(
    featuresCol="features", 
    labelCol="Survival", 
    predictionCol="prediction"
)

In [17]:
pipeline = Pipeline().setStages([sexIndexer, cabinIndexer, embarkedIndexer, vectorAssembler, xgboost])
model = pipeline.fit(df)
model.transform(df).select(col("PassengerId"), col("prediction")).show()

+-----------+----------+
|PassengerId|prediction|
+-----------+----------+
|        1.0|       0.0|
|        2.0|       1.0|
|        3.0|       1.0|
|        4.0|       1.0|
|        5.0|       0.0|
|        6.0|       0.0|
|        7.0|       0.0|
|        8.0|       0.0|
|        9.0|       1.0|
|       10.0|       1.0|
|       11.0|       1.0|
|       12.0|       1.0|
|       13.0|       0.0|
|       14.0|       0.0|
|       15.0|       1.0|
|       16.0|       1.0|
|       17.0|       0.0|
|       18.0|       0.0|
|       19.0|       0.0|
|       20.0|       1.0|
+-----------+----------+
only showing top 20 rows

