# Lesson 4 - Exercise 2: Write INSERT...SELECT to Populate Dimension and Fact Tables

## Learning Objectives

1. Understand the staging-to-warehouse transformation pattern
2. Write INSERT...SELECT statements that join staging to dimension tables
3. Replace natural keys with surrogate keys during the insert
4. Derive integer date keys from timestamps for efficient filtering
5. Validate referential integrity between facts and dimensions

## Prerequisites

- `aws_config.py` with your credentials
- `data/van_transit_trips_postgres.csv` data file

---
## Lesson 4 Exercise 2: Write INSERT...SELECT to Populate Dimension and Fact Tables Solution

In [1]:
# ========= SETUP: Imports, Config, and Functions (Run this cell first!) =========

import os
import time
from datetime import datetime
from typing import Dict, Any, List

import pandas as pd
import numpy as np
import boto3

# Load AWS credentials from aws_config.py
import aws_config

# Configuration
AWS_REGION = os.getenv('AWS_REGION', 'us-east-1')
REDSHIFT_DATABASE = os.getenv('REDSHIFT_DATABASE', 'dev')
REDSHIFT_WORKGROUP = os.getenv('REDSHIFT_WORKGROUP', 'udacity-dwh-wg')
REDSHIFT_SECRET_ARN = os.getenv('REDSHIFT_SECRET_ARN', None)  # Optional

print("Configuration:")
print(f"   Region: {AWS_REGION}")
print(f"   Database: {REDSHIFT_DATABASE}")
print(f"   Workgroup: {REDSHIFT_WORKGROUP}")

# Redshift Data API client
session_boto = boto3.Session(region_name=AWS_REGION)
rsd = session_boto.client("redshift-data", region_name=AWS_REGION)


def _rs_kwargs() -> Dict[str, Any]:
    """Build connection arguments for Redshift Data API."""
    base = dict(Database=REDSHIFT_DATABASE)
    if REDSHIFT_WORKGROUP:
        base["WorkgroupName"] = REDSHIFT_WORKGROUP
        if REDSHIFT_SECRET_ARN:
            base["SecretArn"] = REDSHIFT_SECRET_ARN
    return base


def rs_exec(sql: str, return_results=False, timeout_s=900):
    """Execute SQL on Redshift via the Data API."""
    sql = sql.strip()
    if not sql:
        return None
    
    kwargs = _rs_kwargs()
    kwargs["Sql"] = sql
    
    sid = rsd.execute_statement(**kwargs)["Id"]
    
    start = time.time()
    while True:
        d = rsd.describe_statement(Id=sid)
        if d["Status"] in ("FINISHED", "FAILED", "ABORTED"):
            break
        if time.time() - start > timeout_s:
            raise TimeoutError("Redshift statement timeout")
        time.sleep(0.5)
    
    if d["Status"] != "FINISHED":
        raise RuntimeError(f"Redshift SQL failed: {d.get('Error')}\n---\n{sql}")
    
    if return_results or sql.strip().lower().startswith("select"):
        out, next_token = [], None
        while True:
            args = dict(Id=sid)
            if next_token:
                args["NextToken"] = next_token
            r = rsd.get_statement_result(**args)
            cols = [c["name"] for c in r["ColumnMetadata"]]
            for rec in r["Records"]:
                row = [next(iter(cell.values())) for cell in rec]
                out.append(dict(zip(cols, row)))
            next_token = r.get("NextToken")
            if not next_token:
                break
        return out
    
    return None


