In [0]:
%pip install databricks-labs-dqx==0.9.2

In [0]:
'''
dq framework runner

1. 'main()' orquesta el proceso
2. 'execute_builtin_rule' ejecuta las reglas BUILT-IN
3. 'execute_custom_rule' ejecuta las reglas custom (Python o SQL)
4. 'run_validations_for_table' ejecuta todas las reglas de cada tabla, logueando trazabilidad
'''

# 1. Imports
import json, time, uuid
from datetime import datetime
from typing import Callable, Tuple
from pyspark.sql import DataFrame,SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, lag, when, col, lit, expr, sum as sum_
import importlib
import databricks.labs.dqx.check_funcs as dqx
import inspect
import utils.custom_rules_library_py as cr
import utils.dq_utils as dq
importlib.reload(cr)
importlib.reload(dq)

spark = SparkSession.builder.getOrCreate()

# 2. Widgets

dbutils.widgets.text("catalog_name", "workspace", "Catálogo de UC donde residen las tablas")
dbutils.widgets.text("schema_name", "framework_dq", "Esquema de UC donde residen las tablas")

CATALOG = dbutils.widgets.get("catalog_name")
SCHEMA = dbutils.widgets.get("schema_name")

# 3. Constantes de tablas
TABLE_CONFIG = f"{CATALOG}.{SCHEMA}.dq_tables_config"
RULE_LIB_TABLE = f"{CATALOG}.{SCHEMA}.dq_rules_library"
CATALOG_TABLE = f"{CATALOG}.{SCHEMA}.dq_validations_catalog"
EXECUTION_TABLE = f"{CATALOG}.{SCHEMA}.dq_execution_traceability"
VALIDATIONS_TABLE = f"{CATALOG}.{SCHEMA}.dq_validations_traceability"

# 4. Leer reglas BUILT-IN del catálogo
df_rules_library = spark.table(RULE_LIB_TABLE) \
                .filter(col("implementation_type") == "BUILT-IN") \
                .select("technical_rule_name", "rule_level") \
                .collect()

# 5. Crear diccionarios dinámicos: reglas DQX (set de funciones nivel fila y set de funciones nivel dataset), reglas Custom y reglas con literal conversion

BUILT_IN_RULE_FUNCTIONS = {}
ROW_LEVEL_FUNCS = set()
DATASET_LEVEL_FUNCS = set()
CUSTOM_RULE_FUNCTIONS = {}
RULES_REQUIRING_LITERAL_CONVERSION = []

for row in df_rules_library:
    tech_name = row.technical_rule_name

    # Built-in DQX
    if hasattr(dqx, tech_name):
        BUILT_IN_RULE_FUNCTIONS[tech_name] = getattr(dqx, tech_name)
        if row.rule_level.upper() == "ROW":
            ROW_LEVEL_FUNCS.add(tech_name)
        elif row.rule_level.upper() == "DATASET":
            DATASET_LEVEL_FUNCS.add(tech_name)

    elif row.implementation_type.upper() == "CUSTOM":
        if hasattr(cr, tech_name):
            # Asocia automáticamente con tu módulo cr
            CUSTOM_RULE_FUNCTIONS[tech_name] = getattr(cr, tech_name)
        else:
            # Función SQL en catálogo
            CUSTOM_RULE_FUNCTIONS[tech_name] = f"{CATALOG}.{SCHEMA}.{row.sql_function}"
    
    # Reglas que requieren literal conversion
    if hasattr(row, "requires_literal_conversion") and row.requires_literal_conversion:
        RULES_REQUIRING_LITERAL_CONVERSION.append(tech_name)
'''
CUSTOM_RULE_FUNCTIONS = {
    "custom_validate_dni": cr.custom_validate_dni,
    "is_valid_nif_es": f"{CATALOG}.{SCHEMA}.is_valid_nif_es"
}

RULES_REQUIRING_LITERAL_CONVERSION = [
    "is_equal_to", "is_not_equal_to", "is_not_less_than", "is_not_greater_than"
]
'''    

# 6. Función auxiliar para envolver consultas SQL del usuario
def sql_query_wrapper(user_query: str, merge_columns: list[str]) -> Tuple[col, Callable]:
    '''
    Ejecuta una consulta SQL del usuario sin que tenga que definir la columna 'condition' en la query.
    Devuelve condition_col y apply_func compatibles con DQX
    '''
    if not merge_columns:
        raise ValueError("merge_columns no puede estar vacío")

    unique_condition_column = "condition_tmp"

    def apply_func(df: DataFrame, spark: SparkSession, _: dict) -> DataFrame:
        df_result = spark.sql(user_query)
        df_result = df_result.withColumn(unique_condition_column, lit(True))
        return df_result

    condition_col = col(unique_condition_column)
    return condition_col, apply_func 

