<a href="https://colab.research.google.com/github/weiczhu/BERT-BiLSTM-CRF-NER/blob/master/Kernel_IEEE_CIS_Fraud_Detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Install Java, Spark, and Findspark
This installs Apache Spark 2.2.1, Java 8, and Findspark, a library that makes it easy for Python to find Spark.

In [0]:
# Install Spark
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q http://apache.osuosl.org/spark/spark-2.3.3/spark-2.3.3-bin-hadoop2.7.tgz
!tar xf spark-2.3.3-bin-hadoop2.7.tgz
!pip install -q findspark

# Install kaggle
!pip install kaggle

api_token = {"username":"weiczhu","key":"626eeb97e341f91fe182eecd64e4e8ac"}
import json
import zipfile
import os
import subprocess

if not os.path.exists("/root/.kaggle/kaggle.json"):
    if not os.path.exists("/root/.kaggle/"):
        subprocess.call("!mkdir -p /root/.kaggle/")
    with open('/root/.kaggle/kaggle.json', 'w') as file:
        json.dump(api_token, file)    
        subprocess.call("!chmod 600 /root/.kaggle/kaggle.json")
    
!kaggle competitions download -c ieee-fraud-detection --force
if not os.path.exists("/content/competitions/ieee-fraud-detection"):
    os.makedirs("/content/competitions/ieee-fraud-detection")
    os.chdir('/content/competitions/ieee-fraud-detection')
    for file in os.listdir():
        if file[-4:] == ".zip":
            zip_ref = zipfile.ZipFile(file, 'r')
            zip_ref.extractall()
            zip_ref.close()
            subprocess.call(f"rm {file}")

Downloading train_transaction.csv.zip to /content/competitions/ieee-fraud-detection
 78% 41.0M/52.5M [00:00<00:00, 192MB/s]
100% 52.5M/52.5M [00:00<00:00, 208MB/s]
Downloading train_identity.csv.zip to /content/competitions/ieee-fraud-detection
  0% 0.00/3.02M [00:00<?, ?B/s]
100% 3.02M/3.02M [00:00<00:00, 99.8MB/s]
Downloading test_transaction.csv.zip to /content/competitions/ieee-fraud-detection
 76% 36.0M/47.3M [00:00<00:00, 162MB/s]
100% 47.3M/47.3M [00:00<00:00, 178MB/s]
Downloading test_identity.csv.zip to /content/competitions/ieee-fraud-detection
  0% 0.00/2.97M [00:00<?, ?B/s]
100% 2.97M/2.97M [00:00<00:00, 197MB/s]
Downloading sample_submission.csv.zip to /content/competitions/ieee-fraud-detection
  0% 0.00/1.14M [00:00<?, ?B/s]
100% 1.14M/1.14M [00:00<00:00, 163MB/s]


### Set Environment Variables
Set the locations where Spark and Java are installed.

In [0]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-2.3.3-bin-hadoop2.7"

### Start Actual Code

In [0]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load in 

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the "../input/" directory.
# For example, running this (by clicking run or pressing Shift+Enter) will list the files in the input directory

import os
print(os.listdir("/content/competitions/ieee-fraud-detection"))

# Any results you write to the current directory are saved as output.

['test_transaction.csv', 'train_identity.csv', 'test_transaction.csv.zip', 'test_identity.csv.zip', 'train_transaction.csv', 'train_transaction.csv.zip', 'sample_submission.csv', 'sample_submission.csv.zip', 'test_identity.csv', 'train_identity.csv.zip']


The following is a bit more details about it:

Transaction Table *
TransactionDT: timedelta from a given reference datetime (not an actual timestamp)
TransactionAMT: transaction payment amount in USD
ProductCD: product code, the product for each transaction
card1 - card6: payment card information, such as card type, card category, issue bank, country, etc.
addr: address
dist: distance
P_ and (R__) emaildomain: purchaser and recipient email domain
C1-C14: counting, such as how many addresses are found to be associated with the payment card, etc. The actual meaning is masked.
D1-D15: timedelta, such as days between previous transaction, etc.
M1-M9: match, such as names on card and address, etc.
Vxxx: Vesta engineered rich features, including ranking, counting, and other entity relations.

