In [2]:
##Sales Analytics Data Pipeline - ETL with PySpark and Star Schema Modeling
##Author: Tariqul Ismail, Data Engineering Team Lead
##Description: Complete ETL/ELT pipeline for transforming raw sales data into a star schema model

import os
import sys
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window
from pyspark.sql.functions import sum as spark_sum
import logging
import builtins

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class SalesDataETL:
    def __init__(self, app_name="SalesAnalyticsETL"):
        """Initialize Spark session and configuration"""
        self.spark = SparkSession.builder \
            .appName(app_name) \
            .config("spark.sql.adaptive.enabled", "true") \
            .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
            .config("spark.sql.warehouse.dir", "/tmp/spark-warehouse") \
            .getOrCreate()
        
        # Set log level to reduce noise
        self.spark.sparkContext.setLogLevel("WARN")
        
        logger.info(f"Spark session initialized: {app_name}")
    
    def extract_data(self, data_path):
        """Extract data from CSV files with schema validation"""
        logger.info("Starting data extraction...")
        
        try:
            # Define schemas for better performance and data quality
            sales_schema = StructType([
                StructField("transaction_id", StringType(), False),
                StructField("customer_id", StringType(), False),
                StructField("product_id", StringType(), False),
                StructField("location_id", StringType(), False),
                StructField("transaction_date", DateType(), False),
                StructField("quantity", IntegerType(), False),
                StructField("unit_price", DoubleType(), False),
                StructField("gross_amount", DoubleType(), False),
                StructField("discount_amount", DoubleType(), True),
                StructField("net_amount", DoubleType(), False),
                StructField("channel", StringType(), True),
                StructField("payment_method", StringType(), True),
                StructField("sales_rep_id", StringType(), True),
                StructField("promo_code", StringType(), True)
            ])
            
            # Read CSV files
            self.raw_sales = self.spark.read \
                .option("header", "true") \
                .option("inferSchema", "false") \
                .schema(sales_schema) \
                .csv(f"{data_path}/sales.csv")
            
            self.raw_customers = self.spark.read \
                .option("header", "true") \
                .option("inferSchema", "true") \
                .csv(f"{data_path}/customers.csv")
            
            self.raw_products = self.spark.read \
                .option("header", "true") \
                .option("inferSchema", "true") \
                .csv(f"{data_path}/products.csv")
            
            self.raw_locations = self.spark.read \
                .option("header", "true") \
                .option("inferSchema", "true") \
                .csv(f"{data_path}/locations.csv")
            
            # Cache frequently used datasets
            self.raw_sales.cache()
            self.raw_customers.cache()
            self.raw_products.cache()
            self.raw_locations.cache()
            
            logger.info(f"Data extraction completed:")
            logger.info(f"  - Sales: {self.raw_sales.count():,} records")
            logger.info(f"  - Customers: {self.raw_customers.count():,} records")
            logger.info(f"  - Products: {self.raw_products.count():,} records")
            logger.info(f"  - Locations: {self.raw_locations.count():,} records")
            
        except Exception as e:
            logger.error(f"Error during data extraction: {str(e)}")
            raise
    
    def transform_data(self):
        """Apply comprehensive data transformations"""
        logger.info("Starting data transformation...")
        
        try:
            # 1. Clean Sales Data
            self.clean_sales = self._clean_sales_data()
            
            # 2. Clean Dimension Tables
            self.clean_customers = self._clean_customers_data()
            self.clean_products = self._clean_products_data()
            self.clean_locations = self._clean_locations_data()
            
            # 3. Create Date Dimension
            self.dim_dates = self._create_date_dimension()
            
            # 4. Create Star Schema Tables
            self.fact_sales = self._create_fact_sales()
            self.dim_customers_final = self._create_dim_customers()
            self.dim_products_final = self._create_dim_products()
            self.dim_locations_final = self._create_dim_locations()
            
            logger.info("Data transformation completed successfully")
            
        except Exception as e:
            logger.error(f"Error during data transformation: {str(e)}")
            raise
    
    def _clean_sales_data(self):
        """Clean and validate sales transaction data"""
        logger.info("Cleaning sales data...")
        
        # Remove duplicates based on business logic
        window_spec = Window.partitionBy("customer_id", "product_id", "location_id", "transaction_date", "net_amount").orderBy("transaction_date")
        
        cleaned_sales = self.raw_sales \
            .withColumn("row_num", row_number().over(window_spec)) \
            .filter(col("row_num") == 1) \
            .drop("row_num")
        
        # Data quality checks and corrections
        cleaned_sales = cleaned_sales \
            .filter(col("transaction_id").isNotNull()) \
            .filter(col("customer_id").isNotNull()) \
            .filter(col("product_id").isNotNull()) \
            .filter(col("location_id").isNotNull()) \
            .filter(col("quantity") > 0) \
            .filter(col("unit_price") > 0) \
            .filter(col("net_amount") >= 0) \
            .withColumn("discount_amount", 
                       when(col("discount_amount").isNull(), 0.0).otherwise(col("discount_amount"))) \
            .withColumn("revenue", col("net_amount")) \
            .withColumn("profit", col("net_amount") - (col("quantity") * col("unit_price") * 0.6)) \
            .withColumn("year", year(col("transaction_date"))) \
            .withColumn("month", month(col("transaction_date"))) \
            .withColumn("day", dayofmonth(col("transaction_date"))) \
            .withColumn("quarter", quarter(col("transaction_date"))) \
            .withColumn("day_of_week", dayofweek(col("transaction_date"))) \
            .withColumn("is_weekend", when(col("day_of_week").isin([1, 7]), True).otherwise(False))
        
        logger.info(f"Sales data cleaned: {cleaned_sales.count():,} records after deduplication and validation")
        return cleaned_sales
    
    def _clean_customers_data(self):
        """Clean customer data with deduplication"""
        logger.info("Cleaning customer data...")
        
        # Deduplicate customers based on email
        window_spec = Window.partitionBy("email").orderBy(desc("registration_date"))
        
        cleaned_customers = self.raw_customers \
            .withColumn("row_num", row_number().over(window_spec)) \
            .filter(col("row_num") == 1) \
            .drop("row_num") \
            .filter(col("customer_id").isNotNull()) \
            .filter(col("email").isNotNull()) \
            .withColumn("full_name", concat(col("first_name"), lit(" "), col("last_name"))) \
            .withColumn("customer_age", 
                       floor(datediff(current_date(), col("date_of_birth")) / 365.25))
        
        logger.info(f"Customer data cleaned: {cleaned_customers.count():,} records")
        return cleaned_customers
    
    def _clean_products_data(self):
        """Clean product data"""
        logger.info("Cleaning product data...")
        
        cleaned_products = self.raw_products \
            .filter(col("product_id").isNotNull()) \
            .filter(col("unit_price") > 0) \
            .withColumn("profit_margin", 
                       (col("unit_price") - col("cost_price")) / col("unit_price")) \
            .withColumn("is_active", when(col("status") == "Active", True).otherwise(False))
        
        logger.info(f"Product data cleaned: {cleaned_products.count():,} records")
        return cleaned_products
    
    def _clean_locations_data(self):
        """Clean location data"""
        logger.info("Cleaning location data...")
        
        cleaned_locations = self.raw_locations \
            .filter(col("location_id").isNotNull()) \
            .withColumn("store_size_category",
                       when(col("store_size_sqft") < 5000, "Small")
                       .when(col("store_size_sqft") < 20000, "Medium")
                       .otherwise("Large"))
        
        logger.info(f"Location data cleaned: {cleaned_locations.count():,} records")
        return cleaned_locations
    
    def _create_date_dimension(self):
        """Create comprehensive date dimension table"""
        logger.info("Creating date dimension...")
        
        # Get date range from sales data
        date_range = self.clean_sales.select(
            min("transaction_date").alias("min_date"),
            max("transaction_date").alias("max_date")
        ).collect()[0]
        
        # Create date range
        from datetime import datetime, timedelta
        
        start_date = date_range["min_date"]
        end_date = date_range["max_date"]
        
        # Generate date range using Spark SQL
        self.spark.sql(f"""
            SELECT explode(sequence(to_date('{start_date}'), to_date('{end_date}'), interval 1 day)) as date_value
        """).createOrReplaceTempView("date_range")
        
        dim_dates = self.spark.sql("""
            SELECT 
                date_format(date_value, 'yyyyMMdd') as date_id,
                date_value as full_date,
                year(date_value) as year,
                month(date_value) as month,
                dayofmonth(date_value) as day,
                quarter(date_value) as quarter,
                dayofweek(date_value) as day_of_week,
                date_format(date_value, 'EEEE') as day_name,
                date_format(date_value, 'MMMM') as month_name,
                case when dayofweek(date_value) in (1, 7) then true else false end as is_weekend,
                case when month(date_value) in (12, 1, 2) then 'Winter'
                     when month(date_value) in (3, 4, 5) then 'Spring'
                     when month(date_value) in (6, 7, 8) then 'Summer'
                     else 'Fall' end as season,
                weekofyear(date_value) as week_of_year
            FROM date_range
        """)
        
        logger.info(f"Date dimension created: {dim_dates.count():,} records")
        return dim_dates
    
    def _create_fact_sales(self):
        """Create fact sales table with all measures"""
        logger.info("Creating fact sales table...")
        
        fact_sales = self.clean_sales \
            .select(
                col("transaction_id"),
                col("customer_id"),
                col("product_id"),
                col("location_id"),
                date_format(col("transaction_date"), "yyyyMMdd").alias("date_id"),
                col("quantity"),
                col("unit_price"),
                col("gross_amount"),
                col("discount_amount"),
                col("net_amount").alias("revenue"),
                col("profit"),
                col("channel"),
                col("payment_method"),
                col("sales_rep_id"),
                col("promo_code"),
                current_timestamp().alias("load_timestamp")
            )
        
        logger.info(f"Fact sales table created: {fact_sales.count():,} records")
        return fact_sales
    
    def _create_dim_customers(self):
        """Create customer dimension table"""
        logger.info("Creating customer dimension...")
        
        dim_customers = self.clean_customers \
            .select(
                col("customer_id"),
                col("full_name"),
                col("first_name"),
                col("last_name"),
                col("email"),
                col("phone"),
                col("customer_age"),
                col("gender"),
                col("age_group"),
                col("customer_segment"),
                col("registration_date"),
                col("preferred_contact"),
                col("loyalty_points"),
                col("city"),
                col("state"),
                col("zip_code"),
                current_timestamp().alias("load_timestamp")
            )
        
        logger.info(f"Customer dimension created: {dim_customers.count():,} records")
        return dim_customers
    
    def _create_dim_products(self):
        """Create product dimension table"""
        logger.info("Creating product dimension...")
        
        dim_products = self.clean_products \
            .select(
                col("product_id"),
                col("product_name"),
                col("category"),
                col("subcategory"),
                col("brand"),
                col("unit_price"),
                col("cost_price"),
                col("profit_margin"),
                col("weight_lbs"),
                col("launch_date"),
                col("is_active"),
                current_timestamp().alias("load_timestamp")
            )
        
        logger.info(f"Product dimension created: {dim_products.count():,} records")
        return dim_products
    
    def _create_dim_locations(self):
        """Create location dimension table"""
        logger.info("Creating location dimension...")
        
        dim_locations = self.clean_locations \
            .select(
                col("location_id"),
                col("city"),
                col("state"),
                col("region"),
                col("country"),
                col("zip_code"),
                col("store_type"),
                col("store_size_sqft"),
                col("store_size_category"),
                col("opening_date"),
                current_timestamp().alias("load_timestamp")
            )
        
        logger.info(f"Location dimension created: {dim_locations.count():,} records")
        return dim_locations
    
    def validate_data(self):
        """Comprehensive data validation and quality checks"""
        logger.info("Starting data validation...")
        
        validation_results = {}
        
        try:
            # 1. Record count validation
            fact_count = self.fact_sales.count()
            validation_results['fact_sales_count'] = fact_count
            
            # 2. Null value checks in foreign keys
            null_customer_ids = self.fact_sales.filter(col("customer_id").isNull()).count()
            null_product_ids = self.fact_sales.filter(col("product_id").isNull()).count()
            null_location_ids = self.fact_sales.filter(col("location_id").isNull()).count()
            null_date_ids = self.fact_sales.filter(col("date_id").isNull()).count()
            
            validation_results['null_foreign_keys'] = {
                'customer_id': null_customer_ids,
                'product_id': null_product_ids,
                'location_id': null_location_ids,
                'date_id': null_date_ids
            }
            
            # 3. Business logic validation
            negative_revenue = self.fact_sales.filter(col("revenue") < 0).count()
            invalid_quantities = self.fact_sales.filter(col("quantity") <= 0).count()
            
            validation_results['business_rules'] = {
                'negative_revenue': negative_revenue,
                'invalid_quantities': invalid_quantities
            }
            
            # 4. Referential integrity checks
            orphan_customers = self.fact_sales.join(
                self.dim_customers_final, "customer_id", "left_anti"
            ).count()
            
            orphan_products = self.fact_sales.join(
                self.dim_products_final, "product_id", "left_anti"
            ).count()
            
            orphan_locations = self.fact_sales.join(
                self.dim_locations_final, "location_id", "left_anti"
            ).count()
            
            validation_results['referential_integrity'] = {
                'orphan_customers': orphan_customers,
                'orphan_products': orphan_products,
                'orphan_locations': orphan_locations
            }
            
            # 5. Data distribution checks
            date_range_check = self.fact_sales.select(
                min("date_id").alias("min_date"),
                max("date_id").alias("max_date")
            ).collect()[0]
            
            validation_results['date_range'] = {
                'min_date': date_range_check['min_date'],
                'max_date': date_range_check['max_date']
            }
            
            # Print validation summary
            logger.info("=== DATA VALIDATION SUMMARY ===")
            logger.info(f"Total fact records: {fact_count:,}")
            #logger.info(f"Null foreign keys: {sum(validation_results['null_foreign_keys'].values())}")
            logger.info(f"Null foreign keys: {builtins.sum(validation_results['null_foreign_keys'].values())}")
            #logger.info(f"Business rule violations: {sum(validation_results['business_rules'].values())}")
            logger.info(f"Business rule violations: {builtins.sum(validation_results['business_rules'].values())}")
            #logger.info(f"Referential integrity issues: {sum(validation_results['referential_integrity'].values())}")
            logger.info(f"Referential integrity issues: {builtins.sum(validation_results['referential_integrity'].values())}")
            logger.info(f"Date range: {validation_results['date_range']['min_date']} to {validation_results['date_range']['max_date']}")
            
            # Assert critical validations
            assert null_customer_ids == 0, f"Found {null_customer_ids} null customer IDs"
            assert null_product_ids == 0, f"Found {null_product_ids} null product IDs"
            assert null_location_ids == 0, f"Found {null_location_ids} null location IDs"
            assert negative_revenue == 0, f"Found {negative_revenue} negative revenue records"
            assert invalid_quantities == 0, f"Found {invalid_quantities} invalid quantity records"
            
            logger.info("✓ All critical validations passed")
            
        except AssertionError as e:
            logger.error(f"Validation failed: {str(e)}")
            raise
        except Exception as e:
            logger.error(f"Error during validation: {str(e)}")
            raise
        
        return validation_results
    
    def save_data(self, output_path, format="parquet"):
        """Save transformed data to specified format with partitioning"""
        logger.info(f"Saving data to {output_path} in {format} format...")
        
        try:
            # Save fact table partitioned by year and month
            self.fact_sales \
                .withColumn("year", substring(col("date_id"), 1, 4)) \
                .withColumn("month", substring(col("date_id"), 5, 2)) \
                .write \
                .mode("overwrite") \
                .partitionBy("year", "month") \
                .option("path", f"{output_path}/fact_sales") \
                .saveAsTable("fact_sales")
            
            # Save dimension tables
            self.dim_customers_final.write.mode("overwrite").option("path", f"{output_path}/dim_customers").saveAsTable("dim_customers")
            self.dim_products_final.write.mode("overwrite").option("path", f"{output_path}/dim_products").saveAsTable("dim_products")
            self.dim_locations_final.write.mode("overwrite").option("path", f"{output_path}/dim_locations").saveAsTable("dim_locations")
            self.dim_dates.write.mode("overwrite").option("path", f"{output_path}/dim_dates").saveAsTable("dim_dates")
            
            logger.info("✓ Data saved successfully")
            
        except Exception as e:
            logger.error(f"Error saving data: {str(e)}")
            raise
    
    def generate_summary_stats(self):
        """Generate comprehensive summary statistics"""
        logger.info("Generating summary statistics...")
        
        try:
            # Sales summary by various dimensions
            monthly_sales = self.fact_sales \
                .groupBy(substring(col("date_id"), 1, 6).alias("year_month")) \
                .agg(
                    sum("revenue").alias("total_revenue"),
                    count("transaction_id").alias("total_transactions"),
                    avg("revenue").alias("avg_transaction_value")
                ) \
                .orderBy("year_month")
            
            category_performance = self.fact_sales \
                .join(self.dim_products_final, "product_id") \
                .groupBy("category") \
                .agg(
                    sum("revenue").alias("total_revenue"),
                    count("transaction_id").alias("total_transactions"),
                    avg("revenue").alias("avg_transaction_value")
                ) \
                .orderBy(desc("total_revenue"))
            
            regional_performance = self.fact_sales \
                .join(self.dim_locations_final, "location_id") \
                .groupBy("region") \
                .agg(
                    sum("revenue").alias("total_revenue"),
                    count("transaction_id").alias("total_transactions")
                ) \
                .orderBy(desc("total_revenue"))
            
            logger.info("=== SUMMARY STATISTICS ===")
            logger.info("\nMonthly Sales Trend (Top 10):")
            monthly_sales.show(10)
            
            logger.info("\nTop Product Categories:")
            category_performance.show()
            
            logger.info("\nRegional Performance:")
            regional_performance.show()
            
        except Exception as e:
            logger.error(f"Error generating summary statistics: {str(e)}")
            raise
    
    def cleanup(self):
        """Clean up resources"""
        logger.info("Cleaning up resources...")
        self.spark.stop()
        logger.info("Spark session stopped")


