In [0]:
%sql
USE CATALOG dev;
USE SCHEMA test;

In [0]:
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from itertools import combinations
from typing import List, Optional, Tuple, Literal
from datetime import datetime

In [0]:
class TableMetadata:
    def __init__(self, table_name):
        self.full_name = self._resolve_name(table_name)
        self._count = None


    def _resolve_name(self, raw_name: str) -> str:
        parts = raw_name.split(".")
        length = len(parts)

        if length == 3: return raw_name
        
        try:
            current_catalog = spark.sql("SELECT current_catalog()").collect()[0][0]
            current_schema = spark.sql("SELECT current_schema()").collect()[0][0]
        except:
            raise Exception("Unable to determine current catalog")
        
        if length == 2: return f"{current_catalog}.{parts[0]}.{parts[1]}"
        if length == 1: return f"{current_catalog}.{current_schema}.{parts[0]}"
        raise ValueError(f"Invalid table name format: {raw_name}")


    @property
    def row_count(self) -> int:
        if self._count is None:
            try:
                self._count = spark.sql(f"DESCRIBE DETAIL {self.full_name}") \
                    .select("numRecords").collect()[0][0]
            except:
                self._count = spark.read.table(self.full_name).count()
        return self._count


    @property
    def columns(self) -> List[str]:
        return spark.read.table(self.full_name).columns


    def read_selected(self, cols: List[str]) -> DataFrame:
        return spark.read.table(self.full_name).select(cols)


    def __repr__(self) -> str:
        return f"Table: {self.full_name}"


In [0]:
class KeyDiscoverer:
    @staticmethod
    def find_keys(df: DataFrame, candidates: List[str]) -> Tuple[List[str], bool]:
        """
        Identifies the best column(s) to use as a unique identifier
        """
        print("\tDiscovery: Analysing columns...")
        sample_count = df.count()

        exprs = [F.countDistinct(c).alias(c) for c in candidates]
        distinct_counts = df.agg(*exprs).collect()[0].asDict()

        keys = [col for col, count in distinct_counts.items() if count == sample_count]
        valid_keys = [k for k in keys if df.filter(F.col(k).isNull()).count() == 0]

        if valid_keys:
            best_key = next((k for k in valid_keys if "id" in k.lower()), valid_keys[0])
            print(f"\tSingle key Identified: ['{best_key}']")
            return [best_key], True
        
        print("\tNo single key found. Checking composite keys...")
        sorted_cols = sorted(distinct_counts, key=distinct_counts.get, reverse=True)
        top_candidates = sorted_cols[:5]
        
        for combo in combinations(top_candidates, 3):
            if df.select(*combo).distinct().count() == sample_count:
                print(f"\tComposite Key Verified: {list(combo)}")
                return list(combo), True
        
        print("\tWARNING: No unique key found.")
        return [], False

