In [1]:
import pandas as pd
from summarizer import config
from langchain.llms import VertexAI
from langchain.prompts import PromptTemplate

# Re-writing the TAT-QA dataset responses

As part of the QA fine-tuning dataset, the TAT-QA dataset contains question and answers with a given context from 
financial reports curated by experts. However, some of the answers are span predictions, and thus does not contain complete sentences. For the purposes of our fine-tuning, we want to be able to generate coherent sentences, so the dataset would need to be augmented by re-phrasing the answers via LLM.

In [2]:
tatqa_df = pd.read_csv("gs://finetuningllama/tat_qa_concatenated.csv")
tatqa_df.head()

Unnamed: 0,uid,order,question,answer,answer_type,answer_from,rel_paragraphs,req_comparison,related_text,answer_count_words
0,23801627-ff77-4597-8d24-1c99e2452082,1,What is the company paid on a cost-plus type c...,['our allowable incurred costs plus a profit w...,span,text,['2'],False,"On a fixed-price type contract, we agree to pe...",25
1,86ae8d77-4dcd-4f82-baac-61c6a2551760,1,How is industry end market information presented?,['consistently with our internal management re...,span,text,"['1', '2']",False,(1) Industry end market information is present...,13
2,682e2b7e-2b40-4849-8671-2b85dd3c7f10,1,How is the discount rate for domestic plans de...,['By comparison against the FTSE pension liabi...,span,text,['2'],False,"For domestic plans, the discount rate was dete...",11
3,de34ea96-2a39-4cd5-b4a3-62b0f730103a,2,How is the discount rate for international pla...,['By comparison against country specific AA co...,span,text,['2'],False,"For domestic plans, the discount rate was dete...",11
4,bde0702e-2847-485b-be4a-fb037790bd59,2,Which countries does the group operate defined...,"['Germany, Ghana, India, Ireland, Italy, the U...",span,text,['10'],False,The Group operates defined benefit schemes in ...,3


In [3]:
tatqa_df["answer"] = tatqa_df["answer"].apply(lambda a: a[2:-2])
tatqa_df.answer.tolist()[0]

'our allowable incurred costs plus a profit which can be fixed or variable depending on the contract’s fee arrangement up to predetermined funding levels determined by the customer'

## Zero Shot Prompting

Since re-phrasing sentences should be a fairly common task for a LLM with understanding of the language, we applied zero shot prompting with PaLM2 to re-write the answers as sentences that can read well.

In [4]:
prompt = ("You are a helpful AI assistant. Given a question and an answer, you will rewrite "
          "the answer so that the answer consists of well-written and coherent sentence or sentences. "
          "DO NOT add any new information to the answer. Just make the answer read well given the question. "
          "Respond with the re-written answer.\n\nQuestion: {question}\n\nAnswer: {answer}")

plan_llm = VertexAI(
            project=config.GCP_PROJECT,
            temperature=0,
            model_name="text-bison",
            max_output_tokens=256
        )
prompt = PromptTemplate.from_template(prompt)
chain = prompt | plan_llm
chain.invoke({
    "question": tatqa_df.question.iloc[0],
    "answer": tatqa_df.answer.iloc[0]
})

' The company is paid all of its allowable incurred costs, plus a profit, which can be either fixed or variable depending on the contract’s fee arrangement, up to predetermined funding levels determined by the customer.'

In [5]:
tatqa_df.question.iloc[0]

'What is the company paid on a cost-plus type contract?'

## Re-writing the Dataset

The zero-shot chain is applied to the answers

In [6]:
from tqdm import tqdm
import time

tqdm.pandas()

def rewrite_answer(row):
    while True:
        try:
            result = chain.invoke({
                "question": row["question"],
                "answer": row["answer"]
            })
            return result.strip()
        except Exception as e:
            print(e)
            time.sleep(30)

tatqa_df["answer"] = tatqa_df.progress_apply(rewrite_answer, axis=1)

  2%|▏         | 100/4346 [01:24<59:39,  1.19it/s] Retrying langchain.llms.vertexai.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised ResourceExhausted: 429 Quota exceeded for aiplatform.googleapis.com/online_prediction_requests_per_base_model with base model: text-bison. Please submit a quota increase request. https://cloud.google.com/vertex-ai/docs/quotas..
Retrying langchain.llms.vertexai.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised ResourceExhausted: 429 Quota exceeded for aiplatform.googleapis.com/online_prediction_requests_per_base_model with base model: text-bison. Please submit a quota increase request. https://cloud.google.com/vertex-ai/docs/quotas..
Retrying langchain.llms.vertexai.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised ResourceExhausted: 429 Quota exceeded for aiplatform.googleapis.com/online_prediction_requests_per_base_model with base model: text-bison. Please su

KeyboardInterrupt: 