# Pass@k curves

In [None]:
from typing import Optional, TypeVar
import matplotlib.pyplot as plt
from pathlib import Path
from coderm.eval.metrics import get_pass_ks, pass_at_k, get_pass_ks_given_public, get_num_completions_per_problem, get_num_pass_public_per_problem
from coderm.utils import gunzip_json_read

from adjustText import adjust_text
import numpy as np
from math import comb
from itertools import product
import os

def calcEstVar(n, k, c):
    p = c / n
    var = 0
    for i in range(n+1):
        var += comb(n-i, k) * p**i / comb(n, k) * (comb(n-k, i) * (1-p)**(n-i))
    return var - (1-p)**(2*k)

def convert_basic_prompting(method: str):
    # if method == "basic_prompting":
    #     return "default"
    return method

T = TypeVar('T')
def split_dict_by_datasets(d: dict[tuple[str, str, str], T]) -> dict[tuple[str, str], T]:
    output_dict = {}
    for k, v in d.items():
        if k[0] not in output_dict:
            output_dict[k[0]] = {}
        output_dict[k[0]][k[1:]] = v
    return output_dict

def calc_pass_k(ns, cs, k):
    pass_ks = []
    for n, c in zip(ns, cs):
        pass_ks.append(pass_at_k(n, c, k))
    return pass_ks

def count_ns(items, is_public: bool = False) -> list[int]:
    if not is_public:
        return [len(item["results"]) for item in items]
    else:
        out = []
        for item in items:
            n = 0
            for result in item["results"]:
                if result.get("passing_public", None) is None:
                    return None
                if result["passing_public"]:
                    n += 1
            out.append(n)
        return out

def count_cs(items, is_public: bool = False) -> list[int]:
    if not is_public:
        return [sum(ex["passing"] for ex in item["results"]) for item in items]
    else:
        out = []
        for item in items:
            correct = 0
            for result in item["results"]:
                if result.get("passing_public", None) is None:
                    return None
                if result["passing"]:
                    if not result["passing_public"]:
                        print("Warning: passes private but does not pass public")
                    else:
                        correct += 1
            out.append(correct)
        return out

def calc_pass_k(ns, cs, k):
    pass_ks = []
    for n, c in zip(ns, cs):
        pass_ks.append(pass_at_k(n, c, k))
    return pass_ks

class Result():
    def __init__(self, base_directory: str, diversity_directory: Optional[str], dataset: str, method: str, model: str, temp: float = 0.9) -> None:
        self.dataset = dataset
        self.method = method
        self.model = model
        self.temp = temp
        self.path = os.path.join(base_directory, dataset, convert_basic_prompting(method) + "_" + model + f"_temp{self.temp}")
        if diversity_directory is None:
            self.diversity_path = None
        else:
            self.diversity_path = os.path.join(diversity_directory, dataset, convert_basic_prompting(method) + "_" + model + f"_temp{self.temp}", "results.npy")

        self.pass_ks = None
        self.pass_ks_given_public = None
        self.num_pass_public = None
        self.stds = None
        self.diversities = None

    def pass_k_exists(self) -> bool:
        return Path(self.path).exists()
    def diversity_exists(self) -> bool:
        if self.diversity_path is not None:
            return Path(self.diversity_path).exists()
        return False
   
    def _load_pass_k_results(self):
        if self.pass_ks is not None:
            return

        items = gunzip_json_read(self.path)["items"]
        ns = count_ns(items)
        cs = count_cs(items)
        upper_k = max(len(item["results"]) for item in items)
        pass_ks = []
        for k in range(1, upper_k+1):
            pass_ks.append(np.mean(calc_pass_k(ns, cs, k)))
        self.pass_ks = np.array(pass_ks)
 
        # vars = []
        # for item in items:
        #     single_problem = []
        #     for k in range(1, upper_k+1):
        #         single_problem.append(calcEstVar(len(items[0]["results"]), k, sum(i["passing"] for i in item["results"])))
        #     vars.append(single_problem)
        # vars = np.array(vars)
        # self.stds = np.sqrt(np.sum(vars, axis=0) / len(items) ** 2) * 2.5
    
    def _load_pass_k_public_results(self):
        assert self.pass_k_exists()
        items = gunzip_json_read(self.path)["items"]
        upper_k = max(len(item["results"]) for item in items)
        ns = count_ns(items, is_public=True)
        cs = count_cs(items, is_public=True)

        pass_ks_gp = []
        for k in range(1, upper_k+1):
            assert (ns is not None) and (cs is not None)
            public_ks = calc_pass_k(ns, cs, k)
            # public_ks = get_pass_ks_given_public(items, k)
            assert public_ks is not None
            pass_ks_gp.append(np.mean(public_ks))

        self.pass_ks_given_public = np.array(pass_ks_gp)
        self.num_pass_public = np.array(cs)
   
    def _load_diversity_results(self):
        assert self.diversity_exists()
        self.diversities = np.load(self.diversity_path)

    def get_diversities(self) -> np.ndarray:
        if self.diversities is None:
            self._load_diversity_results()
        return self.diversities

    def get_pass_ks(self) -> np.ndarray:
        if self.pass_ks is None:
            self._load_pass_k_results()
        return self.pass_ks
    
    def get_num_pass_public(self) -> Optional[np.ndarray]:
        if self.pass_ks_given_public is None:
            self._load_pass_k_public_results()
        return self.num_pass_public

    def get_pass_ks_given_public(self) -> Optional[np.ndarray]:
        if self.pass_ks_given_public is None:
            self._load_pass_k_public_results()
        return self.pass_ks_given_public

    def get_pass_ks_stds(self) -> np.ndarray:
        if self.pass_ks is None:
            self._load_pass_k_results()
        return self.stds

