In [None]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
from pyspark.ml.pipeline import PipelineModel
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import StringType
from pyspark.sql.functions import lower, when, col, udf, split, lit, format_string
from dotenv import load_dotenv
import os
load_dotenv()

## Constants

In [None]:
TRAINING_FILE = os.getenv("TRAINING_FILE","dataset/dataset.csv")
SPARK_MASTER = os.getenv("SPARK_MASTER", "spark://gpu3.esw:7077")
KAFKA_SERVER = os.getenv("KAFKA_SERVER", 'localhost:9092')

SPARK_APP_NAME = "Final - PSPD - Predict"
INTERVAL = os.getenv("INTERVAL", "10 seconds")

PREDICT_TOPIC = os.getenv("PREDICT_TOPIC", 'election')
STATS_TOPIC = os.getenv("STATS_TOPIC", 'elasticsearch-sink')

PACKAGES = "org.apache.spark:spark-sql-kafka-0-10_2.12:3.2.0"

PRETRAINED_MODEL_PATH = os.getenv("PRETRAINED_MODEL_PATH", "model/trained.model")
STOPWORDS_PATH = os.getenv("STOPWORDS_PATH", "dataset/stopwords.txt")

## Startup

In [None]:
conf = SparkConf() \
    .setMaster(SPARK_MASTER) \
    .setAppName(SPARK_APP_NAME) \
    .set("spark.jars.packages", PACKAGES)
    
context = SparkContext(conf=conf)
context.setLogLevel("ERROR")

In [None]:
spark = SparkSession.builder.getOrCreate()

## Cleaner

In [None]:
import re

CLEAN_REGEX = r"[.,/\\\[\]\{\}`~^\d&!@#$%*\)\(\'\"<>=+-:;?]"

stopwords = set()

with open(STOPWORDS_PATH, "r") as stop_file:
    for w in stop_file:
        stopwords.add(w.strip().lower())

def cleaner(sentence):
    print(sentence)
    sentence = " ".join(
        filter(
            lambda x: x not in stopwords,
            re.sub(CLEAN_REGEX, '', sentence).split()
        )
    )
    return sentence

cleaner_col = udf(lambda s: cleaner(s), StringType())

## Load Pre-trained Model

In [None]:
model = PipelineModel.load(PRETRAINED_MODEL_PATH)

## Prediction

In [None]:
def foreach_batch_func(df: DataFrame, _):
    # Preparations - split into candidate and message and clean
    candidateMessage = split(df.value, ",", 2)
    sentences = df \
                .withColumn("candidate", candidateMessage.getItem(0)) \
                .withColumn("sentence", cleaner_col(lower(candidateMessage.getItem(1))))

    # Predict
    prediction = model.transform(sentences) \
                .select(
                    "candidate",
                    "sentence",
                    "probability",
                    when(col("prediction") == 1.0, "positive").otherwise("negative").alias("prediction")
                ) \

    # Write in console
    prediction \
        .write \
        .format("console") \
        .save()

    # Prepare prediction to elasticsearch format
    # Group by candidate and prediction and format to json
    predictionElastic = prediction \
                        .groupBy(
                            "candidate",
                            "prediction"
                        ).count() \
                        .select(
                            lit('1').alias("key"),
                            format_string(
                                "{\"candidate\": \"%s\", \"%s\": %d}",
                                col("candidate"), col("prediction"), col("count")
                            ).alias("value")
                        )
    
    # Write to kafka elasticsearch topic
    predictionElastic.write \
                    .format("kafka") \
                    .option("kafka.bootstrap.servers", KAFKA_SERVER) \
                    .option('topic', STATS_TOPIC) \
                    .save()

## Sink

In [None]:
lines = spark \
    .readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", KAFKA_SERVER) \
    .option("subscribe", PREDICT_TOPIC) \
    .option("failOnDataLoss", "false") \
    .load() \
    .writeStream \
    .foreachBatch(foreach_batch_func) \
    .option("checkpointLocation", "/tmp/spark/mllib-predict") \
    .trigger(processingTime=INTERVAL) \
    .start()

# End

In [None]:
lines.stop()

In [None]:
spark.stop()
context.stop()