In [None]:
# Install these libraries separately as they are not part of container
#!pip install numpy pandas

In [None]:
# Import all relevant libraries here

In [None]:
import time
import pandas as pd

import os

os.environ[
    "PYSPARK_SUBMIT_ARGS"
] = "--packages org.apache.spark:spark-sql-kafka-0-10_2.12:3.0.0 pyspark-shell"


from pyspark.sql import SQLContext
from pyspark.sql.types import (
    StructType,
    StructField,
    TimestampType,
    StringType,
    IntegerType,
    LongType,
)
from pyspark.sql import SparkSession

from datetime import datetime, timedelta
from pyspark.sql.functions import col, from_json

import math
import numpy
import random
import pyspark.sql.functions as func
from datetime import datetime, timedelta
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
from pyspark.sql.types import (
    StructField,
    StructType,
    StringType,
    IntegerType,
    TimestampType,
    BooleanType,
)
from pyspark.sql.functions import date_trunc, col
from pyspark.sql.window import Window

In [None]:
# Setting up our apache spark application

In [None]:
topic_name = "test_topic"
app_name = "anomaly_app"
kafka_servers = "broker:9092"

spark = (
    SparkSession.builder.master("local")
    .appName(app_name)
    # Add kafka package
    .config("spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12-3.0.0")
    .getOrCreate()
)
sc = spark.sparkContext
sqlContext = SQLContext(sc)

In [None]:
# Constant variables which will be used in the code later on

In [None]:
stream_schema = StructType(
    [
        StructField("timestamp", LongType(), False),
        StructField("serviceName", StringType(), False),
        StructField("logMessage", StringType(), False),
        StructField("statusCode", IntegerType(), True),
    ]
)

schema = StructType(
    [
        StructField("timestamp", TimestampType(), False),
        StructField("service_id", StringType(), False),
        StructField("log_message", StringType(), True),
        StructField("status_code", IntegerType(), True),
    ]
)

# Size of the sliding window
window_size_mins = 3
window_size = window_size_mins * 60
window_size_millisec = window_size * 1000

# Period of the sliding window
# In this case wnindow will slide by 1 min and has a size of 3 min
window_period = 60
window_period_millisec = window_period * 1000


total_mins_to_look_back = window_size_mins*3
rolling_time_in_seconds = lambda minute: minute * 60
offset_window_size = 200

# List of service_ids for which we have to detect anomalies
service_ids = ["service_three"]

In [None]:
## Helper Functions

In [None]:
def convert_kafka_to_service_df(df):
    """
    Function convert kafka output in usable format
    """
    df = df.withColumn("key", df["key"].cast(StringType())).withColumn(
        "value", df["value"].cast(StringType())
    )
    dfJSON = df.withColumn("jsonData", from_json(col("value"), stream_schema)).select(
        "jsonData.*"
    )
    return dfJSON


def read_from_kafka(start_offset, end_offset):
    """
    Helper function to read data from kafka, it reads data
    between the start_offset and end_offset
    """
    df = (
        spark
        # .readStream
        .read.format("kafka")
        .option("kafka.bootstrap.servers", kafka_servers)  # kafka server
        .option("subscribe", topic_name)  # topic
        .option(
            "startingOffsets", """{"test_topic":{"0":""" + str(start_offset) + """}}"""
        )  # start from beginning
        .option(
            "endingOffsets", """{"test_topic":{"0":""" + str(end_offset) + """}}"""
        )  # read only 20 records
        .load()
    )
    return df


def preprocess_df(df):
    """
    Function to do basic preprocessing of the dataframe
    It standardizes schema and converts epoch time to
    timestamp which is used in anomaly detector model
    """
    df = df.withColumnRenamed("serviceName", "service_id")
    df = df.withColumnRenamed("statusCode", "status_code")
    df = df.withColumnRenamed("logMessage", "log_message")
    df = df.withColumn(
        "timestamp",
        func.to_utc_timestamp(
            func.from_unixtime(func.col("timestamp") / 1000, "yyyy-MM-dd HH:mm:ss"),
            "UTC",
        ),
    )
    return df



def remove_timestamp_window(timestamp_min_window):
    return timestamp_min_window[0]


timestamp_min_udf = func.udf(
    lambda timestamp_min_window: remove_timestamp_window(timestamp_min_window),
    TimestampType(),
)


def check_none(x, y, z):
    if x is None or y is None or z is None:
        return True
    if math.isnan(x) or math.isnan(y) or math.isnan(z):
        return True
    return False