Categorical Features:
ProductCD
card1 - card6
addr1, addr2
Pemaildomain Remaildomain
M1 - M9

Identity Table *
Variables in this table are identity information – network connection information (IP, ISP, Proxy, etc) and digital signature (UA/browser/os/version, etc) associated with transactions. 
They're collected by Vesta’s fraud protection system and digital security partners.
(The field names are masked and pairwise dictionary will not be provided for privacy protection and contract agreement)

Categorical Features:
DeviceType
DeviceInfo
id12 - id38

In [0]:
import findspark
findspark.init()

from pyspark.sql import SparkSession
from pyspark.sql.functions import mean, col, split, regexp_extract, when, lit
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import QuantileDiscretizer

In [0]:
spark = SparkSession \
    .builder \
    .appName("fraud detection") \
    .master("local[*]") \
    .getOrCreate()
#     .config("spark.sql.execution.arrow.enabled", "true") \
#     .config("spark.driver.maxResultSize", "0") \
#     .config("spark.driver.memory", "16g") \
#     .getOrCreate()

In [0]:
# spark.conf.set("spark.sql.execution.arrow.enabled", "true")

In [0]:
train_transaction = spark.read.csv('file:///content/competitions/ieee-fraud-detection/train_transaction.csv', header=True, inferSchema=True)
train_identity = spark.read.csv('file:///content/competitions/ieee-fraud-detection/train_identity.csv', header=True, inferSchema=True)

test_transaction = spark.read.csv('file:///content/competitions/ieee-fraud-detection/test_transaction.csv', header=True, inferSchema=True)
test_identity = spark.read.csv('file:///content/competitions/ieee-fraud-detection/test_identity.csv', header=True, inferSchema=True)

In [0]:
train = train_transaction.join(train_identity, "TransactionID", how='left')

In [0]:
test = test_transaction.join(test_identity, "TransactionID", how='left')

In [0]:
train.printSchema()

In [0]:
# train.cache()

In [0]:
categorical_cols = ["ProductCD"] + \
    ["card" + str(i) for i in range(1, 7)] + \
    ["addr1", "addr2"] + \
    ["P_emaildomain", "R_emaildomain"] + \
    ["M" + str(i) for i in range(1, 10)] + \
    ["DeviceType", "DeviceInfo"] + \
    ["id_{:02d}".format(i) for i in range(12, 39)]
# categorical_cols

In [0]:
label_col = "isFraud"

In [0]:
contiguous_cols = list(filter(lambda x: x not in categorical_cols and x != label_col, train.columns))
contiguous_cols.remove("TransactionID")
# contiguous_cols

In [0]:
total_count = train.count()
# total_count

In [0]:
select_cols = train.columns[:17] + train.columns[432:]
del select_cols[1]
train.select(*(select_cols)).show(5)

+-------------+-------------+--------------+---------+-----+-----+-----+----------+-----+------+-----+-----+-----+-----+-------------+-------------+----------+--------------------+
|TransactionID|TransactionDT|TransactionAmt|ProductCD|card1|card2|card3|     card4|card5| card6|addr1|addr2|dist1|dist2|P_emaildomain|R_emaildomain|DeviceType|          DeviceInfo|
+-------------+-------------+--------------+---------+-----+-----+-----+----------+-----+------+-----+-----+-----+-----+-------------+-------------+----------+--------------------+
|      2987000|        86400|          68.5|        W|13926| null|150.0|  discover|142.0|credit|315.0| 87.0| 19.0| null|         null|         null|      null|                null|
|      2987001|        86401|          29.0|        W| 2755|404.0|150.0|mastercard|102.0|credit|325.0| 87.0| null| null|    gmail.com|         null|      null|                null|
|      2987002|        86469|          59.0|        W| 4663|490.0|150.0|      visa|166.0| debit

