# Lesson 2: Exercise 3 - Generate an ER Diagram from DDL

## Goal

Produce a **schema diagram** (Mermaid ER) from your DDL so you can include it in your **final project report**. This notebook parses `CREATE TABLE` statements and infers **fact-to-dimension relationships** by scanning `*_sk` columns.

## What You Will Build

A process that:

1. Reads SQL DDL (inline or from files)
2. Parses `CREATE TABLE` statements to collect table names and columns
3. Identifies columns ending with `_sk` as foreign key hints
4. Produces a Mermaid ER diagram with tables and relationships
5. Outputs the diagram to a `.mmd` file for use in documentation

---

## Imports and Dependencies

Run this cell first to import all required libraries.

In [None]:
# ========= Imports
import os
import re
from pathlib import Path
from typing import Dict, List, Tuple, Any
from datetime import datetime

print("All imports successful!")

---
## Configuration

Set up paths for input DDL and output diagram files.

In [None]:
# ========= CONFIG
BASE_DIR = os.getenv("PROJECT_BASE_DIR", ".")
OUTPUT_MERMAID = os.path.join(BASE_DIR, "schema.mmd")

print("Configuration loaded!")
print(f"   - BASE_DIR: {BASE_DIR}")
print(f"   - Output file: {OUTPUT_MERMAID}")

---
## Helper Functions for Reporting

These functions match patterns used in the final project for generating markdown reports.

In [None]:
# ========= Report Helper Functions (same pattern as project solution)

def md_table(rows: List[Dict[str, Any]]) -> str:
    """
    Convert a list of dicts to a Markdown table string.
    
    This is the same helper used in the final project for generating reports.
    
    Args:
        rows: List of dictionaries with consistent keys
    
    Returns:
        Markdown table as a string
    """
    if not rows:
        return "_no rows_\n"
    cols = list(rows[0].keys())
    lines = [
        "| " + " | ".join(cols) + " |",
        "| " + " | ".join(["---"] * len(cols)) + " |"
    ]
    for r in rows:
        lines.append("| " + " | ".join(str(r[c]) for c in cols) + " |")
    return "\n".join(lines) + "\n"


print("Helper function defined: md_table()")

---
## DDL Parser Functions

Functions to extract table names and columns from `CREATE TABLE` statements.

In [None]:
# ========= DDL Parser

# Regex to match CREATE TABLE statements
CREATE_TABLE_RE = re.compile(
    r'CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?P<tbl_name>[A-Za-z0-9_\."]+)\s*\((?P<body>.*?)\)\s*;',
    re.IGNORECASE | re.DOTALL,
)

# Regex to match column definitions
COLUMN_RE = re.compile(
    r'^\s*(?P<col>[A-Za-z0-9_"]+)\s+(?P<type>[A-Za-z0-9_\(\), ]+?)(?:,|$)',
    re.IGNORECASE
)

# Lines to skip (constraints, not columns)
SKIP_PREFIXES = ("PRIMARY", "UNIQUE", "FOREIGN", "CONSTRAINT", "DISTKEY", "SORTKEY", "--")


def parse_tables(sql_text: str) -> Dict[Tuple[str, str], List[Tuple[str, str]]]:
    """
    Parse CREATE TABLE statements from SQL text.
    
    Args:
        sql_text: String containing one or more CREATE TABLE statements
    
    Returns:
        Dict mapping (schema, table_name) to list of (column_name, data_type)
    """
    tables = {}
    
    for match in CREATE_TABLE_RE.finditer(sql_text):
        full_name = match.group("tbl_name").strip('"')
        body = match.group("body")
        
        # Parse schema.table format
        if "." in full_name:
            schema, table_name = full_name.split(".", 1)
        else:
            schema = ""
            table_name = full_name
        
        # Extract columns
        columns = []
        for line in body.splitlines():
            line = line.strip()
            
            # Skip empty lines and constraints
            if not line or line.upper().startswith(SKIP_PREFIXES):
                continue
            
            col_match = COLUMN_RE.match(line)
            if col_match:
                col_name = col_match.group("col").strip('"')
                col_type = col_match.group("type").strip().split()[0]  # First word of type
                columns.append((col_name, col_type))
        
        tables[(schema, table_name)] = columns
    
    return tables


print("Parser function defined: parse_tables()")

---
## Relationship Inference

Infer fact-to-dimension relationships by matching `*_sk` columns in fact tables to dimension table names.

**TODO**: Complete the `infer_relationships` function. The logic should:

1. Find all dimension tables (tables starting with `dim_prefix`)
2. Loop through all fact tables (tables starting with `fact_prefix`)
3. For each column ending with `_sk` in a fact table:
   - Derive the dimension name by stripping `_sk` and adding `dim_prefix`
   - Example: `rider_sk` -> `dw_dim_rider`
   - If that dimension exists, add a relationship tuple
