In [235]:
import google.generativeai as genai
import numpy as np
from sentence_transformers import SentenceTransformer
from time import sleep
import csv
import traceback
from abc import ABC, abstractmethod
from tqdm import tqdm

In [236]:
um_api_key = "your Gemini api key here"
genai.configure(api_key=um_api_key)
model = genai.GenerativeModel('gemini-pro')


## data loading

In [237]:
def get_openbook_training_data():
    out = []
    with open("train.tsv") as data:
        reader = csv.reader(data, delimiter="\t", quotechar='"')
        # TODO: understand dataset a bit more; what other things can I use?
        for i, row in enumerate(reader):
            if i == 0: continue
            ex = []
            ex.append(row[3])
            ex.append(row[4])
            out.append(ex)
    return out


def get_openbook_facts():
    facts = []
    with open("openbook.txt") as f:
        facts = f.readlines()
    return list([fact.strip().replace('"', "") for fact in facts])

In [238]:
question_data = get_openbook_training_data()
openbook_facts = get_openbook_facts() 

In [239]:
num_training = int(0.9 * len(question_data))
training_data = question_data[:num_training]
test_data = question_data[num_training:]

In [240]:
print(len(training_data))
print(len(test_data))

4461
496


In [256]:
training_data[3][0]

'Stars are (A) warm lights that float (B) made out of nitrate (C) great balls of gas burning billions of miles away (D) lights in the sky'

In [260]:
openbook_facts[33]

'An example of migration is birds flying south in the winter'

## Embedding Model

In [241]:
# returns a bert model
def get_fresh_bert_transformer():
    return SentenceTransformer("sentence-transformers/all-mpnet-base-v2")


def embed_facts(facts: list[str], model):
    return model.encode(facts, show_progress_bar=True)

global_bert = get_fresh_bert_transformer()
fact_embeddings = embed_facts(openbook_facts, global_bert)

Batches: 100%|██████████| 42/42 [00:04<00:00,  9.05it/s]


In [242]:
fact_embeddings.shape

(1326, 768)

In [243]:
# some type definitions for readability
from enum import Enum
sequence = str
prompt = str

class modelPrediction(Enum):
    A = "A"
    B = "B"
    C = "C"
    D = "D"
    safety = "S"
    fail = "F"
    parse_fail = "PF"

class trueValue(Enum):
    A = "A"
    B = "B"
    C = "C"
    D = "D"

## Prompting Modules

In [244]:
class PromptGenerator(ABC):
    @abstractmethod
    def generate(self, question, prepared_question: str) -> prompt: ...

    @abstractmethod
    def prepare_examples(self, sequences: list[str]): ...

In [245]:
class RelatedFactsPrompter(PromptGenerator):
    def __init__(self, embedded_facts, model, k=4):
        # TODO: really should take in openbook_facts and perform the embedding here, also storing those facts
        self.k = k
        self.embedded_facts = embedded_facts
        self.model = model

    def most_similar_facts(self, embedded_seq):
        # TODO: preprocessed_seq type (also returned by prepare_examples)
        dists = []
        for i, embedded_fact in enumerate(self.embedded_facts):
            dists.append((i, np.linalg.norm(embedded_seq - embedded_fact)))

        most_similar = sorted(dists, key=lambda x: x[1])[:self.k]
        return[ openbook_facts[a[0]] for a in most_similar]

    def generate(self, question, prepared_question: str) -> prompt:
        related = self.most_similar_facts(prepared_question)
        joined_facts = '. '.join(related)
        return f"Please answer the following question with the following facts in consideration {joined_facts}. The question is {question}."

    # returns a list of embedded sequences
    def prepare_examples(self, sequences):
        return self.model.encode(sequences, show_progress_bar=True)

In [246]:
rprompter = RelatedFactsPrompter(embedded_facts=fact_embeddings, model=global_bert)

In [249]:
pp = rprompter.prepare_examples([test_data[2][0]])
rprompter.generate(test_data[2][0], pp)

Batches: 100%|██████████| 1/1 [00:00<00:00,  5.28it/s]


'Please answer the following question with the following facts in consideration when a mineral is rubbed on a streak plate , some of the material breaks off and forms a powder. measuring the hardness of minerals requires scratching those materials. if one mineral can scratch another mineral then that other mineral is softer than that one mineral. pencil lead contains mineral graphite. The question is When writing with an instrument one sharpens, the leftovers when pressed to paper is (A) a squid (B) glowing (C) a mineral (D) bright white.'