In [0]:
train.select(*(select_cols)).select("TransactionDT", "TransactionAmt", "ProductCD", "P_emaildomain", "R_emaildomain", "DeviceType", "DeviceInfo").describe().show()

+-------+-----------------+------------------+---------+-------------+-------------+----------+------------------+
|summary|    TransactionDT|    TransactionAmt|ProductCD|P_emaildomain|R_emaildomain|DeviceType|        DeviceInfo|
+-------+-----------------+------------------+---------+-------------+-------------+----------+------------------+
|  count|           590540|            590540|   590540|       496084|       137291|    140810|            118666|
|   mean|7372311.310116165|135.02717637240136|     null|         null|         null|      null|           14088.0|
| stddev|4617223.646539707|239.16252201373396|     null|         null|         null|      null|22267.786643490188|
|    min|            86400|             0.251|        C|      aim.com|      aim.com|   desktop|             0PAJ5|
|    max|         15811131|         31937.391|        W|    ymail.com|    ymail.com|    mobile|    xs-Z47b7VqTMxs|
+-------+-----------------+------------------+---------+-------------+----------

In [0]:
train.select(*(select_cols)).groupBy("ProductCD").count().show()

+---------+------+
|ProductCD| count|
+---------+------+
|        C| 68519|
|        W|439670|
|        S| 11628|
|        R| 37699|
|        H| 33024|
+---------+------+



In [0]:
train.select(*(select_cols)).groupBy("DeviceType").count().show()

+----------+------+
|DeviceType| count|
+----------+------+
|   desktop| 85165|
|      null|449730|
|    mobile| 55645|
+----------+------+

+----------+------+
|DeviceType| count|
+----------+------+
|   desktop| 85165|
|      null|449730|
|    mobile| 55645|
+----------+------+



### Feature engineering

In [0]:
train = train.withColumn("Emaildomain_match", when(train["P_emaildomain"].isNull() | train["R_emaildomain"].isNull(), 2). 
                         when(train["P_emaildomain"] == train["R_emaildomain"], 1).otherwise(0))
train.groupBy("Emaildomain_match").count().show()

categorical_cols += ["Emaildomain_match"] if "Emaildomain_match" not in categorical_cols else []
if "P_emaildomain" in categorical_cols:
    categorical_cols.remove("P_emaildomain")
if "R_emaildomain" in categorical_cols:
    categorical_cols.remove("R_emaildomain")

+-----------------+------+
|Emaildomain_match| count|
+-----------------+------+
|                1|102504|
|                2|464313|
|                0| 23723|
+-----------------+------+



In [0]:
train.groupBy("DeviceType").count().show()

+----------+------+
|DeviceType| count|
+----------+------+
|   desktop| 85165|
|      null|449730|
|    mobile| 55645|
+----------+------+



In [0]:
train.select("DeviceInfo").distinct().count()

1787

In [0]:
categorical_df = train.select(*categorical_cols)
categorical_df.show(5)

+---------+-----+-----+-----+-----+-----+------+-----+-----+----+----+----+----+----+----+----+----+----+----------+----------+--------+-----+------+-----+--------+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+----------+---------+-----+-----+--------------+-----+-----+-----+-----+-----------------+
|ProductCD|card1|card2|card3|card4|card5| card6|addr1|addr2|  M1|  M2|  M3|  M4|  M5|  M6|  M7|  M8|  M9|DeviceType|DeviceInfo|   id_12|id_13| id_14|id_15|   id_16|id_17|id_18|id_19|id_20|id_21|id_22|id_23|id_24|id_25|id_26|id_27|id_28|id_29|     id_30|    id_31|id_32|id_33|         id_34|id_35|id_36|id_37|id_38|Emaildomain_match|
+---------+-----+-----+-----+-----+-----+------+-----+-----+----+----+----+----+----+----+----+----+----+----------+----------+--------+-----+------+-----+--------+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+----------+---------+-----+-----+--------------+-----+-----+-----+-----+-----------------+
|

In [0]:
# label_df = train.select(label_col)
# label_df.groupBy("isFraud").count().show()

train.groupBy("isFraud").count().where(col("isFraud") == 0).show()

