# 🟫 Bronze Layer Data Pipeline with PySpark

This repository contains a PySpark script that implements the **Bronze Layer** of a data lakehouse architecture. It ingests raw `.csv` files, adds metadata, handles corrupt records, and stores clean data in Parquet format.

## 📁 Directory Structure

```
data/
├── Raw/        # Source CSV files (tab-delimited)
├── Bronze/     # Cleaned and structured data (Parquet)
├── Silver/     # [Reserved for further processing]
├── Gold/       # [Reserved for final curated datasets]
logs/
└── bronze_layer_errors.log  # Log file for errors and warnings
```

## ⚙️ What the Script Does

### 1. **Initialization**
- Ensures all required directories exist (Raw, Bronze, Logs).
- Sets up structured logging to both console and file.

### 2. **Spark Session Setup**
- Creates a Spark session with relevant options for timestamp parsing and error handling.
- If the session fails, the script exits gracefully.

### 3. **CSV Processing (Raw ➡️ Bronze)**
For each `.csv` file in the `Raw/` folder:
- Reads the file using tab (`\t`) as the delimiter.
- Infers schema and detects corrupt rows.
- Adds an `ingestion_timestamp` column with the current timestamp.
- Logs any corrupt records and saves them separately.
- Writes the cleaned data as a Parquet file into `Bronze/` (one subfolder per file).
- Stores each DataFrame in a dictionary for potential downstream use.

### 4. **Schema Logging**
- Logs the schema of the last processed DataFrame for verification.

### 5. **Shutdown**
- Stops the Spark session cleanly and logs the end of processing.

## 🪵 Logging & Error Handling

- All warnings and errors are saved to:
  ```
  logs/bronze_layer_errors.log
  ```
- Corrupt rows are automatically redirected to:
  ```
  logs/bad_records_from_bronze/
  ```

## 🧪 Example Usage

Simply drop one or more `.csv` files (tab-delimited) into the `data/Raw/` directory and run the script:

```bash
python bronze_layer.py
```

## 🛠 Requirements

- Python 3.7+
- PySpark
- Local or distributed Spark environment

## 📌 Notes

- Silver and Gold layers are defined for future use but not yet implemented.
- The script is modular and can be extended for scheduling, validation, or transformation logic.


In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import current_timestamp, lit
from pyspark.sql.types import StructType, StructField, StringType
import os
import logging
from datetime import datetime

# --- Configuration ---
# Define paths for your data layers
raw_dir = 'data/Raw/'
bronze_dir = 'data/Bronze'
silver_dir = 'data/Silver' # Defined for completeness, not used in Bronze layer logic yet
gold_dir = 'data/Gold'     # Defined for completeness, not used in Bronze layer logic yet

# Define path for the error log file
error_log_dir = 'logs'
error_log_file_path = os.path.join(error_log_dir, 'bronze_layer_errors.log')

# Ensure directories exist
os.makedirs(raw_dir, exist_ok=True)
os.makedirs(bronze_dir, exist_ok=True)
os.makedirs(error_log_dir, exist_ok=True) # Ensure log directory exists

# --- Logging Setup ---
# Clear existing handlers to prevent duplicate logs if run multiple times in a notebook
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(error_log_file_path), # Log to the specified file
        logging.StreamHandler()                    # Also print to console
    ]
)
logger = logging.getLogger(__name__)

logger.info("Starting PySpark Data Layer Processing (Bronze Layer)...")

# --- Spark Session Initialization ---
try:
    spark = SparkSession.builder.appName("TakeHomeExamBronzeLayer") \
        .config("spark.sql.legacy.timeParserPolicy", "LEGACY") \
        .getOrCreate()
    logger.info("SparkSession created successfully.")
except Exception as e:
    logger.error(f"Error initializing SparkSession: {e}")
    exit(1) # Exit if SparkSession cannot be created

# --- Data Processing: Raw to Bronze Layer ---

# Placeholder for a 'bad records' path if any malformed records are found during CSV read
# This path will be within the 'logs' directory
bad_records_path = os.path.join(error_log_dir, 'bad_records_from_bronze')
os.makedirs(bad_records_path, exist_ok=True)

file_names = [f for f in os.listdir(raw_dir) if f.endswith('.csv')]

if not file_names:
    logger.warning(f"No CSV files found in the raw directory: {raw_dir}. Please ensure files are present.")

bronze_dfs = {} # Using a dictionary to store DFs by file name for easier access later

for file_name in file_names:
    raw_file_path = os.path.join(raw_dir, file_name)
    # Remove .csv extension for the bronze folder name
    bronze_output_path = os.path.join(bronze_dir, file_name.replace('.csv', ''))

    logger.info(f"Processing file: {file_name}")
    
    try:
        # Read raw CSV with options for error handling and schema inference
        bronze_df = spark.read \
            .option("header", "true") \
            .option("inferSchema", "true") \
            .option("mode", "PERMISSIVE") \
            .option("columnNameOfCorruptRecord", "_corrupt_record") \
            .option("badRecordsPath", bad_records_path) \
            .option("sep", "\t") \
            .csv(raw_file_path)
        
        # Add ingestion timestamp
        bronze_df = bronze_df.withColumn("ingestion_timestamp", current_timestamp())

        # Check for corrupt records and log if any are found
        if "_corrupt_record" in bronze_df.columns:
            corrupt_count = bronze_df.filter(bronze_df["_corrupt_record"].isNotNull()).count()
            if corrupt_count > 0:
                logger.warning(f"Found {corrupt_count} corrupt records in {file_name}. Details in {bad_records_path}.")
            bronze_df = bronze_df.drop("_corrupt_record") # Drop the corrupt record column after logging

        # Write to Bronze layer in Parquet format
        bronze_df.write.mode("overwrite").parquet(bronze_output_path)
        logger.info(f"Successfully processed and wrote {file_name} to Bronze layer: {bronze_output_path}")
        
        bronze_dfs[file_name.replace('.csv', '')] = bronze_df # Store DF by its new folder name

    except Exception as e:
        logger.error(f"Error processing {file_name}: {e}")

# --- Schema Display (example for one DataFrame) ---
if bronze_dfs:
    # Get the key (folder name) of the last processed DataFrame
    last_df_key = list(bronze_dfs.keys())[-1] 
    logger.info(f"Schema of the last processed Bronze DataFrame ({last_df_key}):")
    bronze_dfs[last_df_key].printSchema()
else:
    logger.warning("No Bronze DataFrames were created.")

logger.info("Bronze Layer processing finished.")

# --- Stop Spark Session ---
spark.stop()
logger.info("SparkSession stopped.")

# Silver layer

# 🥈 Silver Layer Data Pipeline with PySpark

This repository contains a PySpark script that processes the **Silver Layer** in a data lakehouse architecture. It refines raw data ingested into the Bronze Layer by performing validation, cleaning, standardization, and enhancement of data quality.

---

## 📂 Directory Structure

```
data/
├── Bronze/         # Cleaned data from raw (input to Silver layer)
├── Silver/         # Refined, validated data (output from Silver layer)
logs/
└── silver_layer_errors.log  # Log file for errors, warnings, and stats
```

---

## ⚙️ What the Script Does

### 🔁 Main Flow

1. **Initial Setup**
   - Creates required directories (`Silver/`, `logs/`)
   - Sets up logging to file and console
   - Starts a SparkSession

2. **Table-Specific Processing**
   - Reads Parquet files from `Bronze/`
   - Cleans string fields (e.g. trimming, null handling)
   - Validates fields: dates, numerics, booleans, business rules
   - Adds metadata like `processing_timestamp`, `data_quality_score`
   - Applies transformations and writes back to `Silver/` as Parquet

3. **Tables Covered**
   - `address`
   - `customer`
   - `person`
   - `product`
   - `salesorderheader`
   - `salesorderdetail`
   - `creditcard`
   - Any other table → handled by a generic processor

4. **Data Quality Checks**
   - Null percentage calculation per column
   - Logical validations like:
     - Prices (e.g. `ListPrice >= StandardCost`)
     - Monetary totals (`TotalDue = SubTotal + Tax + Freight`)
     - Date sequences (e.g. `OrderDate <= ShipDate`)
     - Card masking (e.g. `****-****-****-1234`)

5. **Logs Summary**
   - Full DQ stats (record count changes, null percentages)
   - Errors are caught and logged per table

---

## 🧠 Key Validation Examples

| Validation Type      | Example Columns               | Logic |
|----------------------|-------------------------------|-------|
| **Boolean cleanup**  | `MakeFlag`, `OnlineOrderFlag` | Normalize 1/0, true/false, yes/no |
| **Date parsing**     | `ModifiedDate`, `OrderDate`   | Convert to `DateType`, check sequence |
| **Numeric check**    | `ListPrice`, `StandardCost`   | Ensure non-negative |
| **Business logic**   | `LineTotal`, `TotalDue`       | Match expected formulas |
| **Text cleaning**    | `Name`, `Address`             | Trim, set empty to null |
| **Card masking**     | `CardNumber`                  | Format to `****-****-****-1234` |

---

## 📝 How to Run

```bash
python silver_layer.py
```

Make sure Bronze layer data is already prepared in `data/Bronze/` as Parquet files.

---

## 🧪 Output Example

For each table, a `.parquet` folder will be created in `data/Silver/`, containing validated data.

