# Lesson 2: Exercise 2 - Design the `dw_fact_trips` Table

## Goal

Define the core trips fact table (1 row per `trip_id`) with correct **grain**, **foreign-key columns** (SKs to dimensions), **measures**, and **Redshift physical design** (distribution/sort/encodings).

## What You Will Build

Create `dw_fact_trips` with:

- **Grain**: 1 row per `trip_id`
- **Foreign keys**: `rider_sk`, `route_sk`, `mode_sk`, `origin_station_sk`, `destination_station_sk`
- **Date keys**: Integer `YYYYMMDD` format (`board_date_key`, `alight_date_key`)
- **Measures**: fares, distance, transfers, zones
- **Flags**: `on_time_arrival`, `service_disruption`
- **Distribution**: collocate rider-centric queries
- **Sort**: time-based pruning for range filters

---

## Imports and Dependencies

Run this cell first to import all required libraries.

In [None]:
# ========= Imports
import os
import time
from typing import Dict, Any, List

import pandas as pd
import boto3

print("All imports successful!")
print(f"   - pandas version: {pd.__version__}")

---
## Configuration

Configure your Redshift connection. The same pattern is used in the final project.

In [None]:
# ========= CONFIG (edit for your environment)
# Set your AWS credentials in the aws_config.py file
from aws_config import *  # This sets all AWS env vars

# ---- Read configuration from environment
AWS_ACCESS_KEY_ID           = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY       = os.getenv("AWS_SECRET_ACCESS_KEY")
AWS_SESSION_TOKEN           = os.getenv("AWS_SESSION_TOKEN")
AWS_REGION                  = os.getenv("AWS_REGION")
REDSHIFT_DATABASE           = os.getenv("REDSHIFT_DATABASE")
REDSHIFT_WORKGROUP          = os.getenv("REDSHIFT_WORKGROUP")
REDSHIFT_SECRET_ARN         = os.getenv("REDSHIFT_SECRET_ARN")            # Optional
REDSHIFT_CLUSTER_IDENTIFIER = os.getenv("REDSHIFT_CLUSTER_IDENTIFIER")    # For provisioned
REDSHIFT_DB_USER            = os.getenv("REDSHIFT_DB_USER")               # For provisioned

print("Configuration loaded!")
print(f"   - AWS Region: {AWS_REGION}")
print(f"   - Redshift: {REDSHIFT_DATABASE} (workgroup: {REDSHIFT_WORKGROUP})")
print()
if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY:
    print(f"   AWS credentials found (Key ID: {AWS_ACCESS_KEY_ID[:8]}...)")
    if AWS_SESSION_TOKEN:
        print(f"   AWS session token found (temporary credentials)")
else:
    print("   WARNING: AWS credentials NOT FOUND!")
    print("      Redshift operations will fail with 'NoCredentialsError'.")

---
## Redshift Functions

These helper functions match the patterns used in the final project. Learning them here will prepare you for the capstone.

In [None]:
# ========= Redshift Functions

session_boto = boto3.Session(region_name=AWS_REGION)
rsd = session_boto.client("redshift-data", region_name=AWS_REGION)


def _rs_kwargs() -> Dict[str, Any]:
    """
    Shared Redshift Data API connection args.
    
    Supports both:
    - Serverless: uses WorkgroupName (and optionally SecretArn)
    - Provisioned: uses ClusterIdentifier and DbUser
    """
    base = dict(Database=REDSHIFT_DATABASE)
    if REDSHIFT_WORKGROUP:
        base["WorkgroupName"] = REDSHIFT_WORKGROUP
        if REDSHIFT_SECRET_ARN:
            base["SecretArn"] = REDSHIFT_SECRET_ARN
    elif REDSHIFT_CLUSTER_IDENTIFIER and REDSHIFT_DB_USER:
        base["ClusterIdentifier"] = REDSHIFT_CLUSTER_IDENTIFIER
        base["DbUser"] = REDSHIFT_DB_USER
    else:
        raise RuntimeError("Configure Redshift serverless OR provisioned for Data API.")
    return base


