In [0]:
# Parameter Setup and Configuration
# dbutils.widgets.text("catalog", "") #smuralik_catalog
# dbutils.widgets.text("schema", "") #jpmc_ccbr
dbutils.widgets.text("candidate_table", "") #ccbr_migration_table_candidates
dbutils.widgets.text("partition_audit_table","") #89055_ctg_prod_exp.default.dataset_tags
dbutils.widgets.text("dataset_name","") 
dbutils.widgets.text("parquet_schema_table","")
dbutils.widgets.text("inventory_table","") #ccbr_migration_table_inventory
dbutils.widgets.text("managed_table","")
dbutils.widgets.text("schema_comparision_results","")
dbutils.widgets.text("dataset_mapping_table","")

# CATALOG_NAME = dbutils.widgets.get("catalog")
# SCHEMA_NAME = dbutils.widgets.get("schema")
parquet_schema_table = dbutils.widgets.get("parquet_schema_table")
inventory_table = dbutils.widgets.get("inventory_table")
dataset_name=dbutils.widgets.get("dataset_name")
managed_table=dbutils.widgets.get("managed_table")
candidate_table=dbutils.widgets.get("candidate_table")
partition_audit_table=dbutils.widgets.get("partition_audit_table")
schema_comparision_results=dbutils.widgets.get("schema_comparision_results")
dataset_mapping_table=dbutils.widgets.get("dataset_mapping_table")




# 89055_ctg_prod_managed.89055_trusted_db_hl_hm_loan_orgn_hcd_dora_fdl.vls_rt_acct_ext_m (done)
# 89055_ctg_prod_managed.89055_trusted_db_hl_hm_loan_srvc_hcd_dora_fdl.alsc_rt_acct_fee_fin_m (done)
# 89055_ctg_prod_managed.89055_trusted_db_hl_hm_loan_srvc_hcd_dora_fdl.rmi_loan_dim (done)
# 89055_ctg_prod_managed.89055_trusted_db_hl_hm_loan_srvc_hcd_dora_fdl.rmi_loan_mo_fct

In [0]:
from functools import reduce
from pyspark.sql.functions import col, regexp_extract, collect_list, lit
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, BooleanType, DecimalType, DateType, TimestampType, BinaryType, ShortType, ByteType
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock
import logging
import time
import traceback

spark.conf.set("spark.sql.files.ignoreCorruptFiles", "true")

In [0]:
# this method is used for getting the base schema 
def parse_type(dtype):
    dtype = dtype.lower().strip()

    if dtype.startswith("decimal"):
        scale = dtype[dtype.find("(") + 1 : dtype.find(")")].split(",")
        return DecimalType(int(scale[0]), int(scale[1]))

    if dtype in ("int8", "byte"):
        return ByteType()
    
    if dtype.startswith("bytetype"):
        return ByteType()

    if dtype in ("int16", "smallint"):
        return ShortType()
    
    if dtype.startswith("shorttype"):
        return ShortType()
    
    if dtype in ("int32", "integer"):
        return IntegerType()
    
    if dtype.startswith("integertype"):
        return IntegerType()
    
    if dtype == "int64":
        return LongType()

    if dtype.startswith("longtype"):
        return LongType()
    
    if dtype in ("string"):
        return StringType()

    if dtype.startswith("stringtype"):
        return StringType()

    if dtype.startswith("date32") or dtype == "date":
        return DateType()
    
    if dtype.startswith("datetype"):
        return DateType()

    if dtype.startswith("timestamp"):
        # Handles 'timestamp', 'timestamp[ms]', 'timestamp[us]' etc.
        return TimestampType()

    if dtype == "bool":
        return BooleanType()
    
    if dtype.startswith("booleantype"):
        return BooleanType()
    
    if dtype == "binary":
        return BinaryType()
    
    if dtype.startswith("binarytype"):
        return BinaryType()
    
    raise ValueError(f"Unsupported type: {dtype}")

In [0]:
from pyspark.sql import Row
df_table_candidates = spark.sql(f"""
                                select distinct s3_bucket_name, bucket_prefix, table_name
                                from {candidate_table}
                                where  
                                table_name = '{dataset_name}'
                                """)

# Collect all rows into Python memory
rows = df_table_candidates.collect()