def main():
    from pyspark.sql import SparkSession
    """Main ETL pipeline execution"""
    #etl = SalesDataETL("SalesAnalyticsETL")

    
    spark = SparkSession.builder \
        .appName("SalesAnalyticsETL") \
        .master("local[*]") \
        .config("spark.hadoop.fs.defaultFS", "file:///") \
        .getOrCreate()

    etl = SalesDataETL(spark)
    
    try:
        # Configuration
        data_path = "/Users/tariqul/Development/Python/Data transformations using PySpark/input_data"  # Update this path
        output_path = "/Users/tariqul/Development/Python/Data transformations using PySpark/output_data"  # Update this path for warehouse
        
        # ETL Pipeline Steps
        logger.info("=== STARTING SALES ANALYTICS ETL PIPELINE ===")
        
        # Step 1: Extract
        etl.extract_data(data_path)
        
        # Step 2: Transform
        etl.transform_data()
        
        # Step 3: Validate
        etl.validate_data()
        
        # Step 4: Load (Save)
        etl.save_data(output_path)
        
        # Step 5: Generate Reports
        etl.generate_summary_stats()
        
        logger.info("=== ETL PIPELINE COMPLETED SUCCESSFULLY ===")
        
        
    except Exception as e:
        logger.error(f"Pipeline failed: {str(e)}")
        raise
    finally:
        etl.cleanup()