+-------+------+
|isFraud| count|
+-------+------+
|      0|569877|
+-------+------+



In [0]:
contiguous_df = train.select(*contiguous_cols)
contiguous_df.show(5)

+-------------+--------------+-----+-----+---+---+---+---+---+---+---+---+---+---+---+---+----+---+-----+-----+----+----+----+----+----+----+----+----+-----+----+----+----+-----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+---+----+----+---+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+------+-----+----+-----+-----+----+------+-----+----+----+----+----+----+----+----+----+----+----+-----+----+----+----+----+------+----+----+----+----+----+----+----+----+---------------+-------

In [0]:
# This function use to print feature with null values and null count 
def null_value_count(df):
    null_columns_counts = []
    numRows = df.count()
    for k in df.columns:
        nullRows = df.where(col(k).isNull()).count()
        if nullRows > 0:
            temp = k, nullRows
            null_columns_counts.append(temp)
    return null_columns_counts

In [0]:
# null_columns_count_list = null_value_count(train)
# print("null_columns_count_list:", null_columns_count_list)

null_columns_count_list = [('card2', 8933), ('card3', 1565), ('card4', 1577), ('card5', 4259), ('card6', 1571), ('addr1', 65706), ('addr2', 65706), ('dist1', 352271), ('dist2', 552913), ('P_emaildomain', 94456), ('R_emaildomain', 453249), ('D1', 1269), ('D2', 280797), ('D3', 262878), ('D4', 168922), ('D5', 309841), ('D6', 517353), ('D7', 551623), ('D8', 515614), ('D9', 515614), ('D10', 76022), ('D11', 279287), ('D12', 525823), ('D13', 528588), ('D14', 528353), ('D15', 89113), ('M1', 271100), ('M2', 271100), ('M3', 271100), ('M4', 281444), ('M5', 350482), ('M6', 169360), ('M7', 346265), ('M8', 346252), ('M9', 346252), ('V1', 279287), ('V2', 279287), ('V3', 279287), ('V4', 279287), ('V5', 279287), ('V6', 279287), ('V7', 279287), ('V8', 279287), ('V9', 279287), ('V10', 279287), ('V11', 279287), ('V12', 76073), ('V13', 76073), ('V14', 76073), ('V15', 76073), ('V16', 76073), ('V17', 76073), ('V18', 76073), ('V19', 76073), ('V20', 76073), ('V21', 76073), ('V22', 76073), ('V23', 76073), ('V24', 76073), ('V25', 76073), ('V26', 76073), ('V27', 76073), ('V28', 76073), ('V29', 76073), ('V30', 76073), ('V31', 76073), ('V32', 76073), ('V33', 76073), ('V34', 76073), ('V35', 168969), ('V36', 168969), ('V37', 168969), ('V38', 168969), ('V39', 168969), ('V40', 168969), ('V41', 168969), ('V42', 168969), ('V43', 168969), ('V44', 168969), ('V45', 168969), ('V46', 168969), ('V47', 168969), ('V48', 168969), ('V49', 168969), ('V50', 168969), ('V51', 168969), ('V52', 168969), ('V53', 77096), ('V54', 77096), ('V55', 77096), ('V56', 77096), ('V57', 77096), ('V58', 77096), ('V59', 77096), ('V60', 77096), ('V61', 77096), ('V62', 77096), ('V63', 77096), ('V64', 77096), ('V65', 77096), ('V66', 77096), ('V67', 77096), ('V68', 77096), ('V69', 77096), ('V70', 77096), ('V71', 77096), ('V72', 77096), ('V73', 77096), ('V74', 77096), ('V75', 89164), ('V76', 89164), ('V77', 89164), ('V78', 89164), ('V79', 89164), ('V80', 89164), ('V81', 89164), ('V82', 89164), ('V83', 89164), ('V84', 89164), ('V85', 89164), ('V86', 89164), ('V87', 89164), ('V88', 89164), ('V89', 89164), ('V90', 89164), ('V91', 89164), ('V92', 89164), ('V93', 89164), ('V94', 89164), ('V95', 314), ('V96', 314), ('V97', 314), ('V98', 314), ('V99', 314), ('V100', 314), ('V101', 314), ('V102', 314), ('V103', 314), ('V104', 314), ('V105', 314), ('V106', 314), ('V107', 314), ('V108', 314), ('V109', 314), ('V110', 314), ('V111', 314), ('V112', 314), ('V113', 314), ('V114', 314), ('V115', 314), ('V116', 314), ('V117', 314), ('V118', 314), ('V119', 314), ('V120', 314), ('V121', 314), ('V122', 314), ('V123', 314), ('V124', 314), ('V125', 314), ('V126', 314), ('V127', 314), ('V128', 314), ('V129', 314), ('V130', 314), ('V131', 314), ('V132', 314), ('V133', 314), ('V134', 314), ('V135', 314), ('V136', 314), ('V137', 314), ('V138', 508595), ('V139', 508595), ('V140', 508595), ('V141', 508595), ('V142', 508595), ('V143', 508589), ('V144', 508589), ('V145', 508589), ('V146', 508595), ('V147', 508595), ('V148', 508595), ('V149', 508595), ('V150', 508589), ('V151', 508589), ('V152', 508589), ('V153', 508595), ('V154', 508595), ('V155', 508595), ('V156', 508595), ('V157', 508595), ('V158', 508595), ('V159', 508589), ('V160', 508589), ('V161', 508595), ('V162', 508595), ('V163', 508595), ('V164', 508589), ('V165', 508589), ('V166', 508589), ('V167', 450909), ('V168', 450909), ('V169', 450721), ('V170', 450721), ('V171', 450721), ('V172', 450909), ('V173', 450909), ('V174', 450721), ('V175', 450721), ('V176', 450909), ('V177', 450909), ('V178', 450909), ('V179', 450909), ('V180', 450721), ('V181', 450909), ('V182', 450909), ('V183', 450909), ('V184', 450721), ('V185', 450721), ('V186', 450909), ('V187', 450909), ('V188', 450721), ('V189', 450721), ('V190', 450909), ('V191', 450909), ('V192', 450909), ('V193', 450909), ('V194', 450721), ('V195', 450721), ('V196', 450909), ('V197', 450721), ('V198', 450721), ('V199', 450909), ('V200', 450721), ('V201', 450721), ('V202', 450909), ('V203', 450909), ('V204', 450909), ('V205', 450909), ('V206', 450909), ('V207', 450909), ('V208', 450721), ('V209', 450721), ('V210', 450721), ('V211', 450909), ('V212', 450909), ('V213', 450909), ('V214', 450909), ('V215', 450909), ('V216', 450909), ('V217', 460110), ('V218', 460110), ('V219', 460110), ('V220', 449124), ('V221', 449124), ('V222', 449124), ('V223', 460110), ('V224', 460110), ('V225', 460110), ('V226', 460110), ('V227', 449124), ('V228', 460110), ('V229', 460110), ('V230', 460110), ('V231', 460110), ('V232', 460110), ('V233', 460110), ('V234', 449124), ('V235', 460110), ('V236', 460110), ('V237', 460110), ('V238', 449124), ('V239', 449124), ('V240', 460110), ('V241', 460110), ('V242', 460110), ('V243', 460110), ('V244', 460110), ('V245', 449124), ('V246', 460110), ('V247', 460110), ('V248', 460110), ('V249', 460110), ('V250', 449124), ('V251', 449124), ('V252', 460110), ('V253', 460110), ('V254', 460110), ('V255', 449124), ('V256', 449124), ('V257', 460110), ('V258', 460110), ('V259', 449124), ('V260', 460110), ('V261', 460110), ('V262', 460110), ('V263', 460110), ('V264', 460110), ('V265', 460110), ('V266', 460110), ('V267', 460110), ('V268', 460110), ('V269', 460110), ('V270', 449124), ('V271', 449124), ('V272', 449124), ('V273', 460110), ('V274', 460110), ('V275', 460110), ('V276', 460110), ('V277', 460110), ('V278', 460110), ('V279', 12), ('V280', 12), ('V281', 1269), ('V282', 1269), ('V283', 1269), ('V284', 12), ('V285', 12), ('V286', 12), ('V287', 12), ('V288', 1269), ('V289', 1269), ('V290', 12), ('V291', 12), ('V292', 12), ('V293', 12), ('V294', 12), ('V295', 12), ('V296', 1269), ('V297', 12), ('V298', 12), ('V299', 12), ('V300', 1269), ('V301', 1269), ('V302', 12), ('V303', 12), ('V304', 12), ('V305', 12), ('V306', 12), ('V307', 12), ('V308', 12), ('V309', 12), ('V310', 12), ('V311', 12), ('V312', 12), ('V313', 1269), ('V314', 1269), ('V315', 1269), ('V316', 12), ('V317', 12), ('V318', 12), ('V319', 12), ('V320', 12), ('V321', 12), ('V322', 508189), ('V323', 508189), ('V324', 508189), ('V325', 508189), ('V326', 508189), ('V327', 508189), ('V328', 508189), ('V329', 508189), ('V330', 508189), ('V331', 508189), ('V332', 508189), ('V333', 508189), ('V334', 508189), ('V335', 508189), ('V336', 508189), ('V337', 508189), ('V338', 508189), ('V339', 508189), ('id_01', 446307), ('id_02', 449668), ('id_03', 524216), ('id_04', 524216), ('id_05', 453675), ('id_06', 453675), ('id_07', 585385), ('id_08', 585385), ('id_09', 515614), ('id_10', 515614), ('id_11', 449562), ('id_12', 446307), ('id_13', 463220), ('id_14', 510496), ('id_15', 449555), ('id_16', 461200), ('id_17', 451171), ('id_18', 545427), ('id_19', 451222), ('id_20', 451279), ('id_21', 585381), ('id_22', 585371), ('id_23', 585371), ('id_24', 585793), ('id_25', 585408), ('id_26', 585377), ('id_27', 585371), ('id_28', 449562), ('id_29', 449562), ('id_30', 512975), ('id_31', 450258), ('id_32', 512954), ('id_33', 517251), ('id_34', 512735), ('id_35', 449555), ('id_36', 449555), ('id_37', 449555), ('id_38', 449555), ('DeviceType', 449730), ('DeviceInfo', 471874)]