for row in rows:
    bucket_name = row["s3_bucket_name"]
    bucket_prefix = row["bucket_prefix"]
    dataset_name = row["table_name"]


    partition_key_combination = "edp_run_id, snapshot_date"
    print(f"Processing {dataset_name}")
    print(f"Processing {partition_key_combination}")

    inventory_df = spark.sql(f"""SELECT distinct edp_run_id, snapshot_date
                            FROM 
                            (
                                select distinct edp_run_id, try_cast(snapshot_date as date) as snapshot_date
                                from {inventory_table}
                                where 1 = 1
                                and extension is not null 
                                and partition_key = '{partition_key_combination}'
                                and s3_bucket_name = '{bucket_name}' 
                                and bucket_prefix = '{bucket_prefix}'
                                --and (load_status is null or load_status = 'failed')
                            ) inventory
                            join
                            (
                                select distinct run_id, try_cast(run_tag_value as date) as run_tag_value 
                                from {partition_audit_table}
                                where dataset_name = '{dataset_name}'
                                and lower(run_tag_key) = 'snapshot_date'
                            ) partition
                            on inventory.edp_run_id = partition.run_id
                            and inventory.snapshot_date = partition.run_tag_value
                            """)

    df_latest_snapshot = (inventory_df
                            .orderBy(col("snapshot_date").desc())
                            .limit(1)
                            )

    latest_snapshot_row = df_latest_snapshot.collect()[0]

    latest_edp_run_id = latest_snapshot_row["edp_run_id"]
    latest_snapshot_date = latest_snapshot_row["snapshot_date"]

    print (f"Processing Partition {latest_edp_run_id}:{latest_snapshot_date}")

    schema_json_df = spark.sql(f"""select schema_json 
                                from {parquet_schema_table}
                                where 1 = 1
                                and bucket_prefix = '{bucket_prefix}'
                                and file_path like '%edp_run_id={latest_edp_run_id}/snapshot_date={latest_snapshot_date}%'
                                """)
    #and s3_bucket_name = '{bucket_name}'
    # Extract the schema_json value into a Python string variable
    schema_json = schema_json_df.first()['schema_json']

    schema_dict = json.loads(schema_json)

    # Convert to StructType
    base_schema = StructType([
        StructField(f["name"], parse_type(f["type"]), f["nullable"])
        for f in schema_dict["fields"]
    ])

    partition_columns = [partition_column.strip() for partition_column in partition_key_combination.split(",") if partition_column.strip()]

    # You can define mapping for known types here
    partition_type_map = {
        "edp_run_id": StringType(),
        "snapshot_date": DateType()
    }

    extended_fields = base_schema.fields.copy()

    for partition_column in partition_columns:
        col_type = partition_type_map.get(partition_column, StringType())  # default to STRING
        extended_fields.append(StructField(partition_column, col_type, True))

    extended_schema = StructType(extended_fields)
    # print(extended_schema)
    schema_dtls=[]
    for field in extended_schema:
        dtype=str(field.dataType)
        dtype=dtype.replace("Type","")
        schema_dtls.append(Row(Column_Name=field.name.upper(),Data_Type=dtype))

    base_df_schema=spark.createDataFrame(schema_dtls)
    display(base_df_schema)
 
    

In [0]:
# df_failed_partitions=spark.sql (f"""
# select a.* from
# (select run_id, try_cast(run_tag_value as date) run_tag_value from 89055_ctg_prod_exp.default.dataset_tags
#                                 where dataset_name = '{dataset_name}'
#                                 and lower(run_tag_key) = 'snapshot_date'

# ) a
# left join
# (select distinct edp_run_id, try_cast(snapshot_date as date) snapshot_date from {managed_table}) b
# on a.run_id = b.edp_run_id
# and a.run_tag_value = b.snapshot_date
# where b.edp_run_id is null """)

df_failed_partitions = spark.sql(
    f"""
    SELECT DISTINCT
        inv.execution_id,
        inv.s3_bucket_name,
        inv.bucket_prefix,
        regexp_extract(inv.bucket_prefix, '([^/]+)[/]?$', 1) AS dataset_name,
        inv.edp_run_id,
        inv.snapshot_date,
        map.dbx_managed_table_schema AS managed_schema
    FROM {inventory_table} inv
    LEFT JOIN {dataset_mapping_table} map
     -- ON inv.s3_bucket_name = map.s3_bucket_name
     ON inv.bucket_prefix = map.bucket_prefix
     AND regexp_extract(inv.bucket_prefix, '([^/]+)[/]?$', 1) = map.dataset_name
    WHERE inv.load_status = 'failed'
      AND inv.extension IS NOT NULL
      -- AND  inv.extension ='parquet'
      AND inv.edp_run_id IS NOT NULL
      AND inv.snapshot_date IS NOT NULL
      AND regexp_extract(inv.bucket_prefix, '([^/]+)[/]?$', 1)='{dataset_name}'
    """
)
display(df_failed_partitions)


In [0]:
failed_snapshot_row = df_failed_partitions.collect()
total_record_count=0
bucket_name= 'app-id-89055-dep-id-109792-uu-id-n6ph64imx36e'
bucket_prefix='trusted/analytics/data_dlvr/checks_ccbdl/'
for row in failed_snapshot_row:
    edp_run_id = row["edp_run_id"]
    snapshot_date = row["snapshot_date"]
    path=f"s3://{bucket_name}/{bucket_prefix}/edp_run_id={edp_run_id}/snapshot_date={snapshot_date}/"
    failed_df = spark.read.parquet(path)
    record_count=failed_df.count()
    total_record_count+=record_count
    # display(failed_df)

print(total_record_count)

In [0]:
from pyspark.sql.functions import col,lit,when
from pyspark.sql import Row

