In [3]:
# Install required packages if not already installed
%pip install pyspark python-dotenv

# Import required libraries
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import *
from pyspark.sql.window import Window
import os
import sys
from dotenv import load_dotenv
from urllib.parse import urlparse

# ---- Helper Functions ----

def parse_connection(conn_str: str):
    """Parse and validate a PostgreSQL connection string, returning JDBC URL and credentials.
    Raises ValueError with informative message if required components are missing."""
    if not conn_str:
        raise ValueError("Empty connection string. Set DB_CONNECTION_STRING in environment.")
    parsed = urlparse(conn_str)
    missing = []
    if not parsed.hostname: missing.append("host")
    if not parsed.port: missing.append("port")
    if not parsed.path or parsed.path == '/': missing.append("database name")
    if not parsed.username: missing.append("username")
    if not parsed.password: missing.append("password")
    if missing:
        raise ValueError(f"Connection string missing required components: {', '.join(missing)}")
    # Build JDBC URL with SSL enforced (Azure Postgres typical requirement)
    jdbc_url = f"jdbc:postgresql://{parsed.hostname}:{parsed.port}{parsed.path}?sslmode=require"
    return jdbc_url, parsed.username, parsed.password

# Load environment variables
load_dotenv()

# Database connection string (avoid hard-coded credentials; fallback only for demo)
DB_CONNECTION = os.getenv(
    'DB_CONNECTION_STRING',
    'postgresql://postgressadmin:wf**F!$3dGdf14@copilot-workshop-db.postgres.database.azure.com:5432/workshop_db'
)

# IMPORTANT: Set up Hadoop for Windows BEFORE creating Spark session
if sys.platform.startswith('win'):
    # Create a minimal Hadoop directory structure for Windows
    hadoop_home = os.path.join(os.path.expanduser('~'), '.hadoop')
    os.makedirs(hadoop_home, exist_ok=True)
    os.makedirs(os.path.join(hadoop_home, 'bin'), exist_ok=True)
    os.environ['HADOOP_HOME'] = hadoop_home
    
    # Download winutils.exe if not present (required for Windows)
    winutils_path = os.path.join(hadoop_home, 'bin', 'winutils.exe')
    if not os.path.exists(winutils_path):
        print("⚠️ winutils.exe not found. Downloading...")
        import urllib.request
        try:
            urllib.request.urlretrieve(
                'https://github.com/steveloughran/winutils/raw/master/hadoop-3.0.0/bin/winutils.exe',
                winutils_path
            )
            print("✅ winutils.exe downloaded successfully!")
        except Exception as e:
            print(f"⚠️ Could not download winutils.exe automatically: {e}")
            print("Please download manually from: https://github.com/steveloughran/winutils")

# Initialize Spark Session with PostgreSQL driver
# Note: The driver will be downloaded on first run, which may take a moment
spark = SparkSession.builder \
    .appName("DataPipelineDebugging") \
    .config("spark.jars.packages", "org.postgresql:postgresql:42.7.3") \
    .config("spark.driver.memory", "4g") \
    .config("spark.sql.shuffle.partitions", "4") \
    .master("local[*]") \
    .getOrCreate()

print("✅ Spark session initialized successfully!")
print(f"Spark version: {spark.version}")
print(f"Running on: {sys.platform}")

# Attempt to parse connection & report errors clearly
try:
    JDBC_URL, DB_USER, DB_PASSWORD = parse_connection(DB_CONNECTION)
    print("✅ Parsed JDBC connection successfully.")
except Exception as e:
    print(f"❌ Failed to parse DB connection string: {e}")
    # Abort early: downstream cells depend on this
    raise

# Shared options dict for JDBC reads
JDBC_OPTIONS_BASE = {
    "url": JDBC_URL,
    "user": DB_USER,
    "password": DB_PASSWORD,
    "driver": "org.postgresql.Driver",
    # Potential tuning knobs (commented for workshop):
    # "fetchsize": 10000,
    # Partitioning options can be added for large tables.
}

