In [1]:
import sys
sys.path.append('..')
import ir_datasets as ir
import pandas as pd
import numpy as np
import random
from tqdm import tqdm
from finetune.auto_label import augment_docs
from finetune import config

random.seed(93)
np.random.seed(93)
tqdm.pandas()

In [2]:
train_corpus = ir.load("beir/fiqa/train")
for query in train_corpus.queries_iter():
    print(query)
    break

GenericQuery(query_id='0', text='What is considered a business expense on a business trip?')


In [3]:
for query in train_corpus.docs_iter():
    print(query)
    break

GenericDoc(doc_id='3', text="I'm not saying I don't like the idea of on-the-job training too, but you can't expect the company to do that. Training workers is not their job - they're building software. Perhaps educational systems in the U.S. (or their students) should worry a little about getting marketable skills in exchange for their massive investment in education, rather than getting out with thousands in student debt and then complaining that they aren't qualified to do anything.")


In [4]:
for query in train_corpus.qrels_iter():
    print(query)
    break

TrecQrel(query_id='0', doc_id='18850', relevance=1, iteration='0')


In [5]:
def convert_queries(corpus):
    q_id = []
    q_text = []
    for query in corpus.queries_iter():
        if len(query.text.strip()) < 5:
            continue
        q_id.append(int(query.query_id))
        q_text.append(query.text.strip())
    return pd.DataFrame({"query_id": q_id, "query_text": q_text})


def convert_docs(corpus):
    d_id = []
    d_text = []
    for doc in corpus.docs_iter():
        if len(doc.text.strip()) < 10:
            continue
        d_id.append(int(doc.doc_id))
        d_text.append(doc.text.strip())

    return pd.DataFrame({"doc_id": d_id, "doc_text": d_text})


def convert_relations(corpus):
    q_id = []
    d_id = []
    relevance = []
    for rel in corpus.qrels_iter():
        q_id.append(int(rel.query_id))
        d_id.append(int(rel.doc_id))
        relevance.append(float(rel.relevance))
    return pd.DataFrame({
        "query_id": q_id,
        "doc_id": d_id,
        "relevance": relevance
    })


def convert_to_tables(corpus):
    queries, docs, rels = convert_queries(corpus), convert_docs(corpus), convert_relations(corpus)
    docs_with_queries = pd.merge(left=docs, right=rels, on="doc_id")
    filtered_docs = docs_with_queries[["doc_id", "doc_text"]]
    return queries, filtered_docs, rels


train_query, train_doc, train_rel = convert_to_tables(train_corpus)
train_query.sample(3, random_state=93)

Unnamed: 0,query_id,query_text
4553,8363,Dividend Yield
1622,3173,How can I find a checking account that allows ...
4922,9445,When is the best time to put a large amount of...


In [6]:
API_KEY_PATH = "../key"

with open(API_KEY_PATH, "r") as fp:
    API_KEY = fp.read().strip()

In [21]:
FIQA_TRAIN_DOC_PATH = "../data/fiqa-augmented-traindoc.parquet"
FIQA_TRAIN_Q_PATH = "../data/fiqa-augmented-trainq.parquet"
FIQA_TRAIN_REL_PATH = "../data/fiqa-augmented-trainrel.parquet"

FIQA_DEV_DOC_PATH = "../data/fiqa-augmented-devdoc.parquet"
FIQA_DEV_Q_PATH = "../data/fiqa-augmented-devq.parquet"
FIQA_DEV_REL_PATH = "../data/fiqa-augmented-devrel.parquet"

FIQA_TEST_DOC_PATH = "../data/fiqa-augmented-testdoc.parquet"
FIQA_TEST_Q_PATH = "../data/fiqa-augmented-testq.parquet"
FIQA_TEST_REL_PATH = "../data/fiqa-augmented-testrel.parquet"

In [7]:
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI

plan_llm = ChatOpenAI(openai_api_key=API_KEY, temperature=config.FIQA_TEMPERATURE,
                      model_name=config.FIQA_MODEL)
system_prompt = SystemMessagePromptTemplate.from_template(config.FIQA_SYSTEM)
user_prompt = HumanMessagePromptTemplate.from_template(config.FIQA_USER)
chat_prompt = ChatPromptTemplate.from_messages([system_prompt, user_prompt])
augment_chain = LLMChain(llm=plan_llm, prompt=chat_prompt)
augment_chain(train_doc.iloc[0]["doc_text"])

