In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StringType, DateType
from pyspark.sql.functions import (
    col, current_date, lit, to_date, year, month, dayofmonth, weekofyear, quarter, regexp_replace
)

# Load Silver Data
silver_path = "dbfs:/FileStore/tables/fact_transaction.csv"
df_silver = spark.read.option("header", True).csv(silver_path)

# ----------------- Dimensions -----------------
# DimCustomer (SCD Type 2)
df_customer = df_silver.select(
    col("CustomerID"),
    lit("Unknown").alias("CustomerName"),
    lit(None).cast(DateType()).alias("DOB"),
    lit(None).cast(StringType()).alias("Gender"),
    lit(None).cast(StringType()).alias("Address"),
    lit(None).cast(StringType()).alias("Email"),
    lit("Active").alias("Status"),
    current_date().alias("EffectiveDate"),
    lit(None).cast(DateType()).alias("EndDate"),
    lit(True).alias("IsCurrent")
).dropDuplicates(["CustomerID"])

# DimAccount
df_account = df_silver.select(
    col("AccountNumber"),
    col("CustomerID"),
    col("ATMID").alias("BranchID"),
    lit(1).alias("ProductID"),
    lit("Savings").alias("AccountType"),
    lit("Active").alias("Status"),
    to_date(col("processedAt")).alias("OpenDate"),
    lit(None).cast(DateType()).alias("CloseDate")
).dropDuplicates(["AccountNumber"])

# DimBranch
df_branch = df_silver.select(
    col("ATMID").alias("BranchID"),
    col("Location").alias("BranchName"),
    col("Location").alias("Location")
).dropDuplicates(["BranchID"])

# DimProduct (Static)
df_product = spark.createDataFrame([
    (1, "Savings Account", "Deposit"),
    (2, "Current Account", "Deposit"),
    (3, "Home Loan", "Loan"),
    (4, "Personal Loan", "Loan")
], ["ProductID", "ProductName", "ProductType"])

# DimDate
df_date = df_silver.filter(col("TransactionTime").isNotNull())\
    .withColumn("DateValue", to_date(col("TransactionTime")))\
    .filter(col("DateValue").isNotNull())\
    .select("DateValue")\
    .dropDuplicates()\
    .withColumn("DateID", date_format(col("DateValue"), "yyyyMMdd").cast("int"))\
    .withColumn("Day", dayofmonth(col("DateValue")))\
    .withColumn("Month", month(col("DateValue")))\
    .withColumn("Year", year(col("DateValue")))\
    .withColumn("Quarter", quarter(col("DateValue")))\
    .withColumn("WeekOfYear", weekofyear(col("DateValue")))\
    .dropDuplicates(["DateID"])

# ----------------- FactTransactions -----------------
df_fact_txn = df_silver.select(
    col("TransactionID"),
    col("AccountNumber"),
    col("CustomerID"),
    col("ATMID").alias("BranchID"),
    lit(1).alias("ProductID"),
    regexp_replace(col("TransactionTime").cast("string").substr(1,10), "-", "").alias("DateID"),
    col("TransactionType"),
    col("TransactionAmount").cast("double").alias("Amount"),
    col("Status"),
    col("fraud_flags").alias("FraudFlag")
)

# ----------------- JDBC Connection -----------------
jdbc_url = ""
db_properties = {
    "user": "",
    "password": "",
    "driver": ""
}

# ----------------- Writeing Staging Tables -----------------
df_customer.write.jdbc(url=jdbc_url, table="Staging_DimCustomer", mode="overwrite", properties=db_properties)
df_account.write.jdbc(url=jdbc_url, table="Staging_DimAccount", mode="overwrite", properties=db_properties)
df_branch.write.jdbc(url=jdbc_url, table="Staging_DimBranch", mode="overwrite", properties=db_properties)
df_product.write.jdbc(url=jdbc_url, table="Staging_DimProduct", mode="overwrite", properties=db_properties)
df_date.write.jdbc(url=jdbc_url, table="Staging_DimDate", mode="overwrite", properties=db_properties)
df_fact_txn.write.jdbc(url=jdbc_url, table="Staging_FactTransactions", mode="overwrite", properties=db_properties)

print("Staging tables loaded. Now execute MERGE scripts in SQL.")

In [0]:
df_date.write.jdbc(url=jdbc_url, table="Staging_DimDate", mode="overwrite", properties=db_properties)


In [0]:
from pyspark.sql.functions import col, when

