# Pass@k curves

In [None]:
import matplotlib.pyplot as plt
from pathlib import Path
from coderm.eval.metrics import get_pass_ks
from coderm.utils import gunzip_json_read
import numpy as np
from math import comb
import os
import json

def calcEstVar(n, k, c):
    p = c / n
    var = 0
    for i in range(n+1):
        var += comb(n-i, k) * p**i / comb(n, k) * (comb(n-k, i) * (1-p)**(n-i))
    return var - (1-p)**(2*k)


paths_to_results = [
    # "../../sweeps/backtranslate_dsil_base",
    # "../../sweeps/backtranslate_dsil_5words",
    # "../../sweeps/backtranslate_dsil_10words",
    # "../../sweeps/backtranslate_dsil_25words",
    # "../../sweeps/backtranslate_dsil_35words",
    # "../../sweeps/backtranslate_dsil_50words",
    # "../../sweeps/backtranslate_dsil_75words",
    # "../../sweeps/backtranslate_dsil_100words",
    # "../../sweeps/backtranslate_dsil_125words",
    # "../../sweeps/backtranslate_dsil_150words",
    # "../../sweeps/backtranslate_dsil_175words",
    # "../../sweeps/backtranslate_dsil_200words",
    # "../../sweeps/backtranslate_dsil_250words",
    # "../../sweeps/backtranslate_dsil_500words",
    # "../../sweeps/backtranslate_dsil_750words",
    # "../../sweeps/backtranslate_dsil_all",
    "../../temp_sweeps/backtranslate_base",
    "../../temp_sweeps/backtranslate_5words",
    "../../temp_sweeps/backtranslate_10words",
    "../../temp_sweeps/backtranslate_25words",
    "../../temp_sweeps/backtranslate_35words",
    "../../temp_sweeps/backtranslate_50words",
    "../../temp_sweeps/backtranslate_75words",
    "../../temp_sweeps/backtranslate_100words",
    "../../temp_sweeps/backtranslate_125words",
    "../../temp_sweeps/backtranslate_150words",
    "../../temp_sweeps/backtranslate_175words",
    "../../temp_sweeps/backtranslate_200words",
    "../../temp_sweeps/backtranslate_250words",
    "../../temp_sweeps/backtranslate_500words",
    "../../temp_sweeps/backtranslate_750words",
    "../../temp_sweeps/backtranslate_all",
]
compare = [
    # "test_results/simpleidea_small_temp0.4",
    # "test_results/simpleidea_small_temp0.5",
    # "test_results/simpleidea_small_temp0.6",
]
from pathlib import Path
for p in paths_to_results + compare:
    assert Path(p).exists(), f"Path {p} doesn't exist!"

In [None]:
avg_solution_length = {}

for path in paths_to_results:
    if "base" in path:
        # avg_solution_length[path] = 0. 
        continue

    log_path = os.path.join("../../logs", path.replace("../../", ""))
    query_path = os.path.join(log_path, "queries")

    solution_files = [f for f in os.listdir(query_path) if f.startswith("solution")]
    solution_path = None
    for solution_file in solution_files:
        solution_path = os.path.join(query_path, solution_file)
        print(f"Found solution file: {solution_path}")
    
    assert solution_path is not None

    with open(solution_path, "r") as solution_file:
        solutions = json.load(solution_file)
    
    num_tokens_list = [e["completion"]["num_tokens"] for e in solutions]
    avg_tokens = sum(num_tokens_list) / len(num_tokens_list)
    avg_solution_length[path] = avg_tokens

In [None]:
avg_solution_length

In [None]:
all_pass_ks = {}
for r in (paths_to_results + compare):
    print(f"Reading", r)
    items = gunzip_json_read(r)["items"]
    upper_k = len(items[0]["results"])
    pass_ks = {}
    for k in range(1, upper_k+1):
        pass_ks[k] = np.mean(get_pass_ks(items, k))
    all_pass_ks[r] = pass_ks

all_std = {}
for r in (paths_to_results + compare):
    print(f"Reading", r)
    items = gunzip_json_read(r)["items"]
    upper_k = len(items[0]["results"])
    
    vars = []
    for item in items:
        single_problem = []
        for k in range(1, upper_k+1):
            single_problem.append(calcEstVar(len(items[0]["results"]), k, sum(i["passing"] for i in item["results"])))
        vars.append(single_problem)

    vars = np.array(vars)
    all_std[r] = np.sqrt(np.sum(vars, axis=0) / len(items) ** 2) * 2.5

In [None]:
plt.figure(figsize=(10, 6), dpi=300)

select_ones = [path for path in paths_to_results if "base" not in path]
base_path = [path for path in paths_to_results if "base" in path][0]

plot_line = [[avg_solution_length[label], all_pass_ks[label][1], all_std[label][0]] for label in select_ones]
plot_line = np.array(sorted(plot_line))
plt.plot(plot_line[:, 0], plot_line[:, 1], linestyle='-', marker='o', markersize=6, label='Pass@1', color='#1f77b4')
baseline_pass_at_1 = all_pass_ks[base_path][1]
plt.axhline(y=baseline_pass_at_1, color='#1f77b4', linestyle='--', linewidth=2, label='Baseline Pass@1')

plot_line_pa5 = [[avg_solution_length[label], all_pass_ks[label][5], all_std[label][4]] for label in select_ones]
plot_line_pa5 = np.array(sorted(plot_line_pa5))
plt.plot(plot_line_pa5[:, 0], plot_line_pa5[:, 1], linestyle='-', marker='s', markersize=6, label='Pass@5', color='#ff7f0e')
baseline_pass_at_5 = all_pass_ks[base_path][5]
plt.axhline(y=baseline_pass_at_5, color='#ff7f0e', linestyle='--', linewidth=2, label='Baseline Pass@5')

plt.xlabel('Average Solution Token Length', fontsize=14, fontweight='bold')
plt.xscale('log')
plt.ylabel('Pass@k', fontsize=14, fontweight='bold')
plt.title('Backtranslation Performance vs Average NL Solution Length', fontsize=16, fontweight='bold')
plt.legend(fontsize=12, frameon=True, fancybox=True, )
plt.grid(True, which='both', linestyle='--', alpha=0.3)

plt.tick_params(axis='both', which='major', labelsize=12)

# Add error bars
# plt.errorbar(plot_line[:, 0], plot_line[:, 1], yerr=plot_line[:, 2], fmt='none', ecolor='gray', alpha=0.5, capsize=3)
# plt.errorbar(plot_line_pa5[:, 0], plot_line_pa5[:, 1], yerr=plot_line_pa5[:, 2], fmt='none', ecolor='gray', alpha=0.5, capsize=3)

# Adjust layout and save with high DPI
plt.tight_layout()
plt.savefig('plots/backtranslation_performance.pdf', dpi=300, bbox_inches='tight', format="pdf")
# plt.savefig('plots/backtranslation_performance.png', dpi=300, bbox_inches='tight', format="png")
plt.show()