In [0]:
from functools import reduce
from pyspark.sql.functions import col, regexp_extract, collect_list, lit
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, BooleanType, DecimalType, DateType, TimestampType, BinaryType, ShortType, ByteType
from delta.tables import DeltaTable
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock
import logging
import time
import traceback

In [0]:
dbutils.widgets.text("catalog","","Catalog")
dbutils.widgets.text("schema","","Schema")
dbutils.widgets.text("bucket_name","","Bucket Name")
dbutils.widgets.text("candidate_table","","Candiate Table")
dbutils.widgets.text("parquet_schema_table","","Parquet Schema Table")
dbutils.widgets.text("s3_inventory_table","","S3 Inventory Table")
dbutils.widgets.text("dataset_mapping_table","","Dataset Mapping Table")
dbutils.widgets.text("partition_table","","Partition Table")
dbutils.widgets.text("file_counts","","File Counts")

In [0]:
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
bucket_name = dbutils.widgets.get("bucket_name")
candidate_table = dbutils.widgets.get("candidate_table")
parquet_schema_table = dbutils.widgets.get("parquet_schema_table")
s3_inventory_table = dbutils.widgets.get("s3_inventory_table")
dataset_mapping_table = dbutils.widgets.get("dataset_mapping_table")
partition_table = dbutils.widgets.get("partition_table")
file_counts = dbutils.widgets.get("file_counts")

In [0]:
spark.conf.set("spark.sql.files.ignoreCorruptFiles", "false")

In [0]:
# Function to convert type strings to Spark types
def parse_type(dtype):
    dtype = dtype.lower().strip()

    if dtype.startswith("decimal"):
        scale = dtype[dtype.find("(") + 1 : dtype.find(")")].split(",")
        return DecimalType(int(scale[0]), int(scale[1]))

    if dtype in ("int8", "byte"):
        return ByteType()
    
    if dtype.startswith("bytetype"):
        return ByteType()

    if dtype in ("int16", "smallint"):
        return ShortType()
    
    if dtype.startswith("shorttype"):
        return ShortType()
    
    if dtype in ("int32", "integer"):
        return IntegerType()
    
    if dtype.startswith("integertype"):
        return IntegerType()
    
    if dtype == "int64":
        return LongType()

    if dtype.startswith("longtype"):
        return LongType()
    
    if dtype in ("string"):
        return StringType()

    if dtype.startswith("stringtype"):
        return StringType()

    if dtype.startswith("date32") or dtype == "date":
        return DateType()
    
    if dtype.startswith("datetype"):
        return DateType()

    if dtype.startswith("timestamp"):
        # Handles 'timestamp', 'timestamp[ms]', 'timestamp[us]' etc.
        return TimestampType()

    if dtype == "bool":
        return BooleanType()
    
    if dtype.startswith("booleantype"):
        return BooleanType()
    
    if dtype == "binary":
        return BinaryType()
    
    if dtype.startswith("binarytype"):
        return BinaryType()

    raise ValueError(f"Unsupported type: {dtype}")

In [0]:
# Create a lock for thread-safe Delta writes
write_lock = Lock()

def process_partition(edp_run_id, snapshot_date, bucket_name, bucket_prefix, s3_path, extended_schema, managed_table_name, max_retries=3):
    """
    Process a single partition: read parquet and write to Delta table with retry logic
    """
    path = f"{s3_path}edp_run_id={edp_run_id}/snapshot_date={snapshot_date}"
    
    
    try:
        print(f"Reading partition path: {path}")
        
        filtered_df = (
            spark.read
            .schema(extended_schema)
            .option("basePath", s3_path)
            .parquet(path)
        )
        
        # Acquire lock before writing to Delta table to ensure thread-safe writes
        with write_lock:
            # Write the dataframe to the Delta table (mergeSchema ensures evolution safety)
            (
                filtered_df.write
                .format("delta")
                .mode("append")
                .option("mergeSchema", "true")
                .saveAsTable(managed_table_name)
            )
        
        print(f"Successfully processed for {path}")
        return (bucket_name, bucket_prefix, edp_run_id, snapshot_date, "loaded")
        
    except Exception as e:
        print(f"Failed to process {path}")
        return (bucket_name, bucket_prefix, edp_run_id, snapshot_date, "failed")
        
    return (bucket_name, bucket_prefix, edp_run_id, snapshot_date, "failed")