def rs_batch_insert(table: str, df: pd.DataFrame, batch_size: int = 50):
    """Insert DataFrame rows into Redshift using batched INSERT statements."""
    if df is None or df.empty:
        return 0
    
    cols = list(df.columns)
    total = 0
    
    def format_value(val):
        if pd.isna(val) or val is None:
            return "NULL"
        elif isinstance(val, bool):
            return "TRUE" if val else "FALSE"
        elif isinstance(val, (int, float)):
            return str(val)
        else:
            return "'" + str(val).replace("'", "''") + "'"
    
    rows = df.to_dict(orient="records")
    
    for i in range(0, len(rows), batch_size):
        chunk = rows[i:i+batch_size]
        value_rows = []
        
        for r in chunk:
            values = [format_value(r.get(c)) for c in cols]
            value_rows.append(f"({', '.join(values)})")
        
        sql = f"INSERT INTO {table} ({', '.join(cols)}) VALUES {', '.join(value_rows)}"
        rs_exec(sql)
        total += len(chunk)
        
        if total % 500 == 0:
            print(f"      Progress: {total:,}/{len(rows):,} rows")
    
    return total


print("\nFunctions defined: rs_exec(), rs_batch_insert()")
print("Setup complete!")

Configuration:
   Region: us-east-1
   Database: dev
   Workgroup: udacity-dwh-wg

Functions defined: rs_exec(), rs_batch_insert()
Setup complete!


---
## Step 1: Create and Load Staging Table

In [3]:
# ========= STEP 1: Check/Create Staging Table =========

def check_staging_exists():
    try:
        result = rs_exec("SELECT COUNT(*) AS cnt FROM public.stg_trips_raw;")
        return result[0]['cnt'] > 0
    except:
        return False

print("Checking staging table...")
print("=" * 60)

if check_staging_exists():
    result = rs_exec("SELECT COUNT(*) AS cnt FROM public.stg_trips_raw;")
    print(f"   Staging table exists with {result[0]['cnt']} rows")
else:
    print("   Staging table not found. Creating and loading...")
    
    rs_exec("DROP TABLE IF EXISTS public.stg_trips_raw;")
    rs_exec("""
    CREATE TABLE public.stg_trips_raw (
        trip_id               VARCHAR(32),
        rider_id              VARCHAR(32),
        route_id              VARCHAR(32),
        mode                  VARCHAR(16),
        origin_station_id     VARCHAR(32),
        destination_station_id VARCHAR(32),
        board_datetime        TIMESTAMP,
        alight_datetime       TIMESTAMP,
        country               VARCHAR(8),
        province              VARCHAR(8),
        fare_class            VARCHAR(16),
        payment_method        VARCHAR(32),
        transfers             INTEGER,
        zones_charged         INTEGER,
        distance_km           DECIMAL(10,2),
        base_fare_cad         DECIMAL(10,2),
        discount_rate         DECIMAL(5,3),
        discount_amount_cad   DECIMAL(10,2),
        yvr_addfare_cad       DECIMAL(10,2),
        total_fare_cad        DECIMAL(10,2),
        on_time_arrival       BOOLEAN,
        service_disruption    BOOLEAN,
        polyline_stations     VARCHAR(512)
    );
    """)
    print("   Created stg_trips_raw table")
    
    csv_path = "data/van_transit_trips_postgres.csv"
    print(f"   Loading data from {csv_path}...")
    trips_df = pd.read_csv(csv_path)
    print(f"   Read {len(trips_df):,} rows from CSV")
    
    rows = rs_batch_insert("public.stg_trips_raw", trips_df, batch_size=50)
    print(f"   Loaded {rows:,} rows into staging table")

Checking staging table...
   Staging table not found. Creating and loading...
   Created stg_trips_raw table
   Loading data from data/van_transit_trips_postgres.csv...
   Read 2,500 rows from CSV
      Progress: 500/2,500 rows
      Progress: 1,000/2,500 rows
      Progress: 1,500/2,500 rows
      Progress: 2,000/2,500 rows
      Progress: 2,500/2,500 rows
   Loaded 2,500 rows into staging table


---
## Step 2: Create Dimension Tables

In [4]:
# ========= STEP 2: Create all dimension tables =========

print("Creating dimension tables...")
print("=" * 60)

