# Lesson 4 - Exercise 2: Write INSERT...SELECT to Populate Dimension and Fact Tables

## Learning Objectives

1. Understand the staging-to-warehouse transformation pattern
2. Write INSERT...SELECT statements that join staging to dimension tables
3. Replace natural keys with surrogate keys during the insert
4. Derive integer date keys from timestamps for efficient filtering
5. Validate referential integrity between facts and dimensions

## Prerequisites

- `aws_config.py` with your credentials
- `data/van_transit_trips_postgres.csv` data file

In [None]:
# ========= SETUP: Imports, Config, and Functions (Run this cell first!) =========

import os
import time
from datetime import datetime
from typing import Dict, Any, List

import pandas as pd
import numpy as np
import boto3

# Load AWS credentials from aws_config.py
import aws_config

# Configuration
AWS_REGION = os.getenv('AWS_REGION', 'us-east-1')
REDSHIFT_DATABASE = os.getenv('REDSHIFT_DATABASE', 'dev')
REDSHIFT_WORKGROUP = os.getenv('REDSHIFT_WORKGROUP', 'udacity-dwh-wg')
REDSHIFT_SECRET_ARN = os.getenv('REDSHIFT_SECRET_ARN', None)  # Optional

print("Configuration:")
print(f"   Region: {AWS_REGION}")
print(f"   Database: {REDSHIFT_DATABASE}")
print(f"   Workgroup: {REDSHIFT_WORKGROUP}")

# Redshift Data API client
session_boto = boto3.Session(region_name=AWS_REGION)
rsd = session_boto.client("redshift-data", region_name=AWS_REGION)


def _rs_kwargs() -> Dict[str, Any]:
    """Build connection arguments for Redshift Data API."""
    base = dict(Database=REDSHIFT_DATABASE)
    if REDSHIFT_WORKGROUP:
        base["WorkgroupName"] = REDSHIFT_WORKGROUP
        if REDSHIFT_SECRET_ARN:
            base["SecretArn"] = REDSHIFT_SECRET_ARN
    return base


def rs_exec(sql: str, return_results=False, timeout_s=900):
    """Execute SQL on Redshift via the Data API."""
    sql = sql.strip()
    if not sql:
        return None
    
    kwargs = _rs_kwargs()
    kwargs["Sql"] = sql
    
    sid = rsd.execute_statement(**kwargs)["Id"]
    
    start = time.time()
    while True:
        d = rsd.describe_statement(Id=sid)
        if d["Status"] in ("FINISHED", "FAILED", "ABORTED"):
            break
        if time.time() - start > timeout_s:
            raise TimeoutError("Redshift statement timeout")
        time.sleep(0.5)
    
    if d["Status"] != "FINISHED":
        raise RuntimeError(f"Redshift SQL failed: {d.get('Error')}\n---\n{sql}")
    
    if return_results or sql.strip().lower().startswith("select"):
        out, next_token = [], None
        while True:
            args = dict(Id=sid)
            if next_token:
                args["NextToken"] = next_token
            r = rsd.get_statement_result(**args)
            cols = [c["name"] for c in r["ColumnMetadata"]]
            for rec in r["Records"]:
                row = [next(iter(cell.values())) for cell in rec]
                out.append(dict(zip(cols, row)))
            next_token = r.get("NextToken")
            if not next_token:
                break
        return out
    
    return None


def rs_batch_insert(table: str, df: pd.DataFrame, batch_size: int = 50):
    """Insert DataFrame rows into Redshift using batched INSERT statements."""
    if df is None or df.empty:
        return 0
    
    cols = list(df.columns)
    total = 0
    
    def format_value(val):
        if pd.isna(val) or val is None:
            return "NULL"
        elif isinstance(val, bool):
            return "TRUE" if val else "FALSE"
        elif isinstance(val, (int, float)):
            return str(val)
        else:
            return "'" + str(val).replace("'", "''") + "'"
    
    rows = df.to_dict(orient="records")
    
    for i in range(0, len(rows), batch_size):
        chunk = rows[i:i+batch_size]
        value_rows = []
        
        for r in chunk:
            values = [format_value(r.get(c)) for c in cols]
            value_rows.append(f"({', '.join(values)})")
        
        sql = f"INSERT INTO {table} ({', '.join(cols)}) VALUES {', '.join(value_rows)}"
        rs_exec(sql)
        total += len(chunk)
        
        if total % 500 == 0:
            print(f"      Progress: {total:,}/{len(rows):,} rows")
    
    return total


