In [0]:
%sql
USE CATALOG dev;

In [0]:
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from itertools import combinations
from typing import List, Optional

In [0]:
class TableMetadata:
    def __init__(self, raw_table_name):
        self.full_table_name = self._resolve_name(raw_table_name)
        self._count = None
        self._columns = None


    def _resolve_name(self, raw_table_name: str) -> str:
        parts = raw_table_name.split(".")
        length = len(parts)

        if length == 3: return raw_table_name
        
        try:
            current_catalog = spark.sql("SELECT current_catalog()").collect()[0][0]
            schema = spark.sql("SELECT current_schema()").collect()[0][0]
        except:
            raise Exception("Unable to determine current catalog")
        
        if length == 2: return f"{current_catalog}.{parts[0]}.{parts[1]}"
        if length == 1: return f"{current_catalog}.{schema}.{parts[0]}"
        raise ValueError(f"Invalid table name format: {raw_table_name}")


    @property
    def row_count(self) -> int:
        if self._count is None:
            try:
                self._count = spark.sql(f"DESCRIBE DETAIL {self.full_table_name}") \
                    .select("numRecords").collect()[0][0]
            except:
                self._count = spark.read.table(self.full_table_name).count()
        return self._count


    @property
    def columns(self) -> List[str]:
        if self._columns is None:
            self._columns = spark.read.table(self.full_table_name).columns
        return self._columns


    def read_clean(self, select_cols: List[str]) -> DataFrame:
        return spark.read.table(self.full_table_name).select(select_cols)


    def __repr__(self) -> str:
        return f"<Table: {self.full_table_name}>"


In [0]:
class KeyDiscoverer:
    @staticmethod
    def discover_primary_keys(df: DataFrame, candidates: List[str]) -> List[str]:
        """
        Identifies the best column(s) to use as a unique identifier
        """
        print("\tDiscovery: Analysing columns...")
        sample_rows = df.count()

        exprs = [F.countDistinct(c).alias(c) for c in candidates]
        distinct_counts = df.agg(*exprs).collect()[0].asDict()

        keys = [col for col, count in distinct_counts.items() if count == sample_rows]
        valid_keys = [k for k in keys if df.filter(F.col(k).isNull()).count() == 0]

        if valid_keys:
            best_key = next((k for k in valid_keys if "id" in k.lower()), valid_keys[0])
            print(f"\tKeys Identified: ['{best_key}']")
            return [best_key]
        
        print("\tNo single key found. Checking composite keys...")

        sorted_cols = sorted(distinct_counts, key=distinct_counts.get, reverse=True)
        high_cardinality_cols = sorted_cols[:5]
        
        for combo in combinations(high_cardinality_cols, 3):
            if df.select(combo).distinct().count() == sample_rows:
                print(f"\tComposite Key Verified: {list(combo)}")
                return list(combo)
        
        print(" \WARNING: No unique key found. Using best-guess.")
        return sorted_cols[:3]

In [0]:
class TableComparator:
    def __init__(self, source_table_name: str, target_table_name: str, primary_keys: Optional[List[str]] = None):
        self.source = TableMetadata(source_table_name)
        self.target = TableMetadata(target_table_name)
        self.primary_keys = primary_keys
        self.failed_rows_df = None

        self.MAX_SAMPLE_SIZE = 10000


    def _identify_common_columns(self) -> List[str]:
        common_cols = set(self.source.columns).intersection(set(self.target.columns))
        return sorted(list(common_cols))


    def _sample_data(self):
        common_columns = self._identify_common_columns()
        source_count = self.source.row_count
        
        if source_count > self.MAX_SAMPLE_SIZE:
            print("\tLarge table. ({self.source_count} rows). Sampling...")
            fraction = self.MAX_SAMPLE_SIZE / self.source_count
            df_source = self.source.read_clean(common_columns) \
                .sample(withReplacement=False, fraction=fraction, seed=42) \
                .alias("source")
        else:
            print("\tSmall table. Using FULL comparison.")
            df_source = self.source.read_clean(common_columns).alias("source")
        
        if self.primary_keys:
            keys = self.primary_keys
        else:
            keys = KeyDiscoverer.discover_primary_keys(df_source, common_columns)

        df_target = self.target.read_clean(common_columns) \
            .join(df_source.select(keys), on=keys, how="inner") \
            .alias("target")

        return df_source, df_target, keys, common_columns


    def _execute_comparison(self, df_source, df_target, primary_keys, common_columns) -> bool:
        data_cols = [c for c in common_columns if c not in primary_keys]

        def _generate_hash(df):
            return df.withColumn("fingerprint", F.sha2(F.concat_ws("||", *data_cols), 256))
        
        df_source_hashed = _generate_hash(df_source).select(*primary_keys, "fingerprint")
        df_target_hashed = _generate_hash(df_target).select(*primary_keys, "fingerprint")

        diff_source = df_source_hashed.exceptAll(df_target_hashed).withColumn("origin", F.lit("source"))
        diff_target = df_target_hashed.exceptAll(df_source_hashed).withColumn("origin", F.lit("target"))

        self.failed_rows_df = diff_source.unionByName(diff_target)
        discrepancy_count = self.failed_rows_df.count()

        if discrepancy_count == 0:
            print("SUCCESS: Data matches.")
            return True
        else:
            print(f"FAILED: Found {discrepancy_count} discrepencies.")
            diff_source.show(5)
            return False
    

    def compare(self) -> bool:
        print(f"Starting Validatation: '{self.source.full_table_name}' vs '{self.target.full_table_name}'")        
        return self._execute_comparison(*self._sample_data())
    

    def get_mismatched_rows(self, limit=20):
        if self.failed_rows_df is None or self.failed_rows_df.count() == 0:
            print("No discrepancies to show.")
            return None
        
        print(f"Showing top {limit} mismatched rows...")
        full_details = spark.read.table(self.source.full_table_name) \
            .join(self.failed_rows_df, on=self.primary_keys, how="inner") \
            .limit(limit)

        return full_details

In [0]:
validator = TableComparator("dev.test.employees", "dev.test.employees")
validator.compare()

In [0]:
validator.get_mismatched_rows()