Logging will summarize:
- Number of records in vs. out
- % of nulls per column
- Any high-null columns (>20%) flagged
- Any failed validations

---

## 📌 Notes

- Designed for modular extension to Gold Layer and analytics
- Validation logic is centralized in helper functions
- Each table is processed independently to isolate issues


In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    current_timestamp, lit, col, when, isnan, isnull, trim, upper, lower,
    regexp_replace, to_date, to_timestamp, coalesce, length, substring,
    round as spark_round, abs as spark_abs, split, size, concat, initcap
)
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, DateType, TimestampType, BooleanType
import os
import logging
from datetime import datetime

# --- Configuration ---
# Define paths for your data layers
bronze_dir = 'data/Bronze'
silver_dir = 'data/Silver'
gold_dir = 'data/Gold'     # Defined for completeness, not used in Silver layer logic yet

# Define path for the error log file
error_log_dir = 'logs'
error_log_file_path = os.path.join(error_log_dir, 'silver_layer_errors.log')

# Ensure directories exist
os.makedirs(silver_dir, exist_ok=True)
os.makedirs(error_log_dir, exist_ok=True)

# --- Logging Setup ---
# Clear existing handlers to prevent duplicate logs
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(error_log_file_path),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

logger.info("Starting PySpark Data Layer Processing (Silver Layer)...")

# --- Spark Session Initialization ---
try:
    spark = SparkSession.builder.appName("TakeHomeExamSilverLayer") \
        .config("spark.sql.legacy.timeParserPolicy", "LEGACY") \
        .getOrCreate()
    logger.info("SparkSession created successfully.")
except Exception as e:
    logger.error(f"Error initializing SparkSession: {e}")
    exit(1)

# --- Data Quality Functions ---
def clean_string_column(df, column_name):
    """Clean string columns by trimming whitespace and handling nulls"""
    if column_name in df.columns:
        return df.withColumn(column_name, 
            when(col(column_name).isNull() | (trim(col(column_name)) == ""), None)
            .otherwise(trim(col(column_name)))
        )
    return df

def standardize_boolean_flags(df, column_name):
    """Standardize boolean flags to consistent format"""
    if column_name in df.columns:
        # Cast the column to StringType to handle mixed types gracefully
        # Then convert to boolean based on various string representations
        return df.withColumn(column_name,
            when(lower(col(column_name).cast(StringType())).isin(["1", "true", "t", "y", "yes"]), True)
            .when(lower(col(column_name).cast(StringType())).isin(["0", "false", "f", "n", "no"]), False)
            .otherwise(None)
        ).withColumn(column_name, col(column_name).cast(BooleanType())) # Ensure final type is Boolean
    return df

def validate_dates(df, column_name):
    """Validate and clean date columns"""
    if column_name in df.columns:
        return df.withColumn(column_name,
            when(col(column_name).isNull(), None)
            .otherwise(to_date(col(column_name)))
        )
    return df

def validate_numeric_positive(df, column_name):
    """Ensure numeric columns are positive where applicable"""
    if column_name in df.columns:
        return df.withColumn(column_name,
            when(col(column_name) < 0, None)
            .otherwise(col(column_name))
        )
    return df

def add_data_quality_flags(df, table_name):
    """Add data quality flags for monitoring"""
    df = df.withColumn("data_quality_score", lit(1.0))
    df = df.withColumn("source_table", lit(table_name))
    df = df.withColumn("processing_timestamp", current_timestamp())
    return df

# --- Silver Layer Processing Functions ---

def process_address(bronze_df):
    """Process address table with data quality checks"""
    logger.info("Processing address table...")
    
    df = bronze_df
    
    # Clean string columns
    string_cols = ['AddressLine1', 'AddressLine2', 'City', 'PostalCode']
    for col_name in string_cols:
        df = clean_string_column(df, col_name)
    
    # Validate postal codes (basic format check)
    if 'PostalCode' in df.columns:
        df = df.withColumn('PostalCode',
            when(length(col('PostalCode')) < 3, None)
            .otherwise(col('PostalCode'))
        )
    
    # Validate dates
    df = validate_dates(df, 'ModifiedDate')
    
    # Add quality flags
    df = add_data_quality_flags(df, 'address')
    
    return df

def process_customer(bronze_df):
    """Process customer table"""
    logger.info("Processing customer table...")
    
    df = bronze_df
    
    # Clean account number
    df = clean_string_column(df, 'AccountNumber')
    
    # Validate IDs are positive
    id_cols = ['CustomerID', 'PersonID', 'StoreID', 'TerritoryID']
    for col_name in id_cols:
        df = validate_numeric_positive(df, col_name)
    
    # Validate dates
    df = validate_dates(df, 'ModifiedDate')
    
    df = add_data_quality_flags(df, 'customer')
    return df

def process_person(bronze_df):
    """Process person table with name standardization"""
    logger.info("Processing person table...")

    df = bronze_df

    # Clean and standardize name fields
    name_cols = ['FirstName', 'MiddleName', 'LastName', 'Title', 'Suffix']
    for col_name in name_cols:
        df = clean_string_column(df, col_name)
        if col_name in df.columns:
            # Capitalize first letter of each word using initcap()
            df = df.withColumn(col_name,
                when(col(col_name).isNotNull(), initcap(col(col_name)))
                .otherwise(None)
            )

    # Standardize PersonType
    df = clean_string_column(df, 'PersonType')

    # Validate EmailPromotion values
    if 'EmailPromotion' in df.columns:
        df = df.withColumn('EmailPromotion',
            when(col('EmailPromotion').isin([0, 1, 2]), col('EmailPromotion'))
            .otherwise(0)
        )

    # Validate and clean date fields
    df = validate_dates(df, 'ModifiedDate')

    # Add DQ flags
    df = add_data_quality_flags(df, 'person')

    return df


def process_product(bronze_df):
    """Process product table with comprehensive validation"""
    logger.info("Processing product table...")
    
    df = bronze_df
    
    # Clean string columns
    string_cols = ['Name', 'ProductNumber', 'Color', 'Size', 'ProductLine', 'Class', 'Style']
    for col_name in string_cols:
        df = clean_string_column(df, col_name)
    
    # Standardize boolean flags
    bool_cols = ['MakeFlag', 'FinishedGoodsFlag']
    for col_name in bool_cols:
        df = standardize_boolean_flags(df, col_name)
    
    # Validate numeric columns
    numeric_cols = ['SafetyStockLevel', 'ReorderPoint', 'StandardCost', 'ListPrice', 'Weight', 'DaysToManufacture']
    for col_name in numeric_cols:
        df = validate_numeric_positive(df, col_name)
    
    # Validate price consistency (ListPrice >= StandardCost)
    if 'ListPrice' in df.columns and 'StandardCost' in df.columns:
        df = df.withColumn('price_validation_flag',
            when((col('ListPrice').isNotNull() & col('StandardCost').isNotNull()),
                col('ListPrice') >= col('StandardCost')
            ).otherwise(True)
        )
    
    # Validate dates
    date_cols = ['SellStartDate', 'SellEndDate', 'DiscontinuedDate', 'ModifiedDate']
    for col_name in date_cols:
        df = validate_dates(df, col_name)
    
    df = add_data_quality_flags(df, 'product')
    return df

def process_sales_order_header(bronze_df):
    """Process sales order header with business logic validation"""
    logger.info("Processing salesorderheader table...")
    
    df = bronze_df
    
    # Clean string columns
    string_cols = ['SalesOrderNumber', 'PurchaseOrderNumber', 'AccountNumber', 'Comment']
    for col_name in string_cols:
        df = clean_string_column(df, col_name)
    
    # Standardize boolean flags
    df = standardize_boolean_flags(df, 'OnlineOrderFlag')
    
    # Validate dates and date logic
    date_cols = ['OrderDate', 'DueDate', 'ShipDate', 'ModifiedDate']
    for col_name in date_cols:
        df = validate_dates(df, col_name)
    
    # Validate date sequence (OrderDate <= DueDate, OrderDate <= ShipDate)
    df = df.withColumn('date_sequence_valid',
        when(
            (col('OrderDate').isNotNull() & col('DueDate').isNotNull()),
            col('OrderDate') <= col('DueDate')
        ).otherwise(True) &
        when(
            (col('OrderDate').isNotNull() & col('ShipDate').isNotNull()),
            col('OrderDate') <= col('ShipDate')
        ).otherwise(True)
    )
    
    # Validate monetary amounts
    money_cols = ['SubTotal', 'TaxAmt', 'Freight', 'TotalDue']
    for col_name in money_cols:
        df = validate_numeric_positive(df, col_name)
    
    # Validate total calculation (TotalDue = SubTotal + TaxAmt + Freight)
    if all(col_name in df.columns for col_name in ['SubTotal', 'TaxAmt', 'Freight', 'TotalDue']):
        df = df.withColumn('total_calculation_valid',
            when(
                col('SubTotal').isNotNull() & 
                col('TaxAmt').isNotNull() & 
                col('Freight').isNotNull() &
                col('TotalDue').isNotNull(),
                spark_abs(col('TotalDue') - (col('SubTotal') + col('TaxAmt') + col('Freight'))) < 0.01
            ).otherwise(True)
        )
    
    df = add_data_quality_flags(df, 'salesorderheader')
    return df

