In [19]:
import os
import re
import sys
import math
import glob
import argparse
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.cluster import KMeans
from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import Ridge
from sklearn.ensemble import RandomForestRegressor, HistGradientBoostingRegressor
from sklearn.inspection import permutation_importance


# -----------------------------
# Column resolution (robust to schema variants)
# -----------------------------
CANONICAL_COLS = {
    "placekey": ["PLACEKEY", "placekey", "SAFEGRAPH_PLACE_ID", "safegraph_place_id", "PLACE_ID", "place_id"],
    "visits": ["RAW_VISIT_COUNTS", "raw_visit_counts", "VISIT_COUNTS", "visit_counts"],
    "state": ["REGION", "region", "STATE", "state"],
    "lat": ["LATITUDE", "latitude", "LAT", "lat"],
    "lon": ["LONGITUDE", "longitude", "LON", "lon"],
    "category": [
        "TOP_CATEGORY", "top_category",
        "SUB_CATEGORY", "sub_category",
        "NAICS_DESCRIPTION", "naics_description",
        "CATEGORY", "category",
        "BRANDS", "brands",
    ],
}


def _pick_first_existing(df_cols: List[str], candidates: List[str]) -> Optional[str]:
    s = set(df_cols)
    for c in candidates:
        if c in s:
            return c
    # try case-insensitive match
    upper_map = {c.upper(): c for c in df_cols}
    for c in candidates:
        if c.upper() in upper_map:
            return upper_map[c.upper()]
    return None


def resolve_columns(header_cols: List[str]) -> Dict[str, Optional[str]]:
    """Return mapping: canonical -> actual column name (or None)."""
    mapping = {}
    for k, cand in CANONICAL_COLS.items():
        mapping[k] = _pick_first_existing(header_cols, cand)
    return mapping


def read_csv_header(path: str) -> List[str]:
    return list(pd.read_csv(path, nrows=0).columns)


# -----------------------------
# Utilities
# -----------------------------
def haversine_km(lat1, lon1, lat2, lon2):
    # all in degrees
    R = 6371.0
    lat1 = np.deg2rad(lat1); lon1 = np.deg2rad(lon1)
    lat2 = np.deg2rad(lat2); lon2 = np.deg2rad(lon2)
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = np.sin(dlat/2.0)**2 + np.cos(lat1)*np.cos(lat2)*np.sin(dlon/2.0)**2
    c = 2*np.arcsin(np.sqrt(a))
    return R*c


def ensure_dir(p: str):
    os.makedirs(p, exist_ok=True)


def list_monthly_files(data_dir: str, year: int) -> Dict[int, str]:
    pat = os.path.join(data_dir, f"{year}-??--*.csv")
    files = sorted(glob.glob(pat))
    month_map = {}
    for f in files:
        m = re.search(rf"{year}-(\d\d)--", os.path.basename(f))
        if not m:
            continue
        month = int(m.group(1))
        month_map[month] = f
    return dict(sorted(month_map.items(), key=lambda x: x[0]))


# -----------------------------
# Core pipeline steps
# -----------------------------
@dataclass
class PanelResult:
    panel: pd.DataFrame
    table2: pd.DataFrame