def rs_exec(sql: str, params: List[Dict[str, Any]] = None, return_results=False, timeout_s=900):
    """
    Execute SQL on Redshift via the Data API.
    
    Args:
        sql: SQL statement to execute
        params: Optional list of parameter dicts for parameterized queries
        return_results: If True, fetch and return query results
        timeout_s: Maximum seconds to wait for query completion (default 15 min)
    
    Returns:
        List of dicts if return_results=True or query is SELECT, else None
    """
    sql = sql.strip()
    if not sql:
        return None
    
    # Build request kwargs
    kwargs = _rs_kwargs()
    kwargs["Sql"] = sql
    if params:
        kwargs["Parameters"] = params
    
    # Execute statement
    sid = rsd.execute_statement(**kwargs)["Id"]
    
    # Poll for completion
    start = time.time()
    while True:
        d = rsd.describe_statement(Id=sid)
        if d["Status"] in ("FINISHED", "FAILED", "ABORTED"):
            break
        if time.time() - start > timeout_s:
            raise TimeoutError("Redshift statement timeout")
        time.sleep(0.5)
    
    # Check for errors
    if d["Status"] != "FINISHED":
        raise RuntimeError(f"Redshift SQL failed: {d.get('Error')}\n---\n{sql}")
    
    # Return results for SELECT queries or when explicitly requested
    if return_results or sql.lower().startswith("select"):
        out, next_token = [], None
        while True:
            args = dict(Id=sid)
            if next_token:
                args["NextToken"] = next_token
            r = rsd.get_statement_result(**args)
            cols = [c["name"] for c in r["ColumnMetadata"]]
            for rec in r["Records"]:
                row = []
                for cell in rec:
                    row.append(next(iter(cell.values())))
                out.append(dict(zip(cols, row)))
            next_token = r.get("NextToken")
            if not next_token:
                break
        return out
    
    return None


print("Redshift functions defined: _rs_kwargs(), rs_exec()")

---
## Step 1: Design the dw_fact_trips Table DDL

The fact table captures **one row per trip** with:

| Component | Purpose |
|-----------|----------|
| **Surrogate key** (`trip_sk`) | Warehouse-generated ID |
| **Natural key** (`trip_id`) | Original ID from source system |
| **Dimension FKs** | Surrogate keys linking to dimension tables |
| **Date keys** | Integer `YYYYMMDD` format for efficient joins/filters |
| **Measures** | Numeric facts (fares, distance, transfers) |
| **Flags** | Boolean indicators for operational metrics |
| **DISTKEY/SORTKEY** | Physical design for query performance |

**TODO**: Write the DDL statement to create `public.dw_fact_trips`. Your DDL should include:

1. `DROP TABLE IF EXISTS` to allow re-running
2. `CREATE TABLE` with:
   - **Surrogate key**: `trip_sk` - BIGINT IDENTITY(1,1)
   - **Natural key**: `trip_id` - VARCHAR(32) ENCODE zstd
   - **Dimension FKs** (all BIGINT ENCODE zstd):
     - `rider_sk`, `route_sk`, `mode_sk`
     - `origin_station_sk`, `destination_station_sk`
   - **Date keys** (INTEGER ENCODE zstd):
     - `board_date_key`, `alight_date_key`
   - **Trip measures**:
     - `transfers` - INTEGER
     - `zones_charged` - INTEGER
     - `distance_km` - DECIMAL(10,2)
   - **Fare measures** (all DECIMAL with ENCODE zstd):
     - `base_fare_cad` - DECIMAL(12,2)
     - `discount_rate` - DECIMAL(5,3)
     - `discount_amount_cad` - DECIMAL(12,2)
     - `yvr_addfare_cad` - DECIMAL(12,2)
     - `total_fare_cad` - DECIMAL(12,2)
   - **Additional FKs**: `payment_method_sk`, `fare_class_sk` - BIGINT
   - **Flags**: `on_time_arrival`, `service_disruption` - BOOLEAN