In [0]:
for column, null_count in null_columns_count_list:
    if null_count > 0.8 * total_count:
        if column in categorical_cols:
            categorical_cols.remove(column)
        if column in contiguous_cols:
            contiguous_cols.remove(column) 

In [0]:
"""
Fill NA
titanic_df = titanic_df.withColumn("Age", when((titanic_df["Initial"] == "Miss") & (titanic_df["Age"].isNull()), 22).otherwise(titanic_df["Age"]))
titanic_df = titanic_df.withColumn("Age", when((titanic_df["Initial"] == "Other") & (titanic_df["Age"].isNull()), 46).otherwise(titanic_df["Age"]))
titanic_df = titanic_df.withColumn("Age", when((titanic_df["Initial"] == "Master") & (titanic_df["Age"].isNull()), 5).otherwise(titanic_df["Age"]))
titanic_df = titanic_df.withColumn("Age", when((titanic_df["Initial"] == "Mr") & (titanic_df["Age"].isNull()), 33).otherwise(titanic_df["Age"]))
titanic_df = titanic_df.withColumn("Age", when((titanic_df["Initial"] == "Mrs") & (titanic_df["Age"].isNull()), 36).otherwise(titanic_df["Age"]))

titanic_df = titanic_df.na.fill({"Embarked" : 'S'})
"""

