In [0]:
%sql
USE CATALOG dev;

In [0]:
class TableDetails:
    def __init__(self, table):
        parts = table.split(".")
        match len(parts):
            case 3:
                self.catalog, self.schema, self.table_name = parts
            case 2:
                self.catalog = spark.sql("SELECT current_catalog()").collect()[0][0]
                self.schema, self.table_name = parts
            case 1:
                self.catalog = spark.sql("SELECT current_catalog()").collect()[0][0]
                self.schema = spark.sql("SELECT current_schema()").collect()[0][0]
                self.table_name = table
            case _:
                raise Exception(f"Invalid table name: {table}")
        self.three_level_namespace = f"{self.catalog}.{self.schema}.{self.table_name}"
    
    def get_details(self):
        return spark.sql(f"DESCRIBE DETAIL {self.three_level_namespace}").collect()[0].asDict()
    
    def __repr__(self):
        return f"TableStatus(table_name='{self.table_name}', columns={self.columns}, count={self.count})"

In [0]:
from pyspark.sql import functions as F

class TableComparator:
    def __init__(self, expected_table: str, actual_table: str):
        self.expected_table = TableDetails(expected_table)
        self.actual_table = TableDetails(actual_table)

        self.common_cols = sorted(self.get_commmon_cols(self.expected_table.three_level_namespace, self.actual_table.three_level_namespace))
        self.keys = []

        self.MAX_SAMPLE_SIZE = 10000

        self.expected_row_count = self.get_row_count(self.expected_table)
        self.actual_table_row_num = self.get_row_count(self.actual_table)
    
    def get_row_count(self, table):
        return spark.read.table(table.three_level_namespace).count()
    
    def get_commmon_cols(self, table1, table2):
        columns = spark.table(table1).columns
        common_cols = [c for c in columns if c in spark.table(table2).columns]
        return common_cols
    
    def _discover_keys(self, df):
        columns = self.common_cols
        distinct_counts = [df.select(c).distinct().count() for c in columns]

        candidate_keys = [columns[i] for i in range(len(columns)) if distinct_counts[i] == self.expected_row_count]

        if not candidate_keys:
            max_col = columns[0]
            max_2nd_col = columns[1]
            max_distinct_count = distinct_counts[0]
            max_2nd_distinct_count = distinct_counts[1]
            for i in range(len(columns)):
                if distinct_counts[i] >= max_distinct_count:
                    max_distinct_count = distinct_counts[i]
                    max_col = columns[i]
                elif distinct_counts[i] >= max_2nd_distinct_count:
                    max_2nd_distinct_count = distinct_counts[i]
                    max_2nd_col = columns[i]
            candidate_keys = [max_col, max_2nd_col]
                    
        return candidate_keys

    def _get_samples(self, expected, actual):
        expected_table_row_count = self.expected_row_count

        sample_size = min(expected_table_row_count, self.MAX_SAMPLE_SIZE)

        expected_df_sample = spark.read.table(expected.three_level_namespace).sample(withReplacement=False, fraction=sample_size / expected_table_row_count).alias("expected")

        self.keys = self._discover_keys(expected_df_sample)

        actual_df_sample = spark.read.table(actual.three_level_namespace).join(expected_df_sample.select(self.keys), on=self.keys, how="inner").alias("actual")
        return expected_df_sample, actual_df_sample
    
    def compare(self):
        expected_df_sample, actual_df_sample = self._get_samples(self.expected_table, self.actual_table)

        data_cols = self.common_cols

        def add_hash(df):
            return df.withColumn("row_hash", F.sha2(F.concat_ws("||", *data_cols), 256))
        
        expected_df_hashed = add_hash(expected_df_sample).select(*self.keys, "row_hash")
        actual_df_hashed = add_hash(actual_df_sample).select(*self.keys, "row_hash")

        diff_expected = expected_df_hashed.exceptAll(actual_df_hashed)
        diff_actual = actual_df_hashed.exceptAll(expected_df_hashed)

        issue_count = diff_expected.count() - diff_actual.count()

        if issue_count == 0:
            print("SUCCESS")
            return True
        else:
            print(f"FAILED: Found {issue_count} discrepencies")
            return False

In [0]:
tc = TableComparator("dev.test.employees", "dev.test.employees")
tc.compare()