In [0]:
from pyspark.sql.functions import col, when, lit, current_timestamp, array

transactions = spark.table("silver.transactions_enriched")
clients = spark.table("silver.clients_enriched")

In [0]:
# Rule flags
transactions = transactions.withColumn(
    "rule_high_amount", col("transaction_amount") > 10000
)
transactions = transactions.withColumn(
    "rule_high_risk_counterparty", col("aml_risk_score_counterparty") == "High"
)
transactions = transactions.withColumn(
    "rule_rw_count", col("rolling_txn_count_30d") >= 10
)
transactions = transactions.withColumn("rule_rw_sum", col("rolling_txn_sum_30d") >= 50000)

# Save risk events
risk_events = (
    transactions.filter(
        col("rule_high_amount")
        | col("rule_high_risk_counterparty")
        | col("rule_rw_count")
        | col("rule_rw_sum")
    )
    .select(
        col("transaction_id"),
        col("client_id"),
        when(col("rule_high_amount"), lit("high_amount"))
        .otherwise(None)
        .alias("trigger_reason_1"),
        when(col("rule_high_risk_counterparty"), lit("high_risk_counterparty"))
        .otherwise(None)
        .alias("trigger_reason_2"),
        when(col("rule_rw_count"), lit("rw_count"))
        .otherwise(None)
        .alias("trigger_reason_3"),
        when(col("rule_rw_sum"), lit("rw_sum"))
        .otherwise(None)
        .alias("trigger_reason_4"),
    )).withColumn("evaluated_at", current_timestamp())

    # .withColumn(
    #     "trigger_reasons",
    #     array(
    #         when(col("trigger_reason_1"), lit("high_amount")),
    #         when(col("trigger_reason_2"), lit("high_risk_counterparty")),
    #         when(col("trigger_reason_3"), lit("rw_count")),
    #         when(col("trigger_reason_4"), lit("rw_sum")),
    #     ),
    # )

In [0]:
risk_events.write.mode("overwrite").saveAsTable("silver.risk_events")