In [1]:
from datasets import load_dataset
from langchain.llms import OpenAI, VertexAI
from langchain.prompts import PromptTemplate

In [2]:
web_qa = load_dataset("THUDM/webglm-qa", split="train")
web_qa

Dataset({
    features: ['question', 'answer', 'references'],
    num_rows: 43579
})

In [3]:
from langchain.output_parsers import CommaSeparatedListOutputParser

prompt = ("You are a helpful AI assistant that will decide which of the given questions would "
          "likely be asked by a retail investor doing research on possible investment opportunities. "
          "Respond with the question numbers of the likely asked questions. DO NOT respond with the question itself.\n"
          "Example:\n\n"
          "1. Why does the Earth orbit the sun?\n"
          "2. Which companies are making AI products?\n"
          "3. How does protein shakes affect muscle growth?\n"
          "4. How is Microsoft generating revenue via AI?\n"
          "5. Does NEM faces legal risks with their new mining developments?\n"
          "6. Who invented the telescope?\n"
          "7. Where is New York City Located?\n"
          "8. How was Wells Fargo impacted in terms of business by the recent leak incident?\n"
          "9. Why did Paypal stock crash?\n"
          "10. When is the next olympic occurring?\n\n"
          "2, 4, 5, 8, 9\n\n{format_instructions}\n\n"
          "Follow the example for the following list of questions:\n{questions}")
output_parser = CommaSeparatedListOutputParser()
with open("../key") as fp:
    key = fp.read().strip()

plan_llm = OpenAI(model_name="gpt-3.5-turbo-instruct", openai_api_key=key,
                  temperature=0, max_tokens=512)
prompt = PromptTemplate.from_template(prompt).partial(format_instructions=output_parser.get_format_instructions())
chain = prompt | plan_llm | output_parser

In [4]:
def format_questions(questions):
    template = "{num}. {question}"
    result = []
    for i, q in enumerate(questions):
        result.append(template.format(num=i, question=q))
    return "\n".join(result)


chain.invoke({
    "questions": ["in football whats the point of wasting the first two plays with a rush - up the middle - not regular rush plays i get those", "Why are different tiers (regular < mid < premium) of gas' prices almost always 10 cents different?", "Why do you see weird colors when you press your eyes?", "Which oil and gas companies are developing new oil fields?"]
})

['4']

In [5]:
def decide_relevance(batch):
    questions = format_questions(batch["question"])
    relevant = [False for _ in batch["question"]]
    try:
        result = chain.invoke({"questions": questions})
        for r in result:
            relevant[int(r)] = True
    except:
        pass
    return {"relevant": relevant}


web_qa = web_qa.map(decide_relevance, batched=True, batch_size=10)

Map:   0%|          | 0/43579 [00:00<?, ? examples/s]

In [None]:
web_qa_filtered = web_qa.filter(lambda batch: batch["relevant"], batched=True)
delta = 100
prev_len = len(web_qa)
while prev_len - len(web_qa_filtered) >= delta:
    print("Delta: " + str(prev_len - len(web_qa_filtered)))
    prev_len = len(web_qa_filtered)
    web_qa_filtered = web_qa_filtered.map(decide_relevance, batched=True, batch_size=10)
    web_qa_filtered = web_qa_filtered.filter(lambda batch: batch["relevant"], batched=True)

Delta: 14802


Map:   0%|          | 0/28777 [00:00<?, ? examples/s]