# 7. Ejecución de las reglas BUILT-IN
def execute_builtin_rule(df, rule_name, rule_type, params, primary_key):
    '''Ejecuta reglas BUILT-IN y devuelve df_failed y status'''
    df_failed = None
    status = "PASSED"

    if rule_name in ROW_LEVEL_FUNCS:
        func_to_call = BUILT_IN_RULE_FUNCTIONS[rule_name]
        sig = inspect.signature(func_to_call)

        if "df" in sig.parameters:
            condition_col = BUILT_IN_RULE_FUNCTIONS[rule_name](df, **params)
        else:
            if "column" in params:
                params["column"] = col(params["column"])
            condition_col = BUILT_IN_RULE_FUNCTIONS[rule_name](**params)
            df_failed = df.filter(condition_col.isNotNull())

    elif rule_name in DATASET_LEVEL_FUNCS:

        # 1. Obtener la columna de condición y la función DQX
        if rule_name == "sql_query":
            condition_col, apply_func_dqx = sql_query_wrapper(params["query"], params["merge_columns"])      
        else:
            # Llamada genérica para obtener la regla
            condition_col, apply_func_dqx = BUILT_IN_RULE_FUNCTIONS[rule_name](**params)
            
        # 2. Wrapper de la función para adaptar a los argumentos de DQX
        # Análisis de qué argumentos pide la función apply_func
        sig = inspect.signature(apply_func_dqx)
        params_count = len(sig.parameters)
        
        # Definición del wrapper basado en lo que pide la función para ajustar firma: df, spark, context
        def apply_func(df, spark, context):
            if params_count == 1:
                return apply_func_dqx(df)
            elif params_count == 2:
                return apply_func_dqx(df, spark)
            else:
                return apply_func_dqx(df, spark, context)
        try:
            df_failed = apply_func(df, spark, {}).filter(condition_col.isNotNull())
        except Exception as e:
            raise RuntimeError(f"Error ejecutando la lógica interna de {rule_name}: {str(e)}")

    else:
        status = "ERROR"
        print(f"BUILT-IN desconocida: {rule_name}")

    return df_failed, status

# 8. Ejecución de las reglas Custom
def execute_custom_rule(df, rule_name, params):
    '''Ejecuta reglas Custom (Python o SQL) y devuelve df_failed y status'''

    df_failed = None
    status = "PASSED"
    
    columns = params.get("columns", [])
    if not columns:
        print(f"Regla custom mal configurada: {rule_name}, no se especificaron columnas")
        return df_failed, "ERROR"

    # Comprobar que la regla existe en la whitelist
    udf_func = CUSTOM_RULE_FUNCTIONS.get(rule_name)
    if not udf_func:
        print(f"UDF '{rule_name}' no está en la whitelist y no se puede ejecutar")
        return df_failed, "ERROR"
    
    # Ejecutar regla SQL
    if isinstance(udf_func, str):
        # Construye la expresión SQL
        sql_expr = f"{udf_func}({', '.join([ '`{}`'.format(c) for c in columns ])})"
        df_failed = df.filter(~expr(sql_expr))
        return df_failed, status
    # Ejecutar regla Python
    elif callable(udf_func):
        df_failed = udf_func(df, params.get("columns"))
        return df_failed, status
    # Desconocida
    else:
        print(f"Regla custom desconocida: {rule_name}")
        return df_failed, "ERROR"

