In [None]:
from pyspark.sql import SparkSession
from pyspark.sql import Row
from pyspark.sql.types import StructField , StructType ,DoubleType,StringType
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import VectorAssembler
from sparkxgb import XGBoostEstimator
from pyspark.ml import Pipeline
from pyspark.sql.functions import expr,col,column

In [None]:
spark = SparkSession.builder.appName("XGBoost")\
        .master("local[*]").getOrCreate()

In [None]:
schema = StructType(
  [StructField("PassengerId", DoubleType()),
    StructField("Survival", DoubleType()),
    StructField("Pclass", DoubleType()),
    StructField("Name", StringType()),
    StructField("Sex", StringType()),
    StructField("Age", DoubleType()),
    StructField("SibSp", DoubleType()),
    StructField("Parch", DoubleType()),
    StructField("Ticket", StringType()),
    StructField("Fare", DoubleType()),
    StructField("Cabin", StringType()),
    StructField("Embarked", StringType())
  ])

In [None]:
df_raw = spark\
  .read\
  .option("header", "true")\
  .schema(schema)\
  .csv("data/train.csv")
df = df_raw.na.fill(0)

In [None]:
sexIndexer = StringIndexer()\
  .setInputCol("Sex")\
  .setOutputCol("SexIndex")\
  .setHandleInvalid("keep")
    
cabinIndexer = StringIndexer()\
  .setInputCol("Cabin")\
  .setOutputCol("CabinIndex")\
  .setHandleInvalid("keep")
    
embarkedIndexer = StringIndexer()\
  .setInputCol("Embarked")\
  .setOutputCol("EmbarkedIndex")\
  .setHandleInvalid("keep")

vectorAssembler = VectorAssembler()\
  .setInputCols(["Pclass", "SexIndex", "Age", "SibSp",\
                 "Parch", "Fare", "CabinIndex", "EmbarkedIndex"])\
  .setOutputCol("features")

In [None]:
xgboost = XGBoostEstimator(
    featuresCol="features", 
    labelCol="Survival", 
    predictionCol="prediction"
)

In [None]:
pipeline = Pipeline().setStages([sexIndexer, cabinIndexer, embarkedIndexer, vectorAssembler, xgboost])

In [None]:
trainDF, testDF = df.randomSplit([0.8, 0.2], seed=24)
model = pipeline.fit(trainDF)
model.transform(testDF).select(col("PassengerId"), col("prediction")).show()