In [0]:
# Create schema (bronze) if not exists
spark.sql("CREATE SCHEMA IF NOT EXISTS silver")

spark.sql("SELECT current_catalog(), current_database()").show()

In [0]:
from pyspark.sql.window import Window
from pyspark.sql.functions import col, sum as _sum, count as _count, unix_timestamp

clients_df = spark.table("bronze.clients")
transactions_df = spark.table("bronze.transactions")
aml_risk_df = spark.table("bronze.aml_risk_score")

# Join AML risk scores to clients
clients_enriched_df = clients_df.join(aml_risk_df, on="country", how="left")

# Join AML scores to counterparty_country in transactions
transactions_enriched_df = (
    transactions_df
    .join(aml_risk_df.withColumnRenamed("country", "counterparty_country")
                      .withColumnRenamed("aml_risk_score", "aml_risk_score_counterparty"),
          on="counterparty_country", how="left")
)

# Join client AML score
transactions_enriched_df = transactions_enriched_df.join(
    clients_enriched_df.select("client_id", "aml_risk_score"),
    on="client_id", how="left"
)

# Make sure transaction_date is timestamp
transactions_enriched_df = transactions_enriched_df.withColumn("transaction_ts", col("transaction_date").cast("timestamp"))

# Define a 30-day rolling window per client #TODO: set as parameter
rolling_window = (
    Window.partitionBy("client_id")
          .orderBy(col("transaction_ts").cast("long"))
          .rangeBetween(-30 * 86400, 0)  # 30 days in seconds
)

# Add rolling sum and count columns
transactions_enriched_df = (
    transactions_enriched_df
    .withColumn("rolling_txn_count_30d", _count("transaction_id").over(rolling_window))
    .withColumn("rolling_txn_sum_30d", _sum("transaction_amount").over(rolling_window))
)

# Write to silver
clients_enriched_df.write.mode("overwrite").saveAsTable("silver.clients_enriched")
transactions_enriched_df.write.mode("overwrite").saveAsTable("silver.transactions_enriched")
