In [0]:
# Parameter Setup and Configuration
dbutils.widgets.text("parquet_schema_table","")
dbutils.widgets.text("inventory_table","") 
dbutils.widgets.text("bucket_name","")
dbutils.widgets.text("managed_catalog","")
dbutils.widgets.text("partition_audit_table","")
dbutils.widgets.text("recon_remediation_table","")
dbutils.widgets.text("candidate_table","")
dbutils.widgets.text("dataset_mapping_table","")
dbutils.widgets.text("datasets","")
dbutils.widgets.text("catalog","")
dbutils.widgets.text("schema","")
##dbutils.widgets.text("source_volume_path","")
# dbutils.widgets.text("managed_schema","")

parquet_schema_table = dbutils.widgets.get("parquet_schema_table")
inventory_table = dbutils.widgets.get("inventory_table")
bucket_name=dbutils.widgets.get("bucket_name")
partition_audit_table=dbutils.widgets.get("partition_audit_table")
managed_catalog_name=dbutils.widgets.get("managed_catalog")
recon_remediation_table=dbutils.widgets.get("recon_remediation_table")
candidate_table=dbutils.widgets.get("candidate_table")
dataset_mapping_table=dbutils.widgets.get("dataset_mapping_table")
datasets=dbutils.widgets.get("datasets")
# managed_schema_wdgt=dbutils.widgets.get("managed_schema")
catalog=dbutils.widgets.get("catalog")
schema=dbutils.widgets.get("schema")

print(f"bucket_name: {bucket_name}")
print(f"candidate_table: {candidate_table}")
print(f"inventory_table: {inventory_table}")
print(f"managed_catalog: {managed_catalog_name}")
# print(f"managed_schema: {managed_schema}")
print(f"parquet_schema_table: {parquet_schema_table}")
print(f"recon_remediation_table: {recon_remediation_table}")
print(f"partition_audit_table: {partition_audit_table}")
print(f"dataset_mapping_table: {dataset_mapping_table}")

##MANAGED_SCHEMA_NAME=dbutils.widgets.get("managed_schema")
##volume=dbutils.widgets.get("source_volume_path")

In [0]:
from functools import reduce
from pyspark.sql.functions import col, regexp_extract, collect_list, lit
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, BooleanType, DecimalType, DateType, TimestampType, BinaryType, ShortType, ByteType, FloatType, DoubleType
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock
import logging
import time
import traceback
from pyspark.sql import Row
from datetime import datetime

spark.conf.set("spark.sql.files.ignoreCorruptFiles", "true")

In [0]:
df_failed_partitions = spark.sql(
    f"""
    SELECT DISTINCT
        inv.execution_id,
        inv.s3_bucket_name,
        inv.bucket_prefix,
        regexp_extract(inv.bucket_prefix, '([^/]+)[/]?$', 1) AS dataset_name,
        inv.edp_run_id,
        inv.snapshot_date,
        map.dbx_managed_table_schema AS managed_schema
    FROM {catalog}.{schema}.{inventory_table} inv
    LEFT JOIN {catalog}.{schema}.{dataset_mapping_table} map
      ON inv.s3_bucket_name = map.s3_bucket_name
     AND inv.bucket_prefix = map.bucket_prefix
     AND regexp_extract(inv.bucket_prefix, '([^/]+)[/]?$', 1) = map.dataset_name
    WHERE inv.load_status = 'failed'
      AND inv.extension IS NOT NULL
      AND inv.edp_run_id IS NOT NULL
      AND inv.snapshot_date >= '2020-01-01'
      AND inv.s3_bucket_name = '{bucket_name}'
      AND map.dataset_name in ({datasets})
    """
)
display(df_failed_partitions)