In [3]:
main()

25/06/10 12:38:42 WARN Utils: Your hostname, tariquls-MacBook-Air.local resolves to a loopback address: 127.0.0.1; using 192.168.0.9 instead (on interface en0)
25/06/10 12:38:42 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/06/10 12:38:44 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/06/10 12:38:49 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.
INFO:__main__:Spark session initialized: <pyspark.sql.session.SparkSession object at 0x1045d9c90>
INFO:__main__:=== STARTING SALES ANALYTICS ETL PIPELINE ===
INFO:__main__:Starting data extraction...
INFO:__main__:Data extraction completed:                                        
INFO:__main__:  - Sales: 100,100 records                                      

+----------+------------------+------------------+---------------------+
|year_month|     total_revenue|total_transactions|avg_transaction_value|
+----------+------------------+------------------+---------------------+
|    202201|        1810908.38|              2745|    659.7116138433515|
|    202202|1728242.4099999997|              2519|    686.0827352123857|
|    202203|         1939811.6|              2890|    671.2150865051904|
|    202204|1842994.4199999995|              2746|    671.1560160233064|
|    202205|1910003.5300000014|              2807|    680.4430103313151|
|    202206|1834977.8999999994|              2677|    685.4605528576762|
|    202207|1905808.2000000002|              2791|    682.8406305983519|
|    202208| 1937517.629999999|              2829|    684.8772110286317|
|    202209|1864006.0200000005|              2705|    689.0964953789281|
|    202210|1947986.1400000001|              2832|    687.8482132768362|
+----------+------------------+------------------+-

