# Sales ETL - End-to-End PySpark Demo 

This notebook demonstrates an end-to-end ETL pipeline using PySpark.

Pipeline features:
- Generate mock sales data
- Extract from CSV
- Data quality checks + schema mapping
- Hash (encrypt) sensitive IDs
- Incremental load using watermark
- Write final table into Postgres
- Save output into Parquet (partitioned)
- Log ETL runs into Postgres table

In [1]:
# ---------------- Imports + Setup -----------------------------

import os
import json
import time
import random
import logging
import pandas as pd

from datetime import datetime, timedelta

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import (
    StructType, StructField,
    StringType, IntegerType, TimestampType
)

In [2]:
import os

os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars /Users/thiagomoura/Documents/projetos/sales-etl-pyspark/jars/postgresql-42.7.3.jar pyspark-shell"

In [3]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("ETL").getOrCreate()

26/02/10 13:37:35 WARN Utils: Your hostname, Thiagos-MacBook-Air.local resolves to a loopback address: 127.0.0.1; using 192.168.1.11 instead (on interface en0)
26/02/10 13:37:35 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
26/02/10 13:37:36 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/02/10 13:37:38 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
26/02/10 13:37:38 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


In [4]:
#-------------------- CONFIGURATION ---------------------------------


CFG = {
    "app": {"name": "sales_etl_demo"},
    "paths": {
        "raw_data_dir": "data/raw",
        "state_dir": "data/state",
        "parquet_output_dir": "data/output/sales_parquet"
    },
    "spark": {
        "master": "local[*]",
        "app_name": "sales_etl_spark_demo",
        "shuffle_partitions": 4
    },
    "db": {
        "url": "jdbc:postgresql://localhost:5432/sales_db",
        "user": "sales_user",
        "password": os.getenv("DB_PASSWORD", "sales_pass"),  
        "driver": "org.postgresql.Driver",
        "table": "sales",
        "log_table": "etl_run_log"
    }
}

# State file for watermark
WATERMARK_PATH = os.path.join(CFG["paths"]["state_dir"], "watermark.json")


In [5]:
# ------------------- SPARK SESSION --------------------------

postgres_jar = os.path.abspath("jars/postgresql-42.7.3.jar")
JDBC_DRIVER_PATH = "/Users/thiago/libs/postgresql-42.6.0.jar"

spark = (
    SparkSession.builder
    .master(CFG["spark"]["master"])
    .appName(CFG["spark"]["app_name"])
    .config("spark.jars", JDBC_DRIVER_PATH) 
    .config("spark.sql.shuffle.partitions", CFG["spark"]["shuffle_partitions"])
    .getOrCreate()
)


spark.sparkContext.setLogLevel("ERROR")
print("Spark created with Postgres JDBC driver!")
print("spark.jars =", spark.sparkContext.getConf().get("spark.jars"))


Spark created with Postgres JDBC driver!
spark.jars = file:///Users/thiagomoura/Documents/projetos/sales-etl-pyspark/jars/postgresql-42.7.3.jar


26/02/10 13:37:41 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [6]:
# ---------------------------- MOCK DATA GENERATOR -------------------------------


def generate_mock_sales_csv(output_path: str, n_rows: int):
    """
    Generates a mock sales CSV file.

    Why this is useful:
    - Allows a full ETL demo without real sensitive data
    - Enables repeatable interview execution
    - Produces a sequential transaction_id (used as watermark)
    """

    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    base_time = datetime.now() - timedelta(days=10)

    rows = []
    for i in range(n_rows):
        transaction_id = i + 1
        customer_id = random.randint(1000, 5000)
        product_id = random.randint(1, 300)
        quantity = random.randint(1, 10)

        ts = base_time + timedelta(minutes=random.randint(0, 60 * 24 * 10))

        rows.append({
            "transaction_id": transaction_id,
            "customer_id": customer_id,
            "product_id": product_id,
            "quantity": quantity,
            "timestamp": ts.isoformat()
        })

    df = pd.DataFrame(rows)
    df.to_csv(output_path, index=False)

    print(f"Mock CSV generated: {output_path} ({n_rows} rows)")


