In [1]:
import demolib
from demolib import spark, cfg, schema
from demolib.streams import *
from demolib.mongo import *
from demolib.udf import *
from demolib.lib import *
from time import sleep

from pyspark.sql.functions import from_json, col, expr, lit, broadcast
from pyspark.ml import PipelineModel
from pyspark.ml.classification import RandomForestClassificationModel

In [2]:
customer_df = mongo_read(cfg.db.customer).withColumn('age', udf_age()).cache()

In [3]:
raw_stream = spark.readStream.format("kafka") \
    .option("kafka.bootstrap.servers", cfg.kafka.config['bootstrap.servers']) \
    .option("subscribe", cfg.kafka.topic) \
    .option("group.id", cfg.kafka.groupid) \
    .load() \
    .withColumn('transaction', from_json(col('value').cast("string"), schema.event_transaction.schema)) \


In [4]:
transaction_stream = raw_stream \
    .selectExpr("transaction.*", "partition", "offset") \
    .withColumn("amt", lit(col("amt")).cast("double")) \
    .withColumn("merch_lat", lit(col("merch_lat")).cast("double")) \
    .withColumn("merch_long", lit(col("merch_long")).cast("double")) \
    .drop("first") \
    .drop("last")


In [5]:
_ = spark.sql("SET spark.sql.autoBroadcastJoinThreshold = 52428800")

In [6]:
processed_transactions_df = transaction_stream.join(broadcast(customer_df), "cc_num") \
    .withColumn("distance", lit(round(udf_dist(col("lat"), col("long"), col("merch_lat"), col("merch_long")), 2))) \
    .select("cc_num", "trans_num", to_timestamp("trans_time", "yyyy-MM-dd HH:mm:ss").alias("trans_time"), "category", "merchant", "amt", "merch_lat", "merch_long", "distance", "age", "partition", "offset")

In [7]:
preprocessing_transformer_model = PipelineModel.load(cfg.model.preprocessing.path)
feature_df = preprocessing_transformer_model.transform(processed_transactions_df)

random_forest_model = RandomForestClassificationModel.load(cfg.model.predict.path)
prediction_df =  random_forest_model.transform(feature_df).withColumnRenamed("prediction", "is_fraud") \
    .select("cc_num", "trans_num", "trans_time", "category", "merchant", "amt", "merch_lat", "merch_long", "distance", "age", "is_fraud") \
    .withColumn("_id", col("trans_num"))

In [8]:
fraud_prediction_df = prediction_df.filter('is_fraud = 1.0')
nonfraud_prediction_df = prediction_df.filter('is_fraud = 0.0')

## Persist Transactions

In [9]:
def foreachBatch_sink(dataset_name):
    """Returns a function which can be used as callback to foreachBatch to persist specified dataset.
    """
    def persist_batch(batchDF, batch_id):
        batchDF.persist()
        mongo_save(batchDF, dataset_name, database=cfg.db.name, uri=cfg.db.uri)
        batchDF.write \
            .mode('append') \
            .format('json').save(f'../data/out/{dataset_name}')
        batchDF.unpersist()
    return persist_batch

def get_write_stream(df, dataset_name):
    """Create streaming query with foreachBatch sink to persist stream to dataset"""
    query = df.writeStream \
        .queryName(f'sink__{dataset_name}') \
        .trigger(processingTime = '0 seconds') \
        .foreachBatch(foreachBatch_sink(dataset_name)) \
        .outputMode("append")
    return query

## Persist `fraud` and `nonfraud` Streams

In [10]:
_ = get_write_stream(nonfraud_prediction_df, 'predict_nonfraud').start()
_ = get_write_stream(fraud_prediction_df, 'predict_fraud').start()

In [11]:
streams_list()

['sink__predict_nonfraud', 'sink__predict_fraud']

We need to be able to stop the streaming job in graceful manner. For this purpose we use "flags". Flags are actually files on the filesystem.

Presence of a file means a flag is set.

For example if we want to set flag `predictor_stop`, we need to create a file named `__predictor_stop__`. After flag is processed it is cleared (the `clear_on_poll` argument to `flag_poll` function).

In [12]:
# Wait for all streams to stop or to observe 'predictor_stop' flag set.
while streams_list() and not flag_poll('predictor_stop', clear_on_poll=True):
    sleep(2)
    
# Make sure all streams are stopped.
streams_stop_all()

Query sink__predict_nonfraud stopped
Query sink__predict_fraud stopped
