In [0]:
from functools import reduce
from pyspark.sql.functions import col, regexp_extract, collect_list, lit
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, BooleanType, DecimalType, DateType, TimestampType, BinaryType, ShortType, ByteType
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock
import logging
import time
import traceback

In [0]:
dbutils.widgets.text("catalog","","Catalog")
dbutils.widgets.text("schema","","Schema")
dbutils.widgets.text("bucket_name","","Bucket Name")
dbutils.widgets.text("candidate_table","","Candiate Table")
dbutils.widgets.text("parquet_schema_table","","Parquet Schema Table")
dbutils.widgets.text("s3_inventory_table","","S3 Inventory Table")
dbutils.widgets.text("dataset_mapping_table","","Dataset Mapping Table")
dbutils.widgets.text("s3_bucket_volume_mapping_table","","Bucket To Mapping Table")
dbutils.widgets.text("partition_table","","Partition Table")

In [0]:
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
bucket_name = dbutils.widgets.get("bucket_name")
candidate_table = dbutils.widgets.get("candidate_table")
parquet_schema_table = dbutils.widgets.get("parquet_schema_table")
s3_inventory_table = dbutils.widgets.get("s3_inventory_table")
dataset_mapping_table = dbutils.widgets.get("dataset_mapping_table")
s3_bucket_volume_mapping_table = dbutils.widgets.get("s3_bucket_volume_mapping_table")
partition_table = dbutils.widgets.get("partition_table")

In [0]:
spark.conf.set("spark.sql.files.ignoreCorruptFiles", "true")

In [0]:
# Function to convert type strings to Spark types
def parse_type(dtype):
    dtype = dtype.lower().strip()

    if dtype.startswith("decimal"):
        scale = dtype[dtype.find("(") + 1 : dtype.find(")")].split(",")
        return DecimalType(int(scale[0]), int(scale[1]))

    if dtype in ("int8", "byte"):
        return ByteType()

    if dtype in ("int16", "smallint"):
        return ShortType()
    
    if dtype in ("int32", "integer"):
        return IntegerType()

    if dtype == "int64":
        return LongType()

    if dtype in ("string",):
        return StringType()

    if dtype.startswith("date32") or dtype == "date":
        return DateType()

    if dtype.startswith("timestamp"):
        # Handles 'timestamp', 'timestamp[ms]', 'timestamp[us]' etc.
        return TimestampType()

    if dtype == "bool":
        return BooleanType()
    
    if dtype == "binary":
        return BinaryType()

    raise ValueError(f"Unsupported type: {dtype}")

In [0]:
# Create a lock for thread-safe Delta writes
write_lock = Lock()

def process_partition(edp_run_id, snapshot_date, volume_path, extended_schema, managed_table_name, max_retries=3):
    """
    Process a single partition: read parquet and write to Delta table with retry logic
    """
    path = f"{volume_path}edp_run_id={edp_run_id}/snapshot_date={snapshot_date}"
    
    for attempt in range(1, max_retries + 1):
        try:
            print(f"Reading partition path: {path}")
            
            filtered_df = (
                spark.read
                .schema(extended_schema)
                .option("basePath", volume_path)
                .parquet(path)
            )
            
            # Acquire lock before writing to Delta table to ensure thread-safe writes
            with write_lock:
                # Write the dataframe to the Delta table (mergeSchema ensures evolution safety)
                (
                    filtered_df.write
                    .format("delta")
                    .mode("append")
                    .option("mergeSchema", "true")
                    .saveAsTable(managed_table_name)
                )
            
            print(f"Successfully processed for {path}")
            return (edp_run_id, snapshot_date, "SUCCESS", None)
            
        except Exception as e:
            error_msg = f"Error processing partition {path} (Attempt {attempt}/{max_retries}): {str(e)}"
            print(error_msg)
            
            if attempt < max_retries:
                wait_time = attempt * 2  # Exponential backoff
                print(f"Waiting {wait_time} seconds before retry...")
                time.sleep(wait_time)
            else:
                print(f"Failed to process {path} after {max_retries} attempts")
                return (edp_run_id, snapshot_date, "FAILED", str(e))
    
    return (edp_run_id, snapshot_date, "FAILED", "Unknown error")