In [0]:
df_table_candidates = spark.sql(f"""
                                select distinct s3_bucket_name, bucket_prefix, table_name
                                from  {catalog}.{schema}.{candidate_table}
                                where s3_bucket_name = '{bucket_name}' 
                                and candidate_for_managed_table_creation in ('true', 'false') 
                                and managed_table_created is null
                                and structured_file_count between {file_counts}
                                """)

# Collect all rows into Python memory
rows = df_table_candidates.collect()

inventory_table_name = f"{catalog}.{schema}.{s3_inventory_table}"
#print(inventory_table_name)

inventory_table = DeltaTable.forName(spark, inventory_table_name)

for row in rows:
    bucket_name = row["s3_bucket_name"]
    bucket_prefix = row["bucket_prefix"]
    dataset_name = row["table_name"]
    
    try:
        load_status_schema = StructType([
            StructField("bucket_name", StringType(), True),
            StructField("bucket_prefix", StringType(), True),
            StructField("edp_run_id", StringType(), True),
            StructField("snapshot_date", DateType(), True),
            StructField("load_status", StringType(), True)
        ])

        load_status_df = spark.createDataFrame([], load_status_schema)

        partition_key_combination = "edp_run_id, snapshot_date"
        print(f"Processing {dataset_name}")

        inventory_df = spark.sql(f"""SELECT distinct edp_run_id, snapshot_date
                                FROM 
                                (
                                    select distinct edp_run_id, try_cast(snapshot_date as date) as snapshot_date
                                    from {catalog}.{schema}.{s3_inventory_table}
                                    where 1 = 1
                                    and extension is not null 
                                    and lower(partition_key) = '{partition_key_combination}'
                                    and s3_bucket_name = '{bucket_name}' 
                                    and bucket_prefix = '{bucket_prefix}'
                                    and load_status is null
                                ) inventory
                                join
                                (
                                    select distinct run_id, try_cast(run_tag_value as date) as run_tag_value 
                                    from {partition_table}
                                    where dataset_name = '{dataset_name}'
                                    and lower(run_tag_key) = 'snapshot_date'
                                ) partition
                                on inventory.edp_run_id = partition.run_id
                                and inventory.snapshot_date = partition.run_tag_value
                                """)
        
        df_latest_snapshot = (inventory_df
                              .orderBy(col("snapshot_date").desc())
                              .limit(1)
                              )

        latest_snapshot_row = df_latest_snapshot.collect()[0]

        latest_edp_run_id = latest_snapshot_row["edp_run_id"]
        latest_snapshot_date = latest_snapshot_row["snapshot_date"]

        schema_json_df = spark.sql(f"""select schema_json 
                                    from {catalog}.{schema}.{parquet_schema_table}
                                    where 1 = 1
                                    and s3_bucket_name = '{bucket_name}'
                                    and bucket_prefix = '{bucket_prefix}'
                                    and file_path like '%edp_run_id={latest_edp_run_id}/snapshot_date={latest_snapshot_date}%'
                                    """)

        # Extract the schema_json value into a Python string variable
        schema_json = schema_json_df.first()['schema_json']
        #print(schema_json)

        schema_dict = json.loads(schema_json)
        #print(f"schema_dict: {schema_dict}")

        # Convert to StructType
        base_schema = StructType([
            StructField(f["name"], parse_type(f["type"]), f["nullable"])
            for f in schema_dict["fields"]
        ])
        #print(f"base_schema: {base_schema}")

        partition_columns = [partition_column.strip() for partition_column in partition_key_combination.split(",") if partition_column.strip()]

        # You can define mapping for known types here
        partition_type_map = {
            "edp_run_id": StringType(),
            "snapshot_date": DateType()
        }

        extended_fields = base_schema.fields.copy()

        for partition_column in partition_columns:
            col_type = partition_type_map.get(partition_column, StringType())  # default to STRING
            extended_fields.append(StructField(partition_column, col_type, True))

        extended_schema = StructType(extended_fields)
        #print(extended_schema)

        target_df = spark.sql(f"""select 
                            concat(dbx_catalog, '.', dbx_managed_table_schema, '.', dataset_name) as managed_table_name,
                            concat('s3://', s3_bucket_name, '/', bucket_prefix) as s3_path
                            from {catalog}.{schema}.{dataset_mapping_table}
                            where 1 = 1
                            and s3_bucket_name = '{bucket_name}'
                            and bucket_prefix = '{bucket_prefix}'
                            """)

        
        target_row = target_df.first()
        
        if target_row:
            managed_table_name, s3_path = target_row["managed_table_name"], target_row["s3_path"]
        
        #Check if tracker has partitions for this dataset
        if inventory_df.count() == 0: 
            print("No files found — No data to write.")

        else:
            print("Files identified — Starting the data laad.")
            print(f"Total Partition to Load: {inventory_df.count()}")
            # Collect partition values as tuples
            partition_filters = [(row.edp_run_id, row.snapshot_date) for row in inventory_df.collect()]
            print(f"Creating Table: {managed_table_name}")

            spark.sql(f"""
                CREATE TABLE IF NOT EXISTS {managed_table_name} (
                    edp_run_id string,
                    snapshot_date date
                )
                USING DELTA
                TBLPROPERTIES (
                    'delta.columnMapping.mode' = 'name',
                    'delta.enableIcebergCompatV2' = 'true',
                    'delta.universalFormat.enabledFormats' = 'iceberg'
                )
                cluster by (edp_run_id, snapshot_date)
            """)

            # ============================================================
            # PARALLELIZED PARTITION PROCESSING WITH 40 THREADS
            # ============================================================
        
            max_workers = 40  # 40 threads for parallel processing
            successful_partitions = []
            failed_partitions = []
            
            print(f"Starting parallel processing of {len(partition_filters)} partitions with {max_workers} threads")
            
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                # Submit all partition processing tasks
                future_to_partition = {
                    executor.submit(
                        process_partition,
                        edp_run_id,
                        snapshot_date,
                        bucket_name,
                        bucket_prefix,
                        s3_path,
                        extended_schema,
                        managed_table_name
                    ): (edp_run_id, snapshot_date)
                    for edp_run_id, snapshot_date in partition_filters
                }
                
                # Process completed tasks as they finish
                completed = 0
                for future in as_completed(future_to_partition):
                    bucket_name, bucket_prefix, edp_run_id, snapshot_date, load_status = future.result()
                    completed += 1
                    
                    if load_status == "loaded":
                        successful_partitions.append((bucket_name, bucket_prefix, edp_run_id, snapshot_date, load_status))
                    else:
                        failed_partitions.append((bucket_name, bucket_prefix, edp_run_id, snapshot_date, load_status))
                    
                    # Progress update every 50 partitions
                    if completed % 50 == 0 or completed == len(partition_filters):
                        print(f"Progress: {completed}/{len(partition_filters)} partitions processed")
            
            
            successful_partitions_df = spark.createDataFrame(successful_partitions, load_status_schema)
            failed_partitions_df = spark.createDataFrame(failed_partitions, load_status_schema)

            load_status_df = load_status_df.union(successful_partitions_df).union(failed_partitions_df)
            #display(load_status_df)

            # Perform merge to update the target table
            inventory_table.alias("target").merge(
                load_status_df.alias("updates"),
                """
                target.s3_bucket_name = updates.bucket_name AND
                target.bucket_prefix = updates.bucket_prefix AND
                target.edp_run_id = updates.edp_run_id AND
                target.snapshot_date = updates.snapshot_date
                """
            ).whenMatchedUpdate(
                condition="target.load_status is null",
                set={"load_status": col("updates.load_status")}
            ).execute()


            # Summary
            print("\n" + "="*80)
            print(f"Processing Complete for {dataset_name}")
            print(f"Total partitions: {len(partition_filters)}")
            print(f"Successful: {len(successful_partitions)}")
            print(f"Failed: {len(failed_partitions)}")
            print("="*80)

            spark.sql(f"""update {catalog}.{schema}.{candidate_table}
                    set managed_table_created = 'true'
                    where 1 = 1
                    and managed_table_created is null
                    and s3_bucket_name = '{bucket_name}'
                    and bucket_prefix = '{bucket_prefix}'""")
            
            print(f"Migration Completed")

    except Exception as e: 
        error_message = str(e)
        stack = traceback.format_exc()
        print(f"Failed to process {dataset_name}: {error_message}")
        spark.sql(f"""update {catalog}.{schema}.{candidate_table}
                    set error_message = 'Failed to Process'
                    where 1 = 1
                    and managed_table_created is null
                    and s3_bucket_name = '{bucket_name}'
                    and bucket_prefix = '{bucket_prefix}'""")
        continue  # move to next dataset    