In [1]:
import json
import re
from typing import List, Tuple, Dict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

In [4]:
def load_jsonl(path: str) -> pd.DataFrame:
    df = pd.read_json(path, lines=True)
    assert {"gt", "top_ids", "top_probs"}.issubset(df.columns), \
        "Input must have keys: gt, top_ids, top_probs"
    # Convert top_probs to float if it's a string
    if df["top_probs"].dtype == object:
        df["top_probs"] = df["top_probs"].apply(lambda x: [float(p) for p in x])
    return df

def print_last_k_lists(acc: np.ndarray, labels: List[str], support: np.ndarray, thresholds=(0.10, 0.05, 0.01)):
    def right_edge(label: str) -> float:
        base = label.split("(")[0].strip()
        sep = "–" if "–" in base else "-"
        parts = [p.strip() for p in base.split(sep)]
        try:
            return float(parts[-1])
        except Exception:
            nums = re.findall(r"[-+]?\d*\.\d+|\d+", base)
            return float(nums[-1]) if nums else np.nan

    rights = np.array([right_edge(s) for s in labels])
    order = np.argsort(rights)[::-1]

    for thr in thresholds:
        result = []
        for b in order:
            row = acc[b, :]
            mask = (~np.isnan(row)) & (row > thr)
            idxs = np.flatnonzero(mask)
            if support[b] < 10:
                result.append(1)
                continue
            result.append(max(1, int(idxs[-1] + 1) if idxs.size > 0 else 0))
        print(f"last_top_k_where_c_gt_{thr:.2f} = [{','.join(map(str, result))}]")

def make_bins_top1(
    top1: np.ndarray,
    bins: int = 10,
    quantile: bool = False,
) -> Tuple[np.ndarray, List[str], np.ndarray]:
    if quantile:
        cat = pd.qcut(top1, q=bins, duplicates="drop")
        codes = cat.cat.codes.to_numpy()
        intervals = cat.cat.categories
        labels = [f"{iv.left:.2f}–{iv.right:.2f}" for iv in intervals]
        edges = np.array([iv.left for iv in intervals] + [intervals[-1].right])
        return codes, labels, edges
    else:
        edges = np.linspace(0.0, 1.0, bins + 1)
        codes = np.digitize(top1, edges, right=True) - 1
        codes = np.clip(codes, 0, bins - 1)
        labels = [f"{edges[i]:.1f}–{edges[i+1]:.1f}" for i in range(bins)]
        return codes, labels, edges

