In [None]:
## THIS IS A TEST CODE
import great_expectations as ge
from great_expectations.checkpoint.checkpoint import Checkpoint
from great_expectations.core.batch import RuntimeBatchRequest
import pandas as pd
from datetime import datetime
import pytz
from pyspark.sql.functions import col, upper, lit, current_timestamp, row_number
from pyspark.sql.types import StringType
from pyspark.sql.window import Window
from uuid import uuid4
import json
import logging
from typing import Dict, List, Optional, Tuple, Any


class DataQualityValidator:
    """
    A class to handle data quality validation using Great Expectations
    and write clean data to Silver layer.
    """
    
    def __init__(self, config: Dict[str, Any], spark_session=None):
        """
        Initialize the Data Quality Validator
        
        Args:
            config: Configuration dictionary containing paths and settings
            spark_session: Spark session instance (if None, will try to get from globals)
        """
        self.config = config
        self.logger = self._setup_logging()
        self.context = ge.get_context()
        self.data_assets_cache = {}
        
        # Handle Spark session
        if spark_session is not None:
            self.spark = spark_session
        else:
            # Try to get spark from globals (common in notebook environments)
            try:
                global spark
                self.spark = spark
            except NameError:
                try:
                    from pyspark.sql import SparkSession
                    self.spark = SparkSession.getActiveSession()
                    if self.spark is None:
                        raise RuntimeError("No active Spark session found")
                except Exception as e:
                    raise RuntimeError(f"Could not initialize Spark session: {str(e)}")
    
        # Set required Spark configurations
        self.spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "false")
        self.spark.conf.set("spark.sql.parquet.datetimeRebaseModeInRead", "LEGACY")
        self.spark.conf.set("spark.sql.parquet.datetimeRebaseModeInWrite", "LEGACY")    

        # Suppress GE logging
        logging.getLogger("great_expectations").setLevel(logging.ERROR)
        
    def _setup_logging(self) -> logging.Logger:
        """Setup logging configuration"""
        logger = logging.getLogger(self.__class__.__name__)
        logger.setLevel(self.config.get('log_level', logging.INFO))
        
        if not logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
            )
            handler.setFormatter(formatter)
            logger.addHandler(handler)
            
        return logger
    
    def load_validation_expectations(self) -> pd.DataFrame:
        """
        Load validation expectations from the configuration table
        
        Returns:
            DataFrame containing validation expectations for selected tables
        """
        expectations_df = (
            self.spark.read.format("delta")
            .load(self.config['expectations_table_path'])
            .filter(upper(col("Selected").cast("string")) == "TRUE")
            .toPandas()
        )
        
        self.logger.info(f"Loaded {len(expectations_df)} validation expectations")
        return expectations_df
    
    def get_selected_tables(self, expectations_df: pd.DataFrame) -> List[str]:
        """
        Get list of selected tables for processing
        
        Args:
            expectations_df: DataFrame containing validation expectations
            
        Returns:
            List of table names to process
        """
        selected_tables = expectations_df['Table_Name'].unique().tolist()
        self.logger.info(f"Found {len(selected_tables)} tables to process")
        return selected_tables
    
    def create_data_source(self, table_name: str) -> str:
        """
        Create a Great Expectations data source for the table
        
        Args:
            table_name: Name of the table
            
        Returns:
            Data source name
        """
        data_source_name = f"spark_src_{table_name}_{uuid4().hex[:8]}"
        
        self.context.add_datasource(
            name=data_source_name,
            class_name="Datasource",
            execution_engine={"class_name": "SparkDFExecutionEngine"},
            data_connectors={
                "default_runtime_data_connector_name": {
                    "class_name": "RuntimeDataConnector",
                    "batch_identifiers": ["default_identifier_name"],
                }
            }
        )
        
        return data_source_name
    
    def load_and_prepare_data(self, table_name: str) -> Any:
        """
        Load and prepare data for validation
        
        Args:
            table_name: Name of the table to load
            
        Returns:
            Spark DataFrame with prepared data
        """
        # Load data
        table_path = f"{self.config['source_base_path']}/{table_name}"
        clean_df = self.spark.read.format("delta").load(table_path)
        
        # Remove null columns
        non_null_cols = [
            c for c in clean_df.columns 
            if clean_df.filter(f"{c} IS NOT NULL").limit(1).count() > 0
        ]
        clean_df = clean_df.select(*non_null_cols)
        
        # Filter by delete flag if exists
        delete_flag_column = next(
            (c for c in clean_df.columns if 'delete_flag' in c.lower()), 
            None
        )
        if delete_flag_column:
            clean_df = clean_df.filter(col(delete_flag_column) == 0)
            self.logger.info(f"Filtered {delete_flag_column} == 0 for table {table_name}")
        
        # Add row index for tracking
        indexed_df = clean_df.withColumn(
            "__row_idx", 
            row_number().over(Window.orderBy(lit(1))) - 1
        )
        
        self.logger.info(f"Prepared {indexed_df.count()} rows for table {table_name}")
        return indexed_df
    
    def create_expectation_metadata(self, expectation_for_table: pd.DataFrame) -> Dict[Tuple, Dict]:
        """
        Create metadata mapping for expectations
        
        Args:
            expectation_for_table: DataFrame containing expectations for current table
            
        Returns:
            Dictionary mapping (rule, column) to metadata
        """
        expectation_metadata_map = {}
        
        for _, row in expectation_for_table.iterrows():
            col_nm = row['Column_Name'].strip()
            rule = row['validation_rule'].strip()
            cond = row['Condition'].strip() if pd.notna(row['Condition']) else None
            
            meta = {
                "Source_System": self.config.get('source_system', 'Netforum'),
                "Table_Name": row.get('Table_Name', ''),
                "Column_Name": col_nm,
                "AI_ColumnName": row.get('AI_ColumnName', ""),
                "Index_Key": row.get('Index_Key', ""),
                "Val_Description": row.get('Val_Description', ""),
                "Condition": cond,
                "Minimum_Range": row.get('Minimum_Range'),
                "Maximum_Range": row.get('Maximum_Range'),
                "AI_Reasoning": row.get('AI_Reasoning', ""),
                "SampleData": row.get('SampleData', ""),
            }
            expectation_metadata_map[(rule, col_nm)] = meta
            
        return expectation_metadata_map
    
    def add_expectations_to_validator(self, validator: Any, expectation_for_table: pd.DataFrame) -> None:
        """
        Add expectations to the validator based on validation rules
        
        Args:
            validator: Great Expectations validator
            expectation_for_table: DataFrame containing expectations for current table
        """
        for _, row in expectation_for_table.iterrows():
            col_nm = row['Column_Name'].strip()
            rule = row['validation_rule'].strip()
            cond = row['Condition'].strip() if pd.notna(row['Condition']) else None
            
            if rule == 'expect_column_values_to_not_be_null':
                validator.expect_column_values_to_not_be_null(column=col_nm)
                
            elif rule == 'expect_column_values_to_match_regex' and cond:
                validator.expect_column_values_to_match_regex(column=col_nm, regex=cond)
                
            elif rule == 'expect_column_values_to_be_unique':
                validator.expect_column_values_to_be_unique(column=col_nm)
                
            elif rule == 'expect_column_values_to_be_between':
                validator.expect_column_values_to_be_between(
                    column=col_nm,
                    min_value=row['Minimum_Range'],
                    max_value=row['Maximum_Range']
                )
                
            elif rule == 'expect_column_values_to_be_in_set' and cond:
                raw_vals = cond.strip()
                if raw_vals.startswith("[") and raw_vals.endswith("]"):
                    raw_vals = raw_vals[1:-1]
                
                value_set = [int(v.strip()) for v in raw_vals.split(',')]
                validator.expect_column_values_to_be_in_set(
                    column=col_nm,
                    value_set=value_set
                )
    
    def process_validation_results(self, validation_result: Any, indexed_df: Any, 
                                 expectation_metadata_map: Dict, table_name: str) -> Tuple[Any, List[int]]:
        """
        Process validation results and handle failed records
        
        Args:
            validation_result: Great Expectations validation result
            indexed_df: Indexed DataFrame
            expectation_metadata_map: Metadata mapping for expectations
            table_name: Name of the current table
            
        Returns:
            Tuple of (clean_df, all_unexpected_indices)
        """
        all_unexpected_indices = []
        failed_columns = set()
        
        # Extract failed indices and columns
        for result in validation_result.to_json_dict()["results"]:
            col_name = result.get("expectation_config", {}).get("kwargs", {}).get("column")
            unexpected_list = result["result"].get("unexpected_index_list", [])
            
            if unexpected_list and col_name:
                failed_columns.add(col_name)
                for entry in unexpected_list:
                    if "__row_idx" in entry:
                        all_unexpected_indices.append(entry["__row_idx"])
        
        if failed_columns:
            self.logger.warning(f"Columns that failed validation in {table_name}: {sorted(failed_columns)}")
        
        # Process failed records if any
        if all_unexpected_indices:
            self._write_audit_records(
                validation_result, indexed_df, expectation_metadata_map, 
                table_name, set(all_unexpected_indices)
            )
            
            # Remove failed records
            unexpected_indices_df = self.spark.createDataFrame(
                [(i,) for i in set(all_unexpected_indices)], ["__row_idx"]
            )
            indexed_df = indexed_df.join(unexpected_indices_df, on="__row_idx", how="left_anti")
        
        clean_df = indexed_df.drop("__row_idx")
        return clean_df, all_unexpected_indices
    
    def _write_audit_records(self, validation_result: Any, indexed_df: Any, 
                           expectation_metadata_map: Dict, table_name: str, 
                           unexpected_index_set: set) -> None:
        """
        Write audit records for failed validations
        
        Args:
            validation_result: Great Expectations validation result
            indexed_df: Indexed DataFrame
            expectation_metadata_map: Metadata mapping for expectations
            table_name: Name of the current table
            unexpected_index_set: Set of unexpected indices
        """
        unexpected_indices_df = self.spark.createDataFrame(
            [(i,) for i in unexpected_index_set], ["__row_idx"]
        )
        bad_rows_df = indexed_df.join(unexpected_indices_df, on="__row_idx", how="inner")
        
        source_system_col = (
            col("Source_System") if "Source_System" in indexed_df.columns 
            else lit(self.config.get('source_system', 'Netforum'))
        )
        
        for result in validation_result.to_json_dict()["results"]:
            exp_config = result.get("expectation_config", {})
            col_name = exp_config.get("kwargs", {}).get("column", "unknown_column")
            exp_type = exp_config.get("expectation_type", "unknown_expectation")
            
            key = (exp_type, col_name)
            meta = expectation_metadata_map.get(key, {})
            
            # Handle potential null index key
            index_key = meta.get("Index_Key", "")
            index_value_col = col(index_key).cast("string") if index_key and index_key in bad_rows_df.columns else lit("")
            
            audit_spark_df = bad_rows_df.select(
                source_system_col.alias("Source_System"),
                lit(table_name).alias("Table_Name"),
                lit(col_name).alias("Column_Name"),
                lit(meta.get("AI_ColumnName", "")).alias("AI_ColumnName"),
                lit(index_key).alias("Index_Key"),
                index_value_col.alias("Index_Value"),
                # col(col_name).alias("Failed_Value"),  # Cast to string for consistency
                lit(exp_type).alias("validation_rules"),
                lit(meta.get("Val_Description", "")).alias("Val_Description"),
                lit(meta.get("AI_Reasoning", "")).alias("ai_reasoning"),
                lit(meta.get("SampleData", "")).alias("SampleData"),
                current_timestamp().alias("LoadTime_UTC"),
                current_timestamp().alias("LoadTime_PST")
            )
            
            # Write with schema evolution enabled
            audit_spark_df.write \
                .mode("append") \
                .option("overwriteSchema", "true") \
                .saveAsTable(self.config['audit_table_name'])
            
            self.logger.info(f"Written {audit_spark_df.count()} audit records for {table_name}")
    
    def write_to_silver_layer(self, clean_df: Any, table_name: str) -> None:
        """
        Write clean data to Silver layer
        
        Args:
            clean_df: Clean DataFrame
            table_name: Name of the table
        """
        if clean_df.rdd.isEmpty():
            self.logger.warning(f"Clean DataFrame for {table_name} is empty after validation")
            return
        
        # Generate silver table name and path
        base_table_name = table_name
        if base_table_name.lower().startswith("bronze_"):
            base_table_name = base_table_name[7:]
        
        silver_table_name = f"SilverStage_{base_table_name}"
        silver_table_path = f"{self.config['silver_base_path']}/{silver_table_name}"
        
        # Write to silver layer
        clean_df.write.format("delta") \
            .mode("overwrite") \
            .option("overwriteSchema", "true") \
            .save(silver_table_path)
        
        self.logger.info(f"Written clean data to Silver Layer: {silver_table_path}")
    
    def record_passed_table_info(self, table_name: str, row_count: int) -> None:
        """
        Record information about tables that passed all validations
        
        Args:
            table_name: Name of the table
            row_count: Number of rows in the clean table
        """
        load_time_pst = datetime.now()
        
        data = {
            "Table_Name": [table_name],
            "LoadDate_PST": [load_time_pst],
            "Count": [row_count]
        }
        
        pdf = pd.DataFrame(data)
        sdf = self.spark.createDataFrame(pdf)
        
        sdf = sdf.select(
            col("Table_Name").cast("string"),
            col("LoadDate_PST").cast("timestamp"),
            col("Count").cast("long")
        )
        
        sdf.write.mode("append").option("overwriteSchema", "true").save(
            self.config['passed_tables_path']
        )
        
        self.logger.info(f"Table {table_name} passed all DQ checks and logged to passed table")
    
    def validate_single_table(self, table_name: str, expectations_df: pd.DataFrame) -> None:
        """
        Validate a single table
        
        Args:
            table_name: Name of the table to validate
            expectations_df: DataFrame containing all validation expectations
        """
        try:
            self.logger.info(f"Processing table: {table_name}")
            
            # Get expectations for this table
            expectation_for_table = expectations_df[expectations_df['Table_Name'] == table_name]
            if expectation_for_table.empty:
                self.logger.warning(f"No expectations found for table {table_name}")
                return
            
            # Create suite and data source
            suite_name = f"suite_{table_name}_{uuid4().hex[:8]}"
            suite = self.context.create_expectation_suite(suite_name, overwrite_existing=True)
            data_source_name = self.create_data_source(table_name)
            
            # Load and prepare data
            indexed_df = self.load_and_prepare_data(table_name)
            
            # Create expectation metadata
            expectation_metadata_map = self.create_expectation_metadata(expectation_for_table)
            
            # Create validator
            runtime_batch_request = RuntimeBatchRequest(
                datasource_name=data_source_name,
                data_connector_name="default_runtime_data_connector_name",
                data_asset_name=f"{table_name}_asset",
                runtime_parameters={"batch_data": indexed_df},
                batch_identifiers={"default_identifier_name": f"id_{uuid4().hex[:8]}"}
            )
            
            validator = self.context.get_validator(
                batch_request=runtime_batch_request,
                expectation_suite_name=suite_name
            )
            
            # Add expectations
            self.add_expectations_to_validator(validator, expectation_for_table)
            
            # Check if any expectations were added
            if not validator.get_expectation_suite().expectations:
                self.logger.warning(f"No expectations added for {table_name}, skipping")
                return
            
            # Run validation
            validation_result = validator.validate(
                result_format={
                    "result_format": "COMPLETE",
                    "unexpected_index_column_names": ["__row_idx"],
                    "unexpected_index_list": True
                }
            )
            
            # Process results
            clean_df, all_unexpected_indices = self.process_validation_results(
                validation_result, indexed_df, expectation_metadata_map, table_name
            )
            
            # Write to silver layer
            self.write_to_silver_layer(clean_df, table_name)
            
            # Record passed table if no failures
            if not all_unexpected_indices:
                self.record_passed_table_info(table_name, clean_df.count())
            else:
                self.logger.info(f"Table {table_name} had failed rows, not logging to passed table")
                
            self.logger.info(f"Completed processing {table_name} with {clean_df.count()} clean records")
            
        except Exception as ex:
            self.logger.error(f"Error processing table {table_name}: {str(ex)}")
            raise
    
    def run_validation_pipeline(self) -> None:
        """
        Run the complete validation pipeline for all selected tables
        """
        try:
            self.logger.info("Starting data quality validation pipeline")
            
            # Load validation expectations
            expectations_df = self.load_validation_expectations()
            
            # Get selected tables
            selected_tables = self.get_selected_tables(expectations_df)
            
            # Process each table
            for table_name in selected_tables:
                self.validate_single_table(table_name, expectations_df)
            
            self.logger.info("Data quality validation pipeline completed successfully")
            
        except Exception as ex:
            self.logger.error(f"Pipeline failed: {str(ex)}")
            raise