def build_panel_from_monthly_csvs(
    data_dir: str,
    state: str,
    year: int,
    chunksize: int = 200_000,
) -> pd.DataFrame:
    month_files = list_monthly_files(data_dir, year)
    if not month_files:
        raise FileNotFoundError(f"No monthly files found for {year} in: {data_dir}")

    print(f"Found {len(month_files)} monthly files for {year}:")
    for mm, fp in month_files.items():
        print(f"  {mm:02d}: {os.path.basename(fp)}")

    # sniff header from first file
    header_cols = read_csv_header(next(iter(month_files.values())))
    colmap = resolve_columns(header_cols)

    # required
    if colmap["placekey"] is None or colmap["visits"] is None:
        raise ValueError(
            "Could not find required columns for placekey/visits.\n"
            f"Resolved mapping: {colmap}\n"
            "Please open one CSV header and add the correct aliases to CANONICAL_COLS."
        )

    # optional but strongly preferred
    st_col = colmap["state"]
    cat_col = colmap["category"]
    lat_col = colmap["lat"]
    lon_col = colmap["lon"]

    # We will read minimal columns (case-insensitive usecols callable)
    needed_upper = {colmap["placekey"].upper(), colmap["visits"].upper()}
    if st_col:  needed_upper.add(st_col.upper())
    if cat_col: needed_upper.add(cat_col.upper())
    if lat_col: needed_upper.add(lat_col.upper())
    if lon_col: needed_upper.add(lon_col.upper())

    def usecols(c):
        return c.upper() in needed_upper

    # store monthly visit series
    month_series: Dict[int, pd.Series] = {}
    # store meta per placekey
    meta = {}

    for mm, fp in month_files.items():
        print(f"\nReading month {mm:02d}: {fp}")
        acc = {}
        reader = pd.read_csv(fp, chunksize=chunksize, usecols=usecols, low_memory=False)

        for chunk in reader:
            # normalize column names (keep originals but access via resolved)
            pk = colmap["placekey"]
            vc = colmap["visits"]

            if st_col:
                # state filter
                chunk_state = chunk[st_col].astype("string")
                chunk = chunk.loc[chunk_state.str.upper() == state.upper()]
                if chunk.empty:
                    continue

            # clean core columns
            chunk[pk] = chunk[pk].astype("string")
            chunk[vc] = pd.to_numeric(chunk[vc], errors="coerce").fillna(0.0)

            # accumulate visits per placekey within this month
            grp = chunk.groupby(pk, dropna=False)[vc].sum()
            for k, v in grp.items():
                if pd.isna(k):
                    continue
                k = str(k)
                acc[k] = acc.get(k, 0.0) + float(v)

            # meta: take first-seen non-missing values
            if cat_col or lat_col or lon_col:
                sub = chunk[[c for c in [pk, cat_col, lat_col, lon_col] if c is not None]].copy()
                sub = sub.drop_duplicates(subset=[pk])
                for _, row in sub.iterrows():
                    k = row.get(pk, None)
                    if pd.isna(k):
                        continue
                    k = str(k)
                    if k in meta:
                        continue
                    meta[k] = {}
                    if cat_col:
                        meta[k]["poi_category"] = row.get(cat_col, pd.NA)
                    if lat_col:
                        meta[k]["latitude"] = row.get(lat_col, pd.NA)
                    if lon_col:
                        meta[k]["longitude"] = row.get(lon_col, pd.NA)

        s = pd.Series(acc, name=f"m{mm:02d}", dtype="float64")
        month_series[mm] = s
        print(f"  accumulated POIs: {len(s):,}")

    # build wide panel
    all_keys = pd.Index(sorted(set().union(*[set(s.index) for s in month_series.values()])))
    panel = pd.DataFrame(index=all_keys)

    for mm in range(1, 13):
        col = f"m{mm:02d}"
        if mm in month_series:
            panel[col] = month_series[mm].reindex(all_keys).fillna(0.0).astype("float64")
        else:
            panel[col] = 0.0

    # meta columns
    meta_df = pd.DataFrame.from_dict(meta, orient="index")
    panel = panel.join(meta_df, how="left")

    # clean meta types
    if "poi_category" in panel.columns:
        panel["poi_category"] = panel["poi_category"].astype("string").fillna("Unknown")
    else:
        panel["poi_category"] = "Unknown"

    for c in ["latitude", "longitude"]:
        if c in panel.columns:
            panel[c] = pd.to_numeric(panel[c], errors="coerce")
        else:
            panel[c] = np.nan

    # derived targets
    mcols = [f"m{mm:02d}" for mm in range(1, 13)]
    panel["avg_monthly_visits"] = panel[mcols].mean(axis=1)
    panel["min_monthly_visits"] = panel[mcols].min(axis=1)
    panel["max_monthly_visits"] = panel[mcols].max(axis=1)
    eps = 1e-9
    panel["seasonality_index"] = (
        (panel["max_monthly_visits"] - panel["min_monthly_visits"]) /
        (panel["max_monthly_visits"] + panel["min_monthly_visits"] + eps)
    )

    # distance to "state center" (mean lat/lon over POIs with coords)
    lat_mu = panel["latitude"].mean(skipna=True)
    lon_mu = panel["longitude"].mean(skipna=True)
    panel["dist_to_center_km"] = haversine_km(panel["latitude"], panel["longitude"], lat_mu, lon_mu)

    panel.index.name = "placekey"
    return panel


