In [0]:
%sql
USE CATALOG dev;

In [0]:
class TableMetadata:
    def __init__(self, raw_table_name):
        parts = raw_table_name.split(".")

        try:
            current_catalog = spark.sql("SELECT current_catalog()").collect()[0][0]
        except:
            raise Exception("Unable to determine current catalog")

        match len(parts):
            case 3:
                self.catalog, self.schema, self.table = parts
            case 2:
                self.catalog = current_catalog
                self.schema, self.table = parts
            case 1:
                self.catalog = current_catalog
                self.schema = spark.sql("SELECT current_schema()").collect()[0][0]
                self.table = raw_table_name
            case _:
                raise ValueError(f"Invalid table name format: {raw_table_name}")
        
        self.full_table_name = f"{self.catalog}.{self.schema}.{self.table}"
    
    def fetch_count(self):
        try:
            return spark.sql(f"DESCRIBE DETAIL {self.full_table_name}") \
                .select("numRecords").collect()[0][0]
        except:
            return spark.read.table(self.full_table_name).count()
    
    def __repr__(self):
        return f"<Table: {self.full_table_name}>"

In [0]:
from pyspark.sql import functions as F
from itertools import combinations

class TableComparator:
    def __init__(self, source_table_name: str, target_table_name: str, primary_keys: list = None):
        self.source_meta = TableMetadata(source_table_name)
        self.target_meta = TableMetadata(target_table_name)
        self.primary_keys = primary_keys if primary_keys else []

        self.source_count = None
        self.target_count = None

        self.common_columns = sorted(self._identify_common_columns())
        self.failed_rows_df = None

        self.MAX_SAMPLE_SIZE = 10000


    def _identify_common_columns(self):
        cols_source = set(spark.table(self.source_meta.full_table_name).columns)
        cols_target = set(spark.table(self.target_meta.full_table_name).columns)
        return list(cols_source.intersection(cols_target))
    

    def _get_counts(self):
        if self.source_count is None:
            self.source_count = self.source_meta.fetch_count()
        if self.target_count is None:
            self.target_count = self.target_meta.fetch_count()


    def _discover_primary_keys(self, df):
        """
        Identifies the best column(s) to use as a unique identifier
        """
        if self.primary_keys:
            print(f"\tUsing provided Primary keys: {self.primary_keys}")
            return self.primary_keys
        
        print("\tDiscovery: Analysing columns...")
        sample_rows = df.count()

        exprs = [F.countDistinct(c).alias(c) for c in self.common_columns]
        distinct_counts = df.agg(*exprs).collect()[0].asDict()

        candidates = [col for col, count in distinct_counts.items() if count == sample_rows]
        valid_keys = [k for k in candidates if df.filter(F.col(k).isNull()).count() == 0]

        if valid_keys:
            best_key = next((k for k in valid_keys if "id" in k.lower()), valid_keys[0])
            self.primary_keys = [best_key]
            print(f"\tKeys Identified: ['{best_key}']")
            return self.primary_keys
        
        print("\tNo single key found. Checking composite keys...")

        high_cardinality_cols = [col for col, count in distinct_counts.items() if count > sample_rows * 0.9]
        if len(high_cardinality_cols) < 2:
            high_cardinality_cols = sorted(distinct_counts, key=distinct_counts.get, reverse=True)[:5]
        
        for combo in combinations(high_cardinality_cols, 2):
            if df.select(combo).distinct().count() == sample_rows:
                print(f"\tComposite Key Verified: {list(combo)}")
                return list(combo)
        
        print(" \WARNING: No unique key found. Using best-guess.")
        self.primary_keys = sorted(distinct_counts, key=distinct_counts.get, reverse=True)[:2]
        return self.primary_keys


    def _align_and_sample_data(self):
        self._get_counts()

        def read_clean(table_meta):
            return spark.read.table(table_meta.full_table_name).select(self.common_columns)
        
        if self.source_count <= self.MAX_SAMPLE_SIZE:
            print("\tSmall table. Using FULL comparison.")
            df_source = read_clean(self.source_meta).alias("source")

            self.primary_keys = self._discover_primary_keys(df_source)

            df_target = read_clean(self.target_meta).alias("target")
            return df_source, df_target
        
        print("\tLarge table. ({self.source_count} rows). Sampling...")
        fraction = self.MAX_SAMPLE_SIZE / self.source_count

        df_source = read_clean(self.source_meta) \
            .sample(withReplacement=False, fraction=fraction, seed=42) \
            .alias("source")
        
        self.primary_keys = self._discover_primary_keys(df_source)

        df_target = read_clean(self.target_meta) \
            .join(df_source.select(self.primary_keys), on=self.primary_keys, how="inner") \
            .alias("target")

        return df_source, df_target


    def execute_comparison(self):
        print(f"Starting Validatation: '{self.source_meta.full_table_name}' vs '{self.target_meta.full_table_name}'")

        if self.source_count != self.target_count:
            print(f"WARNING: Table sizes do not match. Source: {self.source_count} | Target: {self.target_count}")
        
        df_source, df_target = self._align_and_sample_data()

        data_cols = [c for c in self.common_columns if c not in self.primary_keys]

        def _generate_fingerprint(df):
            return df.withColumn("fingerprint", F.sha2(F.concat_ws("||", *data_cols), 256))
        
        df_source_hashed = _generate_fingerprint(df_source).select(*self.primary_keys, "fingerprint")
        df_target_hashed = _generate_fingerprint(df_target).select(*self.primary_keys, "fingerprint")

        diff_source = df_source_hashed.exceptAll(df_target_hashed).withColumn("origin", F.lit("source"))
        diff_target = df_target_hashed.exceptAll(df_source_hashed).withColumn("origin", F.lit("target"))

        self.failed_rows_df = diff_source.unionByName(diff_target)
        discrepancy_count = self.failed_rows_df.count()

        if discrepancy_count == 0:
            print("SUCCESS: Data matches.")
            return True
        else:
            print(f"FAILED: Found {discrepancy_count} discrepencies.")
            diff_source.show(5)
            return False
    

    def get_mismatched_rows(self, limit=20):
        if self.failed_rows_df is None or self.failed_rows_df.count() == 0:
            print("No discrepancies to show.")
            return None
        
        print(f"Showing top {limit} mismatched rows...")
        full_details = spark.read.table(self.source_meta.full_table_name) \
            .join(self.failed_rows_df, on=self.primary_keys, how="inner") \
            .limit(limit)

        return full_details

In [0]:
validator = TableComparator("dev.test.employees", "dev.test.employees")
validator.execute_comparison()

In [0]:
validator.get_mismatched_rows()