Fraud_Detector

# Trade Settlement (Spark Streaming app that consumes stock settlement data from Kafka and stores them into the  VAST Database

In [1]:
import os

# Load environment variables for Kafka and VastDB connectivity
DOCKER_HOST_OR_IP = os.getenv("DOCKER_HOST_OR_IP", "localhost")
VASTDB_ENDPOINT = os.getenv("VASTDB_ENDPOINT")
VASTDB_ACCESS_KEY = os.getenv("VASTDB_ACCESS_KEY")
VASTDB_SECRET_KEY = os.getenv("VASTDB_SECRET_KEY")

VASTDB_FRAUD_DETECTION_BUCKET = os.getenv("VASTDB_FRAUD_DETECTION_BUCKET")
VASTDB_FRAUD_DETECTION_SCHEMA = os.getenv("VASTDB_FRAUD_DETECTION_SCHEMA")
VASTDB_FRAUD_DETECTION_TABLE = 'fraud'

# Print configurations
print(f"""
---
DOCKER_HOST_OR_IP={DOCKER_HOST_OR_IP}
---
VASTDB_ENDPOINT={VASTDB_ENDPOINT}
VASTDB_ACCESS_KEY==****{VASTDB_ACCESS_KEY[-4:]}
VASTDB_SECRET_KEY=****{VASTDB_SECRET_KEY[-4:]}
VASTDB_FRAUD_DETECTION_BUCKET={VASTDB_FRAUD_DETECTION_BUCKET}
VASTDB_FRAUD_DETECTION_SCHEMA={VASTDB_FRAUD_DETECTION_SCHEMA}
# VASTDB_FRAUD_DETECTION_TABLE={VASTDB_FRAUD_DETECTION_TABLE}
---
""")

# Kafka Configuration
kafka_brokers = f'{DOCKER_HOST_OR_IP}:19092'
topic = 'stock-settlement'


---
DOCKER_HOST_OR_IP=10.143.11.241
---
VASTDB_ENDPOINT=http://172.200.204.2:80
VASTDB_ACCESS_KEY==****QXN5
VASTDB_SECRET_KEY=****oLGr
VASTDB_FRAUD_DETECTION_BUCKET=csnow-db
VASTDB_FRAUD_DETECTION_SCHEMA=fraud_detection
# VASTDB_FRAUD_DETECTION_TABLE=fraud
---



Create Vast DB schema if it doesn't exist.

In [2]:
%%capture --no-stderr
%pip install --quiet -U vastdb

In [3]:
import vastdb

session = vastdb.connect(endpoint=VASTDB_ENDPOINT, access=VASTDB_ACCESS_KEY, secret=VASTDB_SECRET_KEY)
with session.transaction() as tx:
    bucket = tx.bucket(VASTDB_FRAUD_DETECTION_BUCKET)
    bucket.schema(VASTDB_FRAUD_DETECTION_SCHEMA, fail_if_missing=False) or bucket.create_schema(VASTDB_FRAUD_DETECTION_SCHEMA)

In [4]:
import socket
import pyspark
from pyspark.conf import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import from_json, col, count
from pyspark.sql.types import StructType, StructField, StringType, LongType, DoubleType, BooleanType
import threading
import time

# Spark Configuration
conf = SparkConf()
conf.setAll([
    ("spark.driver.host", socket.gethostbyname(socket.gethostname())),
    ("spark.sql.execution.arrow.pyspark.enabled", "false"),
    # VASTDB
    ("spark.sql.catalog.ndb", 'spark.sql.catalog.ndb.VastCatalog'),
    ("spark.ndb.endpoint", VASTDB_ENDPOINT),
    ("spark.ndb.data_endpoints", VASTDB_ENDPOINT),
    ("spark.ndb.access_key_id", VASTDB_ACCESS_KEY),
    ("spark.ndb.secret_access_key", VASTDB_SECRET_KEY),
    ("spark.driver.extraClassPath", '/usr/local/spark/jars/spark3-vast-3.4.1-f93839bfa38a/*'),
    ("spark.executor.extraClassPath", '/usr/local/spark/jars/spark3-vast-3.4.1-f93839bfa38a/*'),
    ("spark.sql.extensions", 'ndb.NDBSparkSessionExtension'),
    # Kafka
    ("spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.13:3.4.3," 
                            "org.apache.logging.log4j:log4j-slf4j2-impl:2.19.0," 
                            "org.apache.logging.log4j:log4j-api:2.19.0," 
                            "org.apache.logging.log4j:log4j-core:2.19.0"),
    ("spark.jars.excludes", "org.slf4j:slf4j-api,org.slf4j:slf4j-log4j12"),
    ("spark.hadoop.fs.file.impl", "org.apache.hadoop.fs.RawLocalFileSystem"),
])

spark = SparkSession.builder \
    .master("local") \
    .appName("KafkaStreamingToVastDB") \
    .config(conf=conf) \
    .enableHiveSupport() \
    .getOrCreate()

sc = spark.sparkContext
sc.setLogLevel("DEBUG")

print("Spark successfully loaded\n")


Spark successfully loaded