{'input_text': 'Here are the SEC requirements: The federal securities laws define the term accredited investor in   Rule 501 of Regulation D as: a bank, insurance company, registered investment company, business development company, or small business investment company; an employee benefit plan, within the meaning of the Employee Retirement Income Security Act, if a bank, insurance company, or   registered investment adviser makes the investment decisions, or if   the plan has total assets in excess of $5 million; a charitable organization, corporation, or partnership with assets exceeding $5 million; a director, executive officer, or general partner of the company selling the securities; a business in which all the equity owners are accredited investors; a natural person who has individual net worth, or joint net worth with the person’s spouse, that exceeds $1 million at the time of the   purchase, excluding the value of the primary residence of such person; a natural person with inco

In [8]:
train_aug_docs, train_aug_qrel = augment_docs(augment_chain, train_doc, train_rel)

  0%|          | 10/14128 [00:49<15:04:22,  3.84s/it]Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised Timeout: Request timed out: HTTPSConnectionPool(host='api.openai.com', port=443): Read timed out. (read timeout=600).
 41%|████      | 5771/14128 [7:38:56<10:44:48,  4.63s/it]Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised Timeout: Request timed out: HTTPSConnectionPool(host='api.openai.com', port=443): Read timed out. (read timeout=600).
 41%|████      | 5811/14128 [7:53:08<18:40:37,  8.08s/it]  Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised Timeout: Request timed out: HTTPSConnectionPool(host='api.openai.com', port=443): Read timed out. (read timeout=600).
 69%|██████▉   | 9764/14128 [16:43:07<8:16:20,  6.82s/it]  Retrying langchain.chat_m

In [14]:
train_doc_final = pd.concat([train_doc, train_aug_docs], ignore_index=True).reset_index(drop=True)
train_query_final = train_query
train_rel_final = pd.concat([train_rel, train_aug_qrel], ignore_index=True).reset_index(drop=True)
train_doc_final.to_parquet(FIQA_TRAIN_DOC_PATH, index=False)
train_query_final.to_parquet(FIQA_TRAIN_Q_PATH, index=False)
train_rel_final.to_parquet(FIQA_TRAIN_REL_PATH, index=False)

In [17]:
dev_corpus = ir.load("beir/fiqa/dev")
dev_query, dev_doc, dev_rel = convert_to_tables(dev_corpus)
dev_query.sample(3, random_state=93)

[INFO] [starting] opening zip file
[INFO] [finished] opening zip file [1ms]


Unnamed: 0,query_id,query_text
390,1325,How far do I go with a mortgage approval proce...
66,5381,Question about large capital gain
65,8215,"Where to invest, that compounds interest more ..."


In [18]:
dev_aug_docs, dev_aug_qrel = augment_docs(augment_chain, dev_doc, dev_rel)

100%|██████████| 1236/1236 [2:05:04<00:00,  6.07s/it] 


In [19]:
test_corpus = ir.load("beir/fiqa/test")
test_query, test_doc, test_rel = convert_to_tables(test_corpus)
test_query.sample(3, random_state=93)

[INFO] [starting] opening zip file
[INFO] [finished] opening zip file [0ms]


Unnamed: 0,query_id,query_text
67,8539,Can the risk of investing in an asset be diffe...
82,1198,What are the consequences of IRS “reclassifica...
377,2513,How does revenue shared with someone else go i...


In [20]:
test_aug_docs, test_aug_qrel = augment_docs(augment_chain, test_doc, test_rel)

 68%|██████▊   | 1164/1705 [2:04:54<51:07,  5.67s/it]  Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised Timeout: Request timed out: HTTPSConnectionPool(host='api.openai.com', port=443): Read timed out. (read timeout=600).
100%|██████████| 1705/1705 [3:23:59<00:00,  7.18s/it]    


In [22]:
dev_doc_final = pd.concat([dev_doc, dev_aug_docs], ignore_index=True).reset_index(drop=True)
dev_query_final = dev_query
dev_rel_final = pd.concat([dev_rel, dev_aug_qrel], ignore_index=True).reset_index(drop=True)
dev_doc_final.to_parquet(FIQA_DEV_DOC_PATH, index=False)
dev_query_final.to_parquet(FIQA_DEV_Q_PATH, index=False)
dev_rel_final.to_parquet(FIQA_DEV_REL_PATH, index=False)

In [23]:
test_doc_final = pd.concat([test_doc, test_aug_docs], ignore_index=True).reset_index(drop=True)
test_query_final = test_query
test_rel_final = pd.concat([test_rel, test_aug_qrel], ignore_index=True).reset_index(drop=True)
test_doc_final.to_parquet(FIQA_TEST_DOC_PATH, index=False)
test_query_final.to_parquet(FIQA_TEST_Q_PATH, index=False)
test_rel_final.to_parquet(FIQA_TEST_REL_PATH, index=False)