# dim_date
rs_exec("DROP TABLE IF EXISTS public.dw_dim_date;")
rs_exec("""
CREATE TABLE public.dw_dim_date (
    date_key        INTEGER       NOT NULL,
    date_actual     DATE          NOT NULL,
    year            SMALLINT      NOT NULL,
    quarter         SMALLINT      NOT NULL,
    month           SMALLINT      NOT NULL,
    day             SMALLINT      NOT NULL,
    week_of_year    SMALLINT      NOT NULL,
    day_of_week     SMALLINT      NOT NULL,
    is_weekend      BOOLEAN       NOT NULL,
    PRIMARY KEY (date_key)
) DISTSTYLE ALL SORTKEY (date_key);
""")
print("   Created: dw_dim_date")

# dim_rider
rs_exec("DROP TABLE IF EXISTS public.dw_dim_rider;")
rs_exec("""
CREATE TABLE public.dw_dim_rider (
    rider_sk        BIGINT IDENTITY(1,1),
    rider_id        VARCHAR(32)   ENCODE zstd,
    rider_segment   VARCHAR(16)   ENCODE zstd,
    effective_from  TIMESTAMP     ENCODE zstd,
    effective_to    TIMESTAMP     ENCODE zstd,
    is_current      BOOLEAN       ENCODE zstd,
    PRIMARY KEY (rider_sk)
) DISTKEY (rider_id) SORTKEY (rider_id);
""")
print("   Created: dw_dim_rider")

# dim_route
rs_exec("DROP TABLE IF EXISTS public.dw_dim_route;")
rs_exec("""
CREATE TABLE public.dw_dim_route (
    route_sk        BIGINT IDENTITY(1,1),
    route_id        VARCHAR(32)   ENCODE zstd,
    mode            VARCHAR(16)   ENCODE zstd,
    PRIMARY KEY (route_sk)
) DISTSTYLE ALL SORTKEY (route_id);
""")
print("   Created: dw_dim_route")

# dim_station
rs_exec("DROP TABLE IF EXISTS public.dw_dim_station;")
rs_exec("""
CREATE TABLE public.dw_dim_station (
    station_sk      BIGINT IDENTITY(1,1),
    station_id      VARCHAR(32)   ENCODE zstd,
    PRIMARY KEY (station_sk)
) DISTSTYLE ALL SORTKEY (station_id);
""")
print("   Created: dw_dim_station")

# dim_fare_class
rs_exec("DROP TABLE IF EXISTS public.dw_dim_fare_class;")
rs_exec("""
CREATE TABLE public.dw_dim_fare_class (
    fare_class_sk   BIGINT IDENTITY(1,1),
    fare_class      VARCHAR(16)   ENCODE zstd,
    PRIMARY KEY (fare_class_sk)
) DISTSTYLE ALL SORTKEY (fare_class);
""")
print("   Created: dw_dim_fare_class")

# dim_payment_method
rs_exec("DROP TABLE IF EXISTS public.dw_dim_payment_method;")
rs_exec("""
CREATE TABLE public.dw_dim_payment_method (
    payment_method_sk  BIGINT IDENTITY(1,1),
    payment_method     VARCHAR(32)  ENCODE zstd,
    PRIMARY KEY (payment_method_sk)
) DISTSTYLE ALL SORTKEY (payment_method);
""")
print("   Created: dw_dim_payment_method")

print("\nAll dimension tables created!")

Creating dimension tables...
   Created: dw_dim_date
   Created: dw_dim_rider
   Created: dw_dim_route
   Created: dw_dim_station
   Created: dw_dim_fare_class
   Created: dw_dim_payment_method

All dimension tables created!


---
## Step 3: Populate Dimension Tables

In [5]:
# ========= STEP 3: Populate all dimension tables =========

print("Populating dimension tables...")
print("=" * 60)