RAW_CSV_PATH = os.path.join(CFG["paths"]["raw_data_dir"], "sales1.csv")

# Generate data only if file does not exist (avoid overwriting)
if not os.path.exists(RAW_CSV_PATH):
    generate_mock_sales_csv(RAW_CSV_PATH, n_rows=500)
else:
    print(f"CSV already exists: {RAW_CSV_PATH}")

CSV already exists: data/raw/sales1.csv


In [7]:
# ------------------ EXTRACT -----------------------------


def extract_sales_csv(spark, path: str):
    """
    Reads the raw CSV into a Spark DataFrame.

    - header=True: reads column names
    - inferSchema=True: Spark infers types (good for demo, not for prod)
    """

    df = (
        spark.read
        .option("header", True)
        .option("inferSchema", True)
        .csv(path)
    )

    # Select only the expected columns (ignore extra columns if any)
    df = df.select(
        F.col("transaction_id"),
        F.col("customer_id"),
        F.col("product_id"),
        F.col("quantity"),
        F.col("timestamp")
    )

    return df


raw_df = extract_sales_csv(spark, RAW_CSV_PATH)

print("Raw rows:", raw_df.count())
raw_df.show(10, truncate=False)


                                                                                

Raw rows: 500
+--------------+-----------+----------+--------+--------------------------+
|transaction_id|customer_id|product_id|quantity|timestamp                 |
+--------------+-----------+----------+--------+--------------------------+
|1             |4115       |128       |8       |2026-02-03 10:50:58.620735|
|2             |4632       |115       |8       |2026-02-03 06:11:58.620735|
|3             |1696       |106       |7       |2026-02-03 22:34:58.620735|
|4             |4920       |100       |6       |2026-02-04 19:56:58.620735|
|5             |1304       |208       |9       |2026-02-06 06:35:58.620735|
|6             |2778       |255       |5       |2026-02-07 19:22:58.620735|
|7             |1005       |155       |10      |2026-02-01 17:19:58.620735|
|8             |4834       |277       |8       |2026-02-08 04:53:58.620735|
|9             |4000       |128       |2       |2026-02-05 16:18:58.620735|
|10            |1633       |246       |5       |2026-02-08 13:57:58.620735

In [8]:
# ------------------------ TRANSFORM (DQ + schema mapping) --------------------------


def transform_sales(df):
    """
    Applies:
    - schema enforcement
    - data quality filters
    - outlier filtering
    - adds sale_date
    """

    # Schema evolution: ensure required columns exist
    required_cols = ["transaction_id", "customer_id", "product_id", "quantity", "timestamp"]

    for c in required_cols:
        if c not in df.columns:
            df = df.withColumn(c, F.lit(None).cast(StringType()))

    # Casts
    df = (
        df.withColumn("transaction_id", F.col("transaction_id").cast(IntegerType()))
          .withColumn("customer_id", F.col("customer_id").cast(IntegerType()))
          .withColumn("product_id", F.col("product_id").cast(IntegerType()))
          .withColumn("quantity", F.col("quantity").cast(IntegerType()))
          .withColumn("timestamp", F.to_timestamp("timestamp"))
    )

    # Data quality: remove nulls on critical columns
    df = df.filter(F.col("transaction_id").isNotNull())
    df = df.filter(F.col("customer_id").isNotNull())
    df = df.filter(F.col("product_id").isNotNull())
    df = df.filter(F.col("quantity").isNotNull())
    df = df.filter(F.col("timestamp").isNotNull())

    # Outlier handling (business rule)
    df = df.filter((F.col("quantity") > 0) & (F.col("quantity") <= 100))

    # Derive sale_date (common analytical column)
    df = df.withColumn("sale_date", F.to_date("timestamp"))

    # Final schema
    df = df.select(
        "transaction_id",
        "customer_id",
        "product_id",
        "quantity",
        "sale_date"
    )

    return df


clean_df = transform_sales(raw_df)

print("Clean rows:", clean_df.count())
clean_df.show(10, truncate=False)