# For FactTransactions
df_fact_txn_clean = df_fact_txn \
    .withColumn("ProductID", when(col("ProductID") == "null", None).otherwise(col("ProductID").cast("int"))) \
    .withColumn("DateID", when(col("DateID") == "null", None).otherwise(col("DateID").cast("int"))) \
    .withColumn("Amount", when(col("Amount") == "null", None).otherwise(col("Amount").cast("double")))


In [0]:
df_fact_txn_clean.write.jdbc(
    url=jdbc_url,
    table="Staging_FactTransactions",
    mode="overwrite",
    properties=db_properties
)


In [0]:
from pyspark.sql.functions import col, concat_ws, monotonically_increasing_id, lit, when


df_fact_txn_clean = df_silver.withColumn(
    "TransactionID",
    when(
        col("TransactionID").rlike("ATM\\d{6}"), 
        col("TransactionID")
    ).otherwise(
        concat_ws("_", col("AccountNumber"), col("TransactionTime"), monotonically_increasing_id())
    )
)

# Drop exact duplicates if any
df_fact_txn_clean = df_fact_txn_clean.dropDuplicates(["TransactionID"])


In [0]:
from pyspark.sql.functions import col, when, lit


product_mapping = {"WITHDRAWAL": 1, "DEPOSIT": 2, "TRANSFER": 3}


df_fact_txn_clean = df_fact_txn_clean.withColumn(
    "ProductID",
    when(col("TransactionType") == "WITHDRAWAL", lit(1))
    .when(col("TransactionType") == "DEPOSIT", lit(2))
    .when(col("TransactionType") == "TRANSFER", lit(3))
    .otherwise(lit(None))
)


In [0]:
from pyspark.sql.functions import col, concat_ws, monotonically_increasing_id, lit

# Start from your cleaned silver data
df_fact_txn_clean = df_silver


df_fact_txn_clean = df_fact_txn_clean.withColumn(
    "TransactionID",
    when(
        col("TransactionID").rlike("ATM\\d{6}"), 
        col("TransactionID")
    ).otherwise(
        concat_ws("_", col("AccountNumber"), col("TransactionTime"), monotonically_increasing_id())
    )
)

# Remove exact duplicates
df_fact_txn_clean = df_fact_txn_clean.dropDuplicates(["TransactionID"])


In [0]:
from pyspark.sql.functions import when, to_date, date_format

df_fact_txn_clean = df_fact_txn_clean \
    .withColumn("Amount", when(col("TransactionAmount").isNull() | (col("TransactionAmount") == "null"), None)
                .otherwise(col("TransactionAmount").cast("double"))) \
    .withColumn("DateID", when(col("TransactionTime").isNull() | (col("TransactionTime") == "null"), None)
                .otherwise(date_format(to_date(col("TransactionTime")), "yyyyMMdd").cast("int")))

# Map TransactionType to ProductID
from pyspark.sql.functions import lit
df_fact_txn_clean = df_fact_txn_clean.withColumn(
    "ProductID",
    when(col("TransactionType") == "WITHDRAWAL", lit(1))
    .when(col("TransactionType") == "DEPOSIT", lit(2))
    .when(col("TransactionType") == "TRANSFER", lit(3))
    .otherwise(lit(None))
)


In [0]:
df_fact_txn_clean.write.jdbc(
    url=jdbc_url,
    table="Staging_FactTransactions",
    mode="overwrite",
    properties=db_properties
)


In [0]:
# Load DimBranch from Azure SQL into PySpark
df_dim_branch = spark.read.jdbc(
    url=jdbc_url,
    table="DimBranch",
    properties=db_properties
)


In [0]:
# Join silver data with DimBranch using Location to get BranchID
df_fact_txn_final = df_fact_txn_clean.join(
    df_dim_branch.select("Location", "BranchID"),
    on="Location",
    how="left"
)


In [0]:

df_fact_txn_final = df_fact_txn_final.withColumnRenamed("fraud_flags", "FraudFlag")


df_fact_txn_final = df_fact_txn_final.select(
    "TransactionID",
    "AccountNumber",
    "CustomerID",
    "BranchID",
    "ProductID",
    "DateID",
    "TransactionType",
    "Amount",
    "Status",
    "FraudFlag"
)

df_fact_txn_final = df_fact_txn_final.dropDuplicates(["TransactionID"])


In [0]:

df_fact_txn_final.write.jdbc(
    url=jdbc_url,                
    table="Staging_FactTransactions",
    mode="overwrite",            
    properties=db_properties     
)


