# CCBR Migration Data Reconciliation Notebook
Purpose:
This notebook performs comprehensive data reconciliation between S3 source data and 
Databricks managed tables for the CCBR migration project. It executes multiple 
reconciliation types and stores results in a clean, simplified enhanced results table.

Reconcilation Types Performed:

1. ROW_COUNT
   - Compares total record counts between S3 source and managed target tables
   - Status: PASS/FAIL based on exact match
   - Metrics: Source count vs Target count

2. COLUMN_COUNT  
   - Compares number of columns between source and target schemas
   - Status: PASS/FAIL based on exact match
   - Metrics: Source column count vs Target column count

3. SCHEMA_VALIDATION
   - Validates schema compatibility and identifies column mismatches
   - Status: PASS if schemas match exactly, FAIL if differences found
   - Summary: Lists specific missing columns in source/target with names
   - Example: "Missing in source: [col1, col2] | Missing in target: [col3]"

4. VERIFIED
   - Compares Databricks managed table statistics with computed S3 equivalents
   - Statistics: num_nulls, distinct_count, min/max, avg_col_len, max_col_len
   - Comprehensive statistical comparison across all available metrics
   - Summary: "Statistics comparison: X/Y stats match (Z%)"
     - Methods: abs_sum (numeric)



# Parameter Setup and Configuration

Initializes all required Databricks widget parameters for dynamic runtime configuration. Captures catalog, schema, and table names for results, mapping, and candidate metadata. These parameters allow the notebook to be executed across different environments without code changes.

**Key Parameters:**
- `catalog_name`: Target Unity Catalog name
- `schema_name`: Schema for storing reconciliation results
- `managed_schema`: Schema containing managed tables to validate
- `results_table`: Main results table for reconciliation outcomes
- `mapping_table`: Source-to-target mapping configuration
- `candidate_table`: Table metadata and partition information
- `run_id` & `job_id`: Execution tracking identifiers

In [0]:
# Parameter Setup and Configuration
dbutils.widgets.text("catalog", "") #smuralik_catalog
dbutils.widgets.text("schema", "") #jpmc_ccbr
dbutils.widgets.text("recon_results_table", "") #ccbr_migration_recon_results
dbutils.widgets.text("dataset_mapping_table", "") #ccbr_migration_dataset_mapping
dbutils.widgets.text("candidate_table", "") #ccbr_migration_table_candidates
dbutils.widgets.text("partition_audit_table","") #89055_ctg_prod_exp.default.dataset_tags
dbutils.widgets.text("bucket_name","") 
dbutils.widgets.text("parquet_schema_table","")
dbutils.widgets.text("inventory_table","") #ccbr_migration_table_inventory
dbutils.widgets.text("dataset_name","") #ccbr_migration_table_inventory


# dbutils.widgets.text("run_id", "")
# dbutils.widgets.text("job_id", "")
# dbutils.widgets.text("managed_table_schema", "") #jpmc_ccbr_managed_tables
# dbutils.widgets.text("dataset_inventory_mapping_volume_location", "")

CATALOG_NAME = dbutils.widgets.get("catalog")
SCHEMA_NAME = dbutils.widgets.get("schema")
RESULTS_TABLE = dbutils.widgets.get("recon_results_table")
MAPPING_TABLE = dbutils.widgets.get("dataset_mapping_table")
CANDIDATES_TABLE = dbutils.widgets.get("candidate_table")
inventory_table = dbutils.widgets.get("inventory_table")
PARTITION_AUDIT_TABLE = dbutils.widgets.get("partition_audit_table")
parquet_schema_table = dbutils.widgets.get("parquet_schema_table")
BUCKET_NAME=dbutils.widgets.get("bucket_name") 
dataset_name = dbutils.widgets.get("dataset_name") #ccbr_migration_table_inventory


# RUN_ID = dbutils.widgets.get("run_id")
# JOB_ID = dbutils.widgets.get("job_id")
#Partition_Schema = 'dbx_89055_trusted_db_mdas_ais_hcd_dora_fdl_prod_exp'
#parquet_schema_table = 'ccbr_migration_parquet_schemas'

# MANAGED_SCHEMA = dbutils.widgets.get("managed_table_schema")
# VOLUME_LOCATION = dbutils.widgets.get("dataset_inventory_mapping_volume_location")


RESULTS_TABLE_FQN = f"{CATALOG_NAME}.{SCHEMA_NAME}.{RESULTS_TABLE}"
MAPPING_TABLE_FQN = f"{CATALOG_NAME}.{SCHEMA_NAME}.{MAPPING_TABLE}"
CANDIDATES_TABLE_FQN = f"{CATALOG_NAME}.{SCHEMA_NAME}.{CANDIDATES_TABLE}"
PARTITION_AUDIT_TABLE_FQN = f"{PARTITION_AUDIT_TABLE}"
inventory_table_FQN = f"{CATALOG_NAME}.{SCHEMA_NAME}.{inventory_table}"
# execution_id=""

print(f"Results table: {RESULTS_TABLE_FQN}")
print(f"Mapping table: {MAPPING_TABLE_FQN}")
print(f"Candidates table: {CANDIDATES_TABLE_FQN}")
print(f"Audit table: {PARTITION_AUDIT_TABLE_FQN}")
print(f"Inventory table: {inventory_table_FQN}")
print(f"dataset_name : {dataset_name}")
#89055_ctg_prod_exp.dbx_89055_trusted_db_mdas_ais_hcd_dora_fdl_prod_exp.ccbr_migration_table_inventory
#89055_ctg_prod_exp.dbx_89055_trusted_db_mdas_ais_hcd_dora_fdl_prod_exp.ccbr_migration_table_inventory

# Schema Definition

Defines the comprehensive schema structure for the reconciliation results table. This enhanced schema captures detailed metrics, execution metadata, and error information for each reconciliation check.

**Schema Components:**
- Execution metadata (ID, timestamp, duration)
- Reconciliation type and status
- Source and target metrics (JSON format)
- Statistical comparison results
- Summary and error logging

In [0]:
# Define recon results schema for enhanced results table with error_log column
from pyspark.sql.types import *
from datetime import datetime
import json

recon_results_schema = StructType([
    StructField("execution_id", StringType(), True),
    StructField("table_name", StringType(), True),
    StructField("recon_type", StringType(), True),
    StructField("exec_timestamp", TimestampType(), True),
    StructField("exec_duration", DoubleType(), True),
    StructField("status", StringType(), True),
    StructField("source_metrics", StringType(), True),
    StructField("target_metrics", StringType(), True),
    StructField("metrics_compared", IntegerType(), True),
    StructField("metrics_matched", IntegerType(), True),
    StructField("metrics_different", IntegerType(), True),
    StructField("summary", StringType(), True),
    StructField("error_log", StringType(), True)
])

# Core Helper Functions

Implements fundamental utility functions for data operations and schema analysis. These functions provide the foundation for reading data, comparing schemas, and performing basic reconciliation calculations.

**Key Functions:**
- `read_source_parquet()`: Reads S3 parquet data with error handling
- `get_schema_dict()`: Extracts detailed schema information including precision/scale
- `compare_schemas()`: Comprehensive schema comparison with type and precision validation
- `compute_variance()`: Calculates percentage differences between metrics

In [0]:
# DBTITLE 1,Helper Functions
# Core utility functions for data operations
from pyspark.sql.functions import *
from pyspark.sql import DataFrame
from functools import reduce
from pyspark.sql.functions import col, regexp_extract, collect_list, lit, to_date
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, BooleanType, DecimalType, DateType, TimestampType, BinaryType, ShortType, ByteType, FloatType, DoubleType
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock
import logging
import time
import traceback
# Configure Spark to display full content without truncation
spark.conf.set("spark.sql.repl.eagerEval.truncate", 10000)
spark.conf.set("spark.sql.debug.maxToStringFields", 1000)

# Configure pandas display options
import pandas as pd
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)

# this method is used for getting the base schema 
def parse_type(dtype):
    dtype = dtype.lower().strip()

    if dtype.startswith("decimal"):
        scale = dtype[dtype.find("(") + 1 : dtype.find(")")].split(",")
        return DecimalType(int(scale[0]), int(scale[1]))

    if dtype in ("int8", "byte"):
        return ByteType()
    
    if dtype.startswith("bytetype"):
        return ByteType()

    if dtype in ("int16", "smallint"):
        return ShortType()
    
    if dtype.startswith("shorttype"):
        return ShortType()
    
    if dtype in ("int32", "integer"):
        return IntegerType()
    
    if dtype.startswith("integertype"):
        return IntegerType()
    
    if dtype == "int64":
        return LongType()

    if dtype.startswith("longtype"):
        return LongType()
    
    if dtype in ("string"):
        return StringType()

    if dtype.startswith("stringtype"):
        return StringType()

    if dtype.startswith("date32") or dtype == "date":
        return DateType()
    
    if dtype.startswith("datetype"):
        return DateType()
    
    if dtype.startswith("floattype"):
        return FloatType()
    
    if dtype.startswith("doubletype"):
        return DoubleType()

    if dtype.startswith("timestamp"):
        # Handles 'timestamp', 'timestamp[ms]', 'timestamp[us]' etc.
        return TimestampType()

    if dtype == "bool":
        return BooleanType()
    
    if dtype.startswith("booleantype"):
        return BooleanType()
    
    if dtype == "binary":
        return BinaryType()
    
    if dtype.startswith("binarytype"):
        return BinaryType()
    
    raise ValueError(f"Unsupported type: {dtype}")