Clean rows: 500
+--------------+-----------+----------+--------+----------+
|transaction_id|customer_id|product_id|quantity|sale_date |
+--------------+-----------+----------+--------+----------+
|1             |4115       |128       |8       |2026-02-03|
|2             |4632       |115       |8       |2026-02-03|
|3             |1696       |106       |7       |2026-02-03|
|4             |4920       |100       |6       |2026-02-04|
|5             |1304       |208       |9       |2026-02-06|
|6             |2778       |255       |5       |2026-02-07|
|7             |1005       |155       |10      |2026-02-01|
|8             |4834       |277       |8       |2026-02-08|
|9             |4000       |128       |2       |2026-02-05|
|10            |1633       |246       |5       |2026-02-08|
+--------------+-----------+----------+--------+----------+
only showing top 10 rows



In [9]:
# --------------------------- ENCRYPTION (HASH) ------------------------------

def encrypt_ids(df):
    """
    Hashes customer_id and product_id using SHA-256.

    Why:
    - protects sensitive identifiers
    - still allows joins and analytics (consistent hash)

    Note:
    - This is hashing, not reversible encryption.
    """

    return (
        df.withColumn("customer_id_hash", F.sha2(F.col("customer_id").cast("string"), 256))
          .withColumn("product_id_hash", F.sha2(F.col("product_id").cast("string"), 256))
    )


# Show before hashing
print("Before hashing:")
clean_df.select("customer_id", "product_id").show(5)

hashed_df = encrypt_ids(clean_df)

print("After hashing:")
hashed_df.select("customer_id_hash", "product_id_hash").show(5, truncate=False)


Before hashing:
+-----------+----------+
|customer_id|product_id|
+-----------+----------+
|       4115|       128|
|       4632|       115|
|       1696|       106|
|       4920|       100|
|       1304|       208|
+-----------+----------+
only showing top 5 rows

After hashing:
+----------------------------------------------------------------+----------------------------------------------------------------+
|customer_id_hash                                                |product_id_hash                                                 |
+----------------------------------------------------------------+----------------------------------------------------------------+
|481d60bc802f51580e81af5dc1c0534d6eb255fcbb82bdaf646e8549b7cce4f3|2747b7c718564ba5f066f0523b03e17f6a496b06851333d2d59ab6d863225848|
|9509368b9a5172b8d96f18644356e636b4999607ec09c62b6d92d365169cedfa|28dae7c8bde2f3ca608f86d0e16a214dee74c74bee011cdfdd46bc04b655bc14|
|e64474fd91f16a0891fb1de23dff06e8bd0a0ee495f1add2ec08a07f1f

In [10]:
# ------------------- WATERMARK (incremental state) -------------------------


def read_watermark(path):
    """
    Reads last_transaction_id from watermark JSON file.
    Returns None if file doesn't exist.
    """
    if not os.path.exists(path):
        return None

    with open(path, "r") as f:
        return json.load(f).get("last_transaction_id")


def write_watermark(path, last_id):
    """
    Writes last_transaction_id to watermark JSON file.
    """
    os.makedirs(os.path.dirname(path), exist_ok=True)

    with open(path, "w") as f:
        json.dump({"last_transaction_id": int(last_id)}, f)


In [11]:
# -------------------- ETL RUN LOGGING INTO POSTGRES --------------------
from datetime import datetime
from pyspark.sql.types import (
    StructType, StructField,
    StringType, IntegerType, TimestampType
)

def log_etl_run_to_postgres(
    spark,
    cfg,
    run_id: str,
    status: str,
    row_count: int,
    started_at,
    finished_at,
    error_message: str = None
):
    """
    Logs one ETL execution into Postgres with a robust schema.

    started_at / finished_at must be datetime.datetime objects.
    """

    # Convert from string to datetime if needed
    if isinstance(started_at, str):
        started_at = datetime.fromisoformat(started_at)
    if isinstance(finished_at, str):
        finished_at = datetime.fromisoformat(finished_at)

    # Explicit schema
    log_schema = StructType([
        StructField("run_id", StringType(), False),
        StructField("status", StringType(), False),
        StructField("row_count", IntegerType(), True),
        StructField("started_at", TimestampType(), True),
        StructField("finished_at", TimestampType(), True),
        StructField("error_message", StringType(), True),
    ])

    # Ensure row_count is int
    row_count = int(row_count or 0)

    # Create DataFrame
    log_df = spark.createDataFrame([{
        "run_id": run_id,
        "status": status,
        "row_count": row_count,
        "started_at": started_at,
        "finished_at": finished_at,
        "error_message": error_message
    }], schema=log_schema)

    # JDBC properties
    props = {
        "user": cfg["db"]["user"],
        "password": cfg["db"]["password"],
        "driver": cfg["db"]["driver"]
    }

    # Write to Postgres
    log_df.write.jdbc(
        url=cfg["db"]["url"],
        table="etl_run_log",
        mode="append",
        properties=props
    )

