# Continuous Applications with Structured Streaming Python APIs in Apache Spark to make prediction of credit fraud 

We have some historical data of credit card transactions, some of which have been identified as fraud. We want to train a model using this historical data that can flag potentially fraudulent transactions coming in as a live stream. We then want to deploy that model as part of a data pipeline which will work with a stream of transaction data to identify potential fraud hotspots in a continunous manner.

This dataset has 3 columns we'll be using.

**pcaVector:** The PCA transformation of raw transaction data. The main idea of principal component analysis (PCA) is to reduce the dimensionality of a data set consisting of many variables correlated with each other. Put simply, it is a method of summarizing data.

**amountRange:** This column is a value between 0 and 7 and tells us the approximate amount of a transaction. The values correspond to 0-1, 1-5, 5-10, 10-20, 20-50, 50-100, 100-200, and 200+ in dollars.

**label:** 0 or 1, whether a transaction was fraudulent.

We want to build a model which will predict the label using the pcaVector and amountRange data. We'll do this by using a ML pipeline with 3 stages:
* 1) A **OneHotEncoder** to build a vector from our _amountRange_ column. It is a process by which categorical variables are converted into a vector form that could be provided to ML algorithms to do a better job in prediction.
* 2) A **Vector assembler** to merge our _pcaVector_ & _amountRange_ vector into our features vector. It is a transformer that combines a given list of columns into a single vector column. It is useful for combining raw features and features generated by different feature transformers into a single feature vector, in order to train ML models like logistic regression and decision trees 
* 3) A **GBTClassifier** to serve as our Estimator. It's a learning algorithm for classification. It supports binary labels, as well as both continuous and categorical features.

## Setup input and output files

In [4]:
input_data = "/databricks-datasets/credit-card-fraud/data"
output_test_parquet_data = "/tmp/pydata/credit-card-frauld-test-data"

In [5]:
# Take a look at the schema of the historical dataset that will be used
data = spark.read.parquet(input_data)
display(data)