4. Return a list of `(fact_table, dim_table, column_name)` tuples

In [None]:
# ========= Relationship Inference

def infer_relationships(
    tables: Dict[Tuple[str, str], List[Tuple[str, str]]],
    dim_prefix: str = "dw_dim_",
    fact_prefix: str = "dw_fact_"
) -> List[Tuple[str, str, str]]:
    """
    Infer relationships between fact and dimension tables.
    
    Logic: If a fact table has a column like `rider_sk`, look for `dw_dim_rider`.
    
    Args:
        tables: Output from parse_tables()
        dim_prefix: Prefix for dimension tables (default: "dw_dim_")
        fact_prefix: Prefix for fact tables (default: "dw_fact_")
    
    Returns:
        List of (fact_table, dim_table, column_name) tuples
    """
    # TODO: Find all dimension tables (hint: use a set comprehension)
    dims = set()
    
    relationships = []
    
    # TODO: Loop through tables and find fact tables
    for (schema, table_name), columns in tables.items():
        # TODO: Skip non-fact tables
        pass
        
        # TODO: For each column ending with "_sk", check if matching dimension exists
        for col_name, _ in columns:
            pass
    
    return relationships


print("Inference function defined: infer_relationships()")

---
## Mermaid Generator

Generate Mermaid ER diagram syntax from parsed tables and relationships.

**TODO**: Complete the `generate_mermaid` function. The output should:

1. Start with `erDiagram`
2. For each table, create a block like:
   ```
   table_name {
       TYPE column_name
       TYPE column_name
   }
   ```
3. For each relationship, add a line like:
   ```
   dim_table ||--o{ fact_table : "fk_column"
   ```
   (This means "one dim row to many fact rows")

In [None]:
# ========= Mermaid Generator

def generate_mermaid(
    tables: Dict[Tuple[str, str], List[Tuple[str, str]]],
    relationships: List[Tuple[str, str, str]]
) -> str:
    """
    Generate Mermaid ER diagram syntax.
    
    Args:
        tables: Output from parse_tables()
        relationships: Output from infer_relationships()
    
    Returns:
        String containing the Mermaid diagram
    """
    lines = ["erDiagram"]
    
    # TODO: Generate table blocks
    for (schema, table_name), columns in sorted(tables.items()):
        # TODO: Add table header
        # TODO: Add columns (format: "        TYPE column_name")
        # TODO: Close table block
        pass
    
    # TODO: Generate relationships
    # Format: dim_table ||--o{ fact_table : "column_name"
    for fact_table, dim_table, col_name in relationships:
        pass
    
    return "\n".join(lines)


print("Generator function defined: generate_mermaid()")

---
## Sample DDL

Use the DDL from Exercises 1 and 2, plus additional dimension tables to create a complete schema.

Note: We use the `public` schema with `dw_` prefixes to match the pattern from Exercises 1 and 2.

In [None]:
# ========= Sample DDL (from Exercises 1 and 2, plus supporting dims)

SAMPLE_DDL = """
-- =============================================================
-- DIMENSION TABLES
-- =============================================================

CREATE TABLE public.dw_dim_rider (
    rider_sk        BIGINT IDENTITY(1,1),
    rider_id        VARCHAR(32),
    rider_segment   VARCHAR(16),
    effective_from  TIMESTAMP,
    effective_to    TIMESTAMP,
    is_current      BOOLEAN,
    PRIMARY KEY (rider_sk)
);

CREATE TABLE public.dw_dim_route (
    route_sk   BIGINT IDENTITY(1,1),
    route_id   VARCHAR(32),
    PRIMARY KEY (route_sk)
);

CREATE TABLE public.dw_dim_mode (
    mode_sk   BIGINT IDENTITY(1,1),
    mode      VARCHAR(32),
    PRIMARY KEY (mode_sk)
);

CREATE TABLE public.dw_dim_station (
    station_sk  BIGINT IDENTITY(1,1),
    station_id  VARCHAR(32),
    city        VARCHAR(64),
    province    VARCHAR(32),
    latitude    DECIMAL(10,6),
    longitude   DECIMAL(10,6),
    PRIMARY KEY (station_sk)
);

CREATE TABLE public.dw_dim_date (
    date_key     INTEGER,
    date_actual  DATE,
    year         SMALLINT,
    quarter      SMALLINT,
    month        SMALLINT,
    day          SMALLINT,
    day_of_week  SMALLINT,
    is_weekend   BOOLEAN,
    PRIMARY KEY (date_key)
);

CREATE TABLE public.dw_dim_payment_method (
    payment_method_sk  BIGINT IDENTITY(1,1),
    payment_method     VARCHAR(32),
    PRIMARY KEY (payment_method_sk)
);

CREATE TABLE public.dw_dim_fare_class (
    fare_class_sk  BIGINT IDENTITY(1,1),
    fare_class     VARCHAR(32),
    PRIMARY KEY (fare_class_sk)
);

-- =============================================================
-- FACT TABLES
-- =============================================================

CREATE TABLE public.dw_fact_trips (
    trip_sk                 BIGINT IDENTITY(1,1),
    trip_id                 VARCHAR(32),
    rider_sk                BIGINT,
    route_sk                BIGINT,
    mode_sk                 BIGINT,
    origin_station_sk       BIGINT,
    destination_station_sk  BIGINT,
    board_date_key          INTEGER,
    alight_date_key         INTEGER,
    transfers               INTEGER,
    zones_charged           INTEGER,
    distance_km             DECIMAL(10,2),
    total_fare_cad          DECIMAL(12,2),
    payment_method_sk       BIGINT,
    fare_class_sk           BIGINT,
    on_time_arrival         BOOLEAN,
    service_disruption      BOOLEAN
);
"""