def aggregate_rank_by_bin_conditional(
    df: pd.DataFrame,
    K: int,
    bins: int = 10,
    quantile_bins: bool = False,
):
    """
    Computes:
      acc[b, r]      = \bar c_{m,r}  = P(gold at rank r | bin b)
      mu_cond[b, r]  = \tilde p_{m,r}= E[p^(r) | bin b, gold at rank r]   <-- conditional mean
      std_cond[b, r] = std of p^(r) over the same conditional subset
      (also returns unconditional means for reference)
      C_topK[b]      = sum_r mu_cond[b,r] * acc[b,r]  (bin accuracy restricted to gold ∈ top-K)
    """
    top_ids = np.stack(df["top_ids"].to_numpy())
    top_probs = np.stack(df["top_probs"].to_numpy())
    assert top_ids.shape == top_probs.shape, "top_ids/top_probs shape mismatch"

    if K > top_probs.shape[1]:
        raise ValueError(f"K={K} > available top list length {top_probs.shape[1]}")

    top_ids = top_ids[:, :K]
    top_probs = top_probs[:, :K]
    top1 = top_probs[:, 0]
    gt = df["gt"].to_numpy()

    codes, labels, _ = make_bins_top1(top1, bins=bins, quantile=quantile_bins)
    B = len(labels)

    # rank-wise correctness (same as your current acc)
    acc = np.full((B, K), np.nan)          # \bar c_{m,r}
    support = np.zeros(B, dtype=int)       # N_m

    # NEW: conditional means (tilde p), and also keep unconditional for reference
    mu_cond = np.full((B, K), np.nan)      # \tilde p_{m,r}
    std_cond = np.full((B, K), np.nan)
    mu_uncond = np.full((B, K), np.nan)    # \bar p_{m,r}
    std_uncond = np.full((B, K), np.nan)

    # Optional: counts per (bin, rank) where gold is at rank r
    N_mr = np.zeros((B, K), dtype=int)

    for b in range(B):
        idx = (codes == b)
        n = int(idx.sum())
        support[b] = n
        if n == 0:
            continue

        probs_b = top_probs[idx, :]      # [n, K]
        ids_b = top_ids[idx, :]          # [n, K]
        gt_b = gt[idx]                   # [n]

        # Unconditional means (what you had before)
        mu_uncond[b, :] = probs_b.mean(axis=0)
        std_uncond[b, :] = probs_b.std(axis=0, ddof=0)

        # Rank-wise correctness and counts
        # mask[t, r] == True iff gold is at rank r for example t
        mask = (ids_b == gt_b[:, None])  # [n, K] bool
        counts = mask.sum(axis=0)        # [K]
        N_mr[b, :] = counts
        acc[b, :] = counts / n

        # CONDITIONAL means: average p^(r) ONLY over examples where gold is at rank r
        for r in range(K):
            if counts[r] > 0:
                vals = probs_b[mask[:, r], r]   # select p^(r) where gold-at-r
                mu_cond[b, r] = vals.mean()
                std_cond[b, r] = vals.std(ddof=0)
            else:
                # leave as NaN; contribution to sums will be zero because acc[b,r]=0
                pass

    # Bin accuracy restricted to top-K (sanity check / proxy):
    # C_topK[b] = sum_r \tilde p_{m,r} * \bar c_{m,r}
    C_topK = np.nansum(mu_cond * acc, axis=1)  # shape [B]

    return (
        acc,            # \bar c_{m,r}
        mu_cond,        # \tilde p_{m,r}  (use this for conditioned calculations)
        std_cond,
        support,
        labels,
        # extras (optional, but handy to inspect)
        mu_uncond,      # \bar p_{m,r}    (your previous "mu")
        std_uncond,
        N_mr,
        C_topK
    )

# --- New code for the dual-y plot you asked for ---
def compute_expected_accuracy(mu: np.ndarray, acc: np.ndarray) -> np.ndarray:
    """
    expected accuracy per bin = sum_k (mu_k * c_k)
    Handles NaNs by treating missing entries as 0 contribution.
    """
    mu_safe = np.nan_to_num(mu, nan=0.0)
    sum_mu = mu_safe.sum(axis=1)
    mu_safe = [mu_safe[i]/sum_mu[i] for i in range(mu_safe.shape[0])]
    acc_safe = np.nan_to_num(acc, nan=0.0)
    return (mu_safe * acc_safe).sum(axis=1)  # shape: [B]

def calculate_expected_acc_and_freq_data(
    model_specs: List[Dict[str, str]],
    K: int = 10,
    bins: int = 10,
    quantile_bins: bool = False,
):
    """
    Calculate expected accuracy and frequency data for plotting.
    
    Parameters
    ----------
    model_specs : list of {"label": str, "path": str}
        e.g., [{"label":"0.5B","path":"/path/a.jsonl"}, ...]
    K : int
        Use top-K from each file.
    bins : int
        Number of confidence bins (default 10 for 0.0–0.1 ... 0.9–1.0).
    quantile_bins : bool
        If True, use quantile bins; else fixed 0..1 bins.
        
    Returns
    -------
    dict with keys:
        - 'all_expected': List[np.ndarray] - expected accuracy for each model
        - 'all_freq_pct': List[np.ndarray] - frequency percentages for each model
        - 'x_labels_common': List[str] - bin labels
        - 'model_labels': List[str] - model labels
    """
    # Storage
    all_expected = []
    all_freq_pct = []
    x_labels_common = None
    model_labels = []

    # Process each model
    for spec in model_specs:
        label, path = spec["label"], spec["path"]
        df = load_jsonl(path)
        acc, mu, std, support, labels, _, _, _, _ = aggregate_rank_by_bin_conditional(df, K=K, bins=bins, quantile_bins=quantile_bins)

        # expected accuracy per bin: sum_k mu_k * c_k
        expected_acc = compute_expected_accuracy(mu[:, :K], acc[:, :K])  # [B]

        # frequency per bin (rounded to 4 dp), then percentage scale
        total = support.sum()
        freq = (support / total) if total > 0 else np.zeros_like(support, dtype=float)
        freq = np.round(freq, 4)  # to 4 dp
        freq_pct = 100.0 * freq

        all_expected.append(expected_acc)
        all_freq_pct.append(freq_pct)
        model_labels.append(label)

        # Capture/verify x labels
        if x_labels_common is None:
            x_labels_common = labels
        else:
            if labels != x_labels_common:
                raise ValueError("Bin labels differ across models; ensure consistent binning settings.")

        # Optional: print frequencies for visibility/debugging
        print(f"[{label}] frequency per bin (proportion, 4 dp): {freq.tolist()}")

    return {
        'all_expected': all_expected,
        'all_freq_pct': all_freq_pct,
        'x_labels_common': x_labels_common,
        'model_labels': model_labels
    }

