In [1]:
import os
import sys
import json
import math
import argparse
from collections import Counter
from datetime import datetime
import numpy as np
import pandas as pd
from sqlalchemy import create_engine, text
from scipy.stats import ks_2samp, chisquare

In [2]:
def parse_args():
    p = argparse.ArgumentParser(description="Validate sample vs. full dataset (bias, KS/Chi² tests).")
    p.add_argument("--digest-analysis", required=True, help="Path to digest_analysis.jsonl (sample)")
    p.add_argument("--unique-repos", help="Path to unique_repos.jsonl (full; fallback if no precomputed CSV)")
    p.add_argument("--combined-tags", help="Path to combined_tags.jsonl (full; fallback if no precomputed CSV)")
    p.add_argument("--db-url", default=os.environ.get("DATABASE_URL", ""), help="Postgres URL for sample DB")
    p.add_argument("--outdir", default="analysis_outputs", help="Output directory for CSVs")
    p.add_argument("--precomputed-dir", default="", help="Directory containing precomputed CSVs from Go")
    p.add_argument("--max-ks-samples", type=int, default=500000, help="Max reservoir sample size for JSONL fallback")
    return p.parse_args()

# ---------- helpers ----------

def ensure_outdir(p):
    os.makedirs(p, exist_ok=True)

def month_str(ts: pd.Timestamp) -> str:
    ts = pd.to_datetime(ts)
    return f"{ts.year:04d}-{ts.month:02d}"

def read_hist_csv(path: str) -> Counter:
    df = pd.read_csv(path)
    if df.empty or "key" not in df or "count" not in df:
        return Counter()
    return Counter({str(k): int(v) for k, v in zip(df["key"], df["count"])})

def read_series_csv(path: str) -> np.ndarray:
    df = pd.read_csv(path)
    if "value" not in df or df.empty:
        return np.array([], dtype=float)
    return df["value"].astype(float).to_numpy()

def normalize_cat(s):
    if s is None:
        return "unknown"
    s = str(s).strip()
    return s if s else "unknown"

def make_bias_table(full_counts: Counter, sample_counts: Counter) -> pd.DataFrame:
    keys = sorted(set(full_counts.keys()) | set(sample_counts.keys()))
    rows = []
    n_full = sum(full_counts.values())
    n_sam = sum(sample_counts.values())
    for k in keys:
        cf = full_counts.get(k, 0)
        cs = sample_counts.get(k, 0)
        pf = cf / n_full if n_full > 0 else 0.0
        ps = cs / n_sam if n_sam > 0 else 0.0
        bias = (pf / ps) if ps > 0 else (np.inf if pf > 0 else 1.0)
        rows.append({"key": k, "count_full": cf, "count_sample": cs, "p_full": pf, "p_sample": ps, "bias_full_over_sample": bias})
    return pd.DataFrame(rows)

def make_numeric_bias_bins(full_vals: np.ndarray, sample_vals: np.ndarray, nbins: int = 20) -> pd.DataFrame:
    full_vals = full_vals[~np.isnan(full_vals)]
    sample_vals = sample_vals[~np.isnan(sample_vals)]
    if len(full_vals) == 0 or len(sample_vals) == 0:
        return pd.DataFrame(columns=["bin_left","bin_right","count_full","count_sample","p_full","p_sample","bias_full_over_sample"])
    qs = np.linspace(0, 1, nbins + 1)
    edges = np.quantile(full_vals, qs)
    edges = np.unique(edges)
    if len(edges) < 3:
        edges = np.linspace(np.nanmin(full_vals), np.nanmax(full_vals), nbins + 1)
    cf, _ = np.histogram(full_vals, bins=edges)
    cs, _ = np.histogram(sample_vals, bins=edges)
    n_full, n_sam = cf.sum(), cs.sum()
    rows = []
    for i in range(len(edges) - 1):
        pf = cf[i] / n_full if n_full > 0 else 0.0
        ps = cs[i] / n_sam if n_sam > 0 else 0.0
        bias = (pf / ps) if ps > 0 else (np.inf if pf > 0 else 1.0)
        rows.append({
            "bin_left": edges[i],
            "bin_right": edges[i+1],
            "count_full": int(cf[i]),
            "count_sample": int(cs[i]),
            "p_full": pf,
            "p_sample": ps,
            "bias_full_over_sample": bias,
        })
    return pd.DataFrame(rows)

def save_series_counts(counts: Counter, out_path: str):
    total = sum(counts.values())
    rows = []
    for key, n in sorted(counts.items(), key=lambda kv: kv[0]):
        p = n / total if total > 0 else 0.0
        rows.append({"key": key, "count": n, "proportion": p})
    pd.DataFrame(rows).to_csv(out_path, index=False)

