In [1]:
from pcgen.algorithms import create_scorgen_pipeline
from pcgen.baselines.clm.uncertainty import create_clm_pipeline, generate_clm
from pcgen.algorithms.base import compute_alphas
import os
from pcgen.triviaqa.paths import DATA_DIR
import pickle
import numpy as np
from pcgen.utils import set_seed, load_configs_from_jsonl


In [2]:
set_seed(0)

In [3]:
alpha = 0.2
score = "sum"
stages = ["generation", "quality", "remove_dupl"]
split_ratios = [1/2, 1/2]
alphas = compute_alphas(alpha=alpha, K=len(split_ratios), M=5)
# amount of samples to show
N = 10

In [4]:
# load calibration set
data_dir = os.path.join(DATA_DIR, "processed")
data_path = os.path.join(data_dir, "data.pkl")

# duplicate removal does not need to be calibrated
K = len(stages) - ('remove_dupl' in stages)

with open(data_path, 'rb') as file:
    data = pickle.load(file)
data_cal = data[600:1800]

In [5]:
scorgen_pipeline = create_scorgen_pipeline(data=data_cal, split_ratios=split_ratios, alphas=alphas,
                                           score=score, data_splitting=True, verbose=True,
                                           stages=stages, count_adm=False, measure_time=False)['pipeline']
# delta_1 and delta_2 are taken from the quantitative experiments
delta_1 = 0.29 
delta_2 = 0.02
clm_pipeline = create_clm_pipeline(data=data_cal, split_ratio=0.5, delta_1=delta_1, 
                                   delta_2=delta_2, use_lambda_1=False, use_lambda_2=True, 
                                   alt_lambda_1=0.5, alt_lambda_2=None, reduced_max=20, 
                                   measure_time=False, score=score, count_adm=False)

In [6]:
jsonl_path = os.path.join(DATA_DIR, "examples.jsonl")
obj = load_configs_from_jsonl(jsonl_path)
# Extract the decoded answers from generations
all_answers = []
for line in obj:
    decoded_answers = [{"idx" : idx, 'decoded' : gen['decoded']} for (idx, gen) in enumerate(line['generations'])]
    all_answers.append(decoded_answers)
# extract processed answers
processed_data = [data[i] for i in range(-N, 0)]
# add index to answers in each line
for processed_line in processed_data:
    processed_line["idxs"] = np.array([i for i in range(20)])

In [7]:
# get prediction set
scorgen_out = scorgen_pipeline.generate_new(processed_data)
# clm prediction set
kept_mask = generate_clm(data=processed_data, clm_pipeline=tuple(clm_pipeline["pipeline"][0]), score=score)

In [13]:
# go through questions and print answers
print("Scorgen Answers")
for idx, line in enumerate(obj):
    print(f"Question: {line['question']}")
    print("Generated answers:")
    str_ = "\\{"
    for idx_ in scorgen_out[idx]["idxs"]:
        str_ += line["generations"][idx_]["decoded"] + " (" + ("\\cmark" if bool(processed_data[idx]["labels"][idx_]) else "\\xmark") + ")"
        if idx_ != scorgen_out[idx]["idxs"][-1]:
            str_ += f", "
    print(str_ + "\\}")

Scorgen Answers
Question: What former U.S. president is known for his staunch support of Habitat for Humanity?
Generated answers:
\{Jimmy Carter (\cmark)\}
Question: What cat food “tastes so good, cats ask for it by name”?
Generated answers:
\{Friskies (\xmark), Whiskas (\xmark), Fancy Feast (\xmark), Sheba (\xmark), Sheba (\xmark), Felix (\xmark)\}
Question: What is the name of the giraffe that Toys-r-us uses as its' mascot?
Generated answers:
\{Geoffrey (\cmark)\}
Question: Where do you find the Bridal Veil, American, and Horseshoe Falls?
Generated answers:
\{Niagara Falls (\cmark), Niagara Falls, Canada (\cmark), Niagra Falls (\cmark)\}
Question: The worlds largest marketer of fruit juices, what is the juice arm of the Coca Cola company?
Generated answers:
\{Minute Maid (\cmark)\}
Question: Whose backing band is known as The Miami Sound Machine?
Generated answers:
\{Gloria Estefan (\cmark), Gloria Estefan & Miami Sound Machine (\cmark)\}
Question: With a motto of Always Ready, Alway

In [11]:
# go through questions and print answers
print("CLM Answers")
for idx, line in enumerate(obj):
    print(f"Question: {line['question']}")
    print("Generated answers:")
    idxs = np.where(kept_mask[idx])[0]
    str_ = "\\{"
    for idx_ in idxs:
        str_ += line["generations"][idx_]["decoded"] + " (" + ("\\cmark" if bool(processed_data[idx]["labels"][idx_]) else "\\xmark") + ")"
        if idx_ != idxs[-1]:
            str_ += f", "
    print(str_ + "\\}")

CLM Answers
Question: What former U.S. president is known for his staunch support of Habitat for Humanity?
Generated answers:
\{Jimmy Carter (\cmark)\}
Question: What cat food “tastes so good, cats ask for it by name”?
Generated answers:
\{Friskies (\xmark), Sheba (\xmark), Whiskas (\xmark), Fancy Feast (\xmark), Felix (\xmark), Purina (\xmark)\}
Question: What is the name of the giraffe that Toys-r-us uses as its' mascot?
Generated answers:
\{Geoffrey (\cmark), Geoffrey the Giraffe (\xmark), George (\xmark)\}
Question: Where do you find the Bridal Veil, American, and Horseshoe Falls?
Generated answers:
\{Niagara Falls (\cmark), Niagara Falls, Canada (\cmark), Niagra Falls (\cmark), Niagara (\xmark)\}
Question: The worlds largest marketer of fruit juices, what is the juice arm of the Coca Cola company?
Generated answers:
\{Minute Maid (\cmark)\}
Question: Whose backing band is known as The Miami Sound Machine?
Generated answers:
\{Gloria Estefan (\cmark), Gloria Estefan & Miami Sound M