# CCBR Migration get_schema_from_parquet_files

Purpose:
This notebook reads the inventory table which contains the S3 buckets and file paths. 

For performance reasons, the notebook uses pyarrow in conjunction with Databricks volumes so we use a 
static map table for the conversion of S3 buckets to DB Volumes

The number of bucket_prefixes woul be passed in as a comma separated string and each prefix is processed.

The process will read the schemas for all the files, determine if the file is a valid parquet file
and also detect if the schema has changed from its baseline using a hash. The baseline is the oldest file.

The schema is stored in the schema table as a json string for detailed comparison as needed. 
 
Once complete, bucket prefix in the candidate table will be updated with a true (ready for migration) or false which means
that the schema has changed from the baseline which could cause a table migration to fail.

The compare_schema notebook can compare the schemas within a prefix and provide a detailed view of how the schema has changed over time to help make a decision about how to move forward.

Once all of the prefix parameters are processed, the job will exit.


In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    udf, col, lit, current_timestamp, hash as spark_hash,
    sum as spark_sum, count
)
from pyspark.sql.types import (
    StructType, StructField, StringType, BooleanType, TimestampType
)
import json
import re

In [0]:
dbutils.widgets.text("catalog", "", "Catalog Name")
dbutils.widgets.text("schema", "", "Schema Name")
dbutils.widgets.text("s3_inventory_table", "", "S3 Inventory Table")
dbutils.widgets.text("parquet_schema_table", "", "Parquet Schema Table")
dbutils.widgets.text("candidate_table", "", "Candidate Table")
dbutils.widgets.text("bucket_name", "","Bucket Name")
dbutils.widgets.text("file_counts", "","File Counts")
dbutils.widgets.text("partition_table", "","Partition Table")

In [0]:
catalog =                         dbutils.widgets.get("catalog")
app_schema =                      dbutils.widgets.get("schema")
parquet_schema_table =            dbutils.widgets.get("parquet_schema_table")
s3_inventory_table =              dbutils.widgets.get("s3_inventory_table")
candidate_table =                 dbutils.widgets.get("candidate_table")
bucket_name =                     dbutils.widgets.get("bucket_name")
file_counts =                     dbutils.widgets.get("file_counts")
partition_table =                     dbutils.widgets.get("partition_table")

In [0]:
# This function will be called for each prefix to be processed
# This notebook will collect the schemas based on the DB Volume

def getSql(pPrefix):
    partition_key_combination = "edp_run_id, snapshot_date"
    sql_query =  f"""
                    SELECT distinct
                         inv.file_path,
                         inv.s3_bucket_name,
                         inv.bucket_prefix
                    FROM 
                        (select distinct 
                            CONCAT('s3://', s3_bucket_name, '/', left(key, LENGTH(key) - POSITION('/' IN REVERSE(key)))) as file_path
                            ,edp_run_id, snapshot_date, s3_bucket_name, bucket_prefix
                        from {catalog}.{app_schema}.{s3_inventory_table}
                        where 
                            bucket_prefix = '{pPrefix}'
                            and extension is not null
                            and load_status is null
                            and lower(partition_key) = '{partition_key_combination}'
                        ) inv
                    JOIN 
                        (select 
                            s3_bucket_name, bucket_prefix
                        from {catalog}.{app_schema}.{candidate_table}
                        where 
                            bucket_prefix = '{pPrefix}' 
                            and candidate_for_managed_table_creation is null
                        ) c 
                    ON  inv.s3_bucket_name = c.s3_bucket_name 
                        and inv.bucket_prefix = c.bucket_prefix
                    JOIN
                        (
                            select distinct dataset_name, run_id, try_cast(run_tag_value as date) as run_tag_value 
                            from {partition_table}
                            where lower(run_tag_key) = 'snapshot_date'
                        ) partition
                    on regexp_extract(inv.bucket_prefix, '([^/]+)/?$', 1) = partition.dataset_name
                    and inv.edp_run_id = partition.run_id
                    and try_cast(inv.snapshot_date as date) = partition.run_tag_value
                """
    return sql_query

