In [1]:
import warnings

from snowflake.ml.modeling.impute import SimpleImputer
from snowflake.ml.modeling.metrics import accuracy_score, precision_score, recall_score
from snowflake.ml.modeling.preprocessing import OneHotEncoder
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
from snowflake.snowpark import Session
from snowflake.snowpark import types as T
from snowflake.snowpark.functions import col

warnings.simplefilter(action="ignore", category=UserWarning)

In [2]:
session = Session.builder.configs(SnowflakeLoginOptions()).getOrCreate()

SnowflakeLoginOptions() is in private preview since 0.2.0. Do not use it in production. 


In [3]:
titanic_df = session.table("titanic")

In [4]:
titanic_df.show()

-------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"SURVIVED"  |"PCLASS"  |"AGE"  |"SIBSP"  |"PARCH"  |"FARE"   |"ADULT_MALE"  |"DECK"  |"ALIVE"  |"ALONE"  |"SEX"   |"EMBARKED"  |"CLASS"  |"WHO"  |"EMBARK_TOWN"  |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------
|0           |3         |22.00  |1        |0        |7.2500   |True          |NULL    |False    |False    |MALE    |S           |THIRD    |MAN    |SOUTHAMPTON    |
|1           |1         |38.00  |1        |0        |71.2833  |False         |C       |True     |False    |FEMALE  |C           |FIRST    |WOMAN  |CHERBOURG      |
|1           |3         |26.00  |0        |0        |7.9250   |False         |NULL    |True     |True     |FEMALE  |S           |THIRD    |WOMAN  |SOUTHAMPTON    |
|1           |1 

In [5]:
# Columns with null values and their respective counts
{
    k: v
    for k, v in {
        col_name: titanic_df.where(col(col_name).is_null()).count()
        for col_name in titanic_df.columns
    }.items()
    if v > 0
}

{'AGE': 177, 'DECK': 688, 'EMBARKED': 2, 'EMBARK_TOWN': 2}

In [6]:
titanic_df = titanic_df.drop(["AGE", "DECK", "ALIVE", "ADULT_MALE", "EMBARKED"])

In [7]:
titanic_df = titanic_df.withColumn("FARE", titanic_df["FARE"].astype(T.FloatType()))

titanic_df.show()

------------------------------------------------------------------------------------------------------------
|"SURVIVED"  |"PCLASS"  |"SIBSP"  |"PARCH"  |"ALONE"  |"SEX"   |"CLASS"  |"WHO"  |"EMBARK_TOWN"  |"FARE"   |
------------------------------------------------------------------------------------------------------------
|0           |3         |1        |0        |False    |MALE    |THIRD    |MAN    |SOUTHAMPTON    |7.25     |
|1           |1         |1        |0        |False    |FEMALE  |FIRST    |WOMAN  |CHERBOURG      |71.2833  |
|1           |3         |0        |0        |True     |FEMALE  |THIRD    |WOMAN  |SOUTHAMPTON    |7.925    |
|1           |1         |1        |0        |False    |FEMALE  |FIRST    |WOMAN  |SOUTHAMPTON    |53.1     |
|0           |3         |0        |0        |True     |MALE    |THIRD    |MAN    |SOUTHAMPTON    |8.05     |
|0           |3         |0        |0        |True     |MALE    |THIRD    |MAN    |QUEENSTOWN     |8.4583   |
|0           |1    

In [8]:
cat_cols = ["SEX", "CLASS", "WHO", "EMBARK_TOWN"]
num_cols = ["PCLASS", "SIBSP", "PARCH", "FARE"]

In [9]:
impute_cat = SimpleImputer(
    input_cols=cat_cols,
    output_cols=cat_cols,
    strategy="most_frequent",
    drop_input_cols=True,
)

titanic_df = impute_cat.fit(titanic_df).transform(titanic_df)
titanic_df.show()

------------------------------------------------------------------------------------------------------------
|"SEX"   |"CLASS"  |"WHO"  |"EMBARK_TOWN"  |"SURVIVED"  |"PCLASS"  |"SIBSP"  |"PARCH"  |"ALONE"  |"FARE"   |
------------------------------------------------------------------------------------------------------------
|MALE    |THIRD    |MAN    |SOUTHAMPTON    |0           |3         |1        |0        |False    |7.25     |
|FEMALE  |FIRST    |WOMAN  |CHERBOURG      |1           |1         |1        |0        |False    |71.2833  |
|FEMALE  |THIRD    |WOMAN  |SOUTHAMPTON    |1           |3         |0        |0        |True     |7.925    |
|FEMALE  |FIRST    |WOMAN  |SOUTHAMPTON    |1           |1         |1        |0        |False    |53.1     |
|MALE    |THIRD    |MAN    |SOUTHAMPTON    |0           |3         |0        |0        |True     |8.05     |
|MALE    |THIRD    |MAN    |QUEENSTOWN     |0           |3         |0        |0        |True     |8.4583   |
|MALE    |FIRST    

In [10]:
OHE = OneHotEncoder(
    input_cols=cat_cols,
    output_cols=cat_cols,
    drop_input_cols=True,
    drop="first",
    handle_unknown="ignore",
)

titanic_df = OHE.fit(titanic_df).transform(titanic_df)
titanic_df.show()

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"SEX_MALE"  |"CLASS_SECOND"  |"CLASS_THIRD"  |"WHO_MAN"  |"WHO_WOMAN"  |"EMBARK_TOWN_QUEENSTOWN"  |"EMBARK_TOWN_SOUTHAMPTON"  |"SURVIVED"  |"PCLASS"  |"SIBSP"  |"PARCH"  |"ALONE"  |"FARE"   |
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|1.0         |0.0             |1.0            |1.0        |0.0          |0.0                       |1.0                        |0           |3         |1        |0        |False    |7.25     |
|0.0         |0.0             |0.0            |0.0        |1.0          |0.0                       |0.0                        |1           |1         |1        |0        |False    |71.2833  |
|0.0         |0.0             |1.0 

In [11]:
train_df, test_df = titanic_df.random_split(weights=[0.8, 0.2], seed=8)

In [12]:
xgb = XGBClassifier(
    input_cols=train_df.drop("SURVIVED").columns,
    label_cols="SURVIVED",
    output_cols="PRED_SURVIVED",
)

# Train
xgb.fit(train_df)

<snowflake.ml.modeling.xgboost.xgb_classifier.XGBClassifier at 0x12f02fa90>

In [13]:
result = xgb.predict(test_df)

In [14]:
accuracy = accuracy_score(
    df=result, y_true_col_names="SURVIVED", y_pred_col_names="PRED_SURVIVED"
)

precision = precision_score(
    df=result, y_true_col_names="SURVIVED", y_pred_col_names="PRED_SURVIVED"
)

recall = recall_score(
    df=result, y_true_col_names="SURVIVED", y_pred_col_names="PRED_SURVIVED"
)

print(f"Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}")

DataFrame.flatten() is deprecated since 0.7.0. Use `DataFrame.join_table_function()` instead.


Accuracy: 0.824468, Precision: 0.7808219178082192, Recall: 0.7702702702702703
