# De-corrupt these tables!!
- There are bad data files (either missing in s3 OR with odd permissions) in Iceberg table metadata
- this script identifies the bad files across an entire warehouse and removes them with iceberg deletes
- these deletes can be rolled back if necessary just like any other iceberg operation.


The hope is that with these corrupted data files gone, we can proceed with deduplication of tables to fully fix them forever.

## Scope
Fixing includes 3 steps. 

Step one is to decorrupt the catalogs:
- my_warehouse -- ✅ fixed, had many corruptions 

Step 2 is to deduplicate the tables in the catalogs.
- my_warehouse -- ✅ done!

Step 3 is to promote the deduplicated tables.
- my_warehouse -- ✅ done!

## First, let's configure spark a little better for this EMR instance.


In [None]:
%%configure -f
{
  "conf": {
    "spark.app.name": "Tabular - Iceberg Table Decorruption",
    
    "spark.master": "yarn",
    "spark.executor.memory": "28g",
    "spark.executor.cores": "4",
    "spark.executor.instances": "24",
    "spark.driver.memory": "16g",
    "spark.driver.cores": "4",
    "spark.dynamicAllocation.enabled": "true",
      
    "spark.dynamicAllocation.minExecutors": "1",    
    
    "spark.sql.catalog.my_warehouse": "org.apache.iceberg.spark.SparkCatalog",
    "spark.sql.catalog.my_warehouse.catalog-impl": "org.apache.iceberg.rest.RESTCatalog",
    "spark.sql.catalog.my_warehouse.credential": "t-asdf:asdf-asdf",
    "spark.sql.catalog.my_warehouse.region": "us-east-1",
    "spark.sql.catalog.my_warehouse.uri": "https://api.tabular.io/ws/",
    "spark.sql.catalog.my_warehouse.warehouse": "my_warehouse",
      
    "spark.sql.defaultCatalog": "my_warehouse"
  }
}

## Next, let's create an iceberg database to contain our work and hold logs


In [None]:
spark.sql("create database if not exists tabular_support;")
spark.sql("""
    create table if not exists tabular_support.logs as (
        select
            current_timestamp() as event_ts,
            'id' as batch_id,
            'log' as event_type,
            'hello' as event_message
        limit 0
    );
""")

## Next, time for some helper functions


In [None]:
import org.apache.iceberg.spark.Spark3Util
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.SaveMode

import scala.util.Random

def getRandomBatchId(length: Int = 6): String = {
  val chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
  (1 to length).map(_ => chars(Random.nextInt(chars.length))).mkString
}

def log_to_tabular(batch_id: String, msg_type: String, msg: String): Unit = {
    println(s"$batch_id -- $msg")
    spark.sql(s"""
      INSERT INTO tabular_support.logs VALUES 
          (current_timestamp(), '$batch_id', '$msg_type', '$msg')
    """)
}

def getFileCount(spark: SparkSession, tableIdentifier: String): Long = {
  spark.table(s"$tableIdentifier.files").count()
}

/**
  * This function iterates through all databases except the tabular System database
  * to get all tables in the catalog.
  *
  * - Accepts a spark session, string catalog name, and batch_id for logging
  * - Returns a list of table identifiers in the format of "{database_name}.{table_name}"
 **/ 
def listTablesInCatalog(spark: SparkSession, batch_id: String): List[String] = {
  // Get the list of all databases in the specified catalog
  val databases = spark.catalog.listDatabases().collect().filter {
      db => db.name != "system" && db.name != "tabular_support" && db.name != "examples" && !db.name.endsWith("_deduplicated")
  }
  
  // Filter out the 'system' database and list tables in each remaining database
  val tables = databases.flatMap { db =>
    log_to_tabular(batch_id, "decorruption.listTablesInCatalog", s"Getting tables for database ${db.name}")
    spark.catalog.listTables(db.name).collect().map { table =>
      s"${db.name}.${table.name}"
    }
  }
  
  tables.toList
}