# dim_date
rs_exec("""
INSERT INTO public.dw_dim_date (
    date_key, date_actual, year, quarter, month, day, 
    week_of_year, day_of_week, is_weekend
)
SELECT DISTINCT
    CAST(TO_CHAR(dt, 'YYYYMMDD') AS INTEGER) AS date_key,
    dt AS date_actual,
    EXTRACT(YEAR FROM dt)::SMALLINT AS year,
    EXTRACT(QUARTER FROM dt)::SMALLINT AS quarter,
    EXTRACT(MONTH FROM dt)::SMALLINT AS month,
    EXTRACT(DAY FROM dt)::SMALLINT AS day,
    EXTRACT(WEEK FROM dt)::SMALLINT AS week_of_year,
    EXTRACT(DOW FROM dt)::SMALLINT AS day_of_week,
    CASE WHEN EXTRACT(DOW FROM dt) IN (0, 6) THEN TRUE ELSE FALSE END AS is_weekend
FROM (
    SELECT board_datetime::DATE AS dt FROM public.stg_trips_raw WHERE board_datetime IS NOT NULL
    UNION
    SELECT alight_datetime::DATE AS dt FROM public.stg_trips_raw WHERE alight_datetime IS NOT NULL
) dates WHERE dt IS NOT NULL;
""")
result = rs_exec("SELECT COUNT(*) AS cnt FROM public.dw_dim_date;")
print(f"   dw_dim_date: {result[0]['cnt']} rows")

# dim_rider
rs_exec("""
INSERT INTO public.dw_dim_rider (rider_id, rider_segment, effective_from, effective_to, is_current)
SELECT DISTINCT
    rider_id,
    CAST(NULL AS VARCHAR(16)) AS rider_segment,
    MIN(board_datetime) AS effective_from,
    CAST(NULL AS TIMESTAMP) AS effective_to,
    TRUE AS is_current
FROM public.stg_trips_raw
WHERE rider_id IS NOT NULL AND rider_id != ''
GROUP BY rider_id;
""")
result = rs_exec("SELECT COUNT(*) AS cnt FROM public.dw_dim_rider;")
print(f"   dw_dim_rider: {result[0]['cnt']} rows")

# dim_route
rs_exec("""
INSERT INTO public.dw_dim_route (route_id, mode)
SELECT DISTINCT route_id, mode
FROM public.stg_trips_raw
WHERE route_id IS NOT NULL AND route_id != '';
""")
result = rs_exec("SELECT COUNT(*) AS cnt FROM public.dw_dim_route;")
print(f"   dw_dim_route: {result[0]['cnt']} rows")

# dim_station
rs_exec("""
INSERT INTO public.dw_dim_station (station_id)
SELECT DISTINCT station_id FROM (
    SELECT origin_station_id AS station_id FROM public.stg_trips_raw WHERE origin_station_id IS NOT NULL
    UNION
    SELECT destination_station_id AS station_id FROM public.stg_trips_raw WHERE destination_station_id IS NOT NULL
) s WHERE station_id != '';
""")
result = rs_exec("SELECT COUNT(*) AS cnt FROM public.dw_dim_station;")
print(f"   dw_dim_station: {result[0]['cnt']} rows")

# dim_fare_class
rs_exec("""
INSERT INTO public.dw_dim_fare_class (fare_class)
SELECT DISTINCT fare_class
FROM public.stg_trips_raw
WHERE fare_class IS NOT NULL AND fare_class != '';
""")
result = rs_exec("SELECT COUNT(*) AS cnt FROM public.dw_dim_fare_class;")
print(f"   dw_dim_fare_class: {result[0]['cnt']} rows")

# dim_payment_method
rs_exec("""
INSERT INTO public.dw_dim_payment_method (payment_method)
SELECT DISTINCT payment_method
FROM public.stg_trips_raw
WHERE payment_method IS NOT NULL AND payment_method != '';
""")
result = rs_exec("SELECT COUNT(*) AS cnt FROM public.dw_dim_payment_method;")
print(f"   dw_dim_payment_method: {result[0]['cnt']} rows")

print("\nAll dimensions populated!")

