# Lesson 2: Exercise 3 Solution - Generate an ER Diagram from DDL

## Goal

Produce a **schema diagram** (Mermaid ER) from your DDL so you can include it in your **final project report**. This notebook parses `CREATE TABLE` statements and infers **fact-to-dimension relationships** by scanning `*_sk` columns.

## What You Will Build

A process that:

1. Reads SQL DDL (inline or from files)
2. Parses `CREATE TABLE` statements to collect table names and columns
3. Identifies columns ending with `_sk` as foreign key hints
4. Produces a Mermaid ER diagram with tables and relationships
5. Outputs the diagram to a `.mmd` file for use in documentation


### Acceptance Criteria

- Diagram shows `dw_dim_rider` and `dw_fact_trips` with an edge from trips to rider
- Works even if **no explicit foreign keys** are declared in SQL
- Output can be previewed at [Mermaid Live](https://mermaid.live) or in VS Code

---

## Lesson 2 Exercise 3: Generate an ER Diagram Solution

## Imports and Dependencies

Run this cell first to import all required libraries.

In [1]:
# ========= Imports
import os
import re
from pathlib import Path
from typing import Dict, List, Tuple, Any
from datetime import datetime

print("All imports successful!")

All imports successful!


---
## Configuration

Set up paths for input DDL and output diagram files.

In [2]:
# ========= CONFIG
BASE_DIR = os.getenv("PROJECT_BASE_DIR", ".")
OUTPUT_MERMAID = os.path.join(BASE_DIR, "schema.mmd")

print("Configuration loaded!")
print(f"   - BASE_DIR: {BASE_DIR}")
print(f"   - Output file: {OUTPUT_MERMAID}")

Configuration loaded!
   - BASE_DIR: .
   - Output file: ./schema.mmd


---
## Helper Functions for Reporting

These functions match patterns used in the final project for generating markdown reports.

In [3]:
# ========= Report Helper Functions (same pattern as project solution)

def md_table(rows: List[Dict[str, Any]]) -> str:
    """
    Convert a list of dicts to a Markdown table string.
    
    This is the same helper used in the final project for generating reports.
    
    Args:
        rows: List of dictionaries with consistent keys
    
    Returns:
        Markdown table as a string
    """
    if not rows:
        return "_no rows_\n"
    cols = list(rows[0].keys())
    lines = [
        "| " + " | ".join(cols) + " |",
        "| " + " | ".join(["---"] * len(cols)) + " |"
    ]
    for r in rows:
        lines.append("| " + " | ".join(str(r[c]) for c in cols) + " |")
    return "\n".join(lines) + "\n"


print("Helper function defined: md_table()")

Helper function defined: md_table()


---
## DDL Parser Functions

Functions to extract table names and columns from `CREATE TABLE` statements.

In [4]:
# ========= DDL Parser

# Regex to match CREATE TABLE statements
CREATE_TABLE_RE = re.compile(
    r'CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?P<tbl_name>[A-Za-z0-9_\."]+)\s*\((?P<body>.*?)\)\s*;',
    re.IGNORECASE | re.DOTALL,
)

# Regex to match column definitions
COLUMN_RE = re.compile(
    r'^\s*(?P<col>[A-Za-z0-9_"]+)\s+(?P<type>[A-Za-z0-9_\(\), ]+?)(?:,|$)',
    re.IGNORECASE
)

# Lines to skip (constraints, not columns)
SKIP_PREFIXES = ("PRIMARY", "UNIQUE", "FOREIGN", "CONSTRAINT", "DISTKEY", "SORTKEY", "--")


def parse_tables(sql_text: str) -> Dict[Tuple[str, str], List[Tuple[str, str]]]:
    """
    Parse CREATE TABLE statements from SQL text.
    
    Args:
        sql_text: String containing one or more CREATE TABLE statements
    
    Returns:
        Dict mapping (schema, table_name) to list of (column_name, data_type)
    """
    tables = {}
    
    for match in CREATE_TABLE_RE.finditer(sql_text):
        full_name = match.group("tbl_name").strip('"')
        body = match.group("body")
        
        # Parse schema.table format
        if "." in full_name:
            schema, table_name = full_name.split(".", 1)
        else:
            schema = ""
            table_name = full_name
        
        # Extract columns
        columns = []
        for line in body.splitlines():
            line = line.strip()
            
            # Skip empty lines and constraints
            if not line or line.upper().startswith(SKIP_PREFIXES):
                continue
            
            col_match = COLUMN_RE.match(line)
            if col_match:
                col_name = col_match.group("col").strip('"')
                col_type = col_match.group("type").strip().split()[0]  # First word of type
                columns.append((col_name, col_type))
        
        tables[(schema, table_name)] = columns
    
    return tables


print("Parser function defined: parse_tables()")

Parser function defined: parse_tables()


---
## Relationship Inference

Infer fact-to-dimension relationships by matching `*_sk` columns in fact tables to dimension table names.

In [5]:
# ========= Relationship Inference

def infer_relationships(
    tables: Dict[Tuple[str, str], List[Tuple[str, str]]],
    dim_prefix: str = "dw_dim_",
    fact_prefix: str = "dw_fact_"
) -> List[Tuple[str, str, str]]:
    """
    Infer relationships between fact and dimension tables.
    
    Logic: If a fact table has a column like `rider_sk`, look for `dw_dim_rider`.
    
    Args:
        tables: Output from parse_tables()
        dim_prefix: Prefix for dimension tables (default: "dw_dim_")
        fact_prefix: Prefix for fact tables (default: "dw_fact_")
    
    Returns:
        List of (fact_table, dim_table, column_name) tuples
    """
    # Find all dimension tables
    dims = {name for (_, name) in tables if name.startswith(dim_prefix)}
    
    relationships = []
    
    for (schema, table_name), columns in tables.items():
        # Only process fact tables
        if not table_name.startswith(fact_prefix):
            continue
        
        for col_name, _ in columns:
            # Look for *_sk columns
            if col_name.endswith("_sk"):
                # Derive dimension name: rider_sk -> dw_dim_rider
                dim_candidate = f"{dim_prefix}{col_name[:-3]}"  # Strip "_sk"
                
                if dim_candidate in dims:
                    relationships.append((table_name, dim_candidate, col_name))
    
    return relationships


print("Inference function defined: infer_relationships()")

Inference function defined: infer_relationships()


---
## Mermaid Generator

Generate Mermaid ER diagram syntax from parsed tables and relationships.

In [6]:
# ========= Mermaid Generator

def generate_mermaid(
    tables: Dict[Tuple[str, str], List[Tuple[str, str]]],
    relationships: List[Tuple[str, str, str]]
) -> str:
    """
    Generate Mermaid ER diagram syntax.
    
    Args:
        tables: Output from parse_tables()
        relationships: Output from infer_relationships()
    
    Returns:
        String containing the Mermaid diagram
    """
    lines = ["erDiagram"]
    
    # Generate table blocks
    for (schema, table_name), columns in sorted(tables.items()):
        lines.append(f"    {table_name} {{")
        for col_name, col_type in columns:
            # Mermaid format: TYPE column_name
            mermaid_type = col_type.upper()
            lines.append(f"        {mermaid_type} {col_name}")
        lines.append("    }")
        lines.append("")  # Blank line between tables
    
    # Generate relationships (many fact rows to one dim row)
    for fact_table, dim_table, col_name in relationships:
        # ||--o{ means "one to many" (one dim row, many fact rows)
        lines.append(f"    {dim_table} ||--o{{ {fact_table} : \"{col_name}\"")
    
    return "\n".join(lines)


print("Generator function defined: generate_mermaid()")

Generator function defined: generate_mermaid()


---
## Sample DDL

Use the DDL from Exercises 1 and 2, plus additional dimension tables to create a complete schema.

Note: We use the `public` schema with `dw_` prefixes to match the pattern from Exercises 1 and 2.

In [7]:
# ========= Sample DDL (from Exercises 1 and 2, plus supporting dims)

SAMPLE_DDL = """
-- =============================================================
-- DIMENSION TABLES
-- =============================================================

CREATE TABLE public.dw_dim_rider (
    rider_sk        BIGINT IDENTITY(1,1),
    rider_id        VARCHAR(32),
    rider_segment   VARCHAR(16),
    effective_from  TIMESTAMP,
    effective_to    TIMESTAMP,
    is_current      BOOLEAN,
    PRIMARY KEY (rider_sk)
);

CREATE TABLE public.dw_dim_route (
    route_sk   BIGINT IDENTITY(1,1),
    route_id   VARCHAR(32),
    PRIMARY KEY (route_sk)
);

CREATE TABLE public.dw_dim_mode (
    mode_sk   BIGINT IDENTITY(1,1),
    mode      VARCHAR(32),
    PRIMARY KEY (mode_sk)
);

CREATE TABLE public.dw_dim_station (
    station_sk  BIGINT IDENTITY(1,1),
    station_id  VARCHAR(32),
    city        VARCHAR(64),
    province    VARCHAR(32),
    latitude    DECIMAL(10,6),
    longitude   DECIMAL(10,6),
    PRIMARY KEY (station_sk)
);

CREATE TABLE public.dw_dim_date (
    date_key     INTEGER,
    date_actual  DATE,
    year         SMALLINT,
    quarter      SMALLINT,
    month        SMALLINT,
    day          SMALLINT,
    day_of_week  SMALLINT,
    is_weekend   BOOLEAN,
    PRIMARY KEY (date_key)
);

CREATE TABLE public.dw_dim_payment_method (
    payment_method_sk  BIGINT IDENTITY(1,1),
    payment_method     VARCHAR(32),
    PRIMARY KEY (payment_method_sk)
);

CREATE TABLE public.dw_dim_fare_class (
    fare_class_sk  BIGINT IDENTITY(1,1),
    fare_class     VARCHAR(32),
    PRIMARY KEY (fare_class_sk)
);

-- =============================================================
-- FACT TABLES
-- =============================================================

CREATE TABLE public.dw_fact_trips (
    trip_sk                 BIGINT IDENTITY(1,1),
    trip_id                 VARCHAR(32),
    rider_sk                BIGINT,
    route_sk                BIGINT,
    mode_sk                 BIGINT,
    origin_station_sk       BIGINT,
    destination_station_sk  BIGINT,
    board_date_key          INTEGER,
    alight_date_key         INTEGER,
    transfers               INTEGER,
    zones_charged           INTEGER,
    distance_km             DECIMAL(10,2),
    total_fare_cad          DECIMAL(12,2),
    payment_method_sk       BIGINT,
    fare_class_sk           BIGINT,
    on_time_arrival         BOOLEAN,
    service_disruption      BOOLEAN
);
"""

print(f"Sample DDL defined ({len(SAMPLE_DDL)} characters)")
print("Contains: dw_dim_rider, dw_dim_route, dw_dim_mode, dw_dim_station, dw_dim_date,")
print("          dw_dim_payment_method, dw_dim_fare_class, dw_fact_trips")

Sample DDL defined (2356 characters)
Contains: dw_dim_rider, dw_dim_route, dw_dim_mode, dw_dim_station, dw_dim_date,
          dw_dim_payment_method, dw_dim_fare_class, dw_fact_trips


---
## Step 1: Parse the DDL

Extract tables and columns from the sample DDL.

In [8]:
tables = parse_tables(SAMPLE_DDL)

print(f"Parsed {len(tables)} tables:")
print("-" * 40)
for (schema, name), columns in sorted(tables.items()):
    print(f"  {schema}.{name}: {len(columns)} columns")

Parsed 8 tables:
----------------------------------------
  public.dw_dim_date: 8 columns
  public.dw_dim_fare_class: 2 columns
  public.dw_dim_mode: 2 columns
  public.dw_dim_payment_method: 2 columns
  public.dw_dim_rider: 6 columns
  public.dw_dim_route: 2 columns
  public.dw_dim_station: 6 columns
  public.dw_fact_trips: 17 columns


---
## Step 2: Infer Relationships

Find fact-to-dimension relationships based on `*_sk` column naming.

In [9]:
relationships = infer_relationships(tables)

print(f"Inferred {len(relationships)} relationships:")
print("-" * 40)

# Display as a table using md_table pattern
rel_data = [
    {"dimension": dim, "fact": fact, "fk_column": col}
    for fact, dim, col in relationships
]
print(md_table(rel_data))

Inferred 5 relationships:
----------------------------------------
| dimension | fact | fk_column |
| --- | --- | --- |
| dw_dim_rider | dw_fact_trips | rider_sk |
| dw_dim_route | dw_fact_trips | route_sk |
| dw_dim_mode | dw_fact_trips | mode_sk |
| dw_dim_payment_method | dw_fact_trips | payment_method_sk |
| dw_dim_fare_class | dw_fact_trips | fare_class_sk |



---
## Step 3: Generate the Mermaid Diagram

Create the Mermaid ER diagram syntax.

In [10]:
mermaid_diagram = generate_mermaid(tables, relationships)

print("Generated Mermaid ER Diagram:")
print("=" * 60)
print(mermaid_diagram)

Generated Mermaid ER Diagram:
erDiagram
    dw_dim_date {
        INTEGER date_key
        DATE date_actual
        SMALLINT year
        SMALLINT quarter
        SMALLINT month
        SMALLINT day
        SMALLINT day_of_week
        BOOLEAN is_weekend
    }

    dw_dim_fare_class {
        BIGINT fare_class_sk
        VARCHAR(32) fare_class
    }

    dw_dim_mode {
        BIGINT mode_sk
        VARCHAR(32) mode
    }

    dw_dim_payment_method {
        BIGINT payment_method_sk
        VARCHAR(32) payment_method
    }

    dw_dim_rider {
        BIGINT rider_sk
        VARCHAR(32) rider_id
        VARCHAR(16) rider_segment
        TIMESTAMP effective_from
        TIMESTAMP effective_to
        BOOLEAN is_current
    }

    dw_dim_route {
        BIGINT route_sk
        VARCHAR(32) route_id
    }

    dw_dim_station {
        BIGINT station_sk
        VARCHAR(32) station_id
        VARCHAR(64) city
        VARCHAR(32) province
        DECIMAL(10 latitude
        DECIMAL(10 longitude

---
## Step 4: Save to File

Write the diagram to a `.mmd` file for use in documentation.

In [11]:
output_path = Path(OUTPUT_MERMAID)
output_path.write_text(mermaid_diagram, encoding="utf-8")

print(f"Diagram saved to: {output_path.absolute()}")
print("\nNext steps:")
print("  1. Preview at https://mermaid.live")
print("  2. Or use VS Code with Mermaid extension")
print("  3. Or render to PNG: mmdc -i schema.mmd -o schema.png")

Diagram saved to: /workspace/schema.mmd

Next steps:
  1. Preview at https://mermaid.live
  2. Or use VS Code with Mermaid extension
  3. Or render to PNG: mmdc -i schema.mmd -o schema.png


---
## Step 5: Display in Notebook

Display the Mermaid code for easy copying.

In [12]:
# Display the raw Mermaid code in a code block for easy copying
from IPython.display import Markdown

display(Markdown(f"""```mermaid
{mermaid_diagram}
```"""))

```mermaid
erDiagram
    dw_dim_date {
        INTEGER date_key
        DATE date_actual
        SMALLINT year
        SMALLINT quarter
        SMALLINT month
        SMALLINT day
        SMALLINT day_of_week
        BOOLEAN is_weekend
    }

    dw_dim_fare_class {
        BIGINT fare_class_sk
        VARCHAR(32) fare_class
    }

    dw_dim_mode {
        BIGINT mode_sk
        VARCHAR(32) mode
    }

    dw_dim_payment_method {
        BIGINT payment_method_sk
        VARCHAR(32) payment_method
    }

    dw_dim_rider {
        BIGINT rider_sk
        VARCHAR(32) rider_id
        VARCHAR(16) rider_segment
        TIMESTAMP effective_from
        TIMESTAMP effective_to
        BOOLEAN is_current
    }

    dw_dim_route {
        BIGINT route_sk
        VARCHAR(32) route_id
    }

    dw_dim_station {
        BIGINT station_sk
        VARCHAR(32) station_id
        VARCHAR(64) city
        VARCHAR(32) province
        DECIMAL(10 latitude
        DECIMAL(10 longitude
    }

    dw_fact_trips {
        BIGINT trip_sk
        VARCHAR(32) trip_id
        BIGINT rider_sk
        BIGINT route_sk
        BIGINT mode_sk
        BIGINT origin_station_sk
        BIGINT destination_station_sk
        INTEGER board_date_key
        INTEGER alight_date_key
        INTEGER transfers
        INTEGER zones_charged
        DECIMAL(10 distance_km
        DECIMAL(12 total_fare_cad
        BIGINT payment_method_sk
        BIGINT fare_class_sk
        BOOLEAN on_time_arrival
        BOOLEAN service_disruption
    }

    dw_dim_rider ||--o{ dw_fact_trips : "rider_sk"
    dw_dim_route ||--o{ dw_fact_trips : "route_sk"
    dw_dim_mode ||--o{ dw_fact_trips : "mode_sk"
    dw_dim_payment_method ||--o{ dw_fact_trips : "payment_method_sk"
    dw_dim_fare_class ||--o{ dw_fact_trips : "fare_class_sk"
```

---
## Bonus: Generate a Report Summary

This shows how you might include the diagram in a markdown report, similar to the final project's reporting section.

In [13]:
# Generate a report like the project solution does
report_path = Path(BASE_DIR) / "schema_report.md"

with open(report_path, "w", encoding="utf-8") as f:
    f.write("# Schema Design Report\n\n")
    f.write(f"_Generated: {datetime.utcnow().isoformat()}Z_\n\n")
    
    f.write("## Schema Diagram (Mermaid)\n\n")
    f.write("```mermaid\n" + mermaid_diagram + "\n```\n\n")
    
    f.write("## Tables Summary\n\n")
    table_summary = [
        {"schema": schema, "table": name, "columns": len(cols)}
        for (schema, name), cols in sorted(tables.items())
    ]
    f.write(md_table(table_summary))
    
    f.write("\n## Relationships\n\n")
    f.write(md_table(rel_data))

print(f"Report saved to: {report_path.absolute()}")

Report saved to: /workspace/schema_report.md
