# Extract DiD Tables from MySQL

Extracts difference-in-differences panel and estimate tables from MySQL database `causal_analytics`.

**Inputs:**
- MySQL database connection (credentials from `.env` in repo root)
- Schema: `causal_analytics`

**Outputs:**
- Parquet files in `data/intermediate/`
- Manifest JSON in `results/run_manifests/`


## Imports


In [12]:
import os
import sys
import json
import logging
from pathlib import Path
from datetime import datetime, timezone
from typing import Dict, List, Optional

import pandas as pd
from sqlalchemy import create_engine, text
from sqlalchemy.engine import URL
from dotenv import load_dotenv

try:
    import pyarrow as pa
    import pyarrow.parquet as pq
    HAS_PYARROW = True
except ImportError:
    HAS_PYARROW = False

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


## Repository Root Detection


In [13]:
def find_repo_root(start: Path) -> Path:
    """Find repository root by locating README.md."""
    p = start.resolve()
    for _ in range(12):
        if (p / "README.md").exists():
            return p
        p = p.parent
    raise RuntimeError("Repository root not found (README.md missing)")


REPO_ROOT = find_repo_root(Path.cwd())
logger.info(f"Repository root: {REPO_ROOT}")


2026-01-14 14:31:50,100 - INFO - Repository root: /Users/rajnishpanwar/Desktop/Casual Analytics


## Load Environment Variables


In [14]:
env_path = REPO_ROOT / ".env"
if not env_path.exists():
    raise FileNotFoundError(
        f"Environment file not found: {env_path}. "
        f"Create .env in repo root with MYSQL_HOST, MYSQL_PORT, MYSQL_DB, MYSQL_USER, MYSQL_PASSWORD."
    )

load_dotenv(env_path)

MYSQL_HOST = os.getenv("MYSQL_HOST")
MYSQL_PORT = os.getenv("MYSQL_PORT", "3306")
MYSQL_DB = os.getenv("MYSQL_DB")
MYSQL_USER = os.getenv("MYSQL_USER")
MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD")

if not all([MYSQL_HOST, MYSQL_DB, MYSQL_USER, MYSQL_PASSWORD]):
    missing = [k for k, v in {
        "MYSQL_HOST": MYSQL_HOST,
        "MYSQL_DB": MYSQL_DB,
        "MYSQL_USER": MYSQL_USER,
        "MYSQL_PASSWORD": MYSQL_PASSWORD
    }.items() if not v]
    raise ValueError(f"Missing required environment variables: {', '.join(missing)}")


## Environment Variable Diagnostics


In [15]:
print(f"MYSQL_HOST: {repr(MYSQL_HOST)}")
print(f"MYSQL_PORT: {MYSQL_PORT}")
print(f"MYSQL_DB: {repr(MYSQL_DB)}")
print(f"MYSQL_USER: {repr(MYSQL_USER)}")
print(f"MYSQL_PASSWORD set: {MYSQL_PASSWORD is not None}")

if MYSQL_HOST.startswith('@'):
    raise ValueError(
        f"MYSQL_HOST starts with '@': {repr(MYSQL_HOST)}. "
        f"Fix {REPO_ROOT}/.env - remove leading '@' from MYSQL_HOST."
    )

if ' ' in MYSQL_HOST:
    raise ValueError(
        f"MYSQL_HOST contains spaces: {repr(MYSQL_HOST)}. "
        f"Fix {REPO_ROOT}/.env - remove spaces from MYSQL_HOST."
    )

try:
    port_int = int(MYSQL_PORT)
    if not (1 <= port_int <= 65535):
        raise ValueError(f"MYSQL_PORT out of range: {MYSQL_PORT}")
except ValueError as e:
    raise ValueError(f"Invalid MYSQL_PORT: {MYSQL_PORT}. Must be integer 1-65535.") from e


MYSQL_HOST: 'localhost'
MYSQL_PORT: 3306
MYSQL_DB: 'causal_analytics'
MYSQL_USER: 'root'
MYSQL_PASSWORD set: True


## Create Database Engine


In [16]:
url = URL.create(
    drivername="mysql+pymysql",
    username=MYSQL_USER,
    password=MYSQL_PASSWORD,
    host=MYSQL_HOST,
    port=int(MYSQL_PORT),
    database=MYSQL_DB
)

engine = create_engine(url, pool_pre_ping=True)
logger.info(f"Engine created for {MYSQL_USER}@{MYSQL_HOST}:{MYSQL_PORT}/{MYSQL_DB}")


2026-01-14 14:31:50,117 - INFO - Engine created for root@localhost:3306/causal_analytics


## Connectivity Check


In [17]:
try:
    with engine.connect() as conn:
        result = conn.execute(text("SELECT 1"))
        result.fetchone()
    logger.info("Database connection successful")
except Exception as e:
    logger.error(f"Database connection failed: {e}")
    raise RuntimeError(
        f"Cannot connect to database. Check credentials in {REPO_ROOT}/.env and ensure MySQL server is running."
    ) from e


2026-01-14 14:31:50,137 - INFO - Database connection successful


## Configuration


In [18]:
SCHEMA = "causal_analytics"