Populating dimension tables...
   dw_dim_date: 541 rows
   dw_dim_rider: 2455 rows
   dw_dim_route: 423 rows
   dw_dim_station: 30 rows
   dw_dim_fare_class: 5 rows
   dw_dim_payment_method: 4 rows

All dimensions populated!


---
## Step 4: Create Fact Table

In [6]:
# ========= STEP 4: Create fact_trips table =========

print("Creating fact table...")
print("=" * 60)

rs_exec("DROP TABLE IF EXISTS public.dw_fact_trips;")
rs_exec("""
CREATE TABLE public.dw_fact_trips (
    trip_sk                 BIGINT IDENTITY(1,1),
    trip_id                 VARCHAR(32)   ENCODE zstd,
    rider_sk                BIGINT        ENCODE zstd,
    route_sk                BIGINT        ENCODE zstd,
    origin_station_sk       BIGINT        ENCODE zstd,
    destination_station_sk  BIGINT        ENCODE zstd,
    fare_class_sk           BIGINT        ENCODE zstd,
    payment_method_sk       BIGINT        ENCODE zstd,
    board_date_key          INTEGER       ENCODE zstd,
    alight_date_key         INTEGER       ENCODE zstd,
    transfers               INTEGER       ENCODE zstd,
    zones_charged           INTEGER       ENCODE zstd,
    distance_km             DECIMAL(10,2) ENCODE zstd,
    base_fare_cad           DECIMAL(10,2) ENCODE zstd,
    discount_rate           DECIMAL(5,3)  ENCODE zstd,
    discount_amount_cad     DECIMAL(10,2) ENCODE zstd,
    yvr_addfare_cad         DECIMAL(10,2) ENCODE zstd,
    total_fare_cad          DECIMAL(10,2) ENCODE zstd,
    on_time_arrival         BOOLEAN       ENCODE zstd,
    service_disruption      BOOLEAN       ENCODE zstd,
    PRIMARY KEY (trip_sk)
) DISTKEY (rider_sk) SORTKEY (board_date_key);
""")

print("   Created: dw_fact_trips")

Creating fact table...
   Created: dw_fact_trips


---
## Step 5: INSERT...SELECT to Populate Fact Table

This is the core of the exercise - joining staging to dimensions and inserting into facts.

In [7]:
# ========= STEP 5: INSERT...SELECT to populate fact_trips =========

print("Populating fact table with INSERT...SELECT...")
print("=" * 60)

start_time = time.time()

rs_exec("""
INSERT INTO public.dw_fact_trips (
    trip_id, rider_sk, route_sk, origin_station_sk, destination_station_sk,
    fare_class_sk, payment_method_sk, board_date_key, alight_date_key,
    transfers, zones_charged, distance_km, base_fare_cad, discount_rate,
    discount_amount_cad, yvr_addfare_cad, total_fare_cad, on_time_arrival, service_disruption
)
SELECT
    t.trip_id,
    dr.rider_sk,
    drt.route_sk,
    ds_orig.station_sk AS origin_station_sk,
    ds_dest.station_sk AS destination_station_sk,
    dfc.fare_class_sk,
    dpm.payment_method_sk,
    CAST(TO_CHAR(t.board_datetime::DATE, 'YYYYMMDD') AS INTEGER) AS board_date_key,
    CAST(TO_CHAR(t.alight_datetime::DATE, 'YYYYMMDD') AS INTEGER) AS alight_date_key,
    t.transfers,
    t.zones_charged,
    t.distance_km,
    t.base_fare_cad,
    t.discount_rate,
    t.discount_amount_cad,
    t.yvr_addfare_cad,
    t.total_fare_cad,
    t.on_time_arrival,
    t.service_disruption
FROM public.stg_trips_raw t
LEFT JOIN public.dw_dim_rider dr ON dr.rider_id = t.rider_id AND dr.is_current = TRUE
LEFT JOIN public.dw_dim_route drt ON drt.route_id = t.route_id
LEFT JOIN public.dw_dim_station ds_orig ON ds_orig.station_id = t.origin_station_id
LEFT JOIN public.dw_dim_station ds_dest ON ds_dest.station_id = t.destination_station_id
LEFT JOIN public.dw_dim_fare_class dfc ON dfc.fare_class = t.fare_class
LEFT JOIN public.dw_dim_payment_method dpm ON dpm.payment_method = t.payment_method;
""")