In [0]:
class TableComparator:
    history = []

    def __init__(self, source_table: str, target_table: str, keys: Optional[List[str]] = None):
        self.source = TableMetadata(source_table)
        self.target = TableMetadata(target_table)
        self.keys = keys

        self.source_count = None
        self.join_keys = []
        self.is_keyless_mode = False

        self._df_mismatch_source = None
        self._df_mismatch_target = None
        self._df_mismatches_merged = None

        self._count_source_only = 0
        self._count_target_only = 0
        self.rows_analyzed = 0
        
        self.MAX_SAMPLE_SIZE = 10000


    def _get_common_columns(self) -> List[str]:
        common_cols = set(self.source.columns).intersection(set(self.target.columns))
        return sorted(list(common_cols))


    def _log_result(self, status: str, mode: str):
        record = {
            "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "source": self.source.full_name,
            "target": self.target.full_name,
            "status": status,
            "mode": mode,
            "keys": str(self.join_keys) if self.join_keys else "N/A",
            "source_mismatch_count": self._count_source_only,
            "target_mismatch_count": self._count_target_only if self._count_target_only != -1 else "N/A"
        }
        TableComparator.history.append(record)


    def compare(self, full_scan=False) -> bool:
        print(f"\n\nStarting Validatation: '{self.source.full_name}' vs '{self.target.full_name}'")
        self.source_count = self.source.row_count
        common_columns = self._get_common_columns()

        is_sampling = (self.source_count > self.MAX_SAMPLE_SIZE) and (not full_scan)
        
        if is_sampling:
            print(f"\tLarge table ({self.source_count} rows). Sampling...")
            fraction = self.MAX_SAMPLE_SIZE / self.source_count
            df_source = self.source.read_selected(common_columns) \
                .sample(withReplacement=False, fraction=fraction, seed=42) \
                .alias("source")
        else:
            print(f"\tSmall table. Using FULL comparison ({self.source_count} rows).")
            df_source = self.source.read_selected(common_columns).alias("source")
        
        self.rows_analyzed = df_source.count()
        print(f"\tActual rows Analyzed: {self.rows_analyzed}")
        
        if self.keys:
            self.join_keys, is_trusted = self.keys, True
        else:
            self.join_keys, is_trusted = KeyDiscoverer.find_keys(df_source, common_columns)
        
        if is_trusted:
            self.is_keyless_mode = False
            print(f"\tMode: Key-based validation. (Hash comparison) (Keys: {self.join_keys})")
            df_target = self.target.read_selected(common_columns) \
                .join(df_source.select(self.join_keys), on=self.join_keys, how="inner") \
                .alias("target")
            
            return self._run_keyed_comparison(df_source, df_target, common_columns)
        else:
            self.is_keyless_mode = True
            print("\tMode: Keyless validation. (Set Difference)")
            return self._run_keyless_comparison(df_source, common_columns, is_sampling)


    def _run_keyed_comparison(self, df_source, df_target, cols) -> bool:
        hash_cols = [c for c in cols if c not in self.join_keys]

        def _hash(df):
            return df.withColumn("fingerprint", F.sha2(F.concat_ws("||", *hash_cols), 256))
        
        df_source_hashed = _hash(df_source).select(*self.join_keys, "fingerprint")
        df_target_hashed = _hash(df_target).select(*self.join_keys, "fingerprint")

        self._df_mismatch_source = df_source_hashed.exceptAll(df_target_hashed).withColumn("origin", F.lit("source"))
        self._df_mismatch_target = df_target_hashed.exceptAll(df_source_hashed).withColumn("origin", F.lit("target"))

        self._count_source_only = self._df_mismatch_source.count()
        self._count_target_only = self._df_mismatch_target.count()

        self._df_mismatches_merged = self._df_mismatch_source.unionByName(self._df_mismatch_target)

        if self._count_source_only + self._count_target_only == 0:
            print(f"SUCCESS: Data matches. {self.source_count} rows")
            self._log_result("PASS", "KEYED")
            return True
        else:
            print(f"FAILED:")
            print(f"\tSource mismatch: {self._count_source_only}\n\tTarget mismatch: {self._count_target_only}\n\tTotal rows in source: {self.source_count}")
            self._log_result("FAIL", "KEYED")
            return False


    def _run_keyless_comparison(self, df_source, cols, is_sampling: bool) -> bool:
        df_target = self.target.read_selected(cols).alias("target")
        
        self._df_mismatch_source = df_source.exceptAll(df_target).withColumn("origin", F.lit("source"))
        self._count_source_only = self._df_mismatch_source.count()

        if is_sampling:
            self._df_mismatch_target = None
            self._count_target_only = -1

            self._df_mismatches_merged = self._df_mismatch_source
        else:
            self._df_mismatch_target = df_target.exceptAll(df_source).withColumn("origin", F.lit("target"))
            self._count_target_only = self._df_mismatch_target.count()

            self._df_mismatches_merged = self._df_mismatch_src.unionByName(self._df_mismatch_target)
        
        failed = (self._count_source_only > 0) or (self._count_target_only > 0)

        if not failed == 0:
            print(f"SUCCESS: Data matches. {self.source_count} rows")
            self._log_result("PASS", "KEYLESS")
            return True
        else:
            print(f"FAILED:")
            print(f"\tSource mismatch: {self._count_source_only}\n\tTarget mismatch: {self._count_target_only}\n\tTotal rows in source: {self.source_count}")
            self._log_result("FAIL", "KEYLESS")
            return False


    def get_all_mismatches(self, limit: int = 100) -> DataFrame:
        if self._df_mismatches_merged is None or self._df_mismatches_merged.count() == 0:
            print("No discrepancies to show.")
            return None
        
        if self.is_keyless_mode:
            return self._df_mismatches_merged.limit(limit)

        print(f"\tFetching full row details for combined mismatches...")

        final_df = None

        if self._count_source_only > 0:
            df_source_errs = spark.read.table(self.source.full_name) \
                .join(self._df_mismatch_source, on=self.join_keys, how="inner") \
                .limit(limit)
            final_df = df_source_errs
        
        if self._count_target_only > 0:
            df_target_errs = spark.read.table(self.target.full_name) \
                .join(self._df_mismatch_target, on=self.join_keys, how="inner") \
                .limit(limit)
            
            if final_df:
                cols = self._get_common_columns() + ["origin"]
                final_df = final_df.select(cols).unionByName(df_target_errs.select(cols))
            else:
                final_df = df_target_errs
        
        return final_df
    

    def get_summary_report(self) -> DataFrame:
        if self.source_count is None:
            print("Summary not available. Please run `compare()` first.")

        print("Generating Summary Report...")
        target_count = self.target.row_count
        row_diff = self.source_count - target_count

        matched_count = self.rows_analyzed - self._count_source_only
        match_percentage = (matched_count / self.source_count) * 100.0 if self.source_count > 0 else 0.0

        data = [
            ("Total rows compared", str(self.rows_analyzed)),
            ("Matched rows", f"{matched_count} ({match_percentage:.2f}%)"),
            ("Source Rows (Total)", str(self.source_count)),
            ("Target Rows (Total)", str(target_count)),
            ("Total row count diff", str(row_diff)),
            ("Mismatch in source", str(self._count_source_only)),
            ("Mismatch in target", "N/A (Sampling)" if self._count_target_only == -1 else str(self._count_target_only))
        ]

        if not self.is_keyless_mode and (self._count_source_only > 0 or self._count_target_only > 0):
            modified_ids = self._df_mismatch_source.join(self._df_mismatch_target, on=self.join_keys, how="inner") \
                .select(self.join_keys)
            mod_count = modified_ids.count()

            if mod_count > 0:
                data.append(("Mismatch rows (ID match, data diff)", str(mod_count)))

                common = self._get_common_columns()
                source_data = self.source.read_selected(common).join(modified_ids, on=self.join_keys, how="inner")
                target_data = self.target.read_selected(common).join(modified_ids, on=self.join_keys, how="inner")

                cols_to_check = [c for c in common if c not in self.join_keys]
                diff_exprs = [F.sum(F.when(F.col(f"source.{c}") != F.col(f"target.{c}"), 1).otherwise(0)).alias(c) for c in cols_to_check]

                joined = source_data.alias("source").join(target_data.alias("target"), on=self.join_keys, how="inner")
                diff_sums = joined.agg(*diff_exprs).collect()[0].asDict()

                sorted_diffs = sorted(diff_sums.items(), key=lambda x: x[1], reverse=True)[:5]
                for col, count in sorted_diffs:
                    if count > 0:
                        data.append((f"\t Col '{col}' mismatched", f"{count} ({(count / mod_count) * 100:.1f}%)"))

        return spark.createDataFrame(data, ["metric", "value"])


    @classmethod
    def get_history(cls) -> DataFrame:
        if not cls.history:
            print("No history available yet.")
            return None
        return spark.createDataFrame(cls.history)


    @classmethod
    def clear_history(cls):
        cls.history = []
        print("History cleared.")
    

    def __repr__(self):
        return f"{self.source.full_name} vs {self.target.full_name}"


In [0]:
comparator = TableComparator("employees", "employees_2")
comparator.compare()

In [0]:
logs = TableComparator.get_history()
logs.display()

In [0]:
comparator.get_all_mismatches().display()

In [0]:
comparator.get_summary_report().display()

In [0]:
comparator