def plot_expected_acc_and_freq(
    data: Dict,
    figsize: Tuple[int, int] = (12, 6),
    save_path: str | None = None,
    colors: List[str] | None = None,
    title: str = "",
    show_grid: bool = True,
    grid_alpha: float = 0.2,
    linewidth: float = 1.0,
    marker_size: int = 6,
):
    """
    Plot expected accuracy and frequency data.

    - X-axis ticks at 0.0, 0.1, ..., 1.0
    - Data points centered at bin midpoints
    - Dotted/solid line style legend placed on the left next to the model size legend
    
    Parameters
    ----------
    data : dict
        Output from calculate_expected_acc_and_freq_data()
    figsize : (w, h)
    save_path : str or None
        If provided, saves the figure; otherwise shows it.
    colors : List[str] or None
        Custom colors for models. If None, uses darker versions of default color cycle.
    title : str
        Plot title
    show_grid : bool
        Whether to show grid
    grid_alpha : float
        Grid transparency
    linewidth : float
        Line width for plots
    marker_size : int
        Size of markers (not used when markers are disabled)
    """
    all_expected = data['all_expected']
    all_freq_pct = data['all_freq_pct']
    x_labels_common = data['x_labels_common']
    model_labels = data['model_labels']

    # Colors (darker versions of default colors)
    if colors is None:
        color_cycle = plt.cm.tab10.colors  # up to 10 distinct colors
        # Make colors darker by reducing brightness
        colors = []
        for i in range(len(model_labels)):
            base_color = color_cycle[i % len(color_cycle)]
            # Convert to darker version by scaling RGB values
            darker_color = tuple(c * 0.7 for c in base_color[:3])  # 0.7 makes it darker
            colors.append(darker_color)

    num_bins = len(x_labels_common)
    # Bin midpoints in [0, 1]
    x_mid = (np.arange(num_bins) + 0.5) / num_bins

    # X ticks exactly at 0., 0.1, ..., 1.0 (match two-subplot style)
    x_ticks = np.linspace(0.0, 1.0, 11)
    x_tick_labels = []
    for t in x_ticks:
        if np.isclose(t, 0.0):
            x_tick_labels.append("0.")
        elif np.isclose(t, 1.0):
            x_tick_labels.append("1.0")
        else:
            x_tick_labels.append(f"{t:.1f}")

    # --- Plot ---
    fig, ax_left = plt.subplots(figsize=figsize)
    ax_right = ax_left.twinx()

    # Left y-axis: expected accuracy (0..1), dashed lines
    for i, model_label in enumerate(model_labels):
        ax_left.plot(
            x_mid, all_expected[i],
            linestyle=(0, (3, 2)),  # dashed
            linewidth=linewidth, color=colors[i],
            label=f"{model_label} (expected acc)"
        )

    ax_left.set_xlim(0.0, 1.0)
    ax_left.set_ylim(0.0, 1.0)
    ax_left.set_xticks(x_ticks)
    ax_left.set_xticklabels(x_tick_labels)
    ax_left.set_xlabel("Confidence bin")
    ax_left.set_ylabel("Expected accuracy")

    # Right y-axis: frequency in percent (0..100), solid lines
    for i, model_label in enumerate(model_labels):
        ax_right.plot(
            x_mid, all_freq_pct[i],
            linestyle="-", linewidth=linewidth, color=colors[i],
            label=f"{model_label} (freq)"
        )

    ax_right.set_ylim(0.0, 55.0)
    ax_right.set_ylabel("Frequency of occurrences (%)")

    # Legends: (1) colors map to models, (2) line styles map to metrics
    # Build model legend (colors)
    model_handles = [Line2D([0], [0], color=colors[i], lw=3, label=model_labels[i]) for i in range(len(model_labels))]
    # Build style legend
    style_handles = [
        Line2D([0], [0], color="black", lw=2, linestyle=(0, (3, 2)), label="Expected accuracy"),
        Line2D([0], [0], color="black", lw=2, linestyle="-", label="Frequency"),
    ]

    # Place both legends left, side-by-side
    first_legend = ax_left.legend(handles=model_handles, title="Model size", loc="upper left", bbox_to_anchor=(0.03, 0.98), borderaxespad=0.0)
    ax_left.add_artist(first_legend)
    ax_left.legend(handles=style_handles, loc="upper left", bbox_to_anchor=(0.2, 0.98), borderaxespad=0.0, title=None)

    if show_grid:
        ax_left.grid(True, axis="y", alpha=grid_alpha)
    plt.title(title)

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=200, bbox_inches="tight")
        print(f"Saved figure to: {save_path}")
    else:
        plt.show()