class ResultSeries():
    def __init__(self, base_directory: str, diversity_directory: Optional[str], datasets: list[str], models: list[str], methods: list[str], temp: float = 0.9) -> None:
        self.base_directory = base_directory
        self.diversity_directory = diversity_directory

        self.big_dict = {}
        self.datasets = datasets
        self.models = models
        self.methods = methods

        self.the_dict: dict[tuple[str, str, str], Result] = {}
        for dataset, model, method in product(self.datasets, self.models, self.methods):
            add_result = Result(self.base_directory, self.diversity_directory, dataset, method, model, temp=temp)
            if add_result.pass_k_exists():
                self.the_dict[(dataset, model, method)] = add_result 
            else:
                print(f"Warning, not adding {(dataset, model, method)}.")

    def add_results(self, r: list[Result]):
        for result in r:
            key = (result.dataset, result.model, result.method)
            assert key not in self.the_dict
            if not result.pass_k_exists():
                print(f"Warning, not adding {key}.")
                continue
            self.the_dict[key] = result

    def add_result_series(self, rs: "ResultSeries"):
        for k, v in rs.the_dict.items():
            assert k not in self.the_dict
            self.the_dict[k] = v

    def get_pass_ks(self, with_public: bool = False) -> dict[tuple[str, str, str], np.ndarray]:
        out_dict = {}
        for k, v in self.the_dict.items():
            out_dict[k] = v.get_pass_ks()
            if with_public:
                assert v.get_pass_ks_given_public() is not None
                out_dict[(k[0], k[1], "public_filtered_" + k[2])] = v.get_pass_ks_given_public()
        return out_dict

    def get_pass_ks_stds(self) -> dict[tuple[str, str, str], np.ndarray]:
        return {k: v.get_pass_ks_stds() for k, v in self.the_dict.items()}

    def get_num_pass_public(self) -> dict[tuple[str, str, str], np.ndarray]:
        out_dict = {}
        for k, v in self.the_dict.items():
            assert v.get_num_pass_public() is not None
            out_dict[k] = v.get_num_pass_public()
        return out_dict

    def get_diversities(self) -> dict[tuple[str, str, str], np.ndarray]:
        output_dict = {}
        for k, v in self.the_dict.items():
            if v.diversity_exists():
                output_dict[k] = v.get_diversities()
        return output_dict

DIVER_DIR = "../../other_logs/similar_logs/final_logs"
BASE_DIR = "../../final_results"

result_series = ResultSeries(BASE_DIR, DIVER_DIR, 
    ["livecodebench_lite_v3"],
    ["gpt-4o-mini", "gpt-4o", "deepseek-coder", "sonnet-3-5"],
    ["basic_prompting1200", "combo_observation_no"],
    temp=0.9
)