def make_fig2_visits_by_category(panel: pd.DataFrame, out_path: str, topk: int = 8):
    ensure_dir(os.path.dirname(out_path))
    df = panel.copy()
    # choose top categories by POI count
    top_cats = df["poi_category"].value_counts().head(topk).index.tolist()

    data_by_cat = []
    for cat in top_cats:
        v = df.loc[df["poi_category"] == cat, "avg_monthly_visits"].astype(float)
        # log-scale plot: keep positive; clamp 0 to 1 for visualization
        v = v.clip(lower=1.0)
        data_by_cat.append(v.values)

    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.boxplot(data_by_cat, tick_labels=top_cats, showfliers=False)
    ax.set_yscale("log")
    ax.set_xlabel("POI category (top by count)")
    ax.set_ylabel("Average monthly visits (log scale)")
    plt.setp(ax.get_xticklabels(), rotation=30, ha="right")
    fig.tight_layout()
    fig.savefig(out_path, dpi=200)
    plt.close(fig)


def make_fig3_seasonality_by_category(panel: pd.DataFrame, out_path: str, topk: int = 8):
    ensure_dir(os.path.dirname(out_path))
    df = panel.copy()
    top_cats = df["poi_category"].value_counts().head(topk).index.tolist()

    data_by_cat = []
    for cat in top_cats:
        v = df.loc[df["poi_category"] == cat, "seasonality_index"].astype(float)
        v = v.clip(lower=0.0, upper=1.0)
        data_by_cat.append(v.values)

    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.boxplot(data_by_cat, tick_labels=top_cats, showfliers=False)
    ax.set_xlabel("POI category (top by count)")
    ax.set_ylabel(r"Seasonality index $S_i$")
    plt.setp(ax.get_xticklabels(), rotation=30, ha="right")
    fig.tight_layout()
    fig.savefig(out_path, dpi=200)
    plt.close(fig)


def make_fig4_cluster_centroids(panel: pd.DataFrame, out_path: str, k: int = 4):
    ensure_dir(os.path.dirname(out_path))
    mcols = [f"m{mm:02d}" for mm in range(1, 13)]
    X = panel[mcols].to_numpy(dtype=float)

    # standardize per POI by its mean to highlight within-year pattern
    mu = X.mean(axis=1, keepdims=True)
    Xs = X / (mu + 1e-9)

    # keep only reasonable rows (avoid all-zero)
    mask = np.isfinite(Xs).all(axis=1) & (mu.flatten() > 0)
    Xs2 = Xs[mask]

    km = KMeans(n_clusters=k, random_state=42, n_init="auto")
    km.fit(Xs2)
    C = km.cluster_centers_  # k x 12

    fig = plt.figure()
    ax = fig.add_subplot(111)
    months = np.arange(1, 13)
    for i in range(k):
        ax.plot(months, C[i], marker="o", label=f"Cluster {i+1}")
    ax.set_xlabel("Month")
    ax.set_ylabel("Standardized monthly visits")
    ax.legend()
    fig.tight_layout()
    fig.savefig(out_path, dpi=200)
    plt.close(fig)


