In [0]:
# Logger Utility
# Provides logging for the data pipeline with color-coded output.

import logging
import os
import functools
import time
from pyspark.sql import SparkSession

# Get Spark session
spark = SparkSession.builder.getOrCreate()

# ANSI color codes
COLORS = {
    'DEBUG': '\033[37m',    # White
    'INFO': '\033[32m',     # Green
    'WARNING': '\033[33m',  # Yellow
    'ERROR': '\033[31m',    # Red
    'CRITICAL': '\033[1;31m',  # Red bold
    'data_loading': '\033[36m',  # Cyan
    'transformation': '\033[35m',  # Magenta
    'data_writing': '\033[33m',  # Blue
    'RESET': '\033[0m'      # Reset
}

class NotebookFormatter(logging.Formatter):
    """Custom formatter with color-coded log messages."""
    def __init__(self, notebook_name):
        self.notebook_name = notebook_name
        super().__init__(f"[%(levelname)s] %(name)s: %(message)s [%(asctime)s]")

    def format(self, record):
        # Initialize component as None
        component = None
        # Check for component in logger name (e.g., landing_to_bronze.data_loading)
        if '.' in record.name:
            parts = record.name.split('.')
            if len(parts) > 1:
                component = parts[-1]

        # Get colors
        level_color = COLORS.get(record.levelname.upper(), COLORS['RESET'])
        component_color = COLORS.get(component, '') if component else ''

        # Copy record to avoid modifying original
        record_copy = logging.makeLogRecord(record.__dict__)
        record_copy.levelname = f"{level_color}{record_copy.levelname}{COLORS['RESET']}"
        record_copy.name = f"{component_color}{record_copy.name}{COLORS['RESET']}"

        return super().format(record_copy)

def create_logger(notebook_name="default", log_level="INFO", 
                  component_log_levels={"data_loading": "DEBUG",
                                        "transformation": "INFO",
                                        "data_writing": "INFO"}
    ):
    """
    Create a logger with color-coded output.
    
    Args:
        notebook_name: Name of notebook/script (optional)
        log_level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
        component_log_levels: Dict of component names to log levels (e.g., {"data_loading": "DEBUG"})
    
    Returns:
        Logger instance
    """
    numeric_level = getattr(logging, log_level.upper(), None)
    if not isinstance(numeric_level, int):
        raise ValueError(f"Invalid log level: {log_level}")

    # Create root logger
    logger = logging.getLogger(notebook_name)
    logger.setLevel(numeric_level)
    logger.handlers = []  # Clear handlers
    logger.propagate = False  # Prevent duplicates

    # Create console handler
    console_handler = logging.StreamHandler()
    console_handler.setLevel(numeric_level)
    console_handler.setFormatter(NotebookFormatter(notebook_name))
    logger.addHandler(console_handler)

    # Debug: Verify root logger setup
    logger.debug(f"Root logger {notebook_name} setup with level {log_level}, {len(logger.handlers)} handlers")

    if component_log_levels:
        for component, level in component_log_levels.items():
            component_logger = logging.getLogger(f"{notebook_name}.{component}")
            component_numeric_level = getattr(logging, level.upper(), None)
            if isinstance(component_numeric_level, int):
                component_logger.setLevel(component_numeric_level)
                component_logger.handlers = []  # Clear any inherited handlers
                component_logger.addHandler(console_handler)  # Share root handler
                component_logger.propagate = False
                # Debug: Verify component logger setup
                component_logger.debug(f"Component logger {notebook_name}.{component} setup with level {level}")

    def log_start(operation, component=None):
        msg = f"▶️ Starting {operation}"
        if component:
            logging.getLogger(f"{notebook_name}.{component}").info(msg)
        else:
            logger.info(msg)

    def log_end(operation, duration=None, component=None):
        duration_str = f" (Duration: {duration:.2f}s)" if duration is not None else ""
        msg = f"✅ Completed {operation}{duration_str}"
        if component:
            logging.getLogger(f"{notebook_name}.{component}").info(msg)
        else:
            logger.info(msg)

    logger.log_start = log_start
    logger.log_end = log_end

    return logger

def log_execution(logger):
    """Decorator to log function execution time and status."""
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            component = None
            fname = func.__name__
            if any(x in fname for x in ["load", "read"]):
                component = "data_loading"
            elif "transform" in fname:
                component = "transformation"
            elif any(x in fname for x in ["write", "save", "upsert"]):
                component = "data_writing"

            log = logging.getLogger(f"{logger.name}.{component}") if component else logger

            log.info(f"🔄 Running {fname}")
            start = time.perf_counter()
            try:
                result = func(*args, **kwargs)
                log.info(f"✅ Finished {fname} in {time.perf_counter() - start:.2f}s")
                return result
            except Exception as e:
                e = str(e).split("JVM")[0]
                log.error(f"❌ Failed {fname}: {str(e)}")
                raise
        return wrapper
    return decorator

def log_dataframe_info(df, name, logger, component=None):
    """Log DataFrame row and column count."""
    log = logging.getLogger(f"{logger.name}.{component}") if component else logger
    count = df.count()
    columns = len(df.columns)
    log.info(f"📊 DataFrame '{name}' has {count} rows, {columns} columns")
    
    # Log schema details at debug level
    schema_str = "\n  " + "\n  ".join([f"{field.name}: {field.dataType}" for field in df.schema.fields])
    log.debug(f"🔍 DataFrame '{name}' schema: {schema_str}")
    
    # Log sample data at debug level
    if log.isEnabledFor(logging.DEBUG) and count > 0:
        try:
            sample = df.limit(5).toPandas()
            log.debug(f"🔍 DataFrame '{name}' sample data:\n{sample}")
        except Exception as e:
            e = str(e).split("JVM")[0]
            log.debug(f"Could not convert sample to pandas: {str(e)}")
    return df