### Phase 2: Data Transformation & Integration (1â€“1.5 Hours) 

In [0]:
%python
from pyspark.sql.functions import col, to_date, when

# Load tables from default schema
table_names = ['raw_accounts', 'raw_customers', 'raw_customers', 'raw_transactions']
dfs = [spark.table(f'default.{name}') for name in table_names]

# Standardize date formats to 'yyyy-MM-dd'
date_columns = ['created_date', 'updated_date', 'transaction_date']
for i, df in enumerate(dfs):
    for date_col in date_columns:
        if date_col in df.columns:
            df = df.withColumn(
                date_col,
                to_date(col(date_col), 'yyyy-MM-dd')
            )
    dfs[i] = df

# Standardize numeric types to DoubleType
numeric_columns = ['amount', 'balance', 'income']
for i, df in enumerate(dfs):
    for num_col in numeric_columns:
        if num_col in df.columns:
            df = df.withColumn(
                num_col,
                col(num_col).cast('double')
            )
    dfs[i] = df

# Standardize boolean fields to BooleanType
boolean_columns = ['is_active', 'is_joint_account']
for i, df in enumerate(dfs):
    for bool_col in boolean_columns:
        if bool_col in df.columns:
            df = df.withColumn(
                bool_col,
                when(
                    col(bool_col).cast('string').isin('true', 'True', '1'),
                    True
                ).otherwise(False)
            )
    dfs[i] = df

for df in dfs:
    display(df)

In [0]:
# Handle invalid/empty values: drop rows with null or NaN in any column using SQL
for i, table_name in enumerate(table_names):
    # Drop rows with NULLs in any column
    sql_query = f"""
        SELECT *
        FROM default.{table_name}
        WHERE {" AND ".join([f"{col} IS NOT NULL" for col in dfs[i].columns])}
    """
    # Additionally, filter out NaN in numeric columns
    for num_col in numeric_columns:
        if num_col in dfs[i].columns:
            sql_query = sql_query.replace(
                "WHERE",
                f"WHERE ({num_col} = {num_col} OR {num_col} IS NULL) AND"
            )
    dfs[i] = spark.sql(sql_query)

for df in dfs:
    display(df)

In [0]:
# Remove duplicate rows from each DataFrame
for i, df in enumerate(dfs):
    dfs[i] = df.dropDuplicates()

for df in dfs:
    display(df)

In [0]:
from pyspark.sql.functions import regexp_replace, upper, trim

# Define normalization functions for phone, PAN, and Aadhar
def normalize_columns(df):
    # Normalize phone numbers: remove non-digit characters, keep last 10 digits
    if 'phone_number' in df.columns:
        df = df.withColumn(
            'phone_number',
            regexp_replace(col('phone_number'), r'\D', '')
        ).withColumn(
            'phone_number',
            col('phone_number').substr(-9, 10)
        )
    # Normalize PAN: uppercase, remove spaces, keep valid pattern
    if 'pan' in df.columns:
        df = df.withColumn(
            'pan',
            upper(trim(regexp_replace(col('pan'), r'[^A-Z0-9]', '')))
        )
    # Normalize Aadhar: remove non-digit characters, keep 12 digits
    if 'aadhar' in df.columns:
        df = df.withColumn(
            'aadhar',
            regexp_replace(col('aadhar'), r'\D', '')
        ).withColumn(
            'aadhar',
            col('aadhar').substr(1, 12)
        )
    return df

# Apply normalization to all DataFrames
for i, df in enumerate(dfs):
    dfs[i] = normalize_columns(df)

for df in dfs:
    display(df)

In [0]:
# Logical joins between standardized DataFrames
accounts_df, customers_df, branches_df, transactions_df = dfs

# Join accounts with customers on customer_id
accounts_customers_df = accounts_df.join(
    customers_df,
    accounts_df.customer_id == customers_df.customer_id,
    'inner'
)

# Join transactions with accounts on account_id
accounts_customers_transactions_df = accounts_customers_df.join(
    transactions_df,
    col('account_id') == col('account_id'),
    'inner'
)

# Rename duplicate columns by appending '_1' to the duplicate
from collections import Counter

cols = accounts_customers_transactions_df.columns
col_counts = Counter(cols)
new_cols = []
seen = {}

for col_name in cols:
    if col_counts[col_name] > 1:
        if col_name not in seen:
            new_cols.append(col_name)
            seen[col_name] = 1
        else:
            new_cols.append(f"{col_name}_1")
            seen[col_name] += 1
    else:
        new_cols.append(col_name)

accounts_customers_transactions_df = accounts_customers_transactions_df.toDF(*new_cols)

display(accounts_customers_transactions_df)

In [0]:
%python

# Write cleaned DataFrames to Delta tables
accounts_df.write.format("delta").mode("overwrite").saveAsTable("silver_accounts")
customers_df.write.format("delta").mode("overwrite").saveAsTable("silver_customers")
branches_df.write.format("delta").mode("overwrite").saveAsTable("silver_branches")
transactions_df.write.format("delta").mode("overwrite").saveAsTable("silver_transactions")

In [0]:
%python
# Remove duplicate 'customer_id' column before saving
if "customer_id_1" in accounts_customers_transactions_df.columns:
    accounts_customers_transactions_df_clean = accounts_customers_transactions_df.drop("customer_id_1")
else:
    accounts_customers_transactions_df_clean = accounts_customers_transactions_df

accounts_customers_transactions_df_clean.write.format("delta").mode("overwrite").saveAsTable("silver_customer_account_txn")

In [0]:
%sql
select * from silver_customer_account_txn