def plot_expected_acc_and_freq_from_specs(
    model_specs: List[Dict[str, str]],
    K: int = 10,
    bins: int = 10,
    quantile_bins: bool = False,
    figsize: Tuple[int, int] = (12, 6),
    save_path: str | None = None,
    **plot_kwargs
):
    """
    Convenience function that combines data calculation and plotting.
    This is the original function behavior for backward compatibility.
    """
    data = calculate_expected_acc_and_freq_data(
        model_specs=model_specs,
        K=K,
        bins=bins,
        quantile_bins=quantile_bins
    )
    plot_expected_acc_and_freq(data, figsize=figsize, save_path=save_path, **plot_kwargs)

In [None]:
models_to_plot = [
    {"label": "0.5B",  "path": "/selective_greedy/softmax_values/topk_traces_20250904_202648.jsonl"},
    {"label": "1.5B",  "path": "/selective_greedy/softmax_values/topk_traces_20250904_202131.jsonl"},
    {"label": "3B",    "path": "/selective_greedy/softmax_values/topk_traces_20250904_194548.jsonl"},
    {"label": "7B",   "path": "/selective_greedy/softmax_values/topk_traces_20250904_190655.jsonl"},
    {"label": "14B",    "path": "/selective_greedy/softmax_values/topk_traces_20250904_192814.jsonl"},
]

# Step 1: Calculate the data (only need to run this once)
print("Calculating data...")
data = calculate_expected_acc_and_freq_data(
    model_specs=models_to_plot,
    K=20,                 # or 20, etc.
    bins=10,              # fixed bins: 0.0–0.1, ..., 0.9–1.0
    quantile_bins=False,  # keep fixed edges as requested
)

# Step 2: Create the plot with updated styling
custom_colors = ["#0072B2", "#E69F00", "#009E73", "#D55E00", "#CC79A7", "#56B4E9"]  # darker versions

plot_expected_acc_and_freq(
    data=data,  # Reuse the same data from above
    figsize=(8, 4),
    colors=custom_colors,
    #title="Custom Styled Plot - Expected Accuracy vs Frequency",
    linewidth=1.2,  # slightly thicker than default but still thin
    show_grid=True,
    grid_alpha=0.1,
    save_path="/selective_greedy/post_processing/figures/var_size.pdf"
)

In [None]:
# Example: Create a different style plot using the same data
# You can run this cell multiple times with different parameters without recalculating data

# Custom darker colors for a different look
custom_colors = ["#0072B2", "#E69F00", "#009E73", "#D55E00", "#CC79A7", "#56B4E9"]  # darker versions

plot_expected_acc_and_freq(
    data=data,  # Reuse the same data from above
    figsize=(8, 3.5),
    colors=custom_colors,
    #title="Custom Styled Plot - Expected Accuracy vs Frequency",
    linewidth=1.2,  # slightly thicker than default but still thin
    show_grid=True,
    grid_alpha=0.1,
    save_path="var_size.pdf"
)