def process_sales_order_detail(bronze_df):
    """Process sales order detail with calculation validation"""
    logger.info("Processing salesorderdetail table...")
    
    df = bronze_df
    
    # Clean string columns
    df = clean_string_column(df, 'CarrierTrackingNumber')
    
    # Validate numeric columns
    numeric_cols = ['OrderQty', 'UnitPrice', 'UnitPriceDiscount', 'LineTotal']
    for col_name in numeric_cols:
        df = validate_numeric_positive(df, col_name)
    
    # Validate line total calculation
    if all(col_name in df.columns for col_name in ['OrderQty', 'UnitPrice', 'UnitPriceDiscount', 'LineTotal']):
        df = df.withColumn('line_total_valid',
            when(
                col('OrderQty').isNotNull() & 
                col('UnitPrice').isNotNull() & 
                col('UnitPriceDiscount').isNotNull() &
                col('LineTotal').isNotNull(),
                spark_abs(col('LineTotal') - (col('OrderQty') * col('UnitPrice') * (1 - col('UnitPriceDiscount')))) < 0.01
            ).otherwise(True)
        )
    
    df = validate_dates(df, 'ModifiedDate')
    df = add_data_quality_flags(df, 'salesorderdetail')
    return df

def process_creditcard(bronze_df):
    """Process credit card with sensitive data handling"""
    logger.info("Processing creditcard table...")
    
    df = bronze_df
    
    # Clean card type
    df = clean_string_column(df, 'CardType')
    
    # Mask credit card number (keep last 4 digits)
    if 'CardNumber' in df.columns:
        df = df.withColumn('CardNumber',
            when(col('CardNumber').isNotNull() & (length(col('CardNumber')) >= 4),
                concat(lit("****-****-****-"), substring(col('CardNumber'), -4, 4))
            ).otherwise(lit("****-****-****-****"))
        )
    
    # Validate expiration dates
    if 'ExpMonth' in df.columns:
        df = df.withColumn('ExpMonth',
            when((col('ExpMonth') >= 1) & (col('ExpMonth') <= 12), col('ExpMonth'))
            .otherwise(None)
        )
    
    if 'ExpYear' in df.columns:
        df = df.withColumn('ExpYear',
            when((col('ExpYear') >= 2020) & (col('ExpYear') <= 2050), col('ExpYear'))
            .otherwise(None)
        )
    
    df = validate_dates(df, 'ModifiedDate')
    df = add_data_quality_flags(df, 'creditcard')
    return df

def process_generic_table(bronze_df, table_name):
    """Generic processing for simpler tables"""
    logger.info(f"Processing {table_name} table...")
    
    df = bronze_df
    
    # Clean all string columns
    for col_name, col_type in df.dtypes:
        if col_type == 'string':
            df = clean_string_column(df, col_name)
    
    # Validate ModifiedDate if exists
    df = validate_dates(df, 'ModifiedDate')
    
    # Add quality flags
    df = add_data_quality_flags(df, table_name)
    return df

# --- Main Processing Logic ---

# Dictionary mapping table names to their specific processing functions
table_processors = {
    'address': process_address,
    'customer': process_customer,
    'person': process_person,
    'product': process_product,
    'salesorderheader': process_sales_order_header,
    'salesorderdetail': process_sales_order_detail,
    'creditcard': process_creditcard
}

# Get list of Bronze layer tables
bronze_tables = [d for d in os.listdir(bronze_dir) if os.path.isdir(os.path.join(bronze_dir, d))]

if not bronze_tables:
    logger.warning(f"No Bronze tables found in directory: {bronze_dir}")
    exit(1)

silver_dfs = {}
processing_stats = {}

for table_name in bronze_tables:
    bronze_table_path = os.path.join(bronze_dir, table_name)
    silver_table_path = os.path.join(silver_dir, table_name)
    
    try:
        logger.info(f"Processing Bronze table: {table_name}")
        
        # Read Bronze layer data
        bronze_df = spark.read.parquet(bronze_table_path)
        
        # Get initial record count
        initial_count = bronze_df.count()
        
        # Apply specific processing function or generic processing
        if table_name in table_processors:
            silver_df = table_processors[table_name](bronze_df)
        else:
            silver_df = process_generic_table(bronze_df, table_name)
        
        # Get final record count
        final_count = silver_df.count()
        
        # Calculate null percentages for key columns
        total_cols = len(silver_df.columns)
        null_counts = {}
        for col_name in silver_df.columns:
            if col_name not in ['ingestion_timestamp', 'processing_timestamp', 'data_quality_score', 'source_table']:
                null_count = silver_df.filter(col(col_name).isNull()).count()
                null_percentage = (null_count / final_count * 100) if final_count > 0 else 0
                null_counts[col_name] = null_percentage
        
        # Store processing statistics
        processing_stats[table_name] = {
            'initial_count': initial_count,
            'final_count': final_count,
            'records_dropped': initial_count - final_count,
            'null_percentages': null_counts
        }
        
        # Write to Silver layer
        silver_df.write.mode("overwrite").parquet(silver_table_path)
        logger.info(f"Successfully processed {table_name}: {initial_count} -> {final_count} records")
        
        # Store DataFrame for potential further use
        silver_dfs[table_name] = silver_df
        
    except Exception as e:
        logger.error(f"Error processing {table_name}: {e}")
        processing_stats[table_name] = {'error': str(e)}

# --- Data Quality Report ---
logger.info("=== SILVER LAYER DATA QUALITY REPORT ===")
for table_name, stats in processing_stats.items():
    if 'error' in stats:
        logger.error(f"{table_name}: Processing failed - {stats['error']}")
    else:
        logger.info(f"{table_name}:")
        logger.info(f"  Records: {stats['initial_count']} -> {stats['final_count']}")
        if stats['records_dropped'] > 0:
            logger.warning(f"  Dropped records: {stats['records_dropped']}")
        
        # Report columns with high null percentages
        high_null_cols = {k: v for k, v in stats['null_percentages'].items() if v > 20}
        if high_null_cols:
            logger.warning(f"  High null percentage columns: {high_null_cols}")

# --- Schema Display (example for one DataFrame) ---
if silver_dfs:
    sample_table = list(silver_dfs.keys())[0]
    logger.info(f"Sample Silver DataFrame schema ({sample_table}):")
    silver_dfs[sample_table].printSchema()

logger.info("Silver Layer processing finished.")

# --- Stop Spark Session ---
spark.stop()
logger.info("SparkSession stopped.")

# Gold layer


In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    current_timestamp, lit, col, when, isnan, isnull, trim, upper, lower,
    regexp_replace, to_date, to_timestamp, coalesce, length, substring,
    round as spark_round, abs as spark_abs, split, size, concat, sum as spark_sum,
    count, max as spark_max, min as spark_min, avg, year, month, dayofmonth,
    dayofweek, quarter, weekofyear, date_format, row_number, rank, dense_rank,
    lag, lead, first, last, collect_list, collect_set, explode, array_contains,
    struct, desc, asc, monotonically_increasing_id, hash, md5, sha1,
    regexp_extract, split as spark_split, slice as spark_slice,
    date_add, date_sub, datediff, months_between, next_day, last_day,
    from_unixtime, unix_timestamp, date_trunc, concat_ws
)
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, DateType, TimestampType, BooleanType, LongType
from pyspark.sql.window import Window
import os
import logging
from datetime import datetime, date, timedelta

# --- Configuration ---
silver_dir = 'data/Silver'
gold_dir = 'data/Gold'

# Define path for the error log file
error_log_dir = 'logs'
error_log_file_path = os.path.join(error_log_dir, 'gold_layer_errors.log')

# Ensure directories exist
os.makedirs(gold_dir, exist_ok=True)
os.makedirs(error_log_dir, exist_ok=True)

# --- Logging Setup ---
# Clear existing handlers to prevent duplicate logs
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(error_log_file_path),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

logger.info("Starting PySpark Data Layer Processing (Gold Layer with SCD and Upsert)...")

# --- Spark Session Initialization ---
try:
    spark = SparkSession.builder.appName("TakeHomeExamGoldLayer") \
        .config("spark.sql.legacy.timeParserPolicy", "LEGACY") \
        .config("spark.sql.adaptive.enabled", "true") \
        .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
        .getOrCreate()
    logger.info("SparkSession created successfully.")
except Exception as e:
    logger.error(f"Error initializing SparkSession: {e}")
    exit(1)

def generate_surrogate_key(df, columns):
    """Generate surrogate key based on multiple columns"""
    concat_cols = [coalesce(col(c).cast(StringType()), lit("NULL")) for c in columns]
    return df.withColumn("surrogate_key",
                         spark_abs(hash(concat(*concat_cols))))

def safe_table_read(table_path, table_name):
    """Safely read a table with error handling"""
    try:
        return spark.read.parquet(table_path)
    except Exception as e:
        logger.warning(f"Could not read {table_name} from {table_path}: {e}")
        return None

def add_scd_columns(df):
    """Add SCD Type 2 columns to dimension tables"""
    return df.withColumn("effective_date", current_timestamp()) \
             .withColumn("end_date", lit(None).cast(TimestampType())) \
             .withColumn("is_current", lit(True)) \
             .withColumn("version", lit(1))

