In [0]:
#%pip install databricks-labs-dqx==0.9.2 

In [0]:
#dbutils.library.restartPython()

In [0]:
#%run /Workspace/Users/glamero17@gmail.com/test_dq/utils/custom_rules_library_py

In [0]:
#%run /Workspace/Users/glamero17@gmail.com/test_dq/utils/dq_utils

In [0]:
# ===============================
# dq_framework_runner_final.py
# ===============================

import json, time, uuid
from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, lag, when, sum as sum_

# --- Import DQX dinámico ---
import databricks.labs.dqx.check_funcs as dqx

# --- Import funciones custom ---
from utils.custom_rules_library_py import custom_validate_dni
from utils.dq_utils import log_execution_start, log_execution_finish, log_validations_traceability, log_evidences_staging

spark = SparkSession.builder.getOrCreate()

# --- Constantes de tablas ---
CATALOG = "workspace"
SCHEMA = "dq_framework"
TABLE_CONFIG = f"{CATALOG}.{SCHEMA}.dq_tables_config"
RULE_LIB_TABLE = f"{CATALOG}.{SCHEMA}.dq_rules_library"
CATALOG_TABLE = f"{CATALOG}.{SCHEMA}.dq_validations_catalog"
EXECUTION_TABLE = f"{CATALOG}.{SCHEMA}.dq_execution_traceability"
VALIDATIONS_TABLE = f"{CATALOG}.{SCHEMA}.dq_validations_traceability"

# Leer reglas BUILT-IN del catálogo
df_rules_library = spark.table(RULE_LIB_TABLE) \
                .filter(col("implementation_type") == "BUILT-IN") \
                .select("technical_rule_name", "rule_level") \
                .collect()

BUILT_IN_RULE_FUNCTIONS = {}
ROW_LEVEL_FUNCS = set()
DATASET_LEVEL_FUNCS = set()

# Crear diccionario dinámico literal: función DQX, set de funciones nivel fila y set de funciones nivel dataset

for row in df_rules_library:
    func_name = row.technical_rule_name
    if hasattr(dqx, func_name):
        BUILT_IN_RULE_FUNCTIONS[func_name] = getattr(dqx, func_name)
        if row.rule_level.upper() == "ROW":
            ROW_LEVEL_FUNCS.add(func_name)
        elif row.rule_level.upper() == "DATASET":
            DATASET_LEVEL_FUNCS.add(func_name)
    else:
        print(f"⚠️ Función DQX {func_name} no encontrada")

CUSTOM_RULE_FUNCTIONS = {
    "custom_validate_dni": custom_validate_dni,
    "is_valid_nif_es": is_valid_nif_es
}

RULES_REQUIRING_LITERAL_CONVERSION = [
    "is_equal_to", "is_not_equal_to", "is_not_less_than", "is_not_greater_than"
]

# ===============================
# Ejecutores auxiliares
# ===============================

def execute_builtin_rule(df, rule_name, rule_type, params, primary_key):
    """Ejecuta reglas BUILT-IN y devuelve df_failed y status"""
    df_failed = None
    status = "PASSED"

    if rule_name in ROW_LEVEL_FUNCS:
        if "column" in params:
            params["column"] = col(params["column"])
        condition_col = built_in_rule_functions[rule_name](**params)
        df_failed = df.filter(condition_col.isNotNull())
    elif rule_name in DATASET_LEVEL_FUNCS:
        condition_col, apply_func = built_in_rule_functions[rule_name](**params)
        df_failed = apply_func(df, spark, {}).filter(condition_col.isNotNull())
    else:
        status = "ERROR"
        print(f"BUILT-IN desconocida: {rule_name}")
    return df_failed, status

def execute_custom_rule(df, rule_name, params):
    """Ejecuta reglas CUSTOM Python"""
    df_failed = None
    status = "PASSED"
    func = CUSTOM_RULE_FUNCTIONS.get(rule_name)
    if func:
        df_failed = func(df, params.get("columns"))
    else:
        status = "ERROR"
        print(f"CUSTOM desconocida: {rule_name}")
    return df_failed, status

def execute_sql_rule(df, rule_name, params):
    """Ejecuta reglas CUSTOM_SQL_UDF"""
    df_failed = None
    status = "PASSED"
    udf_name = rule_name
    columns = params.get("columns", [])
    udf_func = CUSTOM_RULE_FUNCTIONS.get(rule_name)
    if udf_func:
        if udf_name and columns:
            sql_expr = f"{udf_name}({', '.join([ '`{}`'.format(c) for c in columns ])})"
            df_failed = df.filter(~expr(sql_expr))#.cache()
        else:
            status = "ERROR"
            print(f"CUSTOM_SQL_UDF mal configurada: {rule_name}")
    else:
        status = "ERROR"
        raise ValueError(f"UDF '{udf_name}' no está en la whitelist y no se puede ejecutar")
    return df_failed, status
    
# ===============================
# Runner por tabla
# ===============================