def read_source_parquet(bucket: str, prefix: str, dataset_name: str) -> DataFrame:
    # Reads parquet files from S3 path and returns a DataFrame
    #Get latest Schema
    partition_key_combination = "edp_run_id, snapshot_date"
    inventory_schema_df = spark.sql(f"""SELECT distinct edp_run_id, snapshot_date
                                FROM 
                                (
                                    select distinct edp_run_id, try_cast(snapshot_date as date) as snapshot_date
                                    from {inventory_table_FQN}
                                    where 1 = 1
                                    and extension is not null 
                                    and lower(partition_key) = '{partition_key_combination}'
                                    and s3_bucket_name = '{bucket}' 
                                    and bucket_prefix = '{prefix}'
                                    and load_status ='loaded'
                                    and try_cast(snapshot_date as date) >= '2020-01-01'
                                ) inventory
                                join
                                (
                                    select distinct run_id, try_cast(run_tag_value as date) as run_tag_value 
                                    from {PARTITION_AUDIT_TABLE_FQN}
                                    where dataset_name = '{dataset_name}'
                                    and lower(run_tag_key) = 'snapshot_date'
                                    and try_cast(run_tag_value as date) >= '2020-01-01'
                                ) partition
                                on inventory.edp_run_id = partition.run_id
                                and inventory.snapshot_date = partition.run_tag_value
                                """)
    if not inventory_schema_df.isEmpty():
        df_latest_snapshot = (inventory_schema_df
                                .orderBy(col("snapshot_date").desc())
                                .limit(1)
                                )

        latest_snapshot_row = df_latest_snapshot.collect()[0]

        latest_edp_run_id = latest_snapshot_row["edp_run_id"]
        latest_snapshot_date = latest_snapshot_row["snapshot_date"]
        schema_json_df = spark.sql(f"""select schema_json 
                                    from {CATALOG_NAME}.{SCHEMA_NAME}.{parquet_schema_table}
                                    where 1 = 1
                                    and s3_bucket_name = '{bucket}'
                                    and bucket_prefix = '{prefix}'
                                    and lower(file_path) like '%edp_run_id={latest_edp_run_id}/snapshot_date={latest_snapshot_date}%'
                                    """)

        # Extract the schema_json value into a Python string variable
        schema_json = schema_json_df.first()['schema_json']

        schema_dict = json.loads(schema_json)

        # Convert to StructType
        base_schema = StructType([
            StructField(f["name"], parse_type(f["type"]), f["nullable"])
            for f in schema_dict["fields"]
        ])

        partition_columns = [partition_column.strip() for partition_column in partition_key_combination.split(",") if partition_column.strip()]

        # You can define mapping for known types here
        partition_type_map = {
            "edp_run_id": StringType(),
            "snapshot_date": DateType()
        }

        extended_fields = base_schema.fields.copy()

        for partition_column in partition_columns:
            col_type = partition_type_map.get(partition_column, StringType())  # default to STRING
            extended_fields.append(StructField(partition_column, col_type, True))

        extended_schema = StructType(extended_fields)
        # print(extended_schema)

        path = f"s3://{bucket}/{prefix}" 
        s3_df=spark.read.format("parquet").schema(extended_schema).option("ignoreCorruptFiles", "true").load(path)
        #display(s3_df)

        partition_audit_df=spark.sql(f""" 
                                        select distinct dataset_name, run_id, run_tag_value from {PARTITION_AUDIT_TABLE_FQN}
                                        where lower(run_tag_key) = 'snapshot_date' 
                                        AND dataset_name ='{dataset_name}'
                                        AND try_cast(run_tag_value as date) >= '2020-01-01'
                                    """)
        # display(partition_audit_df)
        #print(bucket)
        Inventory_df=spark.sql(f""" 
                                    select distinct edp_run_id,snapshot_date from {inventory_table_FQN}
                                    where load_status in ('loaded','failed')
                                    and s3_bucket_name = '{bucket}'
                                    and bucket_prefix = '{prefix}'
                                    and extension is not null
                                    and try_cast(snapshot_date as date) >= '2020-01-01'
                                """)

        if partition_audit_df.count() > 0 and Inventory_df.count()>0: 
            # display(partition_audit_df)
            filter_df = (
                s3_df.join(
                    partition_audit_df,
                    on=[
                        to_date(s3_df["snapshot_date"], 'yyyy-MM-dd') == to_date(partition_audit_df["run_tag_value"],'yyyy-MM-dd'),
                        s3_df["edp_run_id"] == partition_audit_df["run_id"]
                    ],
                    how="inner"
                )
                .join (
                    Inventory_df,
                    on=[
                        to_date(s3_df["snapshot_date"], 'yyyy-MM-dd') == to_date(Inventory_df["snapshot_date"],'yyyy-MM-dd'),
                        s3_df["edp_run_id"] == Inventory_df["edp_run_id"]
                    ],
                    how="inner"
                )
                .select(s3_df["*"])
                .filter(to_date(col("snapshot_date"), 'yyyy-MM-dd') >= lit('2020-01-01'))
            )
            return filter_df
        else:
            return s3_df.filter(to_date(col("snapshot_date"), 'yyyy-MM-dd') >= lit('2020-01-01'))
    else:
        return None
        

# df=read_source_parquet("app-id-89055-dep-id-109792-uu-id-wj1in46hrkbl", "trusted/ops/auto_srvc_coll/credit_losses/","credit_losses")
# display(df)
# # df.select("snapshot_date", "edp_run_id").distinct().show(truncate=False)

def get_schema_dict(df: DataFrame) -> dict:
    # Converts DataFrame schema to a dictionary {column_name: (data_type, precision, scale)}
    schema_info = {}
    for field in df.schema.fields:
        data_type = field.dataType.simpleString()
        precision = None
        scale = None
        
        # Extract precision and scale for decimal types
        if hasattr(field.dataType, 'precision') and hasattr(field.dataType, 'scale'):
            precision = field.dataType.precision
            scale = field.dataType.scale
            
        schema_info[field.name] = {
            'data_type': data_type,
            'precision': precision,
            'scale': scale
        }
    
    return schema_info

def compare_schemas(src_schema: dict, tgt_schema: dict):
    # Enhanced schema comparison including precision and scale
    src_schema = {k.lower() : v for k,v in src_schema.items()}
    tgt_schema = {k.lower() : v for k,v in tgt_schema.items()}
    missing_in_src = []
    missing_in_tgt = []
    mismatched = []
    precision_mismatched = []
    
    # Check for missing columns
    for col in tgt_schema:
        if col not in src_schema:
            missing_in_src.append(col)
    
    for col in src_schema:
        if col not in tgt_schema:
            missing_in_tgt.append(col)
    
    # Check for data type and precision mismatches
    for col in src_schema:
        if col in tgt_schema:
            src_info = src_schema[col]
            tgt_info = tgt_schema[col]
            
            # Check data type mismatch
            if src_info['data_type'] != tgt_info['data_type']:
                mismatched.append(f"{col} (src: {src_info['data_type']}, tgt: {tgt_info['data_type']})")
            
            # Check precision/scale mismatch for decimal types
            elif (src_info['precision'] != tgt_info['precision'] or 
                  src_info['scale'] != tgt_info['scale']):
                if src_info['precision'] is not None or tgt_info['precision'] is not None:
                    precision_mismatched.append(
                        f"{col} (src: p={src_info['precision']},s={src_info['scale']}, "
                        f"tgt: p={tgt_info['precision']},s={tgt_info['scale']})"
                    )
    
    return missing_in_src, missing_in_tgt, mismatched, precision_mismatched

def compute_variance(src_cnt: int, tgt_cnt: int) -> float:
    # Computes percentage variance between source and target counts
    if src_cnt == 0:
        return 0.0
    return round(abs(src_cnt - tgt_cnt) / src_cnt * 100.0, 2)

# Advanced Column Validation Functions

Provides sophisticated statistical analysis capabilities for both Databricks managed tables and source data. Computes comprehensive column-level statistics including null counts, distinct values, min/max values, and absolute sums for numeric columns.

**Key Functions:**
- `get_table_stats()`: Extracts Databricks table statistics with fallback computation
- `compute_source_stats()`: Calculates equivalent statistics for source S3 data
- Handles both managed table metadata and direct computation
- Supports numeric aggregations (abs_sum) and string length analysis

In [0]:
# Advanced column-level validation and statistics computation
from pyspark.sql import functions as F
from pyspark.sql.types import *
import time
from pyspark.sql import DataFrame
numeric_cats = {"int", "long", "double", "float", "decimal", "bigint"}
DEFAULT_NUM_COLUMNS = 20

def get_table_stats(df: DataFrame, selected_columns: list = None, column_types: dict = None):
    """
    OPTIMIZED: Get MIN, MAX, ABS_SUM statistics for selected columns only
    Computes all stats in a single aggregation query for performance
    """
    try:
        target_df = df
        target_df = target_df.toDF(*[c.lower() for c in target_df.columns])
        if selected_columns:
            selected_columns = [c.lower() for c in selected_columns]
        # If no columns specified, return empty
        if not selected_columns:
            print(f"Warning: No columns selected for stats computation")
            return {}
        
        print(f"Computing target stats for {len(selected_columns)} columns...")
        
        # Build aggregation expressions for all columns in ONE query
        agg_exprs = []
        
        for col_name in selected_columns:
            # Check if column exists in DataFrame
            if col_name not in target_df.columns:
                print(f"Warning: Column {col_name} not found in target table")
                continue
            
            # MIN and MAX for all columns
            agg_exprs.append(min(col(col_name)).alias(f"{col_name}_min"))
            agg_exprs.append(max(col(col_name)).alias(f"{col_name}_max"))
            
            # ABS_SUM only for numeric columns
            is_numeric = False
            if column_types:
                for cat in ['int', 'decimal']:
                    if col_name in column_types.get(cat, []):
                        is_numeric = True
                        break
            
            if is_numeric:
                agg_exprs.append(sum(abs(col(col_name))).alias(f"{col_name}_abs_sum"))
        
        if not agg_exprs:
            print("Warning: No valid aggregation expressions built")
            return {}
        
        # Execute single aggregation query
        stats_row = target_df.agg(*agg_exprs).collect()[0]
        
        # Parse results into structured format
        col_stats = {}
        for col_name in selected_columns:
            if col_name not in target_df.columns:
                continue
                
            col_stats[col_name] = {}
            
            # Get MIN and MAX
            min_key = f"{col_name}_min"
            max_key = f"{col_name}_max"
            
            if min_key in stats_row.asDict():
                col_stats[col_name]['min'] = str(stats_row[min_key]) if stats_row[min_key] is not None else 'null'
            
            if max_key in stats_row.asDict():
                col_stats[col_name]['max'] = str(stats_row[max_key]) if stats_row[max_key] is not None else 'null'
            
            # Get ABS_SUM if it exists
            abs_sum_key = f"{col_name}_abs_sum"
            if abs_sum_key in stats_row.asDict():
                col_stats[col_name]['abs_sum'] = str(stats_row[abs_sum_key]) if stats_row[abs_sum_key] is not None else '0'
        
        return col_stats
        
    except Exception as e:
        print(f"Error getting stats {e}")
        import traceback
        print(traceback.format_exc())
        return {}
    