time,amountRange,label,pcaVector
52972,2,0,"List(1, 28, List(), List(-0.775460885847953, 0.709595287942808, 1.61012185737375, 1.23179224889532, 0.316177505485586, 0.11441004444483, 0.410964424898908, 0.0482463120477359, 0.0160424613851936, 0.590293137181086, 1.06963744560403, -0.0970015458011077, -2.05395393083259, 0.296433513650427, -0.242125977744976, -1.10398301441835, 0.438081447824287, -0.0900982397359751, 0.972524799861282, -0.0095254201744326, -0.0313171825598146, 0.289930092783562, -0.215881463615024, 0.196378667308146, -0.267693005539635, -0.298693369594354, 0.0111946468450594, -0.0233199829808697))"
41768,6,0,"List(1, 28, List(), List(0.873553941009571, -1.37751021440628, 1.07203129318284, 0.898335056199892, -1.37767734218766, 1.1341181119011, -1.04950509839334, 0.303593246528786, 0.391174678156379, 0.30645706157889, -1.74332056201875, 0.496823246907654, 0.582594013119008, -1.04298111816545, -1.09304763035534, -2.15804820205749, 0.552883029565036, 0.941659626723935, -0.878165967036597, -0.188112036138337, -0.206446493695106, -0.0208600286360115, -0.294525977933785, -0.419986418387049, 0.530461547583973, -0.0953143171911623, 0.102469833198419, 0.0613929441495575))"
40769,7,0,"List(1, 28, List(), List(0.890896976094619, -0.528186780186241, -0.678653636768693, 0.168700533941972, -0.141127297738687, -1.03734480197013, 0.786720278659216, -0.463098544159764, -0.231256401727913, -0.210526809038417, -0.579607229088567, 0.320748654905059, 0.739512636179789, 0.385612185997922, 0.760072091416543, 0.0562860882084142, -0.365932329362014, -0.553207062295103, 0.225507413957292, 0.469011543311791, 0.0724908019085573, -0.283582883371654, -0.367601705300131, -0.00374237527132833, 0.548135925665169, 1.07383485281209, -0.156020333403545, 0.0292101280471998))"
40682,3,0,"List(1, 28, List(), List(-0.572954136732452, 0.458245881426411, 0.239597590908972, -1.56195349253073, 2.68806291830668, 3.5295007407309, 0.386101316501381, 0.773473570610461, 0.0616701240466451, -0.425875341967141, -0.594523365105178, -0.149708184135855, -0.36862982219484, -0.285347030554428, -0.600321621523469, 0.106861235199372, -0.793362789636541, -0.20800446692985, 0.0823295987619518, 0.22946073939933, -0.290700364845232, -0.627816849509753, -0.206939602558071, 1.02190677932925, 0.191260521903609, 0.219702773949549, 0.122007212047274, -0.0629967134534869))"
50032,7,0,"List(1, 28, List(), List(-2.05305889575943, 0.504530324242776, -0.111155667869986, -1.04073848891865, -2.68458996781652, 0.0484387282726955, 2.15221951164238, 0.497990320120156, -0.802817358608386, -1.26277208511232, 1.10103363698183, 0.617549672582074, -0.478061657274302, 0.968828948292094, -0.337550992626504, 1.02051282591686, -0.476523815268263, 0.0486324920932011, -0.389469468851195, -0.142086381219838, -0.0511834359345369, -0.370120210625707, 0.583369828558284, 0.556770951700458, -0.250298033074819, 0.631668074606364, 0.103109972579581, -0.186395682097962))"
53637,7,0,"List(1, 28, List(), List(0.593549262043115, -1.0453650732411, -0.610745896437447, 0.277111810627197, -0.583285921921287, -0.995132318101401, 0.732746548626124, -0.283726190338305, -0.289380804099332, -0.0467347939072331, 1.28447592477553, 0.303680861144012, -1.05582055096572, 0.964711022411289, 0.217875962525365, 0.199900565624148, -0.338876909890541, 0.00498898793577377, 0.302888768406764, 0.588759417414082, 0.206397605941981, -0.273923925471225, -0.398510675654682, 0.401092217682281, 0.35465218080797, 1.01903912167697, -0.191281030233503, 0.0402095546132128))"
39160,6,0,"List(1, 28, List(), List(-0.923235255508135, -0.445387319467476, 1.42286311943015, -1.40924129397362, 0.498512239591899, 0.746435856726508, 1.08478835512332, -0.3917344097041, -0.827226697036681, 0.291464719918271, -1.75828968368404, -0.593144215020114, 1.0819478139191, -0.891649065339908, 0.172534238233272, -0.52787562413733, -1.5225349679792, 1.41081928493877, -2.7834715943276, -0.49615003392596, -0.299173667118795, -0.0188754189035958, 0.0207146225103749, -1.38135104828589, -0.0112325452959756, -0.612703415055321, -0.219640107829619, -0.330579950319831))"
52811,1,0,"List(1, 28, List(), List(-2.07310249785767, -1.36550893823072, 1.85414834269174, -0.231115522038623, 2.25914600510552, -0.977836997718314, -0.896331127320243, 0.545240778428714, -0.439603636128481, -0.497722090720462, 1.53315939546202, 0.65189935041798, -0.712463182969143, 0.568422112962742, 0.410804574936842, 0.11164056725874, -0.272601612926559, -0.724889765306901, -1.00499248948732, 0.322088208116428, 0.0458272371473575, -0.57053267571224, 0.434595092959784, -0.326171106205633, -0.446507323034388, 0.0221653443257277, -0.00463022663145894, 0.13516570293773))"
44704,4,0,"List(1, 28, List(), List(-0.631938626188563, 0.836257331608831, 1.98163791119715, 0.503599364178845, 0.0723555095019509, 0.14612224077591, 0.518540203373895, 0.197419049440126, -0.836070863705218, -0.346767091029642, 1.56581374831052, 0.94489412120318, 0.438374643048311, 0.223138718974657, 0.382418693355949, 0.0312397525866345, -0.472760782825448, 0.286902696689211, -0.566480401791849, 0.0195774486251498, 0.382765302148085, 1.08912603558255, -0.192762383672016, 0.2451108760633, -0.0178318033591119, -0.273238226198606, 0.139159044347052, 0.119843903937563))"
53397,3,0,"List(1, 28, List(), List(0.0850686577522656, 0.253272476584915, 0.269375279060897, -1.84929178865663, 0.295126856250259, -1.27094140950399, 0.848296016535243, -0.482323052194861, -1.06490514841405, 0.630489515857669, -1.00967251605995, -1.37479973948769, -0.511257460697217, -0.156859944074501, -0.282808097186376, 0.97140843897321, -0.187555100335406, -1.29966214715816, 0.132488576992436, 0.159666244875025, 0.239007750667715, 0.872221450881813, -0.129988349004438, -0.106777625070603, -0.615862507650666, -0.41912494176333, 0.261702215960486, 0.0487670238871042))"


In [6]:
data.count()

We using PySpark so import the appropriate classes

In [8]:
from pyspark.ml.feature import OneHotEncoderEstimator, VectorAssembler, VectorSizeHint
from pyspark.ml.classification import GBTClassifier