print("\nFunctions defined: rs_exec(), rs_batch_insert()")
print("Setup complete!")

---
## Step 1: Create and Load Staging Table

In [None]:
# ========= STEP 1: Check/Create Staging Table =========

def check_staging_exists():
    try:
        result = rs_exec("SELECT COUNT(*) AS cnt FROM public.stg_trips_raw;")
        return result[0]['cnt'] > 0
    except:
        return False

print("Checking staging table...")
print("=" * 60)

if check_staging_exists():
    result = rs_exec("SELECT COUNT(*) AS cnt FROM public.stg_trips_raw;")
    print(f"   Staging table exists with {result[0]['cnt']} rows")
else:
    print("   Staging table not found. Creating and loading...")
    
    rs_exec("DROP TABLE IF EXISTS public.stg_trips_raw;")
    rs_exec("""
    CREATE TABLE public.stg_trips_raw (
        trip_id               VARCHAR(32),
        rider_id              VARCHAR(32),
        route_id              VARCHAR(32),
        mode                  VARCHAR(16),
        origin_station_id     VARCHAR(32),
        destination_station_id VARCHAR(32),
        board_datetime        TIMESTAMP,
        alight_datetime       TIMESTAMP,
        country               VARCHAR(8),
        province              VARCHAR(8),
        fare_class            VARCHAR(16),
        payment_method        VARCHAR(32),
        transfers             INTEGER,
        zones_charged         INTEGER,
        distance_km           DECIMAL(10,2),
        base_fare_cad         DECIMAL(10,2),
        discount_rate         DECIMAL(5,3),
        discount_amount_cad   DECIMAL(10,2),
        yvr_addfare_cad       DECIMAL(10,2),
        total_fare_cad        DECIMAL(10,2),
        on_time_arrival       BOOLEAN,
        service_disruption    BOOLEAN,
        polyline_stations     VARCHAR(512)
    );
    """)
    print("   Created stg_trips_raw table")
    
    csv_path = "data/van_transit_trips_postgres.csv"
    print(f"   Loading data from {csv_path}...")
    trips_df = pd.read_csv(csv_path)
    print(f"   Read {len(trips_df):,} rows from CSV")
    
    rows = rs_batch_insert("public.stg_trips_raw", trips_df, batch_size=50)
    print(f"   Loaded {rows:,} rows into staging table")

---
## Step 2: Create Dimension Tables

We need to create dimension tables for: date, rider, route, station, fare_class, and payment_method.

In [None]:
# ========= STEP 2: Create all dimension tables =========

print("Creating dimension tables...")
print("=" * 60)

# dim_date (provided)
rs_exec("DROP TABLE IF EXISTS public.dw_dim_date;")
rs_exec("""
CREATE TABLE public.dw_dim_date (
    date_key        INTEGER       NOT NULL,
    date_actual     DATE          NOT NULL,
    year            SMALLINT      NOT NULL,
    quarter         SMALLINT      NOT NULL,
    month           SMALLINT      NOT NULL,
    day             SMALLINT      NOT NULL,
    week_of_year    SMALLINT      NOT NULL,
    day_of_week     SMALLINT      NOT NULL,
    is_weekend      BOOLEAN       NOT NULL,
    PRIMARY KEY (date_key)
) DISTSTYLE ALL SORTKEY (date_key);
""")
print("   Created: dw_dim_date")

# dim_rider (provided)
rs_exec("DROP TABLE IF EXISTS public.dw_dim_rider;")
rs_exec("""
CREATE TABLE public.dw_dim_rider (
    rider_sk        BIGINT IDENTITY(1,1),
    rider_id        VARCHAR(32)   ENCODE zstd,
    rider_segment   VARCHAR(16)   ENCODE zstd,
    effective_from  TIMESTAMP     ENCODE zstd,
    effective_to    TIMESTAMP     ENCODE zstd,
    is_current      BOOLEAN       ENCODE zstd,
    PRIMARY KEY (rider_sk)
) DISTKEY (rider_id) SORTKEY (rider_id);
""")
print("   Created: dw_dim_rider")