def fit_static_predictor_and_importance(panel: pd.DataFrame, out_path: str):
    """
    Predict log1p(avg_monthly_visits) from:
      - poi_category (categorical)
      - latitude, longitude, dist_to_center_km (numeric)
    Then permutation importance on ORIGINAL feature columns (safe, no mismatch).
    """
    ensure_dir(os.path.dirname(out_path))
    df = panel.copy()

    # features
    X = df[["poi_category", "latitude", "longitude", "dist_to_center_km"]].copy()
    X["poi_category"] = X["poi_category"].astype("string").fillna("Unknown")

    # numeric cleanup
    for c in ["latitude", "longitude", "dist_to_center_km"]:
        X[c] = pd.to_numeric(X[c], errors="coerce")
    X = X.fillna({ "latitude": X["latitude"].median(),
                   "longitude": X["longitude"].median(),
                   "dist_to_center_km": X["dist_to_center_km"].median() })

    y = np.log1p(df["avg_monthly_visits"].astype(float))

    # model
    cat_cols = ["poi_category"]
    num_cols = ["latitude", "longitude", "dist_to_center_km"]

    pre = ColumnTransformer(
        transformers=[
            ("cat", OneHotEncoder(handle_unknown="ignore", sparse_output=False), cat_cols),
            ("num", "passthrough", num_cols),
        ],
        remainder="drop",
    )

    model = HistGradientBoostingRegressor(random_state=42)
    pipe = Pipeline([("pre", pre), ("model", model)])

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    pipe.fit(X_train, y_train)

    r = permutation_importance(pipe, X_test, y_test, n_repeats=5, random_state=42)
    imp = pd.Series(r.importances_mean, index=X.columns).sort_values(ascending=False)

    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.barh(imp.index[::-1], imp.values[::-1])
    ax.set_xlabel("Permutation importance (relative)")
    ax.set_ylabel("Input feature")
    fig.tight_layout()
    fig.savefig(out_path, dpi=200)
    plt.close(fig)


def build_time_ordered_table2(panel: pd.DataFrame, out_csv_path: str) -> pd.DataFrame:
    """
    Table 2: time-ordered evaluation (train months<=10, test months>=11).
    Two targets:
      - Visit intensity: log1p(visits_t)
      - Seasonality proxy: standardized visits_t / (mean_train_per_poi)
    Models: Ridge, RandomForest, HistGBDT
    """
    ensure_dir(os.path.dirname(out_csv_path))
    mcols = [f"m{mm:02d}" for mm in range(1, 13)]

    # build long format
    base = panel.reset_index()[["placekey", "poi_category", "latitude", "longitude", "dist_to_center_km"] + mcols].copy()
    base["poi_category"] = base["poi_category"].astype("string").fillna("Unknown")

    # mean over TRAIN months (1..10) per POI for standardization (avoid leakage)
    train_months = [f"m{mm:02d}" for mm in range(1, 11)]
    base["mu_train"] = base[train_months].mean(axis=1).astype(float)

    rows = []
    for t in range(4, 13):  # need 3 lags
        mt = f"m{t:02d}"
        m1 = f"m{t-1:02d}"
        m2 = f"m{t-2:02d}"
        m3 = f"m{t-3:02d}"

        tmp = base[["placekey", "poi_category", "latitude", "longitude", "dist_to_center_km", "mu_train", mt, m1, m2, m3]].copy()
        tmp["month"] = t
        tmp["lag1"] = np.log1p(tmp[m1].astype(float))
        tmp["lag2"] = np.log1p(tmp[m2].astype(float))
        tmp["lag3"] = np.log1p(tmp[m3].astype(float))

        tmp["y_intensity"] = np.log1p(tmp[mt].astype(float))
        tmp["y_seasonality"] = (tmp[mt].astype(float) / (tmp["mu_train"].astype(float) + 1e-9))

        rows.append(tmp[["placekey", "poi_category", "latitude", "longitude", "dist_to_center_km", "month", "lag1", "lag2", "lag3", "y_intensity", "y_seasonality"]])

    longdf = pd.concat(rows, ignore_index=True)

    # split time-ordered
    train_df = longdf[longdf["month"] <= 10].copy()
    test_df  = longdf[longdf["month"] >= 11].copy()

    # features
    feat_cols_cat = ["poi_category", "month"]
    feat_cols_num = ["latitude", "longitude", "dist_to_center_km", "lag1", "lag2", "lag3"]

    X_train = train_df[feat_cols_cat + feat_cols_num].copy()
    X_test  = test_df[feat_cols_cat + feat_cols_num].copy()

    # clean numeric
    for c in feat_cols_num:
        X_train[c] = pd.to_numeric(X_train[c], errors="coerce")
        X_test[c]  = pd.to_numeric(X_test[c], errors="coerce")
    # fill missing
    for c in ["latitude", "longitude", "dist_to_center_km"]:
        med = X_train[c].median()
        X_train[c] = X_train[c].fillna(med)
        X_test[c]  = X_test[c].fillna(med)
    for c in ["lag1", "lag2", "lag3"]:
        X_train[c] = X_train[c].fillna(0.0)
        X_test[c]  = X_test[c].fillna(0.0)

    # categorical clean
    for c in feat_cols_cat:
        X_train[c] = X_train[c].astype("string").fillna("Unknown")
        X_test[c]  = X_test[c].astype("string").fillna("Unknown")

    pre = ColumnTransformer(
        transformers=[
            ("cat", OneHotEncoder(handle_unknown="ignore", sparse_output=False), feat_cols_cat),
            ("num", "passthrough", feat_cols_num),
        ],
        remainder="drop",
    )

    models = {
        "Ridge": Ridge(alpha=1.0, random_state=42),
        "RandomForest": RandomForestRegressor(n_estimators=200, random_state=42, n_jobs=-1),
        "HistGBDT": HistGradientBoostingRegressor(random_state=42),
    }

    def eval_one(y_train, y_test, target_name: str) -> List[dict]:
        out = []
        for model_name, model in models.items():
            pipe = Pipeline([("pre", pre), ("model", model)])
            pipe.fit(X_train, y_train)
            pred = pipe.predict(X_test)
            rmse = math.sqrt(mean_squared_error(y_test, pred))
            r2 = r2_score(y_test, pred)
            out.append({"Model": model_name, "Target": target_name, "RMSE": rmse, "R2": r2})
        return out

    res = []
    res += eval_one(train_df["y_intensity"].to_numpy(), test_df["y_intensity"].to_numpy(), "Visit intensity (log1p)")
    res += eval_one(train_df["y_seasonality"].to_numpy(), test_df["y_seasonality"].to_numpy(), "Seasonality proxy (v/mu_train)")

    table2 = pd.DataFrame(res)
    table2.to_csv(out_csv_path, index=False)
    return table2


