In [24]:
from pyspark.sql.functions import col, avg, count, lit, current_timestamp


jdbc_url = "jdbc:mysql://localhost:3306/transaction_db"
connection_properties = {
    "user": "root",
    "password": "Gopi@777",
    "driver": "com.mysql.cj.jdbc.Driver"
}


importance_df = spark.read.option("header", True).csv("Files/raw_data/CustomerImportance.csv")


chunk_files_df = spark.read.format("binaryFile").load("Files/chunks/")
chunk_files = [row.path for row in chunk_files_df.select("path").collect()]


detections = []


for file_path in chunk_files:
    chunk_df = spark.read.parquet(file_path)
    joined_df = chunk_df.join(importance_df, on="customerId", how="left")


    merchant_counts = joined_df.groupBy("merchantId").agg(count("*").alias("txn_count"))
    merchant_counts_filtered = merchant_counts.filter(col("txn_count") > 50000)


    merchant_counts.write.jdbc(
        url=jdbc_url,
        table="merchant_txn_counts",
        mode="append",
        properties=connection_properties
    )

    for row in merchant_counts_filtered.collect():
        merchant_id = row["merchantId"]
        merchant_df = joined_df.filter(col("merchantId") == merchant_id)
        txn_by_customer = merchant_df.groupBy("customerId").agg(count("*").alias("txn_count"), avg("weight").alias("avg_weight"))

        top_10 = txn_by_customer.approxQuantile("txn_count", [0.9], 0.01)[0]
        bottom_10_weight = txn_by_customer.approxQuantile("avg_weight", [0.1], 0.01)[0]

        upgrade_df = txn_by_customer.filter((col("txn_count") >= top_10) & (col("avg_weight") <= bottom_10_weight))
        upgrade_df = upgrade_df.withColumn("patternId", lit("PatId1")) \
                               .withColumn("ActionType", lit("UPGRADE")) \
                               .withColumn("MerchantId", lit(merchant_id)) \
                               .withColumn("customerName", col("customerId")) \
                               .selectExpr("current_timestamp() as YStartTime", "current_timestamp() as detectionTime", "patternId", "ActionType", "customerName", "MerchantId")

        detections.append(upgrade_df)


    avg_txn_df = joined_df.groupBy("customerId", "merchantId").agg(avg("amount").alias("avg_amount"), count("*").alias("txn_count"))


    avg_txn_df.write.jdbc(
        url=jdbc_url,
        table="customer_txn_summary",
        mode="append",
        properties=connection_properties
    )

    child_df = avg_txn_df.filter((col("avg_amount") < 23) & (col("txn_count") >= 80)) \
                         .withColumn("patternId", lit("PatId2")) \
                         .withColumn("ActionType", lit("CHILD")) \
                         .withColumn("customerName", col("customerId")) \
                         .selectExpr("current_timestamp() as YStartTime", "current_timestamp() as detectionTime", "patternId", "ActionType", "customerName", "merchantId as MerchantId")

    detections.append(child_df)


    gender_df = joined_df.select("merchantId", "customerId", "gender").dropDuplicates()
    gender_counts = gender_df.groupBy("merchantId", "gender").agg(count("customerId").alias("gender_count"))

 
    gender_counts.write.jdbc(
        url=jdbc_url,
        table="gender_counts",
        mode="append",
        properties=connection_properties
    )

    gender_pivot = gender_counts.groupBy("merchantId").pivot("gender").sum("gender_count").fillna(0)
    if "Female" in gender_pivot.columns and "Male" in gender_pivot.columns:
        dei_df = gender_pivot.filter((col("Female") < col("Male")) & (col("Female") > 100)) \
                             .withColumn("patternId", lit("PatId3")) \
                             .withColumn("ActionType", lit("DEI-NEEDED")) \
                             .withColumn("customerName", lit("")) \
                             .selectExpr("current_timestamp() as YStartTime", "current_timestamp() as detectionTime", "patternId", "ActionType", "customerName", "merchantId as MerchantId")

        detections.append(dei_df)


if len(detections) == 0:
    print("No detections found. Check your pattern logic or input data.")
else:
    final_df = detections[0]
    for df in detections[1:]:
        final_df = final_df.union(df)

    print("Total detections:", final_df.count())
    final_df.show(5)

    total = final_df.count()
    batches = (total // 50) + 1
    for i in range(batches):
        batch_df = final_df.limit(50).offset(i * 50)
        try:
            batch_df.write.mode("overwrite").json("Files/detections/detection_batch_{}.json".format(i))
            print(f"Batch {i} written successfully.")
        except Exception as e:
            print(f"Error writing batch {i}:", e)


StatementMeta(, 0ac6c71f-3d0d-454e-b47c-eccf207ac5f8, 26, Finished, Available, Finished)

No detections found. Check your pattern logic or input data.