In [130]:
class UnmodifiedQuestionPrompter(PromptGenerator):
    def __init__(self):
        ...

    def generate(self, question, prepared_question):
        return question

    def prepare_examples(self, sequences: list[str]):
        return sequences

In [162]:
class SimilarQuestionsPrompter(PromptGenerator):
    def __init__(self, training_questions: list[str, str], bert, k=2):
        self.model = bert
        questions, labels = zip(*training_questions)
        self.question_embeddings = self.model.encode(questions, show_progress_bar=True)
        self.questions = questions
        self.labels = labels
        self.k = k

    def most_similar_facts(self, embedded_seq):
        dists = []
        for i, embedded_question in enumerate(self.question_embeddings):
            dists.append((i, np.linalg.norm(embedded_seq - embedded_question)))

        most_similar = sorted(dists, key=lambda x: x[1])[:self.k]
        return [(self.questions[a[0]], self.labels[a[0]]) for a in most_similar]
    
    def generate(self, question, prepared_question):
        related_questions = self.most_similar_facts(prepared_question)
        promptstr = ""
        for q, label in related_questions:
            promptstr += q + "\n\n"
            promptstr += f"The correct answer is {label}.\n\n"
        
        promptstr += "Now answer the following:\n"
        promptstr += question
        return promptstr

    def prepare_examples(self, sequences):
        return self.model.encode(sequences, show_progress_bar=True)

In [250]:
sq = SimilarQuestionsPrompter(training_data, global_bert, k=3)

Batches: 100%|██████████| 140/140 [00:34<00:00,  4.02it/s]


In [251]:
pp = sq.prepare_examples([test_data[2][0]])
sq.generate(test_data[2][0], pp)

Batches: 100%|██████████| 1/1 [00:00<00:00,  1.53it/s]


'Rubbing calcium on a streak plate (A) describes a white mineral (B) leaves behind bits of white (C) tells a lot about calcium (D) breaks the calcium into chunks\n\nThe correct answer is B.\n\nWhich part of a pencil comes most directly from rocks? (A) eraser (B) logo (C) the middle (D) wood\n\nThe correct answer is C.\n\na student leaves a nail line on a mineral sample, so that mineral can be described as what? (A) a mineral (B) a soft mineral (C) a liquid mineral (D) a mineral melt\n\nThe correct answer is B.\n\nNow answer the following:\nWhen writing with an instrument one sharpens, the leftovers when pressed to paper is (A) a squid (B) glowing (C) a mineral (D) bright white'

In [221]:
class ChainOfThoughPrompter(PromptGenerator):
    def __init__(self, training_questions: list[str, str], bert, gemini):
        self.bert = bert
        self.gemini = gemini
        questions, labels = zip(*training_questions)
        self.questions = questions
        self.question_embeddings = self.bert.encode(questions, show_progress_bar=True)
        self.labels = labels

    def prepare_examples(self, sequences: list[str]):
        return self.bert.encode(sequences, show_progress_bar=True)

    def get_similar_question(self, embedded_seq):
        dists = []
        for i, embedded_question in enumerate(self.question_embeddings):
            dists.append((i, np.linalg.norm(embedded_seq - embedded_question)))

        most_similar = sorted(dists, key=lambda x: x[1])[:1]
        return [(self.questions[a[0]], self.labels[a[0]]) for a in most_similar][0]

    def get_answer_explanation(self, question, answer):
        p = f"Can you please explain why the answer to the following question is {answer}?\n\n{question}"
        try:
            r = self.gemini.generate_content(p)
            sleep(1.05)
            o = r.candidates[0].content.parts[0].text
            return f"I will show you a question and explain its answer.\n\n{question}\n\n{o}\n\n Now, please answer the following question without explaining it: \n\n"
            # return f"{r.candidates[0].content.parts[0].text}\n\n"
        except Exception as e:
            # some default thing for when we can't explain
            return f"Please answer the following question without explaining it:\n\n"
            # return "\n"

    def generate(self, question, prepared_question):
        # find a similar question
        similar_question, similar_question_answer = self.get_similar_question(prepared_question)
        # ask gemini to explain
        explanation = self.get_answer_explanation(similar_question, similar_question_answer)
        return f"{explanation}{question}"
        # return f"{similar_question}\n\n Explanation: \n\n {explanation}\n\nNow, please answer the following question: \n\n {question}"