def upsert_dimension_table(new_data, table_path, table_name, business_key_columns):
    """
    Perform SCD Type 2 upsert on dimension table
    """
    logger.info(f"Performing SCD Type 2 upsert for {table_name}...")
    
    # Check if table exists
    if os.path.exists(table_path):
        try:
            existing_data = spark.read.parquet(table_path)
            logger.info(f"Found existing {table_name} with {existing_data.count()} records")
            
            # Get current records (is_current = True)
            current_records = existing_data.filter(col("is_current") == True)
            
            # Create a comparison key for detecting changes
            comparison_cols = [c for c in new_data.columns if c not in ['created_date', 'modified_date', 'effective_date', 'end_date', 'is_current', 'version', 'surrogate_key']]
            
            # Join to find changes
            joined_df = new_data.alias("new").join(
                current_records.alias("existing"),
                [col(f"new.{bk}") == col(f"existing.{bk}") for bk in business_key_columns],
                "left"
            )
            
            # Identify new records (no match in existing)
            new_records = joined_df.filter(col("existing.surrogate_key").isNull()) \
                                   .select("new.*")
            
            # Identify changed records
            changed_records = joined_df.filter(col("existing.surrogate_key").isNotNull())
            
            # Check for actual changes in data
            change_conditions = []
            for col_name in comparison_cols:
                if col_name in business_key_columns:
                    continue
                change_conditions.append(
                    coalesce(col(f"new.{col_name}"), lit("NULL")) != 
                    coalesce(col(f"existing.{col_name}"), lit("NULL"))
                )
            
            if change_conditions:
                changed_filter = change_conditions[0]
                for condition in change_conditions[1:]:
                    changed_filter = changed_filter | condition
                
                actually_changed = changed_records.filter(changed_filter)
                unchanged = changed_records.filter(~changed_filter)
            else:
                actually_changed = spark.createDataFrame([], new_data.schema)
                unchanged = changed_records
            
            # Handle unchanged records - keep existing
            unchanged_existing = unchanged.select("existing.*")
            
            # Handle new records - add with SCD columns
            final_new_records = add_scd_columns(new_records)
            
            # Handle changed records
            if actually_changed.count() > 0:
                # Close existing records
                closed_records = actually_changed.select("existing.*") \
                    .withColumn("end_date", current_timestamp()) \
                    .withColumn("is_current", lit(False))
                
                # Create new versions of changed records
                new_versions = actually_changed.select("new.*") \
                    .join(actually_changed.select("existing.version", *[f"existing.{bk}" for bk in business_key_columns]),
                          [col(f"new.{bk}") == col(f"existing.{bk}") for bk in business_key_columns]) \
                    .withColumn("effective_date", current_timestamp()) \
                    .withColumn("end_date", lit(None).cast(TimestampType())) \
                    .withColumn("is_current", lit(True)) \
                    .withColumn("version", col("existing.version") + 1) \
                    .drop(*[f"existing.{bk}" for bk in business_key_columns]) \
                    .drop("existing.version")
                
                # Combine all records
                historical_records = existing_data.filter(col("is_current") == False)
                result_df = historical_records.union(unchanged_existing) \
                                             .union(closed_records) \
                                             .union(final_new_records) \
                                             .union(new_versions)
            else:
                # No changes, just add new records
                historical_records = existing_data.filter(col("is_current") == False)
                result_df = historical_records.union(unchanged_existing) \
                                             .union(final_new_records)
                
            logger.info(f"SCD upsert completed: {result_df.count()} total records")
            return result_df
            
        except Exception as e:
            logger.error(f"Error during SCD upsert for {table_name}: {e}")
            # Fall back to initial load
            return add_scd_columns(new_data)
    else:
        # Initial load - add SCD columns
        logger.info(f"Initial load for {table_name}")
        return add_scd_columns(new_data)

def upsert_fact_table(new_data, table_path, table_name, business_key_columns):
    """
    Perform upsert on fact table (Type 1 - overwrite existing records)
    """
    logger.info(f"Performing fact table upsert for {table_name}...")
    
    if os.path.exists(table_path):
        try:
            existing_data = spark.read.parquet(table_path)
            logger.info(f"Found existing {table_name} with {existing_data.count()} records")
            
            # For fact tables, we typically do a merge (upsert)
            # Remove existing records that match business keys
            existing_to_keep = existing_data.alias("existing").join(
                new_data.alias("new"),
                [col(f"existing.{bk}") == col(f"new.{bk}") for bk in business_key_columns],
                "left_anti"
            )
            
            # Combine with new data
            result_df = existing_to_keep.union(new_data)
            
            logger.info(f"Fact upsert completed: {result_df.count()} total records")
            return result_df
            
        except Exception as e:
            logger.error(f"Error during fact upsert for {table_name}: {e}")
            return new_data
    else:
        logger.info(f"Initial load for fact table {table_name}")
        return new_data

def create_date_dimension():
    """Create comprehensive date dimension table"""
    logger.info("Creating date dimension...")

    # Create date range from 2010 to 2030
    start_date = date(2010, 1, 1)
    end_date = date(2030, 12, 31)

    # Generate date range
    date_list = []
    current_date = start_date
    while current_date <= end_date:
        date_list.append((current_date,))
        current_date += timedelta(days=1)

    schema = StructType([StructField("date", DateType(), True)])
    date_df = spark.createDataFrame(date_list, schema)

    # Add date dimension attributes
    date_dim = date_df.select(
        col("date").alias("date_key"),
        col("date"),
        year("date").alias("year"),
        month("date").alias("month"),
        dayofmonth("date").alias("day"),
        dayofweek("date").alias("day_of_week"),
        quarter("date").alias("quarter"),
        weekofyear("date").alias("week_of_year"),
        date_format("date", "MMMM").alias("month_name"),
        date_format("date", "EEEE").alias("day_name"),
        when(dayofweek("date").isin([1, 7]), True).otherwise(False).alias("is_weekend"),
        date_format("date", "yyyy-MM").alias("year_month"),
        concat_ws("Q", year("date").cast("string"), quarter("date").cast("string")).alias("year_quarter"),
        current_timestamp().alias("created_date"),
        current_timestamp().alias("modified_date")
    )

    return date_dim

def create_customer_dimension():
    """Create customer dimension with person info"""
    logger.info("Creating customer dimension...")

    # Read Silver layer data
    customer_df = safe_table_read(os.path.join(silver_dir, 'customer'), 'customer')
    person_df = safe_table_read(os.path.join(silver_dir, 'person'), 'person')
    
    if customer_df is None:
        logger.error("Customer table not found in Silver layer")
        return None
    
    if person_df is None:
        logger.warning("Person table not found, creating customer dimension without person details")
        customer_dim = customer_df.select(
            col("CustomerID").alias("customer_key"),
            col("CustomerID").alias("customer_id"),
            col("AccountNumber").alias("account_number"),
            lit("Unknown").alias("first_name"),
            lit("Unknown").alias("middle_name"),
            lit("Unknown").alias("last_name"),
            lit("Unknown").alias("full_name"),
            lit("Unknown").alias("person_type"),
            col("TerritoryID").alias("territory_id"),
            current_timestamp().alias("created_date"),
            current_timestamp().alias("modified_date")
        )
    else:
        # Join customer with person data
        customer_dim = customer_df.join(
            person_df,
            customer_df.PersonID == person_df.BusinessEntityID,
            "left"
        ).select(
            customer_df.CustomerID.alias("customer_key"),
            customer_df.CustomerID.alias("customer_id"),
            customer_df.AccountNumber.alias("account_number"),
            coalesce(person_df.FirstName, lit("Unknown")).alias("first_name"),
            coalesce(person_df.MiddleName, lit("")).alias("middle_name"),
            coalesce(person_df.LastName, lit("Unknown")).alias("last_name"),
            concat(
                coalesce(person_df.FirstName, lit("Unknown")), 
                lit(" "),
                coalesce(person_df.LastName, lit("Unknown"))
            ).alias("full_name"),
            coalesce(person_df.PersonType, lit("Unknown")).alias("person_type"),
            customer_df.TerritoryID.alias("territory_id"),
            current_timestamp().alias("created_date"),
            current_timestamp().alias("modified_date")
        )

    # Add surrogate key
    customer_dim = generate_surrogate_key(customer_dim, ["customer_id"])
    
    return customer_dim

