Retrieval Augmented Generation with GPT LMs

We create a retriever-reader pipeline to provide retrieval augmented systems based on large language models (LLMs) such as OpenAI's ChatGPT(`gpt-3.5-turbo`) and InstructGPT(`text-davinci-003`). These systems generate answers conditioned on the retrieved passages.

For every question we show the ChatGPT results with and without the retrieved documents. 


In [1]:
from primeqa.components.retriever.dense import ColBERTRetriever
from primeqa.components.reader.prompt import PromptGPTReader
from primeqa.pipelines.qa_pipeline import QAPipeline
import json

{"time":"2023-03-06 10:52:19,329", "name": "faiss.loader", "level": "INFO", "message": "Loading faiss with AVX2 support."}
{"time":"2023-03-06 10:52:19,725", "name": "faiss.loader", "level": "INFO", "message": "Successfully loaded faiss with AVX2 support."}


In [2]:
# setup ColBERT index
index_root = "/dccstor/colbert-ir/franzm/indexes/oct2_10_11/oct2_10_11_exp/indexes/"
index_name = "oct2_10_11_indname"
collection = "/dccstor/avi7/neural_ir/colbert/data/psgs_w100.tsv"


retriever = ColBERTRetriever(index_root = index_root, 
                                     index_name = index_name, 
                                     collection = collection, 
                                     max_num_documents = 3)
retriever.load()

[Mar 06, 10:52:39] #> base_config.py from_path /dccstor/colbert-ir/franzm/indexes/oct2_10_11/oct2_10_11_exp/indexes//oct2_10_11_indname/metadata.json
[Mar 06, 10:52:39] #> base_config.py from_path args loaded! 
[Mar 06, 10:52:39] #> base_config.py from_path args replaced ! 
[Mar 06, 10:52:50] #>>>>> at ColBERT name (model type) : /dccstor/colbert-ir/franzm/experiments/oct2_7_12_1.5e-06/none/2022-10/09/15.21.39/checkpoints/colbert.dnn.batch_91287.model
[Mar 06, 10:52:50] #>>>>> at BaseColBERT name (model type) : /dccstor/colbert-ir/franzm/experiments/oct2_7_12_1.5e-06/none/2022-10/09/15.21.39/checkpoints/colbert.dnn.batch_91287.model
[Mar 06, 10:52:57] factory model type: xlm-roberta-large
[Mar 06, 10:53:12] get query model type: xlm-roberta-large
[Mar 06, 10:53:13] get doc model type: xlm-roberta-large
[Mar 06, 10:53:43] #> Loading codec...
[Mar 06, 10:53:43] #> base_config.py from_path /dccstor/colbert-ir/franzm/indexes/oct2_10_11/oct2_10_11_exp/indexes//oct2_10_11_indname/metadata.js

In [3]:
# setup a Prompt Reader: here we use a GPT reader
# supports models gpt-3.5-turbo and text-davinci-003
reader = PromptGPTReader(model_name='gpt-3.5-turbo', api_key='API KEY HERE')
reader.load()

In [4]:
# setup the pipeline
pipeline = QAPipeline(retriever, reader)

21015325it [01:07, 311127.86it/s]


In [5]:
# start asking questions without retriever
questions = ["number of participating countries in tour de france 2017 ?"]
answers = pipeline.run(questions, use_retriever=False)
print(json.dumps(answers, indent=4))

{
    "0": {
        "answers": {
            "text": "\n\nAnswer: There were a total of 22 participating countries in the Tour de France 2017."
        }
    }
}


In [6]:
# start asking questions with retriever
questions = ["number of participating countries in tour de france 2017 ?"]
prompt_prefix = "Answer the following question after looking at the text."
answers = pipeline.run(questions, prefix=prompt_prefix, use_retriever=True)
print(json.dumps(answers, indent=4))

[Mar 06, 11:04:51] #> XMLR QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==
[Mar 06, 11:04:51] #> Input: $ number of participating countries in tour de france 2017 ?, 		 True, 		 None
[Mar 06, 11:04:51] #> Output IDs: torch.Size([32]), tensor([    0,  9748, 14012,   111, 42938,   214, 76726,    23,  9742,     8,
        37863,    13,   505,   705,     2,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1])
[Mar 06, 11:04:51] #> Output Mask: torch.Size([32]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
[Mar 06, 11:04:51] #>>>> colbert query ==
[Mar 06, 11:04:51] #>>>>> input_ids: torch.Size([32]), tensor([    0,  9748, 14012,   111, 42938,   214, 76726,    23,  9742,     8,
        37863,    13,   505,   705,     2,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,   

100%|████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  6.45it/s]


{
    "0": {
        "answers": {
            "text": "How many countries participated in the Tour de France 2017?\n\n32 countries participated in the Tour de France 2017."
        },
        "passages": [
            "\"2017 Tour de France\"\n \"de France. The total number of riders that finished the race was 167. The riders came from 32 countries. Six countries had more than 10 riders in the race: France (39), Italy (18), Belgium (16), Germany (16), the Netherlands (15), and Spain (13). The average age of riders in the race was 29.4 years, ranging from the 22-year-old \u00c9lie Gesbert () to the 40-year-old Haimar Zubeldia (). had the youngest average age while had the oldest. The teams entering the race were: In the lead up to the 2017 Tour de France, Chris Froome () was seen by many pundits\"",
            "\"2018 Tour de France\"\n \"reduced the number of riders per team for Grand Tours from 9 to 8, resulting in a start list total of 176, instead of the usual 198. Of these, 35 com

In [7]:
# start asking questions without retreiver
questions = ["How many area codes are in New Hampshire ?"]
answers = pipeline.run(questions, use_retriever=False)
print(json.dumps(answers, indent=4))

{
    "0": {
        "answers": {
            "text": "\n\nAs an AI language model, I don't have the most up-to-date information, but according to the North American Numbering Plan Administration (NANPA) as of August 2021, there are two area codes assigned to New Hampshire: 603 and 364."
        }
    }
}


In [8]:
# start asking questions with retreived passages
questions = ["How many area codes are in New Hampshire?"]
prompt_prefix = "Answer the following question after looking at the text."
answers = pipeline.run(questions, prefix=prompt_prefix, use_retriever=True)
print(json.dumps(answers, indent=4))

100%|████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.12it/s]


{
    "0": {
        "answers": {
            "text": "There is only one area code in New Hampshire, which is 603."
        },
        "passages": [
            "\"Area code 603\"\n \"previously allocated telephone numbers, as well as number pooling, the exhaustion time frame has been moved beyond 2025. Since New Hampshire has only one area code, callers in the state can reach any other telephone within the 603 area code with 7-digit dialing. Area code 603 Area code 603 is the sole area code for the U.S. state of New Hampshire in the North American Numbering Plan (NANP). It was created as one of the original 86 numbering plan areas in October 1947. As of April 2011, the 603 code was nearing exhaustion and a second area code for New\"",
            "\"Area code 603\"\n \"Area code 603 Area code 603 is the sole area code for the U.S. state of New Hampshire in the North American Numbering Plan (NANP). It was created as one of the original 86 numbering plan areas in October 1947. As of Apr