In [0]:
from pyspark.sql.functions import to_timestamp, date_format, col, when

# Correct DateID generation
df_fact_txn_clean = df_fact_txn_clean.withColumn(
    "DateID",
    when(
        col("TransactionTime").isNotNull(),
        date_format(to_timestamp(col("TransactionTime"), "yyyy-MM-dd'T'HH:mm:ssX"), "yyyyMMdd").cast("int")
    ).otherwise(None)
)


In [0]:
from pyspark.sql.functions import col

# Join with DimBranch to get BranchID using Location
df_fact_txn_final = df_fact_txn_clean.join(
    df_dim_branch.select("Location", "BranchID"),
    on="Location",
    how="left"
)

# Rename fraud_flags to FraudFlag
df_fact_txn_final = df_fact_txn_final.withColumnRenamed("fraud_flags", "FraudFlag")

# Select only columns needed for staging
df_fact_txn_final = df_fact_txn_final.select(
    "TransactionID",
    "AccountNumber",
    "CustomerID",
    "BranchID",
    "ProductID",
    "DateID",
    "TransactionType",
    "Amount",
    "Status",
    "FraudFlag"
)

# Drop duplicates
df_fact_txn_final = df_fact_txn_final.dropDuplicates(["TransactionID"])


In [0]:
df_fact_txn_final.write.jdbc(
    url=jdbc_url,
    table="Staging_FactTransactions",
    mode="overwrite",
    properties=db_properties
)


In [0]:
df_silver.select("TransactionTime").distinct().show(50, truncate=False)


In [0]:
from pyspark.sql.functions import to_timestamp, date_format, col

df_fact_txn_clean = df_silver.withColumn(
    "DateID",
    date_format(
        to_timestamp(col("TransactionTime"), "yyyy-MM-dd'T'HH:mm:ssX"),  # X handles Z timezone
        "yyyyMMdd"
    ).cast("int")
)


In [0]:
from pyspark.sql.functions import when

df_fact_txn_clean = df_fact_txn_clean.withColumn(
    "DateID",
    when(col("TransactionTime").isNotNull(),
         date_format(to_timestamp(col("TransactionTime"), "yyyy-MM-dd'T'HH:mm:ssX"), "yyyyMMdd").cast("int")
    ).otherwise(None)
)


In [0]:
df_fact_txn_clean.filter(col("DateID").isNull()).count()


In [0]:
from pyspark.sql.functions import when, col, to_timestamp, date_format

df_fact_txn_clean = df_silver.withColumn(
    "TransactionTime_clean",
    when(
        (col("TransactionTime").isNull()) | 
        (col("TransactionTime") == "null") | 
        (col("TransactionTime") == ""), None
    ).otherwise(col("TransactionTime"))
)


In [0]:
from pyspark.sql.functions import when, col, to_timestamp, date_format

df_fact_txn_clean = df_silver.withColumn(
    "TransactionTime_clean",
    when(
        (col("TransactionTime").isNull()) | 
        (col("TransactionTime") == "null") | 
        (col("TransactionTime") == ""), None
    ).otherwise(col("TransactionTime"))
)


In [0]:
df_fact_txn_clean = df_fact_txn_clean.filter(col("DateID").isNotNull())


In [0]:
df_fact_txn_clean.filter(col("DateID").isNull()).count()


In [0]:
df_fact_txn_final = df_fact_txn_clean.join(
    df_dim_branch.select("Location", "BranchID"),
    on="Location",
    how="left"
)


In [0]:
df_fact_txn_final = df_fact_txn_final.withColumnRenamed("fraud_flags", "FraudFlag")


In [0]:

from pyspark.sql.functions import when

df_fact_txn_final = df_fact_txn_final.withColumn(
    "ProductID",
    when(col("TransactionType") == "WITHDRAWAL", 1) \
    .when(col("TransactionType") == "DEPOSIT", 2) \
    .when(col("TransactionType") == "TRANSFER", 3) \
    .otherwise(None)
)


In [0]:
df_dim_product = spark.read.jdbc(
    url=jdbc_url,
    table="DimProduct",
    properties=db_properties
)


In [0]:
df_fact_txn_final = df_fact_txn_clean.join(
    df_dim_product.select("ProductName", "ProductID"),
    df_fact_txn_clean.TransactionType == df_dim_product.ProductName,
    "left"
)
    