3. `DISTKEY (rider_sk)` - collocate for rider-centric analytics
4. `SORTKEY (board_date_key)` - enable time-range query optimization

In [None]:
DDL_FACT_TRIPS = """
-- TODO: Write your DDL here

"""

print("DDL for dw_fact_trips:")
print("=" * 60)
print(DDL_FACT_TRIPS)

---
## Step 2: Execute the DDL

Create the `dw_fact_trips` table in Redshift.

In [None]:
rs_exec(DDL_FACT_TRIPS)
print("Table public.dw_fact_trips created successfully!")

---
## Step 3: Validate the Table Structure

Verify the table was created with all expected columns.

In [None]:
validation_sql = """
SELECT 
    column_name,
    data_type,
    character_maximum_length,
    numeric_precision,
    numeric_scale,
    is_nullable
FROM information_schema.columns
WHERE table_schema = 'public'
  AND table_name = 'dw_fact_trips'
ORDER BY ordinal_position;
"""

columns = rs_exec(validation_sql, return_results=True)

print("Table Structure for dw_fact_trips:")
print("-" * 60)
if columns:
    df = pd.DataFrame(columns)
    display(df)
    print(f"\nTotal columns: {len(columns)}")
else:
    print("No columns found. Check if table was created.")

---
## Step 4: Check Distribution and Sort Keys

Verify that the DISTKEY and SORTKEY were applied correctly.

In [None]:
properties_sql = """
SELECT 
    "column",
    type,
    encoding,
    distkey,
    sortkey
FROM pg_table_def
WHERE schemaname = 'public'
  AND tablename = 'dw_fact_trips'
ORDER BY 
    CASE WHEN distkey THEN 0 ELSE 1 END,
    sortkey DESC,
    "column";
"""

properties = rs_exec(properties_sql, return_results=True)

print("Distribution and Sort Key Configuration:")
print("-" * 60)
if properties:
    df = pd.DataFrame(properties)
    display(df)
    
    # Summarize key columns
    distkey_col = [p["column"] for p in properties if p.get("distkey")]
    sortkey_col = [p["column"] for p in properties if p.get("sortkey") and p["sortkey"] > 0]
    
    print(f"\nDISTKEY column(s): {distkey_col}")
    print(f"SORTKEY column(s): {sortkey_col}")
else:
    print("Could not retrieve table properties.")

---

## Design Rationale

### Why This Design?

| Design Choice | Rationale |
|---------------|------------|
| **Grain: 1 row per trip** | Each row represents one rider's journey from origin to destination. This matches the business event we want to measure. |
| **DISTKEY on `rider_sk`** | Collocates trip facts with `dw_dim_rider` rows. Enables fast joins for rider cohort analysis, lifetime value, and behavior patterns. |
| **SORTKEY on `board_date_key`** | Enables efficient time-window queries (e.g., "trips last month"). Also speeds up materialized view refreshes that filter by date. |
| **Integer date keys** | `YYYYMMDD` format (e.g., 20240115) is compact, sortable, and joins efficiently with `dw_dim_date`. |
| **Separate fare columns** | Breaking down `base_fare`, `discount`, `yvr_addfare`, and `total_fare` enables detailed revenue analysis. |
| **Boolean flags** | `on_time_arrival` and `service_disruption` support operational KPIs without additional joins. |

### Grain

**1 row = 1 rider trip** (from board to alight)

### Dimensional Relationships

```
dw_fact_trips connects to:
  - dw_dim_rider (rider_sk)
  - dw_dim_route (route_sk)
  - dw_dim_mode (mode_sk)
  - dw_dim_station (origin_station_sk, destination_station_sk)
  - dw_dim_date (board_date_key, alight_date_key)
  - dw_dim_payment_method (payment_method_sk)
  - dw_dim_fare_class (fare_class_sk)
```