In [252]:
cotr = ChainOfThoughPrompter(training_data, global_bert, model)

Batches: 100%|██████████| 140/140 [00:34<00:00,  4.08it/s]


In [253]:
pp = cotr.prepare_examples([test_data[2][0]])
cotr.generate(test_data[2][0], pp)

Batches: 100%|██████████| 1/1 [00:00<00:00,  1.53it/s]


'I will show you a question and explain its answer.\n\nRubbing calcium on a streak plate (A) describes a white mineral (B) leaves behind bits of white (C) tells a lot about calcium (D) breaks the calcium into chunks\n\nThe correct answer is **(B) leaves behind bits of white**.\n\nThe streak plate test is used to identify minerals. It involves rubbing a mineral on an unglazed porcelain plate to produce a streak of powder. The color of the streak can help identify the mineral. In this case, rubbing calcium on a streak plate would leave behind bits of white, which indicates that the mineral is white in color.\n\n Now, please answer the following question without explaining it: \n\nWhen writing with an instrument one sharpens, the leftovers when pressed to paper is (A) a squid (B) glowing (C) a mineral (D) bright white'

## Evaluation Code

In [132]:
def get_model_selection(response:str) -> modelPrediction:
  # TODO: the parsing on this could be better...
  # should read docs to determine all possible forms of the response
  try:
    cand = response.candidates[0]
    if cand.finish_reason != 1: return modelPrediction.parse_fail
    output = cand.content.parts[0].text
    if "A" in output: return modelPrediction.A
    if "B" in output: return modelPrediction.B
    if "C" in output: return modelPrediction.C
    if "D" in output: return modelPrediction.D

    return modelPrediction.parse_fail
  except Exception as e:
    return modelPrediction.parse_fail

In [133]:
def query_model(q: prompt) -> modelPrediction:
    # TODO: look into alternative querying methods (fucntions on genai model)
    try:
        r = model.generate_content(q)
        sleep(1.1) # for API usage rate limit
        return r, get_model_selection(r)
    except Exception as e:
        print(e)
        print(traceback.format_exc())
        return "", modelPrediction.fail

In [134]:
def evaluate_single_sequence(question, prepared_question: sequence, prompter: PromptGenerator):
    p = prompter.generate(question, prepared_question)
    r, guess = query_model(p)
    return p, r, guess

In [135]:
def build_retry_map(outputs: list[sequence, prompt, modelPrediction, trueValue, any], num_retries = 3):
    s_to_retries = {item[0] : num_retries for item in outputs if item[2] == modelPrediction.parse_fail}
    return s_to_retries

In [136]:
def get_accuracy(eval_outputs: list[sequence, prompt, modelPrediction, trueValue, any, str]) -> float:
    num_right = 0
    for _, _, pred, trueval, _, _ in eval_outputs:
        if pred.value == trueval: num_right += 1
    return num_right / len(eval_outputs)

In [137]:
def retry_questions(
    examples: list[sequence, prompt, modelPrediction, trueValue, any, str],
    retry_map: dict[str, any],
    prompter,
) -> list[sequence, prompt, modelPrediction, trueValue, any]:
    n = len(retry_map.keys())
    num_examples = len(examples)
    num_processed = 0
    i = 0

    print(f"retrying {n} prompts")

    while num_processed < n:
        question, p, modelPred, label, prepared_question, _ = examples[i]
        if question not in retry_map or modelPred != modelPrediction.parse_fail: 
            i+= 1
            if i == num_examples: i = 0
            continue
    
        p, generation, modelGuess = evaluate_single_sequence(question, prepared_question, prompter)
        examples[i] = (question, p, modelGuess, label, prepared_question, generation)

        if modelGuess == modelPrediction.parse_fail: retry_map[question] -= 1

        if retry_map[question] == 0 or modelGuess != modelPrediction.parse_fail:
            num_processed += 1
            del retry_map[question]

        print(f"processed {num_processed} / {n}", end="\r")

        if num_processed >= n:
            print()
            return examples

        i += 1
        if i == num_examples:
            i = 0