In [0]:
df_fact_txn_final = df_fact_txn_final.join(
    df_dim_branch.select("Location", "BranchID"),
    on="Location",
    how="left"
)


In [0]:
df_fact_txn_final = df_fact_txn_final.withColumnRenamed("fraud_flags", "FraudFlag")


In [0]:
from pyspark.sql.functions import col

df_fact_txn_final = df_fact_txn_final.select(
    "TransactionID",
    "AccountNumber",
    "CustomerID",
    "BranchID",
    "ProductID",
    "DateID",
    "TransactionType",
    col("TransactionAmount").alias("Amount"),  # rename to match SQL schema
    "Status",
    "FraudFlag"
)


In [0]:
df_fact_txn_final = df_fact_txn_final.dropDuplicates(["TransactionID"])


In [0]:
df_fact_txn_final.write.jdbc(
    url=jdbc_url,
    table="Staging_FactTransactions",
    mode="overwrite",
    properties=db_properties
)


In [0]:
from pyspark.sql.functions import when, col, to_timestamp, date_format

df_fact_txn_clean = df_silver.withColumn(
    "TransactionTime_clean",
    when(
        (col("TransactionTime").isNull()) | 
        (col("TransactionTime") == "null") | 
        (col("TransactionTime") == ""), None
    ).otherwise(col("TransactionTime"))
)


In [0]:
df_fact_txn_clean = df_fact_txn_clean.withColumn(
    "DateID",
    when(
        col("TransactionTime_clean").isNotNull(),
        date_format(
            to_timestamp(col("TransactionTime_clean"), "yyyy-MM-dd'T'HH:mm:ssX"),
            "yyyyMMdd"
        ).cast("int")
    ).otherwise(None)
)


In [0]:
df_fact_txn_clean = df_fact_txn_clean.filter(col("DateID").isNotNull())


In [0]:
df_fact_txn_clean.filter(col("DateID").isNull()).count()

In [0]:
df_fact_txn_final = df_fact_txn_clean.join(
    df_dim_branch.select("Location", "BranchID"),
    on="Location",
    how="left"
)


In [0]:
df_fact_txn_final = df_fact_txn_final.withColumnRenamed("fraud_flags", "FraudFlag")


In [0]:
df_fact_txn_final = df_fact_txn_clean.join(
    df_dim_product.select("ProductName", "ProductID"),
    df_fact_txn_clean.TransactionType == df_dim_product.ProductName,
    "left"
)
    

In [0]:
# If DimBranch exists in Azure SQL
df_dim_branch = spark.read.jdbc(
    url=jdbc_url,
    table="DimBranch",
    properties=db_properties
)

from pyspark.sql import Row

branch_data = [
    Row(BranchID=1, Location="Kolkata", BranchName="Kolkata Main"),
    Row(BranchID=2, Location="Hyderabad", BranchName="Hyderabad Main"),
    Row(BranchID=3, Location="Pune", BranchName="Pune Main"),
    Row(BranchID=4, Location="Delhi", BranchName="Delhi Main")
]

df_dim_branch = spark.createDataFrame(branch_data)
df_dim_branch.show()


In [0]:
df_fact_txn_final = df_fact_txn_clean.join(
    df_dim_branch.select("Location", "BranchID"),
    on="Location",
    how="left"
)


In [0]:
from pyspark.sql.functions import when

df_fact_txn_final = df_fact_txn_final.withColumn(
    "ProductID",
    when(col("TransactionType") == "WITHDRAWAL", 1) \
    .when(col("TransactionType") == "DEPOSIT", 2) \
    .when(col("TransactionType") == "TRANSFER", 3) \
    .otherwise(None)
)


In [0]:
df_fact_txn_final = df_fact_txn_final.withColumnRenamed("TransactionAmount", "Amount") \
                                     .withColumnRenamed("fraud_flags", "FraudFlag")


In [0]:
df_staging_fact = df_fact_txn_final.select(
    "TransactionID",
    "AccountNumber",
    "CustomerID",
    "BranchID",
    "ProductID",
    "DateID",
    "TransactionType",
    "Amount",
    "Status",
    "FraudFlag"
).dropDuplicates(["TransactionID"])


In [0]:
df_staging_fact.write.jdbc(
    url=jdbc_url,
    table="Staging_FactTransactions",
    mode="overwrite",
    properties=db_properties
)

print("Loaded into Staging_FactTransactions")


In [0]:
from pyspark.sql.functions import trim, upper, col