def is_anomaly_std(traffic_count, rolling_mean, rolling_std):
    """
    Function to check if the current window is an anomaly, it is based
    on the principle that if the currrent traffic count is 2 standard deviation
    away from the rolling mean, that means it is an anomaly
    """
    if check_none(traffic_count, rolling_mean, rolling_std):
        return False
    if rolling_std > 0.2 * rolling_mean and traffic_count < rolling_mean:
        return True
    if (
        (rolling_mean - 2 * rolling_std)
        <= traffic_count
        <= (rolling_mean + 2 * rolling_std)
    ):
        return False
    return True


is_anomaly_udf = func.udf(is_anomaly_std, BooleanType())


def get_empty_df(start_time, end_time):
    """
    Used to get empty dataframe between two time stamps, used to
    raise anomaly in case of no data from the kakfa
    """
    start_time = datetime.utcfromtimestamp(start_time / 1000)
    end_time = datetime.utcfromtimestamp(end_time / 1000)
    
    df_service_ids = pd.DataFrame(service_ids, columns=["service_id"])
    df_service_ids = spark.createDataFrame(df_service_ids)
    df_service_ids.registerTempTable("temp_window_service")
        
    times = pd.date_range(start_time, end_time,freq='Min')
    df_times = pd.DataFrame()
    df_times['timestamp_min'] = times
    df_times.timestamp_min = df_times.timestamp_min.dt.round('min')
    df_time_empty = spark.createDataFrame(df_times)
    df_time_empty.registerTempTable("temp_window_temp_time")
        
    df_time_all = spark.sql("select service_id, timestamp_min, 0 as traffic_count from temp_window_service, temp_window_temp_time")
    return df_time_all

In [None]:
class BaseAnomalyDetector:
    """
    Skeleton class of the anomaly detector
    Every anomaly detector needs to have these
    3 functions
    """
    def __init__():
        pass

    def fit():
        pass

    def detect_anomaly():
        pass


class VanillaAnomalyDetector(BaseAnomalyDetector):
    """
    This anomaly detector works on the principle of threshold.
    If the traffic count of any service goes below this threshold,
    it will raise an alert
    """
    def __init__(self, traffic_threshold=100):
        self.traffic_threshold = traffic_threshold
        emptyRDD = spark.sparkContext.emptyRDD()
        agg_schema = StructType(
            [
                StructField("service_id", StringType(), False),
                StructField("timestamp_min", TimestampType(), False),
                StructField("traffic_count", IntegerType(), False),
            ]
        )
        self.df_cached = spark.createDataFrame(emptyRDD, agg_schema)

    def preprocess_df(self, X, start_time, end_time):
        """
        X = (time, status_code, service_id)
        Group by at every minute and count traffic for every service
        """
        df_recent = X.withColumn(
            "timestamp_min", date_trunc("minute", col("timestamp"))
        )
        df_recent = df_recent.groupBy(["service_id", "timestamp_min"]).agg(
            func.count("status_code").alias("traffic_count")
        )
        df_recent = df_recent.union(self.df_cached)

        df_recent = df_recent.groupBy(["service_id", "timestamp_min"]).agg(
            func.max("traffic_count").alias("traffic_count")
        )
        self.df_cached = df_recent
        self.df_cached = self.df_cached.cache()

        return df_recent

    def detect_anomaly(self, X, start_time, end_time):
        df = self.preprocess_df(X, start_time, end_time)

        df = df.groupBy(["service_id"]).agg(
            func.sum("traffic_count").alias("traffic_count")
        )
        dict_summ = list(map(lambda row: row.asDict(), df.collect()))
        service_errors = []
        for service in dict_summ:
            if service["traffic_count"] < self.traffic_threshold:
                service["traffic_threshold"] = self.traffic_threshold
                service_errors.append(service)
        return service_errors