def run_validations_for_table(table_name, severity_param=None, validation_id_param=None):
    spark = SparkSession.builder.getOrCreate()
    execution_id = str(uuid.uuid4())
    exec_timestamp = datetime.now()
    exec_date = exec_timestamp.date()

    # --- Leer configuración de tabla ---
    config_row = spark.table(TABLE_CONFIG).filter(col("table_name") == table_name).first()
    if not config_row:
        print(f"No se encontró configuración para tabla '{table_name}'.")
        return

    table_id = config_row.table_id
    table_tech_name = config_row.table_name_tech
    staging_table = config_row.staging_evidences_table
    primary_key = config_row.primary_key

    log_execution_start(execution_id, exec_timestamp, table_id, EXECUTION_TABLE)

    # --- Leer validaciones ---
    filter_query = f"table_id='{table_id}' AND is_active=True"
    if severity_param: filter_query += f" AND severity='{severity_param}'"
    if validation_id_param: filter_query += f" AND validation_id='{validation_id_param}'"

    df_validations = spark.table(CATALOG_TABLE).filter(filter_query)
    if df_validations.isEmpty():
        print(f"No hay validaciones activas para {table_name}.")
        log_execution_finish(execution_id, "SUCCESS", 0, 0, 0, EXECUTION_TABLE)
        return

    df_rules_lib = spark.table(RULE_LIB_TABLE).select("rule_id", "technical_rule_name", col("implementation_type").alias("rule_type"))
    rules_to_run = df_validations.join(df_rules_lib, "rule_id", "left").collect()

    df = spark.table(table_tech_name)
    total_records = df.count()
    validations_executed = 0
    validations_failed = 0

    for v in rules_to_run:
        validations_executed += 1
        validation_id = v.validation_id
        rule_id = v.rule_id
        rule_type = v.rule_type
        technical_rule_name = v.technical_rule_name
        validation_params = json.loads(v.validation_definition or "{}")
        failed_field = validation_params.get("column") or (validation_params.get("columns") or [primary_key])[0]

        df_failed = None
        rule_status = "PASSED"

        try:
            
            # Conversión literal
            if technical_rule_name in RULES_REQUIRING_LITERAL_CONVERSION:
                for key in ["value", "limit"]:
                    if key in validation_params and isinstance(validation_params[key], str):
                        validation_params[key] = lit(validation_params[key])

            
            # --- Aplicar filtro de perímetro si existe ---
            perimeter_sql = getattr(v, "perimeter_definition", None)
            df_filtered = df  # fallback: todo el dataset
            total_records_perimeter = df.count()

            if perimeter_sql:
                try:
                    df_filtered = df.filter(perimeter_sql)
                    total_records_perimeter = df_filtered.count()
                    print(f"Aplicando filtro de perímetro: '{perimeter_sql}' -> {total_records_perimeter} registros")
                except Exception as e:
                    print(f"⚠️ Filtro de perímetro inválido '{perimeter_sql}': {e}")
                    print("Ejecutando validación sobre el dataset completo.")
                    df_filtered = df
                    total_records_perimeter = df.count()

                    
            # Ejecutar regla
            if rule_type == "BUILT-IN":
                df_failed, rule_status = execute_builtin_rule(df, technical_rule_name, rule_type, validation_params, primary_key)
            elif rule_type == "CUSTOM":
                df_failed, rule_status = execute_custom_rule(df, technical_rule_name, validation_params)
            elif rule_type == "CUSTOM_SQL_UDF":
                df_failed, rule_status = execute_sql_rule(df, technical_rule_name, validation_params)
            else:
                rule_status = "ERROR"
                print(f"Tipo de regla desconocido: {rule_type}")

            failed_count = 0
            if df_failed is not None and not df_failed.isEmpty():
                failed_count = df_failed.count()
                validations_failed += 1
                rule_status = "FAILED"
                log_evidences_staging(df_failed, staging_table, execution_id, exec_date, validation_id, primary_key, failed_field, CATALOG, SCHEMA)

            log_validations_traceability(execution_id, exec_date, validation_id, rule_id, rule_status, total_records_perimeter, failed_count, VALIDATIONS_TABLE)

        except Exception as e:
            validations_failed += 1
            print(f"Error crítico en regla {technical_rule_name}: {e}")
            log_validations_traceability(execution_id, exec_date, validation_id, rule_id, "ERROR", total_records_perimeter, -1, VALIDATIONS_TABLE)

    # --- Fin de ejecución tabla ---
    duration_seconds = round((datetime.now() - exec_timestamp).total_seconds(), 2)
    status = "SUCCESS" if validations_failed == 0 else "FAILED"
    log_execution_finish(execution_id, status, duration_seconds, validations_executed, validations_failed, EXECUTION_TABLE)
    print(f"Tabla {table_name} completada con estado: {status}")
    return execution_id, status

# ===============================
# Runner por job (múltiples tablas)
# ===============================

def run_validations_for_tables(tables_param=None, severity_param=None, validation_id_param=None):
    """
    Ejecuta validaciones para todas las tablas indicadas o todas si tables_param=None
    """
    spark = SparkSession.builder.getOrCreate()
    tables_param_clean = tables_param.strip() if tables_param else None

    if not tables_param_clean or tables_param_clean.lower() == "all":
        df_tables = spark.table(TABLE_CONFIG).select("table_name").collect()
        tables_to_run = [r.table_name for r in df_tables]
    else:
        tables_to_run = [t.strip() for t in tables_param_clean.split(",") if t.strip()]

    print(f"Tablas a ejecutar: {tables_to_run}")

    for table_name in tables_to_run:
        run_validations_for_table(table_name, severity_param, validation_id_param)

# ===============================
# Punto de entrada
# ===============================

if __name__ == "__main__":
    tables_param = dbutils.widgets.get("table_name")  # puede ser vacío o "ALL"
    severity_param = dbutils.widgets.get("severity_filter")
    validation_id_param = dbutils.widgets.get("validation_id_filter")
    run_validations_for_tables(tables_param, severity_param, validation_id_param)