In [0]:
df_table_candidates = spark.sql(f"""
                                select distinct s3_bucket_name, bucket_prefix, table_name
                                from  {catalog}.{schema}.{candidate_table}
                                where s3_bucket_name = '{bucket_name}' 
                                and candidate_for_managed_table_creation in ('true', 'false') 
                                and managed_table_created is null
                                """)

# Collect all rows into Python memory
rows = df_table_candidates.collect()

for row in rows:
    bucket_name = row["s3_bucket_name"]
    bucket_prefix = row["bucket_prefix"]
    dataset_name = row["table_name"]
    
    try:
        
        partition_key_combination = "edp_run_id, snapshot_date"
        print(f"Processing {dataset_name}")

        partition_key_df = spark.sql(f"""select partition_key from {catalog}.{schema}.{s3_inventory_table}
                                        where 1 = 1
                                        and extension = 'parquet' 
                                        and s3_bucket_name = '{bucket_name}' 
                                        and bucket_prefix = '{bucket_prefix}' 
                                        and  partition_key = '{partition_key_combination}'
                                        and load_status is null
                                        and last_modified_time = (select max(last_modified_time) from {catalog}.{schema}.{s3_inventory_table} 
                                        where 1 = 1
                                        and extension = 'parquet' 
                                        and s3_bucket_name = '{bucket_name}' 
                                        and bucket_prefix = '{bucket_prefix}'
                                        and partition_key = '{partition_key_combination}'
                                        and load_status is null)
                                        limit 1
                                    """)
        
        partition_key_str = partition_key_df.first()["partition_key"]
        #print(partition_key_str)

        #Discover all available partitions from the base parquet dataset
        # print("Scanning dataset for available partitions")
        # all_partitions_df = (
        #     spark.read
        #     .option("basePath", volume_path)
        #     .parquet(volume_path)
        #     .select("edp_run_id", "snapshot_date")
        #     .distinct()
        # )
        #display(all_partitions_df)

        inventory_df = spark.sql(f"""SELECT distinct inventory.key, buc_vol.volume_name
                                ,CONCAT(buc_vol.volume_name, '/', inventory.key) AS volume_path_key 
                                FROM 
                                (select distinct s3_bucket_name, bucket_prefix, key 
                                from {catalog}.{schema}.{s3_inventory_table}
                                where 1 = 1
                                and extension = 'parquet' 
                                and partition_key = '{partition_key_str}'
                                and s3_bucket_name = '{bucket_name}' 
                                and bucket_prefix = '{bucket_prefix}'
                                and load_status is null) inventory
                                join 
                                (select s3_bucket_name, volume_name from  {catalog}.{schema}.{s3_bucket_volume_mapping_table}
                                where 1 = 1
                                and s3_bucket_name = '{bucket_name}') buc_vol
                                on inventory.s3_bucket_name = buc_vol.s3_bucket_name
                                """)
        
        #display(inventory_df)
        inventory_df = (
            inventory_df
            # Must have edp_run_id=.../snapshot_date=.../ structure
            .filter(col("volume_path_key").rlike(r"/edp_run_id=[^/]+/snapshot_date=[^/]+/"))
            # Exclude files that have multiple snapshot_date entries
            .filter(~col("volume_path_key").rlike(r"snapshot_date=[^/]+/.*snapshot_date="))
        )
        #display(inventory_df)

        inventory_df = (
            inventory_df
            .withColumn("edp_run_id", regexp_extract(col("volume_path_key"), r"edp_run_id=([^/]+)", 1))
            .withColumn("snapshot_date", regexp_extract(col("volume_path_key"), r"snapshot_date=([^/]+)", 1))
            .select("edp_run_id", "snapshot_date").distinct()
        )


        schema_json_df = spark.sql(f"""select schema_json 
                                    from {catalog}.{schema}.{parquet_schema_table}
                                    where 1 = 1
                                    and s3_bucket_name = '{bucket_name}'
                                    and bucket_prefix = '{bucket_prefix}'
                                    and last_modified_time = (select max(last_modified_time) 
                                                            from {catalog}.{schema}.{parquet_schema_table} 
                                                            where 1 = 1
                                                            and s3_bucket_name = '{bucket_name}'
                                                            and bucket_prefix = '{bucket_prefix}')
                                    limit 1
                                    """)
        
        # Extract the schema_json value into a Python string variable
        schema_json = schema_json_df.first()['schema_json']

        schema_dict = json.loads(schema_json)

        # Convert to StructType
        base_schema = StructType([
            StructField(f["name"], parse_type(f["type"]), f["nullable"])
            for f in schema_dict["fields"]
        ])

        partition_columns = [partition_column.strip() for partition_column in partition_key_str.split(",") if partition_column.strip()]

        # You can define mapping for known types here
        partition_type_map = {
            "edp_run_id": StringType(),
            "snapshot_date": DateType()
        }

        extended_fields = base_schema.fields.copy()

        for partition_column in partition_columns:
            col_type = partition_type_map.get(partition_column, StringType())  # default to STRING
            extended_fields.append(StructField(partition_column, col_type, True))

        extended_schema = StructType(extended_fields)
        #print(extended_schema)

        target_df = spark.sql(f"""select 
                            concat(dbx_catalog, '.', dbx_managed_table_schema, '.', dataset_name) as managed_table_name,
                            concat(buc_vol.volume_name, '/', bucket_prefix) as volume_path
                            from {catalog}.{schema}.{dataset_mapping_table} src_trgt
                            inner join {catalog}.{schema}.{s3_bucket_volume_mapping_table} buc_vol
                            on src_trgt.s3_bucket_name = buc_vol.s3_bucket_name
                            where 1 = 1
                            and src_trgt.s3_bucket_name = '{bucket_name}'
                            and src_trgt.bucket_prefix = '{bucket_prefix}'
                            """)
        target_row = target_df.first()
        
        if target_row:
            managed_table_name, volume_path = target_row["managed_table_name"], target_row["volume_path"]
        #print(volume_path)
        
        # 1️Read partition tracker
        partition_df = spark.sql(f"""
                                select run_id, run_tag_value from {partition_table}
                                where dataset_name = '{dataset_name}'
                                and lower(run_tag_key) = 'snapshot_date'
                                """)
        #display(partition_df)

        dataset_snapshot_col = "snapshot_date"
        dataset_run_id_col = "edp_run_id"

        tracker_snapshot_col = "run_tag_value"
        tracker_run_id_col = "run_id"

        # Align tracker columns to dataset schema
        partition_df_renamed = (
            partition_df
            .withColumnRenamed(tracker_snapshot_col, dataset_snapshot_col)
            .withColumnRenamed(tracker_run_id_col, dataset_run_id_col)
            .select(dataset_run_id_col, dataset_snapshot_col)
            .distinct()
        )
        #display(partition_df_renamed)

        # Perform join to get files for blessed partitions only
        print("Fetching files for blessed partitions")
        inventory_filtered_df = (
            inventory_df.alias("a")
            .join(
                partition_df_renamed.alias("b"),
                (col("a.snapshot_date") == col("b.snapshot_date")) &
                (col("a.edp_run_id") == col("b.edp_run_id")),
                how="inner"
            )
            .select("a.edp_run_id", "a.snapshot_date").distinct()
        )
        #display(inventory_filtered_df)

        #Check if tracker has partitions for this dataset
        if partition_df.count() == 0: 
            print("No partition information found in tracker — No data to write.")

        else:
            print("Partition information found — reading only matching partitions.")
            print(f"Total Partition to Load: {inventory_filtered_df.count()}")
            # Collect partition values as tuples
            partition_filters = [(row.edp_run_id, row.snapshot_date) for row in inventory_filtered_df.collect()]
            print(f"Creating Table: {managed_table_name}")

            spark.sql(f"""
                CREATE TABLE IF NOT EXISTS {managed_table_name} (
                    edp_run_id string,
                    snapshot_date date
                )
                USING DELTA
                TBLPROPERTIES (
                    'delta.columnMapping.mode' = 'name',
                    'delta.enableIcebergCompatV2' = 'true',
                    'delta.universalFormat.enabledFormats' = 'iceberg'
                )
                cluster by (edp_run_id, snapshot_date)
            """)

            # ============================================================
            # PARALLELIZED PARTITION PROCESSING WITH 40 THREADS
            # ============================================================
        
            max_workers = 40  # 40 threads for parallel processing
            successful_partitions = []
            failed_partitions = []
            
            print(f"Starting parallel processing of {len(partition_filters)} partitions with {max_workers} threads")
            
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                # Submit all partition processing tasks
                future_to_partition = {
                    executor.submit(
                        process_partition,
                        edp_run_id,
                        snapshot_date,
                        volume_path,
                        extended_schema,
                        managed_table_name
                    ): (edp_run_id, snapshot_date)
                    for edp_run_id, snapshot_date in partition_filters
                }
                
                # Process completed tasks as they finish
                completed = 0
                for future in as_completed(future_to_partition):
                    edp_run_id, snapshot_date, status, error = future.result()
                    completed += 1
                    
                    if status == "SUCCESS":
                        successful_partitions.append((edp_run_id, snapshot_date))
                    else:
                        failed_partitions.append((edp_run_id, snapshot_date, error))
                    
                    # Progress update every 50 partitions
                    if completed % 50 == 0 or completed == len(partition_filters):
                        print(f"Progress: {completed}/{len(partition_filters)} partitions processed")
            
            # Retry failed partitions if any
            if failed_partitions:
                print(f"\nRetrying {len(failed_partitions)} failed partitions...")
                retry_filters = [(edp_run_id, snapshot_date) for edp_run_id, snapshot_date, _ in failed_partitions]
                failed_partitions = []  # Reset
                
                with ThreadPoolExecutor(max_workers=max_workers) as executor:
                    future_to_partition = {
                        executor.submit(
                            process_partition,
                            edp_run_id,
                            snapshot_date,
                            volume_path,
                            extended_schema,
                            managed_table_name
                        ): (edp_run_id, snapshot_date)
                        for edp_run_id, snapshot_date in retry_filters
                    }
                    
                    for future in as_completed(future_to_partition):
                        edp_run_id, snapshot_date, status, error = future.result()
                        
                        if status == "SUCCESS":
                            successful_partitions.append((edp_run_id, snapshot_date))
                        else:
                            failed_partitions.append((edp_run_id, snapshot_date, error))
            
            # Summary
            print("\n" + "="*80)
            print(f"Processing Complete for {dataset_name}")
            print(f"Total partitions: {len(partition_filters)}")
            print(f"Successful: {len(successful_partitions)}")
            print(f"Failed: {len(failed_partitions)}")
            print("="*80)

            spark.sql(f"""update {catalog}.{schema}.{candidate_table}
                    set managed_table_created = 'true'
                    where 1 = 1
                    and managed_table_created is null
                    and s3_bucket_name = '{bucket_name}'
                    and bucket_prefix = '{bucket_prefix}'""")
            
            spark.sql(f"""update {catalog}.{schema}.{s3_inventory_table}
                    set load_status = 'loaded'
                    where 1 = 1
                    and load_status is null
                    and extension = 'parquet'
                    and s3_bucket_name = '{bucket_name}'
                    and bucket_prefix = '{bucket_prefix}'""")
            
            if failed_partitions:
                print("\nFailed partitions:")
                for edp_run_id, snapshot_date, error in failed_partitions:
                    print(f"  - edp_run_id={edp_run_id}, snapshot_date={snapshot_date}: {error}")
                
                raise Exception(f"Failed to process {len(failed_partitions)} partitions. Check logs for details.")
            else:
                print(f"All partitions processed successfully!")

    except Exception as e: 
        error_message = str(e)
        stack = traceback.format_exc()
        print(f"Failed to process {dataset_name}: {error_message}")
        spark.sql(f"""update {catalog}.{schema}.{candidate_table}
                    set error_message = 'Failed to Process'
                    where 1 = 1
                    and managed_table_created is null
                    and s3_bucket_name = '{bucket_name}'
                    and bucket_prefix = '{bucket_prefix}'""")
        continue  # move to next dataset    