In [1]:
import pandas as pd
from tabulate import tabulate
from sklearn.metrics import accuracy_score, f1_score

from pyspark import SparkContext
from pyspark.sql import SparkSession
import pyspark.sql.functions as f
from pyspark.ml import PipelineModel

import os

os.environ[
    "PYSPARK_SUBMIT_ARGS"
] = "--packages org.apache.spark:spark-sql-kafka-0-10_2.12:3.1.2 pyspark-shell"

In [2]:
!pip3 install -q tabulate

[0m

In [3]:
spark = SparkSession.builder.appName("CS4830_project").getOrCreate()
spark.sparkContext.setLogLevel("WARN")

:: loading settings :: url = jar:file:/usr/lib/spark/jars/ivy-2.4.0.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /root/.ivy2/cache
The jars for the packages stored in: /root/.ivy2/jars
org.apache.spark#spark-sql-kafka-0-10_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-638babf7-b71a-4ffb-82be-1a644986c1a5;1.0
	confs: [default]
	found org.apache.spark#spark-sql-kafka-0-10_2.12;3.1.2 in central
	found org.apache.spark#spark-token-provider-kafka-0-10_2.12;3.1.2 in central
	found org.apache.kafka#kafka-clients;2.6.0 in central
	found com.github.luben#zstd-jni;1.4.8-1 in central
	found org.lz4#lz4-java;1.7.1 in central
	found org.xerial.snappy#snappy-java;1.1.8.2 in central
	found org.slf4j#slf4j-api;1.7.30 in central
	found org.spark-project.spark#unused;1.0.0 in central
	found org.apache.commons#commons-pool2;2.6.2 in central
:: resolution report :: resolve 728ms :: artifacts dl 16ms
	:: modules in use:
	com.github.luben#zstd-jni;1.4.8-1 from central in [default]
	org.apache.commons#commons-pool2;2.6.2 from central in [default]


In [4]:
# TODO: change following for demo
BROKER = "10.128.0.40:9092"
TOPIC = "CS4830-project"

In [5]:
MODEL_PATH = "gs://big-data-cs4830/project/final_model"

In [6]:
df = (
    spark.readStream.format("kafka")
    .option("kafka.bootstrap.servers", BROKER)
    .option("subscribe", TOPIC)
    .load()
)

split_cols = f.split(df.value, ",")
df = df.withColumn("Feet From Curb", split_cols.getItem(1))
df = df.withColumn("Violation In Front Of Or Opposite", split_cols.getItem(2))
df = df.withColumn("Issuing Agency", split_cols.getItem(3))
df = df.withColumn("Violation County", split_cols.getItem(4))
df = df.withColumn("Plate Type", split_cols.getItem(5))
df = df.withColumn("Violation Code", split_cols.getItem(6))
df = df.withColumn("Registration State", split_cols.getItem(7))
df = df.withColumn("Issuer Squad", split_cols.getItem(8))
df = df.withColumn("Violation Precinct", split_cols.getItem(9))

In [7]:
pipeline = PipelineModel.load(MODEL_PATH)

22/05/15 12:17:11 WARN org.apache.hadoop.util.concurrent.ExecutorHelper: Thread (Thread[GetFileInfo #1,5,main]) interrupted: 
java.lang.InterruptedException
	at com.google.common.util.concurrent.AbstractFuture.get(AbstractFuture.java:510)
	at com.google.common.util.concurrent.FluentFuture$TrustedFuture.get(FluentFuture.java:88)
	at org.apache.hadoop.util.concurrent.ExecutorHelper.logThrowableFromAfterExecute(ExecutorHelper.java:48)
	at org.apache.hadoop.util.concurrent.HadoopThreadPoolExecutor.afterExecute(HadoopThreadPoolExecutor.java:90)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1157)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:750)


In [8]:
df = pipeline.transform(df).select("Violation Precinct", "prediction")

In [9]:
def batch_function(df, batch_id):
    if df.count() > 0:
        print("#" * 88)
        print("Batch:", batch_id, "| COUNT:", df.count())

        dftemp = df.toPandas()
        acc = accuracy_score(dftemp["prediction"], dftemp["Violation Precinct"])
        f1 = f1_score(
            dftemp["prediction"], dftemp["Violation Precinct"], average="weighted"
        )
        output = pd.DataFrame(
            [["Accuracy", acc], ["F1-Score", f1]], columns=["Metric", "Value"]
        )

        print(tabulate(dftemp, headers="keys", tablefmt="psql", showindex=False))
        print(tabulate(output, headers="keys", tablefmt="psql", showindex=False))

        print("#" * 88)

In [10]:
query = df.writeStream.foreachBatch(batch_function).start()
query.awaitTermination()

22/05/15 12:17:23 WARN org.apache.spark.sql.streaming.StreamingQueryManager: Temporary checkpoint location created which is deleted normally when the query didn't fail: /tmp/temporary-bb8eb881-4d1e-4ccb-a176-768453351132. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort.
22/05/15 12:17:23 WARN org.apache.spark.sql.streaming.StreamingQueryManager: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.
                                                                                

########################################################################################
Batch: 1 | COUNT: 1
+----------------------+--------------+
|   Violation Precinct |   prediction |
|----------------------+--------------|
|                   71 |           79 |
+----------------------+--------------+
+----------+---------+
| Metric   |   Value |
|----------+---------|
| Accuracy |       0 |
| F1-Score |       0 |
+----------+---------+
########################################################################################
########################################################################################
Batch: 2 | COUNT: 50
+----------------------+--------------+
|   Violation Precinct |   prediction |
|----------------------+--------------|
|                  108 |          109 |
|                  109 |          110 |
|                   71 |           77 |
|                    0 |            0 |
|                   34 |           19 |
|                  115 |          

KeyboardInterrupt: 