In [23]:
import json
from typing import List, Union
import itertools

import numpy as np
from collections import defaultdict


In [24]:
def estimate_pass_at_k(
    num_samples: Union[int, List[int], np.ndarray],
    num_correct: Union[List[int], np.ndarray],
    k: int
) -> np.ndarray:
    """
    Estimates pass@k of each problem and returns them in an array.
    """

    def estimator(n: int, c: int, k: int) -> float:
        """
        Calculates 1 - comb(n - c, k) / comb(n, k).
        """
        if n - c < k:
            return 1.0
        return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

    if isinstance(num_samples, int):
        num_samples_it = itertools.repeat(num_samples, len(num_correct))
    else:
        assert len(num_samples) == len(num_correct)
        num_samples_it = iter(num_samples)

    return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])


In [25]:
names = ["Text_DaVinci_Raw_Output","Text_DaVinci_Refined_Output","GPT3.5_Raw_Output","GPT3.5_Refined_Output"]
for name in names:
    print(name)
    filename = f'../Generation/{name}.json'
    outfilename = f'./{name}_Compiled_Result.json'

    with open(outfilename, 'r') as f:
        result = json.load(f)

    results = defaultdict(list)

    for r in result:
        results[r['id']].append(r)

    # Calculate pass@k.
    total, correct = [], []
    for result in results.values():
            passed = [r["is_valid"] for r in result]
            total.append(len(passed))
            correct.append(sum(passed))
    total = np.array(total)
    correct = np.array(correct)

    ks = [1,3,5,10]
    pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean()
                        for k in ks if (total >= k).all()}
    print(pass_at_k)

Text_DaVinci_Raw_Output
{'pass@1': 0.1757217847769029, 'pass@3': 0.2667541557305337, 'pass@5': 0.30535245594300714, 'pass@10': 0.3543307086614173}
Text_DaVinci_Refined_Output
{'pass@1': 0.26968503937007876, 'pass@3': 0.373611111111111, 'pass@5': 0.4144586093404991, 'pass@10': 0.463254593175853}
GPT3.5_Raw_Output
{'pass@1': 0.3341207349081365, 'pass@3': 0.44496937882764653, 'pass@5': 0.4849914594009082, 'pass@10': 0.5249343832020997}
GPT3.5_Refined_Output
{'pass@1': 0.33556430446194224, 'pass@3': 0.4437226596675415, 'pass@5': 0.48302816314627345, 'pass@10': 0.5288713910761155}
