# Pass@k curves

In [None]:
from typing import Optional, TypeVar
import matplotlib.pyplot as plt
from pathlib import Path
from coderm.eval.metrics import get_pass_ks, pass_at_k, get_pass_ks_given_public, get_num_completions_per_problem, get_num_pass_public_per_problem
from coderm.utils import gunzip_json_read

from adjustText import adjust_text
import numpy as np
from math import comb
from itertools import product
import os

def calcEstVar(n, k, c):
    p = c / n
    var = 0
    for i in range(n+1):
        var += comb(n-i, k) * p**i / comb(n, k) * (comb(n-k, i) * (1-p)**(n-i))
    return var - (1-p)**(2*k)

def convert_basic_prompting(method: str):
    # if method == "basic_prompting":
    #     return "default"
    return method

T = TypeVar('T')
def split_dict_by_datasets(d: dict[tuple[str, str, str], T]) -> dict[tuple[str, str], T]:
    output_dict = {}
    for k, v in d.items():
        if k[0] not in output_dict:
            output_dict[k[0]] = {}
        output_dict[k[0]][k[1:]] = v
    return output_dict

def calc_pass_k(ns, cs, k):
    pass_ks = []
    for n, c in zip(ns, cs):
        pass_ks.append(pass_at_k(n, c, k))
    return pass_ks

def count_ns(items, is_public: bool = False) -> list[int]:
    if not is_public:
        return [len(item["results"]) for item in items]
    else:
        out = []
        for item in items:
            n = 0
            for result in item["results"]:
                if result.get("passing_public", None) is None:
                    return None
                if result["passing_public"]:
                    n += 1
            out.append(n)
        return out

def count_cs(items, is_public: bool = False) -> list[int]:
    if not is_public:
        return [sum(ex["passing"] for ex in item["results"]) for item in items]
    else:
        out = []
        for item in items:
            correct = 0
            for result in item["results"]:
                if result.get("passing_public", None) is None:
                    return None
                if result["passing"]:
                    if not result["passing_public"]:
                        print("Warning: passes private but does not pass public")
                    else:
                        correct += 1
            out.append(correct)
        return out

def calc_pass_k(ns, cs, k):
    pass_ks = []
    for n, c in zip(ns, cs):
        pass_ks.append(pass_at_k(n, c, k))
    return pass_ks

class Result():
    def __init__(self, base_directory: str, diversity_directory: Optional[str], dataset: str, method: str, model: str, temp: float = 0.9) -> None:
        self.dataset = dataset
        self.method = method
        self.model = model
        self.temp = temp
        self.path = os.path.join(base_directory, dataset, convert_basic_prompting(method) + "_" + model + f"_temp{self.temp}")
        if diversity_directory is None:
            self.diversity_path = None
        else:
            self.diversity_path = os.path.join(diversity_directory, dataset, convert_basic_prompting(method) + "_" + model + f"_temp{self.temp}", "results.npy")

        self.pass_ks = None
        self.pass_ks_given_public = None
        self.num_pass_public = None
        self.stds = None
        self.diversities = None

    def pass_k_exists(self) -> bool:
        return Path(self.path).exists()
    def diversity_exists(self) -> bool:
        if self.diversity_path is not None:
            return Path(self.diversity_path).exists()
        return False
   
    def _load_pass_k_results(self):
        if self.pass_ks is not None:
            return

        items = gunzip_json_read(self.path)["items"]
        ns = count_ns(items)
        cs = count_cs(items)
        upper_k = max(len(item["results"]) for item in items)
        pass_ks = []
        for k in range(1, upper_k+1):
            pass_ks.append(np.mean(calc_pass_k(ns, cs, k)))
        self.pass_ks = np.array(pass_ks)
 
        # vars = []
        # for item in items:
        #     single_problem = []
        #     for k in range(1, upper_k+1):
        #         single_problem.append(calcEstVar(len(items[0]["results"]), k, sum(i["passing"] for i in item["results"])))
        #     vars.append(single_problem)
        # vars = np.array(vars)
        # self.stds = np.sqrt(np.sum(vars, axis=0) / len(items) ** 2) * 2.5
    
    def _load_pass_k_public_results(self):
        assert self.pass_k_exists()
        items = gunzip_json_read(self.path)["items"]
        upper_k = max(len(item["results"]) for item in items)
        ns = count_ns(items, is_public=True)
        cs = count_cs(items, is_public=True)

        pass_ks_gp = []
        for k in range(1, upper_k+1):
            assert (ns is not None) and (cs is not None)
            public_ks = calc_pass_k(ns, cs, k)
            # public_ks = get_pass_ks_given_public(items, k)
            assert public_ks is not None
            pass_ks_gp.append(np.mean(public_ks))

        self.pass_ks_given_public = np.array(pass_ks_gp)
        self.num_pass_public = np.array(cs)
   
    def _load_diversity_results(self):
        assert self.diversity_exists()
        self.diversities = np.load(self.diversity_path)

    def get_diversities(self) -> np.ndarray:
        if self.diversities is None:
            self._load_diversity_results()
        return self.diversities

    def get_pass_ks(self) -> np.ndarray:
        if self.pass_ks is None:
            self._load_pass_k_results()
        return self.pass_ks
    
    def get_num_pass_public(self) -> Optional[np.ndarray]:
        if self.pass_ks_given_public is None:
            self._load_pass_k_public_results()
        return self.num_pass_public

    def get_pass_ks_given_public(self) -> Optional[np.ndarray]:
        if self.pass_ks_given_public is None:
            self._load_pass_k_public_results()
        return self.pass_ks_given_public

    def get_pass_ks_stds(self) -> np.ndarray:
        if self.pass_ks is None:
            self._load_pass_k_results()
        return self.stds