elapsed = time.time() - start_time
result = rs_exec("SELECT COUNT(*) AS cnt FROM public.dw_fact_trips;")

print(f"   Inserted {result[0]['cnt']:,} rows into dw_fact_trips")
print(f"   Elapsed time: {elapsed:.1f} seconds")

Populating fact table with INSERT...SELECT...
   Inserted 8,883 rows into dw_fact_trips
   Elapsed time: 14.3 seconds


---
## Step 6: Validate the Load

In [8]:
# ========= STEP 6: Validate the load =========

print("Validation")
print("=" * 60)

# Row count comparison
print("\n1. Row Count Comparison:")
results = rs_exec("""
SELECT 'stg_trips_raw' AS table_name, COUNT(*) AS row_count FROM public.stg_trips_raw
UNION ALL
SELECT 'dw_fact_trips', COUNT(*) FROM public.dw_fact_trips;
""")
display(pd.DataFrame(results))

# Referential integrity
print("\n2. NULL Foreign Keys (should all be 0):")
results = rs_exec("""
SELECT 'rider_sk' AS column_name, COUNT(*) AS null_count FROM public.dw_fact_trips WHERE rider_sk IS NULL
UNION ALL SELECT 'route_sk', COUNT(*) FROM public.dw_fact_trips WHERE route_sk IS NULL
UNION ALL SELECT 'fare_class_sk', COUNT(*) FROM public.dw_fact_trips WHERE fare_class_sk IS NULL
UNION ALL SELECT 'payment_method_sk', COUNT(*) FROM public.dw_fact_trips WHERE payment_method_sk IS NULL;
""")
display(pd.DataFrame(results))

# Sample data
print("\n3. Sample Fact Data with Dimension Lookups:")
results = rs_exec("""
SELECT f.trip_id, dr.rider_id, drt.route_id, drt.mode, dfc.fare_class, f.total_fare_cad
FROM public.dw_fact_trips f
LEFT JOIN public.dw_dim_rider dr ON f.rider_sk = dr.rider_sk
LEFT JOIN public.dw_dim_route drt ON f.route_sk = drt.route_sk
LEFT JOIN public.dw_dim_fare_class dfc ON f.fare_class_sk = dfc.fare_class_sk
LIMIT 5;
""")
display(pd.DataFrame(results))

Validation

1. Row Count Comparison:


Unnamed: 0,table_name,row_count
0,stg_trips_raw,2500
1,dw_fact_trips,8883



2. NULL Foreign Keys (should all be 0):


Unnamed: 0,column_name,null_count
0,rider_sk,0
1,route_sk,0
2,payment_method_sk,0
3,fare_class_sk,0



3. Sample Fact Data with Dimension Lookups:


Unnamed: 0,trip_id,rider_id,route_id,mode,fare_class,total_fare_cad
0,T100560,R41921,R071,seabus,senior,2.03
1,T100560,R41921,R071,wce,senior,2.03
2,T100560,R41921,R071,skytrain,senior,2.03
3,T100560,R41921,R071,bus,senior,2.03
4,T100105,R26163,R114,seabus,adult,3.2


---
## Summary

### Lesson 4 - Exercise 2 Complete

You successfully:

1. **Created dimension tables** (date, rider, route, station, fare_class, payment)
2. **Populated dimensions** from staging data
3. **Created the fact table** with surrogate key references
4. **Wrote INSERT...SELECT** to populate fact_trips
5. **Validated row counts** and referential integrity