In [0]:
%sql
USE CATALOG dev;
USE SCHEMA test;

In [0]:
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from itertools import combinations
from typing import List, Optional, Literal

In [0]:
class TableMetadata:
    def __init__(self, raw_table_name):
        self.full_table_name = self._resolve_name(raw_table_name)
        self._count = None
        self._columns = None


    def _resolve_name(self, raw_table_name: str) -> str:
        parts = raw_table_name.split(".")
        length = len(parts)

        if length == 3: return raw_table_name
        
        try:
            current_catalog = spark.sql("SELECT current_catalog()").collect()[0][0]
            schema = spark.sql("SELECT current_schema()").collect()[0][0]
        except:
            raise Exception("Unable to determine current catalog")
        
        if length == 2: return f"{current_catalog}.{parts[0]}.{parts[1]}"
        if length == 1: return f"{current_catalog}.{schema}.{parts[0]}"
        raise ValueError(f"Invalid table name format: {raw_table_name}")


    @property
    def row_count(self) -> int:
        if self._count is None:
            try:
                self._count = spark.sql(f"DESCRIBE DETAIL {self.full_table_name}") \
                    .select("numRecords").collect()[0][0]
            except:
                self._count = spark.read.table(self.full_table_name).count()
        return self._count


    @property
    def columns(self) -> List[str]:
        if self._columns is None:
            self._columns = spark.read.table(self.full_table_name).columns
        return self._columns


    def read_clean(self, select_cols: List[str]) -> DataFrame:
        return spark.read.table(self.full_table_name).select(select_cols)


    def __repr__(self) -> str:
        return f"Table: {self.full_table_name}"


In [0]:
class KeyDiscoverer:
    @staticmethod
    def discover_primary_keys(df: DataFrame, candidates: List[str]) -> List[str]:
        """
        Identifies the best column(s) to use as a unique identifier
        """
        print("\tDiscovery: Analysing columns...")
        sample_rows = df.count()

        exprs = [F.countDistinct(c).alias(c) for c in candidates]
        distinct_counts = df.agg(*exprs).collect()[0].asDict()

        keys = [col for col, count in distinct_counts.items() if count == sample_rows]
        valid_keys = [k for k in keys if df.filter(F.col(k).isNull()).count() == 0]

        if valid_keys:
            best_key = next((k for k in valid_keys if "id" in k.lower()), valid_keys[0])
            print(f"\tKeys Identified: ['{best_key}']")
            return [best_key], True
        
        print("\tNo single key found. Checking composite keys...")
        sorted_cols = sorted(distinct_counts, key=distinct_counts.get, reverse=True)
        top_candidates = sorted_cols[:5]
        
        for combo in combinations(top_candidates, 3):
            if df.select(*combo).distinct().count() == sample_rows:
                print(f"\tComposite Key Verified: {list(combo)}")
                return list(combo), True
        
        print("\tWARNING: No unique key found.")
        return [], False