class ResultSeries():
    def __init__(self, base_directory: str, diversity_directory: Optional[str], datasets: list[str], models: list[str], methods: list[str], temps: list[float] = None) -> None:
        if temps is None:
            temps = [0.9]
        self.base_directory = base_directory
        self.diversity_directory = diversity_directory

        self.big_dict = {}
        self.datasets = datasets
        self.models = models
        self.methods = methods
        self.temps = temps

        self.the_dict: dict[tuple[str, str, str], Result] = {}
        for dataset, model, method, temp in product(self.datasets, self.models, self.methods, self.temps):
            add_result = Result(self.base_directory, self.diversity_directory, dataset, method, model, temp=temp)
            if add_result.pass_k_exists():
                self.the_dict[(dataset, model, method)] = add_result 
            else:
                print(f"Warning, not adding {(dataset, model, method)}.")

    def add_results(self, r: list[Result]):
        for result in r:
            key = (result.dataset, result.model, result.method)
            assert key not in self.the_dict
            if not result.pass_k_exists():
                print(f"Warning, not adding {key}.")
                continue
            self.the_dict[key] = result

    def add_result_series(self, rs: "ResultSeries"):
        for k, v in rs.the_dict.items():
            assert k not in self.the_dict
            self.the_dict[k] = v

    def get_pass_ks(self, with_public: bool = False) -> dict[tuple[str, str, str], np.ndarray]:
        out_dict = {}
        for k, v in self.the_dict.items():
            out_dict[k] = v.get_pass_ks()
            if with_public:
                assert v.get_pass_ks_given_public() is not None
                out_dict[(k[0], k[1], "public_filtered_" + k[2])] = v.get_pass_ks_given_public()
        return out_dict

    def get_pass_ks_stds(self) -> dict[tuple[str, str, str], np.ndarray]:
        return {k: v.get_pass_ks_stds() for k, v in self.the_dict.items()}

    def get_num_pass_public(self) -> dict[tuple[str, str, str], np.ndarray]:
        out_dict = {}
        for k, v in self.the_dict.items():
            assert v.get_num_pass_public() is not None
            out_dict[k] = v.get_num_pass_public()
        return out_dict

    def get_diversities(self) -> dict[tuple[str, str, str], np.ndarray]:
        output_dict = {}
        for k, v in self.the_dict.items():
            if v.diversity_exists():
                output_dict[k] = v.get_diversities()
        return output_dict