/**
  * This function identifies and optionally removes files in a given iceberg table identifier's
  * metadata. The files to be removed are files that do not appear to exist anymore.
  *
  * - Accepts a spark session, string table identifier, and a dry_run flag controlling when to
  *      actually delete the missing files or not. When dry_run = true, no deletes occur.
 **/ 
def performDeletes(spark: SparkSession, table_identifier: String, dry_run: Boolean = true, batch_id: String): Unit = {    
  val curried_log = (log_to_tabular _).curried
  val log = curried_log(batch_id)(s"decorruption.performDeletes.table=$table_identifier")
    
  // Create UDF
  val iceberg_table = Spark3Util.loadIcebergTable(spark, table_identifier)
  val iceberg_table_io = iceberg_table.io
  val fileExists = udf((path: String) => {
    try {
      iceberg_table_io.newInputFile(path).exists()
    } catch {
      case _: Exception => false
    }
  })

  // Load spark table
  val iceberg_files_table = spark.table(s"$table_identifier.files").persist()

  // Use UDF to get missing files
  val missing_files = iceberg_files_table
    .select("file_path")
    .withColumn("exists", fileExists(col("file_path")))
    .select("file_path", "exists")
    .where("exists != true")


  // Perform deletes if dry_run is false
  if (!dry_run) {
    log(s"Performing live decorruption of $table_identifier")
      
    // Collect missing file paths
    val files = missing_files.select("file_path").as[String].collect()
    
    // Return early if there are no files
    if (files.isEmpty) {
      log(s"✅ no missing files found for ${table_identifier}. Moving on 💪\n")
      return
    }
      
    // delete missing files from iceberg table
    var num_retries = 3
    while (num_retries > 0) {
        try {
            var del = Spark3Util.loadIcebergTable(spark, table_identifier).newDelete()
            val num_files_to_fix = files.size
            var counter = 0
            for (file <- files) {
              counter += 1
              log(s"\t🚧 fixing file number $counter out of $num_files_to_fix total to fix")
              del = del.deleteFile(file)
            }

            log(s"📬 all files handled for ${table_identifier}. Sending commit to Tabular")
            del.commit()
            log(s"✅ Commit succeeded for ${table_identifier}. Moving on 💪\n")
            return
        } catch {
            case _: Throwable => 
                num_retries -= 1 
        }
    }
      
    // Out of retries
    log(s"❌ Commit failed after 3 retries for ${table_identifier}. Moving on 💔\n")
    
  } else {
    val num_missing_files = missing_files.count()
    log(s"Table: $table_identifier\t\tnum_missing_files:$num_missing_files")
  }
}



def deduplicate_table(spark: SparkSession, table: String, dry_run: Boolean, batch_id: String): Unit = {
  val curried_log = (log_to_tabular _).curried
  val log = curried_log(batch_id)(s"deduplicate.table=$table")
  log(s"\nStarted deduplication for table $table")

  // Extract database name and table name
  val tableParts = table.split("\\.")
  val databaseName = tableParts(0)
  val tableName = tableParts(1)
  val deduplicatedDatabaseName = s"${databaseName}_deduplicated"

  // Check if the deduplicated table already exists
  val deduplicatedTableExists = spark.catalog.tableExists(s"${deduplicatedDatabaseName}.${tableName}")

  if (deduplicatedTableExists) {
    log(s"✅ Skipping deduplication for table $table because deduplicated version already exists")
  } else {
    // Create deduplicated database if it doesn't exist
    spark.sql(s"CREATE DATABASE IF NOT EXISTS $deduplicatedDatabaseName")

    // Load the table into a DataFrame
    val df = spark.table(table)

    // Generate a surrogate key for the entire row by concatenating all columns and applying a hash function
    val surrogate_key_col = hash(df.columns.map(c => col(c).cast("string")): _*).alias("surrogate_key")

    // Add the surrogate key column to the DataFrame
    val df_with_key = df.withColumn("surrogate_key", surrogate_key_col)

    // Deduplicate the DataFrame based on the surrogate key
    val deduplicated_df = df_with_key.dropDuplicates("surrogate_key").drop("surrogate_key")

    if (!dry_run) {
      // Write the deduplicated DataFrame to the new database
      deduplicated_df.write.mode(SaveMode.Overwrite).saveAsTable(s"$deduplicatedDatabaseName.$tableName")
      log(s"Table $table has been deduplicated and saved to $deduplicatedDatabaseName.$tableName")
    }

    log(s"✅ Finished deduplication for table $table")
  }
}