**TODO**: Create the remaining dimension tables:

1. **dw_dim_route**: route_sk (IDENTITY), route_id (VARCHAR(32)), mode (VARCHAR(16))
2. **dw_dim_station**: station_sk (IDENTITY), station_id (VARCHAR(32))
3. **dw_dim_fare_class**: fare_class_sk (IDENTITY), fare_class (VARCHAR(16))
4. **dw_dim_payment_method**: payment_method_sk (IDENTITY), payment_method (VARCHAR(32))

In [None]:
# TODO: Create dim_route
rs_exec("DROP TABLE IF EXISTS public.dw_dim_route;")
rs_exec("""
-- TODO: Write CREATE TABLE for dw_dim_route

""")
print("   Created: dw_dim_route")

# TODO: Create dim_station
rs_exec("DROP TABLE IF EXISTS public.dw_dim_station;")
rs_exec("""
-- TODO: Write CREATE TABLE for dw_dim_station

""")
print("   Created: dw_dim_station")

# TODO: Create dim_fare_class
rs_exec("DROP TABLE IF EXISTS public.dw_dim_fare_class;")
rs_exec("""
-- TODO: Write CREATE TABLE for dw_dim_fare_class

""")
print("   Created: dw_dim_fare_class")

# TODO: Create dim_payment_method
rs_exec("DROP TABLE IF EXISTS public.dw_dim_payment_method;")
rs_exec("""
-- TODO: Write CREATE TABLE for dw_dim_payment_method

""")
print("   Created: dw_dim_payment_method")

print("\nAll dimension tables created!")

---
## Step 3: Populate Dimension Tables

Use INSERT...SELECT to populate dimensions from staging data.

In [None]:
# ========= STEP 3: Populate dimension tables =========

print("Populating dimension tables...")
print("=" * 60)

# dim_date (provided)
rs_exec("""
INSERT INTO public.dw_dim_date (
    date_key, date_actual, year, quarter, month, day, 
    week_of_year, day_of_week, is_weekend
)
SELECT DISTINCT
    CAST(TO_CHAR(dt, 'YYYYMMDD') AS INTEGER) AS date_key,
    dt AS date_actual,
    EXTRACT(YEAR FROM dt)::SMALLINT AS year,
    EXTRACT(QUARTER FROM dt)::SMALLINT AS quarter,
    EXTRACT(MONTH FROM dt)::SMALLINT AS month,
    EXTRACT(DAY FROM dt)::SMALLINT AS day,
    EXTRACT(WEEK FROM dt)::SMALLINT AS week_of_year,
    EXTRACT(DOW FROM dt)::SMALLINT AS day_of_week,
    CASE WHEN EXTRACT(DOW FROM dt) IN (0, 6) THEN TRUE ELSE FALSE END AS is_weekend
FROM (
    SELECT board_datetime::DATE AS dt FROM public.stg_trips_raw WHERE board_datetime IS NOT NULL
    UNION
    SELECT alight_datetime::DATE AS dt FROM public.stg_trips_raw WHERE alight_datetime IS NOT NULL
) dates WHERE dt IS NOT NULL;
""")
result = rs_exec("SELECT COUNT(*) AS cnt FROM public.dw_dim_date;")
print(f"   dw_dim_date: {result[0]['cnt']} rows")

# dim_rider (provided)
rs_exec("""
INSERT INTO public.dw_dim_rider (rider_id, rider_segment, effective_from, effective_to, is_current)
SELECT DISTINCT
    rider_id,
    CAST(NULL AS VARCHAR(16)) AS rider_segment,
    MIN(board_datetime) AS effective_from,
    CAST(NULL AS TIMESTAMP) AS effective_to,
    TRUE AS is_current
FROM public.stg_trips_raw
WHERE rider_id IS NOT NULL AND rider_id != ''
GROUP BY rider_id;
""")
result = rs_exec("SELECT COUNT(*) AS cnt FROM public.dw_dim_rider;")
print(f"   dw_dim_rider: {result[0]['cnt']} rows")

**TODO**: Write INSERT...SELECT statements to populate the remaining dimensions:

1. **dw_dim_route**: SELECT DISTINCT route_id, mode from staging
2. **dw_dim_station**: SELECT DISTINCT station_id (UNION of origin and destination)
3. **dw_dim_fare_class**: SELECT DISTINCT fare_class from staging
4. **dw_dim_payment_method**: SELECT DISTINCT payment_method from staging