In [0]:
# this method is used for getting the base schema 
def parse_type(dtype):
    dtype = dtype.lower().strip()

    if dtype.startswith("decimal"):
        scale = dtype[dtype.find("(") + 1 : dtype.find(")")].split(",")
        return DecimalType(int(scale[0]), int(scale[1]))

    if dtype in ("int8", "byte"):
        return ByteType()
    
    if dtype.startswith("bytetype"):
        return ByteType()

    if dtype in ("int16", "smallint"):
        return ShortType()
    
    if dtype.startswith("shorttype"):
        return ShortType()
    
    if dtype in ("int32", "integer"):
        return IntegerType()
    
    if dtype.startswith("integertype"):
        return IntegerType()
    
    if dtype == "int64":
        return LongType()

    if dtype.startswith("longtype"):
        return LongType()
    
    if dtype in ("string"):
        return StringType()

    if dtype.startswith("stringtype"):
        return StringType()

    if dtype.startswith("date32") or dtype == "date":
        return DateType()
    
    if dtype.startswith("datetype"):
        return DateType()
    
    if dtype.startswith("floattype"):
        return FloatType()
    
    if dtype.startswith("doubletype"):
        return DoubleType()

    if dtype.startswith("timestamp"):
        # Handles 'timestamp', 'timestamp[ms]', 'timestamp[us]' etc.
        return TimestampType()

    if dtype == "bool":
        return BooleanType()
    
    if dtype.startswith("booleantype"):
        return BooleanType()
    
    if dtype == "binary":
        return BinaryType()
    
    if dtype.startswith("binarytype"):
        return BinaryType()
    
    raise ValueError(f"Unsupported type: {dtype}")

In [0]:
from pyspark.sql.functions import col
import json

partition_key_combination = "edp_run_id, snapshot_date"
dataset_schemas = {}

failed_datasets = (
    df_failed_partitions
    .select("s3_bucket_name", "bucket_prefix", "dataset_name")
    .distinct()
    .collect()
)

for row in failed_datasets:
    bucket_name = row["s3_bucket_name"]
    bucket_prefix = row["bucket_prefix"]
    dataset_name = row["dataset_name"]

    inventory_df = spark.sql(
        f"""
        SELECT DISTINCT edp_run_id, snapshot_date
        FROM (
            SELECT DISTINCT edp_run_id, try_cast(snapshot_date AS DATE) AS snapshot_date
            FROM {catalog}.{schema}.{inventory_table}
            WHERE extension IS NOT NULL
              AND partition_key = '{partition_key_combination}'
              AND s3_bucket_name = '{bucket_name}'
              AND bucket_prefix = '{bucket_prefix}'
              AND load_status ='loaded'
        ) inventory
        JOIN (
            SELECT DISTINCT run_id, try_cast(run_tag_value AS DATE) AS run_tag_value
            FROM {partition_audit_table}
            WHERE dataset_name = '{dataset_name}'
              AND lower(run_tag_key) = 'snapshot_date'
        ) partition
        ON inventory.edp_run_id = partition.run_id
           AND inventory.snapshot_date = partition.run_tag_value
        """
    )


    if not inventory_df.isEmpty():
        df_latest_snapshot = (
            inventory_df
            .orderBy(col("snapshot_date").desc())
            .limit(1)
        )

        latest_snapshot_row = df_latest_snapshot.collect()[0]
        latest_edp_run_id = latest_snapshot_row["edp_run_id"]
        latest_snapshot_date = latest_snapshot_row["snapshot_date"]

        print(latest_edp_run_id)
        print(latest_snapshot_date)

        schema_json_df = spark.sql(
            f"""
            SELECT schema_json
            FROM {catalog}.{schema}.{parquet_schema_table}
            WHERE bucket_prefix = '{bucket_prefix}'
              AND file_path LIKE '%edp_run_id={latest_edp_run_id}/snapshot_date={latest_snapshot_date}%'
            """
        )

        schema_json = schema_json_df.first()['schema_json']
        schema_dict = json.loads(schema_json)

        base_schema = StructType([
            StructField(f["name"], parse_type(f["type"]), f["nullable"])
            for f in schema_dict["fields"]
        ])

        partition_columns = [
            partition_column.strip()
            for partition_column in partition_key_combination.split(",")
            if partition_column.strip()
        ]

        partition_type_map = {
            "edp_run_id": StringType(),
            "snapshot_date": DateType()
        }

        extended_fields = base_schema.fields.copy()
        for partition_column in partition_columns:
            col_type = partition_type_map.get(partition_column, StringType())
            extended_fields.append(StructField(partition_column, col_type, True))

        extended_schema = StructType(extended_fields)

        # print(extended_schema)

        # Add to dictionary
        dataset_schemas[(bucket_name, bucket_prefix, dataset_name)] = extended_schema

# Now dataset_schemas is ready to use in your processing loop