from pyspark.sql.types import *
from pyspark.sql.functions import count, rand, collect_list, explode, struct, when, sum

Because we intend to use this model in a streaming context, there a few things we should be aware of.

`VectorAssembler` has some limitations in a streaming context. Specifically, `VectorAssembler` can only work on Vector columns of known size. To address this issue we can explicitly specify the size of the pcaVector column so that we'll be be able to use our pipeline with structured streaming. To do this we'll use the `VectorSizeHint` transformer.

In [10]:
oneHot = OneHotEncoderEstimator(inputCols=["amountRange"], outputCols=["amountVect"])

vectorAssembler = VectorAssembler(inputCols=["amountVect", "pcaVector"], outputCol="features")

estimator = GBTClassifier(labelCol="label", featuresCol="features")

In [11]:
from pyspark.ml.feature import VectorSizeHint

vectorSizeHint = VectorSizeHint(inputCol="pcaVector", size=28)

### Build a ML Pipeline and fit it.

In [13]:
from pyspark.ml import Pipeline
from pyspark.sql.functions import col

pipeline = Pipeline(stages=[oneHot, vectorSizeHint, vectorAssembler, estimator])

# Split the data into testing and training datasets. 
# We will shave the test dataset for later
train = data.filter(col("time") % 10 < 8)
test = data.filter(col("time") % 10 >= 8)

# Save our data into partitions so we can read them as files
(test.repartition(20).write
  .mode("overwrite")
  .parquet(output_test_parquet_data))

In [14]:
train.count()

In [15]:
test.count()

## Fit the model with our training data

In [17]:
pipelineModel = pipeline.fit(train)

We can simulate a stream by reading our test data from a file, since we don't have a Kafka cluster availale right now.
But the effect is no different; we are still using PySpark APIs to read off the filesystem as we would off Kafka topics.

First, we need to define the schema

In [19]:
from pyspark.sql.types import *
from pyspark.ml.linalg import VectorUDT

schema = (StructType([StructField("time", IntegerType(), True), 
                      StructField("amountRange", IntegerType(), True), 
                      StructField("label", IntegerType(), True), 
                      StructField("pcaVector", VectorUDT(), True)]))

## **Read streaming test data:** 
Read files simulating as a Kafka stream using one file at a time

In [21]:
streamingData = (spark.readStream 
                 .schema(schema) 
                 .option("maxFilesPerTrigger", 1) 
                 .parquet(output_test_parquet_data)) # our test data

Transform the Streaming DataFrame using the pipeline model and use DataFrame PySpark API to make queries

In [23]:
from pyspark.sql.functions import *

stream = pipelineModel.transform(streamingData)

## Do aggregations using PySpark DataFrame APIs

1. _groupBy_("label", "preditcions")
2. _sort_("label", "predictions")

And finally _display()_ the predictions as they are scored in real-time from the stream

In [25]:
streamPredictions = (pipelineModel.transform(streamingData) \
          .groupBy("label", "prediction") \
          .count() \
          .sort("label", "prediction"))

In [26]:
display(streamPredictions)

label,prediction,count
0,0.0,57131
0,1.0,9
1,0.0,25
1,1.0,71


### Compute the Precision, Recall and F1 score

In [28]:
# define udf

def get_precision_recall(x, y):
  return x / (x + y)

def get_f1_score(x, y):
  return 2 * x * y / (x + y)

get_precision_recall_udf = udf(get_precision_recall, FloatType())
get_f1_score_udf = udf(get_f1_score, FloatType())

In [29]:
streamingMetrics = pipelineModel.transform(streamingData) \
        .select('label', 'prediction') \
        .groupBy() \
        .agg( \
             sum(when((col('prediction') == 1) & (col('label') == 1), 1).otherwise(0)).alias('TP') \
             , sum(when((col('prediction') == 1) & (col('label') == 0), 1).otherwise(0)).alias('FP') \
             , sum(when((col('prediction') == 0) & (col('label') == 0), 1).otherwise(0)).alias('TN') \
             , sum(when((col('prediction') == 0) & (col('label') == 1), 1).otherwise(0)).alias('FN') \
            ) \
        .withColumn('Precision', get_precision_recall_udf(col('TP'), col('FP'))) \
        .withColumn('Recall', get_precision_recall_udf(col('TP'), col('FN'))) \
        .withColumn('F1_Score', get_f1_score_udf(col('Precision'), col('Recall')))

In [30]:
display(streamingMetrics)

TP,FP,TN,FN,Precision,Recall,F1_Score
71,9,57131,25,0.8875,0.7395833,0.8068182


In [31]:
dbutils.fs.rm(output_test_parquet_data, True)