Fraud_Detector

# Trade Settlement (Spark Streaming app that consumes stock settlement data from Kafka and stores them into the  VAST Database

In [1]:
import os

# Load environment variables for Kafka and VastDB connectivity
DOCKER_HOST_OR_IP = os.getenv("DOCKER_HOST_OR_IP", "localhost")
VASTDB_ENDPOINT = os.getenv("VASTDB_ENDPOINT")
VASTDB_ACCESS_KEY = os.getenv("VASTDB_ACCESS_KEY")
VASTDB_SECRET_KEY = os.getenv("VASTDB_SECRET_KEY")

VASTDB_SIEM_BUCKET = os.getenv("VASTDB_SIEM_BUCKET", 'csnow-db')
VASTDB_SIEM_SCHEMA = os.getenv("VASTDB_SIEM_SCHEMA", 'zeek-live-logs')
VASTDB_SIEM_TABLE_PREFIX = 'zeek_'

use_vastkafka = True
if use_vastkafka:
    VAST_KAFKA_BROKER = os.getenv("VAST_KAFKA_BROKER")
else:
    VAST_KAFKA_BROKER = f"{DOCKER_HOST_OR_IP}:19092"

kafka_brokers = VAST_KAFKA_BROKER
topic = 'zeek-live-logs'

# Print configurations
print(f"""
---
DOCKER_HOST_OR_IP={DOCKER_HOST_OR_IP}
---
VASTDB_ENDPOINT={VASTDB_ENDPOINT}
VASTDB_ACCESS_KEY==****{VASTDB_ACCESS_KEY[-4:]}
VASTDB_SECRET_KEY=****{VASTDB_SECRET_KEY[-4:]}
VASTDB_SIEM_BUCKET={VASTDB_SIEM_BUCKET}
VASTDB_SIEM_SCHEMA={VASTDB_SIEM_SCHEMA}
VASTDB_SIEM_TABLE_PREFIX={VASTDB_SIEM_TABLE_PREFIX}
---
VAST_KAFKA_BROKER={VAST_KAFKA_BROKER}
topic={topic}
""")




---
DOCKER_HOST_OR_IP=10.143.11.241
---
VASTDB_ENDPOINT=http://172.200.204.2:80
VASTDB_ACCESS_KEY==****QXN5
VASTDB_SECRET_KEY=****oLGr
VASTDB_SIEM_BUCKET=csnow-db
VASTDB_SIEM_SCHEMA=zeek-live-logs
VASTDB_SIEM_TABLE_PREFIX=zeek_
---
VAST_KAFKA_BROKER=172.200.204.1:9092
topic=zeek-live-logs



Create Vast DB schema if it doesn't exist.

In [2]:
%%capture --no-stderr
%pip install --quiet -U vastdb

In [3]:
import vastdb

session = vastdb.connect(endpoint=VASTDB_ENDPOINT, access=VASTDB_ACCESS_KEY, secret=VASTDB_SECRET_KEY)
with session.transaction() as tx:
    bucket = tx.bucket(VASTDB_SIEM_BUCKET)
    bucket.schema(VASTDB_SIEM_SCHEMA, fail_if_missing=False) or bucket.create_schema(VASTDB_SIEM_SCHEMA)

In [4]:
import json
import socket
import pyspark
from pyspark.conf import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import from_json, col, count, get_json_object
from pyspark.sql.types import StructType, StructField, StringType, LongType, DoubleType, BooleanType
import threading
import time