def compute_source_stats(df: DataFrame, selected_columns: list = None, column_types: dict = None):
    """
    OPTIMIZED: Compute MIN, MAX, ABS_SUM statistics for selected columns only
    Computes all stats in a single aggregation query for performance
    """
    try:
        df = df.toDF(*[c.lower() for c in df.columns])
        if selected_columns:
            selected_columns = [c.lower() for c in selected_columns]
        # If no columns specified, return empty
        if not selected_columns:
            print(f"Warning: No columns selected for stats computation")
            return {}
        
        print(f"Computing source stats for {len(selected_columns)} columns...")
        
        # Build aggregation expressions for all columns in ONE query
        agg_exprs = []
        
        for col_name in selected_columns:
            # Check if column exists in DataFrame
            if col_name not in df.columns:
                print(f"Warning: Column {col_name} not found in source DataFrame")
                continue
            
            # MIN and MAX for all columns
            agg_exprs.append(min(col(col_name)).alias(f"{col_name}_min"))
            agg_exprs.append(max(col(col_name)).alias(f"{col_name}_max"))
            
            # ABS_SUM only for numeric columns
            is_numeric = False
            if column_types:
                for cat in ['int', 'decimal']:
                    if col_name in column_types.get(cat, []):
                        is_numeric = True
                        break
            
            if is_numeric:
                agg_exprs.append(sum(abs(col(col_name))).alias(f"{col_name}_abs_sum"))
        
        if not agg_exprs:
            print("Warning: No valid aggregation expressions built")
            return {}
        
        # Execute single aggregation query
        stats_row = df.agg(*agg_exprs).collect()[0]
        
        # Parse results into structured format
        col_stats = {}
        for col_name in selected_columns:
            if col_name not in df.columns:
                continue
                
            col_stats[col_name] = {}
            
            # Get MIN and MAX
            min_key = f"{col_name}_min"
            max_key = f"{col_name}_max"
            
            if min_key in stats_row.asDict():
                col_stats[col_name]['min'] = str(stats_row[min_key]) if stats_row[min_key] is not None else 'null'
            
            if max_key in stats_row.asDict():
                col_stats[col_name]['max'] = str(stats_row[max_key]) if stats_row[max_key] is not None else 'null'
            
            # Get ABS_SUM if it exists
            abs_sum_key = f"{col_name}_abs_sum"
            if abs_sum_key in stats_row.asDict():
                col_stats[col_name]['abs_sum'] = str(stats_row[abs_sum_key]) if stats_row[abs_sum_key] is not None else '0'
        
        return col_stats
        
    except Exception as e:
        print(f"Error computing source stats: {e}")
        import traceback
        print(traceback.format_exc())
        return {}

# Load and Validate Mapping Data

Retrieves source-to-target table mapping configuration from the metadata tables. Handles both direct S3-to-table mappings and indirect mappings through staging layers. Uses complex SQL joins to resolve the complete data lineage.

**Functionality:**
- Queries mapping table for S3 source to managed table relationships
- Handles indirect mappings via staging tables
- Validates table existence in information_schema
- Returns complete source bucket/prefix to target table mappings

In [0]:
#Load and validate source-to-target mapping data
# Handles both direct S3-to-table mappings and indirect mappings through staging

def load_mapping_data():
   
    mapping_sql = f"""
        WITH managed_tables AS (
            SELECT DISTINCT table_name, execution_id
            from {CANDIDATES_TABLE_FQN} where  
            managed_table_created =true 
            and s3_bucket_name='{BUCKET_NAME}'
            and table_name in ({dataset_name})
          
        )
        SELECT 
        m.execution_id,
        i.dbx_catalog AS target_catalog_name, 
        i.dbx_managed_table_schema AS target_schema_name, 
        i.dataset_name AS target_table_name, 
        i.s3_bucket_name AS source_s3_bucket,
        i.bucket_prefix AS source_bucket_prefix
        -- concat(b.volume_name,"/",i.bucket_prefix) as source_bucket_prefix
        FROM {MAPPING_TABLE_FQN} i
        INNER JOIN managed_tables m ON i.dataset_name = m.table_name
    """
    
    # concat(i.dbx_catalog,'.', i.dbx_managed_table_schema,'.',i.dataset_name) AS target_table_name, 
    # print(mapping_sql)
    try:
        return spark.sql(mapping_sql)
    except Exception:
        return spark.createDataFrame([], StructType([
            StructField("target_catalog_name", StringType()),
            StructField("target_schema_name", StringType()),
            StructField("target_table_name", StringType()),
            StructField("source_s3_bucket", StringType()),
            StructField("source_bucket_prefix", StringType())
        ]))

mapping_df = load_mapping_data()
display(mapping_df)
#print(f"Loaded {mapping_df.count()} table mappings")


In [0]:
# =========================================================
# CHECKSUM HELPER FUNCTIONS
# =========================================================
from pyspark.sql import functions as F

def prepare_for_checksum(
    df,
    partition_cols=None
):
    """
    Standardizes DataFrame for checksum calculation:
    - Excludes partition columns
    - Orders columns consistently
    - Trims whitespace in string columns
    - Casts all columns to string and replaces nulls with 'NULL'
    """
    # partition_cols = partition_cols or []
    # cols_to_hash = [c for c in df.columns if c not in partition_cols]
    # cols_to_hash = sorted(cols_to_hash)
    # for c in cols_to_hash:
    #     df = df.withColumn(
    #         c,
    #         F.when(F.col(c).isNull(), F.lit("NULL"))
    #          .otherwise(F.trim(F.col(c).cast("string")))
    #     )
    # return df.select(cols_to_hash)
    partition_cols = partition_cols or []
    cols_to_hash = [c for c in df.columns if c not in partition_cols]
    cols_to_hash = sorted(cols_to_hash)
    # exprs = [
    #     F.when(F.col(c).isNull(), F.lit("NULL"))
    #      .otherwise(F.trim(F.col(c).cast("string")))
    #      .alias(c)
    #     for c in cols_to_hash
    # ]
    exprs=[
      F.col(c)
      for c in cols_to_hash
    ]
    return df.select(exprs)
    
def calculate_row_checksum(df, exclude_cols=None):
    """
    Returns a DataFrame with column 'checksum' (sha2 hex string) for each row.
    exclude_cols: list of partition columns to exclude from hashing.
    """
    try:
        exclude_cols = exclude_cols or []
        cols_to_hash = [c for c in df.columns if c not in exclude_cols]
        if not cols_to_hash:
            raise ValueError("No columns available for checksum after excluding partition columns")

        # concat_expr = F.concat_ws("||", *[F.coalesce(F.col(c).cast("string"), F.lit("NULL")) for c in cols_to_hash])
        concat_expr = F.concat_ws("||", *[F.col(c) for c in cols_to_hash])
        return df.withColumn("__checksum", F.sha2(concat_expr, 256)).select("__checksum")
    except Exception as e:
        print(f"Error calculating checksums: {e}")
        return None


# def compare_checksums(src_checksum_df, tgt_checksum_df, partition_dict=None):
#     try:
#         src_col = [c for c in src_checksum_df.columns if "checksum" in c.lower()][0]
#         tgt_col = [c for c in tgt_checksum_df.columns if "checksum" in c.lower()][0]

#         src_count = src_checksum_df.count()
#         tgt_count = tgt_checksum_df.count()

#         matched_df = src_checksum_df.join(
#             tgt_checksum_df, src_checksum_df[src_col] == tgt_checksum_df[tgt_col], "inner"
#         )
#         matched_count = matched_df.count()

#         ##status = "PASS" if src_count == tgt_count else "FAIL"
#         status = "PASS" if (src_count == tgt_count == matched_count) else "FAIL"

#         return {
#             "status": status,
#             "partition": partition_dict,
#             "source_record_count": src_count,
#             "target_record_count": tgt_count,
#             "matched_record_count": matched_count
#         }
#     except Exception as e:
#         return {"status": "ERROR", "error_message": str(e), "partition": partition_dict}
def compare_checksums(src_checksum_df, tgt_checksum_df, partition_dict=None):
    try:
        src_col = [c for c in src_checksum_df.columns if "checksum" in c.lower()][0]
        tgt_col = [c for c in tgt_checksum_df.columns if "checksum" in c.lower()][0]

        total_src_rows = src_checksum_df.count()
        total_tgt_rows = tgt_checksum_df.count()

        src_distinct = src_checksum_df.select(src_col).distinct().withColumnRenamed(src_col, "checksum")
        tgt_distinct = tgt_checksum_df.select(tgt_col).distinct().withColumnRenamed(tgt_col, "checksum")
        
        src_distinct_count = src_distinct.count()
        tgt_distinct_count = tgt_distinct.count()
        
        matched = src_distinct.join(tgt_distinct, "checksum", "inner")
        matched_count = matched.count()
        
        # Simple PASS/FAIL: if all distinct checksums match
        status = "PASS" if (src_distinct_count == tgt_distinct_count == matched_count) else "FAIL"

        return {
            "status": status,
            "partition": partition_dict,
            "source_record_count": total_src_rows,
            "target_record_count": total_tgt_rows,
            "matched_record_count": matched_count
        }
    except Exception as e:
        return {"status": "ERROR", "error_message": str(e), "partition": partition_dict}


