In [31]:
import os
os.environ['HADOOP_HOME'] = r"C:\hadoop"
os.environ['PATH'] += r";C:\hadoop\bin"

In [32]:
import findspark
findspark.init("C:\spark-3.5.5-bin-hadoop3")
import pyspark
from pyspark.sql import SparkSession
from pyspark.ml.feature import RegexTokenizer, StopWordsRemover, CountVectorizerModel
from pyspark.ml.clustering import LDAModel

scala_version = "2.12"
spark_version = "3.5.5"
packages = [
    f"org.apache.spark:spark-sql-kafka-0-10_{scala_version}:{spark_version}",
    "org.apache.kafka:kafka-clients:3.6.0"
]

# spark = SparkSession.builder \
#     .master("local") \
#     .appName("kafka-example") \ 
#     .config("spark.jars.packages", ",".join(packages)) \
#     .getOrCreate()
    
spark = SparkSession.builder \
    .appName("RedditTopicModeling") \
    .master("local") \
    .config("spark.jars.packages", ",".join(packages)) \
    .config("spark.sql.streaming.forceDeleteTempCheckpointLocation", "true") \
    .config("spark.hadoop.io.nativeio.enabled", "false") \
    .config("spark.sql.broadcastTimeout", "600") \
    .getOrCreate()
    # .config("spark.hadoop.fs.file.impl", "org.apache.hadoop.fs.LocalFileSystem") \
    # .config("spark.hadoop.fs.AbstractFileSystem.file.impl", "org.apache.hadoop.fs.local.LocalFs") \
    

In [33]:
from pyspark.sql.functions import from_json, col, concat_ws, from_unixtime
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, TimestampType
import os

schema = StructType() \
    .add("id", StringType()) \
    .add("title", StringType()) \
    .add("selftext", StringType()) \
    .add("created_utc", LongType())

raw = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "localhost:9092") \
    .option("subscribe", "reddit_posts") \
    .load()

posts = (raw
    .selectExpr("CAST(value AS STRING) AS json")
    .select(from_json(col("json"), schema).alias("data"))
    .select("data.id", "data.title", "data.selftext", "data.created_utc")
    
    # ← here we add your `text` column and a proper timestamp
    .withColumn("text", concat_ws(" ", col("title"), col("selftext")))
    .withColumn("timestamp",
                from_unixtime(col("created_utc")).cast(TimestampType()))
)

In [34]:
from bertopic import BERTopic
topic_model = BERTopic.load("modeling/bertopic")
info = topic_model.get_topic_info()  
label_map = dict(zip(info["Topic"], info["Name"]))

In [None]:
from pyspark.sql.streaming import StreamingQuery
import pandas as pd
import numpy as np
from pyspark.sql.functions import (
    from_json, col, concat_ws, from_unixtime, pandas_udf
)
from pyspark.sql.functions import split, regexp_replace, col

csv_path = "output/all_results.csv"
checkpoint_path = "output/checkpoints/bertopic_fbb"

@pandas_udf(StringType())
def predict_label(texts: pd.Series) -> pd.Series:
    model = topic_model
    topics, _ = model.transform(texts.tolist())
    # Với mỗi t: nếu None giữ None, else lấy label_map[t]
    return pd.Series([
        label_map.get(int(t), None) if t is not None else None
        for t in topics
    ])

annotated = (posts
    .withColumn("topic", predict_label(col("text")))
)

def foreach_batch(df, epoch_id):
    # select only the columns you want
    pdf = df.select("text", "topic") \
            .toPandas()
    if pdf.empty:
        return

    # strip off the leading ID and clean underscores
    pdf["topic"] = (
        pdf["topic"]
           .str.split("_", n=1).str.get(1)     # drop the "1_"
           .str.replace("_", " ", regex=False)  # underscores → spaces
    )
    
    # append to a single CSV; write header only on the first batch
    pdf.to_csv(
        csv_path,
        mode="a",
        index=False,
        header=(epoch_id == 0)
    )

if posts.isStreaming:
    print("We are streaming!")
    # (posts.writeStream
    #     .foreachBatch(foreach_batch)
    #     .option("checkpointLocation", "./tmp/checkpoints/bertopic")
    #     .trigger(processingTime="5 seconds")
    #     .start()
    #     .awaitTermination())
    
    # query = (
    #     annotated.writeStream
    #     # DataStream queries need to be named
    #     .queryName("posts")
    #     .format("memory")
    #     .outputMode("append")
    #     .trigger(processingTime="5 seconds")
    #     .option("checkpointLocation", r"C:\tmp\spark-checkpoint\posts_append_v23")
    #     .start()
    # )
    
    # import time
    # while query.isActive:
    #     # 2) Clear the terminal window
    #     os.system('clear')
    #     time.sleep(5)  # phải dài hơn trigger interval
    #     print("=== Latest batch snapshot ===")
    #     spark.table("posts").select("text", "topic").show(truncate=False)
    
    cleaned = annotated \
    .select(
        col("text"),
        col("topic")
    )
    
    query = (
        cleaned.writeStream               # file sink only supports append
        .foreachBatch(foreach_batch)        # where to write the CSV files
        .option("checkpointLocation", checkpoint_path)
        .option("header", True)              # emit headers in each part file
        .trigger(processingTime="5 seconds")
        .start()
    )
    
    # Chờ query hoàn thành
    query.awaitTermination()

We are streaming!