# Spark Configuration
conf = SparkConf()
conf.setAll([
    ("spark.driver.host", socket.gethostbyname(socket.gethostname())),
    ("spark.sql.execution.arrow.pyspark.enabled", "false"),
    # VASTDB
    ("spark.sql.catalog.ndb", 'spark.sql.catalog.ndb.VastCatalog'),
    ("spark.ndb.endpoint", VASTDB_ENDPOINT),
    ("spark.ndb.data_endpoints", VASTDB_ENDPOINT),
    ("spark.ndb.access_key_id", VASTDB_ACCESS_KEY),
    ("spark.ndb.secret_access_key", VASTDB_SECRET_KEY),
    ("spark.driver.extraClassPath", '/usr/local/spark/jars/spark3-vast-3.4.1-f93839bfa38a/*'),
    ("spark.executor.extraClassPath", '/usr/local/spark/jars/spark3-vast-3.4.1-f93839bfa38a/*'),
    ("spark.sql.extensions", 'ndb.NDBSparkSessionExtension'),
    # Kafka
    ("spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.13:3.4.3," 
                            "org.apache.logging.log4j:log4j-slf4j2-impl:2.19.0," 
                            "org.apache.logging.log4j:log4j-api:2.19.0," 
                            "org.apache.logging.log4j:log4j-core:2.19.0"),
    ("spark.jars.excludes", "org.slf4j:slf4j-api,org.slf4j:slf4j-log4j12"),
    ("spark.hadoop.fs.file.impl", "org.apache.hadoop.fs.RawLocalFileSystem"),
])

spark = SparkSession.builder \
    .master("local") \
    .appName("KafkaStreamingToVastDB") \
    .config(conf=conf) \
    .enableHiveSupport() \
    .getOrCreate()

sc = spark.sparkContext
sc.setLogLevel("DEBUG")

print("Spark successfully loaded\n")


Spark successfully loaded



In [5]:
destination_table_name_prefix = f"`ndb`.`{VASTDB_SIEM_BUCKET}`.`{VASTDB_SIEM_SCHEMA}`.`{VASTDB_SIEM_TABLE_PREFIX}`"
destination_table_name_prefix

'`ndb`.`csnow-db`.`zeek-live-logs`.`zeek_`'

In [6]:
import os
import signal
import time
import threading
import pyspark
from pyspark.sql.functions import col, from_json
from pyspark.sql.types import *

# Create checkpoint directory with absolute path
checkpoint_dir = os.path.abspath("/tmp/spark_checkpoint")
os.makedirs(checkpoint_dir, exist_ok=True)

# Global variables for tracking
total_message_count = 0
table_row_counts = {}  # Track row counts per table
last_batch_id = 0
last_batch_size = 0
processed_log_types = set()  # Track which log types we've seen
created_tables = set()  # Track which tables we've already created

should_shutdown = False

# Print a comprehensive status update
def print_status(source=""):
    global total_message_count, table_row_counts, last_batch_id, last_batch_size, processed_log_types
    if not should_shutdown:
        current_time = time.strftime("%H:%M:%S", time.localtime())
        total_db_rows = sum(table_row_counts.values())
        
        # Create summary of table counts
        table_summary = ", ".join([f"{log_type}: {count}" for log_type, count in table_row_counts.items()])
        if not table_summary:
            table_summary = "No tables yet"
            
        print(f"\rLast update: {current_time} | Batch {last_batch_id}: {last_batch_size} records | "
              f"Total messages: {total_message_count} | Total VastDB rows: {total_db_rows} | "
              f"Log types: {len(processed_log_types)} ({', '.join(sorted(processed_log_types))}) | "
              f"Tables: [{table_summary}]     ", end="")
        
        import sys
        sys.stdout.flush()

# Helper function to create safe VastDB table names
def create_vastdb_table_name(log_type):
    """Create a VastDB table name for the log type"""
    # Clean up the log type for SQL compatibility
    clean_log_type = log_type.replace("-", "_").replace(".", "_")
    return f"`ndb`.`{VASTDB_SIEM_BUCKET}`.`{VASTDB_SIEM_SCHEMA}`.`{clean_log_type}`"