# Reconciliation Helper Functions

Comprehensive collection of helper functions supporting all reconciliation types. Includes result creation, partition validation, and statistical comparison utilities that work together to provide complete data validation capabilities.

**Result Creation:**
- `create_recon_result()`: Standardized result factory for all reconciliation types
- Generates type-specific summaries and metrics comparison statistics
- Handles error logging and JSON formatting

**Partition Column Validation:**
- `get_s3_partition_columns()`: Extracts partition keys from S3 directory structure
- `get_databricks_partition_columns()`: Retrieves partition metadata from Databricks tables
- `compare_partition_columns()`: Identifies partition schema differences

**Partition Record Count Validation:**
- `get_partition_counts_s3()`: Computes per-partition record counts from S3 data
- `get_partition_counts_dbx()`: Extracts partition counts from managed tables
- `compare_partition_record_counts()`: Validates partition-level data consistency

**Supported Reconciliation Types:**
- ROW_COUNT, COLUMN_COUNT, SCHEMA_VALIDATION
- VERIFIED (statistical comparison)
- PARTITION_COLUMNS, PARTITION_RECORD_COUNT

In [0]:
# Reconciliation Helper Functions - Complete collection supporting all reconciliation types

# =====================================
# RESULT CREATION FUNCTIONS
# =====================================


# Configuration constant - change here to adjust tolerance globally
ABS_SUM_TOLERANCE_PCT = 0.001  
# 0.0001% tolerance for abs_sum comparisons

def compare_numeric_values(src_val, tgt_val, tolerance_pct=ABS_SUM_TOLERANCE_PCT):
    """
    High-performance numeric comparison with configurable tolerance
    Uses Python built-in abs() to avoid PySpark column conflict
    
    Args:
        src_val: Source value (string/numeric)
        tgt_val: Target value (string/numeric)
        tolerance_pct: Percentage tolerance (default 0.10%)
    
    Returns:
        tuple: (is_match: bool, abs_diff: float, diff_pct: float)
    """
    import builtins  # Import Python builtins to access built-in abs()
    
    # Fast path: exact string match
    if src_val == tgt_val:
        return True, 0.0, 0.0
    
    try:
        # Convert to string first for consistent handling
        src_str = str(src_val).strip()
        tgt_str = str(tgt_val).strip()
        
        # Handle null/empty values
        if src_str.lower() in ('null', 'none', ''):
            src_str = '0'
        if tgt_str.lower() in ('null', 'none', ''):
            tgt_str = '0'
        
        # Convert to float
        src_float = float(src_str)
        tgt_float = float(tgt_str)
        
        # Fast path: numeric equality
        if src_float == tgt_float:
            return True, 0.0, 0.0
        
        # Calculate absolute difference using Python's built-in abs()
        abs_diff = builtins.abs(src_float - tgt_float)
        
        # Use max absolute value for stable percentage calculation
        denominator = builtins.max(builtins.abs(src_float), builtins.abs(tgt_float))
        
        # Handle edge case: both very close to zero
        if denominator < 1e-10:
            return True, abs_diff, 0.0
        
        # Calculate percentage difference
        diff_pct = (abs_diff / denominator) * 100.0
        
        # Check if diff_pct is valid (not inf or nan)
        if diff_pct != diff_pct or diff_pct == float('inf'):  # Check for NaN or inf
            if abs_diff < 0.01:
                return True, abs_diff, 0.0
            else:
                return False, abs_diff, 100.0
        
        # Single comparison
        return diff_pct <= tolerance_pct, abs_diff, diff_pct
        
    except Exception as e:
        # Debug: print the error
        print(f"Error comparing values: src={src_val}, tgt={tgt_val}, error={str(e)}")
        # Fallback to string comparison
        return str(src_val) == str(tgt_val), 0.0, 0.0

def create_recon_result(table_name, recon_type, status,execution_id,source_data=None, target_data=None, 
                       error_msg=None, exec_duration=0.0, missing_src=None, missing_tgt=None,
                       mismatched=None, precision_mismatched=None,partitions_used=None,checksum_result=None):
    """Create a standardized result record for reconciliation outcomes"""
    
    import builtins  # Import builtins to access Python's built-in functions
    
    # execution_id = RUN_ID + "-" + JOB_ID
    execution_id = execution_id
    # print(execution_id)
    exec_time = datetime.now()
    
    source_metrics = json.dumps(source_data) if source_data else None
    target_metrics = json.dumps(target_data) if target_data else None
    
    metrics_compared = 0
    metrics_matched = 0
    metrics_different = 0
    summary = ""
    error_log = error_msg if error_msg else None
    
    # Calculate metrics based on reconciliation type
    if recon_type == "ROW_COUNT":
        if source_data is not None and target_data is not None:
            # Use Python's built-in abs function explicitly
            difference = builtins.abs(int(source_data) - int(target_data))
            if difference == 0:
                summary = f"Row count matches: Source={source_data}, Target={target_data}"
            else:
                summary = f"Row count difference: {difference} (Source={source_data}, Target={target_data})"
                status = "FAIL"
    
    elif recon_type == "COLUMN_COUNT":
        if source_data is not None and target_data is not None:
            # Use Python's built-in abs function explicitly
            difference = builtins.abs(int(source_data) - int(target_data))
            if difference == 0:
                summary = f"Column count matches: Source={source_data}, Target={target_data}"
            else:
                summary = f"Column count difference: {difference} (Source={source_data}, Target={target_data})"
                status = "FAIL"
    
    elif recon_type == "SCHEMA_VALIDATION":
        issues = []
        
        if missing_src:
            issues.append(f"Missing in source: {len(missing_src)} columns {missing_src[:3]}{'...' if len(missing_src) > 3 else ''}")
        
        if missing_tgt:
            issues.append(f"Missing in target: {len(missing_tgt)} columns {missing_tgt[:3]}{'...' if len(missing_tgt) > 3 else ''}")
        
        if mismatched:
            issues.append(f"Type mismatches: {len(mismatched)} columns {mismatched[:2]}{'...' if len(mismatched) > 2 else ''}")
        
        if precision_mismatched:
            issues.append(f"Precision mismatches: {len(precision_mismatched)} columns {precision_mismatched[:2]}{'...' if len(precision_mismatched) > 2 else ''}")
        
        if issues:
            summary = f"Schema differences found: {' | '.join(issues)}"
            status = "FAIL"
        else:
            summary = "Schema matches completely: All columns, types, and precision/scale are identical"
    
    elif recon_type == "Partiton_Data_count":
        if status == "PASS":
            error_msg = ''
            summary = (
                f"S3 distinct partition count = {source_data}, "
                f"Managed table distinct partition count = {target_data}. "
                "All partitions match."
            )
        else:
            status = "FAIL"
            summary = error_msg
            error_msg = ''
            error_log = None
        
    elif recon_type == "CHECKSUM_VALIDATION":
        if checksum_result is not None:
            try:
                summary = json.dumps({
                    "partition": checksum_result.get("partition", {}),
                    "source_partition_record_count": checksum_result.get("source_record_count", 0),
                    "target_partition_record_count": checksum_result.get("target_record_count", 0)
                })
                #source_metrics = None
                #target_metrics = None
                if checksum_result.get("status") != "PASS":
                    status = "FAIL"
            except Exception:
                summary = "Error processing checksum validation"
                status = "ERROR"
        else:
            summary = "Checksum validation: No checksum_result provided"
            status = "ERROR"    

    # VERIFIED section with partition info in summary and tolerance for abs_sum
    elif recon_type == "VERIFIED" and source_data and target_data and isinstance(source_data, dict) and isinstance(target_data, dict):
        # Extract metadata
        if partitions_used is None:
            partitions_used = []
        columns_verified = source_data.get('columns_verified', [])
        src_stats = source_data.get('statistics', {})
        tgt_stats = target_data.get('statistics', {})
        
        # Build partition summary efficiently
        partition_summary = ""
        if partitions_used:
            partition_summary = f"Partitions [{len(partitions_used)}]: " + '; '.join(
                f"({', '.join(f'{k}={v}' for k, v in part.items())})" 
                for part in partitions_used
            )
        
        # Get common columns using set intersection (faster than loops)
        common_cols = (set(src_stats.keys()) - {'_metadata'}) & set(tgt_stats.keys())
        
        # Initialize result lists
        matched_stats = []
        tolerance_matched_stats = []
        unmatched_stats = []
        
        # Single-pass comparison loop
        for col_name in common_cols:
            src_col_stats = src_stats.get(col_name)
            tgt_col_stats = tgt_stats.get(col_name)
            
            # Skip if not both dictionaries
            if not (isinstance(src_col_stats, dict) and isinstance(tgt_col_stats, dict)):
                continue
            
            # Get common stat types
            common_stat_types = set(src_col_stats.keys()) & set(tgt_col_stats.keys())
            
            for stat_name in common_stat_types:
                metrics_compared += 1
                src_stat = str(src_col_stats[stat_name])
                tgt_stat = str(tgt_col_stats[stat_name])
                
                # Apply tolerance for abs_sum, exact match for others
                if stat_name == 'abs_sum':
                    # Use tolerance comparison (0.10%)
                    is_match, abs_diff, diff_pct = compare_numeric_values(src_stat, tgt_stat)
                    
                    if is_match:
                        metrics_matched += 1
                        if abs_diff == 0.0:
                            # Exact match
                            matched_stats.append(f"{col_name}.{stat_name}={src_stat}")
                        else:
                            # Matched within tolerance
                            tolerance_matched_stats.append(
                                f"{col_name}.{stat_name}≈{src_stat}(Δ{diff_pct:.6f}%)"
                            )
                    else:
                        # Outside tolerance
                        metrics_different += 1
                        unmatched_stats.append(
                            f"{col_name}.{stat_name}(src={src_stat},tgt={tgt_stat},Δ{diff_pct:.6f}%)"
                        )
                else:
                    # Exact string comparison for min/max (fastest)
                    if src_stat == tgt_stat:
                        metrics_matched += 1
                        matched_stats.append(f"{col_name}.{stat_name}={src_stat}")
                    else:
                        metrics_different += 1
                        unmatched_stats.append(
                            f"{col_name}.{stat_name}(src={src_stat},tgt={tgt_stat})"
                        )
        
        # Build summary efficiently using list join
        if metrics_compared > 0:
            if metrics_different > 0:
                status = "FAIL"
            
            match_rate = (metrics_matched / metrics_compared) * 100.0
            
            # Build summary parts
            summary_parts = [
                f"Stats: {metrics_matched}/{metrics_compared} match ({match_rate:.1f}%)",
                f"Columns [{len(columns_verified)}]: {', '.join(columns_verified)}",
                partition_summary
            ]
            
            if matched_stats:
                summary_parts.append(f"Matched [{len(matched_stats)}]: {', '.join(matched_stats)}")
            
            if tolerance_matched_stats:
                summary_parts.append(
                    f"Tolerance-Matched [{len(tolerance_matched_stats)}]: {', '.join(tolerance_matched_stats)}"
                )
            
            if unmatched_stats:
                summary_parts.append(f"Unmatched [{len(unmatched_stats)}]: {', '.join(unmatched_stats)}")
            
            # Join all parts efficiently
            summary = " | ".join(summary_parts)
        else:
            summary = f"No common statistics to compare | {partition_summary}"
            status = "WARNING" if status == "PASS" else status
    
    
    # Handle error cases - CHANGED: FAIL to ERROR
    if error_msg and not summary:
        summary = f"{recon_type}: ERROR - {error_msg}"
        status = "ERROR"  # Changed from "FAIL" to "ERROR"
    
    # Default summary if not set
    if not summary:
        summary = f"{recon_type}: {status}"
    
    return {
        "execution_id": execution_id,
        "table_name": table_name,
        "recon_type": recon_type,
        "exec_timestamp": exec_time,
        "exec_duration": exec_duration,
        "status": status,
        "source_metrics": source_metrics,
        "target_metrics": target_metrics,
        "metrics_compared": metrics_compared,
        "metrics_matched": metrics_matched,
        "metrics_different": metrics_different,
        "summary": summary,
        "error_log": error_log
    }