INFO:__main__:                                                                  
Regional Performance:


+-------------+-----------------+------------------+---------------------+
|     category|    total_revenue|total_transactions|avg_transaction_value|
+-------------+-----------------+------------------+---------------------+
|   Automotive|9789641.330000019|             15205|    643.8435600131548|
|Home & Garden|9603142.490000024|             14288|     672.112436310192|
|        Books|9176482.970000006|             13425|    683.5369065176914|
|     Clothing|9001828.860000018|             12657|    711.2134676463631|
|       Beauty|8135679.610000008|             12050|    675.1601336099592|
|  Electronics|8089031.240000006|             11487|     704.190061809002|
|       Sports|7768540.860000004|             11197|    693.8055604179694|
|         Toys|6558976.750000004|              9691|    676.8111392013212|
+-------------+-----------------+------------------+---------------------+



INFO:__main__:=== ETL PIPELINE COMPLETED SUCCESSFULLY ===
INFO:__main__:Cleaning up resources...


+-------+--------------------+------------------+
| region|       total_revenue|total_transactions|
+-------+--------------------+------------------+
|  South|1.8183891300000045E7|             26775|
|Central|1.7634833179999985E7|             25902|
|  North| 1.565890121000002E7|             22980|
|   West|1.1812267070000019E7|             17222|
|   East|   4833431.350000001|              7121|
+-------+--------------------+------------------+



INFO:__main__:Spark session stopped