# Helper function to create table schema in VastDB if it doesn't exist
def ensure_table_exists(log_type, sample_data):
    """Ensure the VastDB table exists for this log type"""
    global created_tables
    
    table_name = create_vastdb_table_name(log_type)
    table_key = f"{VASTDB_SIEM_BUCKET}.{VASTDB_SIEM_SCHEMA}.{log_type}"
    
    if table_key in created_tables:
        return table_name
    
    try:
        # Try to query the table to see if it exists
        spark.sql(f"SELECT 1 FROM {table_name} LIMIT 1")
        created_tables.add(table_key)
        return table_name
    except:
        # Table doesn't exist, we'll let Spark create it dynamically
        created_tables.add(table_key)
        return table_name

# Process each microbatch with dynamic table routing
def process_microbatch(raw_df, epoch_id):
    global total_message_count, last_batch_id, last_batch_size, processed_log_types
    if not should_shutdown:
        try:
            batch_size = raw_df.count()
            if batch_size == 0:
                return
                
            total_message_count += batch_size
            last_batch_id = epoch_id
            last_batch_size = batch_size
            
            # Collect all JSON strings to determine log types
            json_strings = [row.json for row in raw_df.collect()]
            
            # Group messages by log type
            log_type_groups = {}
            for json_str in json_strings:
                try:
                    parsed = json.loads(json_str)
                    # Get the top-level key (log type)
                    log_type = list(parsed.keys())[0]
                    processed_log_types.add(log_type)
                    
                    if log_type not in log_type_groups:
                        log_type_groups[log_type] = []
                    log_type_groups[log_type].append(json_str)
                except Exception as e:
                    print(f"\nError parsing JSON: {e}")
                    continue
            
            # Process each log type group
            for log_type, json_list in log_type_groups.items():
                try:
                    # Create DataFrame for this log type
                    log_type_rdd = spark.sparkContext.parallelize([(json_str,) for json_str in json_list])
                    log_type_df = spark.createDataFrame(log_type_rdd, ["json"])
                    
                    # Extract the nested object for this log type using get_json_object
                    extracted_df = log_type_df.select(
                        get_json_object(col("json"), f"$.{log_type}").alias("log_data")
                    ).filter(col("log_data").isNotNull())
                    
                    if extracted_df.count() > 0:
                        # Use from_json with schema inference
                        sample_json = extracted_df.select("log_data").first()
                        if sample_json and sample_json.log_data:
                            try:
                                # Parse sample to create a basic schema
                                sample_dict = json.loads(sample_json.log_data)
                                
                                # Create a flexible schema that accommodates common types
                                fields = []
                                for key, value in sample_dict.items():
                                    # Keep original field names but handle special characters
                                    if isinstance(value, str):
                                        fields.append(StructField(key, StringType(), True))
                                    elif isinstance(value, int):
                                        fields.append(StructField(key, LongType(), True))
                                    elif isinstance(value, float):
                                        fields.append(StructField(key, DoubleType(), True))
                                    elif isinstance(value, bool):
                                        fields.append(StructField(key, BooleanType(), True))
                                    else:
                                        # Default to string for complex types
                                        fields.append(StructField(key, StringType(), True))
                                
                                inferred_schema = StructType(fields)
                                
                                # Parse with inferred schema
                                parsed_df = extracted_df.select(
                                    from_json(col("log_data"), inferred_schema).alias("parsed")
                                ).select("parsed.*")
                                
                                # Ensure table exists and get table name
                                table_name = ensure_table_exists(log_type, sample_dict)
                                
                                # Write to VastDB table specific to this log type
                                parsed_df.write.mode("append").saveAsTable(table_name)
                                
                            except Exception as schema_error:
                                print(f"\nSchema inference error for {log_type}: {schema_error}")
                                # Fallback: store as raw JSON string
                                try:
                                    fallback_df = extracted_df.select(col("log_data").alias("raw_json"))
                                    table_name = f"`ndb`.`{VASTDB_SIEM_BUCKET}`.`{VASTDB_SIEM_SCHEMA}`.`{log_type}_raw`"
                                    fallback_df.write.mode("append").saveAsTable(table_name)
                                except Exception as fallback_error:
                                    print(f"\nFallback failed for {log_type}: {fallback_error}")
                    
                except Exception as e:
                    print(f"\nError processing log type {log_type}: {e}")
                    continue
            
            print_status("Batch")
            
        except Exception as e:
            print(f"\nException in process_microbatch: {e}")

