# Pass@k curves

In [None]:
import matplotlib.pyplot as plt
from pathlib import Path
from coderm.eval.metrics import get_pass_ks
from coderm.utils import gunzip_json_read
import numpy as np
from math import comb
import os
import json

def calcEstVar(n, k, c):
    p = c / n
    var = 0
    for i in range(n+1):
        var += comb(n-i, k) * p**i / comb(n, k) * (comb(n-k, i) * (1-p)**(n-i))
    return var - (1-p)**(2*k)


paths_to_results = [
    "sweeps/backtranslate_dsil_base",
    "sweeps/backtranslate_dsil_5words",
    "sweeps/backtranslate_dsil_10words",
    "sweeps/backtranslate_dsil_25words",
    "sweeps/backtranslate_dsil_35words",
    "sweeps/backtranslate_dsil_50words",
    "sweeps/backtranslate_dsil_75words",
    "sweeps/backtranslate_dsil_100words",
    "sweeps/backtranslate_dsil_125words",
    "sweeps/backtranslate_dsil_150words",
    "sweeps/backtranslate_dsil_175words",
    "sweeps/backtranslate_dsil_200words",
    "sweeps/backtranslate_dsil_250words",
    "sweeps/backtranslate_dsil_500words",
    "sweeps/backtranslate_dsil_750words",
    "sweeps/backtranslate_dsil_all",
    # "temp_sweeps/backtranslate_base",
    # "temp_sweeps/backtranslate_5words",
    # "temp_sweeps/backtranslate_10words",
    # "temp_sweeps/backtranslate_25words",
    # "temp_sweeps/backtranslate_35words",
    # "temp_sweeps/backtranslate_50words",
    # "temp_sweeps/backtranslate_75words",
    # "temp_sweeps/backtranslate_100words",
    # "temp_sweeps/backtranslate_125words",
    # "temp_sweeps/backtranslate_150words",
    # "temp_sweeps/backtranslate_175words",
    # "temp_sweeps/backtranslate_200words",
    # "temp_sweeps/backtranslate_250words",
    # "temp_sweeps/backtranslate_500words",
    # "temp_sweeps/backtranslate_750words",
    # "temp_sweeps/backtranslate_all",
]
compare = [
    # "test_results/simpleidea_small_temp0.4",
    # "test_results/simpleidea_small_temp0.5",
    # "test_results/simpleidea_small_temp0.6",
]
from pathlib import Path
for p in paths_to_results + compare:
    assert Path(p).exists(), f"Path {p} doesn't exist!"

In [None]:
avg_solution_length = {}

for path in paths_to_results:
    if "base" in path:
        avg_solution_length[path] = 0. 
        continue

    log_path = os.path.join("logs", path)
    query_path = os.path.join(log_path, "queries")

    solution_files = [f for f in os.listdir(query_path) if f.startswith("solution")]
    solution_path = None
    for solution_file in solution_files:
        solution_path = os.path.join(query_path, solution_file)
        print(f"Found solution file: {solution_path}")
    
    assert solution_path is not None

    with open(solution_path, "r") as solution_file:
        solutions = json.load(solution_file)
    
    num_tokens_list = [e["completion"]["num_tokens"] for e in solutions]
    avg_tokens = sum(num_tokens_list) / len(num_tokens_list)
    avg_solution_length[path] = avg_tokens

In [None]:
avg_solution_length

In [None]:
all_pass_ks = {}
for r in (paths_to_results + compare):
    print(f"Reading", r)
    items = gunzip_json_read(r)["items"]
    upper_k = len(items[0]["results"])
    pass_ks = {}
    for k in range(1, upper_k+1):
        pass_ks[k] = np.mean(get_pass_ks(items, k))
    all_pass_ks[r] = pass_ks

all_std = {}
for r in (paths_to_results + compare):
    print(f"Reading", r)
    items = gunzip_json_read(r)["items"]
    upper_k = len(items[0]["results"])
    
    vars = []
    for item in items:
        single_problem = []
        for k in range(1, upper_k+1):
            single_problem.append(calcEstVar(len(items[0]["results"]), k, sum(i["passing"] for i in item["results"])))
        vars.append(single_problem)

    vars = np.array(vars)
    all_std[r] = np.sqrt(np.sum(vars, axis=0) / len(items) ** 2) * 2.5

In [None]:
plt.figure(figsize=(12, 6))

select_ones = paths_to_results[1:]
select_ones = paths_to_results

