### Imports

In [None]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, countDistinct, count, isnan, lit
from pyspark.sql.utils import AnalysisException
import json

In [None]:
# Get metadata for tables to check | Location can be parameterised

metadata_path = '..\\output_obj\\tbl_metadata.json'

with open(metadata_path, "r") as f:
    metadata = json.load(f)

print("Loaded JSON data:", metadata)

### Setup spark session

In [None]:
# spark = SparkSession.builder \
#     .appName("LocalTest") \
#     .master("local[*]") \
#     .getOrCreate()

spark = SparkSession.builder \
    .appName("MyDockerSparkApp") \
    .config("spark.jars.packages", "org.apache.hadoop:hadoop-azure:3.3.0,com.microsoft.azure:azure-storage:8.6.6") \
    .getOrCreate()

# Debug via docker
spark.conf.set("fs.azure.impl", "org.apache.hadoop.fs.azure.NativeAzureFileSystem")
print("PySpark version:", pyspark.__version__)
hadoop_version = spark.sparkContext._jvm.org.apache.hadoop.util.VersionInfo.getVersion()
print("Hadoop version:", hadoop_version)  

# spark.stop()

### Read from UC

In [None]:
# Catalog and schema information | Can be parameterised in actual usecase
catalog_name = 'data_foundation_dev'
schema_name = 'raw'

In [None]:
# Retrieve PK and FK from metadata
pk = {i:j['pk'] for i, j in metadata.items()}
fk = {i:j['fk'] for i, j in metadata.items()}

In [None]:

results = []

tables = [t.name for t in spark.catalog_name.listTables(f"{catalog_name}.{schema_name}")]

for tbl_name in tables:
    three_pt_name = f"{catalog_name}.{schema_name}.{tbl_name}"
    print(f"Checking table: {three_pt_name}")
    table_result={}
    
    try:
        df = spark.read.table(three_pt_name)
        row_count = df.count()
        table_result = {"table": tbl_name, "row_count": row_count}

        # Check 1: Pk uniqueness
        pk_col = pk[tbl_name]
        pk_count = df.select(pk_col).dropDuplicates().count()
        is_unique = pk_count == row_count
        table_result["pk_check"] = "PASS" if is_unique else "FAIL"

        # Check 2: Not nulls
        nulls = {col_name: df.filter(col(col_name).isNull() | isnan(col(col_name))).count() 
                 for col_name in df.columns}
        null_columns = [k for k, v in nulls.items() if v > 0]
        table_result["null_columns"] = ", ".join(null_columns) if null_columns else "None"

        # Check 3: Fk existence
        fk_ref = fk[tbl_name]
        if fk_ref:
            fk_results = []
            for fk_col, ref_table in fk_ref.items():
                ref_df = spark.read.table(f"{catalog_name}.{schema_name}.{ref_table}")
                missing_count = df.filter(~col(fk_col).isin(ref_df.select(fk_col).distinct().collect())).count()
                fk_results.append(f"{fk_col}: {'PASS' if missing_count == 0 else 'FAIL'}")                          # TODO: validate
            table_result["fk_check"] = ", ".join(fk_results)
        else:
            table_result["fk_check"] = "SKIPPED"

        # Check 4: inferred data types
        types = [f"{f.name}:{f.dataType.simpleString()}" for f in df.schema_name.fields]
        table_result["columns"] = ", ".join(types)                                                                  # TODO: comare with dtype col in metadata
    
    except AnalysisException as e:
        table_result = {
            "table": tbl_name,
            "error": str(e)
        }
    
    results.append(table_result)