def create_product_dimension():
    """Create product dimension with category information"""
    logger.info("Creating product dimension...")

    product_df = safe_table_read(os.path.join(silver_dir, 'product'), 'product')
    
    if product_df is None:
        logger.error("Product table not found in Silver layer")
        return None

    # Try to read additional tables for product categories
    product_category_df = safe_table_read(os.path.join(silver_dir, 'productcategory'), 'productcategory')
    product_subcategory_df = safe_table_read(os.path.join(silver_dir, 'productsubcategory'), 'productsubcategory')

    if product_category_df is not None and product_subcategory_df is not None:
        # Join product with category information
        product_dim = product_df.alias("p") \
            .join(product_subcategory_df.alias("psc"),
                  col("p.ProductSubcategoryID") == col("psc.ProductSubcategoryID"), "left") \
            .join(product_category_df.alias("pc"),
                  col("psc.ProductCategoryID") == col("pc.ProductCategoryID"), "left")

        product_dim = product_dim.select(
            col("p.ProductID").alias("product_key"),
            col("p.ProductID").alias("product_id"),
            col("p.Name").alias("product_name"),
            col("p.ProductNumber").alias("product_number"),
            coalesce(col("p.Color"), lit("Unknown")).alias("color"),
            coalesce(col("p.Size"), lit("Unknown")).alias("size"),
            coalesce(col("p.Weight"), lit(0.0)).alias("weight"),
            coalesce(col("p.ListPrice"), lit(0.0)).alias("list_price"),
            coalesce(col("p.StandardCost"), lit(0.0)).alias("standard_cost"),
            coalesce(col("psc.Name"), lit("Unknown")).alias("subcategory_name"),
            coalesce(col("pc.Name"), lit("Unknown")).alias("category_name"),
            current_timestamp().alias("created_date"),
            current_timestamp().alias("modified_date")
        )
    else:
        # Create dimension without category information
        product_dim = product_df.select(
            col("ProductID").alias("product_key"),
            col("ProductID").alias("product_id"),
            col("Name").alias("product_name"),
            col("ProductNumber").alias("product_number"),
            coalesce(col("Color"), lit("Unknown")).alias("color"),
            coalesce(col("Size"), lit("Unknown")).alias("size"),
            coalesce(col("Weight"), lit(0.0)).alias("weight"),
            coalesce(col("ListPrice"), lit(0.0)).alias("list_price"),
            coalesce(col("StandardCost"), lit(0.0)).alias("standard_cost"),
            lit("Unknown").alias("subcategory_name"),
            lit("Unknown").alias("category_name"),
            current_timestamp().alias("created_date"),
            current_timestamp().alias("modified_date")
        )

    # Add surrogate key
    product_dim = generate_surrogate_key(product_dim, ["product_id"])

    return product_dim

def create_geography_dimension():
    """Create geography dimension"""
    logger.info("Creating geography dimension...")

    # Read address data
    address_df = safe_table_read(os.path.join(silver_dir, 'address'), 'address')
    
    if address_df is None:
        logger.error("Address table not found in Silver layer")
        return None

    # Try to read geography reference tables
    territory_df = safe_table_read(os.path.join(silver_dir, 'salesterritory'), 'salesterritory')
    state_df = safe_table_read(os.path.join(silver_dir, 'stateprovince'), 'stateprovince')
    country_df = safe_table_read(os.path.join(silver_dir, 'countryregion'), 'countryregion')

    if all([territory_df is not None, state_df is not None, country_df is not None]):
        geography_dim = address_df.alias("a") \
            .join(state_df.alias("s"), col("a.StateProvinceID") == col("s.StateProvinceID"), "left") \
            .join(country_df.alias("c"), col("s.CountryRegionCode") == col("c.CountryRegionCode"), "left") \
            .join(territory_df.alias("t"), col("s.TerritoryID") == col("t.TerritoryID"), "left")

        geography_dim = geography_dim.select(
            col("a.AddressID").alias("geography_key"),
            col("a.AddressID").alias("address_id"),
            coalesce(col("a.AddressLine1"), lit("Unknown")).alias("address_line1"),
            coalesce(col("a.AddressLine2"), lit("")).alias("address_line2"),
            coalesce(col("a.City"), lit("Unknown")).alias("city"),
            coalesce(col("a.PostalCode"), lit("Unknown")).alias("postal_code"),
            coalesce(col("s.Name"), lit("Unknown")).alias("state_name"),
            coalesce(col("s.StateProvinceCode"), lit("UN")).alias("state_code"),
            coalesce(col("c.Name"), lit("Unknown")).alias("country_name"),
            coalesce(col("c.CountryRegionCode"), lit("UN")).alias("country_code"),
            coalesce(col("t.Name"), lit("Unknown")).alias("territory_name"),
            current_timestamp().alias("created_date"),
            current_timestamp().alias("modified_date")
        )
    else:
        # Create simplified geography dimension
        geography_dim = address_df.select(
            col("AddressID").alias("geography_key"),
            col("AddressID").alias("address_id"),
            coalesce(col("AddressLine1"), lit("Unknown")).alias("address_line1"),
            coalesce(col("AddressLine2"), lit("")).alias("address_line2"),
            coalesce(col("City"), lit("Unknown")).alias("city"),
            coalesce(col("PostalCode"), lit("Unknown")).alias("postal_code"),
            lit("Unknown").alias("state_name"),
            lit("UN").alias("state_code"),
            lit("Unknown").alias("country_name"),
            lit("UN").alias("country_code"),
            lit("Unknown").alias("territory_name"),
            current_timestamp().alias("created_date"),
            current_timestamp().alias("modified_date")
        )

    # Add surrogate key
    geography_dim = generate_surrogate_key(geography_dim, ["address_id"])

    return geography_dim

def create_sales_fact_table():
    """Create main sales fact table"""
    logger.info("Creating sales fact table...")

    # Read Silver layer data
    order_header_df = safe_table_read(os.path.join(silver_dir, 'salesorderheader'), 'salesorderheader')
    order_detail_df = safe_table_read(os.path.join(silver_dir, 'salesorderdetail'), 'salesorderdetail')

    if order_header_df is None or order_detail_df is None:
        logger.error("Required sales tables not found in Silver layer")
        return None

    # Join order header and detail
    sales_fact = order_detail_df.alias("od").join(
        order_header_df.alias("oh"),
        col("od.SalesOrderID") == col("oh.SalesOrderID")
    )

    # Create fact table with measures and foreign keys
    sales_fact = sales_fact.select(
        # Surrogate key
        monotonically_increasing_id().alias("sales_fact_key"),

        # Business keys
        col("oh.SalesOrderID").alias("sales_order_id"),
        col("od.SalesOrderDetailID").alias("sales_order_detail_id"),

        # Foreign keys
        col("oh.CustomerID").alias("customer_key"),
        col("od.ProductID").alias("product_key"),
        coalesce(col("oh.BillToAddressID"), col("oh.ShipToAddressID")).alias("bill_to_geography_key"),
        coalesce(col("oh.ShipToAddressID"), col("oh.BillToAddressID")).alias("ship_to_geography_key"),
        col("oh.OrderDate").alias("order_date_key"),
        coalesce(col("oh.DueDate"), col("oh.OrderDate")).alias("due_date_key"),
        coalesce(col("oh.ShipDate"), col("oh.OrderDate")).alias("ship_date_key"),

        # Measures
        coalesce(col("od.OrderQty"), lit(0)).alias("order_quantity"),
        coalesce(col("od.UnitPrice"), lit(0.0)).alias("unit_price"),
        coalesce(col("od.UnitPriceDiscount"), lit(0.0)).alias("unit_price_discount"),
        coalesce(col("od.LineTotal"), lit(0.0)).alias("line_total"),
        coalesce(col("oh.SubTotal"), lit(0.0)).alias("order_subtotal"),
        coalesce(col("oh.TaxAmt"), lit(0.0)).alias("tax_amount"),
        coalesce(col("oh.Freight"), lit(0.0)).alias("freight"),
        coalesce(col("oh.TotalDue"), lit(0.0)).alias("total_due"),

        # Calculated measures
        (coalesce(col("od.OrderQty"), lit(0)) * coalesce(col("od.UnitPrice"), lit(0.0))).alias("gross_revenue"),
        (coalesce(col("od.LineTotal"), lit(0.0))).alias("net_revenue"),

        # Attributes
        coalesce(col("oh.Status"), lit(0)).alias("order_status"),
        coalesce(col("oh.OnlineOrderFlag"), lit(False)).alias("online_order_flag"),

        # Audit columns
        current_timestamp().alias("created_date"),
        current_timestamp().alias("modified_date")
    )

    return sales_fact