print(f"Sample DDL defined ({len(SAMPLE_DDL)} characters)")
print("Contains: dw_dim_rider, dw_dim_route, dw_dim_mode, dw_dim_station, dw_dim_date,")
print("          dw_dim_payment_method, dw_dim_fare_class, dw_fact_trips")

---
## Step 1: Parse the DDL

Extract tables and columns from the sample DDL.

In [None]:
tables = parse_tables(SAMPLE_DDL)

print(f"Parsed {len(tables)} tables:")
print("-" * 40)
for (schema, name), columns in sorted(tables.items()):
    print(f"  {schema}.{name}: {len(columns)} columns")

---
## Step 2: Infer Relationships

Find fact-to-dimension relationships based on `*_sk` column naming.

In [None]:
relationships = infer_relationships(tables)

print(f"Inferred {len(relationships)} relationships:")
print("-" * 40)

# Display as a table using md_table pattern
rel_data = [
    {"dimension": dim, "fact": fact, "fk_column": col}
    for fact, dim, col in relationships
]
print(md_table(rel_data))

---
## Step 3: Generate the Mermaid Diagram

Create the Mermaid ER diagram syntax.

In [None]:
mermaid_diagram = generate_mermaid(tables, relationships)

print("Generated Mermaid ER Diagram:")
print("=" * 60)
print(mermaid_diagram)

---
## Step 4: Save to File

Write the diagram to a `.mmd` file for use in documentation.

In [None]:
output_path = Path(OUTPUT_MERMAID)
output_path.write_text(mermaid_diagram, encoding="utf-8")

print(f"Diagram saved to: {output_path.absolute()}")
print("\nNext steps:")
print("  1. Preview at https://mermaid.live")
print("  2. Or use VS Code with Mermaid extension")
print("  3. Or render to PNG: mmdc -i schema.mmd -o schema.png")

---
## Step 5: Display in Notebook

Display the Mermaid code for easy copying.

In [None]:
# Display the raw Mermaid code in a code block for easy copying
from IPython.display import Markdown

display(Markdown(f"""```mermaid
{mermaid_diagram}
```"""))

---

## How to Use This in Your Project

### Option 1: From SQL Files

```python
# Read DDL from files
sql_files = ["dw_dim_rider.sql", "dw_dim_station.sql", "dw_fact_trips.sql"]
combined_sql = "\n\n".join(Path(f).read_text() for f in sql_files)

tables = parse_tables(combined_sql)
relationships = infer_relationships(tables)
mermaid = generate_mermaid(tables, relationships)

Path("my_schema.mmd").write_text(mermaid)
```

### Option 2: Use Pre-made Diagram

The final project includes `project-mermaid-diagram.md` which you can use directly:

```python
MERMAID_MD = os.path.join(BASE_DIR, "project-mermaid-diagram.md")
if os.path.exists(MERMAID_MD):
    with open(MERMAID_MD, "r", encoding="utf-8") as f: 
        mermaid = f.read().strip()
```

### Tips for Your Final Report

1. **Include the Mermaid source** in your documentation (renders in GitHub, Notion, etc.)
2. **Export to PNG/SVG** for PDF reports using Mermaid CLI
3. **Document the grain** for each fact table in accompanying text
4. **Highlight conformed dimensions** that span multiple facts

### Naming Convention

This exercise uses the same naming convention as the final project:
- All tables in the `public` schema
- Dimension tables prefixed with `dw_dim_` (e.g., `dw_dim_rider`)
- Fact tables prefixed with `dw_fact_` (e.g., `dw_fact_trips`)

### Connection to Project Solution

The project solution uses `md_table()` to generate markdown tables in the final report. The pattern you learned here will help you understand and extend the project's reporting capabilities.