plot_line = [[avg_solution_length[label], all_pass_ks[label][1], all_std[label][0]] for label in select_ones]
plot_line = np.array(sorted(plot_line))
plt.plot(plot_line[:, 0]+1, plot_line[:, 1], linestyle='-', label='Pass@1', marker='o')
plt.fill_between(plot_line[:, 0]+1, plot_line[:, 1] - plot_line[:, 2], plot_line[:, 1] + plot_line[:, 2], alpha=0.2)

plot_line_pa5 = [[avg_solution_length[label], all_pass_ks[label][5], all_std[label][4]] for label in select_ones]
plot_line_pa5 = np.array(sorted(plot_line_pa5))
plt.plot(plot_line_pa5[:, 0]+1, plot_line_pa5[:, 1], linestyle='-', label='Pass@5', marker='o')
plt.fill_between(plot_line_pa5[:, 0]+1, plot_line_pa5[:, 1] - plot_line_pa5[:, 2], plot_line_pa5[:, 1] + plot_line_pa5[:, 2], alpha=0.2)

# for label in all_pass_ks:
#     plot_line.append()
#     # ks = list(values.keys())
#     # pass_at_k = list(values.values())
#     std_devs = list(all_std[label])
#     linestyle = '--' if any(p in label for p in paths_to_results) and len(compare) > 0 else '-'
#     plt.plot
#     plt.plot(ks, pass_at_k, label=Path(label), linestyle=linestyle)
#     plt.fill_between(ks, np.array(pass_at_k) - np.array(std_devs), np.array(pass_at_k) + np.array(std_devs), alpha=0.2)

plt.xlabel('Avg Solution Token Length (+1)')
# plt.xscale('symlog', linthresh=20)
plt.xscale('log')
plt.ylabel('Pass@1')
plt.title('Pass@1 and Pass@5 vs avg solution token length in backtranslation')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
ps = []

items = gunzip_json_read("test_results/simpleidea_small")["items"]
for item in items:
    trials = []
    for trial in item["results"]:
        trials.append(trial["passing"])
    ps.append(sum(trials) / len(trials))

plt.hist(ps)
sorted(ps)

In [None]:
plt.figure(figsize=(12, 6))
for label, values in all_pass_ks.items():
    ks = list(values.keys())
    pass_at_k = list(values.values())
    std_devs = list(all_std[label])
    linestyle = '--' if any(p in label for p in paths_to_results) and len(compare) > 0 else '-'
    plt.plot(ks, pass_at_k, label=Path(label), linestyle=linestyle)
    plt.fill_between(ks, np.array(pass_at_k) - np.array(std_devs), np.array(pass_at_k) + np.array(std_devs), alpha=0.2)

plt.xlabel('k')
plt.xscale('log')
plt.ylabel('Pass@k')
plt.title('Pass@k vs k for different models')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Generate paths for the 25 result files
idea_temps = [0.2, 0.3, 0.4, 0.5, 0.6]
code_temps = [0.2, 0.3, 0.4, 0.5, 0.6]
heatmap_data = np.zeros((len(idea_temps), len(code_temps)))

sel_k = 90
for i, it in enumerate(idea_temps):
    for j, ct in enumerate(code_temps):
        path = f"sweeps/simpleidea_small_it{it}_ct{ct}"
        if not Path(path).exists():
            pass_ks = 0
        else:
            items = gunzip_json_read(path)["items"]
            pass_ks = np.mean(get_pass_ks(items, sel_k))
        heatmap_data[i, j] = pass_ks
# Plotting the heatmap
plt.figure(figsize=(10, 8))
plt.imshow(heatmap_data, cmap='plasma', origin='lower')
plt.colorbar(label=f'Pass@{sel_k}')
plt.xticks(ticks=np.arange(len(code_temps)), labels=code_temps)
plt.yticks(ticks=np.arange(len(idea_temps)), labels=idea_temps)
plt.xlabel('Code Temperature (ct)')
plt.ylabel('Idea Temperature (it)')
plt.title(f'Pass@{sel_k} Heatmap for Different Temperatures')

# Function to determine text color based on brightness
def get_text_color(value, cmap):
    norm = plt.Normalize(vmin=heatmap_data.min(), vmax=heatmap_data.max())
    rgba = cmap(norm(value))
    brightness = 0.299 * rgba[0] + 0.587 * rgba[1] + 0.114 * rgba[2]
    return 'black' if brightness > 0.5 else 'white'

cmap = plt.get_cmap('plasma')
for i in range(len(idea_temps)):
    for j in range(len(code_temps)):
        text_color = get_text_color(heatmap_data[i, j], cmap)
        plt.text(j, i, f"{heatmap_data[i, j]:.2f}", ha='center', va='center', color=text_color, fontsize=12)
plt.show()

In [None]:
import json

In [None]:
with open("test_results/temp.json", "r") as f:
    data = json.load(f)