In [1]:
import os

aws_access_key = os.getenv("AWS_ACCESS_KEY_ID")
aws_secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")

In [2]:
from pyspark.sql import SparkSession

spark = (
    SparkSession.builder
    .appName("Kafka-Keyspaces-Write")
    .master("local[*]")
    
    .config(
        "spark.jars.packages",
        "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.5,"
        "org.postgresql:postgresql:42.7.8"
    )
    .config("spark.sql.streaming.forceDeleteTempCheckpointLocation", "true")
    .getOrCreate()
)


spark.sparkContext.setLogLevel("INFO")

print("Spark started successfully")
print("Version:", spark.version)


Spark started successfully
Version: 3.5.5


In [None]:
spark

In [None]:
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType , DoubleType

fund_schema = StructType([
    StructField("fund_id", IntegerType(), True),
    StructField("fund_name", StringType(), True),
    StructField("fund_code", StringType(), True),
    StructField("fund_description", StringType(), True),
    StructField("updated_at", LongType(), True),
    StructField("fund_price", DoubleType(), True)
])

source_schema = StructType([
    StructField("version", StringType(), True),
    StructField("connector", StringType(), True),
    StructField("name", StringType(), True),
    StructField("ts_ms", LongType(), True),
    StructField("snapshot", StringType(), True),
    StructField("db", StringType(), True),
    StructField("sequence", StringType(), True),
    StructField("schema", StringType(), True),
    StructField("table", StringType(), True),
    StructField("txId", LongType(), True),
    StructField("lsn", LongType(), True),
    StructField("xmin", StringType(), True)
])

cdc_schema = StructType([
    StructField("before", fund_schema, True),
    StructField("after", fund_schema, True),
    StructField("source", source_schema, True),
    StructField("op", StringType(), True),
    StructField("ts_ms", LongType(), True),
    StructField("transaction", StringType(), True),
    StructField("source_system", StringType(), True)
])


In [None]:
df = (
    spark.readStream
    .format("kafka")
    .option("kafka.bootstrap.servers", "localhost:29092")
    .option("subscribe", "pgsrc.public.fund_metadata")
    .option("startingOffsets", "earliest")  # all existing + new messages
    .option("failOnDataLoss", "false")
    .load()
)

In [None]:
from pyspark.sql.functions import col, from_json
parsed_df = df.selectExpr("CAST(value AS STRING) as json_str") \
    .select(from_json(col("json_str"), cdc_schema).alias("data")) \
    .select("data.*")

In [None]:
from pyspark.sql.functions import col, when, to_timestamp

actual_df = (
    parsed_df.select(
        col("op").alias("Operation"),
        to_timestamp((col("ts_ms") / 1000)).alias("ProcessTime"),
        col("before"),
        col("after")
    )
    .withColumn(
        "Values",
        when(col("Operation") == "d", col("before"))  
        .otherwise(col("after"))                      
    )
    .drop("before", "after")
)




In [None]:
from pyspark.sql.functions import current_timestamp
final_df = (
    actual_df
    .filter(
        col("Operation").isNotNull() &
        col("Values.fund_id").isNotNull()
    )
    .select(
        col("Operation"),
        col("ProcessTime"),
        col("Values.fund_id").alias("fund_id"),
        col("Values.fund_name").alias("fund_name"),
        col("Values.fund_code").alias("fund_code"),
        col("Values.fund_description").alias("fund_description"),
        to_timestamp(col("Values.updated_at") / 1_000_000).alias("updated_at"),
        col("Values.fund_price").alias("fund_price")
    ).withColumn(
        "delete_flag",
        when(col("Operation") == "d", "Y").otherwise("N")
    )
    .withColumn(
        "effective_date",
        current_timestamp()
    )
)



In [None]:
jdbc_url = "jdbc:postgresql://localhost:5432/finance"
jdbc_properties = {
    "user": "pguser",
    "password": "pgpassword",
    "driver": "org.postgresql.Driver"
}

In [None]:
spark.sparkContext.setLogLevel("DEBUG")

def write_to_multiple_sink(batch_df, batch_id):

    print(f"\n--- Writing Batch {batch_id} ---")
    batch_df.show(truncate=False) 
    print("COLUMNS:", batch_df.columns)

    try:
        batch_df.write \
            .mode("append") \
            .jdbc(url=jdbc_url, table="public.fund_metadata_trail", properties=jdbc_properties)
        
        print(f"Batch {batch_id} written to Postgres.")
    except Exception as e:
        print ("[FAIL] ERROR while writing batch in Postgres:", batch_id)
        print(e)
        raise e

    # try:
    #     batch_df.write \
    #         .format("org.apache.spark.sql.cassandra") \
    #         .options(
    #             table="fund_metadata_trail",
    #             keyspace="fund_metadata"
    #         ) \
    #         .mode("append") \
    #         .save()
        
    #     print(f"Batch {batch_id} written to AWS Keyspaces.")
    # except Exception as e:
    #     print ("[FAIL] ERROR while writing batch in Cassandra:", batch_id)
    #     print(e)
    #     raise e        

query = final_df.writeStream.foreachBatch(write_to_multiple_sink).start()


In [None]:
query.isActive  # True if running, False if stopped


In [None]:
query.stop()

In [None]:
spark.stop()