# categorical_null_columns = [x[0] for x in categorical_null_columns_count_list]

'\nFill NA\ntitanic_df = titanic_df.withColumn("Age", when((titanic_df["Initial"] == "Miss") & (titanic_df["Age"].isNull()), 22).otherwise(titanic_df["Age"]))\ntitanic_df = titanic_df.withColumn("Age", when((titanic_df["Initial"] == "Other") & (titanic_df["Age"].isNull()), 46).otherwise(titanic_df["Age"]))\ntitanic_df = titanic_df.withColumn("Age", when((titanic_df["Initial"] == "Master") & (titanic_df["Age"].isNull()), 5).otherwise(titanic_df["Age"]))\ntitanic_df = titanic_df.withColumn("Age", when((titanic_df["Initial"] == "Mr") & (titanic_df["Age"].isNull()), 33).otherwise(titanic_df["Age"]))\ntitanic_df = titanic_df.withColumn("Age", when((titanic_df["Initial"] == "Mrs") & (titanic_df["Age"].isNull()), 36).otherwise(titanic_df["Age"]))\n\ntitanic_df = titanic_df.na.fill({"Embarked" : \'S\'})\n'

In [0]:
from pyspark.sql.functions import mean as _mean, stddev as _stddev

def fill_na(df, categorical_feats, contiguous_feats):
    for f in df.columns:
        if f in categorical_feats:
            df = df.na.fill({f: -999})
        elif f in contiguous_feats:
            df_stats = df.select(
            _mean(col(f)).alias('mean'),
            _stddev(col(f)).alias('std')
            ).collect()

            mean = df_stats[0]['mean']
            std = df_stats[0]['std']

            df = df.na.fill({f: mean})
        elif f != label_col:
            df = df.drop(f)
            
    return df