DIVER_DIR = "../../other_logs/similar_logs/final_logs"
BASE_DIR = "../../final_results"

# result_series = ResultSeries("../../final_results/base_v_instruct", "../../other_logs/similar_logs/final_logs/base_v_instruct",
#     ["human_eval_plus", "mbpp_plus", "livecodebench_lite_v3"],
#     ["baby-deepseek-b_sgl", "baby-deepseek-i_sgl", "llama318b_sgl", "llama318bi_sgl", "llama3170b_sgl", "llama3170bi_sgl"],
#     ["basic_prompting225", ]
# )

# result_series = ResultSeries("../../final_results/base_v_instruct", "../../other_logs/similar_logs/final_logs/base_v_instruct",
#     ["livecodebench_lite_v3"],
#     ["baby-deepseek-b_sgl", "baby-deepseek-i_sgl", "llama318b_sgl", "llama318bi_sgl"],
#     ["basic_prompting10000", ]
# )

# result_series.add_exps([Result("../../final_results/llama405bi", None, "livecodebench_lite_v3", "basic_prompting10", "llama405bi_fire")])
# result_series = ResultSeries("../../final_results/llama405bi", None,
#     ["livecodebench_lite_v3"],
#     ["llama405bi_fire", "llama"],
#     ["basic_prompting10"]
# )

result_series = ResultSeries(BASE_DIR, DIVER_DIR, 
    ["human_eval_plus", "mbpp_plus", "livecodebench_lite_v3"],
    ["gpt-4o-mini", "gpt-4o", "deepseek-coder", "sonnet-3-5"],
    ["basic_prompting225", "basic_prompting_cot225", "simple_idea225"],
    temps=[0.9]
)
# result_series = ResultSeries(BASE_DIR, DIVER_DIR, 
#     ["human_eval_plus", "mbpp_plus", "livecodebench_lite_v3"],
#     ["gpt-4o-mini", "gpt-4o", "deepseek-coder", "sonnet-3-5"],
#     ["basic_prompting225", "simple_idea225", "combo_observation_no"],
#     # ["basic_prompting50"],
#     temps=[0.9]
# )
# result_series.add_result_series(result_seriesd)

In [None]:
datasets = {"mbpp_plus": "MBPP+", "human_eval_plus": "HumanEval+", "livecodebench_lite_v3": "LiveCodeBench"}
label_to_str = {"basic_prompting225": "Repeated Sampling", "simple_idea225": "IdeaSearch", "combo_observation_no": "PlanSearch"}
model_to_str = {
    "gpt-4o-mini": "GPT-4o-mini",
    "gpt-4o": "GPT-4o",
    "deepseek-coder": "DeepSeek-Coder-V2",
    "sonnet-3-5": "Sonnet-3.5",
    "baby-deepseek-b_sgl": "DeepSeek-Coder-V2-Lite-Base",
    "baby-deepseek-i_sgl": "DeepSeek-Coder-V2-Lite-Instruct",
    "llama318b_sgl": "Llama-3.1-8B-Base",
    "llama318bi_sgl": "Llama-3.1-8B-Instruct",
    "llama3170b_sgl": "Llama-3.1-70B-Base",
    "llama3170bi_sgl": "Llama-3.1-70B-Instruct"
}
color_scheme = {
    'basic_prompting225': '#4DA6FF',  # Slightly darker blue
    'simple_idea225': '#A64DFF',      # Slightly darker purple
    'combo_observation_no': '#FF704D' # Slightly darker orange
}

In [None]:
use_public = False
MAX_X = 200 if not use_public else 20