def safe_load_table(table_name: str):
    """Load a table via Spark JDBC with error handling.
    Returns DataFrame or empty DataFrame with schema=None on failure."""
    try:
        df = spark.read.format("jdbc") \
            .option("url", JDBC_OPTIONS_BASE["url"]) \
            .option("dbtable", table_name) \
            .option("user", JDBC_OPTIONS_BASE["user"]) \
            .option("password", JDBC_OPTIONS_BASE["password"]) \
            .option("driver", JDBC_OPTIONS_BASE["driver"]) \
            .load()
        count = df.count()  # triggers load; acceptable here for confirmation
        print(f"✅ Loaded {table_name} (rows={count})")
        return df
    except Exception as e:
        print(f"❌ Error loading {table_name}: {e}")
        from pyspark.sql import DataFrame
        # Return an empty DataFrame to allow pipeline to proceed gracefully
        empty_rdd = spark.sparkContext.emptyRDD()
        return spark.createDataFrame(empty_rdd, schema=StructType([]))

# Load required base tables
customers = safe_load_table("raw.customers")
orders = safe_load_table("raw.orders")
order_items = safe_load_table("raw.order_items")
products = safe_load_table("raw.products")

# Optional caching for reuse in multiple analyses
for df_name in ["customers", "orders", "order_items", "products"]:
    df_obj = globals()[df_name]
    if df_obj.head(1):  # only cache if not empty
        globals()[df_name] = df_obj.cache()
        print(f"🗃️ Cached DataFrame: {df_name}")

Note: you may need to restart the kernel to use updated packages.
✅ Spark session initialized successfully!
Spark version: 3.4.1
Running on: win32
✅ Parsed JDBC connection successfully.



[notice] A new release of pip is available: 24.0 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


✅ Loaded raw.customers (rows=1000)
✅ Loaded raw.orders (rows=5000)
✅ Loaded raw.orders (rows=5000)
✅ Loaded raw.order_items (rows=15072)
✅ Loaded raw.products (rows=200)
✅ Loaded raw.products (rows=200)
🗃️ Cached DataFrame: customers
🗃️ Cached DataFrame: customers
🗃️ Cached DataFrame: orders
🗃️ Cached DataFrame: order_items
🗃️ Cached DataFrame: products


# 🐛 Data Pipeline Debugging Exercise

This notebook contains a data pipeline with **several bugs and performance issues**. Your task is to use GitHub Copilot to identify and fix them.

## Your Mission:
Use GitHub Copilot Chat to:
1. Review the code and identify issues
2. Understand what each section is trying to accomplish
3. Fix bugs and optimize performance
4. Add proper error handling and validation

## Hints:
- Try asking Copilot to review specific cells
- Ask about performance optimization
- Request explanations for suspicious code patterns
- Use Copilot to suggest best practices

Good luck! 🚀

## Step 1: Load Data from Database

In [2]:
# Step 1: Load Data from Database (Refactored with error handling)
# Using helper and safe_load_table defined earlier.

required_tables = [
    "raw.customers", "raw.orders", "raw.order_items", "raw.products"
]

loaded = {}
for t in required_tables:
    df = safe_load_table(t)
    loaded[t.split('.')[-1]] = df

customers = loaded["customers"]
orders = loaded["orders"]
order_items = loaded["order_items"]
products = loaded["products"]

# Validate minimal presence of rows (demo criteria)
if customers.rdd.isEmpty():
    print("⚠️ customers table is empty or failed to load; downstream analyses may be limited.")
if orders.rdd.isEmpty():
    print("⚠️ orders table is empty or failed to load; trend and RFM analysis will be skipped.")
if order_items.rdd.isEmpty():
    print("⚠️ order_items table is empty; revenue analysis will be skipped.")