def archive_database(spark: SparkSession, databaseName: String): String = {
  // Check if the database name is in the refused list or ends with `_deduplicated` or `_archived`
  if (Set("system", "examples", "tabular_support").contains(databaseName) || 
      databaseName.endsWith("_deduplicated") || 
      databaseName.endsWith("_archived")) {
    return s"Refusing to archive database: $databaseName -- not a valid database for archival"
  }

  // Define the archived database name
  val archivedDatabaseName = s"${databaseName}_archived"

  // Create the archived database if it doesn't exist
  spark.sql(s"CREATE DATABASE IF NOT EXISTS $archivedDatabaseName")

  // Get the list of tables in the original database
  val tables = spark.catalog.listTables(databaseName).collect()

  // Move each table to the archived database
  tables.foreach { table =>
    val tableName = table.name
    spark.sql(s"ALTER TABLE $databaseName.$tableName RENAME TO $archivedDatabaseName.$tableName")
  }
    
  return s"Successfully archived $databaseName to $archivedDatabaseName"
}

def promote_deduplicated_database(spark: SparkSession, databaseName: String): String = {
  // Check if the database name ends with `_deduplicated`
  if (!databaseName.endsWith("_deduplicated")) {
    return s"❌ Refusing to promote database: $databaseName -- only *_deduplicated databases can be promoted"
  }

  // Define the target database name by removing `_deduplicated`
  val targetDatabaseName = databaseName.stripSuffix("_deduplicated")

  // Check if the target database exists
  val targetDatabaseExists = spark.catalog.databaseExists(targetDatabaseName)
  if (!targetDatabaseExists) {
    return s"❌ Target database $targetDatabaseName does not exist"
  }

  // Check if the target database is empty
  val live_tables = spark.catalog.listTables(targetDatabaseName).collect()
  if (live_tables.nonEmpty) {
    return s"❌ Target database $targetDatabaseName is not empty! Found tables $live_tables"
  }

  // Get the list of tables in the deduplicated database
  val tables = spark.catalog.listTables(databaseName).collect()

  // Move each table to the target database
  tables.foreach { table =>
    val tableName = table.name
    spark.sql(s"ALTER TABLE $databaseName.$tableName RENAME TO $targetDatabaseName.$tableName")
  }

  s"Successfully promoted $databaseName to $targetDatabaseName"
}

def dropDatabase(spark: SparkSession, databaseName: String): Unit = {
  // Get the list of tables in the database
  val tables = spark.catalog.listTables(databaseName).collect()

  // Drop each table in the database
  tables.foreach { table =>
    val tableName = table.name
    try {
      spark.sql(s"DROP TABLE IF EXISTS $databaseName.$tableName")
      println(s"Successfully dropped table: $tableName")
    } catch {
      case e: Exception =>
        println(s"Error dropping table: $tableName - ${e.getMessage}")
    }
  }

  // Drop the database
  try {
    spark.sql(s"DROP DATABASE IF EXISTS $databaseName")
    println(s"Successfully dropped database: $databaseName")
  } catch {
    case e: Exception =>
      println(s"Error dropping database: $databaseName - ${e.getMessage}")
  }
}


def cleanup_deduplicated_databases(): Unit = {
    val databases = spark.sql("SHOW DATABASES").collect().map(_.getString(0))

    val deduplicatedDatabases = databases.filter(db => db.endsWith("_deduplicated"))

    deduplicatedDatabases.foreach { db =>
      try {
        // Drop the database
        spark.sql(s"DROP DATABASE IF EXISTS $db")
        println(s"✅ Successfully dropped database $db")
          
      } catch {
        case e: Exception =>
          println(s"❌ Error dropping database $db: ${e.getMessage}")
      }
    }
}