for dataset in datasets:
    plt.figure(figsize=(13, 10))
    all_pass_ks = split_dict_by_datasets(result_series.get_pass_ks(with_public=use_public))[dataset]
    all_pass_ks_by_model = split_dict_by_datasets(all_pass_ks)
    avg = {}
    for model in result_series.models:
        final_data = all_pass_ks_by_model[model]
        starting_val = final_data[("basic_prompting225",)][0]
        for label, values in final_data.items():
            to_plot = values[:MAX_X]
            ks = np.arange(len(to_plot)) + 1
            pass_k = (to_plot - starting_val) / starting_val * 100
            if label[0] not in avg:
                avg[label[0]] = []
            avg[label[0]].append(pass_k)
        
    min_ylim, max_ylim = float('inf'), float('-inf')
    for label, values in avg.items():
        is_public_filtered = 'public_filtered' in label
        if is_public_filtered:
            use_label = label.split("public_filtered_")[1]
        else:
            use_label = label

        pass_k = np.stack(values, axis=0).mean(axis=0)
        ks = np.arange(len(pass_k)) + 1

        str_label = use_label if use_label not in label_to_str else label_to_str[use_label]
        str_label = str_label if not is_public_filtered else "Public Filtered " + str_label

        color = color_scheme.get(use_label, 'black')
        
        if use_public:
            linestyle = '-' if is_public_filtered else '--'
            lw = 2.5 if is_public_filtered else 2
            a = 1 if is_public_filtered else 0.55
        else:
            linestyle = '-'
            lw = 2.5
            a = 1
        
        plt.plot(ks[:MAX_X], pass_k[:MAX_X], label=str_label, linestyle=linestyle, color=color, linewidth=lw, alpha=a)
        
        min_ylim = min(min_ylim, pass_k[:MAX_X].min())
        max_ylim = max(max_ylim, pass_k[:MAX_X].max())

    plt.xlabel('k', fontsize=22)
    plt.xscale('log')
    plt.ylabel('Percent Improvement (%)', fontsize=24)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    if use_public:
        plt.title(f'Average Improvement with Public Filtering over Pass@1 on {datasets[dataset]}', fontsize=22)
    else:
        plt.title(f'Average Improvement over Pass@1 on {datasets[dataset]}', fontsize=22)
    plt.legend(fontsize=18, loc='lower right', bbox_to_anchor=(1, 0))
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.grid(which='minor', linestyle=':', linewidth='0.5', color='gray', alpha=0.5)
    plt.minorticks_on()
    plt.xlim(1, MAX_X)
    plt.tight_layout()
    # if use_public:
    #     plt.savefig(f"plots/public_avg_perfimprovement_{dataset}.pdf", format="pdf", dpi=300, bbox_inches='tight')
    # else:
    #     plt.savefig(f"plots/avg_perfimprovement_{dataset}.pdf", format="pdf", dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
lol = result_series.get_pass_ks(with_public=True)
K = 20
splitted = split_dict_by_datasets(lol)

for k, d in splitted.items():
    print(k)
    for (model, method), v in d.items():
        # if "public" in method:
        print(model, method, v[K-1])

In [None]:
pass_ks = result_series.get_pass_ks()
pass_ks_split = {}
for key, value in pass_ks.items():
    dataset, model, method = key
    if model not in pass_ks_split:
        pass_ks_split[model] = {}
    if dataset != "livecodebench_lite_v3":
        continue
    pass_ks_split[model][method] = value
pass_ks = pass_ks_split

K = 200
for k, v in pass_ks.items():
    di = {}
    for k1, v1 in v.items():
        if "simple_idea" in k1:
            continue
        if "basic_prompting" in k1:
            di["basic_prompting1"] = v["basic_prompting225"][0]
            di[f"basic_prompting{K}"] = v["basic_prompting225"][K-1]
        else:
            di[k1] = v1[K-1]
    pass_ks[k] = di

def method_to_str(method: str, default_k: Optional[int] = None) -> str:
    if default_k is None:
        return label_to_str[method]
    else:
        if method.startswith("basic_prompting"):
            k = method.split("basic_prompting")[1]
            return f"Repeated Sampling@{k}"
        return label_to_str[method] + f"@{default_k}"

def method_to_color(method: str):
    method_colors = {
        "Repeated Sampling@1": '#4CAF50',  # Slightly darker green color
        "Repeated Sampling@200": color_scheme["basic_prompting225"],
        "PlanSearch@200": color_scheme["combo_observation_no"],
    }
    return method_colors.get(method, "#808080")