if products.rdd.isEmpty():
    print("⚠️ products table is empty; product enrichment will be skipped.")

Loaded 1000 customers
Loaded 5000 orders
Loaded 15072 order items
Loaded 200 products


## Step 2: Calculate Product Revenue

In [None]:
# Join order items with product information
product_sales = order_items.join(
    products,
    order_items.order_id == products.product_id,
    "inner"
)

# Calculate line total with discount applied
product_sales = product_sales.withColumn(
    "line_total",
    F.col("quantity") * F.col("unit_price") * (1 - F.col("discount_percent"))
)

# Aggregate revenue by product
revenue_by_product = product_sales.groupBy("product_id", "product_name", "category") \
    .agg(
        F.sum("line_total").alias("total_revenue"),
        F.sum("quantity").alias("total_quantity"),
        F.count("order_item_id").alias("num_orders")
    )

print("Top 10 Products by Revenue:")
revenue_by_product.orderBy(F.desc("total_revenue")).show(10)

## Step 3: Customer Segmentation (RFM Analysis)

In [None]:
from datetime import datetime

# Calculate RFM metrics for customer segmentation
reference_date = datetime(2024, 1, 1)

# Join customers with their orders
customer_orders = customers.join(
    orders,
    customers.customer_id == orders.customer_id,
    "left"
)

# Calculate RFM metrics
rfm = customer_orders.groupBy("customer_id", "customer_name", "country") \
    .agg(
        F.datediff(F.lit(reference_date), F.max("order_date")).alias("recency"),
        F.count("order_id").alias("frequency"),
        F.sum("total_amount").alias("monetary")
    )

# Score recency (1-5 scale)
rfm = rfm.withColumn(
    "r_score",
    F.when(F.col("recency") < 30, 1)
     .when(F.col("recency") < 60, 2)
     .when(F.col("recency") < 90, 3)
     .when(F.col("recency") < 180, 4)
     .otherwise(5)
).withColumn(
    "f_score",
    F.when(F.col("frequency") >= 10, 5)
     .when(F.col("frequency") >= 5, 4)
     .when(F.col("frequency") >= 3, 3)
     .when(F.col("frequency") >= 2, 2)
     .otherwise(1)
).withColumn(
    "m_score",
    F.when(F.col("monetary") >= 10000, 5)
     .when(F.col("monetary") >= 5000, 4)
     .when(F.col("monetary") >= 2000, 3)
     .when(F.col("monetary") >= 1000, 2)
     .otherwise(1)
)

# Calculate overall RFM score
rfm = rfm.withColumn(
    "rfm_score",
    F.col("r_score") + F.col("f_score") + F.col("m_score")
)

print("Customer Segmentation Results:")
rfm.orderBy(F.desc("rfm_score")).show(10)

## Step 4: Sales Trend Analysis

In [None]:
# Calculate monthly sales trends
monthly_sales = orders.withColumn(
    "month",
    F.date_format("order_date", "yyyy-MM")
)

# Aggregate sales by month
monthly_sales = monthly_sales.groupBy("month") \
    .agg(
        F.count("order_id").alias("total_orders"),
        F.sum("order_id").alias("unique_customers"),
        F.sum("total_amount").alias("revenue"),
        F.sum("total_amount").alias("avg_order_value")
    )

print("Monthly Sales Trends:")
monthly_sales.orderBy("month").show(12)

# Calculate month-over-month growth rate
windowSpec = Window.orderBy("month")
monthly_sales = monthly_sales.withColumn(
    "prev_month_revenue",
    F.lead("revenue").over(windowSpec)
)

monthly_sales = monthly_sales.withColumn(
    "growth_rate",
    ((F.col("revenue") - F.col("prev_month_revenue")) / F.col("prev_month_revenue") * 100)
)

print("\nMonthly Growth Rates:")
monthly_sales.select("month", "revenue", "prev_month_revenue", "growth_rate").show()