In [0]:
class TableComparator:
    history = []

    def __init__(self, source_table_name: str, target_table_name: str, primary_keys: Optional[List[str]] = None):
        self.source = TableMetadata(source_table_name)
        self.target = TableMetadata(target_table_name)
        self.primary_keys = primary_keys
        self.source_count = None

        self._source_only_raw = None
        self._target_only_raw = None
        self._source_only_count = 0
        self._target_only_count = 0
        self.active_keys = []
        self.is_keyless_mode = False

        self.MAX_SAMPLE_SIZE = 10000


    def _identify_common_columns(self) -> List[str]:
        common_cols = set(self.source.columns).intersection(set(self.target.columns))
        return sorted(list(common_cols))

    def _log_result(self, status: str, mode: str, source_err: int, target_err: int):
        record = {
            "source_table": self.source.full_table_name,
            "target_table": self.target.full_table_name,
            "status": status,
            "mode": mode,
            "keys_used": str(self.active_keys) if self.active_keys else "N/A",
            "source_only_count": source_err,
            "target_only_count": target_err if target_err is not None else -1
        }
        TableComparator.history.append(record)

    def compare(self) -> bool:
        print(f"\n\nStarting Validatation: '{self.source.full_table_name}' vs '{self.target.full_table_name}'")
        common_columns = self._identify_common_columns()
        self.source_count = self.source.row_count
        
        if self.source_count > self.MAX_SAMPLE_SIZE:
            print(f"\tLarge table. ({self.source_count} rows). Sampling...")
            fraction = self.MAX_SAMPLE_SIZE / self.source_count
            df_source = self.source.read_clean(common_columns) \
                .sample(withReplacement=False, fraction=fraction, seed=42) \
                .alias("source")
        else:
            print("\tSmall table. Using FULL comparison.")
            df_source = self.source.read_clean(common_columns).alias("source")
        
        if self.primary_keys:
            self.active_keys, is_trusted = self.primary_keys, True
        else:
            self.active_keys, is_trusted = KeyDiscoverer.discover_primary_keys(df_source, common_columns)
        
        if is_trusted:
            self.is_keyless_mode = False
            print("\tMode: Key based validation. (Hash comparison)")
            df_target = self.target.read_clean(common_columns) \
                .join(df_source.select(self.active_keys), on=self.active_keys, how="inner") \
                .alias("target")
            
            return self._execute_hash_comparison(df_source, df_target, self.active_keys, common_columns)
        else:
            self.is_keyless_mode = True
            print("\tMode: Keyless validation. (Set Difference)")
            return self._execute_keyless_comparison(df_source, common_columns)


    def _execute_hash_comparison(self, df_source, df_target, primary_keys, common_columns) -> bool:
        data_cols = [c for c in common_columns if c not in primary_keys]

        def _generate_hash(df):
            return df.withColumn("fingerprint", F.sha2(F.concat_ws("||", *data_cols), 256))
        
        df_source_hashed = _generate_hash(df_source).select(*primary_keys, "fingerprint")
        df_target_hashed = _generate_hash(df_target).select(*primary_keys, "fingerprint")

        self._source_only_raw = df_source_hashed.exceptAll(df_target_hashed).withColumn("origin", F.lit("source"))
        self._target_only_raw = df_target_hashed.exceptAll(df_source_hashed).withColumn("origin", F.lit("target"))

        self._source_only_count = self._source_only_raw.count()
        self._target_only_count = self._target_only_raw.count()

        if self._source_only_count + self._target_only_count == 0:
            print(f"SUCCESS: Data matches. {self.source_count} rows")
            self._log_result("PASS", "KEYED", 0, 0)
            return True
        else:
            print(f"FAILED: {self._source_only_count + self._target_only_count} / {self.source_count} rows does not match.")
            self._log_result("FAIL", "KEYED", self._source_only_count, self._target_only_count)
            print("\nSource:")
            self._source_only_raw.show(5)
            print("\nTarget:")
            self._target_only_raw.show(5)
            return False


    def _execute_keyless_comparison(self, df_source, common_columns) -> bool:
        df_target = self.target.read_clean(common_columns).alias("target")
        
        self._source_only_raw = df_source.exceptAll(df_target)
        self._target_only_raw = None
        self._source_only_count = self._source_only_raw.count()

        if self._source_only_count == 0:
            print(f"SUCCESS: Data matches. {self.source_count} rows")
            self._log_result("PASS", "KEYLESS", 0, None)
            return True
        else:
            print(f"FAILED: {self._source_only_count} / {self.source_count} rows does not match.")
            self._log_result("FAIL", "KEYLESS", self._source_only_count, None)
            print("\nSource:")
            self._source_only_raw.show(5)
            return False


    def get_source_only(self, limit: int = 20) -> DataFrame:
        if self._source_only_raw is None or self._source_only_count == 0:
            print("No discrepancies to show.")
            return None
        
        print(f"Showing top {limit} mismatched rows...")
        if self.is_keyless_mode:
            return self._source_only_raw.limit(limit)
        else:
            keys = self.active_keys
            return spark.read.table(self.source.full_table_name) \
                .join(self._source_only_raw, on=keys, how="inner") \
                .limit(limit)


    def get_target_only(self, limit: int = 20) -> DataFrame:
        if self._target_only_raw is None or self._target_only_count == 0:
            print("No discrepancies to show.")
            return None
        
        print(f"Showing top {limit} mismatched rows...")
        if self.is_keyless_mode:
            return spark.read.table(self.target.full_table_name).exceptAll(self._source_only_raw).limit(limit)
        else:
            keys = self.active_keys
            return spark.read.table(self.target.full_table_name) \
                .join(self._target_only_raw, on=keys, how="inner") \
                .limit(limit)


    def get_error_rows(self, from_tbl: Literal["source", "target"], limit: int = 20) -> DataFrame:
        match from_tbl:
            case "source":
                self.get_source_only(limit)
            case "target":
                self.get_target_only(limit)
            case _:
                raise ValueError(f"Invalid. Please enter either 'source' or 'target'")
        return True


    @classmethod
    def get_history(cls) -> DataFrame:
        if not cls.history:
            print("No history available yet.")
            return None
        return spark.createDataFrame(cls.history)


    @classmethod
    def clear_history(cls):
        cls.history = []
        print("History cleared.")


In [0]:
validator = TableComparator("employees", "employees")
validator.compare()

In [0]:
validator.get_error_rows(from_tbl="source")

In [0]:
logs = TableComparator.get_history()
logs.display()