# =====================================
# PARTITION COLUMN VALIDATION FUNCTIONS
# =====================================

def get_s3_partition_columns(bucket, prefix):
    """Extract partition columns from S3 path structure"""
    try:
        # First try to get from the candidates table
        candidates_query = f"""
        SELECT partition_key 
        FROM {CANDIDATES_TABLE_FQN}
        WHERE s3_bucket_name = '{bucket}' 
        AND bucket_prefix = '{prefix}'
        """ 
        # --code_change
        # candidates_query = f"""
        # SELECT partition_key 
        # FROM {CANDIDATES_TABLE_FQN}
        # WHERE s3_bucket_name = '{bucket}' 
        # AND table_name = '{table_name}'
        # """
        partition_cols = []
        candidates_df = spark.sql(candidates_query)
        if candidates_df.count() > 0:
            partition_key = candidates_df.first()["partition_key"]
            if partition_key and partition_key.strip():
                # Split by comma and clean up whitespace
                partition_cols = [col.strip() for col in partition_key.split(',')]
                return sorted(partition_cols)
        
        # Fallback: Read a small sample to infer partition structure
        # path = f"s3://{bucket}/{prefix}" ##code_change
        # path = prefix
        
       # List files to understand partition structure
        # files = dbutils.fs.ls(path)
        # partition_cols = []
        
        # ##Look for partition patterns like key=value in directory structure
        # for file_info in files:
        #     if file_info.isDir():
        #         dir_name = file_info.name.rstrip('/')
        #         if '=' in dir_name:
        #             partition_key = dir_name.split('=')[0]
        #             if partition_key not in partition_cols:
        #                 partition_cols.append(partition_key)
        #             # Iterate one more directory level
        #             sub_files = dbutils.fs.ls(file_info.path)
        #             for sub_file_info in sub_files:
        #                 if sub_file_info.isDir():
        #                     sub_dir_name = sub_file_info.name.rstrip('/')
        #                     if '=' in sub_dir_name:
        #                         sub_partition_key = sub_dir_name.split('=')[0]
        #                         if sub_partition_key not in partition_cols:
        #                             partition_cols.append(sub_partition_key)
        # return sorted(partition_cols)
    except Exception as e:
        print(f"Error getting S3 partition columns for {bucket}/{prefix}: {e}")
        return []

def get_databricks_partition_columns(table_name):
    """Get partition columns from Databricks table"""
    try:
        
        describe_df = spark.sql(f"DESCRIBE DETAIL {table_name}")
        partition_cols_row = describe_df.select("partitionColumns").first()
        
        if partition_cols_row and partition_cols_row["partitionColumns"]:
            # Parse the partition columns (they come as array)
            partition_cols = partition_cols_row["partitionColumns"]
            return sorted(partition_cols) if partition_cols else []
        else:
            return []
    except Exception as e:
        print(f"Error getting Databricks partition columns for {table_name}: {e}")
        return []

def compare_partition_columns(src_partitions, tgt_partitions):
    """Compare partition columns between source and target"""
    missing_in_src = [col for col in tgt_partitions if col not in src_partitions]
    missing_in_tgt = [col for col in src_partitions if col not in tgt_partitions]
    
    return missing_in_src, missing_in_tgt

# =====================================
# PARTITION RECORD COUNT VALIDATION FUNCTIONS
# =====================================

from pyspark.sql import functions as F

def get_partition_counts_s3(s3_path: str, partition_cols: list) -> dict:
    """
    Returns a dictionary of partition values -> record counts from S3 data.
    
    Parameters:
        s3_path (str): Base S3 path to the dataset.
        partition_cols (list): List of partition column names (can be one or more).
    
    Returns:
        dict: {(partition_value1, partition_value2, ...): count, ...}
    """
    try:
        # Read the data from S3
        df = spark.read.parquet(s3_path)

        # Compute counts per partition
        counts_df = df.groupBy(*partition_cols).agg(F.count("*").alias("cnt"))
        counts_df = counts_df.sort(*partition_cols)

        # Collect results and create dictionary with tuple keys
        partition_counts = {}
        for row in counts_df.collect():
            key = tuple(row[c] for c in partition_cols)
            partition_counts[key] = row["cnt"]

        return partition_counts
    except Exception as e:
        print(f"Error getting S3 partition counts: {e}")
        return {}


def get_partition_counts_dbx(target_table: str, partition_cols: list) -> dict:
    """
    Returns a dictionary of partition values -> record counts from a Databricks managed table.
    
    Parameters:
        target_table (str): Name of the Databricks table.
        partition_cols (list): List of partition column names (can be one or more).
    
    Returns:
        dict: {(partition_value1, partition_value2, ...): count, ...}
    """
    try:
        # Read table
        df = spark.read.table(target_table)

        # Compute counts per partition
        counts_df = df.groupBy(*partition_cols).agg(F.count("*").alias("cnt"))
        counts_df = counts_df.sort(*partition_cols)

        # Collect results into a dictionary
        partition_counts = {}
        for row in counts_df.collect():
            key = tuple(row[c] for c in partition_cols)
            partition_counts[key] = row["cnt"]

        return partition_counts
    except Exception as e:
        print(f"Error getting Databricks partition counts: {e}")
        return {}


def compare_partition_record_counts(s3_counts: dict, dbx_counts: dict, partition_cols: list):
    """
    Compare partition record counts between S3 and Databricks table.
    
    Parameters:
        s3_counts (dict): S3 partition counts
        dbx_counts (dict): Databricks partition counts  
        partition_cols (list): List of partition column names
    
    Returns:
        tuple: (status, summary, mismatched_partitions)
    """
    mismatched_partitions = []
    
    # Get all partition keys from both sources
    all_partitions = set(s3_counts.keys()) | set(dbx_counts.keys())
    
    # Check each partition
    for partition_key in sorted(all_partitions):
        s3_count = s3_counts.get(partition_key, 0)
        dbx_count = dbx_counts.get(partition_key, 0)
        
        if s3_count != dbx_count:
            # Format partition key for display
            if len(partition_cols) == 1:
                part_display = f"{partition_cols[0]}={partition_key[0]}"
            else:
                part_display = ", ".join([f"{col}={val}" for col, val in zip(partition_cols, partition_key)])
            
            mismatched_partitions.append({
                'partition': part_display,
                'source_count': s3_count,
                'target_count': dbx_count,
                'difference': abs(s3_count - dbx_count)
            })
    
    # Generate status and summary
    if not mismatched_partitions:
        status = "PASS"
        summary = f"All {len(all_partitions)} partitions have matching record counts"
    else:
        status = "FAIL"
        total_partitions = len(all_partitions)
        mismatched_count = len(mismatched_partitions)
        matched_count = total_partitions - mismatched_count
        
        # Show first few mismatched partitions in summary
        mismatch_details = []
        for mismatch in mismatched_partitions[:3]:  # Show first 3
            mismatch_details.append(
                f"{mismatch['partition']}: Source={mismatch['source_count']}, Target={mismatch['target_count']}"
            )
        
        remaining = len(mismatched_partitions) - 3
        if remaining > 0:
            mismatch_details.append(f"and {remaining} more partitions")
        
        summary = f"Partition record count differences: {matched_count}/{total_partitions} partitions match - Mismatched: [{'; '.join(mismatch_details)}]"
    
    return status, summary, mismatched_partitions

In [0]:

def select_columns_for_verification(df: DataFrame, partition_cols: list = None, max_columns: int = 8) -> tuple:
    """
    Select up to 8 columns for verification: 2 int, 2 decimal, 2 string, 2 date
    Excludes partition columns dynamically
    
    Args:
        df: DataFrame to analyze
        partition_cols: List of partition column names to exclude (if None, will try to detect)
        max_columns: Maximum columns to select (default 8)
    
    Returns:
        tuple: (list of selected column names, dict with column types)
    """
    from pyspark.sql.types import IntegerType, LongType, DecimalType, DoubleType, FloatType, StringType, DateType, TimestampType
    
    schema = df.schema
    
    # If partition columns not provided, use empty list
    if partition_cols is None:
        partition_cols = []
    
    # Convert partition columns to lowercase for case-insensitive comparison
    partition_cols_lower = [pc.lower() for pc in partition_cols]
    
    # Categorize columns by type (exclude partition columns)
    int_cols = []
    decimal_cols = []
    string_cols = []
    date_cols = []
    
    for field in schema.fields:
        col_name = field.name
        
        # Skip partition columns (case-insensitive)
        if col_name.lower() in partition_cols_lower:
            continue
            
        col_type = field.dataType
        
        if isinstance(col_type, (IntegerType, LongType)):
            int_cols.append(col_name)
        elif isinstance(col_type, (DecimalType, DoubleType, FloatType)):
            decimal_cols.append(col_name)
        elif isinstance(col_type, StringType):
            string_cols.append(col_name)
        elif isinstance(col_type, (DateType, TimestampType)):
            date_cols.append(col_name)
    
    # Select 2 from each category
    selected = {
        'int': int_cols[:2],
        'decimal': decimal_cols[:2],
        'string': string_cols[:2],
        'date': date_cols[:2]
    }
    
    # Flatten to list
    all_selected = []
    for cols in selected.values():
        all_selected.extend(cols)
    
    # print(f"Column selection: Found {len(int_cols)} int, {len(decimal_cols)} decimal, {len(string_cols)} string, {len(date_cols)} date columns")
    # print(f"Excluded {len(partition_cols)} partition columns: {partition_cols}")
    
    return all_selected[:max_columns], selected

def get_latest_partitions(df: DataFrame, partition_cols: list = ["snapshot_date", "edp_run_id"], n: int = 5) -> tuple:
    """
    Filter DataFrame to include only the latest N partitions
    Uses partition pruning for optimal performance
    
    Returns:
        tuple: (filtered_df, list of partition dictionaries used)
    """
    try:
        # Get distinct partition values with minimal data movement
        latest_partitions = df.select(*partition_cols) \
            .distinct() \
            .orderBy(*[col(c).desc() for c in partition_cols]) \
            .limit(n) \
            .collect()  # Small result set, safe to collect
        
        if not latest_partitions:
            print("Warning: No partitions found")
            return df, []
        
        # NEW: Store partition values for summary
        partitions_used = []
        for partition_row in latest_partitions:
            partition_dict = {}
            for pc in partition_cols:
                value = partition_row[pc]
                if hasattr(value,'isoformat'):
                    partition_dict[pc] = value.isoformat()
                else:
                    partition_dict[pc] = str(value)
            partitions_used.append(partition_dict)
        
        # Build filter condition using OR logic without functools
        filter_condition = None
        
        for partition_row in latest_partitions:
            # Build AND condition for this partition
            partition_condition = None
            
            for pc in partition_cols:
                current_condition = (col(pc) == partition_row[pc])
                
                if partition_condition is None:
                    partition_condition = current_condition
                else:
                    partition_condition = partition_condition & current_condition
            
            # Add to overall OR condition
            if filter_condition is None:
                filter_condition = partition_condition
            else:
                filter_condition = filter_condition | partition_condition
        
        # Apply filter - this will leverage partition pruning in Parquet/Delta
        filtered_df = df.filter(filter_condition)
        
        partition_count = filtered_df.select(*partition_cols).distinct().count()
        # print(f"Filtered to {partition_count} partitions from latest {n}")
        
        return filtered_df, partitions_used  # **RETURN BOTH**
        
    except Exception as e:
        print(f"Error filtering latest partitions: {e}")
        return df, []
    
def prepare_dataframe_for_stats(df: DataFrame, partition_cols: list, selected_columns: list) -> DataFrame:
    """
    Prepare DataFrame with only needed columns and cache
    """
    try:
        # Select only partition columns + selected columns for verification
        columns_to_select = list(set(partition_cols + selected_columns))
        
        # Project early to reduce data size
        df_projected = df.select(*columns_to_select)
        
        # Cache the filtered, projected DataFrame
        df_projected.cache()
        
        # Force materialization with a count
        record_count = df_projected.count()
        # print(f"Cached {record_count:,} records with {len(columns_to_select)} columns for statistics computation")
        
        return df_projected
        
    except Exception as e:
        print(f"Error preparing DataFrame: {e}")
        return df
    


# Execute Reconciliation for Table

Main orchestration function that executes all reconciliation types for a single table mapping. Performs comprehensive validation including row counts, column counts, schema validation, statistical verification, and partition analysis.

**Reconciliation Types Executed:**
1. **ROW_COUNT**: Compares total record counts
2. **COLUMN_COUNT**: Validates column count consistency  
3. **SCHEMA_VALIDATION**: Checks data types, precision, and scale
4. **VERIFIED**: Comprehensive statistical comparison
5. **PARTITION_COLUMNS**: Validates partition schema consistency
6. **PARTITION_RECORD_COUNT**: Ensures partition-level data completeness

**Error Handling:**
- Graceful degradation on individual check failures
- Comprehensive error logging and status tracking
- Continues execution even if individual validations fail