# Normalize strings for proper matching
df_fact_txn_clean = df_fact_txn_clean.withColumn("Location", trim(upper(col("Location"))))
df_dim_branch = df_dim_branch.withColumn("Location", trim(upper(col("Location"))))

# Join to get BranchID
df_fact_txn_final = df_fact_txn_clean.join(
    df_dim_branch.select("Location", "BranchID"),
    on="Location",
    how="left"
)


In [0]:
df_fact_txn_final.filter(col("BranchID").isNull()).count()


In [0]:
# Find missing Locations
missing_locations = df_fact_txn_final.filter(col("BranchID").isNull()) \
                                     .select("Location").distinct()

missing_locations.show()


In [0]:
from pyspark.sql import Row

new_branches = [
    Row(BranchID=5, Location="BANGALORE", BranchName="Bangalore Main"),
    Row(BranchID=6, Location="CHENNAI", BranchName="Chennai Main"),
    Row(BranchID=7, Location="MUMBAI", BranchName="Mumbai Main")
]

df_new_branches = spark.createDataFrame(new_branches)

# Union with existing DimBranch
df_dim_branch = df_dim_branch.union(df_new_branches)


In [0]:
from pyspark.sql.functions import trim, upper, col

# Normalize strings
df_fact_txn_clean = df_fact_txn_clean.withColumn("Location", trim(upper(col("Location"))))
df_dim_branch = df_dim_branch.withColumn("Location", trim(upper(col("Location"))))

# Join to get BranchID
df_fact_txn_final = df_fact_txn_clean.join(
    df_dim_branch.select("Location", "BranchID"),
    on="Location",
    how="left"
)

# Verify no more null BranchID
df_fact_txn_final.filter(col("BranchID").isNull()).count()  # should return 0


In [0]:
from pyspark.sql.functions import when, col

df_fact_txn_final = df_fact_txn_final.withColumn(
    "ProductID",
    when(col("TransactionType") == "WITHDRAWAL", 1)
    .when(col("TransactionType") == "DEPOSIT", 2)
    .when(col("TransactionType") == "TRANSFER", 3)
    .otherwise(None)
)


In [0]:
df_staging_fact = df_fact_txn_final.withColumnRenamed("TransactionAmount", "Amount") \
                                   .withColumnRenamed("fraud_flags", "FraudFlag") \
                                   .select(
                                        "TransactionID",
                                        "AccountNumber",
                                        "CustomerID",
                                        "BranchID",
                                        "ProductID",
                                        "DateID",
                                        "TransactionType",
                                        "Amount",
                                        "Status",
                                        "FraudFlag"
                                   ).dropDuplicates(["TransactionID"])


In [0]:
df_staging_fact.write.jdbc(
    url=jdbc_url,
    table="Staging_FactTransactions",
    mode="overwrite",  # or "append" if you want incremental load
    properties=db_properties
)

print("Loaded into Staging_FactTransactions successfully")


In [0]:
from pyspark.sql.functions import regexp_extract

# Extract numeric part from ATMID or map manually
df_fact_txn_final = df_fact_txn_final.withColumn(
    "BranchID",
    regexp_extract(col("ATMID"), "([0-9]+)$", 1).cast("int")
)


In [0]:

from pyspark.sql.functions import when

df_fact_txn_clean = df_fact_txn_clean.withColumn(
    "ProductName",
    when(col("TransactionType").isin("WITHDRAWAL","DEPOSIT"), "Savings")
    .when(col("TransactionType") == "TRANSFER", "Current")
    .otherwise("Other")
)

# Join with DimProduct to get ProductID
df_fact_txn_clean = df_fact_txn_clean.join(
    df_dim_product,
    df_fact_txn_clean.ProductName == df_dim_product.ProductName,
    how="left"
).drop("ProductName")


In [0]:
# Drop duplicate TransactionID rows, keeping the latest TransactionTime
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, desc

window = Window.partitionBy("TransactionID").orderBy(desc("TransactionTime"))

df_fact_txn_dedup = df_fact_txn_final.withColumn("rn", row_number().over(window)) \
                                      .filter(col("rn") == 1) \
                                      .drop("rn")

# Write df_fact_txn_dedup to Staging_FactTransactions


In [0]:
jdbc_url = ""

df_fact_txn_dedup.write \
    .format("jdbc") \
    .mode("overwrite") \
    .option("url", jdbc_url) \
    .option("dbtable", table_name) \
    .option("user", "") \
    .option("password", "") \
    .save()
