# Data Quality Framework

This notebook runs data quality tests defined in a YAML configuration file **in parallel** using Spark SQL.  
Each test executes as a separate concurrent task, speeding up validation across multiple tables and checks.

Tests supported include:
- Row count minimum
- Null checks on specified columns
- Minimum and maximum value checks on columns

Results are collected and displayed with pass/fail status and metric values.

### Example YAML Config Structure (`checks.yml`)

```yaml
tables:
  - name: silver.customers
    tests:
      - type: row_count
        min: 1

      - type: not_null
        column: customer_id

      - type: min_value
        column: total_amount
        min: 0

      - type: max_value
        column: total_amount
        max: 10000

  - name: silver.orders
    tests:
      - type: row_count
        min: 1

      - type: not_null
        column: order_id


In [None]:
debug=True
yaml_file_path="/lakehouse/default/Files/Tests/checks.yml"
parallelism=5

In [None]:
import yaml
from concurrent.futures import ThreadPoolExecutor, as_completed

# Load your config
with open(yaml_file_path, "r") as f:
    config = yaml.safe_load(f)

results = []

def run_test(table_cfg, test):
    table = table_cfg["name"]

    if test["type"] == "row_count":
        sql = f"SELECT COUNT(*) AS cnt FROM {table}"
        count = spark.sql(sql).collect()[0]["cnt"]
        passed = count >= test["min"]
        return (table, "row_count", passed, count)

    elif test["type"] == "not_null":
        col = test["column"]
        sql = f"SELECT COUNT(*) AS cnt_nulls FROM {table} WHERE {col} IS NULL"
        null_count = spark.sql(sql).collect()[0]["cnt_nulls"]
        passed = null_count == 0
        return (table, f"not_null_{col}", passed, null_count)

    elif test["type"] == "min_value":
        col = test["column"]
        sql = f"SELECT MIN({col}) AS min_val FROM {table}"
        min_val = spark.sql(sql).collect()[0]["min_val"]
        passed = min_val >= test["min"]
        return (table, f"min_value_{col}", passed, min_val)

    elif test["type"] == "max_value":
        col = test["column"]
        sql = f"SELECT MAX({col}) AS max_val FROM {table}"
        max_val = spark.sql(sql).collect()[0]["max_val"]
        passed = max_val <= test["max"]
        return (table, f"max_value_{col}", passed, max_val)

    else:
        return (table, f"unknown_test_type_{test['type']}", False, None)

# Create a ThreadPoolExecutor with as many workers as you want (e.g. 5)
with ThreadPoolExecutor(max_workers=5) as executor:
    future_to_test = {}
    for table_cfg in config["tables"]:
        for test in table_cfg["tests"]:
            future = executor.submit(run_test, table_cfg, test)
            future_to_test[future] = (table_cfg["name"], test)

    for future in as_completed(future_to_test):
        result = future.result()
        results.append(result)

# Show results
if debug:
    for r in results:
        table, test_name, passed, val = r
        print(f"{table} - {test_name} => {'PASS' if passed else 'FAIL'} (Value: {val})")


In [None]:
from pyspark.sql import Row
from pyspark.sql.types import StructType, StructField, StringType, BooleanType, DoubleType, TimestampType
import pyspark.sql.functions as F

# Define the schema explicitly
schema = StructType([
    StructField("table_name", StringType(), nullable=False),
    StructField("test_name", StringType(), nullable=False),
    StructField("passed", BooleanType(), nullable=False),
    StructField("value", DoubleType(), nullable=True),
    StructField("run_timestamp", TimestampType(), nullable=False),
])

# Prepare rows with run_timestamp
import datetime
now = datetime.datetime.now()

rows = [
    Row(
        table_name=r[0],
        test_name=r[1],
        passed=r[2],
        value=float(r[3]) if r[3] is not None else None,
        run_timestamp=now
    )
    for r in results
]

# Create DataFrame with the schema
results_df = spark.createDataFrame(rows, schema=schema)

# Write to Delta table
results_df.write.format("delta").mode("append").saveAsTable("dbo.data_quality_results")


if debug:
    # Show results
    results_df.show(truncate=False)