def create_comprehensive_revenue_table():
    """Create comprehensive table with all revenue analysis info merged"""
    logger.info("Creating comprehensive revenue analysis table...")

    # Check if all required tables exist
    required_tables = ['fact_sales', 'dim_customer', 'dim_product', 'dim_geography', 'dim_date']
    missing_tables = []
    
    for table in required_tables:
        table_path = os.path.join(gold_dir, table)
        if not os.path.exists(table_path):
            missing_tables.append(table)
            logger.error(f"Required table {table} not found at {table_path}")
    
    if missing_tables:
        logger.error(f"Cannot create comprehensive revenue table. Missing tables: {missing_tables}")
        return None

    try:
        # Read fact and dimension tables - get current records only for dimensions
        sales_fact = spark.read.parquet(os.path.join(gold_dir, 'fact_sales'))
        logger.info(f"Successfully read sales fact table with {sales_fact.count()} records")
        
        # For SCD tables, only get current records for the comprehensive view
        # Date dimension doesn't have SCD columns, so read it directly
        customer_dim = spark.read.parquet(os.path.join(gold_dir, 'dim_customer')).filter(col("is_current") == True)
        logger.info(f"Successfully read customer dimension with {customer_dim.count()} current records")
        
        product_dim = spark.read.parquet(os.path.join(gold_dir, 'dim_product')).filter(col("is_current") == True)
        logger.info(f"Successfully read product dimension with {product_dim.count()} current records")
        
        geography_dim = spark.read.parquet(os.path.join(gold_dir, 'dim_geography')).filter(col("is_current") == True)
        logger.info(f"Successfully read geography dimension with {geography_dim.count()} current records")
        
        date_dim = spark.read.parquet(os.path.join(gold_dir, 'dim_date'))  # No SCD filter for date dimension
        logger.info(f"Successfully read date dimension with {date_dim.count()} records")
        
    except Exception as e:
        logger.error(f"Error reading dimension tables: {e}")
        return None

    # Create comprehensive table with all dimensions joined
    comprehensive_df = sales_fact.alias("sf") \
        .join(customer_dim.alias("cd"), col("sf.customer_key") == col("cd.customer_key"), "left") \
        .join(product_dim.alias("pd"), col("sf.product_key") == col("pd.product_key"), "left") \
        .join(geography_dim.alias("gd"), col("sf.bill_to_geography_key") == col("gd.geography_key"), "left") \
        .join(date_dim.alias("dd"), col("sf.order_date_key") == col("dd.date_key"), "left")

    # Select all relevant columns for revenue analysis
    revenue_analysis_table = comprehensive_df.select(
        # Sales metrics
        col("sf.sales_fact_key"),
        col("sf.sales_order_id"),
        col("sf.sales_order_detail_id"),
        col("sf.order_quantity"),
        col("sf.unit_price"),
        col("sf.unit_price_discount"),
        col("sf.line_total"),
        col("sf.net_revenue"),
        col("sf.gross_revenue"),
        col("sf.order_subtotal"),
        col("sf.tax_amount"),
        col("sf.freight"),
        col("sf.total_due"),
        col("sf.order_status"),
        
        # Order status description
        when(col("sf.order_status") == 1, "In Process")
        .when(col("sf.order_status") == 2, "Approved")
        .when(col("sf.order_status") == 3, "Backordered")
        .when(col("sf.order_status") == 4, "Rejected")
        .when(col("sf.order_status") == 5, "Shipped")
        .when(col("sf.order_status") == 6, "Cancelled")
        .otherwise("Unknown").alias("order_status_desc"),
        
        col("sf.online_order_flag"),

        # Customer information
        col("cd.customer_id"),
        col("cd.account_number"),
        col("cd.first_name"),
        col("cd.middle_name"),
        col("cd.last_name"),
        col("cd.full_name"),
        col("cd.person_type"),

        # Product information
        col("pd.product_id"),
        col("pd.product_name"),
        col("pd.product_number"),
        col("pd.color"),
        col("pd.size"),
        col("pd.weight"),
        col("pd.list_price"),
        col("pd.standard_cost"),
        col("pd.subcategory_name"),
        col("pd.category_name"),

        # Geography information
        col("gd.address_id"),
        col("gd.address_line1"),
        col("gd.address_line2"),
        col("gd.city"),
        col("gd.postal_code"),
        col("gd.state_name"),
        col("gd.state_code"),
        col("gd.country_name"),
        col("gd.country_code"),
        col("gd.territory_name"),

        # Date information
        col("sf.order_date_key").alias("order_date"),
        col("dd.year").alias("order_year"),
        col("dd.month").alias("order_month"),
        col("dd.quarter").alias("order_quarter"),
        col("dd.month_name"),
        col("dd.day_name"),
        col("dd.year_month"),
        col("dd.year_quarter"),

        # Calculated fields for analysis
        (col("sf.net_revenue") * col("sf.order_quantity")).alias("total_line_revenue"),
        (col("pd.list_price") - col("pd.standard_cost")).alias("profit_margin"),
        current_timestamp().alias("created_date")
    )

    return revenue_analysis_table

def write_table_with_logging(df, table_path, table_name, is_dimension=False, business_keys=None):
    """Write table with logging, error handling, and upsert support"""
    try:
        if df is not None:
            if is_dimension and business_keys:
                # Use SCD Type 2 upsert for dimensions
                final_df = upsert_dimension_table(df, table_path, table_name, business_keys)
            elif not is_dimension and business_keys:
                # Use Type 1 upsert for facts
                final_df = upsert_fact_table(df, table_path, table_name, business_keys)
            else:
                # Default to overwrite for tables without business keys
                final_df = df
            
            record_count = final_df.count()
            final_df.write.mode("overwrite").parquet(table_path)
            logger.info(f"{table_name} written successfully with {record_count} records")
        else:
            logger.error(f"Failed to write {table_name} - DataFrame is None")
    except Exception as e:
        logger.error(f"Error writing {table_name}: {e}")

def main():
    """Main processing function with SCD and upsert capabilities"""
    logger.info("Starting Gold layer processing with SCD and upsert...")

    try:
        # Create Date Dimension (no SCD needed for date dimension typically)
        logger.info("Creating Date Dimension...")
        date_dim = create_date_dimension()
        write_table_with_logging(date_dim, os.path.join(gold_dir, 'dim_date'), 'Date Dimension')

        # Create Customer Dimension with SCD Type 2
        logger.info("Creating Customer Dimension...")
        customer_dim = create_customer_dimension()
        if customer_dim is not None:
            write_table_with_logging(customer_dim, os.path.join(gold_dir, 'dim_customer'), 'Customer Dimension', 
                                   is_dimension=True, business_keys=['customer_id'])
        else:
            logger.error("Failed to create Customer Dimension")

        # Create Product Dimension with SCD Type 2
        logger.info("Creating Product Dimension...")
        product_dim = create_product_dimension()
        if product_dim is not None:
            write_table_with_logging(product_dim, os.path.join(gold_dir, 'dim_product'), 'Product Dimension',
                                   is_dimension=True, business_keys=['product_id'])
        else:
            logger.error("Failed to create Product Dimension")

        # Create Geography Dimension with SCD Type 2
        logger.info("Creating Geography Dimension...")
        geography_dim = create_geography_dimension()
        if geography_dim is not None:
            write_table_with_logging(geography_dim, os.path.join(gold_dir, 'dim_geography'), 'Geography Dimension',
                                   is_dimension=True, business_keys=['address_id'])
        else:
            logger.error("Failed to create Geography Dimension")

        # Create Sales Fact Table with upsert
        logger.info("Creating Sales Fact Table...")
        sales_fact = create_sales_fact_table()
        if sales_fact is not None:
            write_table_with_logging(sales_fact, os.path.join(gold_dir, 'fact_sales'), 'Sales Fact Table',
                                   is_dimension=False, business_keys=['sales_order_id', 'sales_order_detail_id'])
        else:
            logger.error("Failed to create Sales Fact Table")

        # Create Comprehensive Revenue Analysis Table
        logger.info("Checking if all tables exist for comprehensive revenue analysis...")
        required_tables = ['fact_sales', 'dim_customer', 'dim_product', 'dim_geography', 'dim_date']
        tables_exist = all(os.path.exists(os.path.join(gold_dir, table)) for table in required_tables)
        
        if tables_exist:
            logger.info("All required tables exist, creating comprehensive revenue analysis table...")
            revenue_analysis_table = create_comprehensive_revenue_table()
            if revenue_analysis_table is not None:
                write_table_with_logging(revenue_analysis_table, os.path.join(gold_dir, 'revenue_analysis_comprehensive'), 
                                       'Comprehensive Revenue Analysis Table')
            else:
                logger.warning("Failed to create comprehensive revenue analysis table")
        else:
            missing_tables = [table for table in required_tables if not os.path.exists(os.path.join(gold_dir, table))]
            logger.warning(f"Not all required tables exist, skipping comprehensive revenue table creation. Missing: {missing_tables}")

        logger.info("Gold layer processing with SCD and upsert completed successfully!")

        # Display summary statistics
        logger.info("=== GOLD LAYER SUMMARY ===")
        gold_tables = [d for d in os.listdir(gold_dir) if os.path.isdir(os.path.join(gold_dir, d))]
        for table in gold_tables:
            try:
                df = spark.read.parquet(os.path.join(gold_dir, table))
                
                # Show SCD info for dimension tables
                if table.startswith('dim_') and 'is_current' in df.columns:
                    current_count = df.filter(col("is_current") == True).count()
                    total_count = df.count()
                    logger.info(f"{table}: {total_count} total records ({current_count} current, {total_count - current_count} historical), {len(df.columns)} columns")
                else:
                    logger.info(f"{table}: {df.count()} records, {len(df.columns)} columns")
                    
            except Exception as e:
                logger.error(f"Error reading {table}: {e}")

    except Exception as e:
        logger.error(f"Error in main processing: {e}")
        raise

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        logger.error("🔥 Unhandled exception occurred in main()", exc_info=True)
        import traceback
        traceback.print_exc()
        exit(1)
    finally:
        spark.stop()
        logger.info("SparkSession stopped.")

# Plots

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from pathlib import Path
import warnings
import logging
import os
from datetime import datetime

# --- Logging Setup ---
# Define path for the log file
log_dir = 'logs'
plots_log_file_path = os.path.join(log_dir, 'plots_generation.log')

# Ensure log directory exists
os.makedirs(log_dir, exist_ok=True)

# Clear existing handlers to prevent duplicate logs
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(plots_log_file_path),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

logger.info("Starting Plot Generation for Gold Layer Visualizations...")

# Set up plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

