In [None]:
%load_ext autoreload
%autoreload 2
from autora.doc.runtime.predict_hf import Predictor, preprocess_code
from autora.doc.runtime.prompts import PROMPTS, PromptIds, PromptBuilder, SYS_GUIDES
from autora.doc.pipelines.main import evaluate_documentation
from autora.doc.pipelines.main import eval_prompt, load_data

In [None]:
model = "meta-llama/Llama-2-7b-chat-hf"

In [None]:
pred = Predictor(model)

## Test generation for the variable declararion only

In [None]:
TEST_VAR_CODE = """
iv = Variable(name="x", value_range=(0, 2 * np.pi), allowed_values=np.linspace(0, 2 * np.pi, 30))
dv = Variable(name="y", type=ValueType.REAL)
variables = VariableCollection(independent_variables=[iv], dependent_variables=[dv])
"""
LABEL = "The discovery problem is defined by a single independent variable $x \in [0, 2 \pi]$ and dependent variable $y$."

In [None]:
def test(promptid, code, label):
    output = pred.predict(
        PROMPTS[promptid],
        [code],
        do_sample=0,
        max_new_tokens=100,
        temperature=0.05,
        top_k=10,
        num_ret_seq=1,
    )
    bleu, meteor = evaluate_documentation(output, [label])
    for i, o in enumerate(output[0]):
        print(f"{promptid}\n******* Output {i} ********. bleu={bleu}, meteor={meteor}\n{o}\n*************\n")

In [None]:
# Zero shot test
test(PromptIds.AUTORA_VARS_ZEROSHOT, TEST_VAR_CODE, LABEL)

In [None]:
# One shot test
test(PromptIds.AUTORA_VARS_ONESHOT, TEST_VAR_CODE, LABEL)

## One-shot generation for the complete code sample

In [None]:
data_file = "../data/autora/data.jsonl"
inputs, labels = load_data(data_file)
# preprocessing removes comments, import statements and empty lines
inputs = [preprocess_code(i) for i in inputs]
INSTR = "Generate high-level, one or two paragraph documentation for the following experiment."
prompt = PromptBuilder(SYS_GUIDES, INSTR).add_example(f"{inputs[0]}", labels[0]).build()
print(prompt)

In [None]:
out, bleu, meteor = eval_prompt(data_file, pred, prompt, {"max_new_tokens": 800.0})
print(f"bleu={bleu}, meteor={meteor}\n{out[0][0]}\n*************\n")