In [None]:
datasets = {"mbpp_plus": "MBPP+", "human_eval_plus": "HumanEval+", "livecodebench_lite_v3": "LiveCodeBench"}
label_to_str = {"basic_prompting225": "Repeated Sampling", "simple_idea225": "IdeaSearch", "combo_observation_no": "PlanSearch"}
model_to_str = {
    "gpt-4o-mini": "GPT-4o-mini",
    "gpt-4o": "GPT-4o",
    "deepseek-coder": "DeepSeek-Coder-V2",
    "sonnet-3-5": "Sonnet-3.5",
    "baby-deepseek-b_sgl": "DeepSeek-Coder-V2-Lite-Base",
    "baby-deepseek-i_sgl": "DeepSeek-Coder-V2-Lite-Instruct",
    "llama318b_sgl": "Llama-3.1-8B-Base",
    "llama318bi_sgl": "Llama-3.1-8B-Instruct",
    "llama3170b_sgl": "Llama-3.1-70B-Base",
    "llama3170bi_sgl": "Llama-3.1-70B-Instruct"
}
color_scheme = {
    'basic_prompting225': '#4DA6FF',  # Slightly darker blue
    'simple_idea225': '#A64DFF',      # Slightly darker purple
    'combo_observation_no': '#FF704D' # Slightly darker orange
}
BASIC_AVG = 244.05671435131
COMBO_AVG = 1427.981831636804

In [None]:
import matplotlib.patches as mpatches
pass_k_dict = result_series.get_pass_ks()
pass_k_dict = split_dict_by_datasets(split_dict_by_datasets(pass_k_dict)["livecodebench_lite_v3"])

MAX_K = 200
COMBO_TO_BASIC_MULTIPLIER = COMBO_AVG / BASIC_AVG
BASIC_COL = "basic_prompting1200"
COMBO_COL = "combo_observation_no"

model_colors = {
    "gpt-4o-mini": "#1f77b4",  # muted blue
    "gpt-4o": "#ff7f0e",       # safety orange
    "deepseek-coder": "#2ca02c", # cooked asparagus green
    "sonnet-3-5": "#d62728"    # brick red
}

plt.figure(figsize=(14, 10), dpi=300)
for model in pass_k_dict:
    combo_data = pass_k_dict[model][(COMBO_COL, )]
    basic_data = pass_k_dict[model][(BASIC_COL, )]
    combo_x = np.arange(len(combo_data)) + 1
    basic_x = np.arange(len(basic_data)) + 1

    combo_y = combo_data 
    basic_y = basic_data
    
    color = model_colors.get(model, "black")
    plt.plot(combo_x * COMBO_AVG, combo_y, label=f'{model_to_str[model]} - PlanSearch', linestyle='-', color=color, linewidth=2)
    plt.plot(basic_x * BASIC_AVG, basic_y, label=f'{model_to_str[model]} - Repeated Sampling', linestyle='-.', color=color, linewidth=2, alpha=0.75)

plt.xlim(BASIC_AVG, MAX_K * COMBO_AVG)

plt.xlabel('Average Tokens Used (per problem)', fontsize=22, fontweight='medium')
plt.ylabel('Solve-rate', fontsize=22, fontweight='medium')
plt.xscale('log')

plt.title('Compute-Normalized Repeated Sampling vs PlanSearch', fontsize=27, fontweight='medium')

# Create a legend for the line styles
line_handles = [
    plt.Line2D([0], [0], linestyle='-', color='black', linewidth=2.5, label='PlanSearch'),
    plt.Line2D([0], [0], linestyle='-.', color='black', linewidth=2.5, label='Repeated Sampling')
]
first_legend = plt.legend(handles=line_handles, fontsize=20, loc='upper left', frameon=True, title='Search Methods', title_fontsize='22')
plt.gca().add_artist(first_legend)  # Add the first legend to the axes

# Create a legend for the colors
color_handles = [plt.Line2D([0], [0], color=color, lw=4, label=model_to_str[model]) for model, color in model_colors.items()]
plt.legend(handles=color_handles, fontsize=20, loc='lower right', frameon=True, title='Models', title_fontsize='22')

plt.grid(True, linestyle='--', linewidth=0.5)
flat_basic_avg = int(np.ceil(BASIC_AVG))
ticks = [flat_basic_avg] + plt.xticks()[0].tolist()
tick_labels = [f'{flat_basic_avg}'] + plt.xticks()[1]
good_ticks = [i for i, x in enumerate(ticks) if x >= BASIC_AVG and x <= MAX_K * COMBO_AVG]
plt.xticks([ticks[i] for i in good_ticks], [tick_labels[i] for i in good_ticks], fontsize=16)
plt.yticks(fontsize=16)
plt.tight_layout()

plt.savefig("plots/compute_normalized_plansearch.pdf", format='pdf', bbox_inches='tight', dpi=300)

plt.show()