In [0]:
train = fill_na(train, categorical_cols, contiguous_cols)

In [0]:
train.select("TransactionDT", "TransactionAmt", "ProductCD", "DeviceType", "DeviceInfo", "Emaildomain_match").show()

In [0]:
train.select("card1", "card2", "card3", "card4", "card5", "card6", "addr1", "addr2", "dist1").show()

In [0]:
test = test.withColumn("Emaildomain_match", when(test["P_emaildomain"].isNull() | test["R_emaildomain"].isNull(), 2). 
                         when(test["P_emaildomain"] == test["R_emaildomain"], 1).otherwise(0))

In [0]:
test = fill_na(test, categorical_cols, contiguous_cols)

In [0]:
# categorical_cols

In [0]:
M_feats = ["M{:d}".format(i) for i in range(1, 10)]
train.select(*M_feats).show()

In [0]:
id_feats = [
'id_12',
'id_13',
'id_15',
'id_16',
'id_17',
'id_19',
'id_20',
'id_28',
'id_29',
'id_31',
'id_35',
'id_36',
'id_37',
'id_38'
]
train.select(*id_feats).show()

In [0]:
from pyspark.ml import Pipeline

feat_stages = {}
for f in categorical_cols:
    indexer = StringIndexer(inputCol=f, outputCol=f + "_numeric").fit(train)
    feat_stages.append(indexer)

pipeline = Pipeline(stages=feat_stages) 
pipeline_fitted = pipeline.fit(train)

indexed_train = pipeline_fitted.transform(indexed_train)
indexed_test = pipeline_fitted.transform(indexed_test)

indexed_categorical_cols = [x + "_numeric" for x in categorical_cols]

In [0]:
indexed_train.printSchema()

In [0]:
indexed_train.select("DeviceType_numeric").show()

In [0]:
feature = VectorAssembler(inputCols=indexed_categorical_cols + contiguous_cols, outputCol="features")
train_data = feature.transform(indexed_train)
test_data = feature.transform(indexed_test)

### Fit the ML model

In [0]:
from pyspark.ml.classification import LogisticRegression

lr = LogisticRegression(labelCol="isFraud", featuresCol="features")
# Training algo
lrModel = lr.fit(train_data)

lr_prediction = lrModel.transform(test_data)
lr_prediction.select("prediction", "isFraud", "features").show()
evaluator = MulticlassClassificationEvaluator(labelCol="isFraud", predictionCol="prediction", metricName="accuracy")

In [0]:
lr_accuracy = evaluator.evaluate(lr_prediction)
print("Test Accuracy of LogisticRegression is = %g"% (lr_accuracy))