In [0]:
# DBTITLE 1,Execute Reconciliation for Table
# Execute all reconciliation types for a single table (Row Count, Column Count, Schema Validation, VERIFIED, Partition Columns)
def execute_reconciliation_for_table(row):
    """Execute comprehensive reconciliation for a single table mapping"""
    src_bucket = row.source_s3_bucket
    src_prefix = row.source_bucket_prefix
    tgt_table = row.target_table_name
    tgt_catalog_name=row.target_catalog_name
    tgt_schema_name=row.target_schema_name
    execution_id=row.execution_id
    print(f"In execute_reconciliation_for_table method: {execution_id}")
    
    results = []
    
    if not src_bucket or not src_prefix:
        results.append(create_recon_result(
            tgt_table, "source_read_error", "ERROR",  # Changed from "FAIL" to "ERROR"
            error_msg="Missing S3 bucket or prefix in mapping",execution_id=execution_id
        ))
        return results
    
    # Initialize variables to store counts, schemas, and dataframes
    src_row_cnt = None
    tgt_row_cnt = None
    src_col_cnt = None
    tgt_col_cnt = None
    src_schema = None
    tgt_schema = None
    src_df = None
    tgt_df = None
    try:
        start_time = datetime.now()
        src_df = read_source_parquet(src_bucket, src_prefix,tgt_table)
        tgt_df = spark.table(f"{tgt_catalog_name}.{tgt_schema_name}.{tgt_table}")
        
        # Filter managed table for snapshot_date >= 2020
        if "snapshot_date" in tgt_df.columns:
            tgt_df = tgt_df.filter(to_date(col("snapshot_date"), 'yyyy-MM-dd') >= lit('2020-01-01'))
        
        if src_df is None:
            results.append(create_recon_result(
            tgt_table, "data_read_error", "ERROR",  # Changed from "FAIL" to "ERROR"
            error_msg="Source data is not present",execution_id=execution_id
            ))
            return results
        else:
            
            src_row_cnt = src_df.count()
            print(src_row_cnt)
            tgt_row_cnt = tgt_df.count()
            print(tgt_row_cnt)
            src_col_cnt = len(src_df.columns)
            tgt_col_cnt = len(tgt_df.columns)
            src_schema = get_schema_dict(src_df)
            tgt_schema = get_schema_dict(tgt_df)

            # --- Prepare distinct partition-level DataFrames ---
            partition_cols=["snapshot_date", "edp_run_id"]
            s3_partitions_df = (
                src_df.select(*partition_cols)
                    .distinct()
                    .withColumn("source", F.lit("S3"))
            )

            managed_partitions_df = (
                tgt_df.select(*partition_cols)
                        .distinct()
                        .withColumn("source", F.lit("MANAGED"))
            )

            # --- Distinct partition counts ---
            s3_distinct_count = s3_partitions_df.count()
            managed_distinct_count = managed_partitions_df.count()
            # --- Identify mismatched partitions ---
            missing_in_managed = (
                s3_partitions_df.select(*partition_cols)
                .subtract(managed_partitions_df.select(*partition_cols))
            )
            missing_in_s3 = (
                managed_partitions_df.select(*partition_cols)
                .subtract(s3_partitions_df.select(*partition_cols))
            )

            
            data_load_duration = (datetime.now() - start_time).total_seconds()
            
    except Exception as e:
        error_message = f"Failed to read source or target data: {str(e)}"
        results.append(create_recon_result(
            tgt_table, "data_read_error", "ERROR",  # Changed from "FAIL" to "ERROR"
            error_msg=error_message,execution_id=execution_id
        ))
        return results
    
    # Row Count Check - only proceed if we have valid counts
    if src_row_cnt is not None and tgt_row_cnt is not None:
        try:
            start_time = datetime.now()
            row_status = "PASS" if src_row_cnt == tgt_row_cnt else "FAIL"
            duration = (datetime.now() - start_time).total_seconds()
            
            results.append(create_recon_result(
                tgt_table, "ROW_COUNT", row_status,
                source_data=src_row_cnt,
                target_data=tgt_row_cnt,
                exec_duration=duration,execution_id=execution_id
            ))
            
        except Exception as e:
            results.append(create_recon_result(
                tgt_table, "ROW_COUNT", "ERROR",  # Changed from "FAIL" to "ERROR"
                source_data=src_row_cnt,
                target_data=tgt_row_cnt,
                error_msg=f"Row count comparison failed: {str(e)}",
                execution_id=execution_id
            ))
    else:
        results.append(create_recon_result(
            tgt_table, "ROW_COUNT", "ERROR",  # Changed from "FAIL" to "ERROR"
            error_msg="Could not obtain row counts from source or target",execution_id=execution_id
        ))
    
    # Column Count Check - only proceed if we have valid counts
    if src_col_cnt is not None and tgt_col_cnt is not None:
        try:
            start_time = datetime.now()
            col_status = "PASS" if src_col_cnt == tgt_col_cnt else "FAIL"
            duration = (datetime.now() - start_time).total_seconds()
            
            results.append(create_recon_result(
                tgt_table, "COLUMN_COUNT", col_status,
                source_data=src_col_cnt,
                target_data=tgt_col_cnt,
                exec_duration=duration,
                execution_id=execution_id
            ))
            
        except Exception as e:
            results.append(create_recon_result(
                tgt_table, "COLUMN_COUNT", "ERROR",  # Changed from "FAIL" to "ERROR"
                source_data=src_col_cnt,
                target_data=tgt_col_cnt,
                error_msg=f"Column count comparison failed: {str(e)}",
                execution_id=execution_id
            ))
    else:
        results.append(create_recon_result(
            tgt_table, "COLUMN_COUNT", "ERROR",  # Changed from "FAIL" to "ERROR"
            error_msg="Could not obtain column counts from source or target",
            execution_id=execution_id
        ))
    
    # Schema Validation (including precision validation) - only proceed if we have valid schemas
    if src_schema is not None and tgt_schema is not None:
        try:
            start_time = datetime.now()
            miss_src, miss_tgt, mismatched, precision_mismatched = compare_schemas(src_schema, tgt_schema)
            schema_status = "PASS" if not miss_src and not miss_tgt and not mismatched and not precision_mismatched else "FAIL"
            duration = (datetime.now() - start_time).total_seconds()
            
            results.append(create_recon_result(
                tgt_table,"SCHEMA_VALIDATION", schema_status,
                source_data=src_schema,
                target_data=tgt_schema,
                missing_src=miss_src,
                missing_tgt=miss_tgt,
                mismatched=mismatched,
                precision_mismatched=precision_mismatched,
                exec_duration=duration,
                execution_id=execution_id
            ))
            
        except Exception as e:
            results.append(create_recon_result(
                tgt_table, "SCHEMA_VALIDATION", "ERROR",  # Changed from "FAIL" to "ERROR"
                error_msg=f"Schema validation failed: {str(e)}",
                execution_id=execution_id
            ))
    
    else:
        results.append(create_recon_result(
            tgt_table, "SCHEMA_VALIDATION", "ERROR",  # Changed from "FAIL" to "ERROR"
            error_msg="Could not obtain schema information from source or target",
            execution_id=execution_id
        ))
    #Partiton Data count
     # --- Status and messages ---
    if s3_distinct_count == managed_distinct_count and missing_in_managed.count() == 0 and missing_in_s3.count() == 0:
        try:
            source_data = s3_distinct_count
            target_data = managed_distinct_count
    
            start_time = datetime.now()
            status = "PASS"
            duration = (datetime.now() - start_time).total_seconds()
            
            # source_metrics = f"Distinct partitions = {s3_distinct_count}"
            # target_metrics = f"Distinct partitions = {managed_distinct_count}"
      
            results.append(create_recon_result(
                tgt_table, "Partiton_Data_count", status,source_data=source_data,target_data=target_data,
                exec_duration=duration,
                execution_id=execution_id
            ))
        except Exception as e:
            results.append(create_recon_result(
                tgt_table, "Partiton_Data_count", "ERROR",  # Changed from "FAIL" to "ERROR"
                error_msg=f"Partiton_Data_count validation failed: {str(e)}",
                execution_id=execution_id
            )) 
        
    else:
        status = "FAIL"
        s3_sample = [tuple(r[c] for c in partition_cols) for r in s3_partitions_df.limit(3).collect()]
        managed_sample = [tuple(r[c] for c in partition_cols) for r in managed_partitions_df.limit(3).collect()]

        # source_metrics = f"Distinct partitions = {s3_distinct_count}, Sample = {s3_sample}"
        # target_metrics = f"Distinct partitions = {managed_distinct_count}, Sample = {managed_sample}"
        source_metrics = s3_distinct_count
        target_metrics = managed_distinct_count
        error_msg = ""
        if missing_in_managed.count() > 0:
            missing_vals = missing_in_managed.collect()
            error_msg += f"Missing in Managed: {missing_vals}. "
        if missing_in_s3.count() > 0:
            missing_vals = missing_in_s3.collect()
            error_msg += f"Missing in S3: {missing_vals}. "
        summary = (
            f"S3 distinct partition count = {s3_distinct_count}, "
            f"Managed distinct partition count = {managed_distinct_count}. "
            f"Partition mismatch detected."
            f"Error_Msg : {error_msg}." 
        )
        results.append(create_recon_result(
            tgt_table, "Partiton_Data_count", status,execution_id,source_metrics,target_metrics,
            error_msg = summary,
            mismatched=missing_vals,
            exec_duration=duration
        ))
    if src_df is not None and tgt_df is not None:
        try:
            start_time = datetime.now()

            # Fully qualified table name for managed table
            # fq_table = f"{tgt_table}"
            # print(f" table name for checksum : {fq_table}")

            
            partition_cols=["snapshot_date", "edp_run_id"]

            if not partition_cols:
                results.append(create_recon_result(
                    tgt_table,
                    "CHECKSUM_VALIDATION",
                    "SKIP",
                    error_msg="Table not partitioned or no partition columns found",
                    execution_id=execution_id
                ))
            else:
                # Get latest partition(s) from target dataframe using existing helper
                src_df_latest, src_partitions_latest = get_latest_partitions(
                    src_df, 
                    partition_cols=partition_cols,
                    n=1
                )
                tgt_df_latest, tgt_partitions_latest = get_latest_partitions(
                    tgt_df, 
                    partition_cols=partition_cols,
                    n=1
                )
                
                # **Use source partitions for summary (should match target)**
               # partitions_used = src_partitions_used

                if not src_partitions_latest:
                    results.append(create_recon_result(
                        tgt_table,
                        "CHECKSUM_VALIDATION",
                        "SKIP",
                        error_msg="No latest partition found",
                        execution_id=execution_id
                    ))
                else:
                    # Extract latest partition details
                    #partitions_used = src_partitions_used

                    # Filter both source and target DataFrames for that partition
                    # src_partition_df = src_df
                    # tgt_partition_df = tgt_df
                    # for pcol, pval in latest_partition_dict.items():
                    #     src_partition_df = src_partition_df.filter(F.col(pcol) == F.lit(pval))
                    #     tgt_partition_df = tgt_partition_df.filter(F.col(pcol) == F.lit(pval))

                    # Compute checksums (exclude partition columns)
                    src_df_std = prepare_for_checksum(src_df_latest, partition_cols=["snapshot_date", "edp_run_id"])
                    tgt_df_std = prepare_for_checksum(tgt_df_latest, partition_cols=["snapshot_date", "edp_run_id"])

                    src_checksum_df = calculate_row_checksum(src_df_std, exclude_cols=[])
                    tgt_checksum_df = calculate_row_checksum(tgt_df_std, exclude_cols=[])

                    if src_checksum_df is None or tgt_checksum_df is None:
                        results.append(create_recon_result(
                            tgt_table,
                            "CHECKSUM_VALIDATION",
                            "ERROR",
                            error_msg="Failed to calculate checksums",
                            execution_id=execution_id
                        ))
                    else:
                        checksum_result = compare_checksums(src_checksum_df, tgt_checksum_df, src_partitions_latest[0])
                        duration = (datetime.now() - start_time).total_seconds()

                        results.append(create_recon_result(
                            tgt_table,
                            "CHECKSUM_VALIDATION",
                            checksum_result.get("status", "ERROR"),
                            #source_data={"checksum_count": checksum_result.get("source_record_count")},
                            #target_data={"checksum_count": checksum_result.get("target_record_count")},
                            source_data=None,
                            target_data=None,
                            checksum_result=checksum_result,
                            exec_duration=duration,
                            execution_id=execution_id
                        ))

        except Exception as e:
            results.append(create_recon_result(
                tgt_table,
                "CHECKSUM_VALIDATION",
                "ERROR",
                error_msg=f"Checksum validation failed: {str(e)}",
                execution_id=execution_id
            ))
    else:
        results.append(create_recon_result(
            tgt_table, "CHECKSUM_VALIDATION", "ERROR",  # Changed from "FAIL" to "ERROR"
            error_msg="Could not obtain information from source or target",
            execution_id=execution_id
        ))

    # VERIFIED Check (Table Statistics) - OPTIMIZED for billions of records
    if src_df is not None and tgt_df is not None:
        try:
            start_time = datetime.now()
            
            # print(f"\n{'='*60}")
            # print(f"Starting VERIFIED check for: {tgt_table}")
            # print(f"{'='*60}")
            
            # **DYNAMIC PARTITION DETECTION**
            try:
                tgt_table_fqn = f"{tgt_catalog_name}.{tgt_schema_name}.{tgt_table}"
                describe_df = spark.sql(f"DESCRIBE DETAIL {tgt_table_fqn}")
                partition_cols_row = describe_df.select("partitionColumns").first()
                
                if partition_cols_row and partition_cols_row["partitionColumns"]:
                    partition_cols = partition_cols_row["partitionColumns"]
                    # print(f"Detected partition columns: {partition_cols}")
                else:
                    partition_cols = ["snapshot_date", "edp_run_id"]
                    # print(f"Using default partition columns: {partition_cols}")
            except Exception as e:
                # print(f"Could not detect partition columns, using defaults: {e}")
                partition_cols = ["snapshot_date", "edp_run_id"]
            
            # Step 1: Select columns for verification (max 8 columns)
            selected_columns, column_types = select_columns_for_verification(
                src_df, 
                partition_cols=partition_cols,
                max_columns=8
            )
            
            if not selected_columns:
                # print(f"Warning: No columns selected for verification for {tgt_table}")
                results.append(create_recon_result(
                    tgt_table, "VERIFIED", "ERROR",
                    error_msg="No suitable columns found for verification",
                    execution_id=execution_id
                ))
            else:
                # print(f"Selected {len(selected_columns)} columns for verification:")
                for cat, cols in column_types.items():
                    if cols:
                        print(f"  {cat}: {cols}")
                
                # Step 2: Filter to latest 5 partitions - **CAPTURE PARTITION VALUES**
                partition_filter_start = datetime.now()
                
                src_df_filtered, src_partitions_used = get_latest_partitions(
                    src_df, 
                    partition_cols=partition_cols,
                    n=1
                )
                tgt_df_filtered, tgt_partitions_used = get_latest_partitions(
                    tgt_df, 
                    partition_cols=partition_cols,
                    n=1
                )
                
                # **Use source partitions for summary (should match target)**
                partitions_used = src_partitions_used
                
                partition_filter_duration = (datetime.now() - partition_filter_start).total_seconds()
                # print(f"Partition filtering completed in {partition_filter_duration:.2f}s")
                # print(f"Partitions used: {partitions_used}")
                
                # Step 3: Project and cache only needed columns
                prep_start = datetime.now()
                src_df_prepared = prepare_dataframe_for_stats(
                    src_df_filtered, 
                    partition_cols,
                    selected_columns
                )
                tgt_df_prepared = prepare_dataframe_for_stats(
                    tgt_df_filtered, 
                    partition_cols,
                    selected_columns
                )
                prep_duration = (datetime.now() - prep_start).total_seconds()
                # print(f"DataFrame preparation completed in {prep_duration:.2f}s")
                
                # Step 4: Compute statistics using separate functions
                stats_start = datetime.now()
                
                # print("Computing SOURCE statistics...")
                src_stats = compute_source_stats(
                    src_df_prepared, 
                    selected_columns=selected_columns,
                    column_types=column_types
                )
                
                # print("Computing TARGET statistics...")
                tgt_table_fqn = f"{tgt_catalog_name}.{tgt_schema_name}.{tgt_table}"
                
                # Create temporary view for target stats computation
                # temp_view_name = f"temp_target_{tgt_table}_{int(datetime.now().timestamp())}"
                # tgt_df_prepared.createOrReplaceTempView(temp_view_name)
                
                tgt_stats = get_table_stats(
                    tgt_df_prepared,
                    selected_columns=selected_columns,
                    column_types=column_types
                )
                
                # Drop temporary view
                # spark.catalog.dropTempView(temp_view_name)
                
                stats_duration = (datetime.now() - stats_start).total_seconds()
                # print(f"Statistics computation completed in {stats_duration:.2f}s")
                
                # Cleanup: Unpersist cached DataFrames
                src_df_prepared.unpersist()
                tgt_df_prepared.unpersist()
                
                total_duration = (datetime.now() - start_time).total_seconds()
                print(f"Total VERIFIED check duration: {total_duration:.2f}s")
                print(f"{'='*60}\n")
                
                # Determine status
                verified_status = "PASS" if (src_stats and tgt_stats) else "ERROR"
                
                # **NEW: Package stats with metadata for summary**
                src_stats_with_metadata = {
                    "statistics": src_stats,
                    #"partitions_used": partitions_used,
                    "columns_verified": selected_columns
                }
                
                tgt_stats_with_metadata = {
                    "statistics": tgt_stats,
                    #"partitions_used": partitions_used,
                    "columns_verified": selected_columns
                }
                
                results.append(create_recon_result(
                    tgt_table, "VERIFIED", verified_status, 
                    source_data=src_stats_with_metadata, 
                    target_data=tgt_stats_with_metadata,
                    exec_duration=total_duration,
                    partitions_used=partitions_used,
                    execution_id=execution_id
                ))
                
        except Exception as e:
            print(f"ERROR in VERIFIED check for {tgt_table}: {e}")
            import traceback
            print(traceback.format_exc())
            
            results.append(create_recon_result(
                tgt_table, "VERIFIED", "ERROR",
                error_msg=f"Statistics verification failed: {str(e)}",
                execution_id=execution_id
            ))
    else:
        results.append(create_recon_result(
            tgt_table, "VERIFIED", "ERROR",
            error_msg="Could not access source or target data for statistics computation",
            execution_id=execution_id
        ))

    
    return results