# Function to periodically check and update row counts for all VastDB tables
def check_row_counts():
    global table_row_counts
    while not should_shutdown:
        time.sleep(3)  # Check every 3 seconds
        try:
            for log_type in processed_log_types:
                try:
                    table_name = create_vastdb_table_name(log_type)
                    new_count = spark.sql(f"SELECT count(*) FROM {table_name}").collect()[0][0]
                    if table_row_counts.get(log_type, 0) != new_count:
                        table_row_counts[log_type] = new_count
                        print_status("DB Count")
                except Exception:
                    # Table might not exist yet or be accessible
                    pass
        except Exception:
            # Ignore errors in checking
            pass

# Read data from Kafka stream
raw_stream = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", kafka_brokers) \
    .option("subscribe", topic) \
    .option("startingOffsets", "earliest") \
    .option("failOnDataLoss", "true") \
    .load()

# Decode Kafka messages as JSON strings
decoded_stream = raw_stream.selectExpr("CAST(value AS STRING) as json")

# Main processing query - using the dynamic approach
zeek_query = decoded_stream.writeStream \
    .foreachBatch(process_microbatch) \
    .outputMode("append") \
    .trigger(processingTime='2 seconds') \
    .option("maxFilesPerTrigger", 1000) \
    .option("checkpointLocation", checkpoint_dir) \
    .start()

# Print initial status message
print("\nStarting Zeek log streaming job to VastDB...")
print("This will dynamically create VastDB tables for each Zeek log type (conn, analyzer, weird, etc.)")
print(f"Tables will be created in: ndb.{VASTDB_SIEM_BUCKET}.{VASTDB_SIEM_SCHEMA}")
print_status("Init")

# Start thread for checking row counts
row_count_thread = threading.Thread(target=check_row_counts)
row_count_thread.daemon = True
row_count_thread.start()

shutdown_flag = threading.Event()

def signal_handler(sig, frame):
    global should_shutdown
    print("\nGraceful shutdown initiated...")
    should_shutdown = True
    shutdown_flag.set()

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

# Main loop
try:
    while zeek_query.isActive and not shutdown_flag.is_set():
        time.sleep(1)
    if zeek_query.isActive:
        zeek_query.stop()
    zeek_query.awaitTermination()

except Exception as e:
    print(f"Error during awaitTermination: {e}")

print("\nFinal status:")
for log_type, count in table_row_counts.items():
    print(f"  {VASTDB_SIEM_BUCKET}.{VASTDB_SIEM_SCHEMA}.{log_type}: {count} rows")
print("VastDB Zeek streaming completed. Goodbye!")


Starting Zeek log streaming job to VastDB...
This will dynamically create VastDB tables for each Zeek log type (conn, analyzer, weird, etc.)
Tables will be created in: ndb.csnow-db.zeek-live-logs
Last update: 17:54:42 | Batch 0: 0 records | Total messages: 0 | Total VastDB rows: 0 | Log types: 0 () | Tables: [No tables yet]     
Error parsing JSON: name 'json' is not defined
Last update: 17:54:57 | Batch 7: 1 records | Total messages: 1 | Total VastDB rows: 0 | Log types: 0 () | Tables: [No tables yet]     
Error parsing JSON: name 'json' is not defined
Last update: 17:54:58 | Batch 8: 1 records | Total messages: 2 | Total VastDB rows: 0 | Log types: 0 () | Tables: [No tables yet]     
Error parsing JSON: name 'json' is not defined

Error parsing JSON: name 'json' is not defined
Last update: 17:55:00 | Batch 9: 2 records | Total messages: 4 | Total VastDB rows: 0 | Log types: 0 () | Tables: [No tables yet]     
Error parsing JSON: name 'json' is not defined
Last update: 17:55:02 | Bat