In [12]:
# ------------------- LOAD STEP ------------------- 

def load_sales_incremental(df, cfg, watermark_path):
    """
    Loads the final dataset incrementally into Postgres.

    Incremental logic:
    - If watermark exists: load only transaction_id > watermark
    - Else: full load (first run)

    Also writes the final dataset into partitioned Parquet.
    """

    last_id = read_watermark(watermark_path)

    if last_id:
        print(f"Incremental mode enabled. last_transaction_id={last_id}")
        df_to_load = df.filter(F.col("transaction_id") > F.lit(last_id))
    else:
        print("No watermark found. Running full load (first run).")
        df_to_load = df

    # If no new rows, exit early
    if df_to_load.rdd.isEmpty():
        print("No new rows to load. Skipping load.")
        return 0

    # Hash IDs before loading
    df_to_load = encrypt_ids(df_to_load)

    # Repartition (demo parallelism)
    df_to_load = df_to_load.repartition(4)

    props = {
        "user": cfg["db"]["user"],
        "password": cfg["db"]["password"],
        "driver": cfg["db"]["driver"]
    }

    
    # 1) Write to Postgres
    
    df_to_load.write.jdbc(
        url=cfg["db"]["url"],
        table=cfg["db"]["table"],
        mode="append",
        properties=props
    )

    
    # 2) Write to Parquet (partitioned)
    
    parquet_path = cfg["paths"]["parquet_output_dir"]

    (
        df_to_load
        .write
        .mode("append")
        .partitionBy("sale_date")  # partition strategy for analytics
        .parquet(parquet_path)
    )

    # Update watermark
    max_id = df_to_load.agg(F.max("transaction_id").alias("max_id")).collect()[0]["max_id"]
    if max_id:
        write_watermark(watermark_path, max_id)
        print(f"Watermark updated: last_transaction_id={max_id}")

    return df_to_load.count()


In [13]:
# ==============================
# RUN ETL PIPELINE
# ==============================

run_id = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
started_at = datetime.now().isoformat()

try:
    print(f"Starting ETL run_id={run_id}")

    # Extract
    raw_df = extract_sales_csv(spark, RAW_CSV_PATH)

    # Transform
    clean_df = transform_sales(raw_df)

    # Load (incremental)
    loaded_rows = load_sales_incremental(clean_df, CFG, WATERMARK_PATH)

    status = "SUCCESS"
    error_message = None

except Exception as e:
    status = "FAILED"
    loaded_rows = 0
    error_message = str(e)
    raise

finally:
    finished_at = datetime.now().isoformat()

    # Log run into Postgres
    log_etl_run_to_postgres(
        spark=spark,
        cfg=CFG,
        run_id=run_id,
        status=status,
        row_count=loaded_rows,
        started_at=started_at,
        finished_at=finished_at,
        error_message=error_message
    )

    print(f"ETL finished with status={status}, loaded_rows={loaded_rows}")


Starting ETL run_id=run_20260210_133756
Incremental mode enabled. last_transaction_id=50000


                                                                                

No new rows to load. Skipping load.


[Stage 15:>                                                         (0 + 8) / 8]

ETL finished with status=SUCCESS, loaded_rows=0


                                                                                

In [14]:
# ==============================
# READ FINAL TABLE FROM POSTGRES
# ==============================

props = {
    "user": CFG["db"]["user"],
    "password": CFG["db"]["password"],
    "driver": CFG["db"]["driver"]
}