class RollingAnomalyDetector(BaseAnomalyDetector):
    """
    This anomaly detector is an improvement over vanilla anomaly detector,
    here we do not define threshold manually, rather it is calcualted as the
    mean of last 'n' windows
    """
    def __init__(self):
        emptyRDD = spark.sparkContext.emptyRDD()
        agg_schema = StructType(
            [
                StructField("service_id", StringType(), False),
                StructField("timestamp_min", TimestampType(), False),
                StructField("traffic_count", IntegerType(), False),
            ]
        )
        self.df_cached = spark.createDataFrame(emptyRDD, agg_schema)

    def preprocess_df(self, X, start_time, end_time):
        """
        X = (time, status_code, service_id)
        Group by at every minute and count traffic for every service
        """
        df_empty = get_empty_df(start_time, end_time)
        
        df_recent = X.withColumn(
            "timestamp_min", date_trunc("minute", col("timestamp"))
        )
        df_recent = df_recent.groupBy(["service_id", "timestamp_min"]).agg(
            func.count("status_code").alias("traffic_count")
        )
        df_recent = df_recent.union(df_empty)
        df_recent = df_recent.union(self.df_cached)

        df_recent = df_recent.groupBy(["service_id", "timestamp_min"]).agg(
            func.max("traffic_count").alias("traffic_count")
        )
        self.df_cached = df_recent
        self.df_cached = self.df_cached.cache()

        return self.df_cached

    def detect_anomaly(self, X, start_time, end_time):
        df = self.preprocess_df(X, start_time, end_time)
        df_window = df.groupBy(
            [func.window("timestamp_min", f"{window_size_mins} minutes"), "service_id"]
        ).agg(func.sum("traffic_count").alias("traffic_count"))
        df_window = df_window.withColumnRenamed("window", "timestamp_min")
        df_window = df_window.withColumn(
            "timestamp_min", timestamp_min_udf("timestamp_min")
        )

        windowSpec = (
            Window()
            .partitionBy(func.col("service_id"))
            .orderBy(func.col("timestamp_min").cast("long"))
            .rangeBetween(-rolling_time_in_seconds(total_mins_to_look_back), -1)
        )
        df_window_avg = df_window.withColumn(
            "rolling_average", func.avg("traffic_count").over(windowSpec)
        )
        df_window_avg_std = df_window_avg.withColumn(
            "rolling_std", func.stddev("traffic_count").over(windowSpec)
        )
        df_anomaly = df_window_avg_std.withColumn(
            "is_anomaly",
            is_anomaly_udf("traffic_count", "rolling_average", "rolling_std"),
        )
        df_anomaly = df_anomaly.cache()
        return df_anomaly

In [None]:
def df_to_dict_list(df_res):
    df_res = df_res[df_res.is_anomaly == True]
    rows = df_res.collect()
    dict_summ = list(map(lambda row: row.asDict(), rows))
    return dict_summ


def raise_alerts(dict_list, curr_window_long):
    total_alerts_in_window = 0
    curr_datetime = datetime.utcfromtimestamp(curr_window_long / 1000)
    for alert_dict in dict_list:
        if alert_dict["is_anomaly"]:
            total_alerts_in_window += 1
            print(
                f"{alert_dict['service_id']} is anomalous in window {alert_dict['timestamp_min']}"
            )
            print(
                f"Current Traffic Count: {alert_dict['traffic_count']}. Expected Traffic Count: {alert_dict['rolling_average']}"
            )

    if total_alerts_in_window == 0 or len(dict_list) == 0:
        print(f"No Alerts Found in window: {curr_datetime}")

In [None]:
anomaly_detector = RollingAnomalyDetector()

In [None]:
def unix_to_time(x):
    return datetime.utcfromtimestamp(x / 1000)

In [None]:
# This cell contains the main code where we run our loop continously
# to consume data from kafka, create batches or windows and send that
# data to anomaly detector to classify whether we have any anomaly in the current winedow

df = read_from_kafka(0, 1)
service_df = convert_kafka_to_service_df(df)
window_start_time = service_df.first()["timestamp"]

i = 1
n = 100

# Create window DataFrame
window_df = spark.createDataFrame([], stream_schema)

# Create offset counter
offset_counter = 3000

## Start Loop Here
while True:
    i += 1
    if i == n:
        break

    # Create the time window
    window_end_time = window_start_time + window_size_millisec
    window_df_end_time = 0
    while window_end_time > window_df_end_time:
        df = convert_kafka_to_service_df(
            read_from_kafka(offset_counter, offset_counter + offset_window_size)
        )
        window_df = window_df.union(df)
        window_df_end_time = df.select("timestamp").rdd.max()[0]
        offset_counter = offset_counter + offset_window_size

    # Filter the dataframe to contain only the sliding window values
    # Remove values before sliding window
    window_df = window_df.filter(window_df.timestamp >= window_start_time)

    # Only send rows that are within the window
    df_window_final = window_df.filter(window_df.timestamp <= window_end_time)
    
    df_window_final = preprocess_df(df_window_final)
    df_anomalies_detected = anomaly_detector.detect_anomaly(df_window_final, start_time=window_start_time, end_time = window_end_time)
    anomaly_result = df_to_dict_list(df_anomalies_detected)
    raise_alerts(anomaly_result, window_start_time)
    print("\n")

    # Update window start time
    window_start_time = window_start_time + window_period_millisec