In [None]:
import findspark
findspark.init()

In [None]:
from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("KafkaStructuredStreaming") \
    .config('spark.default.parallelism', 1) \
    .config('spark.sql.shuffle.partitions', 1) \
    .getOrCreate()

In [None]:
spark

In [None]:
import nltk

# For sentence and word tokenize
nltk.download('punkt')
# For word tagging (pos_tag)
nltk.download('averaged_perceptron_tagger')
# For extracting named entities
nltk.download('maxent_ne_chunker')
nltk.download('words')
from functools import reduce

def get_entities_count(sentence):
    tokens = nltk.word_tokenize(sentence)
    tagged_words = nltk.pos_tag(tokens)
    chunks = nltk.ne_chunk(tagged_words)

    c = filter(lambda x: isinstance(x, nltk.Tree), chunks)
    leaves = map(lambda x: x.leaves(), c)
    entities = reduce(list.__add__, leaves, [])
    only_entities = map(lambda x: x[0].lower(), entities)

    return list(only_entities)

In [None]:
from pyspark.sql.functions import udf
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType
from pyspark.sql.functions import from_json, to_json, struct, col, explode, count

kafka_bootstrap_servers = "localhost:9092"
kafka_read_topic = "topic1"
kafka_write_topic = "topic2"
schema = StructType([
    StructField("body", StringType(), True),
    StructField("timestamp", IntegerType(), True)
])
checkpoint_location = "./kafka_checkpoints"

map_udf = udf(get_entities_count, ArrayType(StringType(), False))

In [None]:
streaming_df = spark \
    .readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", kafka_bootstrap_servers) \
    .option("subscribe", kafka_read_topic) \
    .option("maxOffsetsPerTrigger", 1000) \
    .load() \
    .selectExpr("CAST(value AS STRING)") \
    .select(from_json("value", schema).alias("data")).select("data.*") \
    .withColumn("entities", map_udf("body")) \
    .select(explode("entities").alias("entity"), "timestamp", "body") \
    .groupBy("entity") \
    .agg(count("entity").alias("count")) \
    .select("entity", "count") \
    .orderBy("count", ascending=False) \
    .withColumn("value", to_json(struct(col("entity"), col("count"))))

In [None]:
streaming_df.printSchema()

In [None]:
# query = streaming_df \
#     .writeStream \
#     .outputMode("complete") \
#     .format("console") \
#     .start()

In [None]:
# query.awaitTermination()

In [None]:
query = streaming_df \
    .repartition(1) \
    .writeStream \
    .outputMode("complete") \
    .format("kafka") \
    .option("kafka.bootstrap.servers", kafka_bootstrap_servers) \
    .option("topic", kafka_write_topic) \
    .option("checkpointLocation", checkpoint_location) \
    .start()

In [None]:
# Start streaming data to topic2
query.awaitTermination()