sales_db_df = spark.read.jdbc(
    url=CFG["db"]["url"],
    table=CFG["db"]["table"],
    properties=props
)

print("Rows in Postgres table:", sales_db_df.count())
sales_db_df.show(10, truncate=False)

Rows in Postgres table: 50000
+--------------+----------------------------------------------------------------+----------------------------------------------------------------+--------+----------+
|transaction_id|customer_id                                                     |product_id                                                      |quantity|sale_date |
+--------------+----------------------------------------------------------------+----------------------------------------------------------------+--------+----------+
|41696         |75ab3e0a8f99144f9c4420698980b6f76f85996b520937c7043a29e18527bfef|a68b412c4282555f15546cf6e1fc42893b7e07f271557ceb021821098dd66c1b|4       |2026-01-28|
|932           |96382608813353bee4eeaf0635a3b3356276fece94fe8c9cf048871078f8fd14|44cb730c420480a0477b505ae68af508fb90f96cf0ec54c6ad16949dd427f13a|6       |2026-02-02|
|1592          |15b9e0db83ca5103d0d81f272584932103478a4850cb788ee4fff20b7ab5c5ba|768b84ef05f655d57fe22d488451f075365f6cd18a13073466aa82

In [15]:
# ==============================
# READ ETL LOG TABLE
# ==============================

log_db_df = spark.read.jdbc(
    url=CFG["db"]["url"],
    table=CFG["db"]["log_table"],
    properties=props
)

log_db_df.orderBy(F.col("started_at").desc()).show(20, truncate=False)

+-------------------+-------+---------+-------------------------+--------------------------+-------------+
|run_id             |status |row_count|started_at               |finished_at               |error_message|
+-------------------+-------+---------+-------------------------+--------------------------+-------------+
|run_20260210_133756|SUCCESS|0        |2026-02-10 13:37:56.66402|2026-02-10 13:38:00.999985|null         |
+-------------------+-------+---------+-------------------------+--------------------------+-------------+



In [26]:
# ==========================================
# SAVE FINAL DF TO PARQUET (PARTITIONED)
# ==========================================

import os

# Output path for Parquet (from CFG)
parquet_path = CFG["paths"].get("parquet_output_dir", "data/output/sales_parquet")
parquet_path_log = CFG["paths"].get("parquet_output_dir", "data/output/sales_parquet_log")

# Create folder if it does not exist
os.makedirs(parquet_path, exist_ok=True)

# Save the final DataFrame (final_df) as Parquet partitioned by sale_date
# Overwrite mode to ensure it replaces existing files (partitionBy("sale_date"))
log_db_df.write.mode("overwrite").parquet(parquet_path_log)
sales_db_df.write.mode("overwrite").partitionBy("sale_date").parquet(parquet_path)

print(f"Parquet successfully saved at: {parquet_path}")

# ==========================================
# READ SAVED PARQUET
# ==========================================

# Read the Parquet back for verification
parquet_df_log = spark.read.parquet(parquet_path_log)
parquet_df = spark.read.parquet(parquet_path)

print("Number of rows in Parquet:", parquet_df.count())
print("Sample data:")
parquet_df.show(10, truncate=False)

parquet_df_log.show(10, truncate=False)


# ==========================================
# FUTURE IMPROVEMENT
# ==========================================
# 1) Save detailed ETL logs (run_id, status, processed row count) in a separate table.
# 2) Enable Parquet compression (e.g., parquet(..., compression='snappy')) to save space.
# 3) Add incremental control when writing Parquet (append) if ETL runs multiple times.
# 4) Validate schema when reading Parquet: spark.read.schema(schema).parquet(path)
# 5) Consider partitioning by multiple columns for more efficient queries.


                                                                                

