# Sales ETL - End-to-End PySpark Demo 

This notebook demonstrates an end-to-end ETL pipeline using PySpark.

Pipeline features:
- Generate mock sales data
- Extract from CSV
- Data quality checks + schema mapping
- Hash (encrypt) sensitive IDs
- Incremental load using watermark
- Write final table into Postgres
- Save output into Parquet (partitioned)
- Log ETL runs into Postgres table

In [1]:
import os
from pyspark.sql import SparkSession

postgres_jar = os.path.abspath("jars/postgresql-42.7.3.jar")

spark = (
    SparkSession.builder
    .master("local[*]")
    .appName("sales_etl_spark_demo")
    .config("spark.jars", postgres_jar)
    .config("spark.driver.extraClassPath", postgres_jar)
    .getOrCreate()
)

spark.sparkContext.setLogLevel("ERROR")
print("Spark created with Postgres JDBC driver!")
print("spark.jars =", spark.sparkContext.getConf().get("spark.jars"))


26/02/10 13:09:46 WARN Utils: Your hostname, Thiagos-MacBook-Air.local resolves to a loopback address: 127.0.0.1; using 192.168.1.11 instead (on interface en0)
26/02/10 13:09:46 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
26/02/10 13:09:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
26/02/10 13:09:47 WARN DependencyUtils: Local jar /Users/thiagomoura/Documents/projetos/sales-etl-pyspark/notebooks/jars/postgresql-42.7.3.jar does not exist, skipping.
26/02/10 13:09:47 INFO SparkContext: Running Spark version 3.4.1
26/02/10 13:09:47 INFO ResourceUtils: No custom resources configured for spark.driver.
26/02/10 13:09:47 INFO SparkContext: Submitted application: sales_etl_spark_demo
26/02/10 13:09:47 INFO ResourceProfile: Default ResourceProfile created, executor resources: Map(cores -> name: cores, amount: 1, script: , vendor: , memory -> name: memory, amount: 1024, script: , vendor: ,

Spark created with Postgres JDBC driver!
spark.jars = /Users/thiagomoura/Documents/projetos/sales-etl-pyspark/notebooks/jars/postgresql-42.7.3.jar


In [3]:
# ---------------- Imports + Setup -----------------------------

import os
import json
import time
import random
import logging
import pandas as pd

from datetime import datetime, timedelta

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType, StringType

In [4]:
#-------------------- CONFIGURATION ---------------------------------


CFG = {
    "app": {"name": "sales_etl_demo"},
    "paths": {
        "raw_data_dir": "data/raw",
        "state_dir": "data/state",
        "parquet_output_dir": "data/output/sales_parquet"
    },
    "spark": {
        "master": "local[*]",
        "app_name": "sales_etl_spark_demo",
        "shuffle_partitions": 4
    },
    "db": {
        "url": "jdbc:postgresql://localhost:5432/sales_db",
        "user": "sales_user",
        "password": os.getenv("DB_PASSWORD", "sales_pass"),  
        "driver": "org.postgresql.Driver",
        "table": "sales",
        "log_table": "etl_run_log"
    }
}

# State file for watermark
WATERMARK_PATH = os.path.join(CFG["paths"]["state_dir"], "watermark.json")


In [None]:
# ------------------- SPARK SESSION --------------------------


spark = (
    SparkSession.builder
    .master(CFG["spark"]["master"])
    .appName(CFG["spark"]["app_name"])
    .config("spark.sql.shuffle.partitions", CFG["spark"]["shuffle_partitions"])
    .getOrCreate()
)

spark.sparkContext.setLogLevel("ERROR")

print("Spark session created!")


In [None]:
# ---------------------------- MOCK DATA GENERATOR -------------------------------


def generate_mock_sales_csv(output_path: str, n_rows: int):
    """
    Generates a mock sales CSV file.

    Why this is useful:
    - Allows a full ETL demo without real sensitive data
    - Enables repeatable interview execution
    - Produces a sequential transaction_id (used as watermark)
    """

    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    base_time = datetime.now() - timedelta(days=10)

    rows = []
    for i in range(n_rows):
        transaction_id = i + 1
        customer_id = random.randint(1000, 5000)
        product_id = random.randint(1, 300)
        quantity = random.randint(1, 10)

        ts = base_time + timedelta(minutes=random.randint(0, 60 * 24 * 10))

        rows.append({
            "transaction_id": transaction_id,
            "customer_id": customer_id,
            "product_id": product_id,
            "quantity": quantity,
            "timestamp": ts.isoformat()
        })

    df = pd.DataFrame(rows)
    df.to_csv(output_path, index=False)

    print(f"Mock CSV generated: {output_path} ({n_rows} rows)")


RAW_CSV_PATH = os.path.join(CFG["paths"]["raw_data_dir"], "sales1.csv")

# Generate data only if file does not exist (avoid overwriting)
if not os.path.exists(RAW_CSV_PATH):
    generate_mock_sales_csv(RAW_CSV_PATH, n_rows=500)
else:
    print(f"CSV already exists: {RAW_CSV_PATH}")

In [None]:
# ------------------ EXTRACT -----------------------------


def extract_sales_csv(spark, path: str):
    """
    Reads the raw CSV into a Spark DataFrame.

    - header=True: reads column names
    - inferSchema=True: Spark infers types (good for demo, not for prod)
    """

    df = (
        spark.read
        .option("header", True)
        .option("inferSchema", True)
        .csv(path)
    )

    # Select only the expected columns (ignore extra columns if any)
    df = df.select(
        F.col("transaction_id"),
        F.col("customer_id"),
        F.col("product_id"),
        F.col("quantity"),
        F.col("timestamp")
    )

    return df


raw_df = extract_sales_csv(spark, RAW_CSV_PATH)

print("Raw rows:", raw_df.count())
raw_df.show(10, truncate=False)


In [None]:
# ------------------------ TRANSFORM (DQ + schema mapping) --------------------------


def transform_sales(df):
    """
    Applies:
    - schema enforcement
    - data quality filters
    - outlier filtering
    - adds sale_date
    """

    # Schema evolution: ensure required columns exist
    required_cols = ["transaction_id", "customer_id", "product_id", "quantity", "timestamp"]

    for c in required_cols:
        if c not in df.columns:
            df = df.withColumn(c, F.lit(None).cast(StringType()))

    # Casts
    df = (
        df.withColumn("transaction_id", F.col("transaction_id").cast(IntegerType()))
          .withColumn("customer_id", F.col("customer_id").cast(IntegerType()))
          .withColumn("product_id", F.col("product_id").cast(IntegerType()))
          .withColumn("quantity", F.col("quantity").cast(IntegerType()))
          .withColumn("timestamp", F.to_timestamp("timestamp"))
    )

    # Data quality: remove nulls on critical columns
    df = df.filter(F.col("transaction_id").isNotNull())
    df = df.filter(F.col("customer_id").isNotNull())
    df = df.filter(F.col("product_id").isNotNull())
    df = df.filter(F.col("quantity").isNotNull())
    df = df.filter(F.col("timestamp").isNotNull())

    # Outlier handling (business rule)
    df = df.filter((F.col("quantity") > 0) & (F.col("quantity") <= 100))

    # Derive sale_date (common analytical column)
    df = df.withColumn("sale_date", F.to_date("timestamp"))

    # Final schema
    df = df.select(
        "transaction_id",
        "customer_id",
        "product_id",
        "quantity",
        "sale_date"
    )

    return df


clean_df = transform_sales(raw_df)

print("Clean rows:", clean_df.count())
clean_df.show(10, truncate=False)


In [None]:
# --------------------------- ENCRYPTION (HASH) ------------------------------

def encrypt_ids(df):
    """
    Hashes customer_id and product_id using SHA-256.

    Why:
    - protects sensitive identifiers
    - still allows joins and analytics (consistent hash)

    Note:
    - This is hashing, not reversible encryption.
    """

    return (
        df.withColumn("customer_id_hash", F.sha2(F.col("customer_id").cast("string"), 256))
          .withColumn("product_id_hash", F.sha2(F.col("product_id").cast("string"), 256))
    )


# Show before hashing
print("Before hashing:")
clean_df.select("customer_id", "product_id").show(5)

hashed_df = encrypt_ids(clean_df)

print("After hashing:")
hashed_df.select("customer_id_hash", "product_id_hash").show(5, truncate=False)


In [None]:
# ------------------- WATERMARK (incremental state) -------------------------


def read_watermark(path):
    """
    Reads last_transaction_id from watermark JSON file.
    Returns None if file doesn't exist.
    """
    if not os.path.exists(path):
        return None

    with open(path, "r") as f:
        return json.load(f).get("last_transaction_id")


def write_watermark(path, last_id):
    """
    Writes last_transaction_id to watermark JSON file.
    """
    os.makedirs(os.path.dirname(path), exist_ok=True)

    with open(path, "w") as f:
        json.dump({"last_transaction_id": int(last_id)}, f)


In [None]:
# -------------------- ETL RUN LOGGING INTO POSTGRES --------------------


def log_etl_run_to_postgres(
    spark,
    cfg,
    run_id: str,
    status: str,
    row_count: int,
    started_at: str,
    finished_at: str,
    error_message: str = None
):
    """
    Writes ETL execution logs into Postgres table.

    This is a very strong feature for interviews because it shows:
    - observability
    - operational maturity
    """

    log_df = spark.createDataFrame([{
        "run_id": run_id,
        "status": status,
        "row_count": int(row_count),
        "started_at": started_at,
        "finished_at": finished_at,
        "error_message": error_message
    }])

    props = {
        "user": cfg["db"]["user"],
        "password": cfg["db"]["password"],
        "driver": cfg["db"]["driver"]
    }

    log_df.write.jdbc(
        url=cfg["db"]["url"],
        table=cfg["db"]["log_table"],
        mode="append",
        properties=props
    )


In [None]:
# ------------------- LOAD STEP ------------------- 

def load_sales_incremental(df, cfg, watermark_path):
    """
    Loads the final dataset incrementally into Postgres.

    Incremental logic:
    - If watermark exists: load only transaction_id > watermark
    - Else: full load (first run)

    Also writes the final dataset into partitioned Parquet.
    """

    last_id = read_watermark(watermark_path)

    if last_id:
        print(f"Incremental mode enabled. last_transaction_id={last_id}")
        df_to_load = df.filter(F.col("transaction_id") > F.lit(last_id))
    else:
        print("No watermark found. Running full load (first run).")
        df_to_load = df

    # If no new rows, exit early
    if df_to_load.rdd.isEmpty():
        print("No new rows to load. Skipping load.")
        return 0

    # Hash IDs before loading
    df_to_load = encrypt_ids(df_to_load)

    # Repartition (demo parallelism)
    df_to_load = df_to_load.repartition(4)

    props = {
        "user": cfg["db"]["user"],
        "password": cfg["db"]["password"],
        "driver": cfg["db"]["driver"]
    }

    
    # 1) Write to Postgres
    
    df_to_load.write.jdbc(
        url=cfg["db"]["url"],
        table=cfg["db"]["table"],
        mode="append",
        properties=props
    )

    
    # 2) Write to Parquet (partitioned)
    
    parquet_path = cfg["paths"]["parquet_output_dir"]

    (
        df_to_load
        .write
        .mode("append")
        .partitionBy("sale_date")  # partition strategy for analytics
        .parquet(parquet_path)
    )

    # Update watermark
    max_id = df_to_load.agg(F.max("transaction_id").alias("max_id")).collect()[0]["max_id"]
    if max_id:
        write_watermark(watermark_path, max_id)
        print(f"Watermark updated: last_transaction_id={max_id}")

    return df_to_load.count()


In [None]:
# ==============================
# RUN ETL PIPELINE
# ==============================

run_id = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
started_at = datetime.now().isoformat()

try:
    print(f"Starting ETL run_id={run_id}")

    # Extract
    raw_df = extract_sales_csv(spark, RAW_CSV_PATH)

    # Transform
    clean_df = transform_sales(raw_df)

    # Load (incremental)
    loaded_rows = load_sales_incremental(clean_df, CFG, WATERMARK_PATH)

    status = "SUCCESS"
    error_message = None

except Exception as e:
    status = "FAILED"
    loaded_rows = 0
    error_message = str(e)
    raise

finally:
    finished_at = datetime.now().isoformat()

    # Log run into Postgres
    log_etl_run_to_postgres(
        spark=spark,
        cfg=CFG,
        run_id=run_id,
        status=status,
        row_count=loaded_rows,
        started_at=started_at,
        finished_at=finished_at,
        error_message=error_message
    )

    print(f"ETL finished with status={status}, loaded_rows={loaded_rows}")


In [None]:
# ==============================
# READ FINAL TABLE FROM POSTGRES
# ==============================

props = {
    "user": CFG["db"]["user"],
    "password": CFG["db"]["password"],
    "driver": CFG["db"]["driver"]
}

sales_db_df = spark.read.jdbc(
    url=CFG["db"]["url"],
    table=CFG["db"]["table"],
    properties=props
)

print("Rows in Postgres table:", sales_db_df.count())
sales_db_df.show(10, truncate=False)

In [None]:
# ==============================
# READ ETL LOG TABLE
# ==============================

log_db_df = spark.read.jdbc(
    url=CFG["db"]["url"],
    table=CFG["db"]["log_table"],
    properties=props
)

log_db_df.orderBy(F.col("started_at").desc()).show(20, truncate=False)

In [None]:
# ==============================
# READ PARQUET OUTPUT
# ==============================

parquet_path = CFG["paths"]["parquet_output_dir"]

parquet_df = spark.read.parquet(parquet_path)

print("Rows in Parquet:", parquet_df.count())
parquet_df.show(10, truncate=False)

print("Partitions available (sale_date):")
parquet_df.select("sale_date").distinct().orderBy("sale_date").show(20, truncate=False)

In [None]:
# ==============================
# SIMPLE ANALYTICS (Spark)
# ==============================

# Top products by total quantity
top_products = (
    parquet_df.groupBy("product_id_hash")
    .agg(F.sum("quantity").alias("total_quantity"))
    .orderBy(F.desc("total_quantity"))
)

top_products.show(10, truncate=False)

# Sales per day
sales_per_day = (
    parquet_df.groupBy("sale_date")
    .agg(
        F.count("*").alias("transactions"),
        F.sum("quantity").alias("total_quantity")
    )
    .orderBy("sale_date")
)

sales_per_day.show(20, truncate=False)


In [None]:
# Future Improvements (Production-Ready Ideas)

### 1) Config + Validation
- Use Pydantic to validate config types
- Separate configs per environment (dev/uat/prod)

### 2) Better Incremental Strategy
- Use timestamp watermark (event time) instead of transaction_id
- Support upserts (merge) instead of append-only

### 3) Observability
- Add metrics (records read, records dropped, invalid rows)
- Send logs to a monitoring tool (Datadog, Prometheus, ELK)

### 4) Security
- Use Secrets Manager / Key Vault
- Use proper encryption (KMS) instead of only hashing

### 5) Data Lake Layout
- Use a layered architecture:
  - bronze/raw
  - silver/clean
  - gold/analytics


In [5]:
props = {
    "user": CFG["db"]["user"],
    "password": CFG["db"]["password"],
    "driver": CFG["db"]["driver"]
}

test_df = spark.read.jdbc(
    url=CFG["db"]["url"],
    table="(select 1 as ok) t",
    properties=props
)

test_df.show()

Py4JJavaError: An error occurred while calling o69.jdbc.
: java.lang.ClassNotFoundException: org.postgresql.Driver
	at java.net.URLClassLoader.findClass(URLClassLoader.java:387)
	at java.lang.ClassLoader.loadClass(ClassLoader.java:418)
	at java.lang.ClassLoader.loadClass(ClassLoader.java:351)
	at org.apache.spark.sql.execution.datasources.jdbc.DriverRegistry$.register(DriverRegistry.scala:46)
	at org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions.$anonfun$driverClass$1(JDBCOptions.scala:103)
	at org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions.$anonfun$driverClass$1$adapted(JDBCOptions.scala:103)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions.<init>(JDBCOptions.scala:103)
	at org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions.<init>(JDBCOptions.scala:41)
	at org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider.createRelation(JdbcRelationProvider.scala:34)
	at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:346)
	at org.apache.spark.sql.DataFrameReader.loadV1Source(DataFrameReader.scala:229)
	at org.apache.spark.sql.DataFrameReader.$anonfun$load$2(DataFrameReader.scala:211)
	at scala.Option.getOrElse(Option.scala:189)
	at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:211)
	at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:172)
	at org.apache.spark.sql.DataFrameReader.jdbc(DataFrameReader.scala:249)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.lang.Thread.run(Thread.java:750)