# 9. Runner por tabla
def run_validations_for_table(table_name, exec_timestamp, execution_id, severity_param=None, validation_id_param=None):
    '''Ejecuta todas las validaciones de una tabla y retorna estado y conteos'''

    spark = SparkSession.builder.getOrCreate()
    exec_date = exec_timestamp.date()

    #  Leer configuración de tabla 
    config_row = spark.table(TABLE_CONFIG).filter(col("table_name") == table_name).first()
    if not config_row:
        print(f"No se encontró configuración para la tabla '{table_name}'.")
        return "SKIPPED", 0, 0

    table_id = config_row.table_id
    table_tech_name = config_row.table_name_tech
    staging_table = config_row.staging_evidences_table
    primary_key = config_row.primary_key
   
    # Leer validaciones activas
    filter_query = f"table_id='{table_id}' AND is_active=True"
    if severity_param: filter_query += f" AND severity='{severity_param}'"
    if validation_id_param: filter_query += f" AND validation_id='{validation_id_param}'"

    df_validations = spark.table(CATALOG_TABLE).filter(filter_query)
    if df_validations.limit(1).count() == 0:
        print(f"No hay validaciones activas para {table_name}.")
        dq.log_execution_finish(execution_id, "SUCCESS", 0, 0, 0, EXECUTION_TABLE)
        return "SUCCESS", 0, 0

    df_rules_lib = spark.table(RULE_LIB_TABLE).select("rule_id", "technical_rule_name", col("implementation_type").alias("rule_type"))
    rules_to_run = df_validations.join(df_rules_lib, "rule_id", "left").collect()

    df = spark.table(table_tech_name)
    total_records = df.count()
    validations_executed = 0
    validations_failed = 0

    for rule in rules_to_run:
        validations_executed += 1
        validation_id = rule.validation_id
        rule_id = rule.rule_id
        rule_type = rule.rule_type
        technical_rule_name = rule.technical_rule_name
        validation_params = json.loads(rule.validation_definition or "{}")
        failed_field = validation_params.get("column") or (validation_params.get("columns") or [primary_key])[0]

        df_failed = None
        status = "PASSED"

        try:
            
            # Conversión literal. Algunas reglas necesitan el literal como parámetro y no la columna
            if technical_rule_name in RULES_REQUIRING_LITERAL_CONVERSION:
                for key in ["value", "limit"]:
                    if key in validation_params and isinstance(validation_params[key], str):
                        validation_params[key] = lit(validation_params[key])
            
            # Aplicar filtro de perímetro si existe
            perimeter_sql = getattr(rule, "perimeter_definition", None)
            df_filtered = df
            total_records_perimeter = df.count()

            if perimeter_sql:
                try:
                    df_filtered = df.filter(perimeter_sql)
                    total_records_perimeter = df_filtered.count()
                    print(f"Aplicando filtro de perímetro: '{perimeter_sql}' -> {total_records_perimeter} registros")
                except Exception as e:
                    print(f"Filtro de perímetro inválido '{perimeter_sql}': {e}")
                    print("Ejecutando validación sobre el dataset completo")
                    df_filtered = df
                    total_records_perimeter = df.count()
                    
            # Ejecutar regla
            if rule_type == "BUILT-IN":
                df_failed, status = execute_builtin_rule(df, technical_rule_name, rule_type, validation_params, primary_key)
            elif rule_type == "CUSTOM":
                df_failed, status = execute_custom_rule(df, technical_rule_name, validation_params)
            else:
                status = "ERROR"
                print(f"Tipo de regla desconocido: {rule_type}")

            failed_count = 0
            if df_failed is not None and df_failed.limit(1).count() > 0:
                failed_count = df_failed.count()
                validations_failed += 1
                status = "FAILED"
                dq.log_evidences_staging(df_failed, staging_table, execution_id, exec_date, validation_id, primary_key, failed_field, CATALOG, SCHEMA)

            dq.log_validations_traceability(execution_id, exec_date, validation_id, rule_id, status, total_records_perimeter, failed_count, VALIDATIONS_TABLE)

        except Exception as e:
            validations_failed += 1
            print(f"Error crítico en regla {technical_rule_name}: {e}")
            dq.log_validations_traceability(execution_id, exec_date, validation_id, rule_id, "ERROR", total_records_perimeter, -1, VALIDATIONS_TABLE)

    status = "SUCCESS" if validations_failed==0 else "FAILED"
    return status, validations_executed, validations_failed

# 10. Runner por múltiples tablas

def run_validations_for_tables(tables_param=None, severity_param=None, validation_id_param=None):
    '''
    Ejecuta validaciones para todas las tablas indicadas o todas si tables_param=None or ALL
    '''

    spark = SparkSession.builder.getOrCreate()
    tables_param_clean = tables_param.strip() if tables_param else None

    total_validations_executed = 0
    total_validations_failed = 0

    if not tables_param_clean or tables_param_clean.lower() == "all":
        df_tables = spark.table(TABLE_CONFIG).select("table_name").collect()
        tables_to_run = [r.table_name for r in df_tables]
    else:
        tables_to_run = [t.strip() for t in tables_param_clean.split(",") if t.strip()]
 
    exec_timestamp = datetime.now()
    execution_id = str(uuid.uuid4())
    dq.log_execution_start(execution_id, exec_timestamp, tables_to_run, EXECUTION_TABLE)

    for table_name in tables_to_run:
        status, validations_executed, validations_failed = run_validations_for_table(table_name, exec_timestamp, execution_id, severity_param, validation_id_param)
        total_validations_executed += validations_executed
        total_validations_failed += validations_failed
    
    duration_seconds = round((datetime.now() - exec_timestamp).total_seconds(), 2)
    dq.log_execution_finish(execution_id, status, duration_seconds, total_validations_executed, total_validations_failed, EXECUTION_TABLE)

# 11. Punto de entrada
if __name__ == "__main__":
    dbutils.widgets.text("severity_filter", "", "Filtro de severidad (Opcional)")
    dbutils.widgets.text("validation_id_filter", "", "Filtro para 1 validación específica (Opcional)")
    dbutils.widgets.text("table_name", "", "Id de la tabla a validar")
    tables_param = dbutils.widgets.get("table_name")
    severity_param = dbutils.widgets.get("severity_filter")
    validation_id_param = dbutils.widgets.get("validation_id_filter")
    run_validations_for_tables(tables_param, severity_param, validation_id_param)