Parquet successfully saved at: data/output/sales_parquet
Number of rows in Parquet: 50000
Sample data:
+--------------+----------------------------------------------------------------+----------------------------------------------------------------+--------+----------+
|transaction_id|customer_id                                                     |product_id                                                      |quantity|sale_date |
+--------------+----------------------------------------------------------------+----------------------------------------------------------------+--------+----------+
|41696         |75ab3e0a8f99144f9c4420698980b6f76f85996b520937c7043a29e18527bfef|a68b412c4282555f15546cf6e1fc42893b7e07f271557ceb021821098dd66c1b|4       |2026-01-28|
|20611         |0a6b81782b5c6d04236250f24ed2b3a27c3afd1a80371ea238ba029cbe5aca0d|4523540f1504cd17100c4835e85b7eefd49911580f8efff0599a8f283be6b9e3|10      |2026-01-28|
|41881         |591d48eb061e4e3e4ca2b55451e2353c3922ff23264f22

In [28]:
# ==============================
# SIMPLE ANALYTICS (Spark)
# ==============================

# Top products by total quantity
top_products = (
    parquet_df.groupBy("product_id")
    .agg(F.sum("quantity").alias("total_quantity"))
    .orderBy(F.desc("total_quantity"))
)

top_products.show(10, truncate=False)

# Sales per day
sales_per_day = (
    parquet_df.groupBy("sale_date")
    .agg(
        F.count("*").alias("transactions"),
        F.sum("quantity").alias("total_quantity")
    )
    .orderBy("sale_date")
)

sales_per_day.show(20, truncate=False)


                                                                                

+----------------------------------------------------------------+--------------+
|product_id                                                      |total_quantity|
+----------------------------------------------------------------+--------------+
|16badfc6202cb3f8889e0f2779b19218af4cbb736e56acadce8148aba9a7a9f8|1202          |
|d029fa3a95e174a19934857f535eb9427d967218a36ea014b70ad704bc6c8d1c|1142          |
|011af72a910ac4acf367eef9e6b761e0980842c30d4e9809840f4141d5163ede|1139          |
|f6e0a1e2ac41945a9aa7ff8a8aaa0cebc12a3bcc981a929ad5cf810a090e11ae|1101          |
|f8809aff4d69bece79dabe35be0c708b890d7eafb841f121330667b77d2e2590|1098          |
|37c20f19f3272b5ccc3a5d80587eb9deb3f4afcf568c4280fb195568da8eb1a2|1088          |
|61a229bae1e90331edd986b6bbbe617f7035de88a5bf7c018c3add6c762a6e8d|1083          |
|41e521adf8ae7a0f419ee06e1d9fb794162369237b46f64bf5b2b9969b0bcd2e|1074          |
|2858dcd1057d3eae7f7d5f782167e24b61153c01551450a628cee722509f6529|1072          |
|09895de0407bcb0

[Stage 69:>                                                         (0 + 6) / 6]

+----------+------------+--------------+
|sale_date |transactions|total_quantity|
+----------+------------+--------------+
|2026-01-27|2648        |14569         |
|2026-01-28|5192        |28752         |
|2026-01-29|5024        |27680         |
|2026-01-30|4900        |27007         |
|2026-01-31|5074        |27344         |
|2026-02-01|5015        |27438         |
|2026-02-02|4989        |27177         |
|2026-02-03|4886        |26964         |
|2026-02-04|4856        |26571         |
|2026-02-05|4996        |27525         |
|2026-02-06|2420        |13340         |
+----------+------------+--------------+



                                                                                

# Future Improvements (Production-Ready Ideas)

### 1) Config + Validation
- Use Pydantic to validate config types
- Separate configs per environment (dev/uat/prod)

### 2) Better Incremental Strategy
- Use timestamp watermark (event time) instead of transaction_id
- Support upserts (merge) instead of append-only

### 3) Observability
- Add metrics (records read, records dropped, invalid rows)
- Send logs to a monitoring tool (Datadog, Prometheus, ELK)

### 4) Security
- Use Secrets Manager / Key Vault
- Use proper encryption (KMS) instead of only hashing

### 5) Data Lake Layout
- Use a layered architecture:
  - bronze/raw
  - silver/clean
  - gold/analytics


In [None]:
props = {
    "user": CFG["db"]["user"],
    "password": CFG["db"]["password"],
    "driver": CFG["db"]["driver"]
}

test_df = spark.read.jdbc(
    url=CFG["db"]["url"],
    table="(select 1 as ok) t",
    properties=props
)

test_df.show()