In [0]:
def getPrefixes():
    sql_query =  f"""
                    SELECT 
                         distinct c.execution_id, c.bucket_prefix
                     FROM 
                        {catalog}.{app_schema}.{candidate_table} c 
                     WHERE 
                         c.managed_table_created is null           AND
                         c.candidate_for_managed_table_creation is null AND
                         c.s3_bucket_name = '{bucket_name}' AND
                         c.structured_file_count between {file_counts}
                    """
    return sql_query

In [0]:
# --------------------------------------------------------------
# 1. Get prefixes
# --------------------------------------------------------------
dfBucketPrefixes = spark.sql(getPrefixes())
#display(dfBucketPrefixes)
bucket_prefixes = [(row['execution_id'], row['bucket_prefix']) for row in dfBucketPrefixes.collect()]

# --------------------------------------------------------------
# 2. Main loop
# --------------------------------------------------------------
for execution_id, prefix in bucket_prefixes:
    print(f"\n=== Processing prefix: {prefix} ===")

    # ---- 2.1 Get file list -------------------------------------------------
    try:
        file_df = spark.sql(getSql(prefix))
        file_count = file_df.count()
        print(f" file paths retrieved for {prefix} : {file_count}")

        # default_p = spark.sparkContext.defaultParallelism
        # partition_count = min(max(file_count // 1000, 200), default_p * 2)
        # file_df = file_df.repartition(partition_count)
    except Exception as e:
        print(f"Error querying file list: {e}")
        raise

    # ---- 2.2 COLLECT to driver --------------------------------------------
    file_paths = [
        (row.file_path,  row.s3_bucket_name, row.bucket_prefix)
        for row in file_df.select("file_path", "s3_bucket_name", "bucket_prefix").collect()
    ]

    if not file_paths:
        print(" No files found. Skipping.")
        continue

    print(f" Collected {len(file_paths)} file paths to driver...")

    # ---- 2.3 Extract schema ON DRIVER -------------------------------------
    schema_records = []
    for file_path, bucket, bucket_prefix in file_paths:
        print(f"Reading schema: {file_path}")
        try:
            df_schema = spark.read.parquet(file_path)
            schema = df_schema.schema

            fields = [
                {
                    "name": f.name.lower(),
                    "type": str(f.dataType),
                    "nullable": f.nullable,
                    "metadata": {}
                }
                for f in schema.fields
            ]
            fields = sorted(fields, key=lambda x: x["name"])
            schema_json = json.dumps({"type": "struct", "fields": fields}, sort_keys=True)

            is_corrupted = False
            error_msg = None
        except Exception as e:
            schema_json = "corrupted"
            is_corrupted = True
            error_msg = str(e).split("\n")[0]

        schema_records.append((
            file_path, schema_json, is_corrupted, error_msg,
            prefix, execution_id, bucket
        ))

    # ---- 2.4 Create DF ----------------------------------------------------
    schema_df = spark.createDataFrame(
        schema_records,
        StructType([
            StructField("file_path", StringType(), False),
            StructField("schema_json", StringType(), True),
            StructField("is_corrupted", BooleanType(), False),
            StructField("error_message", StringType(), True),
            StructField("bucket_prefix", StringType(), False),
            StructField("execution_id", StringType(), False),
            StructField("s3_bucket_name", StringType(), False),
        ])
    ).withColumn("run_timestamp", current_timestamp())

    # ---- 2.5 Write --------------------------------------------------------
    try:
        schema_df.write.mode("append").saveAsTable(f"{catalog}.{app_schema}.{parquet_schema_table}")
        print(f"Schemas written")
    except Exception as e:
        print(f"Delta write error: {e}")
        raise

    # ---- 2.6 Baseline -----------------------------------------------------
    baseline_schema = spark.sql(f"""
        SELECT first_value(schema_json) OVER (ORDER BY lower(file_path)) AS baseline_schema
        FROM {catalog}.{app_schema}.{parquet_schema_table}
        WHERE bucket_prefix = '{prefix}' AND execution_id = '{execution_id}' AND is_corrupted = false
        LIMIT 1
    """).collect()[0]["baseline_schema"]

    baseline_filepath = spark.sql(f"""
        SELECT first_value(file_path) OVER (ORDER BY lower(file_path)) AS baseline_filepath
        FROM {catalog}.{app_schema}.{parquet_schema_table}
        WHERE bucket_prefix = '{prefix}' AND execution_id = '{execution_id}' AND is_corrupted = false
        LIMIT 1
    """).collect()[0]["baseline_filepath"]

    # ---- 2.7 Partition ----------------------------------------------------
    def partition_structure(fp):
        parts = re.findall(r'([^/]+)=[^/]+', fp)
        return '/'.join(sorted(parts)) if parts else ""

    part_udf = udf(partition_structure, StringType())
    analysis_df = schema_df.withColumn("partition_structure", part_udf(col("file_path")))
    baseline_part = partition_structure(baseline_filepath)

    # ---- 2.8 Consistency --------------------------------------------------
    analysis_df = analysis_df.filter(
        (~col("is_corrupted")) & (col("execution_id") == lit(execution_id))
    )
    total_good = analysis_df.count()

    opt = analysis_df \
        .withColumn("schema_hash", spark_hash(col("schema_json"))) \
        .withColumn("base_hash", lit(spark_hash(lit(baseline_schema)))) \
        .withColumn("schema_match", col("schema_hash") == col("base_hash")) \
        .withColumn("part_match", col("partition_structure") == lit(baseline_part))

    stats = opt.agg(
        spark_sum(col("schema_match").cast("int")).alias("match_schema"),
        spark_sum(col("part_match").cast("int")).alias("match_part"),
        count("*").alias("total")
    ).collect()[0]

    schema_ok = stats["match_schema"] == total_good
    part_ok   = stats["match_part"]   == total_good

    corrupted = schema_df.filter(
        col("is_corrupted") & (col("execution_id") == lit(execution_id))
    ).count()

    candidate = not (corrupted > 0 or not schema_ok or not part_ok)

    # ---- 2.9 Summary ------------------------------------------------------
    result = {
        "total_non_corrupted_files": total_good,
        "corrupted_files": corrupted,
        "schema_consistent": schema_ok,
        "partition_consistent": part_ok,
    }
    result_json = json.dumps(result)

    print(f"Schema consistency: good={total_good}, corrupted={corrupted}, candidate={candidate}")

    # ---- 2.10 MERGE (FIXED) -----------------------------------------------
    exists = spark.sql(f"""
        SELECT 1 FROM {catalog}.{app_schema}.{candidate_table}
        WHERE bucket_prefix = '{prefix}' AND execution_id = '{execution_id}'
    """).limit(1).collect()

    if not exists:
        print("SKIPPED  no row in candidate table")
        print("-" * 80)
        continue

    try:
        spark.sql(f"""
            MERGE INTO {catalog}.{app_schema}.{candidate_table} AS tgt
            USING (
                SELECT
                    '{prefix}'      AS bucket_prefix,
                    {str(candidate).lower()}::boolean AS candidate_for_managed_table_creation,
                    '{execution_id}' AS execution_id,
                    '{result_json}'  AS schema_analysis_results
            ) AS src
            ON tgt.bucket_prefix = src.bucket_prefix
           AND tgt.execution_id   = src.execution_id
            WHEN MATCHED THEN UPDATE SET
                candidate_for_managed_table_creation = src.candidate_for_managed_table_creation,
                schema_analysis_results               = src.schema_analysis_results
        """)

    except Exception as e:
        print(f"MERGE error: {e}")
        raise

    print(f"Processed {len(file_paths)} files")
    print("-" * 80)

print("\n=== Completed ===")