def archive_and_promote_catalog(spark: SparkSession): Unit = {
    val current_catalog = spark.catalog.currentCatalog
    val batch_id = getRandomBatchId() + s" -- $current_catalog"
    
    val curried_log = (log_to_tabular _).curried
    val log = curried_log(batch_id)(s"archive_and_promote_catalog")
    log(s"Started batch $batch_id ")
    
    // get databases to work with
    val databases = spark.sql("SHOW DATABASES").collect().map(_.getString(0)).filter {
            db => !db.startsWith("bronze_metadata")
    }
    val deduplicated_databases = databases.filter(db => db.endsWith("_deduplicated"))
    val live_databases = databases.filter { db => 
        !db.endsWith("_deduplicated") && 
        !db.endsWith("_archived") && 
        db != "system" && 
        db != "examples" && 
        db != "tabular_support"
    }
    
    
    // archive the live tables
    live_databases.foreach { live_db =>
        try {
            log(archive_database(spark, live_db) + "\n")
        } catch {
          case e: Exception =>
            // The typical exceptions have full stack traces in the error message, so grab just the first line of the message
            log(s"❌ Exception during archival for database $live_db: ${e.getMessage.split("\n").head}" + "\n")
        }
    }  
    
    
    log(s"\n\n💎 All live databases have been archived! Moving on to promote deduplicated databases\n\n")
    
    
    // promote the deduplicated tables
    deduplicated_databases.foreach { dedup_db =>
        try {
            log(promote_deduplicated_database(spark, dedup_db))
        } catch {
          case e: Exception =>
            // The typical exceptions have full stack traces in the error message, so grab just the first line of the message
            log(s"❌ Exception during archival for database $dedup_db: ${e.getMessage.split("\n").head}")
        }
    }  
    
    
    // cleanup empty deduplicated databases
    log(s"\n💪 Removing all empty deduplicated databases : )")
    cleanup_deduplicated_databases()
    
    log(s"\n✅ Completed batch $batch_id")
}


def decorrupt_catalog(spark: SparkSession, dry_run: Boolean = true): Unit = {
  val current_catalog = spark.catalog.currentCatalog
  val batch_id = getRandomBatchId() + s" -- $current_catalog"
  log_to_tabular(batch_id, s"decorruption.start", s"Started batch $batch_id")
    
  // List all tables in the catalog
  val tables = listTablesInCatalog(spark, batch_id)
  
  // Get file counts for each table
  val tablesWithFileCounts = tables.map { table =>
    val fileCount = getFileCount(spark, table)
    (table, fileCount)
  }

  // Sort tables by file count (ascending order)
  val sortedTables = tablesWithFileCounts.sortBy(_._2)

  // Process tables from smallest to largest
  sortedTables.foreach { case (table, _) =>
    performDeletes(spark, table, dry_run, batch_id)
  }
    
  log_to_tabular(batch_id, s"decorruption.finish", s"\n✅ Completed batch $batch_id")
}

def deduplicate_catalog(spark: SparkSession, dry_run: Boolean = true): Unit = {
  val current_catalog = spark.catalog.currentCatalog
  val batch_id = getRandomBatchId() + s" -- $current_catalog"
  log_to_tabular(batch_id, s"deduplicate_catalog.start", s"Started batch $batch_id")
    
  // List all tables in the catalog
  val tables = listTablesInCatalog(spark, batch_id)
    
  // build deduplicated versions of each table
  tables.foreach { table =>
    try {
      deduplicate_table(spark, table, dry_run, batch_id)
    } catch {
      case e: Exception =>
        // The typical exceptions have full stack traces in the error message, so grab just the first line of the message
        log_to_tabular(batch_id, s"deduplicate_table.error", s"❌ Error during deduplication for table $table: ${e.getMessage.split("\n").head}")
    }
  }  
    
  log_to_tabular(batch_id, s"deduplicate_catalog.finish", s"\n✅ Completed batch $batch_id")
}


## Now, let's actually do the work 💪

In [None]:
decorrupt_catalog(spark, false)
// deduplicate_catalog(spark, false)
// archive_and_promote_catalog(spark)