In [5]:
destination_table_name = f"`ndb`.`{VASTDB_FRAUD_DETECTION_BUCKET}`.`{VASTDB_FRAUD_DETECTION_SCHEMA}`.`{VASTDB_FRAUD_DETECTION_TABLE}`"
destination_table_name

'`ndb`.`csnow-db`.`fraud_detection`.`fraud`'

In [None]:
import os
import time
import threading
import pyspark
from pyspark.sql.functions import col, from_json
from pyspark.sql.types import *

# Create checkpoint directory with absolute path
checkpoint_dir = os.path.abspath("/tmp/spark_checkpoint")
os.makedirs(checkpoint_dir, exist_ok=True)

# Define schema for Kafka message
schema = StructType([
    StructField("partitionID", LongType(), True),
    StructField("offset", LongType(), True),
    StructField("timestamp", LongType(), True),
    StructField("compression", StringType(), True),
    StructField("isTransactional", BooleanType(), True),
    StructField("key", StructType([
        StructField("payload", StringType(), True),
        StructField("encoding", StringType(), True)
    ])),
    StructField("value", StructType([
        StructField("payload", StructType([
            StructField("transaction_id", StringType(), True),
            StructField("settlement_date", StringType(), True),
            StructField("stock_symbol", StringType(), True),
            StructField("quantity", LongType(), True),
            StructField("price", DoubleType(), True),
            StructField("buyer", StringType(), True),
            StructField("seller", StringType(), True),
            StructField("trade_date", StringType(), True),
            StructField("status", StringType(), True)
        ])),
        StructField("encoding", StringType(), True)
    ]))
])

# Global variables for tracking
total_message_count = 0
vast_table_row_count = 0
last_batch_id = 0
last_batch_size = 0

# Print a single comprehensive status update
def print_status(source=""):
    global total_message_count, vast_table_row_count, last_batch_id, last_batch_size
    
    # Format the time
    current_time = time.strftime("%H:%M:%S", time.localtime())
    
    # Print the status line with carriage return to overwrite previous output
    print(f"\rLast update: {current_time} | Batch {last_batch_id}: {last_batch_size} records | Total messages: {total_message_count} | DB rows: {vast_table_row_count}     ", end="")
    
    # Flush to ensure immediate display
    import sys
    sys.stdout.flush()

# Process each microbatch
def process_microbatch(parsed_df, epoch_id):
    global total_message_count, last_batch_id, last_batch_size
    
    batch_size = parsed_df.count()
    total_message_count += batch_size
    last_batch_id = epoch_id
    last_batch_size = batch_size
    parsed_df.write.mode("append").saveAsTable(destination_table_name)
    print_status("Batch")

# Function to periodically check and update row count
def check_row_count():
    global vast_table_row_count
    while True:
        time.sleep(2)  # Check every 2 seconds
        
        try:
            new_count = spark.sql(f"SELECT count(*) FROM {destination_table_name}").collect()[0][0]
            if new_count != vast_table_row_count:
                vast_table_row_count = new_count
                print_status("DB Count")
        except pyspark.errors.exceptions.captured.AnalysisException:
            # Ignore - the table will eventually get created
            pass

# Read data from Kafka stream
raw_stream = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", kafka_brokers) \
    .option("subscribe", topic) \
    .option("startingOffsets", "earliest") \
    .option("failOnDataLoss", "true") \
    .load()

# Parse the Kafka message
decoded_stream = raw_stream.selectExpr("CAST(value AS STRING) as json") \
    .select(from_json(col("json"), StructType([
        StructField("transaction_id", StringType(), True),
        StructField("settlement_date", StringType(), True),
        StructField("stock_symbol", StringType(), True),
        StructField("quantity", LongType(), True),
        StructField("price", DoubleType(), True),
        StructField("buyer", StringType(), True),
        StructField("seller", StringType(), True),
        StructField("trade_date", StringType(), True),
        StructField("status", StringType(), True)
    ])).alias("payload"))

# Prepare data to match VastDB table schema
vastdb_stream = decoded_stream.select(
    col("payload.transaction_id").alias("transaction_id"),
    col("payload.settlement_date").alias("settlement_date"),
    col("payload.stock_symbol").alias("stock_symbol"),
    col("payload.quantity").alias("quantity"),
    col("payload.price").alias("price"),
    col("payload.buyer").alias("buyer"),
    col("payload.seller").alias("seller"),
    col("payload.trade_date").alias("trade_date"),
    col("payload.status").alias("status")
)

# Main processing query
vastdb_query = vastdb_stream.writeStream \
    .foreachBatch(process_microbatch) \
    .outputMode("append") \
    .trigger(processingTime='1 second') \
    .option("maxFilesPerTrigger", 5000) \
    .start()

# Print initial status message
print("\nStarting Spark streaming job...")
print_status("Init")

# Start thread for checking row count
row_count_thread = threading.Thread(target=check_row_count)
row_count_thread.daemon = True
row_count_thread.start()

# Wait for termination
vastdb_query.awaitTermination()


Starting Spark streaming job...
Last update: 08:50:30 | Batch 0: 0 records | Total messages: 0 | DB rows: 0     