In [None]:
# TODO: Populate dim_route
rs_exec("""
-- TODO: INSERT INTO public.dw_dim_route (route_id, mode)
-- SELECT DISTINCT route_id, mode FROM staging

""")
result = rs_exec("SELECT COUNT(*) AS cnt FROM public.dw_dim_route;")
print(f"   dw_dim_route: {result[0]['cnt']} rows")

# TODO: Populate dim_station (UNION of origin and destination)
rs_exec("""
-- TODO: INSERT INTO public.dw_dim_station (station_id)
-- Hint: Use UNION to combine origin_station_id and destination_station_id

""")
result = rs_exec("SELECT COUNT(*) AS cnt FROM public.dw_dim_station;")
print(f"   dw_dim_station: {result[0]['cnt']} rows")

# TODO: Populate dim_fare_class
rs_exec("""
-- TODO: INSERT INTO public.dw_dim_fare_class (fare_class)
-- SELECT DISTINCT fare_class FROM staging

""")
result = rs_exec("SELECT COUNT(*) AS cnt FROM public.dw_dim_fare_class;")
print(f"   dw_dim_fare_class: {result[0]['cnt']} rows")

# TODO: Populate dim_payment_method
rs_exec("""
-- TODO: INSERT INTO public.dw_dim_payment_method (payment_method)
-- SELECT DISTINCT payment_method FROM staging

""")
result = rs_exec("SELECT COUNT(*) AS cnt FROM public.dw_dim_payment_method;")
print(f"   dw_dim_payment_method: {result[0]['cnt']} rows")

print("\nAll dimensions populated!")

---
## Step 4: Create Fact Table

The fact table stores the trip transactions with surrogate keys referencing dimensions.

**TODO**: Create the fact table `dw_fact_trips` with these columns:

| Column | Type | Description |
|--------|------|-------------|
| trip_sk | BIGINT IDENTITY | Surrogate key |
| trip_id | VARCHAR(32) | Natural key |
| rider_sk | BIGINT | FK to dim_rider |
| route_sk | BIGINT | FK to dim_route |
| origin_station_sk | BIGINT | FK to dim_station |
| destination_station_sk | BIGINT | FK to dim_station |
| fare_class_sk | BIGINT | FK to dim_fare_class |
| payment_method_sk | BIGINT | FK to dim_payment_method |
| board_date_key | INTEGER | FK to dim_date |
| alight_date_key | INTEGER | FK to dim_date |
| transfers | INTEGER | Measure |
| zones_charged | INTEGER | Measure |
| distance_km | DECIMAL(10,2) | Measure |
| base_fare_cad | DECIMAL(10,2) | Measure |
| discount_rate | DECIMAL(5,3) | Measure |
| discount_amount_cad | DECIMAL(10,2) | Measure |
| yvr_addfare_cad | DECIMAL(10,2) | Measure |
| total_fare_cad | DECIMAL(10,2) | Measure |
| on_time_arrival | BOOLEAN | Flag |
| service_disruption | BOOLEAN | Flag |

In [None]:
# ========= STEP 4: Create fact_trips table =========

print("Creating fact table...")
print("=" * 60)

rs_exec("DROP TABLE IF EXISTS public.dw_fact_trips;")

# TODO: Write CREATE TABLE for dw_fact_trips
rs_exec("""
CREATE TABLE public.dw_fact_trips (
    -- TODO: Add all columns with appropriate types
    -- Include surrogate keys for dimension lookups
    -- Include date keys as INTEGER for efficient filtering
    -- Include all measure columns
    
) DISTKEY (rider_sk) SORTKEY (board_date_key);
""")

print("   Created: dw_fact_trips")

---
## Step 5: INSERT...SELECT to Populate Fact Table

This is the core of the exercise - joining staging to dimensions and inserting into facts.

**Key concepts:**
- Join staging table to each dimension table to get surrogate keys
- Convert timestamps to integer date keys using `TO_CHAR(date, 'YYYYMMDD')`
- Use LEFT JOINs to handle any missing dimension values

**TODO**: Write the INSERT...SELECT statement to populate the fact table.

