# Demo of Chain of Tables

Paper: https://arxiv.org/abs/2401.04398

## Import libraries

In [1]:
# pip install openai==0.28

In [None]:
import pandas as pd
import requests
import zipfile
import io
import os
import re
import openai

from utils.load_data import wrap_input_for_demo
from utils.llm import ChatGPT
from utils.helper import *
from utils.evaluate import *
from utils.chain import *
from operations import *

## Define model

In [None]:
# User parameters
model_name: str = "gpt-3.5-turbo"
openai_api_key: str = "YOUR_API_KEY"

In [3]:
gpt_llm = ChatGPT(
    model_name=model_name,
    key=openai_api_key,
)

## Prepare WikiTQ dataset

In [4]:
wiki_tq_dir = "WikiTableQuestions/"
if os.path.exists(wiki_tq_dir) and os.path.isdir(wiki_tq_dir):
    print(f"WikiTableQuestions is already downloaded")
else:
    # Step 1: Download the zip file
    url = "https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip"
    response = requests.get(url)

    # Step 2: Unzip the contents
    with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref:
        zip_ref.extractall("WikiTableQuestions")

    print("Download and extraction complete!")

WikiTableQuestions is already downloaded


Here we use a subset of test datset, including 100 questions

In [19]:
test_cases = pd.read_csv(wiki_tq_dir + "data/random-split-4-dev.tsv", sep="\t").head(100)
test_cases = test_cases.set_index("id")
test_cases.head(20)

Unnamed: 0_level_0,utterance,context,targetValue
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
nt-2,which team won previous to crettyard?,csv/204-csv/772.csv,Wolfe Tones
nt-9,which players played the same position as ardo...,csv/203-csv/116.csv,Siim Ennemuist|Andri Aganits
nt-24,who ranked right after turkey?,csv/203-csv/812.csv,Sweden
nt-36,who was the top winner in 2002 of the division...,csv/204-csv/879.csv,Princeton
nt-42,what is the total number of popular votes cast...,csv/203-csv/558.csv,459640
nt-43,which division three team also played in the d...,csv/202-csv/73.csv,Seaford Town
nt-54,does theodis or david play center?,csv/204-csv/847.csv,Theodis Tarver
nt-72,what is the number of formula one series races...,csv/203-csv/198.csv,2
nt-75,how many places list no zip code in either the...,csv/204-csv/356.csv,18
nt-80,has the dominican republic won more or less me...,csv/203-csv/535.csv,less


## Helpful functions

In [6]:
def convert_df_to_table_text(df):
    return [list(df.columns)] + df.astype(str).values.tolist()

def normalize_answer(ans, normalize_numbers=True):
    # Remove "assistant:" prefix
    ans = re.sub(r'^Answer:\s*', '', ans, flags=re.IGNORECASE)

    # Lowercase
    ans = ans.lower()

    # Replace " and " with "|"
    ans = ans.replace(" and ", "|")

    # Remove punctuation (except "|")
    ans = re.sub(r"[^\w\s|]", "", ans)

    # Normalize numbers by removing text after numeric values if needed
    if normalize_numbers:
        ans = re.sub(r'(\d+)[^\d|]*', r'\1', ans)

    # Remove extra spaces around delimiters
    ans = "|".join(part.strip() for part in ans.split("|"))

    return ans.strip()

## Benchmark on WikiTQ

In [44]:
acc = 0
test_count = 0
for testcase_id in test_cases.index:
    df_path = wiki_tq_dir + test_cases.loc[testcase_id]["context"]
    df = pd.read_csv(df_path)
    statement = test_cases.loc[testcase_id]["utterance"]
    answer = test_cases.loc[testcase_id]["targetValue"]
    
    table_caption = ""
    table_text = convert_df_to_table_text(df)
    
    demo_sample = wrap_input_for_demo(
    statement=statement, table_caption=table_caption, table_text=table_text
    )
    proc_sample, dynamic_chain_log = dynamic_chain_exec_one_sample(
        sample=demo_sample, llm=gpt_llm
    )
    output_sample = simple_query(
        sample=proc_sample,
        table_info=get_table_info(proc_sample),
        llm=gpt_llm,
        use_demo=False,
        llm_options=gpt_llm.get_model_options(
            temperature=0.0, per_example_max_decode_steps=200, per_example_top_p=1.0
        ),
    )
    cotable_log = get_table_log(output_sample)
    
    response = cotable_log[-1]['cotable_result']
    response = response.strip().strip("'\"")
    final_response = response.split("Answer:")[-1].strip()
    
    print(f"ID: {testcase_id} | Response: {final_response} | Ground Truth: {answer}")
    if final_response == answer:
        acc += 1
    test_count += 1
    if test_count > 3:
        break

print(f"number of correct response = {acc}")

ID: nt-2 | Response: Confey | Ground Truth: Wolfe Tones
ID: nt-9 | Response: Martti Juhkami, Robert Täht | Ground Truth: Siim Ennemuist|Andri Aganits
ID: nt-24 | Response: Sweden | Ground Truth: Sweden
ID: nt-36 | Response: Michigan | Ground Truth: Princeton
number of correct response = 1


## Examine some testcases