models = list(pass_ks.keys())
methods = list(pass_ks[models[0]].keys())
x = np.arange(len(models))
width = 0.25
multiplier = 0

fig, ax = plt.subplots(figsize=(12, 9))

for method in methods:
    values = [pass_ks[model][method] for model in models]
    offset = width * multiplier
    rects = ax.bar(x + offset, values, width, label=method_to_str(method, K), 
                   color=method_to_color(method_to_str(method, K)), edgecolor='black', linewidth=0.4, alpha=0.81)
    multiplier += 1

ax.set_ylabel('Pass@k', fontsize=16, fontweight='medium')
ax.set_title('Pass@k Scores by Method on LiveCodeBench', fontsize=26, fontweight='medium')
ax.set_xticks(x + width)
ax.set_xticklabels([model_to_str[model] for model in models], rotation=45, ha='right', fontsize=12)
ax.legend(loc='upper left', fontsize=14, frameon=True, edgecolor='black')

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1.5)
ax.spines['bottom'].set_linewidth(1.5)

ax.tick_params(axis='both', which='major', labelsize=16, width=1.5, length=6)
ax.tick_params(axis='both', which='minor', width=1, length=4)

for rect in ax.patches:
    height = rect.get_height()
    ax.text(rect.get_x() + rect.get_width()/2., height,
            f'{height:.2f}',
            ha='center', va='bottom', fontsize=13, fontweight='bold')