Join staging to:
- `dw_dim_rider` on rider_id (WHERE is_current = TRUE)
- `dw_dim_route` on route_id
- `dw_dim_station` (twice) on origin_station_id and destination_station_id
- `dw_dim_fare_class` on fare_class
- `dw_dim_payment_method` on payment_method

Derive date keys:
```sql
CAST(TO_CHAR(t.board_datetime::DATE, 'YYYYMMDD') AS INTEGER) AS board_date_key
```

In [None]:
# ========= STEP 5: INSERT...SELECT to populate fact_trips =========

print("Populating fact table with INSERT...SELECT...")
print("=" * 60)

start_time = time.time()

# TODO: Write the INSERT...SELECT statement
rs_exec("""
INSERT INTO public.dw_fact_trips (
    trip_id, rider_sk, route_sk, origin_station_sk, destination_station_sk,
    fare_class_sk, payment_method_sk, board_date_key, alight_date_key,
    transfers, zones_charged, distance_km, base_fare_cad, discount_rate,
    discount_amount_cad, yvr_addfare_cad, total_fare_cad, on_time_arrival, service_disruption
)
SELECT
    -- TODO: Select trip_id from staging
    -- TODO: Select rider_sk from dim_rider join
    -- TODO: Select route_sk from dim_route join
    -- TODO: Select origin_station_sk from dim_station join
    -- TODO: Select destination_station_sk from dim_station join
    -- TODO: Select fare_class_sk from dim_fare_class join
    -- TODO: Select payment_method_sk from dim_payment_method join
    -- TODO: Derive board_date_key using TO_CHAR
    -- TODO: Derive alight_date_key using TO_CHAR
    -- TODO: Select all measure columns from staging
    
FROM public.stg_trips_raw t
-- TODO: Add LEFT JOINs to all dimension tables

;
""")

elapsed = time.time() - start_time
result = rs_exec("SELECT COUNT(*) AS cnt FROM public.dw_fact_trips;")

print(f"   Inserted {result[0]['cnt']:,} rows into dw_fact_trips")
print(f"   Elapsed time: {elapsed:.1f} seconds")

---
## Step 6: Validate the Load

In [None]:
# ========= STEP 6: Validate the load =========

print("Validation")
print("=" * 60)

# Row count comparison
print("\n1. Row Count Comparison:")
results = rs_exec("""
SELECT 'stg_trips_raw' AS table_name, COUNT(*) AS row_count FROM public.stg_trips_raw
UNION ALL
SELECT 'dw_fact_trips', COUNT(*) FROM public.dw_fact_trips;
""")
display(pd.DataFrame(results))

# Referential integrity
print("\n2. NULL Foreign Keys (should all be 0):")
results = rs_exec("""
SELECT 'rider_sk' AS column_name, COUNT(*) AS null_count FROM public.dw_fact_trips WHERE rider_sk IS NULL
UNION ALL SELECT 'route_sk', COUNT(*) FROM public.dw_fact_trips WHERE route_sk IS NULL
UNION ALL SELECT 'fare_class_sk', COUNT(*) FROM public.dw_fact_trips WHERE fare_class_sk IS NULL
UNION ALL SELECT 'payment_method_sk', COUNT(*) FROM public.dw_fact_trips WHERE payment_method_sk IS NULL;
""")
display(pd.DataFrame(results))

# Sample data
print("\n3. Sample Fact Data with Dimension Lookups:")
results = rs_exec("""
SELECT f.trip_id, dr.rider_id, drt.route_id, drt.mode, dfc.fare_class, f.total_fare_cad
FROM public.dw_fact_trips f
LEFT JOIN public.dw_dim_rider dr ON f.rider_sk = dr.rider_sk
LEFT JOIN public.dw_dim_route drt ON f.route_sk = drt.route_sk
LEFT JOIN public.dw_dim_fare_class dfc ON f.fare_class_sk = dfc.fare_class_sk
LIMIT 5;
""")
display(pd.DataFrame(results))

---
## Summary

### Lesson 4 - Exercise 2 Complete

You successfully:

1. **Created dimension tables** (date, rider, route, station, fare_class, payment)
2. **Populated dimensions** from staging data
3. **Created the fact table** with surrogate key references
4. **Wrote INSERT...SELECT** to populate fact_trips
5. **Validated row counts** and referential integrity