#### Add state starting capacities
Update state starting capacities to reflect the current state of solar + storage penetration

In [None]:
# Imports
import sys, os
import pandas as pd
import numpy as np
from sqlalchemy import create_engine, text

sys.path.append(os.path.abspath(".."))

from input_data_functions import stacked_sectors

In [None]:
# Load data
df = pd.read_csv("../../../data/state_starting_capacities_to_model.csv")

In [None]:
# Normalize booleans if they came in as strings
df["net_metering"] = (
    df["net_metering"]
      .replace({"TRUE": True, "FALSE": False, "True": True, "False": False})
      .astype("boolean")
)

# Ensure numeric dtypes
for col in ["system_mw", "system_mwh", "systems_count"]:
    df[col] = pd.to_numeric(df[col], errors="coerce")

# ---- Enforce MWh logic ----
# Make system_mwh null for solar rows
is_solar = df["tech"].str.lower().eq("solar")
df.loc[is_solar, "system_mwh"] = np.nan

# For storage rows, ensure system_mwh >= 2 × system_mw
is_storage = df["tech"].str.lower().eq("storage")
required_mwh = 2.0 * df.loc[is_storage, "system_mw"]
needs_bump = df.loc[is_storage, "system_mwh"].isna() | (
    df.loc[is_storage, "system_mwh"] < required_mwh
)
df.loc[is_storage, "system_mwh"] = np.where(
    needs_bump, required_mwh, df.loc[is_storage, "system_mwh"]
)

# --- 1) ratio (systems per MW) by tech, using only net_metering == True ---
true_nm = df[df["net_metering"] == True].copy()

ratio_by_tech = (
    true_nm.groupby("tech", dropna=False)
           .agg(systems_count_sum=("systems_count", "sum"),
                system_mw_sum=("system_mw", "sum"))
           .assign(ratio=lambda g: np.where(
               g["system_mw_sum"] > 0,
               g["systems_count_sum"] / g["system_mw_sum"],
               np.nan
           ))["ratio"]
)
# No global fallback — strictly tech-specific

# --- 2) fill NA systems_count where net_metering == False using ratio * system_mw ---
df_filled = df.copy()
mask_fill = (df_filled["net_metering"] == False) & (df_filled["systems_count"].isna())

# Map per-tech ratio; rows for techs without a valid ratio will remain NaN
tech_ratio = df_filled.loc[mask_fill, "tech"].map(ratio_by_tech)
df_filled.loc[mask_fill, "systems_count"] = (
    df_filled.loc[mask_fill, "system_mw"] * tech_ratio
)

# Integer system counts
df_filled["systems_count"] = df_filled["systems_count"].round().astype("Int64")

# --- 3) aggregate by (tech, sector_abbr, state_abbr), summing metrics ---
# Use sum(min_count=1) so solar system_mwh stays NaN (not 0) when all inputs are NaN
group_cols = ["tech", "sector_abbr", "state_abbr"]
agg = (
    df_filled
      .groupby(group_cols, as_index=False, dropna=False)
      .agg(
          system_mw=("system_mw", "sum"),
          system_mwh=("system_mwh", lambda s: s.sum(min_count=1)),
          systems_count=("systems_count", "sum"),
      )
)

In [None]:
# Replace starting capacities in cloud sql

# Connection config 
DB_USER = "postgres"
DB_PASS = "postgres"
DB_NAME = "dgendb"
DB_PORT = 5432
DB_HOST = "127.0.0.1"  # Cloud SQL Proxy
conn_str = f"postgresql+psycopg2://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
engine = create_engine(conn_str)

TABLE_NAME = "state_starting_capacities_to_model_tbl"
VIEW_NAME  = "state_starting_capacities_to_model"
SCHEMA     = "diffusion_template"

def _type_sql(r):
    dt = r["data_type"]
    if dt == "character varying":
        l = r["character_maximum_length"]
        return f"varchar({l})" if l else "varchar"
    if dt == "character":
        l = r["character_maximum_length"]
        return f"char({l})" if l else "char"
    if dt == "numeric":
        p, s = r["numeric_precision"], r["numeric_scale"]
        return f"numeric({p},{s})" if p and s is not None else "numeric"
    if dt in ("double precision", "integer", "bigint", "real", "boolean", "text"):
        return dt
    # fallback to udt_name if needed
    return r["udt_name"]

with engine.begin() as con:
    # Ensure the table exists, then refresh contents safely
    agg.head(0).to_sql(TABLE_NAME, con=con, schema=SCHEMA, if_exists="append", index=False)
    con.execute(text(f"TRUNCATE TABLE {SCHEMA}.{TABLE_NAME};"))
    agg.to_sql(TABLE_NAME, con=con, schema=SCHEMA, if_exists="append", index=False)

    # Introspect existing view column types so CREATE OR REPLACE keeps types identical
    rows = con.execute(
        text("""
            SELECT column_name, data_type, character_maximum_length,
                   numeric_precision, numeric_scale, udt_name
            FROM information_schema.columns
            WHERE table_schema = :schema AND table_name = :view
            ORDER BY ordinal_position
        """),
        {"schema": SCHEMA, "view": VIEW_NAME},
    ).mappings().all()

    # Build CAST list that matches the current view's column types
    cast_selects = ", ".join(
        [f"CAST({r['column_name']} AS {_type_sql(r)}) AS {r['column_name']}" for r in rows]
    )

    # Re-point the view to the table with explicit casts to preserve types
    con.execute(text(f"""
        CREATE OR REPLACE VIEW {SCHEMA}.{VIEW_NAME} AS
        SELECT {cast_selects}
        FROM {SCHEMA}.{TABLE_NAME};
    """))

engine.dispose()