[0;31m---------------------------------------------------------------------------[0m
[0;31mTypeError[0m                                 Traceback (most recent call last)
File [0;32m<command-811489465029817>, line 65[0m
[1;32m     54[0m [38;5;28mprint[39m(latest_snapshot_date)
[1;32m     56[0m schema_json_df [38;5;241m=[39m spark[38;5;241m.[39msql(
[1;32m     57[0m     [38;5;124mf[39m[38;5;124m"""[39m
[1;32m     58[0m [38;5;124m    SELECT schema_json[39m
[0;32m   (...)[0m
[1;32m     62[0m [38;5;124m    [39m[38;5;124m"""[39m
[1;32m     63[0m )
[0;32m---> 65[0m schema_json [38;5;241m=[39m schema_json_df[38;5;241m.[39mfirst()[[38;5;124m'[39m[38;5;124mschema_json[39m[38;5;124m'[39m]
[1;32m     66[0m schema_dict [38;5;241m=[39m json[38;5;241m.[39mloads(schema_json)
[1;32m     68[0m base_schema [38;5;241m=[39m StructType([
[1;32m     69[0m     StructField(f[[38;5;124m"[39m[38;5;124mname[39m[38;5;124m"[39m], parse_type(f[[38;5;

In [0]:
%python
def cast_failed_df_to_base_schema(failed_df, base_schema):
    exprs = []
    mismatched = []
    # Map failed_df columns to lowercase for matching
    failed_schema = {f.name.lower(): f.dataType for f in failed_df.schema.fields}
    # Create a mapping from lowercase to actual column name
    failed_col_map = {f.name.lower(): f.name for f in failed_df.schema.fields}
    for field in base_schema.fields:
        col_name = field.name
        col_name_lc = col_name.lower()
        target_type = field.dataType
        if col_name_lc in failed_schema:
            failed_type = failed_schema[col_name_lc]
            actual_col_name = failed_col_map[col_name_lc]
            if type(failed_type) != type(target_type) or (
                hasattr(target_type, "precision") and hasattr(failed_type, "precision") and
                (target_type.precision != failed_type.precision or target_type.scale != failed_type.scale)
            ):
                mismatched.append(col_name)
                exprs.append(col(actual_col_name).cast(target_type).alias(col_name))
            else:
                exprs.append(col(actual_col_name).alias(col_name))
    print(f"✓ {len(mismatched)} columns will be typecasted: {mismatched}")
    return failed_df.select(exprs)

In [0]:
# Complete updated processing cell

from pyspark.sql.functions import lit
from datetime import datetime
from pyspark.sql import Row
from collections import defaultdict

rows = df_failed_partitions.collect()
partitions_by_dataset = defaultdict(list)
for row in rows:
    key = (
        row['s3_bucket_name'],
        row['bucket_prefix'],
        row['dataset_name'],
        row['managed_schema']
    )
    partitions_by_dataset[key].append(row)

remediation_status_rows = []

for dataset_key, partition_rows in partitions_by_dataset.items():
    s3_bucket_name, bucket_prefix, dataset_name, managed_schema = dataset_key
    print(f"\n{'='*60}")
    print(f"Processing dataset: {dataset_name}")
    print(f"Bucket: {s3_bucket_name}")
    print(f"Prefix: {bucket_prefix}")
    print(f"Managed schema: {managed_schema}")
    print(f"Partition key: edp_run_id, snapshot_date")
    print(f"Total partitions found for dataset: {len(partition_rows)}")
    print(f"{'='*60}")

    extended_schema = dataset_schemas.get((s3_bucket_name, bucket_prefix, dataset_name))
    if extended_schema is None:
        print(f"✗ No schema found for {(s3_bucket_name, bucket_prefix, dataset_name)}, skipping ALL partitions for this dataset.")
        continue

    partitions_attempted = len(partition_rows)
    partitions_remediated = 0
    partitions_failed = 0
    error_msg = []
    execution_ids = set()

    for partition_index, row in enumerate(partition_rows, 1):
        execution_id = row['execution_id']
        run_id = row['edp_run_id']
        snapshot_date = row['snapshot_date']
        execution_ids.add(execution_id)

        print(f"\nProcessing partition {partition_index}/{partitions_attempted} for dataset {dataset_name}")
        print(f"run_id={run_id}, snapshot_date={snapshot_date}")
        
        try:
            path = f"s3://{s3_bucket_name}/{bucket_prefix}/edp_run_id={run_id}/snapshot_date={snapshot_date}/"
            print(f"Reading from s3 path {path}")
            failed_df = spark.read.parquet(path)
            print(f"✓ Read {failed_df.count()} records from failed partition")

            # Add partition columns if not present
            for partition_col in ["edp_run_id", "snapshot_date"]:
                if partition_col not in failed_df.columns:
                    if partition_col == "edp_run_id":
                        failed_df = failed_df.withColumn("edp_run_id", lit(run_id))
                    elif partition_col == "snapshot_date":
                        failed_df = failed_df.withColumn("snapshot_date", lit(snapshot_date))

            # Cast and align schema
            remediated_df = cast_failed_df_to_base_schema(failed_df, extended_schema)
            print(f"✓ Final remediated DataFrame has {len(remediated_df.columns)} columns")

            remediated_df.write\
            .mode("append")\
            .saveAsTable(f"{managed_catalog_name}.{managed_schema}.{dataset_name}")

            print(f"✓ Successfully appended {remediated_df.count()} records to {managed_catalog_name}.{managed_schema}.{dataset_name}")

            update_audit_query = f"""
            UPDATE {catalog}.{schema}.{inventory_table}
            SET load_status = 'loaded',
                last_modified_time = current_timestamp()
            WHERE edp_run_id = '{run_id}'
              AND snapshot_date = '{snapshot_date}'
              AND bucket_prefix = '{bucket_prefix}'
            """
            spark.sql(update_audit_query)

            print(f"✓ Updated inventory table for run_id: {run_id}, snapshot_date: {snapshot_date}")

            partitions_remediated += 1

        except Exception as e:
            print(f"✗ ERROR processing partition run_id={run_id}, snapshot_date={snapshot_date}: {str(e)}")
            error_msg.append(f"ERROR processing partition run_id={run_id}, snapshot_date={snapshot_date}: {str(e)}")
            partitions_failed += 1
            continue

    if partitions_remediated == partitions_attempted and partitions_failed == 0:
        update_candidates_query = f"""
        UPDATE {catalog}.{schema}.{candidate_table}
        SET recon_job_run = NULL
        WHERE execution_id IN ({','.join([f"'{eid}'" for eid in execution_ids])})
          AND table_name = '{dataset_name}'
        """
        spark.sql(update_candidates_query)
        
        print(f"✓ Updated ccbr_migration_table_candidates for table_name='{dataset_name}'")
    else:
        print(f"✗ Not updating candidate table for {dataset_name} as not all partitions loaded successfully")

    remediation_status_rows.append(
        Row(
            execution_id=",".join(execution_ids),
            s3_bucket_name=s3_bucket_name,
            bucket_prefix=bucket_prefix,
            dataset_name=dataset_name,
            remediation_attempted_time=datetime.now(),
            partitions_attempted=partitions_attempted,
            partitions_remediated=partitions_remediated,
            partitions_failed=partitions_failed,
            remediation_status="PASS" if partitions_failed == 0 else "FAIL",
            error_msg="; ".join(error_msg) if error_msg else None
        )
    )

print(f"\n{'='*60}")
print("ALL PARTITIONS PROCESSED!")
print(f"✓ Remediation summary written per dataset")
print(f"{'='*60}")

if remediation_status_rows:
    schema_df = StructType([
        StructField("execution_id", StringType(), True),
        StructField("s3_bucket_name", StringType(), True),
        StructField("bucket_prefix", StringType(), True),
        StructField("dataset_name", StringType(), True),
        StructField("remediation_attempted_time", TimestampType(), True),
        StructField("partitions_attempted", IntegerType(), True),
        StructField("partitions_remediated", IntegerType(), True),
        StructField("partitions_failed", IntegerType(), True),
        StructField("remediation_status", StringType(), True),
        StructField("error_msg", StringType(), True)
    ])
    load_remediation_df = spark.createDataFrame(remediation_status_rows, schema_df)
    display(load_remediation_df)

    load_remediation_df.write\
        .mode("append")\
        .saveAsTable(f"{catalog}.{schema}.{recon_remediation_table}")

    # Data loaded into recon_remediation table