In [28]:
### Example with chain-of-table
def get_chain_of_table(table_text, wrap_input, llm, answer):
    
    proc_sample, dynamic_chain_log = dynamic_chain_exec_one_sample(
        sample=wrap_input, llm=llm
    )
    output_sample = simple_query(
        sample=proc_sample,
        table_info=get_table_info(proc_sample),
        llm=llm,
        use_demo=False,
        llm_options=llm.get_model_options(
            temperature=0.0, per_example_max_decode_steps=200, per_example_top_p=1.0
        ),
    )
    cotable_log = get_table_log(output_sample)
    
    print(f'Question: {output_sample["statement"]}\n')
    print(f'Table: {output_sample["table_caption"]}')
    print(f"{pd.DataFrame(table_text[1:], columns=table_text[0])}\n")
    
    for table_info in cotable_log:
        if table_info["act_chain"]:
            table_text = table_info["table_text"]
            table_action = table_info["act_chain"][-1]
            if "skip" in table_action:
                continue
            if "query" in table_action:
                result = table_info["cotable_result"]
                print(f"{result}")
            else:
                print(f"-> {table_action}\n{pd.DataFrame(table_text[1:], columns=table_text[0])}")
                if 'group_sub_table' in table_info:
                    group_column, group_info = table_info["group_sub_table"]
                    group_headers = ["Group ID", group_column, "Count"]
                    group_rows = []
                    for i, (v, count) in enumerate(group_info):
                        if v.strip() == "":
                            v = "[Empty Cell]"
                        group_rows.append([f"Group {i+1}", v, str(count)])
                    print(f"{pd.DataFrame(group_rows, columns=group_headers)}")
                print()

    print(f"Groundtruth: {answer}")

### Testcase nt-120

In [29]:
testcase_id = "nt-120"
df_path = wiki_tq_dir + test_cases.loc[testcase_id]["context"]
statement = test_cases.loc[testcase_id]["utterance"]
# answer = test_cases.loc[testcase_id]["targetValue"]  # dataset contains wrong answer
answer = "South Korea"
df = pd.read_csv(df_path)

table_caption = ""
table_text = convert_df_to_table_text(df)

demo_sample = wrap_input_for_demo(
    statement=statement, table_caption=table_caption, table_text=table_text
)

In [30]:
get_chain_of_table(table_text, demo_sample, gpt_llm, answer)

Question: which opponent has the most wins

Table: 
     #               Date     Opponent     Score Result  \
0    1               1988    Indonesia       4-0    Won   
1    2               1988  South Korea       6-1    Won   
2    3               1988  South Korea       6-1    Won   
3    4               1988  South Korea       6-1    Won   
4    5               1988  South Korea       6-1    Won   
5    6               1988         Iraq       2-1    Won   
6    7               1988      Bahrain       2-0    Won   
7    8               1988      Bahrain       2-0    Won   
8    9      June 12, 1989       Guinea       2-2   Draw   
9   10      June 12, 1989     Colombia       1-0    Won   
10  11      June 12, 1989      Bahrain       1-0    Won   
11  12  February 22, 1989     Portugal       3-0    Won   
12  13               1989       Kuwait       1-0    Won   
13  14               1989      Bahrain  1-1(4-3)   Lost   

                                          Competition  
0     

Test with direct prompt

In [31]:
# Serialize the table (you can adjust formatting)
def serialize_table(df):
    return df.to_csv(index=False)

# Build the prompt
def build_prompt(df, question):
    serialized_table = serialize_table(df)
    prompt_str = f"""\
Here's a serialized table.

{serialized_table}

Please answer the question: {question}
Answer: """
    return prompt_str

# Query OpenAI API
def query_openai(prompt, model="gpt-3.5-turbo", temperature=0):
    response = openai.ChatCompletion.create(
        model=model,
        messages=[
            {"role": "user", "content": prompt}
        ],
        temperature=temperature,
    )
    return response['choices'][0]['message']['content'].strip()

In [32]:
prompt = build_prompt(df, statement)
openai.api_key = openai_api_key

response = query_openai(prompt)
print(f"response: {response}")

response: South Korea


### Testcase nt-234

In [58]:
testcase_id = "nt-234"
df_path = wiki_tq_dir + test_cases.loc[testcase_id]["context"]
statement = test_cases.loc[testcase_id]["utterance"]
answer = test_cases.loc[testcase_id]["targetValue"]
df = pd.read_csv(df_path)

table_caption = ""
table_text = convert_df_to_table_text(df)

demo_sample = wrap_input_for_demo(
    statement=statement, table_caption=table_caption, table_text=table_text
)

get_chain_of_table(table_text, demo_sample, gpt_llm, answer)

Statements: where was the match held immediately before 2014's at guizhou olympic stadium?

Table: 
   Season              Date                  Jia-A/CSL Winner  \
0    1995   9 December 1995                  Shanghai Shenhua   
1    1996      9 March 1997  Dalian Wanda\n(now Dalian Shide)   
2    1997     12 March 1998  Dalian Wanda\n(now Dalian Shide)   
3    1998      7 March 1999  Dalian Wanda\n(now Dalian Shide)   
4    1999      4 March 2000                   Shandong Luneng   
5    2000  30 December 2000                      Dalian Shide   
6    2001  26 February 2002                      Dalian Shide   
7    2001      2 March 2002                      Dalian Shide   
8    2002   6 February 2003                      Dalian Shide   
9    2003   18 January 2004                  Shanghai Shenhua   
10   2012  25 February 2012              Guangzhou Evergrande   
11   2013      3 March 2013              Guangzhou Evergrande   
12   2014  17 February 2014              Guangzhou Ever