TABLES_TO_EXTRACT = [
    "did_campaign_panel_purchase",
    "did_event_study_weekly_purchase",
    "did_campaign_estimates_purchase",
    "did_campaign_estimates_purchase_valid",
    "did_campaign_estimates_purchase_invalid",
    "did_panel_coverage_purchase",
    "did_overall_cell_means_purchase",
    "did_overall_estimate_purchase"
]

OUTPUT_DIR = REPO_ROOT / "data" / "intermediate"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

MANIFEST_DIR = REPO_ROOT / "results" / "run_manifests"
MANIFEST_DIR.mkdir(parents=True, exist_ok=True)


## Extract Tables


In [19]:
def extract_table(engine, schema: str, table_name: str) -> Optional[pd.DataFrame]:
    """Extract table from MySQL. Returns None if table does not exist."""
    try:
        query = f"SELECT * FROM `{schema}`.`{table_name}`"
        df = pd.read_sql(query, engine)
        logger.info(f"Extracted {table_name}: {len(df)} rows, {len(df.columns)} columns")
        return df
    except Exception as e:
        logger.error(f"Failed to extract {table_name}: {e}")
        return None


extracted_data = {}
missing_tables = []

for table_name in TABLES_TO_EXTRACT:
    df = extract_table(engine, SCHEMA, table_name)
    if df is not None:
        extracted_data[table_name] = df
    else:
        missing_tables.append(table_name)

if missing_tables:
    raise RuntimeError(
        f"Missing required tables in schema {SCHEMA}: {', '.join(missing_tables)}. "
        f"Ensure SQL scripts have been executed to create these tables."
    )


2026-01-14 14:31:51,202 - INFO - Extracted did_campaign_panel_purchase: 122536 rows, 11 columns
2026-01-14 14:31:51,371 - INFO - Extracted did_event_study_weekly_purchase: 27183 rows, 7 columns
2026-01-14 14:31:51,392 - INFO - Extracted did_campaign_estimates_purchase: 1584 rows, 20 columns
2026-01-14 14:31:51,394 - INFO - Extracted did_campaign_estimates_purchase_valid: 9 rows, 23 columns
2026-01-14 14:31:51,414 - INFO - Extracted did_campaign_estimates_purchase_invalid: 1575 rows, 24 columns
2026-01-14 14:31:51,416 - INFO - Extracted did_panel_coverage_purchase: 1 rows, 5 columns
2026-01-14 14:31:51,417 - INFO - Extracted did_overall_cell_means_purchase: 4 rows, 6 columns
2026-01-14 14:31:51,419 - INFO - Extracted did_overall_estimate_purchase: 1 rows, 10 columns


## Validation


In [20]:
validation_results = {}

# Panel table validation
if "did_campaign_panel_purchase" in extracted_data:
    df_panel = extracted_data["did_campaign_panel_purchase"]
    errors = []
    
    required_cols = ["campaign_id", "household_id", "week_number", "treated", "post"]
    missing_cols = [col for col in required_cols if col not in df_panel.columns]
    if missing_cols:
        errors.append(f"Missing required columns: {missing_cols}")
    
    null_counts = {}
    for col in required_cols:
        if col in df_panel.columns:
            null_count = df_panel[col].isnull().sum()
            if null_count > 0:
                null_counts[col] = int(null_count)
    if null_counts:
        errors.append(f"Null values in required columns: {null_counts}")
    
    treated_post_crosstab = pd.crosstab(df_panel["treated"], df_panel["post"], margins=True)
    
    distinct_households = {}
    for t in [0, 1]:
        for p in [0, 1]:
            subset = df_panel[(df_panel["treated"] == t) & (df_panel["post"] == p)]
            distinct_households[f"treated_{t}_post_{p}"] = subset["household_id"].nunique()
    
    validation_results["did_campaign_panel_purchase"] = {
        "valid": len(errors) == 0,
        "errors": errors,
        "summary": {
            "treated_post_crosstab": treated_post_crosstab.to_dict(),
            "distinct_households_by_cell": distinct_households
        }
    }
    
    if errors:
        logger.warning(f"Panel validation errors: {errors}")
    else:
        logger.info("Panel table validation passed")

# Estimates table validation
if "did_campaign_estimates_purchase" in extracted_data:
    df_estimates = extracted_data["did_campaign_estimates_purchase"]
    errors = []
    
    if "did_incomplete_2x2" not in df_estimates.columns:
        errors.append("Missing column: did_incomplete_2x2")
    
    if "campaign_id" not in df_estimates.columns:
        errors.append("Missing column: campaign_id")
    
    n_campaigns = len(df_estimates)
    n_incomplete = 0
    n_null_lift = 0
    
    if "did_incomplete_2x2" in df_estimates.columns:
        n_incomplete = int(df_estimates["did_incomplete_2x2"].sum())
    
    if "did_sales_lift" in df_estimates.columns:
        n_null_lift = int(df_estimates["did_sales_lift"].isnull().sum())
    
    validation_results["did_campaign_estimates_purchase"] = {
        "valid": len(errors) == 0,
        "errors": errors,
        "summary": {
            "n_campaigns": n_campaigns,
            "n_incomplete_2x2": n_incomplete,
            "n_null_did_sales_lift": n_null_lift
        }
    }
    
    print(f"Campaigns: {n_campaigns}, Incomplete 2x2: {n_incomplete}, Null lift: {n_null_lift}")
    
    if errors:
        logger.warning(f"Estimates validation errors: {errors}")
    else:
        logger.info("Estimates table validation passed")