class GoldLayerVisualizer:
    def __init__(self, data_path):
        """
        Initialize the visualizer with the path to the merged parquet files
        and set up the plot output directory.
        """
        self.data_path = Path(data_path)
        self.df = None
        self.load_data()

        # Determine the base directory of the script
        script_dir = Path(__file__).parent
        self.base_output_dir = script_dir.parent / 'plots'
        
        self.obligated_plots_dir = self.base_output_dir / 'obligated_plots'
        self.additional_plots_dir = self.base_output_dir / 'additional_plots'
        
        self.obligated_plots_dir.mkdir(parents=True, exist_ok=True)
        self.additional_plots_dir.mkdir(parents=True, exist_ok=True)
        
        logger.info(f"Obligated plots will be saved to: {self.obligated_plots_dir}")
        logger.info(f"Additional plots will be saved to: {self.additional_plots_dir}")

    def load_data(self):
        """Load the merged parquet data"""
        try:
            if self.data_path.is_file():
                self.df = pd.read_parquet(self.data_path)
            else:
                self.df = pd.read_parquet(self.data_path)

            logger.info(f"Data loaded successfully: {len(self.df)} records, {len(self.df.columns)} columns")

            numeric_cols = ['unit_price', 'line_total', 'net_revenue', 'gross_revenue',
                           'order_subtotal', 'tax_amount', 'freight', 'total_due',
                           'list_price', 'standard_cost', 'weight']

            for col in numeric_cols:
                if col in self.df.columns:
                    if self.df[col].dtype == 'object':
                        self.df[col] = self.df[col].astype(str).str.replace(',', '.').astype(float)

            if 'order_date' in self.df.columns:
                self.df['order_date'] = pd.to_datetime(self.df['order_date'])

        except Exception as e:
            logger.error(f"Error loading data: {e}", exc_info=True)
            logger.error("Please check the data path and ensure parquet files exist")
            self.df = None # Ensure df is None if loading fails

    def _save_plot(self, filename, plot_type='obligated'):
        """Helper function to save and close the plot."""
        if plot_type == 'obligated':
            file_path = self.obligated_plots_dir / filename
        elif plot_type == 'additional':
            file_path = self.additional_plots_dir / filename
        else:
            logger.warning(f"Unknown plot type '{plot_type}'. Saving to obligated_plots folder.")
            file_path = self.obligated_plots_dir / filename

        try:
            plt.savefig(file_path)
            logger.info(f"Plot saved successfully: {file_path}")
        except Exception as e:
            logger.error(f"Error saving plot {filename}: {e}", exc_info=True)
        finally:
            plt.close()

    def plot_revenue_by_category(self):
        """Plot revenue by all product categories"""
        if self.df is None:
            logger.warning("DataFrame is None, skipping plot_revenue_by_category.")
            return

        plt.figure(figsize=(14, 8))
        category_revenue = self.df.groupby('category_name')['net_revenue'].sum().sort_values(ascending=False)

        ax = category_revenue.plot(kind='bar', color='skyblue', edgecolor='black')
        plt.title('Revenue by Product Category', fontsize=16, fontweight='bold')
        plt.xlabel('Product Category', fontsize=12)
        plt.ylabel('Revenue ($)', fontsize=12)
        plt.xticks(rotation=45, ha='right')

        for i, v in enumerate(category_revenue.values):
            ax.text(i, v + max(category_revenue) * 0.01, f'${v:,.0f}',
                   ha='center', va='bottom', fontweight='bold')

        plt.tight_layout()
        self._save_plot('revenue_by_category.png', plot_type='obligated')

    def plot_top_subcategories(self, top_n=10):
        """Plot top N subcategories by revenue"""
        if self.df is None:
            logger.warning("DataFrame is None, skipping plot_top_subcategories.")
            return

        plt.figure(figsize=(14, 8))
        subcategory_revenue = self.df.groupby('subcategory_name')['net_revenue'].sum().sort_values(ascending=False).head(top_n)

        ax = subcategory_revenue.plot(kind='barh', color='lightcoral', edgecolor='black')
        plt.title(f'Top {top_n} Subcategories by Revenue', fontsize=16, fontweight='bold')
        plt.xlabel('Revenue ($)', fontsize=12)
        plt.ylabel('Product Subcategory', fontsize=12)

        for i, v in enumerate(subcategory_revenue.values):
            ax.text(v + max(subcategory_revenue) * 0.01, i, f'${v:,.0f}',
                   ha='left', va='center', fontweight='bold')

        plt.tight_layout()
        self._save_plot(f'top_{top_n}_subcategories.png', plot_type='obligated')

    def plot_top_customers(self, top_n=10):
        """Plot top N customers by revenue"""
        if self.df is None:
            logger.warning("DataFrame is None, skipping plot_top_customers.")
            return

        plt.figure(figsize=(14, 8))
        customer_revenue = self.df.groupby('full_name')['net_revenue'].sum().sort_values(ascending=False).head(top_n)

        ax = customer_revenue.plot(kind='barh', color='lightgreen', edgecolor='black')
        plt.title(f'Top {top_n} Customers by Revenue', fontsize=16, fontweight='bold')
        plt.xlabel('Revenue ($)', fontsize=12)
        plt.ylabel('Customer Name', fontsize=12)

        for i, v in enumerate(customer_revenue.values):
            ax.text(v + max(customer_revenue) * 0.01, i, f'${v:,.0f}',
                   ha='left', va='center', fontweight='bold')

        plt.tight_layout()
        self._save_plot(f'top_{top_n}_customers.png', plot_type='obligated')

    def plot_revenue_by_order_status(self):
        """Plot revenue by all order statuses"""
        if self.df is None:
            logger.warning("DataFrame is None, skipping plot_revenue_by_order_status.")
            return

        plt.figure(figsize=(12, 8))
        status_revenue = self.df.groupby('order_status_desc')['net_revenue'].sum().sort_values(ascending=False)

        colors = plt.cm.Set3(np.linspace(0, 1, len(status_revenue)))
        wedges, texts, autotexts = plt.pie(status_revenue.values, labels=status_revenue.index,
                                          autopct='%1.1f%%', colors=colors, startangle=90)

        plt.title('Revenue Distribution by Order Status', fontsize=16, fontweight='bold')

        for autotext in autotexts:
            autotext.set_color('white')
            autotext.set_fontweight('bold')

        plt.axis('equal')
        self._save_plot('revenue_by_order_status.png', plot_type='obligated')

    def plot_top_countries(self, top_n=10):
        """Plot top N countries by revenue"""
        if self.df is None:
            logger.warning("DataFrame is None, skipping plot_top_countries.")
            return

        plt.figure(figsize=(14, 8))
        country_revenue = self.df.groupby('country_name')['net_revenue'].sum().sort_values(ascending=False).head(top_n)

        ax = country_revenue.plot(kind='bar', color='gold', edgecolor='black')
        plt.title(f'Top {top_n} Countries by Revenue', fontsize=16, fontweight='bold')
        plt.xlabel('Country', fontsize=12)
        plt.ylabel('Revenue ($)', fontsize=12)
        plt.xticks(rotation=45, ha='right')

        for i, v in enumerate(country_revenue.values):
            ax.text(i, v + max(country_revenue) * 0.01, f'${v:,.0f}',
                   ha='center', va='bottom', fontweight='bold')

        plt.tight_layout()
        self._save_plot(f'top_{top_n}_countries.png', plot_type='obligated')

    def plot_top_states(self, top_n=10):
        """Plot top N states by revenue"""
        if self.df is None:
            logger.warning("DataFrame is None, skipping plot_top_states.")
            return

        plt.figure(figsize=(14, 8))
        state_revenue = self.df.groupby('state_name')['net_revenue'].sum().sort_values(ascending=False).head(top_n)

        ax = state_revenue.plot(kind='barh', color='mediumpurple', edgecolor='black')
        plt.title(f'Top {top_n} States by Revenue', fontsize=16, fontweight='bold')
        plt.xlabel('Revenue ($)', fontsize=12)
        plt.ylabel('State', fontsize=12)

        for i, v in enumerate(state_revenue.values):
            ax.text(v + max(state_revenue) * 0.01, i, f'${v:,.0f}',
                   ha='left', va='center', fontweight='bold')

        plt.tight_layout()
        self._save_plot(f'top_{top_n}_states.png', plot_type='obligated')

    def plot_top_cities(self, top_n=10):
        """Plot top N cities by revenue"""
        if self.df is None:
            logger.warning("DataFrame is None, skipping plot_top_cities.")
            return

        plt.figure(figsize=(14, 8))
        city_revenue = self.df.groupby('city')['net_revenue'].sum().sort_values(ascending=False).head(top_n)

        ax = city_revenue.plot(kind='barh', color='orange', edgecolor='black')
        plt.title(f'Top {top_n} Cities by Revenue', fontsize=16, fontweight='bold')
        plt.xlabel('Revenue ($)', fontsize=12)
        plt.ylabel('City', fontsize=12)

        for i, v in enumerate(city_revenue.values):
            ax.text(v + max(city_revenue) * 0.01, i, f'${v:,.0f}',
                   ha='left', va='center', fontweight='bold')

        plt.tight_layout()
        self._save_plot(f'top_{top_n}_cities.png', plot_type='obligated')

    # Additional insightful plots
    def plot_monthly_revenue_trend(self):
        """Plot monthly revenue trend over time"""
        if self.df is None:
            logger.warning("DataFrame is None, skipping plot_monthly_revenue_trend.")
            return

        plt.figure(figsize=(16, 8))
        self.df['order_date'] = pd.to_datetime(self.df['order_date'])
        self.df['order_year'] = self.df['order_date'].dt.year
        self.df['order_month'] = self.df['order_date'].dt.month

        monthly_revenue = self.df.groupby(['order_year', 'order_month'])['net_revenue'].sum().reset_index()

        monthly_revenue['date'] = pd.to_datetime(monthly_revenue['order_year'].astype(str) + '-' +
                                                 monthly_revenue['order_month'].astype(str) + '-01')

        plt.plot(monthly_revenue['date'], monthly_revenue['net_revenue'],
                marker='o', linewidth=2, markersize=6, color='darkblue')
        plt.title('Monthly Revenue Trend', fontsize=16, fontweight='bold')
        plt.xlabel('Date', fontsize=12)
        plt.ylabel('Revenue ($)', fontsize=12)
        plt.grid(True, alpha=0.3)
        plt.xticks(rotation=45)

        z = np.polyfit(range(len(monthly_revenue)), monthly_revenue['net_revenue'], 1)
        p = np.poly1d(z)
        plt.plot(monthly_revenue['date'], p(range(len(monthly_revenue))),
                "r--", alpha=0.8, linewidth=2, label=f'Trend Line')
        plt.legend()

        plt.tight_layout()
        self._save_plot('monthly_revenue_trend.png', plot_type='additional')

    def plot_revenue_by_quarter(self):
        """Plot revenue by quarter"""
        if self.df is None:
            logger.warning("DataFrame is None, skipping plot_revenue_by_quarter.")
            return

        plt.figure(figsize=(12, 8))
        # Ensure order_quarter is available, if not, create it
        if 'order_quarter' not in self.df.columns:
            self.df['order_quarter'] = self.df['order_date'].dt.quarter
            logger.info("Created 'order_quarter' column as it was missing.")

        quarterly_revenue = self.df.groupby('order_quarter')['net_revenue'].sum()

        ax = quarterly_revenue.plot(kind='bar', color='teal', edgecolor='black')
        plt.title('Revenue by Quarter', fontsize=16, fontweight='bold')
        plt.xlabel('Quarter', fontsize=12)
        plt.ylabel('Revenue ($)', fontsize=12)
        plt.xticks(rotation=0)

        for i, v in enumerate(quarterly_revenue.values):
            ax.text(i, v + max(quarterly_revenue) * 0.01, f'${v:,.0f}',
                   ha='center', va='bottom', fontweight='bold')

        plt.tight_layout()
        self._save_plot('revenue_by_quarter.png', plot_type='additional')

    def plot_product_performance_matrix(self):
        """Plot product performance matrix (Revenue vs Quantity)"""
        if self.df is None:
            logger.warning("DataFrame is None, skipping plot_product_performance_matrix.")
            return

        plt.figure(figsize=(14, 10))

        product_perf = self.df.groupby('product_name').agg({
            'net_revenue': 'sum',
            'order_quantity': 'sum',
            'category_name': 'first'
        }).reset_index()

        categories = product_perf['category_name'].unique()
        colors = plt.cm.Set1(np.linspace(0, 1, len(categories)))

        for i, category in enumerate(categories):
            cat_data = product_perf[product_perf['category_name'] == category]
            plt.scatter(cat_data['order_quantity'], cat_data['net_revenue'],
                        alpha=0.7, s=60, color=colors[i], label=category, edgecolors='black')

        plt.title('Product Performance Matrix\n(Revenue vs Quantity Sold)', fontsize=16, fontweight='bold')
        plt.xlabel('Total Quantity Sold', fontsize=12)
        plt.ylabel('Total Revenue ($)', fontsize=12)
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.grid(True, alpha=0.3)

        plt.tight_layout()
        self._save_plot('product_performance_matrix.png', plot_type='additional')

    def plot_customer_type_analysis(self):
        """Plot customer type analysis"""
        if self.df is None:
            logger.warning("DataFrame is None, skipping plot_customer_type_analysis.")
            return

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

        customer_type_revenue = self.df.groupby('person_type')['net_revenue'].sum()
        ax1.pie(customer_type_revenue.values, labels=customer_type_revenue.index,
                autopct='%1.1f%%', startangle=90, colors=plt.cm.Pastel1.colors)
        ax1.set_title('Revenue by Customer Type', fontsize=14, fontweight='bold')

        avg_order_value = self.df.groupby('person_type')['net_revenue'].mean()
        bars = ax2.bar(avg_order_value.index, avg_order_value.values,
                       color='lightblue', edgecolor='black')
        ax2.set_title('Average Order Value by Customer Type', fontsize=14, fontweight='bold')
        ax2.set_xlabel('Customer Type')
        ax2.set_ylabel('Average Order Value ($)')

        for bar in bars:
            height = bar.get_height()
            ax2.text(bar.get_x() + bar.get_width()/2., height + max(avg_order_value) * 0.01,
                    f'${height:,.0f}', ha='center', va='bottom', fontweight='bold')

        plt.tight_layout()
        self._save_plot('customer_type_analysis.png', plot_type='additional')

    def plot_territory_performance(self):
        """Plot territory performance comparison"""
        if self.df is None:
            logger.warning("DataFrame is None, skipping plot_territory_performance.")
            return

        plt.figure(figsize=(14, 8))
        territory_revenue = self.df.groupby('territory_name')['net_revenue'].sum().sort_values(ascending=False)

        ax = territory_revenue.plot(kind='bar', color='darkorange', edgecolor='black')
        plt.title('Revenue by Sales Territory', fontsize=16, fontweight='bold')
        plt.xlabel('Territory', fontsize=12)
        plt.ylabel('Revenue ($)', fontsize=12)
        plt.xticks(rotation=45, ha='right')

        for i, v in enumerate(territory_revenue.values):
            ax.text(i, v + max(territory_revenue) * 0.01, f'${v:,.0f}',
                   ha='center', va='bottom', fontweight='bold')

        plt.tight_layout()
        self._save_plot('territory_performance.png', plot_type='additional')

    def generate_summary_stats(self):
        """Generate and display summary statistics"""
        if self.df is None:
            logger.warning("DataFrame is None, skipping summary statistics generation.")
            return

        logger.info("=" * 50)
        logger.info("GOLD LAYER DATA SUMMARY STATISTICS")
        logger.info("=" * 50)

        logger.info(f"Total Records: {len(self.df):,}")
        
        # Check if 'order_date' column exists and is not empty before calling min/max
        if 'order_date' in self.df.columns and not self.df['order_date'].empty:
            logger.info(f"Date Range: {self.df['order_date'].min()} to {self.df['order_date'].max()}")
        else:
            logger.warning("Order date column is missing or empty, date range not available.")

        logger.info(f"Total Revenue: ${self.df['net_revenue'].sum():,.2f}")
        logger.info(f"Average Order Value: ${self.df['net_revenue'].mean():,.2f}")
        logger.info(f"Total Orders: {self.df['sales_order_id'].nunique():,}")
        logger.info(f"Unique Customers: {self.df['customer_id'].nunique():,}")
        logger.info(f"Unique Products: {self.df['product_id'].nunique():,}")
        logger.info(f"Countries: {self.df['country_name'].nunique()}")
        logger.info(f"States: {self.df['state_name'].nunique()}")
        logger.info(f"Cities: {self.df['city'].nunique()}")

        logger.info("\nTop 5 Revenue Generating:")
        # Check if columns exist before grouping
        if 'category_name' in self.df.columns and 'net_revenue' in self.df.columns:
            logger.info(f"Categories: {list(self.df.groupby('category_name')['net_revenue'].sum().sort_values(ascending=False).head().index)}")
        else:
            logger.warning("Category name or net revenue column missing, skipping top categories.")
        
        if 'country_name' in self.df.columns and 'net_revenue' in self.df.columns:
            logger.info(f"Countries: {list(self.df.groupby('country_name')['net_revenue'].sum().sort_values(ascending=False).head().index)}")
        else:
            logger.warning("Country name or net revenue column missing, skipping top countries.")

        if 'state_name' in self.df.columns and 'net_revenue' in self.df.columns:
            logger.info(f"States: {list(self.df.groupby('state_name')['net_revenue'].sum().sort_values(ascending=False).head().index)}")
        else:
            logger.warning("State name or net revenue column missing, skipping top states.")

    def generate_all_plots(self):
        """Generate all the requested plots plus additional ones"""
        logger.info("Generating comprehensive visualizations...")

        self.generate_summary_stats()

        logger.info("\nGenerating Obligated Plots:")
        self.plot_revenue_by_category()
        self.plot_top_subcategories(10)
        self.plot_top_customers(10)
        self.plot_revenue_by_order_status()
        self.plot_top_countries(10)
        self.plot_top_states(10)
        self.plot_top_cities(10)

        logger.info("\nGenerating Additional Insightful Plots:")
        self.plot_monthly_revenue_trend()
        self.plot_revenue_by_quarter()
        self.plot_product_performance_matrix()
        self.plot_customer_type_analysis()
        self.plot_territory_performance()

        logger.info("\nAll visualizations generated successfully!")

# Usage example
if __name__ == "__main__":
    gold_dir = 'data/Gold'
    comprehensive_data_path = Path(gold_dir) / 'revenue_analysis_comprehensive'

    visualizer = GoldLayerVisualizer(data_path=comprehensive_data_path)
    visualizer.generate_all_plots()