# Run Full Reconciliation

Orchestrates the complete reconciliation process across all mapped tables. Iterates through the mapping configuration and executes comprehensive validation for each source-to-target table pair.

**Process Flow:**
- Loads all table mappings from configuration
- Executes reconciliation for each table sequentially
- Provides progress tracking and error reporting
- Collects all results for batch processing
- Handles critical errors with detailed logging

In [0]:
# Execute reconciliation across all mapped tables with progress tracking
def run_full_reconciliation():
    """Execute reconciliation for all tables in the mapping configuration"""
    all_results = []
    processed_count = 0
    
    for row in mapping_df.collect():
        try:
            processed_count += 1
            table_name = row.target_table_name
            print(f"Processing table {processed_count}: {table_name}")
            table_results = execute_reconciliation_for_table(row)
            all_results.extend(table_results)
            
        except Exception as e:
            print(f"Error processing {row.target_table_name}: {str(e)}")
            import traceback
            print(f"Full traceback: {traceback.format_exc()}")
            
            error_result = create_recon_result(
                table_name=row.target_table_name,
                recon_type="critical_error",
                status="ERROR",  # Changed from "FAIL" to "ERROR"
                error_msg=f"Critical processing error: {str(e)}", 
                execution_id=row.execution_id
            )
            all_results.append(error_result)
    
    print(f"Processed {processed_count} tables, generated {len(all_results)} results")
    return all_results

reconciliation_results = run_full_reconciliation()

# Save Results to Enhanced Table

Persists all reconciliation results to the Delta table for analysis and reporting. Creates a comprehensive audit trail of all validation activities with detailed metrics and error information.

**Features:**
- Batch insert of all reconciliation results
- Delta table format for ACID compliance
- Schema evolution support with overwriteSchema option
- Comprehensive error handling for save operations
- Returns DataFrame for immediate analysis and display

In [0]:
# Save comprehensive reconciliation results to the main Delta table
def save_results_to_enhanced_table(results):
    """Save reconciliation results to the enhanced Delta table"""
    if not results:
        print("No results to save")
        return None
    
    try:
        results_df = spark.createDataFrame(results, recon_results_schema)
        
        results_df.write \
            .format("delta") \
            .mode("append") \
            .option("overwriteSchema", "true") \
            .saveAsTable(RESULTS_TABLE_FQN)

        return results_df
    except Exception as e:
        print(f"Error saving results: {e}")
        return None

results_df = save_results_to_enhanced_table(reconciliation_results)
results_df=results_df.withColumn("summary",substring("summary",1,1000))\
                .withColumn("error_log",substring("error_log",1,1000))
display(results_df)

In [0]:
from pyspark.sql import functions as F
print(mapping_df.count())
if mapping_df.count()>0:
    #print("test")
# Expected recon types
    expected_recon_types = {"ROW_COUNT", "COLUMN_COUNT", "SCHEMA_VALIDATION","Partiton_Data_count","CHECKSUM_VALIDATION","VERIFIED"}
    recon_df =  results_df 
    candidate_df = spark.table(CANDIDATES_TABLE_FQN)
    # # Step 1: Aggregate recon results per table and execution
    agg_recon = (
        recon_df.groupBy("execution_id", "table_name")
        .agg(
            F.collect_set("recon_type").alias("recon_types"),
            F.collect_set("status").alias("statuses")
        )
    )
    
    # , "PARTITION_RECORD_COUNT"
    # Read both tables
    # recon_df_table = spark.table(RESULTS_TABLE_FQN)
    # # display(recon_df_table.limit(2))
    # # display(mapping_df.limit(2))
    # recon_df =( recon_df_table.alias("r1").join(
    #     mapping_df.alias("m1"),
    #     (col("r1.execution_id") == col("m1.execution_id")) &
    #     (col("r1.table_name") == col("m1.target_table_name")))
    #     .select("r1.*")
    #     )
   
    # CANDIDATES_TABLE_FQN = '89055_ctg_prod_exp.dbx_89055_trusted_db_mdas_ais_hcd_dora_fdl_prod_exp.ccbr_migration_table_candidates_test'
    
    # #print(execution_id)
    # #display(CANDIDATES_TABLE_FQN)
    # display(recon_df)
    
    # # Step 2: Derive job_run and recon_status in a simple way
    agg_recon = (
        agg_recon.withColumn(
            "job_run",
            F.when(
            ( F.size(F.array_except(F.array(*[F.lit(x) for x in expected_recon_types]), F.col("recon_types"))) == 0) | (F.array_contains(F.col("recon_types") ,"data_read_error")),
                "true"
            ).otherwise("false")
        )
        .withColumn(
            "recon_status",
            F.when(
                (F.col("job_run") == "true") & (~F.array_contains(F.col("statuses"), "FAIL")) & (~F.array_contains(F.col("statuses"), "ERROR")),
                "PASS"
            ).when(
                (F.col("job_run") == "true") & (F.array_contains(F.col("recon_types") ,"data_read_error")),
                "SKIPPED"
            )
            .otherwise("FAIL")
        )
    )
    # # Step 3: Use simple MERGE to update candidate table
    agg_recon.createOrReplaceTempView("agg_recon_updates")

    sql = f"""
    MERGE INTO {CANDIDATES_TABLE_FQN} AS c
    USING agg_recon_updates AS r
    ON c.table_name = r.table_name AND c.execution_id = r.execution_id  
    WHEN MATCHED THEN UPDATE SET
        c.recon_job_run = r.job_run
        ,c.recon_status = r.recon_status
        ,c.recon_execution_time = now()
    """
    spark.sql(sql)