In [138]:
def evaluate_prompting_method(
    examples: list[(sequence, trueValue)], prompter: PromptGenerator
) -> list[sequence, prompt, modelPrediction, trueValue, any, str]:
    outputs: list[sequence, prompt, modelPrediction, trueValue] = []
    num_right = 0
    parsing_failures = 0

    questions, answers = zip(*examples)
    prepared_questions = prompter.prepare_examples(questions)

    print("initial evaluation loop")
    loop = tqdm(zip(questions, prepared_questions, answers))
    for i, (question, prepared_question, label) in enumerate(loop):
        p, generation, modelGuess = evaluate_single_sequence(question, prepared_question, prompter)
        outputs.append((question, p, modelGuess, label, prepared_question, generation))
        if modelGuess.value == label: num_right += 1
        elif modelGuess == modelPrediction.parse_fail:
            parsing_failures += 1
        loop.set_description(f"accuracy: {num_right / (i+1)}, parsing failures: {parsing_failures}")

    retry_map = build_retry_map(outputs)
    outputs = retry_questions(outputs, retry_map, prompter)
    acc = get_accuracy(outputs)
    print(f"Final accuracy of guesses: {acc}")

    return outputs, acc

In [141]:
related_prompter = RelatedFactsPrompter(embedded_facts=fact_embeddings, model=global_bert, k=4)

In [142]:
related_eval_outputs = evaluate_prompting_method(test_data, related_prompter)

Batches: 100%|██████████| 16/16 [00:03<00:00,  4.39it/s]


initial evaluation loop


accuracy: 0.8004032258064516, parsing failures: 28: : 496it [20:17,  2.46s/it]


retrying 28 prompts
processed 28 / 28
Final accuracy of guesses: 0.8064516129032258


In [143]:
unmodified_prompter = UnmodifiedQuestionPrompter()
unmodified_eval_outputs = evaluate_prompting_method(test_data, unmodified_prompter)

initial evaluation loop


accuracy: 0.7379032258064516, parsing failures: 46: : 496it [20:00,  2.42s/it]


retrying 46 prompts
processed 46 / 46
Final accuracy of guesses: 0.7721774193548387


In [168]:
simq = SimilarQuestionsPrompter(training_data, global_bert, k=3)
simq_eval_outputs = evaluate_prompting_method(test_data, simq)

Batches: 100%|██████████| 140/140 [00:32<00:00,  4.34it/s]
Batches: 100%|██████████| 16/16 [00:03<00:00,  4.55it/s]


initial evaluation loop


accuracy: 0.8125, parsing failures: 37: : 496it [21:50,  2.64s/it]            


retrying 37 prompts
processed 37 / 37
Final accuracy of guesses: 0.8689516129032258


In [222]:
cotr = ChainOfThoughPrompter(training_data, global_bert, model)

Batches: 100%|██████████| 140/140 [01:09<00:00,  2.01it/s]


In [224]:
cotr_eval_outputs = evaluate_prompting_method(test_data, cotr)

Batches: 100%|██████████| 16/16 [00:07<00:00,  2.08it/s]


initial evaluation loop


accuracy: 0.6270161290322581, parsing failures: 130: : 496it [58:50,  7.12s/it]


retrying 130 prompts
processed 130 / 130
Final accuracy of guesses: 0.842741935483871


In [225]:
cotr_eval_outputs

([('Animals are just like humans in that if they run out of oxygen, breathing is impossible and (A) They will perish (B) they will type. (C) they will program (D) they will Laugh',
   'I will show you a question and explain its answer.\n\nIf you find an animal that isnt breathing, it is safe to assume that thing (A) watched TV (B) cried (C) perished (D) laughed\n\nThe correct answer is C.\n\nIf an animal is not breathing, it means that it is not alive. Therefore, it is safe to assume that the animal has perished.\n\n Now, please answer the following question without explaining it: \n\nAnimals are just like humans in that if they run out of oxygen, breathing is impossible and (A) They will perish (B) they will type. (C) they will program (D) they will Laugh',
   <modelPrediction.A: 'A'>,
   'A',
   array([ 7.23904520e-02,  4.57751751e-02, -7.99057190e-04, -6.45854175e-02,
           4.39395793e-02,  1.13036213e-02, -5.12367859e-02,  8.13407004e-02,
           4.82104942e-02, -5.74322566