ax.grid(axis='y', linestyle='--', alpha=0.5, zorder=0)
plt.tight_layout()
plt.savefig("plots/pass_k_scores_by_model_method_livecodebench.pdf", format="pdf", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
import matplotlib.patches as mpatches
SELECTED_K = 200
marker_styles = ['d', 's', 'o', 'P', 'v', 'X', 'H', '8', 'd']  # Define different marker styles with uniform area
model_to_marker = {model: marker_styles[i % len(marker_styles)] for i, model in enumerate(model_to_str.keys())}

for dataset in datasets:
    diversities = split_dict_by_datasets(result_series.get_diversities())[dataset]
    diversities = {k: 1 - v.mean() for k, v in diversities.items()}

    all_pass_ks = split_dict_by_datasets(result_series.get_pass_ks())[dataset]
    all_pass_0s = {k: (v[SELECTED_K-1] - v[0]) / v[0] for k, v in all_pass_ks.items()}
    
    plt.figure(figsize=(12, 8), dpi=300)
    labels_to_idx = {v: i for i, v in enumerate(list(label_to_str))}
    colors = [color_scheme[label[1]] for label in all_pass_0s.keys()]
    markers = [model_to_marker[label[0]] for label in all_pass_0s.keys()]
    avail_models = dict((label[0], None) for label in all_pass_0s.keys())
    
    for (diversity, pass_0, color, marker) in zip(diversities.values(), all_pass_0s.values(), colors, markers):
        plt.scatter(diversity, pass_0, c=color, edgecolor='k', alpha=0.81, s=300, marker=marker, zorder=3)
    
    # Create a legend for the colors
    color_handles = [mpatches.Patch(color=color_scheme[label], label=label_to_str[label], alpha=0.7) for label, i in labels_to_idx.items()]
    
    # Create a legend for the marker styles
    marker_handles = [plt.Line2D([0], [0], marker=model_to_marker[model], color='w', markerfacecolor='w', markersize=14, label=model_to_str[model], markeredgecolor='k', markeredgewidth=1.5, alpha=0.7) for model in avail_models]
    
    first_legend = plt.legend(handles=color_handles, fontsize=20, loc='upper left', frameon=True)
    ax = plt.gca().add_artist(first_legend)  # Add the first legend to the axes
    plt.legend(handles=marker_handles, fontsize=20, loc='lower right', frameon=True)
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.ylabel(f'Relative Gains (Pass@1 to Pass@{SELECTED_K})', fontsize=22, fontweight='medium')
    plt.xlabel('Idea Diversity', fontsize=22, fontweight='medium')
    plt.title(f'Idea Diversity vs Relative Gains from Search (on {datasets[dataset]})', fontsize=25, fontweight='medium')
    plt.grid(True, linestyle='--', alpha=0.6, zorder=0)
    
    plt.tight_layout()
    plt.savefig(f"plots/diversity_vs_improvement_{dataset}.pdf", format="pdf", dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
def split_into_public_private(d: dict[tuple[str, str], np.ndarray]) -> dict[tuple[str, str], np.ndarray]:
    CONSTANT = "public_filtered"
    possible = {k: [v] for k, v in d.items() if CONSTANT not in k[1]}

    for k, v in possible.items():
        v.append(d[(k[0], CONSTANT + "_" + k[1])])
    for k, v in possible.items():
        possible[k] = np.stack(v, axis=0)
    
    return possible

pp_data = split_into_public_private(split_dict_by_datasets(result_series.get_pass_ks(with_public=True))["human_eval_plus"])

MAX_LEN = 10
OFFSET = 2

public_idx = np.arange(MAX_LEN) + 1
private_idx = public_idx * OFFSET


plt.figure(figsize=(9, 9))
for label, values in pp_data.items():
    linestyle = '-'
    plt.plot(values[0, private_idx-1], values[1, public_idx-1], label=f'|'.join(label), linestyle=linestyle)
x = np.linspace(0.3, 1, 40)
y = x
plt.plot(x, y, label='x = y', linestyle='--', color='red')


plt.xlabel('private score')
plt.ylabel('public score')
plt.title(f'public vs private score')
plt.legend(fontsize='small', loc='lower center')
plt.grid(True)
plt.show()

In [None]:
use_public = False
MAX_X = 200 if not use_public else 20

# plt.style.use('seaborn-whitegrid')
# colors = plt.cm.Set2(np.linspace(0, 1, 10))

for dataset in datasets:
    all_pass_ks = split_dict_by_datasets(result_series.get_pass_ks(with_public=use_public))[dataset]
    all_pass_ks_by_model = split_dict_by_datasets(all_pass_ks)
    
    num_models = len(result_series.models)
    num_rows = 2
    num_cols = 2
    fig, axs = plt.subplots(num_rows, num_cols, figsize=(14, 14), squeeze=False)
    if use_public:
        fig.suptitle(f'Pass@k vs k for Methods with Public Filtering on {datasets[dataset]}', fontsize=24, fontweight='medium')
    else:
        fig.suptitle(f'Pass@k vs k for Methods on {datasets[dataset]}', fontsize=24, fontweight='medium')
    
    # Find the global y-axis limits
    y_min, y_max = float('inf'), float('-inf')
    for model in result_series.models:
        final_data = all_pass_ks_by_model[model]
        for values in final_data.values():
            y_min = min(y_min, np.min(values[:MAX_X]))
            y_max = max(y_max, np.max(values[:MAX_X]))
    
    # Extend y_max slightly upward, but cap at 1
    y_max = min(y_max + 0.015, 1.0)
    
    for idx, model in enumerate(result_series.models):
        row = idx // num_cols
        col = idx % num_cols
        ax = axs[row, col]
        final_data = all_pass_ks_by_model[model]
        for i, (label, values) in enumerate(final_data.items()):
            to_plot = values[:MAX_X]
            ks = np.arange(1, MAX_X + 1)

            is_public_filtered = 'public_filtered' in label[0]
            if is_public_filtered:
                use_label = label[0].split("public_filtered_")[1]
            else:
                use_label = label[0]

            str_label = use_label if use_label not in label_to_str else label_to_str[use_label]
            str_label = str_label if not is_public_filtered else "Public Filtered " + str_label
            
            # Determine color and linestyle
            color = color_scheme.get(use_label, 'black')
            
            if use_public:
                linestyle = '-' if is_public_filtered else '--'
                lw = 2.3 if is_public_filtered else 1.8
                a = 0.95 if is_public_filtered else 0.5
            else:
                linestyle = '-'
                lw = 2.3
                a = 0.95
            
            ax.plot(ks, to_plot, label=str_label, linestyle=linestyle, linewidth=lw, color=color, alpha=a)
        
        if row == num_rows - 1:
            ax.set_xlabel('k', fontsize=17)
        ax.set_xscale('log')
        ax.set_xlim(1, MAX_X)
        if col == 0:
            ax.set_ylabel('Pass@k', fontsize=17)
        ax.set_title(f'{model_to_str[model]}', fontsize=19)
        ax.legend(fontsize=14, loc='lower right', frameon=True, fancybox=True)
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.set_ylim(y_min, y_max)
        
        # Improve tick labels
        ax.tick_params(axis='both', which='major', labelsize=14)
        ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: format(int(x), ',')))
        
        # Add minor gridlines
        ax.xaxis.grid(True, which='minor', linestyle=':', alpha=0.4)
        ax.yaxis.grid(True, which='minor', linestyle=':', alpha=0.4)
    
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    if use_public:
        plt.savefig(f"plots/public_pass_at_k_{dataset}.pdf", format="pdf", dpi=300, bbox_inches='tight')
    else:
        plt.savefig(f"plots/pass_at_k_{dataset}.pdf", format="pdf", dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
MAX_X = 10000

# plt.style.use('seaborn-whitegrid')
# colors = plt.cm.Set2(np.linspace(0, 1, 10))
datasets = {"livecodebench_lite_v3": "LiveCodeBench"}
model_groups = {
    "DeepSeek-Coder-V2-Lite": ["baby-deepseek-b", "baby-deepseek-i"],
    "Llama-3.1-8B": ["llama318b", "llama318bi"],
    # "Llama-3.1-70B": ["llama3170b", "llama3170bi"]
}

for dataset in datasets:
    all_pass_ks = split_dict_by_datasets(result_series.get_pass_ks(with_public=False))[dataset]
    all_pass_ks_by_model = split_dict_by_datasets(all_pass_ks)
    
    # Plot each model group separately
    for group_name, group_models in model_groups.items():
        fig, ax = plt.subplots(figsize=(12, 7))
        fig.suptitle(f'Pass@k vs k for {group_name} Models on {datasets[dataset]}', fontsize=23, fontweight='medium')
        
        # Find the y-axis limits for the current group
        y_min, y_max = float('inf'), float('-inf')
        for model in result_series.models:
            if any(model.startswith(group_model) for group_model in group_models):
                final_data = all_pass_ks_by_model[model]
                for values in final_data.values():
                    y_min = min(y_min, np.min(values[:MAX_X]))
                    y_max = max(y_max, np.max(values[:MAX_X]))
        
        # Extend y_max slightly upward, but cap at 1
        y_max = min(y_max + 0.015, 1.0)
        
        for model in result_series.models:
            if any(model.startswith(group_model) for group_model in group_models):
                final_data = all_pass_ks_by_model[model]
                for i, (label, values) in enumerate(final_data.items()):
                    to_plot = values[:MAX_X]
                    ks = np.arange(1, MAX_X + 1)
                    str_label = label[0] if label[0] not in label_to_str else label_to_str[label[0]]
                    ax.plot(ks, to_plot, label=f'{model_to_str[model]}', linestyle='-', linewidth=3.5)
        
        ax.set_xlabel('k', fontsize=22)
        ax.set_xscale('log')
        ax.set_xlim(1, MAX_X)
        ax.set_ylabel('Pass@k', fontsize=22)
        ax.legend(fontsize=16, loc='lower right', frameon=True, fancybox=True)
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.set_ylim(y_min, y_max)
        
        # Improve tick labels
        ax.tick_params(axis='both', which='major', labelsize=16)
        ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: format(int(x), ',')))
        
        # Add minor gridlines
        ax.xaxis.grid(True, which='minor', linestyle=':', alpha=0.4)
        ax.yaxis.grid(True, which='minor', linestyle=':', alpha=0.4)
        
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.savefig(f"plots/basevinstruct_big_{dataset}_{group_name}.pdf", format="pdf", dpi=300, bbox_inches='tight')
        # plt.savefig(f"plots/pass_at_k_{dataset}_{group_name}.png", format="png", dpi=300, bbox_inches='tight')
        plt.show()