# Valid/invalid counts
if "did_campaign_estimates_purchase_valid" in extracted_data:
    df_valid = extracted_data["did_campaign_estimates_purchase_valid"]
    n_valid = len(df_valid)
    logger.info(f"Valid campaigns: {n_valid}")

if "did_campaign_estimates_purchase_invalid" in extracted_data:
    df_invalid = extracted_data["did_campaign_estimates_purchase_invalid"]
    n_invalid = len(df_invalid)
    logger.info(f"Invalid campaigns: {n_invalid}")
    
    if "invalid_reason" in df_invalid.columns:
        reason_counts = df_invalid["invalid_reason"].value_counts().to_dict()
        print(f"Invalid reason breakdown: {reason_counts}")


2026-01-14 14:31:51,447 - INFO - Panel table validation passed
2026-01-14 14:31:51,447 - INFO - Estimates table validation passed
2026-01-14 14:31:51,448 - INFO - Valid campaigns: 9
2026-01-14 14:31:51,448 - INFO - Invalid campaigns: 1575


Campaigns: 1584, Incomplete 2x2: 1569, Null lift: 1569
Invalid reason breakdown: {'missing_2x2_cell': 1569, 'control_hh_pre_below_threshold': 4, 'treated_hh_pre_below_threshold': 2}


## Export to Parquet


In [21]:
files_written = {}

for table_name, df in extracted_data.items():
    output_path = OUTPUT_DIR / f"{table_name}.parquet"
    
    try:
        if HAS_PYARROW:
            df.to_parquet(output_path, index=False, engine='pyarrow')
        else:
            logger.warning("pyarrow not available, falling back to CSV")
            output_path = OUTPUT_DIR / f"{table_name}.csv"
            df.to_csv(output_path, index=False)
        
        files_written[table_name] = str(output_path)
        logger.info(f"Exported {table_name} to {output_path}")
    except Exception as e:
        logger.error(f"Failed to export {table_name}: {e}")
        raise


2026-01-14 14:31:51,479 - INFO - Exported did_campaign_panel_purchase to /Users/rajnishpanwar/Desktop/Casual Analytics/data/intermediate/did_campaign_panel_purchase.parquet
2026-01-14 14:31:51,482 - INFO - Exported did_event_study_weekly_purchase to /Users/rajnishpanwar/Desktop/Casual Analytics/data/intermediate/did_event_study_weekly_purchase.parquet
2026-01-14 14:31:51,484 - INFO - Exported did_campaign_estimates_purchase to /Users/rajnishpanwar/Desktop/Casual Analytics/data/intermediate/did_campaign_estimates_purchase.parquet
2026-01-14 14:31:51,489 - INFO - Exported did_campaign_estimates_purchase_valid to /Users/rajnishpanwar/Desktop/Casual Analytics/data/intermediate/did_campaign_estimates_purchase_valid.parquet
2026-01-14 14:31:51,492 - INFO - Exported did_campaign_estimates_purchase_invalid to /Users/rajnishpanwar/Desktop/Casual Analytics/data/intermediate/did_campaign_estimates_purchase_invalid.parquet
2026-01-14 14:31:51,493 - INFO - Exported did_panel_coverage_purchase to /U

## Create Manifest


In [22]:
def get_package_versions() -> Dict[str, str]:
    """Get versions of key packages."""
    versions = {}
    for package in ["pandas", "sqlalchemy", "pymysql"]:
        try:
            mod = __import__(package)
            versions[package] = getattr(mod, "__version__", "unknown")
        except ImportError:
            versions[package] = "not_installed"
    if HAS_PYARROW:
        versions["pyarrow"] = pa.__version__
    return versions


manifest = {
    "timestamp": datetime.now(timezone.utc).isoformat(),
    "python_version": sys.version,
    "package_versions": get_package_versions(),
    "tables_extracted": {
        name: {
            "row_count": len(df),
            "column_count": len(df.columns)
        }
        for name, df in extracted_data.items()
    },
    "files_written": files_written,
    "missing_tables": missing_tables,
    "validation_results": validation_results
}

manifest_path = MANIFEST_DIR / f"extraction_manifest_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}.json"

with open(manifest_path, 'w') as f:
    json.dump(manifest, f, indent=2, default=str)

logger.info(f"Manifest written to {manifest_path}")
print(f"\nManifest: {manifest_path}")


2026-01-14 14:31:51,500 - INFO - Manifest written to /Users/rajnishpanwar/Desktop/Casual Analytics/results/run_manifests/extraction_manifest_20260114_143151.json



Manifest: /Users/rajnishpanwar/Desktop/Casual Analytics/results/run_manifests/extraction_manifest_20260114_143151.json