def compute_unique_layer_share_from_digest(digest_analysis_path):
    total = 0
    uniq = 0
    with open(digest_analysis_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            total += 1
            rc = obj.get("repo_count")
            if rc is None:
                repos = obj.get("repos") or []
                rc = len(repos)
            if rc == 1:
                uniq += 1
    share = (uniq / total) if total > 0 else math.nan
    return total, uniq, share

# ---------- FULL (prefer precomputed CSVs) ----------

def load_full_tags(pre_dir: str, combined_tags_path: str, max_ks_samples: int):
    if pre_dir:
        lp_month = read_hist_csv(os.path.join(pre_dir, "full_last_pushed_month.csv"))
        status = read_hist_csv(os.path.join(pre_dir, "full_status.csv"))
        lp_epochs = read_series_csv(os.path.join(pre_dir, "full_last_pushed_epoch_reservoir.csv"))
        size = read_series_csv(os.path.join(pre_dir, "full_size_reservoir.csv"))
        if sum(lp_month.values()) > 0 and len(lp_epochs) > 0 and len(size) > 0 and sum(status.values()) > 0:
            return {
                "last_pushed_month_counts": lp_month,
                "last_pushed_epochs_sample": lp_epochs.astype(float),
                "size_sample": size.astype(float),
                "status_counts": status,
            }
        # else: fall through to JSONL
    # JSONL fallback (slow)
    last_pushed_month_counts = Counter()
    last_pushed_reservoir = []
    size_reservoir = []
    status_counts = Counter()
    n_last = 0
    n_size = 0
    rng = np.random.default_rng(42)
    with open(combined_tags_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            lp = obj.get("last_pushed")
            if lp:
                ts = pd.to_datetime(lp, utc=True, errors="coerce")
                if pd.notna(ts):
                    last_pushed_month_counts[month_str(ts)] += 1
                    # reservoir
                    if n_last < max_ks_samples:
                        last_pushed_reservoir.append(ts.value / 1e9)
                        n_last += 1
                    else:
                        j = rng.integers(0, n_last + 1)
                        if j < max_ks_samples:
                            last_pushed_reservoir[j] = ts.value / 1e9
                        n_last += 1
            sz = obj.get("size", None)
            if sz is not None:
                try:
                    val = float(sz)
                    if n_size < max_ks_samples:
                        size_reservoir.append(val)
                        n_size += 1
                    else:
                        j = rng.integers(0, n_size + 1)
                        if j < max_ks_samples:
                            size_reservoir[j] = val
                        n_size += 1
                except Exception:
                    pass
            status = normalize_cat(obj.get("status"))
            status_counts[status] += 1
    return {
        "last_pushed_month_counts": last_pushed_month_counts,
        "last_pushed_epochs_sample": np.array(last_pushed_reservoir, dtype=float),
        "size_sample": np.array(size_reservoir, dtype=float),
        "status_counts": status_counts,
    }

def load_full_repos(pre_dir: str, unique_repos_path: str, max_ks_samples: int):
    if pre_dir:
        io_counts = read_hist_csv(os.path.join(pre_dir, "full_is_official.csv"))
        pull = read_series_csv(os.path.join(pre_dir, "full_pull_count_reservoir.csv"))
        if sum(io_counts.values()) > 0 and len(pull) > 0:
            return {
                "pull_count_sample": pull.astype(float),
                "is_official_counts": io_counts,
            }
        # else: fall through
    # JSONL fallback (slow)
    pull_reservoir = []
    n_pull = 0
    is_official_counts = Counter()
    rng = np.random.default_rng(43)
    with open(unique_repos_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            pc = obj.get("pull_count")
            if pc is not None:
                try:
                    val = float(pc)
                    if n_pull < max_ks_samples:
                        pull_reservoir.append(val)
                        n_pull += 1
                    else:
                        j = rng.integers(0, n_pull + 1)
                        if j < max_ks_samples:
                            pull_reservoir[j] = val
                        n_pull += 1
                except Exception:
                    pass
            iso = bool(obj.get("is_official"))
            is_official_counts["official" if iso else "unofficial"] += 1
    return {
        "pull_count_sample": np.array(pull_reservoir, dtype=float),
        "is_official_counts": is_official_counts,
    }

# ---------- SAMPLE (DB) ----------

def load_sample_from_db(db_url):
    engine = create_engine(db_url)
    with engine.connect() as conn:
        # last_pushed
        df_lp = pd.read_sql_query(text("SELECT last_pushed FROM tags WHERE last_pushed IS NOT NULL"), conn)
        # ns -> s
        lp_epochs = pd.to_datetime(df_lp["last_pushed"]).astype("int64").to_numpy(dtype=float) / 1e9
        df_lp["month_str"] = pd.to_datetime(df_lp["last_pushed"]).map(month_str)
        month_counts = Counter(df_lp["month_str"].tolist())

        # size
        df_sz = pd.read_sql_query(text("SELECT size FROM tags WHERE size IS NOT NULL"), conn)
        size_vals = df_sz["size"].astype(float).to_numpy()

        # status
        df_st = pd.read_sql_query(text("SELECT COALESCE(NULLIF(TRIM(status), ''), 'unknown') AS status FROM tags"), conn)
        status_counts = Counter(df_st["status"].tolist())

        # pull_count
        df_pc = pd.read_sql_query(text("SELECT pull_count FROM repositories WHERE pull_count IS NOT NULL"), conn)
        pull_vals = df_pc["pull_count"].astype(float).to_numpy()

        # is_official
        df_io = pd.read_sql_query(text("SELECT CASE WHEN is_official THEN 'official' ELSE 'unofficial' END AS cat FROM repositories"), conn)
        io_counts = Counter(df_io["cat"].tolist())

    return {
        "last_pushed_epochs": lp_epochs,
        "last_pushed_month_counts": month_counts,
        "size_vals": size_vals,
        "status_counts": status_counts,
        "pull_vals": pull_vals,
        "is_official_counts": io_counts,
    }