failed_snapshot_row = df_failed_partitions.collect()
final_report_df_schema=None
for row in failed_snapshot_row:
    failed_edp_run_id = row["edp_run_id"]
    failed_snapshot_date = row["snapshot_date"]

    print (f"Processing Partition {failed_edp_run_id}:{failed_snapshot_date}")

    schema_json_df = spark.sql(f"""select schema_json 
                                from {parquet_schema_table}
                                where 1 = 1
                                and bucket_prefix = '{bucket_prefix}'
                                and file_path like '%edp_run_id={failed_edp_run_id}/snapshot_date={failed_snapshot_date}%'
                                """)
    #and s3_bucket_name = '{bucket_name}'
    # Extract the schema_json value into a Python string variable
    schema_json = schema_json_df.first()['schema_json']

    schema_dict = json.loads(schema_json)

    # Convert to StructType
    base_schema = StructType([
        StructField(f["name"], parse_type(f["type"]), f["nullable"])
        for f in schema_dict["fields"]
    ])

    partition_columns = [partition_column.strip() for partition_column in partition_key_combination.split(",") if partition_column.strip()]

    # You can define mapping for known types here
    partition_type_map = {
        "edp_run_id": StringType(),
        "snapshot_date": DateType()
    }

    extended_fields = base_schema.fields.copy()

    for partition_column in partition_columns:
        col_type = partition_type_map.get(partition_column, StringType())  # default to STRING
        extended_fields.append(StructField(partition_column, col_type, True))

    extended_schema = StructType(extended_fields)
    # print(extended_schema)
    schema_dtls=[]
    for field in extended_schema:
        dtype=str(field.dataType)
        dtype=dtype.replace("Type","")
        schema_dtls.append(Row(Column_Name=field.name.upper(),Data_Type=dtype))

    failed_df_schema=spark.createDataFrame(schema_dtls)
    matched_colums_df=None
    matched_colums_df=base_df_schema.alias("s").join(
                                            failed_df_schema.alias("t"), 
                                            on="Column_Name", 
                                            how="inner"
                                        )
    matched_colums_df=matched_colums_df.select(
                            col("s.Column_Name").alias("Base_Column_Name"),
                            col("s.Data_Type").alias("Base_Data_Type"),
                            col("t.Column_Name").alias("Failed_Column_Name"),
                            col("t.Data_Type").alias("Failed_Data_Type")
                        )

    matched_colums_df=matched_colums_df.withColumn("Status",
                            when(
                                col("Base_Data_Type")==col("Failed_Data_Type"), lit("PASS")
                            ).otherwise(lit("FAIL"))
                        ).withColumn("dataset_name", lit(dataset_name))
    
    # matched
    
    if final_report_df_schema is None:
        final_report_df_schema=matched_colums_df
    else:
        final_report_df_schema=final_report_df_schema.unionAll(matched_colums_df)

base_schema_column_count=base_df_schema.count()
display(final_report_df_schema)
passed_final_report_df_schema=final_report_df_schema.filter("Status='PASS'")
passed_final_report_df_schema=passed_final_report_df_schema.groupBy("Base_Column_Name").count()
matched_schema_columns_count=passed_final_report_df_schema.count()

failed_final_report_df_schema=final_report_df_schema.filter("Status='FAIL'")
failed_final_report_df_schema=failed_final_report_df_schema.groupBy("Failed_Column_Name").count()
unmatched_schema_columns_count=failed_final_report_df_schema.count()
display(failed_final_report_df_schema)
print(f"Total Columns:{base_schema_column_count} \n Matched Columns:{matched_schema_columns_count} \n Unmatched Columns:{unmatched_schema_columns_count}")


[0;31m---------------------------------------------------------------------------[0m
[0;31mAttributeError[0m                            Traceback (most recent call last)
File [0;32m<command-4284970994240863>, line 81[0m
[1;32m     79[0m base_schema_column_count[38;5;241m=[39mbase_df_schema[38;5;241m.[39mcount()
[1;32m     80[0m display(final_report_df_schema)
[0;32m---> 81[0m passed_final_report_df_schema[38;5;241m=[39mfinal_report_df_schema[38;5;241m.[39mfilter([38;5;124m"[39m[38;5;124mStatus=[39m[38;5;124m'[39m[38;5;124mPASS[39m[38;5;124m'[39m[38;5;124m"[39m)
[1;32m     82[0m passed_final_report_df_schema[38;5;241m=[39mpassed_final_report_df_schema[38;5;241m.[39mgroupBy([38;5;124m"[39m[38;5;124mBase_Column_Name[39m[38;5;124m"[39m)[38;5;241m.[39mcount()
[1;32m     83[0m matched_schema_columns_count[38;5;241m=[39mpassed_final_report_df_schema[38;5;241m.[39mcount()

[0;31mAttributeError[0m: 'NoneType' object has no attribute 'filter'

In [0]:
final_report_df_schema.write.mode("append").saveAsTable(schema_comparision_results)