def run_pipeline(
    data_dir: str,
    state: str = "VA",
    year: int = 2021,
    chunksize: int = 200_000,
    topk_categories: int = 8,
    k_clusters: int = 4,
) -> PanelResult:
    ensure_dir("out")
    ensure_dir("figs")

    panel = build_panel_from_monthly_csvs(data_dir=data_dir, state=state, year=year, chunksize=chunksize)

    out_panel = os.path.join("out", f"panel_{state}_{year}.csv.gz")
    panel.to_csv(out_panel, compression="gzip")
    print(f"\nSaved panel: {out_panel} (rows={len(panel):,})")

    # Figures (your fig1 already exists; we generate 2-5 here)
    make_fig2_visits_by_category(panel, os.path.join("figs", "fig2_visits_by_poi_type.png"), topk=topk_categories)
    print("Saved figs/fig2_visits_by_poi_type.png")

    make_fig3_seasonality_by_category(panel, os.path.join("figs", "fig3_seasonality_by_poi_type.png"), topk=topk_categories)
    print("Saved figs/fig3_seasonality_by_poi_type.png")

    make_fig4_cluster_centroids(panel, os.path.join("figs", "fig4_cluster_centroids.png"), k=k_clusters)
    print("Saved figs/fig4_cluster_centroids.png")

    fit_static_predictor_and_importance(panel, os.path.join("figs", "fig5_feature_importance.png"))
    print("Saved figs/fig5_feature_importance.png")

    # Table 2
    table2 = build_time_ordered_table2(panel, os.path.join("out", "table2_predictive_performance.csv"))
    print("Saved out/table2_predictive_performance.csv")

    return PanelResult(panel=panel, table2=table2)


# -----------------------------
# CLI - safe with Jupyter
# -----------------------------
def cli_main(argv=None):
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", required=True, help="Directory containing monthly CSVs like 2021-01--*.csv")
    parser.add_argument("--state", default="VA")
    parser.add_argument("--year", type=int, default=2021)
    parser.add_argument("--chunksize", type=int, default=200_000)
    parser.add_argument("--topk_categories", type=int, default=8)
    parser.add_argument("--k_clusters", type=int, default=4)

    # IMPORTANT: parse_known_args to ignore ipykernel -f ...json
    args, _ = parser.parse_known_args(argv)

    run_pipeline(
        data_dir=args.data_dir,
        state=args.state,
        year=args.year,
        chunksize=args.chunksize,
        topk_categories=args.topk_categories,
        k_clusters=args.k_clusters,
    )


